Home / Class/ spatial_dilated_max_pooling3d Class — pytorch Architecture

spatial_dilated_max_pooling3d Class — pytorch Architecture

Architecture documentation for the spatial_dilated_max_pooling3d class in Pooling.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/Pooling.cpp lines 96–177

template <typename T>
void spatial_dilated_max_pooling3d(
    const T* qxd,
    int64_t nbatch,
    int64_t iC, // input/output channels
    int64_t iT,
    int64_t iH,
    int64_t iW, // input sizes
    int64_t oT,
    int64_t oH,
    int64_t oW, // output sizes
    int64_t kT,
    int64_t kH,
    int64_t kW, // kernel size
    int64_t sT,
    int64_t sH,
    int64_t sW, // strides
    int64_t pT,
    int64_t pH,
    int64_t pW, // padding
    int64_t dT,
    int64_t dH,
    int64_t dW, // dilation
    T* qyd) { // output arrays (data and max-index)
  // TODO: Further optimize the performance suggested by @mingfeima. Parallel on NCTH and cache the output indices from W.
  // Handle each bs
  int64_t oC = iC;
  int64_t parallel_dim = nbatch * iC;
  at::parallel_for(0, parallel_dim, 0, [&](int64_t start, int64_t end) {
    for (const auto p : c10::irange(start, end)) {

      int64_t batch_idx = p / iC;
      int64_t channel_idx = p - batch_idx * iC;

      auto* iData = qxd + batch_idx * iC * iT * iH * iW;
      auto* oData = qyd + batch_idx * oC * oT * oH * oW;

      // Handle each Channel
      int64_t time, row, col;
      const T* i_p = iData + channel_idx * iT * iW * iH;
      for (time = 0; time < oT; ++time) {
        for (row = 0; row < oH; ++row) {
          for (col = 0; col < oW; ++col) {
            // Handle each output element
            int64_t t_start = time * sT - pT;
            int64_t h_start = row * sH - pH;
            int64_t w_start = col * sW - pW;
            int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
            int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
            int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);

            while (t_start < 0)
              t_start += dT;
            while (h_start < 0)
              h_start += dH;
            while (w_start < 0)
              w_start += dW;

            // local pointers
            T* o_p = oData + channel_idx * oT * oH * oW  + time * oH * oW  + row * oW + col;

            // local max
            auto max_val = std::numeric_limits<typename T::underlying>::lowest();
            int64_t tcntr = 0; // center point
            for (int64_t t = t_start; t < t_end; t += dT) {
              for (int64_t y = h_start; y < h_end; y += dH) {
                for (int64_t x = w_start; x < w_end; x += dW) {
                  tcntr = t * iH * iW + y * iW + x;
                  auto val = (i_p + tcntr)->val_;
                  if (val > max_val) {
                    max_val = val;
                  }
                }
              }
            }
            *o_p = T(max_val); // Output.
          }
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free