Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in MaxPoolKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/MaxPoolKernel.cpp lines 140–233

template <typename scalar_t, typename opmath_t>
inline
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
compute_internal(
  const scalar_t* input_data,
  scalar_t* out_data,
  opmath_t* max_ptr,
  vec::int_same_size_t<opmath_t>* index_ptr,
  int64_t* ind,
  int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
  int64_t n,
  int64_t len,
  int64_t size,
  int64_t id0, int64_t id1,
  int64_t ih0, int64_t ih1,
  int64_t iw0, int64_t iw1,
  int64_t dilationD,
  int64_t dilationH,
  int64_t dilationW) {
  using Vec = vec::Vectorized<scalar_t>;
  using fVec = vec::Vectorized<opmath_t>;
  using iVec = vec::Vectorized<int32_t>;
  // Pass I: init out lane
  iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
  fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
  int64_t d1 = 0;
  for (; d1 < len; d1 += fVec::size()) {
    index0_vec.store(index_ptr + d1);
    out_vec.store(max_ptr + d1);
  }
  for (; d1 < size; d1++) {
    ind[d1] = ih0 * input_width + iw0;
    max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
  }
  // Pass II: compute local max
  for (int64_t id = id0; id < id1; id += dilationD) {
    for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
      for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
        const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
            id * input_height * input_width + ih * input_width + iw) * channels;

        int64_t d2 = 0;
        for (; d2 < len; d2 += Vec::size()) {
          iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
          Vec val_bvec = Vec::loadu(in + d2);
          auto [val_fvec0, val_fvec1] = convert_to_float<scalar_t>(val_bvec);

          iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
          iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
          fVec maxval_fvec0 = fVec::loadu(max_ptr + d2);
          fVec maxval_fvec1 = fVec::loadu(max_ptr + d2 + fVec::size());

          // true = all ones, false = all zeros
          fVec mask0 = (val_fvec0 > maxval_fvec0) | is_nan_vec(val_fvec0);
          fVec mask1 = (val_fvec1 > maxval_fvec1) | is_nan_vec(val_fvec1);
          iVec imask0 = vec::cast<int32_t>(mask0);
          iVec imask1 = vec::cast<int32_t>(mask1);

          fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
          fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
          iVec ind_vec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
          iVec ind_vec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);

          max_fvec0.store(max_ptr + d2);
          max_fvec1.store(max_ptr + d2 + fVec::size());
          // out_vec.store(out + d2);
          ind_vec0.store(index_ptr + d2);
          ind_vec1.store(index_ptr + d2 + iVec::size());
        }
        for (; d2 < size; d2++) {
          int64_t index = id * input_height * input_width + ih * input_width + iw;
          opmath_t val = opmath_t(in[d2]);
          int64_t maxindex = ind[d2];
          opmath_t maxval = max_ptr[d2];

          bool mask = (val > maxval) || std::isnan(val);
          max_ptr[d2] = mask ? val : maxval;
          ind[d2] = mask ? index : maxindex;
        }
      }
    }
  }
  // Convert max values from float to bfloat16/half
  int64_t d3 = 0;
  for (; d3 < len; d3 += Vec::size()) {
    fVec max_fvec0 = fVec::loadu(max_ptr + d3);
    fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
    Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
    max_bvec.store(out_data + d3);
  }
  for (; d3 < size; d3++) {
    out_data[d3] = scalar_t(max_ptr[d3]);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free