Home / Class/ _csr_matmult Class — pytorch Architecture

_csr_matmult Class — pytorch Architecture

Architecture documentation for the _csr_matmult class in SparseMatMul.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/SparseMatMul.cpp lines 88–190

template<typename index_t_ptr, typename scalar_t_ptr>
void _csr_matmult(
    const int64_t n_row,
    const int64_t n_col,
    const index_t_ptr Ap,
    const index_t_ptr Aj,
    const scalar_t_ptr Ax,
    const index_t_ptr Bp,
    const index_t_ptr Bj,
    const scalar_t_ptr Bx,
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
    typename index_t_ptr::value_type Cp[],
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
    typename index_t_ptr::value_type Cj[],
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
    typename scalar_t_ptr::value_type Cx[]) {
  /*
    Compute CSR entries for matrix C = A@B.

    The matrices `A` and 'B' should be in proper CSR structure, and their dimensions
    should be compatible.

    Inputs:
      `n_row`         - number of row in A
      `n_col`         - number of columns in B
      `Ap[n_row+1]`   - row pointer
      `Aj[nnz(A)]`    - column indices
      `Ax[nnz(A)]     - nonzeros
      `Bp[?]`         - row pointer
      `Bj[nnz(B)]`    - column indices
      `Bx[nnz(B)]`    - nonzeros
    Outputs:
      `Cp[n_row+1]` - row pointer
      `Cj[nnz(C)]`  - column indices
      `Cx[nnz(C)]`  - nonzeros

    Note:
      Output arrays Cp, Cj, and Cx must be preallocated
  */
  using index_t = typename index_t_ptr::value_type;
  using scalar_t = typename scalar_t_ptr::value_type;

  std::vector<index_t> next(n_col, -1);
  std::vector<scalar_t> sums(n_col, 0);

  int64_t nnz = 0;

  Cp[0] = 0;

  for (const auto i : c10::irange(n_row)) {
    index_t head = -2;
    index_t length = 0;

    index_t jj_start = Ap[i];
    index_t jj_end = Ap[i + 1];
    for (const auto jj : c10::irange(jj_start, jj_end)) {
      index_t j = Aj[jj];
      scalar_t v = Ax[jj];

      index_t kk_start = Bp[j];
      index_t kk_end = Bp[j + 1];
      for (const auto kk : c10::irange(kk_start, kk_end)) {
        index_t k = Bj[kk];

        sums[k] += v * Bx[kk];

        if (next[k] == -1) {
          next[k] = head;
          head = k;
          length++;
        }
      }
    }

    for ([[maybe_unused]] const auto jj : c10::irange(length)) {
      // NOTE: the linked list that encodes col indices
      // is not guaranteed to be sorted.
      Cj[nnz] = head;
      Cx[nnz] = sums[head];
      nnz++;

      index_t temp = head;
      head = next[head];

      next[temp] = -1; // clear arrays
      sums[temp] = 0;
    }

    // Make sure that col indices are sorted.
    // TODO: a better approach is to implement a CSR @ CSC kernel.
    // NOTE: Cx arrays are expected to be contiguous!
    auto col_indices_accessor = StridedRandomAccessor<int64_t>(Cj + nnz - length, 1);
    auto val_accessor = StridedRandomAccessor<scalar_t>(Cx + nnz - length, 1);
    auto kv_accessor = CompositeRandomAccessorCPU<
      decltype(col_indices_accessor), decltype(val_accessor)
    >(col_indices_accessor, val_accessor);
    std::sort(kv_accessor, kv_accessor + length, [](const auto& lhs, const auto& rhs) -> bool {
        return get<0>(lhs) < get<0>(rhs);
    });

    Cp[i + 1] = nnz;
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free