Home / Function/ check_accuracy() — pytorch Function Reference

check_accuracy() — pytorch Function Reference

Architecture documentation for the check_accuracy() function in utils.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce["check_accuracy()"]
  cf57d9a8_fc49_42e9_c1eb_53ac4b8190b5["check_accuracy()"]
  cf57d9a8_fc49_42e9_c1eb_53ac4b8190b5 -->|calls| bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce
  38cd02c9_bda5_f921_f4c9_f7772e1a880b["benchmark_single_shape()"]
  38cd02c9_bda5_f921_f4c9_f7772e1a880b -->|calls| bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce
  d4938e53_a1f5_c58c_a1e1_efc656eb80f1["clone_inputs()"]
  bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce -->|calls| d4938e53_a1f5_c58c_a1e1_efc656eb80f1
  dfaa8024_94e8_ed55_0a03_ab4c6f879ac0["compiled()"]
  bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce -->|calls| dfaa8024_94e8_ed55_0a03_ab4c6f879ac0
  style bf14c1d3_0a90_a5c6_58ff_3e6211ae65ce fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/genai_layers/utils.py lines 106–148

    def check_accuracy(self, args, kwargs) -> None:
        res = {}
        for backend in self.available_backends:
            args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
            res[backend] = getattr(self, backend)(args_ref, kwargs_ref)()

        if (
            "compiled" in self.available_backends
            and self.script_args.custom_compile_options
        ):
            torch._dynamo.reset()  # cause recompile
            with torch._inductor.config.patch(self.script_args.custom_compile_options):
                args_ref, kwargs_ref = self.clone_inputs(args, kwargs)
                res[self.script_args.custom_compile_name] = self.compiled(
                    args_ref, kwargs_ref
                )()

        gold = res["eager"]

        tol = {}
        if self.script_args.tolerance:
            tol = {
                "atol": self.script_args.tolerance,
                "rtol": self.script_args.tolerance,
            }
        for backend in res:
            if backend == "eager":
                continue
            try:
                torch.testing.assert_close(res[backend], gold, **tol)
                for t, gold_t in zip(res[backend], gold):
                    if t.requires_grad:
                        torch.testing.assert_close(t.grad, gold_t.grad, **tol)
                print(
                    f"Accuracy check \033[92m✓ succeed\033[0m for {backend} backend on {self.name} kernel"
                )
            except Exception as e:
                print(
                    f"Accuracy check \033[91m✗ failed\033[0m for {backend} backend on {self.name} kernel. Error {e}"
                )
                if self.script_args.exit_on_accuracy_failure:
                    print("Exit right away since --exit-on-accuracy-failure is set")
                    sys.exit(1)

Subdomains

Frequently Asked Questions

What does check_accuracy() do?
check_accuracy() is a function in the pytorch codebase.
What does check_accuracy() call?
check_accuracy() calls 2 function(s): clone_inputs, compiled.
What calls check_accuracy()?
check_accuracy() is called by 2 function(s): benchmark_single_shape, check_accuracy.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free