Home / Function/ load_model() — pytorch Function Reference

load_model() — pytorch Function Reference

Architecture documentation for the load_model() function in huggingface.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  f731aa1a_ba23_ab06_f8e5_e87d6a842749["load_model()"]
  121584aa_ef62_fff7_1d2f_03dd864e80c3["_get_model_cls_and_config()"]
  f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 121584aa_ef62_fff7_1d2f_03dd864e80c3
  2c3d93c7_6719_b38f_a928_3d268c212537["_download_model()"]
  f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 2c3d93c7_6719_b38f_a928_3d268c212537
  684ace56_6d9f_5ea1_f47a_cb8aa5834a49["validate_model()"]
  f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 684ace56_6d9f_5ea1_f47a_cb8aa5834a49
  179ecd7f_27e8_48db_95f4_85c2cda45c18["generate_inputs_for_model()"]
  f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 179ecd7f_27e8_48db_95f4_85c2cda45c18
  6c4a7daf_e704_6d7a_bcaa_98900e3a377b["get_model_and_inputs()"]
  f731aa1a_ba23_ab06_f8e5_e87d6a842749 -->|calls| 6c4a7daf_e704_6d7a_bcaa_98900e3a377b
  style f731aa1a_ba23_ab06_f8e5_e87d6a842749 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/huggingface.py lines 410–493

    def load_model(
        self,
        device,
        model_name,
        batch_size=None,
        extra_args=None,
    ):
        is_training = self.args.training
        use_eval_mode = self.args.use_eval_mode
        dtype = torch.float32
        reset_rng_state()

        # Get batch size
        if model_name in BATCH_SIZE_KNOWN_MODELS:
            batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
        elif batch_size is None:
            batch_size_default = 16
            log.info(
                f"Batch size not specified for {model_name}. Setting batch_size=16"  # noqa: G004
            )

        if batch_size is None:
            batch_size = batch_size_default
            batch_size_divisors = self._config["batch_size"]["divisors"]
            if model_name in batch_size_divisors:
                batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
                log.info(
                    f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"  # noqa: G004
                )

        # Get model and example inputs
        if model_name in HF_LLM_MODELS:
            benchmark_cls = HF_LLM_MODELS[model_name]
            model, example_inputs = benchmark_cls.get_model_and_inputs(
                model_name, device
            )

            # Set this flag so that when we test for speedup, we use
            # model.generate instead of using model.forward
            self.hf_llm = True

            def generate(self, _, example_inputs, collect_outputs=True):
                return model.generate(**example_inputs)

            self.generate = types.MethodType(generate, self)

        else:
            self.hf_llm = False

            model_cls, config = self._get_model_cls_and_config(model_name)
            model = self._download_model(model_name)
            model = model.to(device, dtype=dtype)

            example_inputs = generate_inputs_for_model(
                model_cls, model, model_name, batch_size, device, include_loss_args=True
            )

            # So we can check for correct gradients without eliminating the dropout computation
            for attr in dir(config):
                if "drop" in attr and isinstance(getattr(config, attr), float):
                    setattr(config, attr, 1e-30)

            # Turning off kv cache for torchbench models. This is not the right
            # thing to do, but the pt2 dashboard is outdated. Real transformers
            # benchmarks will be added soon using a different infra.
            if hasattr(model, "config") and hasattr(model.config, "use_cache"):
                model.config.use_cache = False

        if self.args.enable_activation_checkpointing:
            model.gradient_checkpointing_enable()

        if (
            is_training
            and not use_eval_mode
            and not (
                self.args.accuracy and model_name in self._config["only_inference"]
            )
        ):
            model.train()
        else:
            model.eval()

        self.validate_model(model, example_inputs)
        return device, model_name, model, example_inputs, batch_size

Subdomains

Frequently Asked Questions

What does load_model() do?
load_model() is a function in the pytorch codebase.
What does load_model() call?
load_model() calls 5 function(s): _download_model, _get_model_cls_and_config, generate_inputs_for_model, get_model_and_inputs, validate_model.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free