Home / Function/ main() — pytorch Function Reference

main() — pytorch Function Reference

Architecture documentation for the main() function in benchmark.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  40a19abe_a75a_07e4_e110_69bd1273c827["main()"]
  33d8ab08_0a75_8d11_3878_4c0cc92b472f["show_environment_info()"]
  40a19abe_a75a_07e4_e110_69bd1273c827 -->|calls| 33d8ab08_0a75_8d11_3878_4c0cc92b472f
  89f1c8fe_d04e_ec92_38d0_c423b1746ef7["list_benchmarks()"]
  40a19abe_a75a_07e4_e110_69bd1273c827 -->|calls| 89f1c8fe_d04e_ec92_38d0_c423b1746ef7
  7720c5f4_4d99_dc08_206e_db4b625e9915["run_all_benchmarks()"]
  40a19abe_a75a_07e4_e110_69bd1273c827 -->|calls| 7720c5f4_4d99_dc08_206e_db4b625e9915
  e6249888_6b0d_225b_6d21_91a095fe832b["run_benchmark()"]
  40a19abe_a75a_07e4_e110_69bd1273c827 -->|calls| e6249888_6b0d_225b_6d21_91a095fe832b
  style 40a19abe_a75a_07e4_e110_69bd1273c827 fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/dynamo/genai_layers/benchmark.py lines 105–216

def main():
    show_environment_info()

    parser = argparse.ArgumentParser(
        description="Benchmark runner for kernel implementations",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python benchmark.py --list                    # List all available benchmarks
  python benchmark.py --all                     # Run all benchmarks
  python benchmark.py cross_entropy_forward     # Run specific benchmark
  python benchmark.py softmax_forward softmax_backward  # Run multiple benchmarks
        """,
    )

    parser.add_argument(
        "benchmarks",
        nargs="*",
        help="Names of benchmarks to run (use --list to see available options)",
    )

    parser.add_argument(
        "--list", action="store_true", help="List all available benchmarks"
    )

    parser.add_argument(
        "--all", action="store_true", help="Run all available benchmarks"
    )

    parser.add_argument(
        "--visualize",
        action="store_true",
        help="Visualize results after running benchmarks",
    )

    parser.add_argument(
        "--compile-mode",
        choices=["default", "max-autotune-no-cudagraphs"],
        default="max-autotune-no-cudagraphs",
        help="Torch compile mode to use (default: default)",
    )

    parser.add_argument(
        "--tolerance",
        type=float,
        default=None,
        help="Tolerance for the accuracy check",
    )

    parser.add_argument(
        "--exit-on-accuracy-failure",
        action="store_true",
        help="Whether to exit with an error message for accuracy failure",
    )

    parser.add_argument(
        "--print-benchmark-result",
        action="store_true",
        help="Whether to print the raw benchmarking result. Easier to quickly check the benchmark results on a server without GUI",
    )

    parser.add_argument(
        "--custom-compile-name",
        type=str,
        default=None,
        help="Name for the curve with customized compilation options",
    )

    parser.add_argument(
        "--custom-compile-options",
        type=str,
        default=None,
        help="Json string for the custom compile options.",
    )

    args = parser.parse_args()

    if args.custom_compile_options:
        import json

        try:
            args.custom_compile_options = json.loads(args.custom_compile_options)
        except json.decoder.JSONDecodeError as e:
            raise RuntimeError(
                f"Invalid json string for --custom-compile-options: {args.custom_compile_options}"
            ) from e

        if not args.custom_compile_options:
            raise RuntimeError("Found no options for --custom-compile-options")
        if not args.custom_compile_name:
            raise RuntimeError("Missing label name for the custom compilation")

    # Handle list option
    if args.list:
        list_benchmarks()
        return

    # Handle all option
    if args.all:
        run_all_benchmarks(args)
        return

    # Handle specific benchmarks
    if not args.benchmarks:
        print("Error: No benchmarks specified")
        print("Use --list to see available benchmarks or --all to run all benchmarks")
        parser.print_help()
        sys.exit(1)

    for benchmark_name in args.benchmarks:
        run_benchmark(benchmark_name, args)
        print()  # Add spacing between benchmarks

Subdomains

Frequently Asked Questions

What does main() do?
main() is a function in the pytorch codebase.
What does main() call?
main() calls 4 function(s): list_benchmarks, run_all_benchmarks, run_benchmark, show_environment_info.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free