Home / Function/ load() — pytorch Function Reference

load() — pytorch Function Reference

Architecture documentation for the load() function in common.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  3df8de63_d71a_714c_93a1_1a0e24c0f362["load()"]
  d4bae6c8_efa9_4b5e_53ab_4e403ff2693e["load()"]
  d4bae6c8_efa9_4b5e_53ab_4e403ff2693e -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  6390669e_1575_3d66_5d4e_f02d7218bd4b["load()"]
  6390669e_1575_3d66_5d4e_f02d7218bd4b -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  529640b9_a20a_d7f3_29f8_900581eb0f3d["export_nativert()"]
  529640b9_a20a_d7f3_29f8_900581eb0f3d -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  d89eb84e_a6ed_8b3f_56e7_7c49cca94d00["export_aot_inductor()"]
  d89eb84e_a6ed_8b3f_56e7_7c49cca94d00 -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  9b71719a_7134_fa38_fd95_b665398662db["torchscript_jit_trace()"]
  9b71719a_7134_fa38_fd95_b665398662db -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  79f63331_206c_51dc_bfd1_4fd84b939754["check_accuracy()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362
  1eff8423_1f23_d138_6815_07d8dc29a749["_normalize_bench_inputs()"]
  3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 1eff8423_1f23_d138_6815_07d8dc29a749
  8964c48d_7b16_2455_878e_dbe471c4a037["_register_dataclass_output_as_pytree()"]
  3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 8964c48d_7b16_2455_878e_dbe471c4a037
  06ff896d_4db0_aa47_b3cd_be1da620f0ea["empty_gpu_cache()"]
  3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 06ff896d_4db0_aa47_b3cd_be1da620f0ea
  f258a26f_4fc8_d6ab_1167_c608fca925fb["export()"]
  3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| f258a26f_4fc8_d6ab_1167_c608fca925fb
  d4bae6c8_efa9_4b5e_53ab_4e403ff2693e["load()"]
  3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| d4bae6c8_efa9_4b5e_53ab_4e403ff2693e
  style 3df8de63_d71a_714c_93a1_1a0e24c0f362 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/common.py lines 1349–1430

    def load(cls, model, example_inputs, mode):
        import torch._inductor
        from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path

        key = weakref.ref(model)
        if key not in cls.cache:
            # Register the output dataclass to pytree
            example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
            with torch.no_grad():
                # copy.deepcopy is required to prevent any surprising side-effect,
                # see https://github.com/pytorch/pytorch/issues/113029
                # This will cause memory stats to be overshadowed by this eager run.
                # To fix that, memory stats will be reset later.
                example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)

            if pytree.is_namedtuple_instance(example_outputs):
                typ = type(example_outputs)
                pytree._register_namedtuple(
                    typ,
                    serialized_type_name=f"{typ.__module__}.{typ.__name__}",
                )
            else:
                _register_dataclass_output_as_pytree(example_outputs)

            combined_args = _combine_args(model, example_args, example_kwargs)
            dynamic_shapes = _tree_map_with_path(
                _produce_dynamic_shapes_for_export, combined_args
            )

            # delete example_outputs and reset memory stats here
            del example_outputs
            if current_device == "cuda":
                empty_gpu_cache(current_device)
                torch.cuda.reset_peak_memory_stats()
                pre_clone_memory_used = torch.cuda.max_memory_allocated()
            elif current_device == "hpu":
                torch.hpu.reset_peak_memory_stats()
                pre_clone_memory_used = torch.hpu.max_memory_allocated()

            # Clone the model pre-exporting.  This prevents scenarios observed in a few
            # models, where the forward pass modifies model state while exporting, and
            # FakeTensors are thus saved as model data members.  This invalidates model
            # reuse in eager mode, so it's safest to export a model clone.
            model_clone = copy.deepcopy(model)

            # Since CPU doesn't monitor max memory allocation, anything measuring peak
            # memory will miss our transient model clone on CPU anyway.
            #
            # The justification for tracking this value (in order to remove it from the
            # AOTInductor memory measurements) is that normal usage of AOTInductor would
            # not clone the model, since the eager model would be unused post-export.
            clone_memory_used = 0.0
            if current_device == "cuda":
                clone_memory_used = (
                    torch.cuda.max_memory_allocated() - pre_clone_memory_used
                ) / 1e9
            elif current_device == "hpu":
                clone_memory_used = (
                    torch.hpu.max_memory_allocated() - pre_clone_memory_used
                ) / 1e9

            inductor_configs = {}
            if mode == "max-autotune":
                inductor_configs["max_autotune"] = True
            ep = torch.export.export(
                model_clone,
                example_args,
                example_kwargs,
                dynamic_shapes=dynamic_shapes,
                strict=False,
            )
            with torch.no_grad():
                package_path = torch._inductor.aoti_compile_and_package(
                    ep, inductor_configs=inductor_configs
                )  # type: ignore[arg-type]

            cls.cache[key] = (
                torch._inductor.aoti_load_package(package_path),
                clone_memory_used,
            )

        return cls.cache[key][0]

Subdomains

Frequently Asked Questions

What does load() do?
load() is a function in the pytorch codebase.
What does load() call?
load() calls 5 function(s): _normalize_bench_inputs, _register_dataclass_output_as_pytree, empty_gpu_cache, export, load.
What calls load()?
load() is called by 6 function(s): check_accuracy, export_aot_inductor, export_nativert, load, load, torchscript_jit_trace.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free