is_same_v Class — pytorch Architecture
Architecture documentation for the is_same_v class in MaxPoolKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/MaxPoolKernel.cpp lines 140–233
template <typename scalar_t, typename opmath_t>
inline
std::enable_if_t<!std::is_same_v<scalar_t, opmath_t>, void>
compute_internal(
const scalar_t* input_data,
scalar_t* out_data,
opmath_t* max_ptr,
vec::int_same_size_t<opmath_t>* index_ptr,
int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
int64_t size,
int64_t id0, int64_t id1,
int64_t ih0, int64_t ih1,
int64_t iw0, int64_t iw1,
int64_t dilationD,
int64_t dilationH,
int64_t dilationW) {
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<opmath_t>;
using iVec = vec::Vectorized<int32_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
fVec out_vec = fVec(-std::numeric_limits<opmath_t>::infinity());
int64_t d1 = 0;
for (; d1 < len; d1 += fVec::size()) {
index0_vec.store(index_ptr + d1);
out_vec.store(max_ptr + d1);
}
for (; d1 < size; d1++) {
ind[d1] = ih0 * input_width + iw0;
max_ptr[d1] = -std::numeric_limits<opmath_t>::infinity();
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
id * input_height * input_width + ih * input_width + iw) * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += Vec::size()) {
iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
Vec val_bvec = Vec::loadu(in + d2);
auto [val_fvec0, val_fvec1] = convert_to_float<scalar_t>(val_bvec);
iVec maxindex_ivec0 = iVec::loadu(index_ptr + d2);
iVec maxindex_ivec1 = iVec::loadu(index_ptr + d2 + iVec::size());
fVec maxval_fvec0 = fVec::loadu(max_ptr + d2);
fVec maxval_fvec1 = fVec::loadu(max_ptr + d2 + fVec::size());
// true = all ones, false = all zeros
fVec mask0 = (val_fvec0 > maxval_fvec0) | is_nan_vec(val_fvec0);
fVec mask1 = (val_fvec1 > maxval_fvec1) | is_nan_vec(val_fvec1);
iVec imask0 = vec::cast<int32_t>(mask0);
iVec imask1 = vec::cast<int32_t>(mask1);
fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
iVec ind_vec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
iVec ind_vec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
max_fvec0.store(max_ptr + d2);
max_fvec1.store(max_ptr + d2 + fVec::size());
// out_vec.store(out + d2);
ind_vec0.store(index_ptr + d2);
ind_vec1.store(index_ptr + d2 + iVec::size());
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
opmath_t val = opmath_t(in[d2]);
int64_t maxindex = ind[d2];
opmath_t maxval = max_ptr[d2];
bool mask = (val > maxval) || std::isnan(val);
max_ptr[d2] = mask ? val : maxval;
ind[d2] = mask ? index : maxindex;
}
}
}
}
// Convert max values from float to bfloat16/half
int64_t d3 = 0;
for (; d3 < len; d3 += Vec::size()) {
fVec max_fvec0 = fVec::loadu(max_ptr + d3);
fVec max_fvec1 = fVec::loadu(max_ptr + d3 + fVec::size());
Vec max_bvec = convert_from_float<scalar_t>(max_fvec0, max_fvec1);
max_bvec.store(out_data + d3);
}
for (; d3 < size; d3++) {
out_data[d3] = scalar_t(max_ptr[d3]);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free