BLOCK_M Class — pytorch Architecture
Architecture documentation for the BLOCK_M class in int4mm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/int4mm_kernel.cpp lines 63–209
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const uint8_t* RESTRICT B,
const BFloat16* RESTRICT ScaleAndZeros,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K,
int BLOCK_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
const int PREFETCH_SIZE_K = 16 * 4;
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
// number of blocks on K
const int KB = K / BLOCK_K;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 scale[COLS];
__m512 zero[COLS];
// Lookup table to de-quantize int4 values to bf16.
// Values are dequantized as truly int4 [-8, 7] range;
//
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
//
static const __m512 lut = _mm512_set_ps(
7.0f, 6.0f, 5.0f, 4.0f,
3.0f, 2.0f, 1.0f, 0.0f,
-1.0f, -2.0f, -3.0f, -4.0f,
-5.0f, -6.0f, -7.0f, -8.0f);
// index for transpose
static const __m512i idx1 = _mm512_set_epi32(
30, 28, 26, 24, 22, 20, 18, 16,
14, 12, 10, 8, 6, 4, 2, 0);
static const __m512i idx2 = _mm512_set_epi32(
31, 29, 27, 25, 23, 21, 19, 17,
15, 13, 11, 9, 7, 5, 3, 1);
// load scale and zero point
auto load_scale_and_zeros = [&](int i, int _kb) {
// load 2x bfloat16 vector
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * ldc * 2 + 32 * i));
if (_kb + PREFETCH_SIZE_KB < KB) {
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 32 * i, _MM_HINT_T0);
}
// convert to 2x f32 vector
__m512 a, b;
vec::cvtbf16_fp32(t, a, b);
// transpose scale_and_zero from {16, 2} to {2, 16}
// inputs:
// a: {s0, z0, s1, z1, ..., s7, z7}
// b: {s8, z8, s9, z9, ..., s15, z15}
// output:
// scale: {s0, s1, s2, ..., s15}
// zero: {z0, z1, z2, ..., z15}
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
};
auto loadc = [&](auto i) {
vc[i] = _mm512_setzero_ps();
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&, COLS](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
float aa = static_cast<float>(A[row * lda + k]);
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
va = _mm512_set1_ps(aa);
}
if constexpr (row == 0) {
if constexpr (COLS == 4) {
// when BLOCK_N = 64, handle each row at a time
// to reduce de-quantize overhead.
if constexpr (col == 0) {
__m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
}
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
vb[0] = _mm512_permutexvar_ps(b32, lut);
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
vb[1] = _mm512_permutexvar_ps(b32, lut);
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
}
} else {
__m128i b8 = conver_int4_to_int8(B + k * ldb + col * 8);
__m512i b32 = _mm512_cvtepu8_epi32(b8);
vb[col] = _mm512_permutexvar_ps(b32, lut);
vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
}
}
constexpr int idx = row * COLS + col;
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
};
for (int k = 0, kb = 0; k < K; ++k) {
if (is_block_start(k, BLOCK_K)) {
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
}
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
//store to C
auto storec = [&, COLS](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (COLS == 4) {
// when BLOCK_N = 64, handle each row at a time
// to reduce `cvtfp32_bf16` overhead.
if constexpr (col == 0) {
__m512i c01 = vec::cvtfp32_bf16(vc[row * 4 + 0], vc[row * 4 + 1]);
__m512i c23 = vec::cvtfp32_bf16(vc[row * 4 + 2], vc[row * 4 + 3]);
_mm512_storeu_si512((__m512i*)(C + row * ldc + 0 * 32), c01);
_mm512_storeu_si512((__m512i*)(C + row * ldc + 1 * 32), c23);
}
} else {
__m256i ci = vec::cvtfp32_bf16(vc[i]);
_mm256_storeu_si256((__m256i*)(C + row * ldc + col * 16), ci);
}
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free