Home / Class/ intersection_binary_op_sparse_dense_out Class — pytorch Architecture

intersection_binary_op_sparse_dense_out Class — pytorch Architecture

Architecture documentation for the intersection_binary_op_sparse_dense_out class in SparseTensorMath.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseTensorMath.cpp lines 831–994

template <typename binary_func_t>
static Tensor& intersection_binary_op_sparse_dense_out(
    const Tensor& d,
    const SparseTensor& s_,
    Tensor& res,
    const char* const op_name,
    const binary_func_t& op,
    const bool coalesce = false) {
  // compute broadcasted shape.
  const auto res_shape = infer_size(d.sizes(), s_.sizes());

  // Short-circuit if either s_ or d is empty.
  if (!s_._nnz() || !s_.numel() || !d.numel()) {
    const int64_t dense_dim = s_.dense_dim();
    const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
    const int64_t nnz = 0;
    const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
    auto res_values_shape = s_._values().sizes().vec();
    res_values_shape[0] = nnz;
    const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
    auto* res_impl = get_sparse_impl(res);
    res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
    res_impl->set_indices_and_values_unsafe(indices, values);
    res_impl->set_nnz_and_narrow(nnz);
    return res._coalesced_(true);
  }

  const auto d_dim = d.dim();
  const auto s_dim = s_.dim();

  // Always coalesce when sparse broadcasts over dense,
  // because new sparse dimensions are created and
  // repeated indices have to be eliminated because of that.
  const auto s = (coalesce || d_dim > s_dim) ? s_.coalesce() : s_;

  const auto sparse_dim = s.sparse_dim();
  const auto dense_dim = s.dense_dim();

  const auto s_indices = s._indices();
  const auto s_values = s._values();

  const auto apply_op = [&](const Tensor& d_filtered) -> Tensor& {
    const auto res_indices = s_indices.clone();
    // to(res.scalar_type) is only performed when both d and s are 0-dim.
    // This insures right type promotions with the following rules:
    // op(0-dim, 0-dim).dtype == <common dtype>
    // op(0-dim, ge-1-dim).dtype == <ge-1-dim>.dtype,
    // where ge-1-dim is a tensor with dim >= 1.
    // We do not cast if op is performed in-place.
    // The cast is required if s is 0-dim non-coalesced tensor and d is 0-dim.
    // This is because s.values is at least 1D, so
    // op(s.values, d).dtype == s.values.dtype, but we want
    // op(s.values, d).dtype == <common dtype>.
    const auto values = op(d_filtered, s_values);
    const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
    auto* res_impl = get_sparse_impl(res);
    res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
    res_impl->set_indices_and_values_unsafe(res_indices, res_values);
    res_impl->set_nnz_and_narrow(s._nnz());
    return res._coalesced_(s.is_coalesced());
  };

  // Easiest case: only dense dimensions intersect.
  // This means only value tensors interact.
  if (d_dim <= dense_dim) {
    return apply_op(d);
  }

  // Now we have intersection between sparse and dense dims.
  const auto sparse_dim_intersec = std::min(sparse_dim, d_dim - dense_dim);
  const auto d_start_dim_intersec = std::max<int64_t>(0, d_dim - s_dim);
  const auto s_start_dim_intersec = std::max<int64_t>(0, s_dim - d_dim);

  // Index d with s_indices to find values which
  // interact with s_values.
  const auto d_filtered = [&]() -> Tensor {
    using at::indexing::Slice;
    using at::indexing::Ellipsis;
    using at::indexing::TensorIndex;

    std::vector<TensorIndex> intersec_indices;
    intersec_indices.reserve(d_dim);

    if (d_start_dim_intersec) {
      intersec_indices.emplace_back(Ellipsis);
    }
    for (const auto i : c10::irange(sparse_dim_intersec)) {
      const auto s_idx = s_start_dim_intersec + i;
      intersec_indices.emplace_back(s_indices[s_idx]);
    }
    for (auto i = d_start_dim_intersec + sparse_dim_intersec; i < d_dim; ++i) {
      intersec_indices.emplace_back(Slice());
    }
    // we need to expand d in the dimensions it is being indexed into
    // to avoid out of bound indices
    const auto d_expanded_shape = std::vector<int64_t>(
        res_shape.end() - d_dim, res_shape.end());
    return d.expand(d_expanded_shape).index(intersec_indices);
  }();

  // When dims match or sparse is "larger", the result nnz is the same,
  // so only values get modified.
  if (s_dim >= d_dim) {
    return apply_op(d_filtered);
  }

  // Otherwise nnz gets larger, and both indices and values need an update.
  const auto d_batch_shape = d.sizes().slice(0, d_start_dim_intersec);
  const auto d_batch_len = static_cast<int64_t>(d_batch_shape.size());
  int64_t batch_count = 1;
  int64_t max_batch_dim = 0;
  std::tie(batch_count, max_batch_dim) = [d_batch_shape]() -> std::tuple<int64_t, int64_t> {
    int64_t batch_count = 1;
    int64_t max_batch_dim = 0;
    for (const auto& b : d_batch_shape) {
      batch_count *= b;
      max_batch_dim = std::max(b, max_batch_dim);
    }
    return std::make_tuple(batch_count, max_batch_dim);
  }();

  const auto res_sparse_dim = static_cast<int64_t>(d_batch_shape.size()) + sparse_dim;
  const auto res_dense_dim = dense_dim;
  const auto s_nnz = s._nnz();
  const auto res_nnz = batch_count * s_nnz;
  auto res_values_shape = s_values.sizes().vec();
  res_values_shape[0] = res_nnz;
  const auto res_values = op(d_filtered, s_values).reshape(res_values_shape);
  const auto res_indices = [&]() -> Tensor {
    const auto index_buffer = at::arange(max_batch_dim, s_indices.options());
    auto indices = at::empty({res_sparse_dim, res_nnz}, s_indices.options());
    // fill in indices corresponding to the "batch" dimensions of d.
    int64_t n_repeat_interleave = res_nnz;
    int64_t n_repeat = 1;
    for (const auto dim : c10::irange(d_batch_len)) {
      const auto dim_size = d_batch_shape[dim];
      n_repeat_interleave /= dim_size;
      // fill in indices corresponding to the "batch" dimension dim.
      // Equivalent to indices[dim].copy_(repeat_interleave(dim_index, n_repeat_interleave).repeat(n_repeat))
      const std::initializer_list<int64_t> dim_index_expanded_shape = {n_repeat, dim_size, n_repeat_interleave};
      const auto dim_index = index_buffer.slice(-1, 0, dim_size);
      const auto dim_index_expanded = dim_index.unsqueeze(0).unsqueeze_(-1).expand(dim_index_expanded_shape);
      // NOTE: indices is contiguous, so view is safe
      indices[dim].view(dim_index_expanded_shape).copy_(dim_index_expanded);
      n_repeat *= dim_size;
    }
    // fill in indices corresponding to s_indices.
    // Equivalent to indices_sparse.copy(s_indices.repeat({1, n_repeat})
    n_repeat = res_nnz / s_nnz;
    auto indices_sparse = indices.narrow(0, d_batch_len, res_sparse_dim - d_batch_len);
    const std::initializer_list<int64_t> s_indices_expanded_shape = {-1, n_repeat, s_nnz};
    const auto s_indices_expanded = s_indices.unsqueeze(1).expand(s_indices_expanded_shape);
    indices_sparse.view(s_indices_expanded_shape).copy_(s_indices_expanded);

    return indices;
  }();
  auto* res_impl = get_sparse_impl(res);
  res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
  res_impl->set_indices_and_values_unsafe(res_indices, res_values);
  res_impl->set_nnz_and_narrow(res_nnz);
  // By design of index expansion and that s is coalesced,
  // the result is also coalesced.
  return res._coalesced_(true);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free