Home / Class/ SparseCsrMKLInterface Class — pytorch Architecture

SparseCsrMKLInterface Class — pytorch Architecture

Architecture documentation for the SparseCsrMKLInterface class in SparseCsrLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp lines 51–178

class SparseCsrMKLInterface {
 private:
  sparse_matrix_t A{nullptr};
  matrix_descr desc;

 public:
  SparseCsrMKLInterface(
      MKL_INT* col_indices,
      MKL_INT* crow_indices,
      double* values,
      MKL_INT nrows,
      MKL_INT ncols) {
    desc.type = SPARSE_MATRIX_TYPE_GENERAL;
    int retval = mkl_sparse_d_create_csr(
        &A,
        SPARSE_INDEX_BASE_ZERO,
        nrows,
        ncols,
        crow_indices,
        crow_indices + 1,
        col_indices,
        values);
    TORCH_CHECK(
        retval == 0,
        "mkl_sparse_d_create_csr failed with error code: ",
        retval);
  }

  SparseCsrMKLInterface(
      MKL_INT* col_indices,
      MKL_INT* crow_indices,
      float* values,
      MKL_INT nrows,
      MKL_INT ncols) {
    desc.type = SPARSE_MATRIX_TYPE_GENERAL;
    int retval = mkl_sparse_s_create_csr(
        &A,
        SPARSE_INDEX_BASE_ZERO,
        nrows,
        ncols,
        crow_indices,
        crow_indices + 1,
        col_indices,
        values);
    TORCH_CHECK(
        retval == 0,
        "mkl_sparse_s_create_csr failed with error code: ",
        retval);
  }

 // res(nrows, dense_ncols) = (sparse(nrows * ncols) @ dense(ncols x dense_ncols))
  inline void sparse_mm(
      float* res,
      float* dense,
      float alpha,
      float beta,
      MKL_INT nrows,
      MKL_INT ncols,
      MKL_INT dense_ncols) {
    int stat;
    if (dense_ncols == 1) {
      stat = mkl_sparse_s_mv(
        SPARSE_OPERATION_NON_TRANSPOSE,
        alpha,
        A,
        desc,
        dense,
        beta,
        res);
      TORCH_CHECK(stat == 0, "mkl_sparse_s_mv failed with error code: ", stat);
    } else {
      stat = mkl_sparse_s_mm(
        SPARSE_OPERATION_NON_TRANSPOSE,
        alpha,
        A,
        desc,
        SPARSE_LAYOUT_ROW_MAJOR,
        dense,
        nrows,
        ncols,
        beta,
        res,
        dense_ncols);
      TORCH_CHECK(stat == 0, "mkl_sparse_s_mm failed with error code: ", stat);
    }
  }

  inline void sparse_mm(
      double* res,
      double* dense,
      double alpha,
      double beta,
      MKL_INT nrows,
      MKL_INT ncols,
      MKL_INT dense_ncols) {
    int stat;
    if (dense_ncols == 1) {
      stat = mkl_sparse_d_mv(
        SPARSE_OPERATION_NON_TRANSPOSE,
        alpha,
        A,
        desc,
        dense,
        beta,
        res);
      TORCH_CHECK(stat == 0, "mkl_sparse_d_mv failed with error code: ", stat);
    }
    else {
      stat = mkl_sparse_d_mm(
        SPARSE_OPERATION_NON_TRANSPOSE,
        alpha,
        A,
        desc,
        SPARSE_LAYOUT_ROW_MAJOR,
        dense,
        nrows,
        ncols,
        beta,
        res,
        dense_ncols);
      TORCH_CHECK(stat == 0, "mkl_sparse_d_mm failed with error code: ", stat);
    }
  }

  ~SparseCsrMKLInterface() {
    mkl_sparse_destroy(A);
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free