optimize_templates() — pytorch Function Reference
Architecture documentation for the optimize_templates() function in analyze_templates.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1["optimize_templates()"] fc92d87b_ad85_a501_c5b5_6e3bd9880812["main()"] fc92d87b_ad85_a501_c5b5_6e3bd9880812 -->|calls| ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1 style ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/dynamo/microbenchmarks/analyze_templates.py lines 37–196
def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False):
# Set of all possible Triton templates keyed by their attributes
triton_templates = set()
for timings in benchmark_logs.values():
for timing in timings:
if timing["type"] == "triton":
triton_templates.add(
(
timing["BLOCK_M"],
timing["BLOCK_N"],
timing["BLOCK_K"],
timing["num_stages"],
timing["num_warps"],
)
)
# Print the initial data
if verbose:
print("Occurrence Count:", occurrence_count)
print("Triton Templates:", triton_templates)
# Create a dictionary to store template selection variables
template_vars = {
template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary)
for template in triton_templates
}
# Variables to select specific timing option for each shape
selection_vars = {
(shape, "cublas"): pulp.LpVariable(
f"Select_{shape}_cublas", 0, 1, pulp.LpBinary
)
for shape in occurrence_count
}
for shape in occurrence_count:
for template in triton_templates:
selection_vars[(shape, template)] = pulp.LpVariable(
f"Select_{shape}_{template}", 0, 1, pulp.LpBinary
)
# Variables for the total time for each shape
min_time_vars = pulp.LpVariable.dicts(
"MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous
)
# Define the problem
prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize)
# Objective: Minimize the weighted total time
prob += pulp.lpSum(
[occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count]
)
# Constraints to select exactly N templates
prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N
# Store triton options per shape for debugging
triton_options_per_shape = {}
# Constraints for the total time for each shape
for shape in occurrence_count:
# Get cuBLAS time
cublas_times = [
timing["time"]
for timing in benchmark_logs[shape]
if timing["type"] == "cublas"
]
min_cublas_time = min(cublas_times)
# Collect Triton options
triton_options = []
for template in triton_templates:
triton_times = [
timing["time"]
for timing in benchmark_logs[shape]
if timing["type"] == "triton"
and (
timing["BLOCK_M"],
timing["BLOCK_N"],
timing["BLOCK_K"],
timing["num_stages"],
timing["num_warps"],
)
== template
]
if triton_times:
min_triton_time = min(triton_times)
triton_options.append((min_triton_time, template))
# Save triton options for debugging
triton_options_per_shape[shape] = triton_options
# Ensure exactly one timing option is selected for each shape
prob += (
pulp.lpSum(
[selection_vars[(shape, "cublas")]]
+ [
selection_vars[(shape, template)]
for triton_time, template in triton_options
]
)
== 1
)
# Ensure min_time_vars[shape] matches the selected timing option
prob += min_time_vars[shape] == (
selection_vars[(shape, "cublas")] * min_cublas_time
+ pulp.lpSum(
[
selection_vars[(shape, template)] * triton_time
for triton_time, template in triton_options
]
)
)
# Ensure Triton templates can only be selected if they are included in the N allowed templates
for triton_time, template in triton_options:
prob += selection_vars[(shape, template)] <= template_vars[template]
# Print the constraints
if verbose:
print("Constraints:")
for constraint in prob.constraints.values():
print(constraint)
# Solve the problem with suppressed output
prob.solve(pulp.PULP_CBC_CMD(msg=False))
# Output the selected templates and their configurations
selected_templates = [
template
for template in triton_templates
if pulp.value(template_vars[template]) == 1
]
total_time = sum(
pulp.value(min_time_vars[shape]) * occurrence_count[shape]
for shape in occurrence_count
)
# Print the values of the decision variables after solving
if verbose:
print("Decision Variable Values:")
for var in prob.variables():
print(f"{var.name} = {var.varValue}")
# # Debugging information
if verbose:
for shape in occurrence_count:
print(f"Shape: {shape}")
print(f" Min Time: {pulp.value(min_time_vars[shape])}")
print(f" Occurrences: {occurrence_count[shape]}")
print(
f" Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}"
)
for triton_time, template in triton_options_per_shape[shape]:
print(
f" Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}"
)
return selected_templates, total_time
Domain
Subdomains
Called By
Source
Frequently Asked Questions
What does optimize_templates() do?
optimize_templates() is a function in the pytorch codebase.
What calls optimize_templates()?
optimize_templates() is called by 1 function(s): main.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free