load() — pytorch Function Reference
Architecture documentation for the load() function in common.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 3df8de63_d71a_714c_93a1_1a0e24c0f362["load()"] d4bae6c8_efa9_4b5e_53ab_4e403ff2693e["load()"] d4bae6c8_efa9_4b5e_53ab_4e403ff2693e -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 6390669e_1575_3d66_5d4e_f02d7218bd4b["load()"] 6390669e_1575_3d66_5d4e_f02d7218bd4b -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 529640b9_a20a_d7f3_29f8_900581eb0f3d["export_nativert()"] 529640b9_a20a_d7f3_29f8_900581eb0f3d -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 d89eb84e_a6ed_8b3f_56e7_7c49cca94d00["export_aot_inductor()"] d89eb84e_a6ed_8b3f_56e7_7c49cca94d00 -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 9b71719a_7134_fa38_fd95_b665398662db["torchscript_jit_trace()"] 9b71719a_7134_fa38_fd95_b665398662db -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 79f63331_206c_51dc_bfd1_4fd84b939754["check_accuracy()"] 79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 3df8de63_d71a_714c_93a1_1a0e24c0f362 1eff8423_1f23_d138_6815_07d8dc29a749["_normalize_bench_inputs()"] 3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 1eff8423_1f23_d138_6815_07d8dc29a749 8964c48d_7b16_2455_878e_dbe471c4a037["_register_dataclass_output_as_pytree()"] 3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 8964c48d_7b16_2455_878e_dbe471c4a037 06ff896d_4db0_aa47_b3cd_be1da620f0ea["empty_gpu_cache()"] 3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| 06ff896d_4db0_aa47_b3cd_be1da620f0ea f258a26f_4fc8_d6ab_1167_c608fca925fb["export()"] 3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| f258a26f_4fc8_d6ab_1167_c608fca925fb d4bae6c8_efa9_4b5e_53ab_4e403ff2693e["load()"] 3df8de63_d71a_714c_93a1_1a0e24c0f362 -->|calls| d4bae6c8_efa9_4b5e_53ab_4e403ff2693e style 3df8de63_d71a_714c_93a1_1a0e24c0f362 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/common.py lines 1349–1430
def load(cls, model, example_inputs, mode):
import torch._inductor
from torch.export.dynamic_shapes import _combine_args, _tree_map_with_path
key = weakref.ref(model)
if key not in cls.cache:
# Register the output dataclass to pytree
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
with torch.no_grad():
# copy.deepcopy is required to prevent any surprising side-effect,
# see https://github.com/pytorch/pytorch/issues/113029
# This will cause memory stats to be overshadowed by this eager run.
# To fix that, memory stats will be reset later.
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
if pytree.is_namedtuple_instance(example_outputs):
typ = type(example_outputs)
pytree._register_namedtuple(
typ,
serialized_type_name=f"{typ.__module__}.{typ.__name__}",
)
else:
_register_dataclass_output_as_pytree(example_outputs)
combined_args = _combine_args(model, example_args, example_kwargs)
dynamic_shapes = _tree_map_with_path(
_produce_dynamic_shapes_for_export, combined_args
)
# delete example_outputs and reset memory stats here
del example_outputs
if current_device == "cuda":
empty_gpu_cache(current_device)
torch.cuda.reset_peak_memory_stats()
pre_clone_memory_used = torch.cuda.max_memory_allocated()
elif current_device == "hpu":
torch.hpu.reset_peak_memory_stats()
pre_clone_memory_used = torch.hpu.max_memory_allocated()
# Clone the model pre-exporting. This prevents scenarios observed in a few
# models, where the forward pass modifies model state while exporting, and
# FakeTensors are thus saved as model data members. This invalidates model
# reuse in eager mode, so it's safest to export a model clone.
model_clone = copy.deepcopy(model)
# Since CPU doesn't monitor max memory allocation, anything measuring peak
# memory will miss our transient model clone on CPU anyway.
#
# The justification for tracking this value (in order to remove it from the
# AOTInductor memory measurements) is that normal usage of AOTInductor would
# not clone the model, since the eager model would be unused post-export.
clone_memory_used = 0.0
if current_device == "cuda":
clone_memory_used = (
torch.cuda.max_memory_allocated() - pre_clone_memory_used
) / 1e9
elif current_device == "hpu":
clone_memory_used = (
torch.hpu.max_memory_allocated() - pre_clone_memory_used
) / 1e9
inductor_configs = {}
if mode == "max-autotune":
inductor_configs["max_autotune"] = True
ep = torch.export.export(
model_clone,
example_args,
example_kwargs,
dynamic_shapes=dynamic_shapes,
strict=False,
)
with torch.no_grad():
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=inductor_configs
) # type: ignore[arg-type]
cls.cache[key] = (
torch._inductor.aoti_load_package(package_path),
clone_memory_used,
)
return cls.cache[key][0]
Domain
Subdomains
Calls
Source
Frequently Asked Questions
What does load() do?
load() is a function in the pytorch codebase.
What does load() call?
load() calls 5 function(s): _normalize_bench_inputs, _register_dataclass_output_as_pytree, empty_gpu_cache, export, load.
What calls load()?
load() is called by 6 function(s): check_accuracy, export_aot_inductor, export_nativert, load, load, torchscript_jit_trace.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free