_csrmm2 Class — pytorch Architecture
Architecture documentation for the _csrmm2 class in SparseCUDABlas.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp lines 36–135
template<typename T>
void _csrmm2(
char transa, char transb,
int64_t m, int64_t n, int64_t k, int64_t nnz,
T *alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
T *b, int64_t ldb, T *beta, T *c, int64_t ldc,
cudaDataType cusparse_value_type)
{
if (csrvala == nullptr || b == nullptr || c == nullptr) return;
cusparseOperation_t opa = convertTransToCusparseOperation(transa);
cusparseOperation_t opb = convertTransToCusparseOperation(transb);
// cusparseSpMM actually supports int64_t.
// In order to support int64 here, index pointers csrrowptra, csrcolinda have to be passed as int64_t.
TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
"At the moment, cusparseSpMM only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX, ".",
"If you need this, please file an issue on GitHub."
);
int64_t ma = m, ka = k;
if (transa != 'n') std::swap(ma, ka);
cusparseSpMatDescr_t descA;
TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
&descA, /* output */
ma, ka, nnz, /* rows, cols, number of non zero elements */
csrrowptra, /* row offsets of the sparse matrix, size = rows +1 */
csrcolinda, /* column indices of the sparse matrix, size = nnz */
csrvala, /* values of the sparse matrix, size = nnz */
CUSPARSE_INDEX_32I, /* data type of row offsets index */
CUSPARSE_INDEX_32I, /* data type of col indices */
CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col index */
cusparse_value_type /* data type of values */
));
int64_t kb = k, nb = n;
if (transb != 'n') std::swap(kb, nb);
cusparseDnMatDescr_t descB;
TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
&descB, /* output */
kb, nb, ldb, /* rows, cols, leading dimension */
b, /* values */
cusparse_value_type, /* data type of values */
CUSPARSE_ORDER_COL /* memory layout, ONLY column-major is supported now */
));
cusparseDnMatDescr_t descC;
TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
&descC, /* output */
m, n, ldc, /* rows, cols, leading dimension */
c, /* values */
cusparse_value_type, /* data type of values */
CUSPARSE_ORDER_COL /* memory layout, ONLY column-major is supported now */
));
auto handle = at::cuda::getCurrentCUDASparseHandle();
// ALG1 is broken on SM89 as of CUDA 11.8+
#if !defined(USE_ROCM)
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
auto default_alg = prop->major == 8 && prop->minor == 9 ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1;
#else
auto default_alg = CUSPARSE_SPMM_CSR_ALG1;
#endif
// cusparseSpMM_bufferSize returns the bufferSize that can be used by cusparseSpMM
size_t bufferSize;
TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
handle, opa, opb,
alpha,
descA, descB,
beta,
descC,
cusparse_value_type, /* data type in which the computation is executed */
default_alg, /* default computing algorithm for CSR sparse matrix format */
&bufferSize /* output */
));
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(bufferSize);
TORCH_CUDASPARSE_CHECK(cusparseSpMM(
handle, opa, opb,
alpha,
descA, descB,
beta,
descC,
cusparse_value_type, /* data type in which the computation is executed */
default_alg, /* default computing algorithm for CSR sparse matrix format */
dataPtr.get() /* external buffer */
));
TORCH_CUDASPARSE_CHECK(cusparseDestroySpMat(descA));
TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descB));
TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descC));
// TODO: Proper fix is to create real descriptor classes
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free