Home / Class/ q_maxpool_2d Class — pytorch Architecture

q_maxpool_2d Class — pytorch Architecture

Architecture documentation for the q_maxpool_2d class in Pooling.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/Pooling.cpp lines 179–341

template <typename Q>
Tensor q_maxpool_2d(
    Tensor qx, // Input Tensor (Quantized)
    int64_t kH,
    int64_t kW, // kernel size
    int64_t sH,
    int64_t sW, // strides
    int64_t pH,
    int64_t pW, // padding
    int64_t dH,
    int64_t dW,
    bool ceil_mode) { // dilation
  // Check input dimensions.
  TORCH_CHECK(kH > 0 && kW > 0, "kernel_size should be greater than zero.");
  TORCH_CHECK(sH > 0 && sW > 0, "strides should be greater than zero.");
  TORCH_CHECK(
      dH > 0 && dW > 0,
      "dilation should be greater than zero. "
      "Got (",
      dH,
      ", ",
      dW,
      ")");

  int ndim = qx.dim();
  TORCH_CHECK(
      ndim == 3 || ndim == 4, "Expecting the input tensor of rank 3 or 4.");
  int dimc = 0;
  int dimh = 1;
  int dimw = 2;
  int nbatch = 1;
  if (ndim == 4) { // Includes batches
    ++dimc;
    ++dimh;
    ++dimw;
    nbatch = qx.size(0);
  }

  // Check if inputs are valid.
  int64_t iC = qx.size(dimc);
  int64_t iH = qx.size(dimh);
  int64_t iW = qx.size(dimw);
  TORCH_CHECK(iC > 0 && iH > 0 && iW > 0, "input dimensions must be non-zero.");
  TORCH_CHECK(
      (ndim == 3 || ndim == 4),
      "non-empty 3D or 4D input tensor is expected.");
  TORCH_CHECK(
      kH / 2 >= pH && kW / 2 >= pW,
      "padding should be smaller than half of kernel_size.");

  // Check output dimensions.
  int64_t oC = iC;
  int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
  int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
  TORCH_CHECK(oH > 0 && oW > 0,
              "Given input size: (",
              iC, "x", iH, "x", iW,
              "). Calculated output size: (",
              oC, "x", oH, "x", oW,
              "). Output size is too small.");

  std::vector<int64_t> oSizes;
  if (ndim == 3) {
    oSizes = {oC, oH, oW};
  } else {
    oSizes = {nbatch, oC, oH, oW};
  }

  if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast)) {
    // Fast path case for channels-last case.
    // In this case, we can preserve the data layout in memory
    // as well as use a loop nest that is more amenable to
    // vectorization.
    Tensor qy;
    if constexpr(std::is_same_v<Q, uint8_t>) {
      qy = at::empty(
        oSizes,
        qx.options()
          .device(c10::kCPU)
          .dtype(qx.scalar_type())
          .memory_format(c10::MemoryFormat::ChannelsLast));
    } else {
      qy = at::_empty_affine_quantized(
          oSizes,
          qx.options()
            .dtype(toQIntType(qx.scalar_type()))
            .memory_format(qx.suggest_memory_format()),
          qx.q_scale(),
          qx.q_zero_point(),
          std::nullopt);
    }
    qmaxpool_2d_nhwc_stub(qx.device().type(), qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
    return qy;
  } else {
    Tensor qy;
    if constexpr(!std::is_same_v<Q, uint8_t>) {
      qy = at::_empty_affine_quantized(
              oSizes,
              qx.options().dtype(toQIntType(qx.scalar_type())),
              qx.q_scale(),
              qx.q_zero_point());
      auto qx_contig = qx.contiguous();
      auto qxd = qx_contig.data_ptr<Q>();
      auto qyd = qy.data_ptr<Q>();
      if (ndim == 3 || nbatch == 1) {
        auto* iData = qxd;
        auto* oData = qyd;
        spatial_dilated_max_pooling<Q>(
            iData,
            iC,
            iH,
            iW,
            oH,
            oW,
            kH,
            kW,
            sH,
            sW,
            pH,
            pW,
            dH,
            dW,
            oData);
      } else {
        at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) {
          for (const auto p : c10::irange(start, end)) {
            auto* iData = qxd + p * iC * iW * iH;
            auto* oData = qyd + p * oC * oW * oH;
            spatial_dilated_max_pooling<Q>(
                iData,
                iC,
                iH,
                iW,
                oH,
                oW,
                kH,
                kW,
                sH,
                sW,
                pH,
                pW,
                dH,
                dW,
                oData);
          }
        });
      }
    } else {
      // If qx is uint8 and contiguous memory format,
      // Use the channels_last implementation and convert qy back to contiguous.
      qy = at::empty(
        oSizes,
        qx.options()
          .device(c10::kCPU)
          .dtype(qx.scalar_type())
          .memory_format(c10::MemoryFormat::ChannelsLast));
      auto qx_nhwc = qx.contiguous(c10::MemoryFormat::ChannelsLast);
      qmaxpool_2d_nhwc_stub(qx_nhwc.device().type(), qx_nhwc, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy);
      qy = qy.contiguous();
    }
    return qy;
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free