Home / Class/ LayerNormBackwardKernelImplInternal Class — pytorch Architecture

LayerNormBackwardKernelImplInternal Class — pytorch Architecture

Architecture documentation for the LayerNormBackwardKernelImplInternal class in layer_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 510–597

template <typename T, typename T2>
void LayerNormBackwardKernelImplInternal(
    const Tensor& dY,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const Tensor& gamma,
    int64_t M,
    int64_t N,
    Tensor* dX,
    Tensor* dgamma,
    Tensor* dbeta) {
  using opmath_t = at::opmath_type<T>;
  TORCH_DCHECK_EQ(dY.numel(), M * N);
  TORCH_DCHECK_EQ(X.numel(), M * N);
  TORCH_DCHECK_EQ(mean.numel(), M);
  TORCH_DCHECK_EQ(rstd.numel(), M);
  DCHECK(!gamma.defined() || gamma.numel() == N);
  const T* dY_data = dY.template const_data_ptr<T>();
  const T* X_data = X.template const_data_ptr<T>();
  const T2* mean_data = mean.template const_data_ptr<T2>();
  const T2* rstd_data = rstd.template const_data_ptr<T2>();
  const T2* gamma_data =
      gamma.defined() ? gamma.template const_data_ptr<T2>() : nullptr;
  T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
  T2* const dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
  T2* const dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
  const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
  const bool gamma_null = gamma_data == nullptr;
  const bool dX_null = dX_data == nullptr;
  const bool dgamma_null = dgamma_data == nullptr;
  const bool dbeta_null = dbeta_data == nullptr;

  // 1. Use two path parallel reduction for dgamma and dbeta:
  //    First path: allocate an immediate buffer of size {2, max_threads, N},
  //        dgamma_buffer = buffer[0], dbeta_buffer = buffer[1]
  //    Parallel along dim0 and reduce dY and X along dim0 to buffer.
  //    Second path: parallel along dim1 and reduce buffer to dgamma and dbeta.
  //
  // 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1
  // cache.
  //
  int num_threads = at::get_num_threads();
  Tensor buffer = at::empty({0}, X.options());
  T* buffer_data = nullptr;
  if (!dgamma_null || !dbeta_null) {
    // zero the immediate buffer and skip zero dgamma and dbeta
    buffer.resize_({2, num_threads, N}).zero_();
    buffer_data = buffer.template data_ptr<T>();
  }

  // First path of dgamma/dbeta and dX
  at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(
        tid < num_threads,
        "expect thread id smaller than ",
        num_threads,
        ", got thread id ",
        tid);
    T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
    T* dbeta_buffer_ptr =
        dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
    for (const auto i : c10::irange(start, end)) {
      layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
    }
  });

  // Second path of dgamma/dbeta
  if (buffer_data != nullptr) {
    parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
      for (const auto j : c10::irange(start, end)) {
        opmath_t dgamma_v = opmath_t(0);
        opmath_t dbeta_v = opmath_t(0);
        for (const auto i : c10::irange(num_threads)) {
          dgamma_v += buffer_data[i * N + j];
          dbeta_v += buffer_data[num_threads * N + i * N + j];
        }
        if (!dgamma_null) {
          dgamma_data[j] = dgamma_v;
        }
        if (!dbeta_null) {
          dbeta_data[j] = dbeta_v;
        }
      }
    });
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free