xla() — pytorch Function Reference
Architecture documentation for the xla() function in common.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 9fdf3e80_933d_ded9_76ca_8248ceccfcaf["xla()"] 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4["timed()"] 9fdf3e80_933d_ded9_76ca_8248ceccfcaf -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 3473d1a5_c1f5_fc97_006e_79a1d3081bef["write_outputs()"] 9fdf3e80_933d_ded9_76ca_8248ceccfcaf -->|calls| 3473d1a5_c1f5_fc97_006e_79a1d3081bef style 9fdf3e80_933d_ded9_76ca_8248ceccfcaf fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/common.py lines 1295–1325
def xla(args, model_iter_fn, model, example_inputs):
xla_dev = xm.xla_device(devkind=current_device)
model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev)
example_inputs_xla = tree_map_only(
torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs
)
for _ in range(3): # warmup
timed(model, model_iter_fn, example_inputs)
timed(model_xla, model_iter_fn, example_inputs_xla)
timings = np.zeros((args.repeat, 2), np.float64)
timings.fill(1.0e10)
for rep in range(args.repeat):
timings[rep, 0] = timed(model, model_iter_fn, example_inputs)
timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla)
pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
time_baseline, time_xla = np.median(timings, axis=0)
speedup = time_baseline / time_xla
write_outputs(
output_filename,
("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"),
[
current_device,
current_name,
current_batch_size,
speedup,
time_baseline,
time_xla,
],
)
return format_speedup(speedup, pvalue)
Domain
Subdomains
Calls
Source
Frequently Asked Questions
What does xla() do?
xla() is a function in the pytorch codebase.
What does xla() call?
xla() calls 2 function(s): timed, write_outputs.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free