NestedTensor_elementwise_Tensor Class — pytorch Architecture
Architecture documentation for the NestedTensor_elementwise_Tensor class in NestedTensorBinaryOps.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp lines 74–180
template <typename Func>
static Tensor NestedTensor_elementwise_Tensor(
const Tensor& self,
const Tensor& other,
const std::string& op_name,
bool supports_striding,
Func f) {
Tensor self_contiguous = self;
Tensor other_contiguous = other;
// self is a scalar
if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
auto other_impl = get_nested_tensor_impl(other);
return wrap_buffer(
f(self, other_impl->get_unsafe_storage_as_tensor()),
other_impl->get_nested_sizes().clone(),
other_impl->get_nested_strides().clone(),
other_impl->get_storage_offsets()
);
}
// other is a scalar
if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
auto self_impl = get_nested_tensor_impl(self);
return wrap_buffer(
f(self_impl->get_unsafe_storage_as_tensor(), other),
self_impl->get_nested_sizes().clone(),
self_impl->get_nested_strides().clone(),
self_impl->get_storage_offsets()
);
}
// special case when other is dense (CUDA only for now)
if (self.is_nested() && !other.is_nested() && self.is_cuda() && other.is_cuda()) {
auto self_ptr = get_nested_tensor_impl(self);
auto other_ = other;
// check for the [B, *, D], [B, 1, D] case -> use custom kernel
// TODO: this if statement is ugly and hopefully we will remove this in the near future
bool is_broadcastable_3d = (
self_ptr->dim() == 3 &&
other.dim() == 3 &&
self_ptr->size(0) == other.size(0) &&
other.size(1) == 1 &&
self_ptr->opt_size(2).has_value() &&
self_ptr->opt_size(2) == other.size(2));
// check for the [B, *], [B, 1] case -> treat as 3D with [B, *, 1], [B, 1, 1]
bool is_broadcastable_2d = (
self_ptr->dim() == 2 &&
other.dim() == 2 &&
self_ptr->size(0) == other.size(0) &&
other.size(1) == 1);
if(is_broadcastable_2d) {
other_ = other.unsqueeze(-1);
is_broadcastable_3d = true;
}
if (is_broadcastable_3d) {
self_contiguous = self.contiguous();
self_ptr = get_nested_tensor_impl(self_contiguous);
const auto self_buffer = self_ptr->get_buffer();
const auto self_sizes = self_ptr->get_nested_sizes();
auto result_buffer = at::empty_like(self_buffer);
auto result = wrap_buffer(result_buffer, self_sizes);
if (op_name == "add") {
nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::ADD);
} else if (op_name == "mul") {
nested_dense_elementwise_stub(self.device().type(), result, self, other_, NESTED_DENSE_OP::MUL);
} else {
TORCH_CHECK(false, "Unsupported nested dense elementwise op: ", op_name, ".");
}
return result;
}
// check for the [B, C, *, *], [C, 1, 1] case
bool is_broadcastable_4d_3d = (
self_ptr->dim() == 4 &&
other.dim() == 3 &&
self_ptr->opt_size(1).has_value() &&
self_ptr->size(1) == other.size(0) &&
other.size(1) == 1 &&
other.size(2) == 1);
if (is_broadcastable_4d_3d) {
std::vector<Tensor> results;
for (const auto& t : self.unbind()) {
results.push_back(f(t, other));
}
return at::_nested_tensor_from_tensor_list(results);
}
TORCH_CHECK(
false,
"Expected both self and other to be nested, but got a nested self and non-nested other for op: ",
op_name,
".");
}
self_contiguous = supports_striding ? self.contiguous() : self;
other_contiguous = supports_striding ? other.contiguous() : other;
auto [self_impl, other_impl] =
get_elementwise_nested_tensor_impl(self_contiguous, other_contiguous, op_name);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
return wrap_buffer(
f(self_impl->get_unsafe_storage_as_tensor(),
other_impl->get_unsafe_storage_as_tensor()),
self_impl->get_nested_sizes(),
self_impl->get_nested_strides(),
self_impl->get_storage_offsets());
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free