Home / Function/ main() — pytorch Function Reference

main() — pytorch Function Reference

Architecture documentation for the main() function in dataloader_benchmark.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  bdd2576f_acbc_aabf_3971_d757c8bc122a["main()"]
  f6150286_ce63_72f2_12ae_6ae1460174f5["benchmark_dataloader()"]
  bdd2576f_acbc_aabf_3971_d757c8bc122a -->|calls| f6150286_ce63_72f2_12ae_6ae1460174f5
  style bdd2576f_acbc_aabf_3971_d757c8bc122a fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

benchmarks/data/dataloader_benchmark.py lines 230–312

def main():
    parser = argparse.ArgumentParser(
        description="Benchmark PyTorch DataLoader with different worker methods"
    )
    parser.add_argument("--data_path", required=True, help="Path to dataset")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
    parser.add_argument(
        "--max_batches",
        type=int,
        default=100,
        help="Maximum number of batches per epoch",
    )
    parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs")
    parser.add_argument(
        "--multiprocessing_context",
        choices=["fork", "spawn", "forkserver"],
        default="forkserver",
        help="Multiprocessing context to use (fork, spawn, forkserver)",
    )
    parser.add_argument(
        "--dataset_copies",
        type=int,
        default=1,
        help="Number of copies of the dataset to concatenate (for testing memory usage)",
    )
    parser.add_argument(
        "--logging_freq",
        type=int,
        default=10,
        help="Frequency of logging memory usage during training",
    )
    args = parser.parse_args()

    # Print system info
    print("System Information:")
    # The following are handy for debugging if building from source worked correctly
    print(f"  PyTorch version: {torch.__version__}")
    print(f"  PyTorch location: {torch.__file__}")
    print(f"  Torchvision version: {torchvision.__version__}")
    print(f"  Torchvision location: {torchvision.__file__}")
    print(f"  CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"  CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"  CPU count: {psutil.cpu_count(logical=True)}")
    print(f"  Physical CPU cores: {psutil.cpu_count(logical=False)}")
    print(f"  Total system memory: {psutil.virtual_memory().total / (1024**3):.2f} GB")

    # Define transforms
    transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Load dataset
    print(f"\nLoading dataset from {args.data_path} ({args.dataset_copies} copies)")

    # Try to load as ImageFolder
    datasets = []
    for _ in range(args.dataset_copies):
        base_dataset = torchvision.datasets.ImageFolder(
            args.data_path, transform=transform
        )
        datasets.append(copy.deepcopy(base_dataset))
        del base_dataset
    dataset = ConcatDataset(datasets)

    print(f"Dataset size: {len(dataset)}")

    # Run benchmark with specified worker method
    benchmark_dataloader(
        dataset,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        multiprocessing_context=args.multiprocessing_context,
        num_epochs=args.num_epochs,
        max_batches=args.max_batches,
        logging_freq=args.logging_freq,
    )

Domain

Subdomains

Frequently Asked Questions

What does main() do?
main() is a function in the pytorch codebase.
What does main() call?
main() calls 1 function(s): benchmark_dataloader.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free