cpu_upsample_linear_channels_last Class — pytorch Architecture
Architecture documentation for the cpu_upsample_linear_channels_last class in UpSampleKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 563–722
template <typename scalar_t, typename scale_type>
void cpu_upsample_linear_channels_last(
const Tensor& output_,
const Tensor& input_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
" for `output` but got dtype ", output_.dtype());
auto input_sizes = input_.sizes().vec();
auto output_sizes = output_.sizes().vec();
auto ndim = input_sizes.size();
TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")
auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(channels_last_memory_format);
auto output = output_.contiguous(channels_last_memory_format);
auto input_data = input.const_data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t num_batches = input_sizes[0];
int64_t channels = input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = input_sizes[ndim - 2];
int64_t output_height = output_sizes[ndim - 2];
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
int64_t output_slice_size = output_depth * output_height * output_width * channels;
using opmath_t = at::opmath_type<scalar_t>;
using Vec = vec::Vectorized<scalar_t>;
auto loop2d = [&](int64_t begin, int64_t end) {
const auto height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales[0]);
const auto width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[1]);
auto input_indexr = [=](int64_t n, int64_t h, int64_t w) {
return input_data + n * input_height * input_width * channels +
h * input_width * channels + w * channels;
};
int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto n : c10::irange(begin, end)) {
for (const auto oh : c10::irange(output_height)) {
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
scalar_t* out = output_data + n * output_slice_size +
oh * output_width * channels + ow * channels;
const scalar_t* i00 = input_indexr(n, ih0, iw0);
const scalar_t* i01 = input_indexr(n, ih0, iw1);
const scalar_t* i10 = input_indexr(n, ih1, iw0);
const scalar_t* i11 = input_indexr(n, ih1, iw1);
opmath_t w00 = h0lambda * w0lambda;
opmath_t w01 = h0lambda * w1lambda;
opmath_t w10 = h1lambda * w0lambda;
opmath_t w11 = h1lambda * w1lambda;
int64_t size = channels;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
auto out_vec = interpolate(i00 + d, w00, i01 + d, w01, i10 + d, w10, i11 + d, w11);
out_vec.store(out + d);
}
for (; d < size; d++) {
out[d] = i00[d] * w00 + i01[d] * w01 + i10[d] * w10 + i11[d] * w11;
}
}
}
}
};
auto loop3d = [&](int64_t begin, int64_t end) {
const auto depth_scale = area_pixel_compute_scale<opmath_t>(
input_depth, output_depth, align_corners, scales[0]);
const auto height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales[1]);
const auto width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[2]);
auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) {
return input_data + n * input_depth * input_height * input_width * channels +
d * input_height * input_width * channels +
h * input_width * channels + w * channels;
};
int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto n : c10::irange(begin, end)) {
for (const auto od : c10::irange(output_depth)) {
compute_source_index_and_lambda(
id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
for (const auto oh : c10::irange(output_height)) {
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
scalar_t* out = output_data + n * output_slice_size +
od * output_height * output_width * channels +
oh * output_width * channels + ow * channels;
const scalar_t* i000 = input_indexr(n, id0, ih0, iw0);
const scalar_t* i001 = input_indexr(n, id0, ih0, iw1);
const scalar_t* i010 = input_indexr(n, id0, ih1, iw0);
const scalar_t* i011 = input_indexr(n, id0, ih1, iw1);
const scalar_t* i100 = input_indexr(n, id1, ih0, iw0);
const scalar_t* i101 = input_indexr(n, id1, ih0, iw1);
const scalar_t* i110 = input_indexr(n, id1, ih1, iw0);
const scalar_t* i111 = input_indexr(n, id1, ih1, iw1);
opmath_t w000 = d0lambda * h0lambda * w0lambda;
opmath_t w001 = d0lambda * h0lambda * w1lambda;
opmath_t w010 = d0lambda * h1lambda * w0lambda;
opmath_t w011 = d0lambda * h1lambda * w1lambda;
opmath_t w100 = d1lambda * h0lambda * w0lambda;
opmath_t w101 = d1lambda * h0lambda * w1lambda;
opmath_t w110 = d1lambda * h1lambda * w0lambda;
opmath_t w111 = d1lambda * h1lambda * w1lambda;
int64_t size = channels;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
auto out_vec = interpolate(
i000 + d, w000, i001 + d, w001, i010 + d, w010, i011 + d, w011,
i100 + d, w100, i101 + d, w101, i110 + d, w110, i111 + d, w111);
out_vec.store(out + d);
}
for (; d < size; d++) {
out[d] =
i000[d] * w000 + i001[d] * w001 + i010[d] * w010 + i011[d] * w011 +
i100[d] * w100 + i101[d] * w101 + i110[d] * w110 + i111[d] * w111;
}
}
}
}
}
};
if (ndim == 4) {
// upsample nearest 2d
at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
// upsample nearest 3d
TORCH_INTERNAL_ASSERT(ndim == 5);
at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
}
if (!output_.is_contiguous(channels_last_memory_format)) {
output_.copy_(output);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free