compute_internal Class — pytorch Architecture
Architecture documentation for the compute_internal class in MaxPoolKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/MaxPoolKernel.cpp lines 65–137
template <typename scalar_t, typename opmath_t>
inline
std::enable_if_t<std::is_same_v<scalar_t, opmath_t>, void>
compute_internal(
const scalar_t* input_data,
scalar_t* out_data,
opmath_t* max_ptr,
vec::int_same_size_t<opmath_t>* index_ptr,
int64_t* ind,
int64_t input_depth, int64_t input_height, int64_t input_width, int64_t channels,
int64_t n,
int64_t len,
int64_t size,
int64_t id0, int64_t id1,
int64_t ih0, int64_t ih1,
int64_t iw0, int64_t iw1,
int64_t dilationD,
int64_t dilationH,
int64_t dilationW) {
using Vec = vec::Vectorized<scalar_t>;
using integer_t = vec::int_same_size_t<opmath_t>;
using iVec = vec::Vectorized<integer_t>;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
scalar_t min_value = lower_bound<scalar_t>();
Vec out_vec = Vec(min_value);
int64_t d1 = 0;
for (; d1 < len; d1 += Vec::size()) {
index0_vec.store(index_ptr + d1);
out_vec.store(out_data + d1);
}
for (; d1 < size; d1++) {
ind[d1] = ih0 * input_width + iw0;
out_data[d1] = min_value;
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id += dilationD) {
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
const scalar_t* in = input_data + (n * input_depth * input_height * input_width +
id * input_height * input_width + ih * input_width + iw) * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += Vec::size()) {
iVec index_vec = iVec(id * input_height * input_width + ih * input_width + iw);
Vec val_vec = Vec::loadu(in + d2);
iVec maxindex_vec = iVec::loadu(index_ptr + d2);
Vec maxval_vec = Vec::loadu(out_data + d2);
// true = all ones, false = all zeros
Vec mask = (val_vec > maxval_vec) | is_nan_vec(val_vec);
iVec imask = vec::cast<integer_t>(mask);
Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);
out_vec.store(out_data + d2);
ind_vec.store(index_ptr + d2);
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
scalar_t val = in[d2];
int64_t maxindex = ind[d2];
scalar_t maxval = out_data[d2];
bool mask = (val > maxval) || is_nan(static_cast<double>(val));
out_data[d2] = mask ? val : maxval;
ind[d2] = mask ? index : maxindex;
}
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free