Home / Class/ cpu_upsample_linear_backward_channels_last Class — pytorch Architecture

cpu_upsample_linear_backward_channels_last Class — pytorch Architecture

Architecture documentation for the cpu_upsample_linear_backward_channels_last class in UpSampleMoreKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp lines 591–737

template <typename scalar_t, typename scale_type>
void cpu_upsample_linear_backward_channels_last(
    const Tensor& grad_input_,
    const Tensor& grad_output_,
    bool align_corners,
    const scale_type& scales) {
  TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
              " for `grad_input` but got dtype ", grad_input_.dtype());

  auto ndim = grad_output_.ndimension();
  TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")

  auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
  auto grad_output = grad_output_.contiguous(channels_last_memory_format);
  auto grad_input = grad_input_.contiguous(channels_last_memory_format);

  auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
  auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();

  auto input_sizes = grad_input.sizes().vec();
  auto output_sizes = grad_output.sizes().vec();

  int64_t num_batches =  input_sizes[0];
  int64_t channels =  input_sizes[1];
  int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
  int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
  int64_t input_height = input_sizes[ndim - 2];
  int64_t output_height = output_sizes[ndim - 2];
  int64_t input_width = input_sizes[ndim - 1];
  int64_t output_width = output_sizes[ndim - 1];
  int64_t input_slice_size = input_depth * input_height * input_width * channels;
  using opmath_t = at::opmath_type<scalar_t>;

  auto loop2d = [&](int64_t begin, int64_t end) {
    opmath_t* acc_data_ptr = nullptr;
    std::unique_ptr<opmath_t[]> buffer_data;
    if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
        acc_data_ptr = buffer_data.get();
        memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
    } else {
      acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
    }

    const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
        input_height, output_height, align_corners, scales[0]);
    const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
        input_width, output_width, align_corners, scales[1]);

    auto input_indexr = [=](int64_t n, int64_t h, int64_t w, int64_t offset){
      return acc_data_ptr + offset + (h * input_width + w) * channels;
    };

    opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
    for (const auto n : c10::irange(begin, end)) {
      int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
      for (const auto oh : c10::irange(output_height)) {
        int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
        compute_source_index_and_lambda(
            ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
        for (const auto ow : c10::irange(output_width)) {
          compute_source_index_and_lambda(
              iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
          const scalar_t* grad_output_ptr = grad_output_data +
              (n * output_height * output_width + oh * output_width + ow) * channels;
          linear_channels_last_acc(input_indexr(n, ih0, iw0, input_offset), grad_output_ptr, h0lambda * w0lambda, channels); /* i00 */
          linear_channels_last_acc(input_indexr(n, ih0, iw1, input_offset), grad_output_ptr, h0lambda * w1lambda, channels); /* i01 */
          linear_channels_last_acc(input_indexr(n, ih1, iw0, input_offset), grad_output_ptr, h1lambda * w0lambda, channels); /* i10 */
          linear_channels_last_acc(input_indexr(n, ih1, iw1, input_offset), grad_output_ptr, h1lambda * w1lambda, channels); /* i11 */
        }
      }
      if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        auto gin = grad_input_data + n * input_slice_size;
        apply_grad_input(acc_data_ptr, gin, input_slice_size);
      }

    }
  };

  auto loop3d = [&](int64_t begin, int64_t end) {
    opmath_t* acc_data_ptr = nullptr;
    std::unique_ptr<opmath_t[]> buffer_data;
    if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
        acc_data_ptr = buffer_data.get();
        memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
    } else {
      acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
    }

    const opmath_t depth_scale = area_pixel_compute_scale<opmath_t>(
        input_depth, output_depth, align_corners, scales[0]);
    const opmath_t height_scale = area_pixel_compute_scale<opmath_t>(
        input_height, output_height, align_corners, scales[1]);
    const opmath_t width_scale = area_pixel_compute_scale<opmath_t>(
        input_width, output_width, align_corners, scales[2]);

    auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w, int64_t offset) {
      return acc_data_ptr + offset + (d * input_height * input_width + h * input_width + w) * channels;
    };

    opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
    for (const auto n : c10::irange(begin, end)) {
      int64_t input_offset = buffer_data.get() == nullptr ? n * input_slice_size : 0;
      for (const auto od : c10::irange(output_depth)) {
        int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
        compute_source_index_and_lambda(
            id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
        for (const auto oh : c10::irange(output_height)) {
          compute_source_index_and_lambda(
              ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
          for (const auto ow : c10::irange(output_width)) {
            compute_source_index_and_lambda(
                iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);
            const scalar_t* grad_output_ptr = grad_output_data + (n * output_depth * output_height * output_width +
                od *  output_height * output_width + oh * output_width + ow) * channels;
            linear_channels_last_acc(input_indexr(n, id0, ih0, iw0, input_offset), grad_output_ptr, d0lambda * h0lambda * w0lambda, channels); /* i000 */
            linear_channels_last_acc(input_indexr(n, id0, ih0, iw1, input_offset), grad_output_ptr, d0lambda * h0lambda * w1lambda, channels); /* i001 */
            linear_channels_last_acc(input_indexr(n, id0, ih1, iw0, input_offset), grad_output_ptr, d0lambda * h1lambda * w0lambda, channels); /* i010 */
            linear_channels_last_acc(input_indexr(n, id0, ih1, iw1, input_offset), grad_output_ptr, d0lambda * h1lambda * w1lambda, channels); /* i011 */
            linear_channels_last_acc(input_indexr(n, id1, ih0, iw0, input_offset), grad_output_ptr, d1lambda * h0lambda * w0lambda, channels); /* i100 */
            linear_channels_last_acc(input_indexr(n, id1, ih0, iw1, input_offset), grad_output_ptr, d1lambda * h0lambda * w1lambda, channels); /* i101 */
            linear_channels_last_acc(input_indexr(n, id1, ih1, iw0, input_offset), grad_output_ptr, d1lambda * h1lambda * w0lambda, channels); /* i110 */
            linear_channels_last_acc(input_indexr(n, id1, ih1, iw1, input_offset), grad_output_ptr, d1lambda * h1lambda * w1lambda, channels); /* i111 */
          }
        }
      }
      if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        auto gin = grad_input_data + n * input_slice_size;
        apply_grad_input(acc_data_ptr, gin, input_slice_size);
      }
    }
  };

  if (ndim == 4) {
    // upsample bilinear 2d
    at::parallel_for(0, num_batches, 0, loop2d);
  } else {
    // upsample trilinear 3d
    TORCH_INTERNAL_ASSERT(ndim == 5);
    at::parallel_for(0, num_batches, 0, loop3d);
  }

  if (!grad_input_.is_contiguous(channels_last_memory_format)) {
    grad_input_.copy_(grad_input);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free