Home / Function/ xla() — pytorch Function Reference

xla() — pytorch Function Reference

Architecture documentation for the xla() function in common.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  9fdf3e80_933d_ded9_76ca_8248ceccfcaf["xla()"]
  9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4["timed()"]
  9fdf3e80_933d_ded9_76ca_8248ceccfcaf -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4
  3473d1a5_c1f5_fc97_006e_79a1d3081bef["write_outputs()"]
  9fdf3e80_933d_ded9_76ca_8248ceccfcaf -->|calls| 3473d1a5_c1f5_fc97_006e_79a1d3081bef
  style 9fdf3e80_933d_ded9_76ca_8248ceccfcaf fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/common.py lines 1295–1325

def xla(args, model_iter_fn, model, example_inputs):
    xla_dev = xm.xla_device(devkind=current_device)
    model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev)
    example_inputs_xla = tree_map_only(
        torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs
    )
    for _ in range(3):  # warmup
        timed(model, model_iter_fn, example_inputs)
        timed(model_xla, model_iter_fn, example_inputs_xla)
    timings = np.zeros((args.repeat, 2), np.float64)
    timings.fill(1.0e10)
    for rep in range(args.repeat):
        timings[rep, 0] = timed(model, model_iter_fn, example_inputs)
        timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla)

    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
    time_baseline, time_xla = np.median(timings, axis=0)
    speedup = time_baseline / time_xla
    write_outputs(
        output_filename,
        ("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"),
        [
            current_device,
            current_name,
            current_batch_size,
            speedup,
            time_baseline,
            time_xla,
        ],
    )
    return format_speedup(speedup, pvalue)

Subdomains

Frequently Asked Questions

What does xla() do?
xla() is a function in the pytorch codebase.
What does xla() call?
xla() calls 2 function(s): timed, write_outputs.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free