Home / Class/ _qavg_pool_nhwc_kernel Class — pytorch Architecture

_qavg_pool_nhwc_kernel Class — pytorch Architecture

Architecture documentation for the _qavg_pool_nhwc_kernel class in QuantizedOpKernels.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 1974–2090

template <typename T>
void _qavg_pool_nhwc_kernel(
    const Tensor& qx,
    Tensor& qy,
    int64_t nBatch,
    int64_t nInputPlane,
    int64_t inputWidth,
    int64_t inputHeight,
    int64_t inputDepth,
    int64_t outputWidth,
    int64_t outputHeight,
    int64_t outputDepth,
    int kW,
    int kH,
    int kD,
    int dW,
    int dH,
    int dD,
    int padW,
    int padH,
    int padD,
    bool count_include_pad,
    std::optional<int64_t> divisor_override) {
  T* idata = static_cast<T*>(qx.data_ptr());
  T* odata = static_cast<T*>(qy.data_ptr());
  int strideC = 1;
  int strideW = strideC * nInputPlane;
  int istrideH = strideW * inputWidth;
  int istrideD = istrideH * inputHeight;
  int istrideB = istrideD * inputDepth;

  // lift these operations outside the loop to reduce access overheads
  float input_scale = qx.q_scale();
  float output_scale = qy.q_scale();
  int input_zero_point = qx.q_zero_point();
  int output_zero_point = qy.q_zero_point();
  int64_t divisor_override_factor =
      divisor_override.has_value() ? divisor_override.value() : 0;

  at::parallel_for(0, nBatch * outputDepth * outputHeight * outputWidth, 0, [&](int64_t begin, int64_t end) {
    int64_t b{0}, od{0}, oh{0}, ow{0};
    data_index_init(begin, b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);

    for (const auto i : c10::irange(begin, end)) {
      auto* i_p = reinterpret_cast<typename T::underlying*>(idata + b * istrideB);
      auto* o_p = reinterpret_cast<typename T::underlying*>(odata + i * strideW);
      int dstart = od * dD - padD;
      int hstart = oh * dH - padH;
      int wstart = ow * dW - padW;

      int dend = std::min(dstart + kD, (int)inputDepth + padD);
      int hend = std::min(hstart + kH, (int)inputHeight + padH);
      int wend = std::min(wstart + kW, (int)inputWidth + padW);
      int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);

      dstart = std::max(dstart, 0);
      hstart = std::max(hstart, 0);
      wstart = std::max(wstart, 0);
      dend = std::min(dend, (int)inputDepth);
      hend = std::min(hend, (int)inputHeight);
      wend = std::min(wend, (int)inputWidth);

      int size = (dend - dstart) * (hend - hstart) * (wend - wstart);
      int divide_size = count_include_pad ? pool_size : size;
      int divide_factor =
          divisor_override_factor ? divisor_override_factor : divide_size;
      float multiplier = input_scale / output_scale  / divide_factor;
      int input_zero_point_m_size = -input_zero_point * size;

      int c_start = 0;

      // For int8 quantization, we implicitly use int32 as accumulation
      // Or else, it will go to the slow path
      // TODO: support 16bit, 32bit, and etc.
      do_avg_pool_nhwc_on_AVX_n<T>(
          i_p,
          o_p,
          c_start,
          input_zero_point_m_size,
          output_zero_point,
          multiplier,
          dstart,
          dend,
          hstart,
          hend,
          wstart,
          wend,
          inputDepth,
          inputHeight,
          inputWidth,
          nInputPlane);

      // 1) The following loop handles the remaining channels
      // 2) It also handles the Non-AVX2 path
      for (const auto c: c10::irange(c_start, nInputPlane)) {
        int32_t acc_int32 = input_zero_point_m_size;
        for (const auto id : c10::irange(dstart, dend)) {
          for (const auto ih : c10::irange(hstart, hend)) {
            for (const auto iw : c10::irange(wstart, wend)) {
              auto val =
                  *(i_p + id * istrideD + ih * istrideH + iw * strideW +
                  c * strideC);
              acc_int32 += val;
            }
          }
       }
       double acc_fp = acc_int32 * 1.0;
       // clamp
       o_p[c] = at::native::quantize_val<T>(
           1.0f / multiplier, output_zero_point, acc_fp)
           .val_;
      } // c

      data_index_step(b, nBatch, od, outputDepth, oh, outputHeight, ow, outputWidth);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free