SparseCsrMKLInterface Class — pytorch Architecture
Architecture documentation for the SparseCsrMKLInterface class in SparseCsrLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp lines 51–178
class SparseCsrMKLInterface {
private:
sparse_matrix_t A{nullptr};
matrix_descr desc;
public:
SparseCsrMKLInterface(
MKL_INT* col_indices,
MKL_INT* crow_indices,
double* values,
MKL_INT nrows,
MKL_INT ncols) {
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
int retval = mkl_sparse_d_create_csr(
&A,
SPARSE_INDEX_BASE_ZERO,
nrows,
ncols,
crow_indices,
crow_indices + 1,
col_indices,
values);
TORCH_CHECK(
retval == 0,
"mkl_sparse_d_create_csr failed with error code: ",
retval);
}
SparseCsrMKLInterface(
MKL_INT* col_indices,
MKL_INT* crow_indices,
float* values,
MKL_INT nrows,
MKL_INT ncols) {
desc.type = SPARSE_MATRIX_TYPE_GENERAL;
int retval = mkl_sparse_s_create_csr(
&A,
SPARSE_INDEX_BASE_ZERO,
nrows,
ncols,
crow_indices,
crow_indices + 1,
col_indices,
values);
TORCH_CHECK(
retval == 0,
"mkl_sparse_s_create_csr failed with error code: ",
retval);
}
// res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
inline void sparse_mm(
float* res,
float* dense,
float alpha,
float beta,
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
int stat;
if (dense_ncols == 1) {
stat = mkl_sparse_s_mv(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
dense,
beta,
res);
TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
} else {
stat = mkl_sparse_s_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
nrows,
ncols,
beta,
res,
dense_ncols);
TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
}
}
inline void sparse_mm(
double* res,
double* dense,
double alpha,
double beta,
MKL_INT nrows,
MKL_INT ncols,
MKL_INT dense_ncols) {
int stat;
if (dense_ncols == 1) {
stat = mkl_sparse_d_mv(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
dense,
beta,
res);
TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
}
else {
stat = mkl_sparse_d_mm(
SPARSE_OPERATION_NON_TRANSPOSE,
alpha,
A,
desc,
SPARSE_LAYOUT_ROW_MAJOR,
dense,
nrows,
ncols,
beta,
res,
dense_ncols);
TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
}
}
~SparseCsrMKLInterface() {
mkl_sparse_destroy(A);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free