Home / Function/ run_model() — pytorch Function Reference

run_model() — pytorch Function Reference

Architecture documentation for the run_model() function in distributed.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74["run_model()"]
  9553c3cb_c1ee_9d98_ebf1_e7a19f8ebc17["setup()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 9553c3cb_c1ee_9d98_ebf1_e7a19f8ebc17
  ff4518e9_afbb_35ef_3d8f_c605ec26b945["apply_fsdp()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| ff4518e9_afbb_35ef_3d8f_c605ec26b945
  9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4["timed()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4
  25f0a1d2_c9d8_0637_551b_ccf12d418bf9["torchviz_model()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 25f0a1d2_c9d8_0637_551b_ccf12d418bf9
  6bdabe92_d399_b13a_52e9_438095848d71["profile_model()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 6bdabe92_d399_b13a_52e9_438095848d71
  b2148985_e759_077c_dcb0_5118eb786fad["cleanup()"]
  24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| b2148985_e759_077c_dcb0_5118eb786fad
  style 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/distributed.py lines 47–113

def run_model(args, model, inputs, key):
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    # result_q = []

    setup(rank, world_size)
    if args.device == "cuda":
        # needed for FSDP
        torch.cuda.set_device(rank)

    dev_rank = f"{args.device}:{rank}"
    model = model.to(dev_rank)

    def move_tensor(maybe_tensor):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.to(dev_rank)
        return maybe_tensor

    inputs = pytree.tree_map(move_tensor, inputs)

    if args.fsdp:
        model = apply_fsdp(
            args,
            model,
            use_checkpointing=args.fsdp_checkpoint,
            use_wrap_policy=args.fsdp_wrap,
        )
    elif args.ddp:
        model = DDP(model)

    if args.verbose:
        print(model)

    if args.dynamo:
        dynamo.reset()
        if args.verbose:
            dynamo.config.verbose = True
            dynamo.config.log_level = logging.DEBUG
        if args.dynamo_no_optimize_ddp:
            dynamo.config.optimize_ddp = False
        if args.dynamo == "inductor" and args.fsdp:
            torch._inductor.config.triton.cudagraphs = False
            log.warning("disabling inductor cudagraphs for compatibility with FSDP")

        def print_compile(gm, ex):
            print(
                f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
            )
            return gm

        dynamo_ctx = dynamo.optimize(
            print_compile if args.dynamo == "print" else args.dynamo
        )
        model = dynamo_ctx(model)

    # warmup
    _ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
    t_total = timed(
        model, model_iter_fn, inputs, times=args.repeat, return_result=False
    )
    if args.torchviz:
        torchviz_model(args, model, inputs, rank)
    if args.profile:
        profile_model(args, model, inputs, rank)

    cleanup()
    return t_total

Subdomains

Frequently Asked Questions

What does run_model() do?
run_model() is a function in the pytorch codebase.
What does run_model() call?
run_model() calls 6 function(s): apply_fsdp, cleanup, profile_model, setup, timed, torchviz_model.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free