Home / Class/ generateBlockCSRMatrix Class — pytorch Architecture

generateBlockCSRMatrix Class — pytorch Architecture

Architecture documentation for the generateBlockCSRMatrix class in pack_block_sparse.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h lines 127–190

template <typename INDICES_DTYPE>
std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
    const uint8_t* a,
    const size_t N,
    const size_t K,
    const uint32_t row_block_size,
    const uint32_t col_block_size,
    const uint8_t* zero_points) {
  assert(K > 0);
  std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
      std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
  auto& row_values = bcsr_mat->row_values.vector();
  auto& col_indices = bcsr_mat->col_indices.vector();
  auto& values = bcsr_mat->values.vector();

  const uint32_t num_row_blocks = (N + row_block_size - 1) / row_block_size;
  // K must be > 0
  const uint32_t num_col_blocks = (K + col_block_size - 1) / col_block_size;

  row_values.reserve(num_row_blocks);
  uint32_t num_nnz_blocks{0};
  row_values.push_back(num_nnz_blocks);
  for (uint32_t i = 0; i < num_row_blocks; ++i) {
    for (uint32_t j = 0; j < num_col_blocks; ++j) {
      bool block_zero{true};
      for (uint32_t ib = 0; ib < row_block_size; ++ib) {
        uint32_t row_index = i * row_block_size + ib;
        if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
          break;
        }
        for (uint32_t jb = 0; jb < col_block_size; ++jb) {
          uint32_t col_index = j * col_block_size + jb;
          if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
            goto block_scanned;
          }
          if (*(a + row_index * K + col_index) != zero_points[row_index]) {
            block_zero = false;
            goto block_scanned;
          }
        }
      }
block_scanned:
      if (!block_zero) {
        col_indices.push_back(j);
        num_nnz_blocks++;
        for (uint32_t ib = 0; ib < row_block_size; ++ib) {
          uint32_t row_index = i * row_block_size + ib;
          if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
            for (; row_index < (num_row_blocks * row_block_size); row_index++) {
              for (uint32_t jb = 0; jb < col_block_size; ++jb) {
                values.push_back(zero_points[N-1]);
              }
            }
            break;
          }
          for (uint32_t jb = 0; jb < col_block_size; ++jb) {
            uint32_t col_index = j * col_block_size + jb;
            if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
              values.push_back(zero_points[row_index]);
            } else {
              uint8_t val = *(a + row_index * K + col_index);
              values.push_back(val);
            }
          }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free