Home / Class/ BenchmarkKernel Class — pytorch Architecture

BenchmarkKernel Class — pytorch Architecture

Architecture documentation for the BenchmarkKernel class in utils.py from the pytorch codebase.

Entity Profile

Relationship Graph

Source Code

benchmarks/dynamo/genai_layers/utils.py lines 47–225

class BenchmarkKernel:
    def __init__(self, script_args):
        self.script_args = script_args
        self.name = self.__class__.__name__
        self.available_backends: list[str] = []
        self.compile_mode: str = script_args.compile_mode

        # mapping from backend to list of performance results
        self.profiling_results: defaultdict[str, list[Performance]] = defaultdict(list)

    def get_memory_bytes(self, args, kwargs) -> int:
        # Get the necessary memory access in bytes for the kernelßß
        raise NotImplementedError

    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
        # Get a list of input shapes to benchmark the kernel
        raise NotImplementedError

    def eager(self, args, kwargs) -> Any:
        raise NotImplementedError

    def compiled(self, args, kwargs) -> Any:
        raise NotImplementedError

    def helion(self, args, kwargs) -> Any:
        raise NotImplementedError

    def quack(self, args, kwargs) -> Any:
        raise NotImplementedError

    def liger(self, args, kwargs) -> Any:
        raise NotImplementedError

    def triton(self, args, kwargs) -> Any:
        raise NotImplementedError

    def benchmark(self):
        raise NotImplementedError

    def clone_inputs(self, args, kwargs) -> Any:
        args_ref = [
            arg.clone().detach().requires_grad_(arg.requires_grad) for arg in args
        ]

        kwargs_ref = (
            {
                k: (
                    v.clone().detach().requires_grad_(v.requires_grad)
                    if isinstance(v, torch.Tensor)
                    else v
                )
                for k, v in kwargs.items()
            }
            if kwargs
            else kwargs
        )

        return args_ref, kwargs_ref

    def check_accuracy(self, args, kwargs) -> None:
        res = {}
        for backend in self.available_backends:
            args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
            res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()

        if (
            "compiled" in self.available_backends
            and self.script_args.custom_compile_options
        ):
            torch._dynamo.reset()  # cause recompile
            with torch._inductor.config.patch(self.script_args.custom_compile_options):
                args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
                res[self.script_args.custom_compile_name] = self.compiled(
                    args_ref, kwargs_ref
                )()

        gold = res["eager"]

        tol = {}
        if self.script_args.tolerance:
            tol = {
                "atol": self.script_args.tolerance,
                "rtol": self.script_args.tolerance,
            }
        for backend in res:
            if backend == "eager":
                continue
            try:
                torch.testing.assert_close(res[backend], gold, **tol)
                for t, gold_t in zip(res[backend], gold):
                    if t.requires_grad:
                        torch.testing.assert_close(t.grad, gold_t.grad, **tol)
                print(
                    f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
                )
            except Exception as e:
                print(
                    f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
                )
                if self.script_args.exit_on_accuracy_failure:
                    print("Exit right away since --exit-on-accuracy-failure is set")
                    sys.exit(1)

    def benchmark_single_shape_for_backend(
        self, backend, args, kwargs, setting, fn=None
    ) -> bool:
        if fn is None:
            fn = getattr(self, backend)
        args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
        try:
            avg_time = benchmark_kernel_in_milliseconds(fn(args_ref, kwargs_ref))
        except Exception as e:
            print(
                f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}"
            )
            self.available_backends.remove(backend)  # noqa: B909
            return False
        mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref)
        perf = Performance(setting, avg_time, mem_bytes)
        print(f"{self.name} kernel on {backend} backend. {perf}")
        self.profiling_results[backend].append(perf)
        return True

    def benchmark_single_shape(
        self, args, kwargs=None, should_check_accuracy=True, setting: str = ""
    ):
        for backend in self.available_backends:
            self.benchmark_single_shape_for_backend(backend, args, kwargs, setting)
        if (
            "compiled" in self.available_backends
            and self.script_args.custom_compile_options
        ):
            torch._dynamo.reset()  # cause recompile
            with torch._inductor.config.patch(self.script_args.custom_compile_options):
                status = self.benchmark_single_shape_for_backend(
                    self.script_args.custom_compile_name,
                    args,
                    kwargs,
                    setting,
                    fn=self.compiled,
                )
            if not status:
                self.script_args.custom_compile_options = (
                    None  # once fail, don't run again
                )

        if should_check_accuracy:
            self.check_accuracy(args, kwargs)

    def visualize(self) -> None:
        device_name = torch.cuda.get_device_name(0)
        visualize_comparison(
            self.profiling_results,
            title=f"{self.name} ({device_name})",
            output_path=f"{self.name}_bench",
        )
        return

    def report_geomean_speedup(self) -> None:
        print(f"Geomean speedup for benchmark {self.name}")
        eager_result = {
            result.setting: result for result in self.profiling_results["eager"]
        }
        print(f"  eager {len(eager_result)} data points")
        for backend, backend_result in self.profiling_results.items():
            if backend == "eager":
                continue
            speeduplist = []
            for result in backend_result:
                eager_latency = eager_result[result.setting].latency
                backend_latency = result.latency
                speeduplist.append(
                    eager_latency / backend_latency if backend_latency != 0 else 0.0
                )

            if len(speeduplist) > 0:
                print(
                    f"  {backend} {len(speeduplist)} data points, {gmean(speeduplist):.2f}x speedup"
                )

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free