Home / Function/ check_tolerance() — pytorch Function Reference

check_tolerance() — pytorch Function Reference

Architecture documentation for the check_tolerance() function in common.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  0f41377a_e71d_fd3c_1974_a0ef9ec1158e["check_tolerance()"]
  9bf8449e_2d7f_c370_514b_b3c7bf20f8e1["run_one_model()"]
  9bf8449e_2d7f_c370_514b_b3c7bf20f8e1 -->|calls| 0f41377a_e71d_fd3c_1974_a0ef9ec1158e
  da0c865a_ac14_7a10_8fc5_8a3b7509426d["maybe_cast()"]
  0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| da0c865a_ac14_7a10_8fc5_8a3b7509426d
  6c83aab9_f1ee_6751_91aa_682a715a5746["init_optimizer()"]
  0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 6c83aab9_f1ee_6751_91aa_682a715a5746
  316b98a5_2d78_5681_cc76_6024cfcb4191["run_n_iterations()"]
  0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 316b98a5_2d78_5681_cc76_6024cfcb4191
  3473d1a5_c1f5_fc97_006e_79a1d3081bef["write_outputs()"]
  0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 3473d1a5_c1f5_fc97_006e_79a1d3081bef
  style 0f41377a_e71d_fd3c_1974_a0ef9ec1158e fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/common.py lines 2506–2581

    def check_tolerance(
        self, name, model, example_inputs, optimize_ctx, base_device="cpu"
    ):
        """
        Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
        """
        tolerance_status = "pass"
        if name in self.skip_accuracy_checks_large_models_dashboard:
            tolerance_status = "pass_due_to_skip"
            return tolerance_status
        # Cast the model to float16/float32 as necessary
        model, example_inputs = self.maybe_cast(model, example_inputs)

        with self.pick_grad(name, self.args.training):
            # Get results of native pytorch
            reset_rng_state()
            model_copy = copy.deepcopy(model)
            model_copy = model_copy.to(base_device)
            example_inputs_copy = copy.deepcopy(example_inputs)
            example_inputs_copy = tree_map(
                lambda x: x.to(base_device), example_inputs_copy
            )
            self.init_optimizer(name, base_device, model_copy.parameters())
            correct_result = self.run_n_iterations(
                model_copy, example_inputs_copy, self.model_iter_fn
            )

            # Run with Dynamo
            # Sometime CI fails with random triton compilation failure which will be skipped for now
            # TODO: revisit this after switching to new Triton runtime
            reset_rng_state()
            torch._dynamo.reset()
            try:
                self.init_optimizer(name, current_device, model.parameters())
                optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
                new_result = self.run_n_iterations(
                    model_copy, example_inputs, optimized_model_iter_fn
                )
            except Exception:
                log.exception("")
                print(
                    "TorchDynamo optimized model failed to run because of following error"
                )
                return "fail_to_run"

            def dump_max_mean_values(tol, ref, res):
                if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
                    for refi, resi in zip(ref, res):
                        dump_max_mean_values(tol, refi, resi)
                elif isinstance(ref, dict):
                    for k in ref:
                        dump_max_mean_values(tol, ref[k], res[k])
                elif isinstance(ref, torch.Tensor):
                    res = res.to(base_device)
                    t = torch.abs(ref - res) / (1 + torch.abs(ref))
                    tol.append(t.flatten().to(torch.float32))
                return tol

            tol = []
            dump_max_mean_values(tol, correct_result, new_result)
            tol = torch.cat(tol)
            tol = torch.tensor(tol)
            max = torch.max(tol)
            mean = torch.mean(tol)
            div = torch.std(tol)
            headers = ["dev", "name", "batch_size", "max", "mean", "std"]
            fields = [
                current_device,
                current_name,
                current_batch_size,
                max.item(),
                mean.item(),
                div.item(),
            ]
            write_outputs(output_filename, headers, fields)
        return tolerance_status

Subdomains

Called By

Frequently Asked Questions

What does check_tolerance() do?
check_tolerance() is a function in the pytorch codebase.
What does check_tolerance() call?
check_tolerance() calls 4 function(s): init_optimizer, maybe_cast, run_n_iterations, write_outputs.
What calls check_tolerance()?
check_tolerance() is called by 1 function(s): run_one_model.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free