generateBlockCSRMatrix Class — pytorch Architecture
Architecture documentation for the generateBlockCSRMatrix class in pack_block_sparse.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qnnpack/include/pack_block_sparse.h lines 127–190
template <typename INDICES_DTYPE>
std::unique_ptr<BCSRMatrix> generateBlockCSRMatrix(
const uint8_t* a,
const size_t N,
const size_t K,
const uint32_t row_block_size,
const uint32_t col_block_size,
const uint8_t* zero_points) {
assert(K > 0);
std::unique_ptr<TypedBCSRMatrix<INDICES_DTYPE>> bcsr_mat =
std::make_unique<TypedBCSRMatrix<INDICES_DTYPE>>();
auto& row_values = bcsr_mat->row_values.vector();
auto& col_indices = bcsr_mat->col_indices.vector();
auto& values = bcsr_mat->values.vector();
const uint32_t num_row_blocks = (N + row_block_size - 1) / row_block_size;
// K must be > 0
const uint32_t num_col_blocks = (K + col_block_size - 1) / col_block_size;
row_values.reserve(num_row_blocks);
uint32_t num_nnz_blocks{0};
row_values.push_back(num_nnz_blocks);
for (uint32_t i = 0; i < num_row_blocks; ++i) {
for (uint32_t j = 0; j < num_col_blocks; ++j) {
bool block_zero{true};
for (uint32_t ib = 0; ib < row_block_size; ++ib) {
uint32_t row_index = i * row_block_size + ib;
if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
break;
}
for (uint32_t jb = 0; jb < col_block_size; ++jb) {
uint32_t col_index = j * col_block_size + jb;
if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
goto block_scanned;
}
if (*(a + row_index * K + col_index) != zero_points[row_index]) {
block_zero = false;
goto block_scanned;
}
}
}
block_scanned:
if (!block_zero) {
col_indices.push_back(j);
num_nnz_blocks++;
for (uint32_t ib = 0; ib < row_block_size; ++ib) {
uint32_t row_index = i * row_block_size + ib;
if PYTORCH_QNNP_UNLIKELY(row_index >= N) {
for (; row_index < (num_row_blocks * row_block_size); row_index++) {
for (uint32_t jb = 0; jb < col_block_size; ++jb) {
values.push_back(zero_points[N-1]);
}
}
break;
}
for (uint32_t jb = 0; jb < col_block_size; ++jb) {
uint32_t col_index = j * col_block_size + jb;
if PYTORCH_QNNP_UNLIKELY(col_index >= K) {
values.push_back(zero_points[row_index]);
} else {
uint8_t val = *(a + row_index * K + col_index);
values.push_back(val);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free