Home / Class/ _csrmm2 Class — pytorch Architecture

_csrmm2 Class — pytorch Architecture

Architecture documentation for the _csrmm2 class in SparseCUDABlas.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp lines 36–135

template<typename T>
void _csrmm2(
  char transa, char transb,
  int64_t m, int64_t n, int64_t k, int64_t nnz,
  T *alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
  T *b, int64_t ldb, T *beta, T *c, int64_t ldc,
  cudaDataType cusparse_value_type)
{
  if (csrvala == nullptr || b == nullptr || c == nullptr) return;

  cusparseOperation_t opa = convertTransToCusparseOperation(transa);
  cusparseOperation_t opb = convertTransToCusparseOperation(transb);

  // cusparseSpMM actually supports int64_t.
  // In order to support int64 here, index pointers csrrowptra, csrcolinda have to be passed as int64_t.
  TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
    "At the moment, cusparseSpMM only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX, ".",
    "If you need this, please file an issue on GitHub."
  );

  int64_t ma = m, ka = k;
  if (transa != 'n') std::swap(ma, ka);

  cusparseSpMatDescr_t descA;
  TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
    &descA,                     /* output */
    ma, ka, nnz,                /* rows, cols, number of non zero elements */
    csrrowptra,                 /* row offsets of the sparse matrix, size = rows +1 */
    csrcolinda,                 /* column indices of the sparse matrix, size = nnz */
    csrvala,                    /* values of the sparse matrix, size = nnz */
    CUSPARSE_INDEX_32I,         /* data type of row offsets index */
    CUSPARSE_INDEX_32I,         /* data type of col indices */
    CUSPARSE_INDEX_BASE_ZERO,   /* base index of row offset and col index */
    cusparse_value_type         /* data type of values */
  ));

  int64_t kb = k, nb = n;
  if (transb != 'n') std::swap(kb, nb);

  cusparseDnMatDescr_t descB;
  TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
    &descB,               /* output */
    kb, nb, ldb,          /* rows, cols, leading dimension */
    b,                    /* values */
    cusparse_value_type,  /* data type of values */
    CUSPARSE_ORDER_COL    /* memory layout, ONLY column-major is supported now */
  ));

  cusparseDnMatDescr_t descC;
  TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
    &descC,               /* output */
    m, n, ldc,            /* rows, cols, leading dimension */
    c,                    /* values */
    cusparse_value_type,  /* data type of values */
    CUSPARSE_ORDER_COL    /* memory layout, ONLY column-major is supported now */
  ));


  auto handle = at::cuda::getCurrentCUDASparseHandle();
  // ALG1 is broken on SM89 as of CUDA 11.8+
#if !defined(USE_ROCM)
  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  auto default_alg = prop->major == 8 && prop->minor == 9 ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1;
#else
  auto default_alg = CUSPARSE_SPMM_CSR_ALG1;
#endif

  // cusparseSpMM_bufferSize returns the bufferSize that can be used by cusparseSpMM
  size_t bufferSize;
  TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
    handle, opa, opb,
    alpha,
    descA, descB,
    beta,
    descC,
    cusparse_value_type,      /* data type in which the computation is executed */
    default_alg,              /* default computing algorithm for CSR sparse matrix format */
    &bufferSize               /* output */
  ));

  auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  auto dataPtr = allocator.allocate(bufferSize);

  TORCH_CUDASPARSE_CHECK(cusparseSpMM(
    handle, opa, opb,
    alpha,
    descA, descB,
    beta,
    descC,
    cusparse_value_type,      /* data type in which the computation is executed */
    default_alg,              /* default computing algorithm for CSR sparse matrix format */
    dataPtr.get()             /* external buffer */
  ));

  TORCH_CUDASPARSE_CHECK(cusparseDestroySpMat(descA));
  TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descB));
  TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descC));

  // TODO: Proper fix is to create real descriptor classes
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free