setup_amp() — pytorch Function Reference
Architecture documentation for the setup_amp() function in common.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe["setup_amp()"] da8c81f0_c44b_4b89_914d_318c84a98a42["cast_based_on_args()"] da8c81f0_c44b_4b89_914d_318c84a98a42 -->|calls| f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe c9be2096_e6d7_2374_ad2e_a6e33f435ada["run()"] c9be2096_e6d7_2374_ad2e_a6e33f435ada -->|calls| f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe style f77ea0ed_facd_a7d6_7e4a_17f4bae6c1fe fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/common.py lines 1790–1829
def setup_amp(self, current_device=None):
if self.args.only in self.fp32_only_models:
return
devices = [current_device] if current_device else self.args.devices
if self.args.amp:
# AMP training can lead to small loss values which can underflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.
# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.
# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype is None:
if self.args.only in self.amp_dtype_bfloat16:
self.autocast_arg["dtype"] = torch.bfloat16
else:
amp_dtype = (
torch.float16
if self.args.amp_dtype == "float16"
else torch.bfloat16
)
self.autocast_arg["dtype"] = amp_dtype
Domain
Subdomains
Called By
Source
Frequently Asked Questions
What does setup_amp() do?
setup_amp() is a function in the pytorch codebase.
What calls setup_amp()?
setup_amp() is called by 2 function(s): cast_based_on_args, run.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free