Func Class — pytorch Architecture
Architecture documentation for the Func class in LegacyBatchingRegistrations.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/LegacyBatchingRegistrations.cpp lines 92–150
template <typename F, F Func, typename... ExtraArgs>
Tensor binary_pointwise_batching_rule(
const Tensor& self, const Tensor& other, ExtraArgs... args) {
if (self.dim() > 0 && other.dim() > 0) {
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
if (isPhysicalScalarTensor(self)) {
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
auto result = Func(self, other_physical.tensor(), args...);
return other_physical.getPhysicalToLogicalMap().apply(result);
}
if (isPhysicalScalarTensor(other)) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = Func(self_physical.tensor(), other, args...);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
// At this point, we know at least one of the operands is a logical Scalar tensor.
// Here we must emulate TensorIterator's special behavior on Scalars.
//
// As a motivating example, consider the following:
// x = torch.randn(3, 10)
// y = torch.randn(3, dtype=torch.double)
// vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
//
// At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
// Type Promotion dictates that the result should be FloatTensor[10].
// This means we cannot directly pass the physical tensors (x and y) to
// TensorIterator (if we did, it would promote them to DoubleTensor).
//
// FIXME(rzou): I didn't want to go down the slippery slope of emulating
// everything TensorIterator does (it would be better to refactor out the
// TensorIterator logic). The one thing that this code doesn't handle
// is cross-device logical scalar tensors.
// cpu_tensor = torch.randn(3)
// cuda_tensor = torch.randn(3, 10, device='cuda')
// vmap(torch.mul)(cpu_tensor, cuda_tensor)
//
// At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
// TensorIterator allows for this cross-device operation because one of the
// tensors is a Scalar CPU tensor. However, the following code will throw an
// error in that case. I don't expect to see many use cases for this, so
// this is probably fine as-is.
auto logical_self = self;
auto logical_other = other;
auto result_type = at::native::result_type(logical_self, logical_other);
if (logical_self.scalar_type() != result_type) {
logical_self = logical_self.to(result_type);
}
if (logical_other.scalar_type() != result_type) {
logical_other = logical_other.to(result_type);
}
auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
{std::move(logical_self), std::move(logical_other)});
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
return physical_args[0].getPhysicalToLogicalMap().apply(result);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free