Home / Class/ only_sparse_compressed_binary_op_trivial_cases Class — pytorch Architecture

only_sparse_compressed_binary_op_trivial_cases Class — pytorch Architecture

Architecture documentation for the only_sparse_compressed_binary_op_trivial_cases class in SparseCsrTensorUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/SparseCsrTensorUtils.h lines 361–391

template <typename binary_op_t, typename binary_op_out_t>
inline bool only_sparse_compressed_binary_op_trivial_cases(
    const Tensor& self,
    const Tensor& other,
    const Scalar& alpha,
    Tensor& out,
    const binary_op_t& binary_op,
    const binary_op_out_t& binary_op_out) {
  // Only sparse compressed! Just like the name says :)
  TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
  TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
  TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));

  // Bypass BLAS if there are matches in (self, other, out)
  if (self.is_same(out) && self.is_same(other)) {
    binary_op_out(self.values(), other.values(), alpha);
    return true;
  }
  if (self.is_same(other)) {
    auto [compressed_indices, plain_indices] =
        at::sparse_csr::getCompressedPlainIndices(self);
    static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
        ->set_member_tensors(
            compressed_indices,
            plain_indices,
            binary_op(self.values(), other.values(), alpha),
            self.sizes());
    return true;
  }
  return false;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free