test_total_time() — pytorch Function Reference
Architecture documentation for the test_total_time() function in inductor_bmm.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 3c19a66e_66bb_a804_2d9f_5eed5907c0cc["test_total_time()"] f8ffe77e_f49e_f15b_2b69_37243da200a0["shapes()"] f8ffe77e_f49e_f15b_2b69_37243da200a0 -->|calls| 3c19a66e_66bb_a804_2d9f_5eed5907c0cc f89158cb_fcc9_8241_4e22_048353b4ffc9["inductor_aten_bmm()"] 3c19a66e_66bb_a804_2d9f_5eed5907c0cc -->|calls| f89158cb_fcc9_8241_4e22_048353b4ffc9 0fbf8bd3_5fc3_d55b_ee7f_584386da6ff6["inductor_triton_bmm()"] 3c19a66e_66bb_a804_2d9f_5eed5907c0cc -->|calls| 0fbf8bd3_5fc3_d55b_ee7f_584386da6ff6 fef5e6fa_eee5_e93a_0d8f_a4ef0477af02["time_with_torch_timer()"] 3c19a66e_66bb_a804_2d9f_5eed5907c0cc -->|calls| fef5e6fa_eee5_e93a_0d8f_a4ef0477af02 style 3c19a66e_66bb_a804_2d9f_5eed5907c0cc fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/microbenchmarks/inductor_bmm.py lines 23–45
def test_total_time(shapes):
print("shape; torch bmm; inductor aten bmm; inductor triton bmm")
for i in range(len(shapes)):
a_shape, b_shape = shapes[i]
print(a_shape, "x", b_shape, end="; ")
a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
config.triton.use_bmm = False
inductor_aten_bmm(a, b)
config.triton.use_bmm = True
inductor_triton_bmm(a, b)
torch_ms = time_with_torch_timer(torch_bmm, (a, b)).mean * 1000
config.triton.use_bmm = False
ind_aten_ms = time_with_torch_timer(inductor_aten_bmm, (a, b)).mean * 1000
config.triton.use_bmm = True
ind_triton_ms = time_with_torch_timer(inductor_triton_bmm, (a, b)).mean * 1000
print(torch_ms, ind_aten_ms, ind_triton_ms, sep="; ")
Domain
Subdomains
Called By
Source
Frequently Asked Questions
What does test_total_time() do?
test_total_time() is a function in the pytorch codebase.
What does test_total_time() call?
test_total_time() calls 3 function(s): inductor_aten_bmm, inductor_triton_bmm, time_with_torch_timer.
What calls test_total_time()?
test_total_time() is called by 1 function(s): shapes.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free