Home / Class/ Unfold3dCopyKernelImpl Class — pytorch Architecture

Unfold3dCopyKernelImpl Class — pytorch Architecture

Architecture documentation for the Unfold3dCopyKernelImpl class in Unfold3d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Unfold3d.cpp lines 223–300

template <typename T>
void Unfold3dCopyKernelImpl(
    int64_t C,
    int64_t X_D,
    int64_t X_H,
    int64_t X_W,
    int64_t Y_D,
    int64_t Y_H,
    int64_t Y_W,
    int64_t kernel_d,
    int64_t kernel_h,
    int64_t kernel_w,
    int64_t stride_d,
    int64_t stride_h,
    int64_t stride_w,
    int64_t pad_d,
    int64_t pad_h,
    int64_t pad_w,
    const T* src,
    T* dst) {
  if (pad_d == 0 && pad_h == 0 && pad_w == 0) {
    Unfold3dZeroPaddingCopyKernelImpl<T>(
        C,
        X_D,
        X_H,
        X_W,
        Y_D,
        Y_H,
        Y_W,
        kernel_d,
        kernel_h,
        kernel_w,
        stride_d,
        stride_h,
        stride_w,
        src,
        dst);
    return;
  }

  const int64_t n = C * kernel_d * kernel_h * kernel_w;
  const int64_t X_size = X_D * X_H * X_W;
  const int64_t Y_size = Y_D * Y_H * Y_W;
  at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) {
    for (const auto p : c10::irange(begin, end)) {
      int64_t c = p;
      const int64_t kw = c % kernel_w;
      c /= kernel_w;
      const int64_t kh = c % kernel_h;
      c /= kernel_h;
      const int64_t kd = c % kernel_d;
      c /= kernel_d;
      const T* src_ptr = src + c * X_size;
      T* dst_ptr = dst + p * Y_size;
      for (const auto yd : c10::irange(Y_D)) {
        const int64_t xd = yd * stride_d - pad_d + kd;
        if (!IsAGeZeroAndALtB(xd, X_D)) {
          std::memset(dst_ptr + yd * Y_H * Y_W, 0, Y_H * Y_W * sizeof(T));
          continue;
        }
        for (const auto yh : c10::irange(Y_H)) {
          const int64_t xh = yh * stride_h - pad_h + kh;
          if (!IsAGeZeroAndALtB(xh, X_H)) {
            std::memset(
                dst_ptr + yd * Y_H * Y_W + yh * Y_W, 0, Y_W * sizeof(T));
            continue;
          }
          for (const auto yw : c10::irange(Y_W)) {
            const int64_t xw = yw * stride_w - pad_w + kw;
            dst_ptr[yd * Y_H * Y_W + yh * Y_W + yw] = IsAGeZeroAndALtB(xw, X_W)
                ? src_ptr[xd * X_H * X_W + xh * X_W + xw]
                : T(0);
          }
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free