NewBatchSampler Class — pytorch Architecture
Architecture documentation for the NewBatchSampler class in samplers_benchmark.py from the pytorch codebase.
Entity Profile
Source Code
benchmarks/data/samplers_benchmark.py lines 13–65
class NewBatchSampler(Sampler[list[int]]):
"""Alternative implementation of BatchSampler for benchmarking purposes."""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
drop_last: bool,
) -> None:
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(
f"batch_size should be a positive integer value, but got batch_size={batch_size}"
)
if not isinstance(drop_last, bool):
raise ValueError(
f"drop_last should be a boolean value, but got drop_last={drop_last}"
)
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[list[int]]:
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free