run_model() — pytorch Function Reference
Architecture documentation for the run_model() function in distributed.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 24e4ab44_f8ed_b835_f57a_8a13ebaabe74["run_model()"] 9553c3cb_c1ee_9d98_ebf1_e7a19f8ebc17["setup()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 9553c3cb_c1ee_9d98_ebf1_e7a19f8ebc17 ff4518e9_afbb_35ef_3d8f_c605ec26b945["apply_fsdp()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| ff4518e9_afbb_35ef_3d8f_c605ec26b945 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4["timed()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 25f0a1d2_c9d8_0637_551b_ccf12d418bf9["torchviz_model()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 25f0a1d2_c9d8_0637_551b_ccf12d418bf9 6bdabe92_d399_b13a_52e9_438095848d71["profile_model()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 6bdabe92_d399_b13a_52e9_438095848d71 b2148985_e759_077c_dcb0_5118eb786fad["cleanup()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| b2148985_e759_077c_dcb0_5118eb786fad style 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/distributed.py lines 47–113
def run_model(args, model, inputs, key):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
# result_q = []
setup(rank, world_size)
if args.device == "cuda":
# needed for FSDP
torch.cuda.set_device(rank)
dev_rank = f"{args.device}:{rank}"
model = model.to(dev_rank)
def move_tensor(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(dev_rank)
return maybe_tensor
inputs = pytree.tree_map(move_tensor, inputs)
if args.fsdp:
model = apply_fsdp(
args,
model,
use_checkpointing=args.fsdp_checkpoint,
use_wrap_policy=args.fsdp_wrap,
)
elif args.ddp:
model = DDP(model)
if args.verbose:
print(model)
if args.dynamo:
dynamo.reset()
if args.verbose:
dynamo.config.verbose = True
dynamo.config.log_level = logging.DEBUG
if args.dynamo_no_optimize_ddp:
dynamo.config.optimize_ddp = False
if args.dynamo == "inductor" and args.fsdp:
torch._inductor.config.triton.cudagraphs = False
log.warning("disabling inductor cudagraphs for compatibility with FSDP")
def print_compile(gm, ex):
print(
f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
)
return gm
dynamo_ctx = dynamo.optimize(
print_compile if args.dynamo == "print" else args.dynamo
)
model = dynamo_ctx(model)
# warmup
_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
t_total = timed(
model, model_iter_fn, inputs, times=args.repeat, return_result=False
)
if args.torchviz:
torchviz_model(args, model, inputs, rank)
if args.profile:
profile_model(args, model, inputs, rank)
cleanup()
return t_total
Domain
Subdomains
Source
Frequently Asked Questions
What does run_model() do?
run_model() is a function in the pytorch codebase.
What does run_model() call?
run_model() calls 6 function(s): apply_fsdp, cleanup, profile_model, setup, timed, torchviz_model.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free