Home / Class/ nn_compute_source_index_fn Class — pytorch Architecture

nn_compute_source_index_fn Class — pytorch Architecture

Architecture documentation for the nn_compute_source_index_fn class in UpSampleNearest3d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp lines 26–81

template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
static void upsample_nearest3d_out_frame(
    scalar_t* odata,
    scalar_t* idata,
    int64_t input_depth,
    int64_t input_height,
    int64_t input_width,
    int64_t output_depth,
    int64_t output_height,
    int64_t output_width,
    int64_t nbatch,
    int64_t channels,
    std::optional<double> scales_d,
    std::optional<double> scales_h,
    std::optional<double> scales_w) {
  float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
  float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
  float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);

  channels = channels * nbatch;
  if (channels == 0 || output_depth == 0 || output_height == 0 || output_width == 0) {
    return;
  }
  auto* i_p = reinterpret_cast<typename scalar_t::underlying*>(idata);
  auto* o_p = reinterpret_cast<typename scalar_t::underlying*>(odata);

  // special case: just copy
  if (input_depth == output_depth && input_height == output_height && input_width == output_width) {
    std::memcpy(o_p, i_p, channels * input_depth * input_height * input_width * sizeof(typename scalar_t::underlying));
    return;
  }

  for (const auto d2 : c10::irange(output_depth)) {
    const int64_t d1 =
          nn_compute_source_index_fn(depth_scale, d2, input_depth);

    for (const auto h2 : c10::irange(output_height)) {
      const int64_t h1 =
          nn_compute_source_index_fn(height_scale, h2, input_height);

      for (const auto w2 : c10::irange(output_width)) {
        const int64_t w1 =
            nn_compute_source_index_fn(width_scale, w2, input_width);

        const auto* pos1 = &i_p[d1 * input_height * input_width + h1 * input_width + w1];
        auto* pos2 = &o_p[d2 * output_height * output_width + h2 * output_width + w2];

        for ([[maybe_unused]] const auto c : c10::irange(channels)) {
          pos2[0] = pos1[0];
          pos1 += input_depth * input_height * input_width;
          pos2 += output_depth * output_height * output_width;
        }
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free