Home / Class/ Func Class — pytorch Architecture

Func Class — pytorch Architecture

Architecture documentation for the Func class in LegacyBatchingRegistrations.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/LegacyBatchingRegistrations.cpp lines 92–150

template <typename F, F Func, typename... ExtraArgs>
Tensor binary_pointwise_batching_rule(
    const Tensor& self, const Tensor& other, ExtraArgs... args) {
  if (self.dim() > 0 && other.dim() > 0) {
    auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
    auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
    return physical_args[0].getPhysicalToLogicalMap().apply(result);
  }
  if (isPhysicalScalarTensor(self)) {
    auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
    auto result = Func(self, other_physical.tensor(), args...);
    return other_physical.getPhysicalToLogicalMap().apply(result);
  }
  if (isPhysicalScalarTensor(other)) {
    auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
    auto result = Func(self_physical.tensor(), other, args...);
    return self_physical.getPhysicalToLogicalMap().apply(result);
  }

  // At this point, we know at least one of the operands is a logical Scalar tensor.
  // Here we must emulate TensorIterator's special behavior on Scalars.
  //
  // As a motivating example, consider the following:
  //   x = torch.randn(3, 10)
  //   y = torch.randn(3, dtype=torch.double)
  //   vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
  //
  // At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
  // Type Promotion dictates that the result should be FloatTensor[10].
  // This means we cannot directly pass the physical tensors (x and y) to
  // TensorIterator (if we did, it would promote them to DoubleTensor).
  //
  // FIXME(rzou): I didn't want to go down the slippery slope of emulating
  // everything TensorIterator does (it would be better to refactor out the
  // TensorIterator logic). The one thing that this code doesn't handle
  // is cross-device logical scalar tensors.
  //   cpu_tensor = torch.randn(3)
  //   cuda_tensor = torch.randn(3, 10, device='cuda')
  //   vmap(torch.mul)(cpu_tensor, cuda_tensor)
  //
  // At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
  // TensorIterator allows for this cross-device operation because one of the
  // tensors is a Scalar CPU tensor. However, the following code will throw an
  // error in that case. I don't expect to see many use cases for this, so
  // this is probably fine as-is.
  auto logical_self = self;
  auto logical_other = other;
  auto result_type = at::native::result_type(logical_self, logical_other);
  if (logical_self.scalar_type() != result_type) {
    logical_self = logical_self.to(result_type);
  }
  if (logical_other.scalar_type() != result_type) {
    logical_other = logical_other.to(result_type);
  }
  auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
      {std::move(logical_self), std::move(logical_other)});
  auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
  return physical_args[0].getPhysicalToLogicalMap().apply(result);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free