_sparse_binary_op_intersection_kernel_out Class — pytorch Architecture
Architecture documentation for the _sparse_binary_op_intersection_kernel_out class in SparseBinaryOpIntersectionCommon.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h lines 419–476
template <
template <typename func_t> class kernel_t,
typename value_selection_intersection_kernel_t>
void _sparse_binary_op_intersection_kernel_out(
Tensor& res,
const Tensor& x,
const Tensor& y,
const std::optional<Tensor>& x_hash_opt = std::nullopt,
const std::optional<Tensor>& y_hash_opt = std::nullopt,
// If op distributes with the sum, the arguments are processed as is,
// without the calls to coalesce().
const bool distributive_with_sum = true
) {
TORCH_CHECK(
(x.is_sparse() && y.is_sparse())
&& (x.dim() == y.dim()) && (x.sparse_dim() == y.sparse_dim())
&& (x.sizes().slice(0, x.sparse_dim()) == y.sizes().slice(0, y.sparse_dim())),
NAME, "(): expects sparse inputs with equal dimensionality, ",
"number of sparse dimensions, and shape of sparse dimensions");
TORCH_CHECK(
x._indices().scalar_type() == y._indices().scalar_type(),
NAME, "(): expects inputs' indices to be of the same dtype (i.e. long or int)");
const auto check_hash_validity = [](const Tensor& t, const std::optional<Tensor>& t_hash_opt) {
if (!t_hash_opt.has_value()) {
return;
}
const auto &t_hash = *t_hash_opt;
TORCH_INTERNAL_ASSERT(
t_hash.dim() == 1 && t_hash.scalar_type() == kLong && t_hash.size(-1) == t._indices().size(-1),
NAME, "(): explicit hash values need to be a 1-dim Long tensor with the ",
"NSE matching that of the corresponding sparse tensor.");
};
check_hash_validity(x, x_hash_opt);
check_hash_validity(y, y_hash_opt);
const auto broadcasted_shape = infer_size(x.sizes(), y.sizes());
// 8 sparse dims should be more than enough?
constexpr int64_t max_sparse_dims = 8;
// COO indices are only 64-bit integers for now.
using index_t = int64_t;
if (max_sparse_dims > x.sparse_dim()) {
_sparse_binary_op_intersection_kernel_impl<
// For some reason MSVC complaints about passing constexpr max_sparse_dims
// as a template parameter claiming as if it is not know at compile time.
kernel_t, value_selection_intersection_kernel_t, index_t, 8>(
res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
} else {
_sparse_binary_op_intersection_kernel_impl<
kernel_t, value_selection_intersection_kernel_t, index_t>(
res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free