Home / Class/ compute_internal Class — pytorch Architecture

compute_internal Class — pytorch Architecture

Architecture documentation for the compute_internal class in MaxPoolKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/MaxPoolKernel.cpp lines 65–137

template <typename scalar_t, typename opmath_t>
inline
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
compute_internal(
  const scalar_t* input_data,
  scalar_t* out_data,
  opmath_t* max_ptr,
  vec::int_same_size_t<opmath_t>* index_ptr,
  int64_t* ind,
  int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
  int64_t n,
  int64_t len,
  int64_t size,
  int64_t id0, int64_t id1,
  int64_t ih0, int64_t ih1,
  int64_t iw0, int64_t iw1,
  int64_t dilationD,
  int64_t dilationH,
  int64_t dilationW) {
  using Vec = vec::Vectorized<scalar_t>;
  using integer_t = vec::int_same_size_t<opmath_t>;
  using iVec = vec::Vectorized<integer_t>;
  // Pass I: init out lane
  iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);

  scalar_t min_value = lower_bound<scalar_t>();
  Vec out_vec = Vec(min_value);
  int64_t d1 = 0;
  for (; d1 < len; d1 += Vec::size()) {
    index0_vec.store(index_ptr + d1);
    out_vec.store(out_data + d1);
  }
  for (; d1 < size; d1++) {
    ind[d1] = ih0 * input_width + iw0;
    out_data[d1] = min_value;
  }
  // Pass II: compute local max
  for (int64_t id = id0; id < id1; id += dilationD) {
    for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
      for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
        const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
            id * input_height * input_width + ih * input_width + iw) * channels;

        int64_t d2 = 0;
        for (; d2 < len; d2 += Vec::size()) {
          iVec index_vec = iVec(id * input_height * input_width + ih * input_width + iw);
          Vec val_vec = Vec::loadu(in + d2);
          iVec maxindex_vec = iVec::loadu(index_ptr + d2);
          Vec maxval_vec = Vec::loadu(out_data + d2);

          // true = all ones, false = all zeros
          Vec mask = (val_vec > maxval_vec) | is_nan_vec(val_vec);
          iVec imask = vec::cast<integer_t>(mask);
          Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
          iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);

          out_vec.store(out_data + d2);
          ind_vec.store(index_ptr + d2);
        }
        for (; d2 < size; d2++) {
          int64_t index = id * input_height * input_width + ih * input_width + iw;
          scalar_t val = in[d2];
          int64_t maxindex = ind[d2];
          scalar_t maxval = out_data[d2];

          bool mask = (val > maxval) || is_nan(static_cast<double>(val));
          out_data[d2] = mask ? val : maxval;
          ind[d2] = mask ? index : maxindex;
        }
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free