Home / Class/ sparse_matmul_kernel Class — pytorch Architecture

sparse_matmul_kernel Class — pytorch Architecture

Architecture documentation for the sparse_matmul_kernel class in SparseMatMul.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseMatMul.cpp lines 193–261

template <typename scalar_t>
void sparse_matmul_kernel(
    Tensor& output,
    const Tensor& mat1,
    const Tensor& mat2) {
  /*
    Computes  the sparse-sparse matrix multiplication between `mat1` and `mat2`, which are sparse tensors in COO format.
  */

  auto M = mat1.size(0);
  auto N = mat2.size(1);

  const auto mat1_csr = mat1.to_sparse_csr();
  const auto mat2_csr = mat2.to_sparse_csr();

  auto mat1_crow_indices_ptr = StridedRandomAccessor<int64_t>(
      mat1_csr.crow_indices().data_ptr<int64_t>(),
      mat1_csr.crow_indices().stride(-1));
  auto mat1_col_indices_ptr = StridedRandomAccessor<int64_t>(
      mat1_csr.col_indices().data_ptr<int64_t>(),
      mat1_csr.col_indices().stride(-1));
  auto mat1_values_ptr = StridedRandomAccessor<scalar_t>(
      mat1_csr.values().data_ptr<scalar_t>(),
      mat1_csr.values().stride(-1));
  auto mat2_crow_indices_ptr = StridedRandomAccessor<int64_t>(
      mat2_csr.crow_indices().data_ptr<int64_t>(),
      mat2_csr.crow_indices().stride(-1));
  auto mat2_col_indices_ptr = StridedRandomAccessor<int64_t>(
      mat2_csr.col_indices().data_ptr<int64_t>(),
      mat2_csr.col_indices().stride(-1));
  auto mat2_values_ptr = StridedRandomAccessor<scalar_t>(
      mat2_csr.values().data_ptr<scalar_t>(),
      mat2_csr.values().stride(-1));

  const auto nnz = _csr_matmult_maxnnz(
      M,
      N,
      mat1_crow_indices_ptr,
      mat1_col_indices_ptr,
      mat2_crow_indices_ptr,
      mat2_col_indices_ptr);

  auto output_indices = output._indices();
  auto output_values = output._values();

  Tensor output_indptr = at::empty({M + 1}, kLong);
  at::native::resize_output(output_indices, {2, nnz});
  at::native::resize_output(output_values, nnz);

  Tensor output_row_indices = output_indices.select(0, 0);
  Tensor output_col_indices = output_indices.select(0, 1);

  // TODO: replace with a CSR @ CSC kernel for better performance.
  _csr_matmult(
      M,
      N,
      mat1_crow_indices_ptr,
      mat1_col_indices_ptr,
      mat1_values_ptr,
      mat2_crow_indices_ptr,
      mat2_col_indices_ptr,
      mat2_values_ptr,
      output_indptr.data_ptr<int64_t>(),
      output_col_indices.data_ptr<int64_t>(),
      output_values.data_ptr<scalar_t>());

  csr_to_coo(M, output_indptr.data_ptr<int64_t>(), output_row_indices.data_ptr<int64_t>());
  output._coalesced_(true);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free