nearest_idx_fn Class — pytorch Architecture
Architecture documentation for the nearest_idx_fn class in UpSampleKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 458–551
template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
void cpu_upsample_nearest_channels_last(
const Tensor& output_,
const Tensor& input_,
const scale_type& scales) {
TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
" for `output` but got dtype ", output_.dtype());
auto input_sizes = input_.sizes().vec();
auto output_sizes = output_.sizes().vec();
auto ndim = input_sizes.size();
TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(channels_last_memory_format);
auto output = output_.contiguous(channels_last_memory_format);
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t num_batches = input_sizes[0];
int64_t channels = input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = input_sizes[ndim - 2];
int64_t output_height = output_sizes[ndim - 2];
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t numel = output.numel();
TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
using Vec = vec::Vectorized<scalar_t>;
auto copy = [](scalar_t* out, const scalar_t* in, int64_t size) {
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec out_vec = Vec::loadu(in + d);
out_vec.store(out + d);
}
for (; d < size; d++) {
out[d] = in[d];
}
};
auto loop2d = [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, num_batches, oh, output_height, ow, output_width);
for (const auto i : c10::irange(begin, end)) {
int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
scalar_t* output_ptr = output_data + i * channels;
const scalar_t* input_ptr = input_data + n * input_height * input_width * channels +
ih * input_width * channels + iw * channels;
copy(output_ptr, input_ptr, channels);
data_index_step(n, num_batches, oh, output_height, ow, output_width);
}
};
auto loop3d = [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, num_batches, od, output_depth, oh, output_height, ow, output_width);
for (const auto i : c10::irange(begin, end)) {
int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
scalar_t* output_ptr = output_data + i * channels;
const scalar_t* input_ptr = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels +
ih * input_width * channels + iw * channels;
copy(output_ptr, input_ptr, channels);
data_index_step(n, num_batches, od, output_depth, oh, output_height, ow, output_width);
}
};
if (ndim == 4) {
// upsample nearest 2d
at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop2d);
} else {
// upsample nearest 3d
TORCH_INTERNAL_ASSERT(ndim == 5);
at::parallel_for(0, numel / channels, at::internal::GRAIN_SIZE / channels, loop3d);
}
if (!output_.is_contiguous(channels_last_memory_format)) {
output_.copy_(output);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free