main() — pytorch Function Reference
Architecture documentation for the main() function in dataloader_benchmark.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD bdd2576f_acbc_aabf_3971_d757c8bc122a["main()"] f6150286_ce63_72f2_12ae_6ae1460174f5["benchmark_dataloader()"] bdd2576f_acbc_aabf_3971_d757c8bc122a -->|calls| f6150286_ce63_72f2_12ae_6ae1460174f5 style bdd2576f_acbc_aabf_3971_d757c8bc122a fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
benchmarks/data/dataloader_benchmark.py lines 230–312
def main():
parser = argparse.ArgumentParser(
description="Benchmark PyTorch DataLoader with different worker methods"
)
parser.add_argument("--data_path", required=True, help="Path to dataset")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
parser.add_argument(
"--max_batches",
type=int,
default=100,
help="Maximum number of batches per epoch",
)
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs")
parser.add_argument(
"--multiprocessing_context",
choices=["fork", "spawn", "forkserver"],
default="forkserver",
help="Multiprocessing context to use (fork, spawn, forkserver)",
)
parser.add_argument(
"--dataset_copies",
type=int,
default=1,
help="Number of copies of the dataset to concatenate (for testing memory usage)",
)
parser.add_argument(
"--logging_freq",
type=int,
default=10,
help="Frequency of logging memory usage during training",
)
args = parser.parse_args()
# Print system info
print("System Information:")
# The following are handy for debugging if building from source worked correctly
print(f" PyTorch version: {torch.__version__}")
print(f" PyTorch location: {torch.__file__}")
print(f" Torchvision version: {torchvision.__version__}")
print(f" Torchvision location: {torchvision.__file__}")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA device: {torch.cuda.get_device_name(0)}")
print(f" CPU count: {psutil.cpu_count(logical=True)}")
print(f" Physical CPU cores: {psutil.cpu_count(logical=False)}")
print(f" Total system memory: {psutil.virtual_memory().total / (1024**3):.2f} GB")
# Define transforms
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Load dataset
print(f"\nLoading dataset from {args.data_path} ({args.dataset_copies} copies)")
# Try to load as ImageFolder
datasets = []
for _ in range(args.dataset_copies):
base_dataset = torchvision.datasets.ImageFolder(
args.data_path, transform=transform
)
datasets.append(copy.deepcopy(base_dataset))
del base_dataset
dataset = ConcatDataset(datasets)
print(f"Dataset size: {len(dataset)}")
# Run benchmark with specified worker method
benchmark_dataloader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
multiprocessing_context=args.multiprocessing_context,
num_epochs=args.num_epochs,
max_batches=args.max_batches,
logging_freq=args.logging_freq,
)
Domain
Subdomains
Calls
Source
Frequently Asked Questions
What does main() do?
main() is a function in the pytorch codebase.
What does main() call?
main() calls 1 function(s): benchmark_dataloader.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free