BLOCK_M Class — pytorch Architecture
Architecture documentation for the BLOCK_M class in int8mm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/int8mm_kernel.cpp lines 31–99
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N;
const int PREFETCH_SIZE_K = 16 * 4;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 scale[COLS];
auto load_scale = [&](int i) {
float ss = static_cast<float>(scales[i]);
scale[i] = _mm512_set1_ps(ss);
};
c10::ForcedUnroll<COLS>{}(load_scale);
auto loadc = [&](auto i) {
vc[i] = _mm512_setzero_ps();
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
__m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
vec::cvtbf16_fp32(a16, va);
}
if constexpr (row == 0) {
__m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
__m512i b32 = _mm512_cvtepi8_epi32(b8);
vb[col] = _mm512_cvtepi32_ps(b32);
vb[col] = _mm512_mul_ps(vb[col], scale[col]);
}
constexpr int idx = row * COLS + col;
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
};
for (int k = 0; k < K; k += 16) {
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free