Home / Function/ test_total_time() — pytorch Function Reference

test_total_time() — pytorch Function Reference

Architecture documentation for the test_total_time() function in inductor_cpu_atomic.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  aba84ade_afa4_11e1_7f1a_d64880227dd0["test_total_time()"]
  4d642674_e009_46e9_ec64_9f93737955ce["shapes()"]
  4d642674_e009_46e9_ec64_9f93737955ce -->|calls| aba84ade_afa4_11e1_7f1a_d64880227dd0
  21baf9d9_d45b_1791_6ee6_1ff53576b4fb["torch_scatter_add()"]
  aba84ade_afa4_11e1_7f1a_d64880227dd0 -->|calls| 21baf9d9_d45b_1791_6ee6_1ff53576b4fb
  8ec02f26_c724_1f4d_66f2_3579ef8b7f89["inductor_scatter_add()"]
  aba84ade_afa4_11e1_7f1a_d64880227dd0 -->|calls| 8ec02f26_c724_1f4d_66f2_3579ef8b7f89
  fef5e6fa_eee5_e93a_0d8f_a4ef0477af02["time_with_torch_timer()"]
  aba84ade_afa4_11e1_7f1a_d64880227dd0 -->|calls| fef5e6fa_eee5_e93a_0d8f_a4ef0477af02
  style aba84ade_afa4_11e1_7f1a_d64880227dd0 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/microbenchmarks/inductor_cpu_atomic.py lines 18–56

def test_total_time(shapes, types):
    print(
        "shape; type; torch scatter_add; inductor scatter_add; torch scatter_add (worst case); inductor scatter_add (worst case)"
    )
    for shape, dtype in itertools.product(shapes, types):
        print(shape, dtype, sep="; ", end="; ")

        torch.manual_seed(1)
        if dtype.is_floating_point:
            src = torch.randn(shape, device="cpu", dtype=dtype)
            dst = torch.randn(shape, device="cpu", dtype=dtype)
        else:
            src = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype)
            dst = torch.randint(0, shape[1], shape, device="cpu", dtype=dtype)
        index = torch.randint(0, shape[1], shape, device="cpu", dtype=torch.int64)
        worst_index = torch.tensor([[0] * shape[1]], device="cpu", dtype=torch.int64)

        torch_result = torch_scatter_add(dst, src, index)
        inductor_result = inductor_scatter_add(dst, src, index)
        torch.testing.assert_close(torch_result, inductor_result)

        torch_ms = (
            time_with_torch_timer(torch_scatter_add, (dst, src, index)).mean * 1000
        )
        inductor_ms = (
            time_with_torch_timer(inductor_scatter_add, (dst, src, index)).mean * 1000
        )
        torch_worst_ms = (
            time_with_torch_timer(torch_scatter_add, (dst, src, worst_index)).mean
            * 1000
        )
        inductor_worst_ms = (
            time_with_torch_timer(inductor_scatter_add, (dst, src, worst_index)).mean
            * 1000
        )

        print(torch_ms, inductor_ms, torch_worst_ms, inductor_worst_ms, sep="; ")

        torch._dynamo.reset()

Subdomains

Called By

Frequently Asked Questions

What does test_total_time() do?
test_total_time() is a function in the pytorch codebase.
What does test_total_time() call?
test_total_time() calls 3 function(s): inductor_scatter_add, time_with_torch_timer, torch_scatter_add.
What calls test_total_time()?
test_total_time() is called by 1 function(s): shapes.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free