Home / Class/ _sparse_binary_op_intersection_kernel_out Class — pytorch Architecture

_sparse_binary_op_intersection_kernel_out Class — pytorch Architecture

Architecture documentation for the _sparse_binary_op_intersection_kernel_out class in SparseBinaryOpIntersectionCommon.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h lines 419–476

template <
  template <typename func_t> class kernel_t,
  typename value_selection_intersection_kernel_t>
void _sparse_binary_op_intersection_kernel_out(
    Tensor& res,
    const Tensor& x,
    const Tensor& y,
    const std::optional<Tensor>& x_hash_opt = std::nullopt,
    const std::optional<Tensor>& y_hash_opt = std::nullopt,
    // If op distributes with the sum, the arguments are processed as is,
    // without the calls to coalesce().
    const bool distributive_with_sum = true
) {
  TORCH_CHECK(
      (x.is_sparse() && y.is_sparse())
      && (x.dim() == y.dim()) && (x.sparse_dim() == y.sparse_dim())
      && (x.sizes().slice(0, x.sparse_dim()) == y.sizes().slice(0, y.sparse_dim())),
      NAME, "(): expects sparse inputs with equal dimensionality, ",
      "number of sparse dimensions, and shape of sparse dimensions");
  TORCH_CHECK(
      x._indices().scalar_type() == y._indices().scalar_type(),
      NAME, "(): expects inputs' indices to be of the same dtype (i.e. long or int)");

  const auto check_hash_validity = [](const Tensor& t, const std::optional<Tensor>& t_hash_opt) {
    if (!t_hash_opt.has_value()) {
      return;
    }

    const auto &t_hash = *t_hash_opt;
    TORCH_INTERNAL_ASSERT(
        t_hash.dim() == 1 && t_hash.scalar_type() == kLong && t_hash.size(-1) == t._indices().size(-1),
        NAME, "(): explicit hash values need to be a 1-dim Long tensor with the ",
        "NSE matching that of the corresponding sparse tensor.");
  };

  check_hash_validity(x, x_hash_opt);
  check_hash_validity(y, y_hash_opt);

  const auto broadcasted_shape = infer_size(x.sizes(), y.sizes());

  // 8 sparse dims should be more than enough?
  constexpr int64_t max_sparse_dims = 8;

  // COO indices are only 64-bit integers for now.
  using index_t = int64_t;

  if (max_sparse_dims > x.sparse_dim()) {
    _sparse_binary_op_intersection_kernel_impl<
      // For some reason MSVC complaints about passing constexpr max_sparse_dims
      // as a template parameter claiming as if it is not know at compile time.
      kernel_t, value_selection_intersection_kernel_t, index_t, 8>(
        res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
  } else {
    _sparse_binary_op_intersection_kernel_impl<
      kernel_t, value_selection_intersection_kernel_t, index_t>(
        res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free