Home / Class/ cpu_pixel_shuffle_channels_last Class — pytorch Architecture

cpu_pixel_shuffle_channels_last Class — pytorch Architecture

Architecture documentation for the cpu_pixel_shuffle_channels_last class in PixelShuffleKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/PixelShuffleKernel.cpp lines 55–111

template <typename scalar_t>
void cpu_pixel_shuffle_channels_last(
    TensorBase& output,
    const TensorBase& input,
    int64_t upscale_factor) {
  TORCH_CHECK(input.ndimension() == 4,
              "pixel shuffle with channels last format supports tensors with 4 dims");
  auto input_data = input.const_data_ptr<scalar_t>();
  auto output_data = output.data_ptr<scalar_t>();

  int64_t nbatch = input.size(0);
  int64_t channels = input.size(1);
  int64_t height = input.size(2);
  int64_t width = input.size(3);
  int64_t sub_channels = channels / (upscale_factor * upscale_factor);
  int64_t S = upscale_factor;

  // input tensor shape of [n, h, w, c, s1, s2]
  // output tensor shape of [n, h, s1, w, s2, c]
  using Vec = vec::Vectorized<scalar_t>;
  at::parallel_for(0, nbatch * height, 0, [&](int64_t begin, int64_t end) {
    // temp buffer holding each channel lane
    auto buffer = std::make_unique<scalar_t []>(channels);
    scalar_t* buffer_ptr = buffer.get();

    int64_t n{0}, h{0};
    data_index_init(begin, n, nbatch, h, height);
    for (const auto i : c10::irange(begin, end)) {
      for (const auto w : c10::irange(width)) {
        const scalar_t* input_ptr = input_data + n * height * width * channels + h * width * channels + w * channels;

        // step 1: transpose each channel lane
        //   from: [c, s1*s2]
        //   to:   [s1*s2, c]
        utils::transpose(sub_channels, S * S, input_ptr, S * S, buffer_ptr, sub_channels);

        // step 2: copy from temp buffer to output
        for (const auto s1 : c10::irange(S)) {
          scalar_t* x_ptr = buffer_ptr + s1 * S * sub_channels;
          scalar_t* y_ptr = output_data + i * width * channels + s1 * width * S * sub_channels + w * S * sub_channels;

          int64_t size = S * sub_channels;
          int64_t d = 0;
          for (; d < size - (size % Vec::size()); d += Vec::size()) {
            Vec data_vec = Vec::loadu(x_ptr + d);
            data_vec.store(y_ptr + d);
          }
          for (; d < size; d++) {
            y_ptr[d] = x_ptr[d];
          }
        }
      }

      data_index_step(n, nbatch, h, height);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free