Home / Class/ q_batch_norm_cpu_kernel_impl Class — pytorch Architecture

q_batch_norm_cpu_kernel_impl Class — pytorch Architecture

Architecture documentation for the q_batch_norm_cpu_kernel_impl class in QuantizedOpKernels.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 2504–2632

template <typename T>
void q_batch_norm_cpu_kernel_impl(
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t in_zero_point,
    int64_t out_zero_point,
    const uint8_t* in_ptr,
    const float* alpha_ptr,
    const float* beta_ptr,
    T* out_ptr) {

  int q_min = 0;
  int q_max = 255;
  const int64_t outer_size = N * HxW;

#if defined(CPU_CAPABILITY_AVX512)
  constexpr int kVLen = 16;
  static constexpr int num_vecs = sizeof(float) / sizeof(uint8_t);
  auto in_zp_vec = _mm512_set1_ps((float)in_zero_point);
  auto fake_scale = _mm512_set1_ps(1.0f);
  auto scale_neg_zp_premul = _mm512_xor_ps(_mm512_set1_ps(-0.f), in_zp_vec);
  auto out_zero_point_v = _mm512_set1_epi32((int)out_zero_point);
  constexpr auto lanes = static_cast<int64_t>(num_vecs * kVLen);
  __m512i v_q_max = _mm512_set1_epi32(q_max);
  __m512i v_q_min = _mm512_set1_epi32(q_min);

  auto load_convert_u8_to_f32_512bit = [&](const uint8_t* src, __m512* dst) {
    // Step 1: Load 512 bits
    __m512i raw = _mm512_loadu_si512(src);

    // Step 2: Extract two 256-bit chunks
    __m256i v0 = _mm512_extracti64x4_epi64(raw, 0); // bytes 0–31
    __m256i v1 = _mm512_extracti64x4_epi64(raw, 1); // bytes 32–63

    // Step 3: Process each 256-bit chunk
    // --- Expand uint8_t -> uint16_t ---
    __m256i u16lo0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 0));
    __m256i u16hi0 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v0, 1));
    __m256i u16lo1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 0));
    __m256i u16hi1 = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v1, 1));
    // --- Expand to uint32_t and convert to float ---
    dst[0] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo0));
    dst[1] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi0));
    dst[2] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16lo1));
    dst[3] = _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(u16hi1));
  };

  auto load_convert_u8_to_f32_128bit = [&](const uint8_t* src) {
    // --- Load and expand uint8_t -> uint16_t ---
    __m256i v_u16 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)src));
    // --- Expand to uint32_t and convert to float ---
    return _mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(v_u16));
  };

  auto store_output = [&](__m512 out, T* out_addr) {
    if constexpr (std::is_same<T, float>::value) {
      _mm512_storeu_ps(out_addr, out);
    } else if constexpr (std::is_same<T, at::BFloat16>::value) {
      __m256i out_bf16 = cvtfp32_bf16(out);
      _mm256_storeu_si256((__m256i*)out_addr, out_bf16);
    } else if constexpr (std::is_same<T, at::Half>::value) {
      __m256i out_f16 = cvtfp32_fp16(out);
      _mm256_storeu_si256((__m256i*)out_addr, out_f16);
    } else { //  T == uint8, requantization needed
      __m512i out_i32 = _mm512_cvtps_epi32(out);
      out_i32 = _mm512_add_epi32(out_i32, out_zero_point_v);
      out_i32 = _mm512_min_epi32(out_i32, v_q_max);
      out_i32 = _mm512_max_epi32(out_i32, v_q_min);
      __m128i out_i8 = _mm512_cvtepi32_epi8(out_i32);
      _mm_storeu_si128((__m128i*)out_addr, out_i8);
    }
  };
#endif

  at::parallel_for(0, outer_size, 0, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      auto* X_ptr = in_ptr + i * C;
      auto* Y_ptr = out_ptr + i * C;
      int64_t ch = 0;

#if defined(CPU_CAPABILITY_AVX512)
      __m512 vals_dq[num_vecs];
      for(; ch + lanes <= C; ch += lanes) {
        // load 64 values of input then dequantize them
        load_convert_u8_to_f32_512bit(X_ptr + ch, vals_dq);
        for (const auto idx : c10::irange(num_vecs)) {
          vals_dq[idx] = _mm512_fmadd_ps(fake_scale, vals_dq[idx], scale_neg_zp_premul);
          auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
          auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
          vals_dq[idx] = _mm512_fmadd_ps(alpha_v, vals_dq[idx], beta_v);
          store_output(vals_dq[idx], Y_ptr + ch + idx * kVLen);
        }
      }

      // for channel between 16 and 64
      int64_t elem_size = C - ch;
      if (elem_size >= kVLen) {
        int64_t vec_num = elem_size / kVLen;
        for (const auto idx : c10::irange(vec_num)) {
          __m512 val_dq = load_convert_u8_to_f32_128bit(X_ptr + ch + idx * kVLen);
          val_dq = _mm512_fmadd_ps(fake_scale, val_dq, scale_neg_zp_premul);
          auto alpha_v = _mm512_loadu_ps(alpha_ptr + ch + idx * kVLen);
          auto beta_v = _mm512_loadu_ps(beta_ptr + ch + idx * kVLen);
          val_dq = _mm512_fmadd_ps(alpha_v, val_dq, beta_v);
          store_output(val_dq, Y_ptr + ch + idx * kVLen);
        }
        ch += vec_num * kVLen;
      }
#endif
      // for channels less than 16
      for (; ch < C; ++ch) {
        float y_val_f = alpha_ptr[ch] * (X_ptr[ch] - in_zero_point) +
                        beta_ptr[ch];
        if constexpr (std::is_same<T, float>::value) {
          Y_ptr[ch] = y_val_f;
        } else if constexpr (std::is_same<T, at::BFloat16>::value) {
          Y_ptr[ch] = (at::BFloat16)y_val_f;
        } else if constexpr (std::is_same<T, at::Half>::value) {
          Y_ptr[ch] = (at::Half)y_val_f;
        } else { //  T == uint8, requantization needed
          long quantized_down = out_zero_point + lrintf(y_val_f);
          Y_ptr[ch] = std::min<long>(
              std::max<long>(quantized_down, q_min), q_max);
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free