Home / Class/ is_3d Class — pytorch Architecture

is_3d Class — pytorch Architecture

Architecture documentation for the is_3d class in MaxPoolKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/MaxPoolKernel.cpp lines 235–353

template <typename scalar_t, bool is_3d>
void cpu_max_pool(
    const Tensor& output_,
    const Tensor& indices_,
    const Tensor& input_,
    IntArrayRef kWHD,
    IntArrayRef dWHD,
    IntArrayRef padWHD,
    IntArrayRef dilWHD) {
  size_t dims =  is_3d ? 3 : 2;
  TORCH_CHECK(kWHD.size() == dims && dWHD.size() == dims && padWHD.size() == dims && dilWHD.size() == dims,
              "max pooling 2d/3d are not matched");
  int kW = kWHD[0];
  int kH = kWHD[1];
  int dW = dWHD[0];
  int dH = dWHD[1];
  int padW = padWHD[0];
  int padH = padWHD[1];
  int dilationW = dilWHD[0];
  int dilationH = dilWHD[1];

  int kD = is_3d ? kWHD[dims - 1] : 1;
  int dD = is_3d ? dWHD[dims - 1] : 1;
  int padD = is_3d ? padWHD[dims - 1] : 0;
  int dilationD = is_3d ? dilWHD[dims - 1] : 1;

  auto input = input_.contiguous();
  auto output = output_.contiguous();
  auto indices = indices_.contiguous();

  auto input_data = input.const_data_ptr<scalar_t>();
  auto output_data = output.data_ptr<scalar_t>();
  auto indices_data = indices.data_ptr<int64_t>();

  int64_t ndim = input.ndimension();
  // treat batch size and channels as one dimension
  //
  // MaxPool2d:
  //   ndim == 3: CHW
  //   ndim == 4: NCHW
  //
  // MaxPool3d:
  //   ndim == 4: CDHW
  //   ndim == 5: NCDHW
  int64_t channels;
  if (is_3d) {
    channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
  } else {
    channels = ndim == 3 ? input.size(0) : input.size(0) * input.size(1);
  }
  int64_t input_depth = is_3d ? input.size(-3) : 1;
  int64_t input_height = input.size(-2);
  int64_t input_width = input.size(-1);
  int64_t output_depth = is_3d ? output.size(-3) : 1;
  int64_t output_height = output.size(-2);
  int64_t output_width = output.size(-1);

  using opmath_t = at::opmath_type<scalar_t>;
  // parallel on dim N, C
  at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
    for (int64_t c = begin; c < end; c++) {
      const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
      scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
      int64_t* indices_ptr = indices_data + c * output_depth * output_height * output_width;

      for (int64_t od = 0; od < output_depth; od++) {
        int64_t id0 = od * dD - padD;
        int64_t id1 = std::min(id0 + (kD - 1) * dilationD + 1, input_depth);
        while(id0 < 0) { id0 += dilationD; }

        for (int64_t oh = 0; oh < output_height; oh++) {
          int64_t ih0 = oh * dH - padH;
          int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
          while(ih0 < 0) { ih0 += dilationH; }

          for (int64_t ow = 0; ow < output_width; ow++) {
            int64_t iw0 = ow * dW - padW;
            int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
            while(iw0 < 0) { iw0 += dilationW; }

            // compute local max
            int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
            opmath_t maxval;
            if (std::numeric_limits<opmath_t>::has_infinity) {
              maxval = -std::numeric_limits<opmath_t>::infinity();
            } else {
              maxval = std::numeric_limits<opmath_t>::min();
            }

            for (int64_t id = id0; id < id1; id += dilationD) {
              for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
                for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
                  int64_t index = id * input_height * input_width + ih * input_width + iw;
                  opmath_t val = input_ptr[index];
                  if ((val > maxval) || is_nan(static_cast<double>(val))) {
                    maxval = val;
                    maxindex = index;
                  }
                }
              }
            }

            // set output to local max and store location of max
            int64_t i = od * output_height * output_width + oh * output_width + ow;
            output_ptr[i] = scalar_t(maxval);
            indices_ptr[i] = maxindex;
          }
        }
      }
    }
  });

  if (!output_.is_contiguous()) {
    output_.copy_(output);
  }
  if (!indices_.is_contiguous()) {
    indices_.copy_(indices);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free