load_model() — pytorch Function Reference
Architecture documentation for the load_model() function in huggingface.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD f731aa1a_ba23_ab06_f8e5_e87d6a842749["load_model()"] 121584aa_ef62_fff7_1d2f_03dd864e80c3["_get_model_cls_and_config()"] f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 121584aa_ef62_fff7_1d2f_03dd864e80c3 2c3d93c7_6719_b38f_a928_3d268c212537["_download_model()"] f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 2c3d93c7_6719_b38f_a928_3d268c212537 684ace56_6d9f_5ea1_f47a_cb8aa5834a49["validate_model()"] f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 684ace56_6d9f_5ea1_f47a_cb8aa5834a49 179ecd7f_27e8_48db_95f4_85c2cda45c18["generate_inputs_for_model()"] f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 179ecd7f_27e8_48db_95f4_85c2cda45c18 6c4a7daf_e704_6d7a_bcaa_98900e3a377b["get_model_and_inputs()"] f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 6c4a7daf_e704_6d7a_bcaa_98900e3a377b style f731aa1a_ba23_ab06_f8e5_e87d6a842749 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/huggingface.py lines 410–493
def load_model(
self,
device,
model_name,
batch_size=None,
extra_args=None,
):
is_training = self.args.training
use_eval_mode = self.args.use_eval_mode
dtype = torch.float32
reset_rng_state()
# Get batch size
if model_name in BATCH_SIZE_KNOWN_MODELS:
batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
elif batch_size is None:
batch_size_default = 16
log.info(
f"Batch size not specified for {model_name}. Setting batch_size=16" # noqa: G004
)
if batch_size is None:
batch_size = batch_size_default
batch_size_divisors = self._config["batch_size"]["divisors"]
if model_name in batch_size_divisors:
batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
log.info(
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004
)
# Get model and example inputs
if model_name in HF_LLM_MODELS:
benchmark_cls = HF_LLM_MODELS[model_name]
model, example_inputs = benchmark_cls.get_model_and_inputs(
model_name, device
)
# Set this flag so that when we test for speedup, we use
# model.generate instead of using model.forward
self.hf_llm = True
def generate(self, _, example_inputs, collect_outputs=True):
return model.generate(**example_inputs)
self.generate = types.MethodType(generate, self)
else:
self.hf_llm = False
model_cls, config = self._get_model_cls_and_config(model_name)
model = self._download_model(model_name)
model = model.to(device, dtype=dtype)
example_inputs = generate_inputs_for_model(
model_cls, model, model_name, batch_size, device, include_loss_args=True
)
# So we can check for correct gradients without eliminating the dropout computation
for attr in dir(config):
if "drop" in attr and isinstance(getattr(config, attr), float):
setattr(config, attr, 1e-30)
# Turning off kv cache for torchbench models. This is not the right
# thing to do, but the pt2 dashboard is outdated. Real transformers
# benchmarks will be added soon using a different infra.
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
model.config.use_cache = False
if self.args.enable_activation_checkpointing:
model.gradient_checkpointing_enable()
if (
is_training
and not use_eval_mode
and not (
self.args.accuracy and model_name in self._config["only_inference"]
)
):
model.train()
else:
model.eval()
self.validate_model(model, example_inputs)
return device, model_name, model, example_inputs, batch_size
Domain
Subdomains
Calls
Source
Frequently Asked Questions
What does load_model() do?
load_model() is a function in the pytorch codebase.
What does load_model() call?
load_model() calls 5 function(s): _download_model, _get_model_cls_and_config, generate_inputs_for_model, get_model_and_inputs, validate_model.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free