Home / Class/ layer_norm_backward_frame Class — pytorch Architecture

layer_norm_backward_frame Class — pytorch Architecture

Architecture documentation for the layer_norm_backward_frame class in layer_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 196–316

template <typename T, typename T2, typename opmath_t>
void layer_norm_backward_frame(
    const T* dY_data,
    const T* X_data,
    const T2* mean_data,
    const T2* rstd_data,
    const T2* gamma_data,
    T* dX_data,
    T* dgamma_buffer_ptr,
    T* dbeta_buffer_ptr,
    // NOTE: the below @lint-ignore is only necessary because we compile
    // specializations of this function for c10::complex.
    // It's extremely likely that nobody actually takes layer norms of
    // complex tensors, and even if they are, c10::complex is laid out poorly
    // and basically should never be used.
    // So it would be nice in the future to figure out how to stop compiling
    // specializations of compute kernels for c10::complex.
    // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
    const opmath_t scale,
    const bool gamma_null,
    const bool dX_null,
    const bool dgamma_null,
    const bool dbeta_null,
    int64_t N,
    int64_t i) {
  using Vec = vec::Vectorized<opmath_t>;
  const T* dY_ptr = dY_data + i * N;
  const T* X_ptr = X_data + i * N;
  if (!dgamma_null) {
    const opmath_t a = rstd_data[i];
    const opmath_t b = -a * mean_data[i];
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
    // }
    vec::map3<T>(
        [a, b](Vec dgamma, Vec dy, Vec x) {
          return dgamma + dy * (Vec(a) * x + Vec(b));
        },
        dgamma_buffer_ptr,
        dgamma_buffer_ptr,
        dY_ptr,
        X_ptr,
        N);
  }
  if (!dbeta_null) {
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   dbeta_data[j] += dY_ptr[j];
    // }
    vec::map2<T>(
        [](Vec dbeta, Vec dy) { return dbeta + dy; },
        dbeta_buffer_ptr,
        dbeta_buffer_ptr,
        dY_ptr,
        N);
  }
  if (!dX_null) {
    T* dX_ptr = dX_data + i * N;
    opmath_t ds = opmath_t(0);
    opmath_t db = opmath_t(0);
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
    //   ds += dY_ptr[j] * X_ptr[j] * gamma_v;
    //   db += dY_ptr[j] * gamma_v;
    // }
    if (gamma_null) {
      ds = vec::map2_reduce_all<T>(
          [](Vec x, Vec y) { return x * y; },
          [](Vec x, Vec y) { return x + y; },
          dY_ptr,
          X_ptr,
          N);
      db = vec::reduce_all<T>(
          [](Vec& x, Vec& y) { return x + y; }, dY_ptr, N);
    } else {
      ds = vec::map3_reduce_all<T>(
          [](Vec x, Vec y, Vec z) { return x * y * z; },
          [](Vec x, Vec y) { return x + y; },
          dY_ptr,
          X_ptr,
          gamma_data,
          N);
      db = vec::map2_reduce_all<T>(
          [](Vec x, Vec y) { return x * y; },
          [](Vec x, Vec y) { return x + y; },
          dY_ptr,
          gamma_data,
          N);
    }
    const opmath_t a = rstd_data[i];
    const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
    const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
    // Scalar math:
    // for (const auto j : c10::irange(N)) {
    //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
    //   dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
    // }
    if (gamma_null) {
      vec::map2<T>(
          [a, b, c](Vec dy, Vec x) {
            return Vec(a) * dy + Vec(b) * x + Vec(c);
          },
          dX_ptr,
          dY_ptr,
          X_ptr,
          N);
    } else {
      vec::map3<T>(
          [a, b, c](Vec dy, Vec gamma, Vec x) {
            return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
          },
          dX_ptr,
          dY_ptr,
          gamma_data,
          X_ptr,
          N);
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free