Home / Class/ cpu_upsample_linear_channels_last Class — pytorch Architecture

cpu_upsample_linear_channels_last Class — pytorch Architecture

Architecture documentation for the cpu_upsample_linear_channels_last class in UpSampleKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 563–722

template <typename scalar_t, typename scale_type>
void cpu_upsample_linear_channels_last(
    const Tensor& output_,
    const Tensor& input_,
    bool align_corners,
    const scale_type& scales) {
  TORCH_CHECK(input_.dtype() == output_.dtype(), "expected dtype ", input_.dtype(),
              " for `output` but got dtype ", output_.dtype());

  auto input_sizes = input_.sizes().vec();
  auto output_sizes = output_.sizes().vec();
  auto ndim = input_sizes.size();
  TORCH_CHECK(ndim >=4 && ndim <= 5, "Upsample with NHWC format supports tensors with 4 or 5 dims.")

  auto channels_last_memory_format = ndim == 4 ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::ChannelsLast3d;
  auto input = input_.contiguous(channels_last_memory_format);
  auto output = output_.contiguous(channels_last_memory_format);

  auto input_data = input.const_data_ptr<scalar_t>();
  auto output_data = output.data_ptr<scalar_t>();

  int64_t num_batches =  input_sizes[0];
  int64_t channels =  input_sizes[1];
  int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
  int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
  int64_t input_height = input_sizes[ndim - 2];
  int64_t output_height = output_sizes[ndim - 2];
  int64_t input_width = input_sizes[ndim - 1];
  int64_t output_width = output_sizes[ndim - 1];

  TORCH_CHECK(channels > 0, "expected input and output channels greater than 0 but got ", channels);
  int64_t output_slice_size = output_depth * output_height * output_width * channels;

  using opmath_t = at::opmath_type<scalar_t>;
  using Vec = vec::Vectorized<scalar_t>;
  auto loop2d = [&](int64_t begin, int64_t end) {
    const auto height_scale = area_pixel_compute_scale<opmath_t>(
        input_height, output_height, align_corners, scales[0]);
    const auto width_scale = area_pixel_compute_scale<opmath_t>(
        input_width, output_width, align_corners, scales[1]);

    auto input_indexr = [=](int64_t n, int64_t h, int64_t w) {
      return input_data + n * input_height * input_width * channels +
          h * input_width * channels + w * channels;
    };

    int64_t ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
    opmath_t h0lambda, h1lambda, w0lambda, w1lambda;
    for (const auto n : c10::irange(begin, end)) {
      for (const auto oh : c10::irange(output_height)) {
        compute_source_index_and_lambda(
            ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
        for (const auto ow : c10::irange(output_width)) {
          compute_source_index_and_lambda(
              iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);

          scalar_t* out = output_data + n * output_slice_size +
              oh * output_width * channels + ow * channels;
          const scalar_t* i00 = input_indexr(n, ih0, iw0);
          const scalar_t* i01 = input_indexr(n, ih0, iw1);
          const scalar_t* i10 = input_indexr(n, ih1, iw0);
          const scalar_t* i11 = input_indexr(n, ih1, iw1);
          opmath_t w00 = h0lambda * w0lambda;
          opmath_t w01 = h0lambda * w1lambda;
          opmath_t w10 = h1lambda * w0lambda;
          opmath_t w11 = h1lambda * w1lambda;

          int64_t size = channels;
          int64_t d = 0;
          for (; d < size - (size % Vec::size()); d += Vec::size()) {
            auto out_vec = interpolate(i00 + d, w00, i01 + d, w01, i10 + d, w10, i11 + d, w11);
            out_vec.store(out + d);
          }
          for (; d < size; d++) {
            out[d] = i00[d] * w00 + i01[d] * w01 + i10[d] * w10 + i11[d] * w11;
          }
        }
      }
    }
  };

  auto loop3d = [&](int64_t begin, int64_t end) {
    const auto depth_scale = area_pixel_compute_scale<opmath_t>(
        input_depth, output_depth, align_corners, scales[0]);
    const auto height_scale = area_pixel_compute_scale<opmath_t>(
        input_height, output_height, align_corners, scales[1]);
    const auto width_scale = area_pixel_compute_scale<opmath_t>(
        input_width, output_width, align_corners, scales[2]);

    auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) {
      return input_data + n * input_depth * input_height * input_width * channels +
          d * input_height * input_width * channels +
          h * input_width * channels + w * channels;
    };

    int64_t id0 = 0, id1 = 0, ih0 = 0, ih1 = 0, iw0 = 0, iw1 = 0;
    opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda;
    for (const auto n : c10::irange(begin, end)) {
      for (const auto od : c10::irange(output_depth)) {
        compute_source_index_and_lambda(
            id0, id1, d0lambda, d1lambda, depth_scale, od, input_depth, output_depth, align_corners);
        for (const auto oh : c10::irange(output_height)) {
          compute_source_index_and_lambda(
              ih0, ih1, h0lambda, h1lambda, height_scale, oh, input_height, output_height, align_corners);
          for (const auto ow : c10::irange(output_width)) {
            compute_source_index_and_lambda(
                iw0, iw1, w0lambda, w1lambda, width_scale, ow, input_width, output_width, align_corners);

            scalar_t* out = output_data + n * output_slice_size +
                od * output_height * output_width * channels +
                oh * output_width * channels + ow * channels;
            const scalar_t* i000 = input_indexr(n, id0, ih0, iw0);
            const scalar_t* i001 = input_indexr(n, id0, ih0, iw1);
            const scalar_t* i010 = input_indexr(n, id0, ih1, iw0);
            const scalar_t* i011 = input_indexr(n, id0, ih1, iw1);
            const scalar_t* i100 = input_indexr(n, id1, ih0, iw0);
            const scalar_t* i101 = input_indexr(n, id1, ih0, iw1);
            const scalar_t* i110 = input_indexr(n, id1, ih1, iw0);
            const scalar_t* i111 = input_indexr(n, id1, ih1, iw1);
            opmath_t w000 = d0lambda * h0lambda * w0lambda;
            opmath_t w001 = d0lambda * h0lambda * w1lambda;
            opmath_t w010 = d0lambda * h1lambda * w0lambda;
            opmath_t w011 = d0lambda * h1lambda * w1lambda;
            opmath_t w100 = d1lambda * h0lambda * w0lambda;
            opmath_t w101 = d1lambda * h0lambda * w1lambda;
            opmath_t w110 = d1lambda * h1lambda * w0lambda;
            opmath_t w111 = d1lambda * h1lambda * w1lambda;

            int64_t size = channels;
            int64_t d = 0;
            for (; d < size - (size % Vec::size()); d += Vec::size()) {
              auto out_vec = interpolate(
                  i000 + d, w000, i001 + d, w001, i010 + d, w010, i011 + d, w011,
                  i100 + d, w100, i101 + d, w101, i110 + d, w110, i111 + d, w111);
              out_vec.store(out + d);
            }
            for (; d < size; d++) {
              out[d] =
                  i000[d] * w000 + i001[d] * w001 + i010[d] * w010 + i011[d] * w011 +
                  i100[d] * w100 + i101[d] * w101 + i110[d] * w110 + i111[d] * w111;
            }
          }
        }
      }
    }
  };

  if (ndim == 4) {
    // upsample nearest 2d
    at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
  } else {
    // upsample nearest 3d
    TORCH_INTERNAL_ASSERT(ndim == 5);
    at::parallel_for(0, num_batches, at::internal::GRAIN_SIZE / output_slice_size / 8, loop3d);
  }

  if (!output_.is_contiguous(channels_last_memory_format)) {
    output_.copy_(output);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free