check_tolerance() — pytorch Function Reference
Architecture documentation for the check_tolerance() function in common.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 0f41377a_e71d_fd3c_1974_a0ef9ec1158e["check_tolerance()"] 9bf8449e_2d7f_c370_514b_b3c7bf20f8e1["run_one_model()"] 9bf8449e_2d7f_c370_514b_b3c7bf20f8e1 -->|calls| 0f41377a_e71d_fd3c_1974_a0ef9ec1158e da0c865a_ac14_7a10_8fc5_8a3b7509426d["maybe_cast()"] 0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| da0c865a_ac14_7a10_8fc5_8a3b7509426d 6c83aab9_f1ee_6751_91aa_682a715a5746["init_optimizer()"] 0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 6c83aab9_f1ee_6751_91aa_682a715a5746 316b98a5_2d78_5681_cc76_6024cfcb4191["run_n_iterations()"] 0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 316b98a5_2d78_5681_cc76_6024cfcb4191 3473d1a5_c1f5_fc97_006e_79a1d3081bef["write_outputs()"] 0f41377a_e71d_fd3c_1974_a0ef9ec1158e -->|calls| 3473d1a5_c1f5_fc97_006e_79a1d3081bef style 0f41377a_e71d_fd3c_1974_a0ef9ec1158e fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/common.py lines 2506–2581
def check_tolerance(
self, name, model, example_inputs, optimize_ctx, base_device="cpu"
):
"""
Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
"""
tolerance_status = "pass"
if name in self.skip_accuracy_checks_large_models_dashboard:
tolerance_status = "pass_due_to_skip"
return tolerance_status
# Cast the model to float16/float32 as necessary
model, example_inputs = self.maybe_cast(model, example_inputs)
with self.pick_grad(name, self.args.training):
# Get results of native pytorch
reset_rng_state()
model_copy = copy.deepcopy(model)
model_copy = model_copy.to(base_device)
example_inputs_copy = copy.deepcopy(example_inputs)
example_inputs_copy = tree_map(
lambda x: x.to(base_device), example_inputs_copy
)
self.init_optimizer(name, base_device, model_copy.parameters())
correct_result = self.run_n_iterations(
model_copy, example_inputs_copy, self.model_iter_fn
)
# Run with Dynamo
# Sometime CI fails with random triton compilation failure which will be skipped for now
# TODO: revisit this after switching to new Triton runtime
reset_rng_state()
torch._dynamo.reset()
try:
self.init_optimizer(name, current_device, model.parameters())
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
except Exception:
log.exception("")
print(
"TorchDynamo optimized model failed to run because of following error"
)
return "fail_to_run"
def dump_max_mean_values(tol, ref, res):
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
for refi, resi in zip(ref, res):
dump_max_mean_values(tol, refi, resi)
elif isinstance(ref, dict):
for k in ref:
dump_max_mean_values(tol, ref[k], res[k])
elif isinstance(ref, torch.Tensor):
res = res.to(base_device)
t = torch.abs(ref - res) / (1 + torch.abs(ref))
tol.append(t.flatten().to(torch.float32))
return tol
tol = []
dump_max_mean_values(tol, correct_result, new_result)
tol = torch.cat(tol)
tol = torch.tensor(tol)
max = torch.max(tol)
mean = torch.mean(tol)
div = torch.std(tol)
headers = ["dev", "name", "batch_size", "max", "mean", "std"]
fields = [
current_device,
current_name,
current_batch_size,
max.item(),
mean.item(),
div.item(),
]
write_outputs(output_filename, headers, fields)
return tolerance_status
Domain
Subdomains
Called By
Source
Frequently Asked Questions
What does check_tolerance() do?
check_tolerance() is a function in the pytorch codebase.
What does check_tolerance() call?
check_tolerance() calls 4 function(s): init_optimizer, maybe_cast, run_n_iterations, write_outputs.
What calls check_tolerance()?
check_tolerance() is called by 1 function(s): run_one_model.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free