Home / Function/ bench() — pytorch Function Reference

bench() — pytorch Function Reference

Architecture documentation for the bench() function in bench_mm_fusion.py from the pytorch codebase.

Entity Profile

Relationship Graph

Source Code

benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py lines 47–93

def bench(shape, layer_id, p, fusion_types=None):
    torch._logging.set_logs(inductor_metrics=True)
    if fusion_types is None:
        fusion_types = [""]
    dtype = torch.float16
    M, K = shape[0]
    _, N = shape[1]
    torch.manual_seed(0)
    # allocate inputs
    a = torch.randn(shape[0], device="cuda", dtype=dtype)
    b = torch.randn(shape[1], device="cuda", dtype=dtype)

    def tflops(ms):
        return M * K * N / ms * 1e-9

    row = [layer_id]
    for fusion_type in fusion_types:
        if fusion_type == "":
            fn_mm = Func.mm
        else:
            fn_mm = getattr(Func, f"mm_{fusion_type}")

        if "add" in fusion_type:
            bias = torch.randn((M, N), dtype=dtype, device="cuda")
        else:
            bias = None

        args = (a, b, bias)

        def fn():
            return fn_mm(*args)

        torch._inductor.config.triton.mm = "aten"
        torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
        torch._inductor.config.triton.mm = "triton"
        # reset to force code gen new python code
        torch._dynamo.reset()
        torch._inductor.metrics.reset()
        triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
        if torch._inductor.metrics.generated_kernel_count != 1:
            raise AssertionError(
                f"Expected 1 generated kernel, but got {torch._inductor.metrics.generated_kernel_count}"
            )
        row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])

    p.add_row(row)
    torch._logging.set_logs()

Subdomains

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free