Home / Class/ nearest_idx_fn Class — pytorch Architecture

nearest_idx_fn Class — pytorch Architecture

Architecture documentation for the nearest_idx_fn class in UpSampleMoreKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/UpSampleMoreKernel.cpp lines 91–222

template <typename scalar_t, typename scale_type, nearest_idx_fn_t nearest_idx_fn>
void cpu_upsample_nearest_backward(
    const Tensor& grad_input_,
    const Tensor& grad_output_,
    const scale_type& scales) {
  TORCH_CHECK(grad_input_.dtype() == grad_output_.dtype(), "expected dtype ", grad_output_.dtype(),
              " for `grad_input` but got dtype ", grad_input_.dtype());

  auto grad_output = grad_output_.contiguous();
  auto grad_input = grad_input_.contiguous();

  auto grad_output_data = grad_output.const_data_ptr<scalar_t>();
  auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
  auto input_sizes = grad_input.sizes().vec();
  auto output_sizes = grad_output.sizes().vec();
  auto ndim = input_sizes.size();

  // treat nbatch and channels as one dimension
  int64_t channels = input_sizes[0] * input_sizes[1];
  int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
  int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
  int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
  int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
  int64_t input_width = input_sizes[ndim - 1];
  int64_t output_width = output_sizes[ndim - 1];

  int64_t output_slice_size = output_depth * output_height * output_width;
  int64_t input_slice_size = input_depth * input_height * input_width;

  using opmath_t = at::opmath_type<scalar_t>;
  auto loop1d = [&](int64_t begin, int64_t end) {
    opmath_t* acc_data_ptr = nullptr;
    std::unique_ptr<opmath_t[]> buffer_data;
    if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
      buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
      acc_data_ptr = buffer_data.get();
      memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
    } else {
      acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
    }

    for (const auto c : c10::irange(begin, end)) {
      int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
      for (const auto ow : c10::irange(output_width)) {
        int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[0]);
        int64_t output_offset = c * output_slice_size + ow;
        acc_data_ptr[input_offset + iw] += grad_output_data[output_offset];
      }
      if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        auto gin = grad_input_data + c * input_slice_size;
        apply_grad_input(acc_data_ptr, gin, input_slice_size);
      }
    }
  };

  auto loop2d = [&](int64_t begin, int64_t end) {
    opmath_t* acc_data_ptr = nullptr;
    std::unique_ptr<opmath_t[]> buffer_data;
    if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
        acc_data_ptr = buffer_data.get();
        memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
    } else {
      acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
    }

    for (const auto c : c10::irange(begin, end)) {
      int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
      for (const auto oh : c10::irange(output_height)) {
        int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[0]);
        for (const auto ow : c10::irange(output_width)) {
          int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[1]);
          int64_t output_offset = c * output_slice_size + oh * output_width + ow;
          acc_data_ptr[input_offset + ih * input_width + iw] += grad_output_data[output_offset];
        }
      }
      if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        auto gin = grad_input_data + c * input_slice_size;
        apply_grad_input(acc_data_ptr, gin, input_slice_size);
      }
    }
  };

  auto loop3d = [&](int64_t begin, int64_t end) {
    opmath_t* acc_data_ptr = nullptr;
    std::unique_ptr<opmath_t[]> buffer_data;
    if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        buffer_data = std::make_unique<opmath_t[]>(input_slice_size);
        acc_data_ptr = buffer_data.get();
        memset(acc_data_ptr, 0, sizeof(opmath_t) * input_slice_size);
    } else {
      acc_data_ptr = reinterpret_cast<opmath_t*>(grad_input_data);
    }

    for (const auto c : c10::irange(begin, end)) {
      int64_t input_offset = buffer_data.get() == nullptr ? c * input_slice_size : 0;
      for (const auto od : c10::irange(output_depth)) {
        int64_t id = nearest_idx_fn(od, input_depth, output_depth, scales[0]);
        for (const auto oh : c10::irange(output_height)) {
          int64_t ih = nearest_idx_fn(oh, input_height, output_height, scales[1]);
          for (const auto ow : c10::irange(output_width)) {
            int64_t iw = nearest_idx_fn(ow, input_width, output_width, scales[2]);
            int64_t output_offset = c * output_slice_size +
                od *  output_height * output_width + oh * output_width + ow;
            acc_data_ptr[input_offset + id * input_height * input_width + ih * input_width + iw] +=
              grad_output_data[output_offset];
          }
        }
      }
      if constexpr (!std::is_same_v<scalar_t, opmath_t>) {
        auto gin = grad_input_data + c * input_slice_size;
        apply_grad_input(acc_data_ptr, gin, input_slice_size);
      }
    }
  };

  if (ndim == 3) {
    // upsample nearest 1d
    at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop1d);
  } else if (ndim == 4) {
    // upsample nearest 2d
    at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size , loop2d);
  } else {
    // upsample nearest 3d
    TORCH_INTERNAL_ASSERT(ndim == 5);
    at::parallel_for(0, channels, at::internal::GRAIN_SIZE / output_slice_size, loop3d);
  }

  if (!grad_input_.is_contiguous()) {
    grad_input_.copy_(grad_input);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free