Home / Class/ CrossEntropyForward Class — pytorch Architecture

CrossEntropyForward Class — pytorch Architecture

Architecture documentation for the CrossEntropyForward class in kernels.py from the pytorch codebase.

Entity Profile

Relationship Graph

Source Code

benchmarks/dynamo/genai_layers/kernels.py lines 20–118

class CrossEntropyForward(BenchmarkKernel):
    def __init__(self, script_args):
        super().__init__(script_args)
        self.available_backends = ["eager", "compiled", "quack", "liger"]

    def get_shapes(self) -> tuple[tuple[int, ...], ...]:
        return (
            (32768, 256),
            (32768, 512),
            (32768, 1024),
            (32768, 2048),
            (32768, 4096),
            (32768, 8192),
            (32768, 16384),
            (32768, 32768),
            (32768, 65536),
            (16384, 131072),
            (8192, 262144),
        )

    def get_memory_bytes(self, args, kwargs) -> int:
        # Read x (M*N elements) + read target (M elements) + write loss (M elements)
        x, target = args
        M, N = x.shape
        dtype = x.dtype
        return (M * N + M + M) * dtype.itemsize

    def eager(self, args, kwargs=None) -> Any:
        if kwargs is not None:
            raise AssertionError(f"Expected kwargs to be None, but got {kwargs}")
        x, target = args
        return lambda: F.cross_entropy(x, target, reduction="none")

    def compiled(self, args, kwargs=None) -> Any:
        if kwargs is not None:
            raise AssertionError(f"Expected kwargs to be None, but got {kwargs}")
        x, target = args

        # Mark batch size as dynamic for realistic workload
        torch._dynamo.mark_dynamic(x, 0)
        torch._dynamo.mark_dynamic(target, 0)

        # Need `lambda` otherwise torch.compile will not trace the function.
        # More discussion: https://github.com/pytorch/pytorch/issues/158455
        compiled_cross_entropy = torch.compile(
            lambda x, target: F.cross_entropy(x, target, reduction="none"),
            mode=self.compile_mode,
            fullgraph=True,
        )
        return lambda: compiled_cross_entropy(x, target)

    def quack(self, args, kwargs=None) -> Any:
        if kwargs is not None:
            raise AssertionError(f"Expected kwargs to be None, but got {kwargs}")
        x, target = args
        from quack.cross_entropy import _cross_entropy

        return lambda: _cross_entropy(x, target)

    def liger(self, args, kwargs=None) -> Any:
        if kwargs is not None:
            raise AssertionError(f"Expected kwargs to be None, but got {kwargs}")
        from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

        x, target = args
        cross_entropy = LigerCrossEntropyLoss(reduction="none")
        return lambda: cross_entropy(x, target)

    def benchmark(self):
        for M, N in self.get_shapes():
            print(f"\n Tensor dimensions: [{M}, {N}]")
            # quack requires cutlass dtype
            torch_dtype = cutlass_torch.dtype(cutlass.BFloat16)
            x = 0.1 * torch.randn(M, N, device="cuda", dtype=torch_dtype)
            target = torch.randint(0, N, (M,), device="cuda", dtype=torch.int64)
            self.benchmark_single_shape((x, target), setting=f"shape: [{M}, {N}]")

    def check_accuracy(self, args, kwargs) -> None:
        res = {}
        for backend in self.available_backends:
            args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
            res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()
        gold = res["eager"]
        for backend in self.available_backends:
            if backend == "eager":
                continue
            if backend == "quack":
                # quack's cross_entropy only returns float32 loss output.
                # Need to convert it to the same dtype as gold for comparison.
                res[backend] = res[backend].to(gold.dtype)
            try:
                torch.testing.assert_close(res[backend], gold)
                print(
                    f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
                )
            except Exception as e:
                print(
                    f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
                )

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free