_csr_matmult_maxnnz Class — pytorch Architecture
Architecture documentation for the _csr_matmult_maxnnz class in SparseMatMul.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/SparseMatMul.cpp lines 53–86
template<typename index_t_ptr = int64_t*>
int64_t _csr_matmult_maxnnz(
const int64_t n_row,
const int64_t n_col,
const index_t_ptr Ap,
const index_t_ptr Aj,
const index_t_ptr Bp,
const index_t_ptr Bj) {
/*
Compute needed buffer size for matrix `C` in `C = A@B` operation.
The matrices should be in proper CSR structure, and their dimensions
should be compatible.
*/
std::vector<int64_t> mask(n_col, -1);
int64_t nnz = 0;
for (const auto i : c10::irange(n_row)) {
int64_t row_nnz = 0;
for (int64_t jj = Ap[i]; jj < Ap[i + 1]; jj++) {
int64_t j = Aj[jj];
for (int64_t kk = Bp[j]; kk < Bp[j + 1]; kk++) {
int64_t k = Bj[kk];
if (mask[k] != i) {
mask[k] = i;
row_nnz++;
}
}
}
int64_t next_nnz = nnz + row_nnz;
nnz = next_nnz;
}
return nnz;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free