Home / Function/ check_accuracy() — pytorch Function Reference

check_accuracy() — pytorch Function Reference

Architecture documentation for the check_accuracy() function in common.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  79f63331_206c_51dc_bfd1_4fd84b939754["check_accuracy()"]
  d825b76f_2b85_74b9_6e8c_af36be54ac1f["minify_model()"]
  d825b76f_2b85_74b9_6e8c_af36be54ac1f -->|calls| 79f63331_206c_51dc_bfd1_4fd84b939754
  9bf8449e_2d7f_c370_514b_b3c7bf20f8e1["run_one_model()"]
  9bf8449e_2d7f_c370_514b_b3c7bf20f8e1 -->|calls| 79f63331_206c_51dc_bfd1_4fd84b939754
  6a0a2015_4bf4_1e7a_daa6_dbf2c23883c7["deepcopy_and_maybe_parallelize()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 6a0a2015_4bf4_1e7a_daa6_dbf2c23883c7
  6c83aab9_f1ee_6751_91aa_682a715a5746["init_optimizer()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 6c83aab9_f1ee_6751_91aa_682a715a5746
  316b98a5_2d78_5681_cc76_6024cfcb4191["run_n_iterations()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 316b98a5_2d78_5681_cc76_6024cfcb4191
  4ed3d19d_5919_8dac_36a3_c8fbf9eb090b["get_tolerance_and_cosine_flag()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 4ed3d19d_5919_8dac_36a3_c8fbf9eb090b
  da0c865a_ac14_7a10_8fc5_8a3b7509426d["maybe_cast()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| da0c865a_ac14_7a10_8fc5_8a3b7509426d
  f7ecd13c_923c_8159_d6c1_ed5366e68d46["use_larger_multiplier_for_smaller_tensor()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| f7ecd13c_923c_8159_d6c1_ed5366e68d46
  ad25db9d_c25e_9f90_e743_71bbb0e95c25["get_accuracy_check_runs()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| ad25db9d_c25e_9f90_e743_71bbb0e95c25
  330c13ec_d234_4f92_ced7_5048aed7e29b["use_iou_for_bool_accuracy()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 330c13ec_d234_4f92_ced7_5048aed7e29b
  60033173_353c_2f29_caad_975b749ace28["get_iou_threshold()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 60033173_353c_2f29_caad_975b749ace28
  f00a6213_f2a3_0eef_9451_ee5a24d1ab9f["get_dynamo_stats()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| f00a6213_f2a3_0eef_9451_ee5a24d1ab9f
  b8cdd827_b831_469a_75e3_9eb4a7bb1874["output_signpost()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| b8cdd827_b831_469a_75e3_9eb4a7bb1874
  3473d1a5_c1f5_fc97_006e_79a1d3081bef["write_outputs()"]
  79f63331_206c_51dc_bfd1_4fd84b939754 -->|calls| 3473d1a5_c1f5_fc97_006e_79a1d3081bef
  style 79f63331_206c_51dc_bfd1_4fd84b939754 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/common.py lines 2168–2504

    def check_accuracy(
        self, name, model, example_inputs, optimize_ctx, experiment, tag
    ):
        """
        Checks accuracy.
        1) Collect the outputs with fp64 datatype. This is useful for error checking.
        2) Checks if eager itself has variations.
        """
        start_stats = get_dynamo_stats()

        def record_status(accuracy_status, dynamo_start_stats):
            """
            Records the status in the csv file
            """
            if current_name in self.non_deterministic_models:
                if accuracy_status in (
                    "pass",
                    "eager_two_runs_differ",
                    "fail_accuracy",
                ):
                    accuracy_status = "pass"

            headers = ["dev", "name", "batch_size", "accuracy"]
            fields = [current_device, current_name, current_batch_size, accuracy_status]

            if tag is not None:
                headers.insert(3, "tag")
                fields.insert(3, tag)

            o_headers = list(headers)
            o_fields = list(fields)

            dynamo_stats = get_dynamo_stats()
            dynamo_stats.subtract(dynamo_start_stats)
            for k, v in dynamo_stats.items():
                headers.append(k)
                fields.append(v)

            total_wall_time = output_signpost(
                dict(zip(o_headers, o_fields)),
                self.args,
                self.suite_name,
            )
            headers.append("compilation_latency")
            fields.append(total_wall_time)
            write_outputs(output_filename, headers, fields)

            if self.args.print_compilation_time:
                print(f"Compilation time (from dynamo_timed): {total_wall_time}")

            return accuracy_status

        if name in self.skip_accuracy_checks_large_models_dashboard:
            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)

        # Skip all accuracy check for the torchao backend
        if self.args.backend == "torchao":
            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)

        with self.pick_grad(name, self.args.training):
            # Collect the fp64 reference outputs to be used later for accuracy checking.
            fp64_outputs = None
            model_fp64 = None
            inputs_fp64 = None
            try:
                model_fp64, inputs_fp64 = cast_to_fp64(
                    self.deepcopy_and_maybe_parallelize(model),
                    clone_inputs(example_inputs),
                )
                self.init_optimizer(name, current_device, model_fp64.parameters())
                fp64_outputs = self.run_n_iterations(
                    model_fp64, inputs_fp64, self.model_iter_fn
                )
                fp64_outputs = tree_map(
                    lambda x: x.to(torch.float64)
                    if isinstance(x, torch.Tensor) and x.is_floating_point()
                    else x,
                    fp64_outputs,
                )
            except Exception:
                log.warning(
                    "fp64 golden ref were not generated for %s. Setting accuracy check to cosine",
                    name,
                    exc_info=True,
                )
                self.args.cosine = True
                fp64_outputs = None
            finally:
                del model_fp64, inputs_fp64
                empty_gpu_cache(current_device)

            tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
                self.args.training, current_device, name
            )

            # Cast the model to float16/float32 as necessary
            model, example_inputs = self.maybe_cast(model, example_inputs)
            accuracy_status = "pass"

            # Get results of native pytorch
            reset_rng_state()
            model_copy = None
            try:
                with torch.compiler.set_stance("force_eager"):
                    model_copy = self.deepcopy_and_maybe_parallelize(model)
                    self.init_optimizer(name, current_device, model_copy.parameters())
                    correct_result = self.run_n_iterations(
                        model_copy, clone_inputs(example_inputs), self.model_iter_fn
                    )
            except Exception as e:
                accuracy_status = (
                    "eager_1st_run_OOM"
                    if isinstance(e, torch.cuda.OutOfMemoryError)
                    else "eager_1st_run_fail"
                )
                log.exception("")
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
            finally:
                del model_copy
                empty_gpu_cache(current_device)

            # Rerun native pytorch
            reset_rng_state()
            model_copy = None
            try:
                with torch.compiler.set_stance("force_eager"):
                    model_copy = self.deepcopy_and_maybe_parallelize(model)
                    self.init_optimizer(name, current_device, model_copy.parameters())
                    correct_rerun_result = self.run_n_iterations(
                        model_copy, clone_inputs(example_inputs), self.model_iter_fn
                    )
            except Exception as e:
                accuracy_status = (
                    "eager_2nd_run_OOM"
                    if isinstance(e, torch.cuda.OutOfMemoryError)
                    else "eager_2nd_run_fail"
                )
                log.exception("")
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
            finally:
                del model_copy
                empty_gpu_cache(current_device)

            # Two eager runs should have exactly same result, within tolerance.
            # TODO If we want the above to be true, then deterministic should be set.
            # For example, MIOpen convolutions could be implemented with non-deterministic algos.
            is_same = True
            try:
                if (
                    name not in self.skip_accuracy_check_as_eager_non_deterministic
                    and not same(
                        correct_result,
                        correct_rerun_result,
                        fp64_ref=None,
                        cos_similarity=False,
                        tol=tolerance if torch.version.hip else 0,
                        equal_nan=self.equal_nan,
                        use_larger_multiplier_for_smaller_tensor=self.use_larger_multiplier_for_smaller_tensor(
                            name
                        ),
                    )
                ):
                    is_same = False
            except Exception:
                # Sometimes torch.allclose may throw RuntimeError
                is_same = False

            if not is_same:
                accuracy_status = "eager_two_runs_differ"
                return record_status(accuracy_status, dynamo_start_stats=start_stats)

            correct_rerun_result = None

            # Support multiple accuracy check runs for flaky models
            accuracy_check_runs = self.get_accuracy_check_runs(name)
            pass_count = 0

            for run_idx in range(accuracy_check_runs):
                # Run with Dynamo
                reset_rng_state()
                torch._dynamo.reset()
                torch._dynamo.utils.counters.clear()
                model_copy = None
                run_passed = True

                try:
                    model_copy = self.deepcopy_and_maybe_parallelize(model)
                    self.init_optimizer(name, current_device, model_copy.parameters())
                    if (
                        self.args.export
                        or self.args.export_aot_inductor
                        or self.args.export_nativert
                        or self.args.torchscript_jit_trace
                        or self.args.aot_precompile
                    ):
                        # apply export on module directly
                        # no need for n iterations
                        # the logic should be the same to self.model_iter_fn (forward_pass)
                        with self.autocast(**self.autocast_arg):
                            optimized_model_iter_fn = optimize_ctx(
                                model_copy, example_inputs
                            )
                            new_result = optimized_model_iter_fn(
                                model_copy, example_inputs
                            )
                    else:
                        optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
                        new_result = self.run_n_iterations(
                            model_copy, example_inputs, optimized_model_iter_fn
                        )
                except Exception as e:
                    log.exception("")
                    print(
                        "TorchDynamo optimized model failed to run because of following error"
                    )
                    accuracy_status = (
                        "OOM"
                        if isinstance(e, torch.cuda.OutOfMemoryError)
                        else "fail_to_run"
                    )
                    return record_status(
                        accuracy_status, dynamo_start_stats=start_stats
                    )
                finally:
                    del model_copy

                if name in self.skip_accuracy_check_as_eager_non_deterministic:
                    return record_status(
                        "pass_due_to_skip", dynamo_start_stats=start_stats
                    )

                force_max_multiplier = False
                if (
                    self.args.freezing
                    and self.args.bfloat16
                    and torch._dynamo.utils.counters["inductor"]["binary_folding_conv"]
                    > 0
                ):
                    force_max_multiplier = True

                try:
                    if self.args.training and self.args.amp:
                        if process_fn := self.get_output_amp_train_process_func.get(
                            name, None
                        ):
                            correct_result = process_fn(correct_result)
                            new_result = process_fn(new_result)
                            fp64_outputs = process_fn(fp64_outputs)

                    if (
                        self.args.save_model_outputs_to
                        and self.args.compare_model_outputs_with
                        and self.args.save_model_outputs_to
                        == self.args.compare_model_outputs_with
                    ):
                        log.warning(
                            "args.save_model_outputs_to and args.compare_model_outputs_with points to the same path."
                            "Result will be undefined."
                        )

                    if self.args.save_model_outputs_to:
                        print(
                            f"Save model outputs to: {self.args.save_model_outputs_to}"
                        )
                        torch.save(new_result, self.args.save_model_outputs_to)

                    if self.args.compare_model_outputs_with:
                        print(
                            f"Load model outputs from {self.args.compare_model_outputs_with} to compare"
                        )
                        saved_result = torch.load(
                            self.args.compare_model_outputs_with, weights_only=False
                        )
                        is_bitwise_same = bitwise_same(saved_result, new_result)
                        if not is_bitwise_same:
                            print(
                                "The result is not bitwise equivalent to the previously saved result"
                            )
                            return record_status(
                                "not_bitwise_equivalent",
                                dynamo_start_stats=start_stats,
                            )

                        print(
                            "The result is bitwise equivalent to the previously saved result"
                        )
                        del saved_result

                    if not same(
                        correct_result,
                        new_result,
                        fp64_outputs,
                        equal_nan=self.equal_nan,
                        use_larger_multiplier_for_smaller_tensor=self.use_larger_multiplier_for_smaller_tensor(
                            name
                        ),
                        cos_similarity=cos_similarity,
                        tol=tolerance,
                        force_max_multiplier=force_max_multiplier,
                        use_iou_for_bool=self.use_iou_for_bool_accuracy(name),
                        iou_threshold=self.get_iou_threshold(name),
                    ):
                        run_passed = False
                except Exception:
                    # Sometimes torch.allclose may throw RuntimeError
                    run_passed = False

                if run_passed:
                    pass_count += 1

                if accuracy_check_runs > 1:
                    log.info(
                        "Accuracy check run %d/%d: %s",
                        run_idx + 1,
                        accuracy_check_runs,
                        "passed" if run_passed else "failed",
                    )

            # Pass if majority of runs pass (more than half)
            is_same = pass_count > accuracy_check_runs // 2

            if accuracy_check_runs > 1:
                log.info(
                    "Accuracy check summary: %d/%d runs passed, %s",
                    pass_count,
                    accuracy_check_runs,
                    "PASS" if is_same else "FAIL",
                )

            if not is_same:
                if self.args.skip_accuracy_check:
                    accuracy_status = "pass_due_to_skip"
                else:
                    accuracy_status = "fail_accuracy"
                return record_status(accuracy_status, dynamo_start_stats=start_stats)

        return record_status(accuracy_status, dynamo_start_stats=start_stats)

Subdomains

Frequently Asked Questions

What does check_accuracy() do?
check_accuracy() is a function in the pytorch codebase.
What does check_accuracy() call?
check_accuracy() calls 15 function(s): cast_to_fp64, deepcopy_and_maybe_parallelize, empty_gpu_cache, get_accuracy_check_runs, get_dynamo_stats, get_iou_threshold, get_tolerance_and_cosine_flag, init_optimizer, and 7 more.
What calls check_accuracy()?
check_accuracy() is called by 2 function(s): minify_model, run_one_model.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free