Home / Function/ optimize_templates() — pytorch Function Reference

optimize_templates() — pytorch Function Reference

Architecture documentation for the optimize_templates() function in analyze_templates.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1["optimize_templates()"]
  fc92d87b_ad85_a501_c5b5_6e3bd9880812["main()"]
  fc92d87b_ad85_a501_c5b5_6e3bd9880812 -->|calls| ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1
  style ba6f4aaa_2ca6_4636_6393_57fd7cfb55c1 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/microbenchmarks/analyze_templates.py lines 37–196

def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False):
    # Set of all possible Triton templates keyed by their attributes
    triton_templates = set()
    for timings in benchmark_logs.values():
        for timing in timings:
            if timing["type"] == "triton":
                triton_templates.add(
                    (
                        timing["BLOCK_M"],
                        timing["BLOCK_N"],
                        timing["BLOCK_K"],
                        timing["num_stages"],
                        timing["num_warps"],
                    )
                )

    # Print the initial data
    if verbose:
        print("Occurrence Count:", occurrence_count)
        print("Triton Templates:", triton_templates)

    # Create a dictionary to store template selection variables
    template_vars = {
        template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary)
        for template in triton_templates
    }

    # Variables to select specific timing option for each shape
    selection_vars = {
        (shape, "cublas"): pulp.LpVariable(
            f"Select_{shape}_cublas", 0, 1, pulp.LpBinary
        )
        for shape in occurrence_count
    }
    for shape in occurrence_count:
        for template in triton_templates:
            selection_vars[(shape, template)] = pulp.LpVariable(
                f"Select_{shape}_{template}", 0, 1, pulp.LpBinary
            )

    # Variables for the total time for each shape
    min_time_vars = pulp.LpVariable.dicts(
        "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous
    )

    # Define the problem
    prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize)

    # Objective: Minimize the weighted total time
    prob += pulp.lpSum(
        [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count]
    )

    # Constraints to select exactly N templates
    prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N

    # Store triton options per shape for debugging
    triton_options_per_shape = {}

    # Constraints for the total time for each shape
    for shape in occurrence_count:
        # Get cuBLAS time
        cublas_times = [
            timing["time"]
            for timing in benchmark_logs[shape]
            if timing["type"] == "cublas"
        ]
        min_cublas_time = min(cublas_times)

        # Collect Triton options
        triton_options = []
        for template in triton_templates:
            triton_times = [
                timing["time"]
                for timing in benchmark_logs[shape]
                if timing["type"] == "triton"
                and (
                    timing["BLOCK_M"],
                    timing["BLOCK_N"],
                    timing["BLOCK_K"],
                    timing["num_stages"],
                    timing["num_warps"],
                )
                == template
            ]
            if triton_times:
                min_triton_time = min(triton_times)
                triton_options.append((min_triton_time, template))

        # Save triton options for debugging
        triton_options_per_shape[shape] = triton_options

        # Ensure exactly one timing option is selected for each shape
        prob += (
            pulp.lpSum(
                [selection_vars[(shape, "cublas")]]
                + [
                    selection_vars[(shape, template)]
                    for triton_time, template in triton_options
                ]
            )
            == 1
        )

        # Ensure min_time_vars[shape] matches the selected timing option
        prob += min_time_vars[shape] == (
            selection_vars[(shape, "cublas")] * min_cublas_time
            + pulp.lpSum(
                [
                    selection_vars[(shape, template)] * triton_time
                    for triton_time, template in triton_options
                ]
            )
        )

        # Ensure Triton templates can only be selected if they are included in the N allowed templates
        for triton_time, template in triton_options:
            prob += selection_vars[(shape, template)] <= template_vars[template]

    # Print the constraints
    if verbose:
        print("Constraints:")
        for constraint in prob.constraints.values():
            print(constraint)

    # Solve the problem with suppressed output
    prob.solve(pulp.PULP_CBC_CMD(msg=False))

    # Output the selected templates and their configurations
    selected_templates = [
        template
        for template in triton_templates
        if pulp.value(template_vars[template]) == 1
    ]
    total_time = sum(
        pulp.value(min_time_vars[shape]) * occurrence_count[shape]
        for shape in occurrence_count
    )

    # Print the values of the decision variables after solving
    if verbose:
        print("Decision Variable Values:")
        for var in prob.variables():
            print(f"{var.name} = {var.varValue}")

    # # Debugging information
    if verbose:
        for shape in occurrence_count:
            print(f"Shape: {shape}")
            print(f"  Min Time: {pulp.value(min_time_vars[shape])}")
            print(f"  Occurrences: {occurrence_count[shape]}")
            print(
                f"  Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}"
            )
            for triton_time, template in triton_options_per_shape[shape]:
                print(
                    f"  Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}"
                )

    return selected_templates, total_time

Subdomains

Called By

Frequently Asked Questions

What does optimize_templates() do?
optimize_templates() is a function in the pytorch codebase.
What calls optimize_templates()?
optimize_templates() is called by 1 function(s): main.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free