intersection_binary_op_sparse_dense_out Class — pytorch Architecture
Architecture documentation for the intersection_binary_op_sparse_dense_out class in SparseTensorMath.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/SparseTensorMath.cpp lines 831–994
template <typename binary_func_t>
static Tensor& intersection_binary_op_sparse_dense_out(
const Tensor& d,
const SparseTensor& s_,
Tensor& res,
const char* const op_name,
const binary_func_t& op,
const bool coalesce = false) {
// compute broadcasted shape.
const auto res_shape = infer_size(d.sizes(), s_.sizes());
// Short-circuit if either s_ or d is empty.
if (!s_._nnz() || !s_.numel() || !d.numel()) {
const int64_t dense_dim = s_.dense_dim();
const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
const int64_t nnz = 0;
const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
auto res_values_shape = s_._values().sizes().vec();
res_values_shape[0] = nnz;
const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
auto* res_impl = get_sparse_impl(res);
res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
res_impl->set_indices_and_values_unsafe(indices, values);
res_impl->set_nnz_and_narrow(nnz);
return res._coalesced_(true);
}
const auto d_dim = d.dim();
const auto s_dim = s_.dim();
// Always coalesce when sparse broadcasts over dense,
// because new sparse dimensions are created and
// repeated indices have to be eliminated because of that.
const auto s = (coalesce || d_dim > s_dim) ? s_.coalesce() : s_;
const auto sparse_dim = s.sparse_dim();
const auto dense_dim = s.dense_dim();
const auto s_indices = s._indices();
const auto s_values = s._values();
const auto apply_op = [&](const Tensor& d_filtered) -> Tensor& {
const auto res_indices = s_indices.clone();
// to(res.scalar_type) is only performed when both d and s are 0-dim.
// This insures right type promotions with the following rules:
// op(0-dim, 0-dim).dtype == <common dtype>
// op(0-dim, ge-1-dim).dtype == <ge-1-dim>.dtype,
// where ge-1-dim is a tensor with dim >= 1.
// We do not cast if op is performed in-place.
// The cast is required if s is 0-dim non-coalesced tensor and d is 0-dim.
// This is because s.values is at least 1D, so
// op(s.values, d).dtype == s.values.dtype, but we want
// op(s.values, d).dtype == <common dtype>.
const auto values = op(d_filtered, s_values);
const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
auto* res_impl = get_sparse_impl(res);
res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
res_impl->set_indices_and_values_unsafe(res_indices, res_values);
res_impl->set_nnz_and_narrow(s._nnz());
return res._coalesced_(s.is_coalesced());
};
// Easiest case: only dense dimensions intersect.
// This means only value tensors interact.
if (d_dim <= dense_dim) {
return apply_op(d);
}
// Now we have intersection between sparse and dense dims.
const auto sparse_dim_intersec = std::min(sparse_dim, d_dim - dense_dim);
const auto d_start_dim_intersec = std::max<int64_t>(0, d_dim - s_dim);
const auto s_start_dim_intersec = std::max<int64_t>(0, s_dim - d_dim);
// Index d with s_indices to find values which
// interact with s_values.
const auto d_filtered = [&]() -> Tensor {
using at::indexing::Slice;
using at::indexing::Ellipsis;
using at::indexing::TensorIndex;
std::vector<TensorIndex> intersec_indices;
intersec_indices.reserve(d_dim);
if (d_start_dim_intersec) {
intersec_indices.emplace_back(Ellipsis);
}
for (const auto i : c10::irange(sparse_dim_intersec)) {
const auto s_idx = s_start_dim_intersec + i;
intersec_indices.emplace_back(s_indices[s_idx]);
}
for (auto i = d_start_dim_intersec + sparse_dim_intersec; i < d_dim; ++i) {
intersec_indices.emplace_back(Slice());
}
// we need to expand d in the dimensions it is being indexed into
// to avoid out of bound indices
const auto d_expanded_shape = std::vector<int64_t>(
res_shape.end() - d_dim, res_shape.end());
return d.expand(d_expanded_shape).index(intersec_indices);
}();
// When dims match or sparse is "larger", the result nnz is the same,
// so only values get modified.
if (s_dim >= d_dim) {
return apply_op(d_filtered);
}
// Otherwise nnz gets larger, and both indices and values need an update.
const auto d_batch_shape = d.sizes().slice(0, d_start_dim_intersec);
const auto d_batch_len = static_cast<int64_t>(d_batch_shape.size());
int64_t batch_count = 1;
int64_t max_batch_dim = 0;
std::tie(batch_count, max_batch_dim) = [d_batch_shape]() -> std::tuple<int64_t, int64_t> {
int64_t batch_count = 1;
int64_t max_batch_dim = 0;
for (const auto& b : d_batch_shape) {
batch_count *= b;
max_batch_dim = std::max(b, max_batch_dim);
}
return std::make_tuple(batch_count, max_batch_dim);
}();
const auto res_sparse_dim = static_cast<int64_t>(d_batch_shape.size()) + sparse_dim;
const auto res_dense_dim = dense_dim;
const auto s_nnz = s._nnz();
const auto res_nnz = batch_count * s_nnz;
auto res_values_shape = s_values.sizes().vec();
res_values_shape[0] = res_nnz;
const auto res_values = op(d_filtered, s_values).reshape(res_values_shape);
const auto res_indices = [&]() -> Tensor {
const auto index_buffer = at::arange(max_batch_dim, s_indices.options());
auto indices = at::empty({res_sparse_dim, res_nnz}, s_indices.options());
// fill in indices corresponding to the "batch" dimensions of d.
int64_t n_repeat_interleave = res_nnz;
int64_t n_repeat = 1;
for (const auto dim : c10::irange(d_batch_len)) {
const auto dim_size = d_batch_shape[dim];
n_repeat_interleave /= dim_size;
// fill in indices corresponding to the "batch" dimension dim.
// Equivalent to indices[dim].copy_(repeat_interleave(dim_index, n_repeat_interleave).repeat(n_repeat))
const std::initializer_list<int64_t> dim_index_expanded_shape = {n_repeat, dim_size, n_repeat_interleave};
const auto dim_index = index_buffer.slice(-1, 0, dim_size);
const auto dim_index_expanded = dim_index.unsqueeze(0).unsqueeze_(-1).expand(dim_index_expanded_shape);
// NOTE: indices is contiguous, so view is safe
indices[dim].view(dim_index_expanded_shape).copy_(dim_index_expanded);
n_repeat *= dim_size;
}
// fill in indices corresponding to s_indices.
// Equivalent to indices_sparse.copy(s_indices.repeat({1, n_repeat})
n_repeat = res_nnz / s_nnz;
auto indices_sparse = indices.narrow(0, d_batch_len, res_sparse_dim - d_batch_len);
const std::initializer_list<int64_t> s_indices_expanded_shape = {-1, n_repeat, s_nnz};
const auto s_indices_expanded = s_indices.unsqueeze(1).expand(s_indices_expanded_shape);
indices_sparse.view(s_indices_expanded_shape).copy_(s_indices_expanded);
return indices;
}();
auto* res_impl = get_sparse_impl(res);
res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
res_impl->set_indices_and_values_unsafe(res_indices, res_values);
res_impl->set_nnz_and_narrow(res_nnz);
// By design of index expansion and that s is coalesced,
// the result is also coalesced.
return res._coalesced_(true);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free