Home / Class/ _unfold_backward_internal_kernel Class — pytorch Architecture

_unfold_backward_internal_kernel Class — pytorch Architecture

Architecture documentation for the _unfold_backward_internal_kernel class in UnfoldBackwardKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp lines 60–109

template <typename scalar_t>
void _unfold_backward_internal_kernel(
  TensorIterator& iter,
  int64_t size,
  int64_t step,
  int64_t grad_in_dim_stride,
  int64_t grad_in_last_dim_stride,
  int64_t grad_in_dim_size,
  int64_t grad_out_dim_stride
) {
  if (iter.numel() == 0) {
    return;
  }

  auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
    auto* RESTRICT grad_out_ptr = data[0];
    auto* RESTRICT grad_in_ptr = data[1];
    auto* RESTRICT idx_dim_ptr = data[2];

    for ([[maybe_unused]] const auto elem : c10::irange(nelems)) {
      auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
      auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);

      auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);

      // left_fold potentially intersecting with idx_dim
      // is either (idx_dim - size) / step or the next integer.
      int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
      if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
        ++left_fold_idx;
      }

      auto right_fold_idx = idx_dim / step;
      right_fold_idx = (right_fold_idx >= grad_in_dim_size)
        ? (grad_in_dim_size - 1) : right_fold_idx;

      for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
        auto idx_last_dim = idx_dim - fold_idx * step;
        *grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
                                    + idx_last_dim * grad_in_last_dim_stride];
      }

      grad_out_ptr += strides[0];
      grad_in_ptr += strides[1];
      idx_dim_ptr += strides[2];
    }
  };

  iter.for_each(loop);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free