nn_compute_source_index_fn Class — pytorch Architecture
Architecture documentation for the nn_compute_source_index_fn class in UpSampleNearest2d.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/UpSampleNearest2d.cpp lines 30–84
template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
static void upsample_nearest2d_out_frame(
scalar_t* odata,
scalar_t* idata,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
int64_t nbatch,
int64_t channels,
std::optional<double> scales_h,
std::optional<double> scales_w) {
float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
channels = channels * nbatch;
if (channels == 0 || output_height == 0 || output_width == 0) {
return;
}
auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);
// special case: just copy
if (input_height == output_height && input_width == output_width) {
std::memcpy(o_p, i_p, channels * input_height * input_width * sizeof(typename scalar_t::underlying));
return;
}
std::vector<int64_t> input_offset_arr(output_width);
int64_t* input_offset = input_offset_arr.data();
for (const auto w2 : c10::irange(output_width)) {
const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);
input_offset[w2] = w1;
}
int64_t grain_size = internal::GRAIN_SIZE / std::max(int64_t{1}, output_width);
at::parallel_for(0, channels * output_height, grain_size, [&](int64_t begin, int64_t end) {
int64_t nc{0}, h2{0};
data_index_init(begin, nc, channels, h2, output_height);
for (const auto i : c10::irange(begin, end)) {
const int64_t h1 = nn_compute_source_index_fn(height_scale, h2, input_height);
const auto* pos1 = &i_p[nc * input_height * input_width + h1 * input_width];
auto* pos2 = &o_p[i * output_width];
for (const auto w2 : c10::irange(output_width)) {
const int64_t w1 = input_offset[w2];
pos2[w2] = pos1[w1];
}
data_index_step(nc, channels, h2, output_height);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free