timed() — pytorch Function Reference
Architecture documentation for the timed() function in common.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4["timed()"] f5d4c5a3_21f5_4ed1_f582_7d73c454a4d7["latency_experiment()"] f5d4c5a3_21f5_4ed1_f582_7d73c454a4d7 -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 04a3a4a6_8db3_854d_a893_02c9542bf9dd["speedup_experiment()"] 04a3a4a6_8db3_854d_a893_02c9542bf9dd -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 1d4553c2_729a_ea9b_a587_db58802dca02["baselines()"] 1d4553c2_729a_ea9b_a587_db58802dca02 -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 9fdf3e80_933d_ded9_76ca_8248ceccfcaf["xla()"] 9fdf3e80_933d_ded9_76ca_8248ceccfcaf -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 24e4ab44_f8ed_b835_f57a_8a13ebaabe74["run_model()"] 24e4ab44_f8ed_b835_f57a_8a13ebaabe74 -->|calls| 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 82d098c4_9a2a_eaf4_0bba_655d342fc39e["tensor_is_on_xla()"] 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 -->|calls| 82d098c4_9a2a_eaf4_0bba_655d342fc39e 0952be66_7e60_a42f_aa90_67a4647f1fd5["synchronize()"] 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 -->|calls| 0952be66_7e60_a42f_aa90_67a4647f1fd5 41369485_8a81_562f_2f09_b8d03a5222f0["patch_torch_manual_seed()"] 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 -->|calls| 41369485_8a81_562f_2f09_b8d03a5222f0 style 9c8df7bf_0e05_9bbb_5e2f_6c88f28b52d4 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/common.py lines 663–742
def timed(
model,
model_iter_fn,
example_inputs,
times=1,
return_result=False,
collect_outputs=False,
batch_size=None,
):
use_xla = tensor_is_on_xla(example_inputs)
synchronize()
if batch_size:
patch_torch_manual_seed()
if use_xla:
xm.mark_step()
xm.wait_device_ops()
def vary_batch(t: torch.Tensor, new_batch_size) -> torch.Tensor:
for i, s in enumerate(t.size()):
if s == batch_size:
# If new batch is smaller, we truncate
if new_batch_size < batch_size:
indexer = [slice(None)] * t.ndim
indexer[i] = slice(0, new_batch_size)
t = t[tuple(indexer)]
# If new batch is greater, we just duplicate the last row
# over and over until we hit the desired batch size
elif new_batch_size > batch_size:
indexer = [slice(None)] * t.ndim
indexer[i] = -1
last_slice = t[tuple(indexer)].unsqueeze(i)
repeat_shape = list(t.shape)
repeat_shape[i] = new_batch_size - batch_size
padding = last_slice.expand(*repeat_shape)
t = torch.cat([t, padding], dim=i)
break
return t
time_total = 0
# Dont collect outputs to correctly measure timing
for i in range(times):
# If batch_size is 1, it too often collides with other non batch size
# dimensions resulting in errors.
if batch_size and batch_size > 1:
# Calculate new batch size by varying the original batch size by up to 20%
# Ensure it's at least greater than 1
variation = random.uniform(0.8, 1.2)
new_batch_size = max(2, int(batch_size * variation))
example_inputs = tree_map_only(
torch.Tensor, lambda x: vary_batch(x, new_batch_size), example_inputs
)
# Put this call inside the loop to reset the seed for each iteration.
# Don't include reset_rng_state() to correctly measure timing
reset_rng_state(use_xla)
t_iter_begin = time.perf_counter()
result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
# instead of calling sync on result_list, we should call mark_step.
# In training case, result_list may be empty, but we want to
# send all the pending graphs for compilation.
if use_xla:
# For the model running on regular torchxla (baseline), we need the
# mark step to send the accumulated graph for compilation.
#
# For the model running with dynamo/torchxla bridge, in training case,
# we need the mark step to send the optimizer graph out for
# compilation.
xm.mark_step()
t_iter_end = time.perf_counter()
time_total += t_iter_end - t_iter_begin
t_0 = time.perf_counter()
if use_xla:
xm.wait_device_ops()
synchronize()
t_1 = time.perf_counter()
time_total += t_1 - t_0
return (time_total, result) if return_result else time_total
Domain
Subdomains
Source
Frequently Asked Questions
What does timed() do?
timed() is a function in the pytorch codebase.
What does timed() call?
timed() calls 3 function(s): patch_torch_manual_seed, synchronize, tensor_is_on_xla.
What calls timed()?
timed() is called by 5 function(s): baselines, latency_experiment, run_model, speedup_experiment, xla.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free