Home / Class/ NestedTensor_elementwise_Tensor Class — pytorch Architecture

NestedTensor_elementwise_Tensor Class — pytorch Architecture

Architecture documentation for the NestedTensor_elementwise_Tensor class in NestedTensorBinaryOps.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp lines 74–180

template <typename Func>
static Tensor NestedTensor_elementwise_Tensor(
    const Tensor& self,
    const Tensor& other,
    const std::string& op_name,
    bool supports_striding,
    Func f) {
  Tensor self_contiguous = self;
  Tensor other_contiguous = other;
  // self is a scalar
  if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
    auto other_impl = get_nested_tensor_impl(other);
    return wrap_buffer(
      f(self, other_impl->get_unsafe_storage_as_tensor()),
      other_impl->get_nested_sizes().clone(),
      other_impl->get_nested_strides().clone(),
      other_impl->get_storage_offsets()
    );
  }
  // other is a scalar
  if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
    auto self_impl = get_nested_tensor_impl(self);
    return wrap_buffer(
      f(self_impl->get_unsafe_storage_as_tensor(), other),
      self_impl->get_nested_sizes().clone(),
      self_impl->get_nested_strides().clone(),
      self_impl->get_storage_offsets()
    );
  }
  // special case when other is dense (CUDA only for now)
  if (self.is_nested() && !other.is_nested() && self.is_cuda() && other.is_cuda()) {
    auto self_ptr = get_nested_tensor_impl(self);
    auto other_ = other;
    // check for the [B, *, D], [B, 1, D] case -> use custom kernel
    // TODO: this if statement is ugly and hopefully we will remove this in the near future
    bool is_broadcastable_3d = (
        self_ptr->dim() == 3 &&
        other.dim() == 3 &&
        self_ptr->size(0) == other.size(0) &&
        other.size(1) == 1 &&
        self_ptr->opt_size(2).has_value() &&
        self_ptr->opt_size(2) == other.size(2));
    // check for the [B, *], [B, 1] case -> treat as 3D with [B, *, 1], [B, 1, 1]
    bool is_broadcastable_2d = (
        self_ptr->dim() == 2 &&
        other.dim() == 2 &&
        self_ptr->size(0) == other.size(0) &&
        other.size(1) == 1);
    if(is_broadcastable_2d) {
        other_ = other.unsqueeze(-1);
        is_broadcastable_3d = true;
    }

    if (is_broadcastable_3d) {
      self_contiguous = self.contiguous();
      self_ptr = get_nested_tensor_impl(self_contiguous);
      const auto self_buffer = self_ptr->get_buffer();
      const auto self_sizes = self_ptr->get_nested_sizes();
      auto result_buffer = at::empty_like(self_buffer);
      auto result = wrap_buffer(result_buffer, self_sizes);
      if (op_name == "add") {
        nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::ADD);
      } else if (op_name == "mul") {
        nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
      } else {
        TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
      }
      return result;
    }

    // check for the [B, C, *, *], [C, 1, 1] case
    bool is_broadcastable_4d_3d = (
        self_ptr->dim() == 4 &&
        other.dim() == 3 &&
        self_ptr->opt_size(1).has_value() &&
        self_ptr->size(1) == other.size(0) &&
        other.size(1) == 1 &&
        other.size(2) == 1);
    if (is_broadcastable_4d_3d) {
      std::vector<Tensor> results;
      for (const auto& t : self.unbind()) {
        results.push_back(f(t, other));
      }
      return at::_nested_tensor_from_tensor_list(results);
    }

    TORCH_CHECK(
        false,
        "Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
        op_name,
        ".");
  }

  self_contiguous = supports_striding ? self.contiguous() : self;
  other_contiguous = supports_striding ? other.contiguous() : other;

  auto [self_impl, other_impl] =
      get_elementwise_nested_tensor_impl(self_contiguous, other_contiguous, op_name);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
  return wrap_buffer(
      f(self_impl->get_unsafe_storage_as_tensor(),
        other_impl->get_unsafe_storage_as_tensor()),
      self_impl->get_nested_sizes(),
      self_impl->get_nested_strides(),
      self_impl->get_storage_offsets());
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free