q_batch_norm_cpu_kernel_impl Class — pytorch Architecture
Architecture documentation for the q_batch_norm_cpu_kernel_impl class in QuantizedOpKernels.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 2504–2632
template <typename T>
void q_batch_norm_cpu_kernel_impl(
int64_t N,
int64_t C,
int64_t HxW,
int64_t in_zero_point,
int64_t out_zero_point,
const uint8_t* in_ptr,
const float* alpha_ptr,
const float* beta_ptr,
T* out_ptr) {
int q_min = 0;
int q_max = 255;
const int64_t outer_size = N * HxW;
#if defined(CPU_CAPABILITY_AVX512)
constexpr int kVLen = 16;
static constexpr int num_vecs = sizeof(float) / sizeof(uint8_t);
auto in_zp_vec = _mm512_set1_ps((float)in_zero_point);
auto fake_scale = _mm512_set1_ps(1.0f);
auto scale_neg_zp_premul = _mm512_xor_ps(_mm512_set1_ps(-0.f), in_zp_vec);
auto out_zero_point_v = _mm512_set1_epi32((int)out_zero_point);
constexpr auto lanes = static_cast<int64_t>(num_vecs * kVLen);
__m512i v_q_max = _mm512_set1_epi32(q_max);
__m512i v_q_min = _mm512_set1_epi32(q_min);
auto load_convert_u8_to_f32_512bit = [&](const uint8_t* src, __m512* dst) {
// Step 1: Load 512 bits
__m512i raw = _mm512_loadu_si512(src);
// Step 2: Extract two 256-bit chunks
__m256i v0 = _mm512_extracti64x4_epi64(raw, 0); // bytes 0–31
__m256i v1 = _mm512_extracti64x4_epi64(raw, 1); // bytes 32–63
// Step 3: Process each 256-bit chunk
// --- Expand uint8_t -> uint16_t ---
__m256i u16lo0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 0));
__m256i u16hi0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 1));
__m256i u16lo1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 0));
__m256i u16hi1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 1));
// --- Expand to uint32_t and convert to float ---
dst[0] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo0));
dst[1] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi0));
dst[2] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo1));
dst[3] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi1));
};
auto load_convert_u8_to_f32_128bit = [&](const uint8_t* src) {
// --- Load and expand uint8_t -> uint16_t ---
__m256i v_u16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)src));
// --- Expand to uint32_t and convert to float ---
return _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(v_u16));
};
auto store_output = [&](__m512 out, T* out_addr) {
if constexpr (std::is_same<T, float>::value) {
_mm512_storeu_ps(out_addr, out);
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
__m256i out_bf16 = cvtfp32_bf16(out);
_mm256_storeu_si256((__m256i*)out_addr, out_bf16);
} else if constexpr (std::is_same<T, at::Half>::value) {
__m256i out_f16 = cvtfp32_fp16(out);
_mm256_storeu_si256((__m256i*)out_addr, out_f16);
} else { // T == uint8, requantization needed
__m512i out_i32 = _mm512_cvtps_epi32(out);
out_i32 = _mm512_add_epi32(out_i32, out_zero_point_v);
out_i32 = _mm512_min_epi32(out_i32, v_q_max);
out_i32 = _mm512_max_epi32(out_i32, v_q_min);
__m128i out_i8 = _mm512_cvtepi32_epi8(out_i32);
_mm_storeu_si128((__m128i*)out_addr, out_i8);
}
};
#endif
at::parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
auto* X_ptr = in_ptr + i * C;
auto* Y_ptr = out_ptr + i * C;
int64_t ch = 0;
#if defined(CPU_CAPABILITY_AVX512)
__m512 vals_dq[num_vecs];
for(; ch + lanes <= C; ch += lanes) {
// load 64 values of input then dequantize them
load_convert_u8_to_f32_512bit(X_ptr + ch, vals_dq);
for (const auto idx : c10::irange(num_vecs)) {
vals_dq[idx] = _mm512_fmadd_ps(fake_scale, vals_dq[idx], scale_neg_zp_premul);
auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
vals_dq[idx] = _mm512_fmadd_ps(alpha_v, vals_dq[idx], beta_v);
store_output(vals_dq[idx], Y_ptr + ch + idx * kVLen);
}
}
// for channel between 16 and 64
int64_t elem_size = C - ch;
if (elem_size >= kVLen) {
int64_t vec_num = elem_size / kVLen;
for (const auto idx : c10::irange(vec_num)) {
__m512 val_dq = load_convert_u8_to_f32_128bit(X_ptr + ch + idx * kVLen);
val_dq = _mm512_fmadd_ps(fake_scale, val_dq, scale_neg_zp_premul);
auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
val_dq = _mm512_fmadd_ps(alpha_v, val_dq, beta_v);
store_output(val_dq, Y_ptr + ch + idx * kVLen);
}
ch += vec_num * kVLen;
}
#endif
// for channels less than 16
for (; ch < C; ++ch) {
float y_val_f = alpha_ptr[ch] * (X_ptr[ch] - in_zero_point) +
beta_ptr[ch];
if constexpr (std::is_same<T, float>::value) {
Y_ptr[ch] = y_val_f;
} else if constexpr (std::is_same<T, at::BFloat16>::value) {
Y_ptr[ch] = (at::BFloat16)y_val_f;
} else if constexpr (std::is_same<T, at::Half>::value) {
Y_ptr[ch] = (at::Half)y_val_f;
} else { // T == uint8, requantization needed
long quantized_down = out_zero_point + lrintf(y_val_f);
Y_ptr[ch] = std::min<long>(
std::max<long>(quantized_down, q_min), q_max);
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free