Home / Function/ setup_amp() — pytorch Function Reference

setup_amp() — pytorch Function Reference

Architecture documentation for the setup_amp() function in common.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe["setup_amp()"]
  da8c81f0_c44b_4b89_914d_318c84a98a42["cast_based_on_args()"]
  da8c81f0_c44b_4b89_914d_318c84a98a42 -->|calls| f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe
  c9be2096_e6d7_2374_ad2e_a6e33f435ada["run()"]
  c9be2096_e6d7_2374_ad2e_a6e33f435ada -->|calls| f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe
  style f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/common.py lines 1790–1829

    def setup_amp(self, current_device=None):
        if self.args.only in self.fp32_only_models:
            return

        devices = [current_device] if current_device else self.args.devices
        if self.args.amp:
            # AMP training can lead to small loss values which can underflow
            # gradient values returning in zero gradients. To solve this
            # problem, PyTorch introduces GradScaler. GradScaler is a stateful
            # structure, that scales the loss values to prevent underflow. Loss
            # values are big at the beginning of training (therefore not
            # requiring scaling), while loss value tends to be small as network
            # starts getting better (requiring scaling). GradScaler manages all
            # of this fine tuning, checking the gradients are turning to inf,
            # discarding such batches.

            # Since we are not running a long iteration, default value of
            # init_scale 65536 is going to turn all gradients to inf. Therefore,
            # we just use a init_scale of 2.0 for benchmarking purpose.

            # Disabling Gradscaler because
            #  1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
            #  2) Current setup shares grad_scaler for eager and dynamo model,
            #  which is bad as Gradscaler has state and can adjust the scaling
            #  factor between eager and dynamo run, making accuracy check
            #  harder.
            # self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
            self.autocast = functools.partial(
                torch.amp.autocast, device_type=devices[0]
            )
            if self.args.amp_dtype is None:
                if self.args.only in self.amp_dtype_bfloat16:
                    self.autocast_arg["dtype"] = torch.bfloat16
            else:
                amp_dtype = (
                    torch.float16
                    if self.args.amp_dtype == "float16"
                    else torch.bfloat16
                )
                self.autocast_arg["dtype"] = amp_dtype

Subdomains

Frequently Asked Questions

What does setup_amp() do?
setup_amp() is a function in the pytorch codebase.
What calls setup_amp()?
setup_amp() is called by 2 function(s): cast_based_on_args, run.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free