Home / Function/ test_total_time() — pytorch Function Reference

test_total_time() — pytorch Function Reference

Architecture documentation for the test_total_time() function in inductor_mm.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  4526aa5c_f504_cae4_579d_39d37e870ba4["test_total_time()"]
  31b7d7cf_d369_349c_d3fb_295720c1291f["shapes()"]
  31b7d7cf_d369_349c_d3fb_295720c1291f -->|calls| 4526aa5c_f504_cae4_579d_39d37e870ba4
  3717ae06_9fef_f61d_b6ef_83761c73d71b["inductor_aten_mm()"]
  4526aa5c_f504_cae4_579d_39d37e870ba4 -->|calls| 3717ae06_9fef_f61d_b6ef_83761c73d71b
  f175abea_1f77_fbbd_4800_616258913520["inductor_triton_mm()"]
  4526aa5c_f504_cae4_579d_39d37e870ba4 -->|calls| f175abea_1f77_fbbd_4800_616258913520
  fef5e6fa_eee5_e93a_0d8f_a4ef0477af02["time_with_torch_timer()"]
  4526aa5c_f504_cae4_579d_39d37e870ba4 -->|calls| fef5e6fa_eee5_e93a_0d8f_a4ef0477af02
  style 4526aa5c_f504_cae4_579d_39d37e870ba4 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/microbenchmarks/inductor_mm.py lines 35–61

def test_total_time(shapes):
    print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
    for i in range(len(shapes)):
        a_shape, b_shape = shapes[i]
        print(a_shape, "x", b_shape, end="; ")
        a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
        b = torch.randn(b_shape, device="cuda", dtype=a.dtype)

        config.triton.mm = "aten"
        inductor_aten_mm(a, b)

        config.triton.mm = "triton"
        inductor_triton_mm(a, b)

        torch_ms = time_with_torch_timer(torch_mm, (a, b)).mean * 1000

        triton_ms = time_with_torch_timer(triton_mm, (a, b)).mean * 1000

        config.triton.mm = "aten"
        ind_aten_ms = time_with_torch_timer(inductor_aten_mm, (a, b)).mean * 1000

        config.triton.mm = "triton"
        ind_triton_ms = time_with_torch_timer(inductor_triton_mm, (a, b)).mean * 1000

        print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")

        torch._dynamo.reset()

Subdomains

Called By

Frequently Asked Questions

What does test_total_time() do?
test_total_time() is a function in the pytorch codebase.
What does test_total_time() call?
test_total_time() calls 3 function(s): inductor_aten_mm, inductor_triton_mm, time_with_torch_timer.
What calls test_total_time()?
test_total_time() is called by 1 function(s): shapes.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free