cpu_upsample_linear_backward Class — pytorch Architecture
Architecture documentation for the cpu_upsample_linear_backward class in UpSampleMoreKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp lines 421–589
template <typename scalar_t, typename scale_type>
void cpu_upsample_linear_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
" for `grad_input` but got dtype ", grad_input_.dtype());
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
auto input_sizes = grad_input.sizes().vec();
auto output_sizes = grad_output.sizes().vec();
auto ndim = input_sizes.size();
// treat nbatch and channels as one dimension
int64_t channels = input_sizes[0] * input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t input_slice_size = input_depth * input_height * input_width;
int64_t output_slice_size = output_depth * output_height * output_width;
using opmath_t = at::opmath_type<scalar_t>;
auto loop1d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
} else {
acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
}
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[0]);
opmath_t w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
opmath_t grad_output_value = grad_output_data[c * output_slice_size + ow];
acc_data_ptr[input_offset + iw0] += w0lambda * grad_output_value; /* i0 */
acc_data_ptr[input_offset + iw1] += w1lambda * grad_output_value; /* i1*/
}
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
}
};
auto loop2d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
} else {
acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
}
const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales[0]);
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[1]);
opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
opmath_t grad_output_value = grad_output_data[c * output_slice_size + oh * output_width + ow];
acc_data_ptr[input_offset + ih0 * input_width + iw0] += h0lambda * w0lambda * grad_output_value; /* i00 */
acc_data_ptr[input_offset + ih0 * input_width + iw1] += h0lambda * w1lambda * grad_output_value; /* i01 */
acc_data_ptr[input_offset + ih1 * input_width + iw0] += h1lambda * w0lambda * grad_output_value; /* i10 */
acc_data_ptr[input_offset + ih1 * input_width + iw1] += h1lambda * w1lambda * grad_output_value; /* i11 */
}
}
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
}
};
auto loop3d = [&](int64_t begin, int64_t end) {
opmath_t* acc_data_ptr = nullptr;
std::unique_ptr<opmath_t[]> buffer_data;
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
acc_data_ptr = buffer_data.get();
memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
} else {
acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
}
const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
input_depth, output_depth, align_corners, scales[0]);
const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
input_height, output_height, align_corners, scales[1]);
const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
input_width, output_width, align_corners, scales[2]);
opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
for (const auto c : c10::irange(begin, end)) {
int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
compute_source_index_and_lambda(
id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
for (const auto oh : c10::irange(output_height)) {
compute_source_index_and_lambda(
ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
for (const auto ow : c10::irange(output_width)) {
compute_source_index_and_lambda(
iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
opmath_t grad_output_value = grad_output_data[c * output_slice_size +
od * output_height * output_width + oh * output_width + ow];
acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw0] += d0lambda * h0lambda * w0lambda * grad_output_value; /* i000 */
acc_data_ptr[input_offset + id0 * input_height * input_width + ih0 * input_width + iw1] += d0lambda * h0lambda * w1lambda * grad_output_value; /* i001 */
acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw0] += d0lambda * h1lambda * w0lambda * grad_output_value; /* i010 */
acc_data_ptr[input_offset + id0 * input_height * input_width + ih1 * input_width + iw1] += d0lambda * h1lambda * w1lambda * grad_output_value; /* i011 */
acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw0] += d1lambda * h0lambda * w0lambda * grad_output_value; /* i100 */
acc_data_ptr[input_offset + id1 * input_height * input_width + ih0 * input_width + iw1] += d1lambda * h0lambda * w1lambda * grad_output_value; /* i101 */
acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw0] += d1lambda * h1lambda * w0lambda * grad_output_value; /* i110 */
acc_data_ptr[input_offset + id1 * input_height * input_width + ih1 * input_width + iw1] += d1lambda * h1lambda * w1lambda * grad_output_value; /* i111 */
}
}
}
if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
auto gin = grad_input_data + c * input_slice_size;
apply_grad_input(acc_data_ptr, gin, input_slice_size);
}
}
};
if (ndim == 3) {
// upsample linear 1d
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 2, loop1d);
} else if (ndim == 4) {
// upsample bilinear 2d
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
// upsample trilinear 3d
TORCH_INTERNAL_ASSERT(ndim == 5);
at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
}
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free