bench() — pytorch Function Reference
Architecture documentation for the bench() function in bench_mm_fusion.py from the pytorch codebase.
Entity Profile
Relationship Graph
Source Code
benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py lines 47–93
def bench(shape, layer_id, p, fusion_types=None):
torch._logging.set_logs(inductor_metrics=True)
if fusion_types is None:
fusion_types = [""]
dtype = torch.float16
M, K = shape[0]
_, N = shape[1]
torch.manual_seed(0)
# allocate inputs
a = torch.randn(shape[0], device="cuda", dtype=dtype)
b = torch.randn(shape[1], device="cuda", dtype=dtype)
def tflops(ms):
return M * K * N / ms * 1e-9
row = [layer_id]
for fusion_type in fusion_types:
if fusion_type == "":
fn_mm = Func.mm
else:
fn_mm = getattr(Func, f"mm_{fusion_type}")
if "add" in fusion_type:
bias = torch.randn((M, N), dtype=dtype, device="cuda")
else:
bias = None
args = (a, b, bias)
def fn():
return fn_mm(*args)
torch._inductor.config.triton.mm = "aten"
torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
torch._inductor.config.triton.mm = "triton"
# reset to force code gen new python code
torch._dynamo.reset()
torch._inductor.metrics.reset()
triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
if torch._inductor.metrics.generated_kernel_count != 1:
raise AssertionError(
f"Expected 1 generated kernel, but got {torch._inductor.metrics.generated_kernel_count}"
)
row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])
p.add_row(row)
torch._logging.set_logs()
Domain
Subdomains
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free