layer_norm_backward_frame Class — pytorch Architecture
Architecture documentation for the layer_norm_backward_frame class in layer_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 196–316
template <typename T, typename T2, typename opmath_t>
void layer_norm_backward_frame(
const T* dY_data,
const T* X_data,
const T2* mean_data,
const T2* rstd_data,
const T2* gamma_data,
T* dX_data,
T* dgamma_buffer_ptr,
T* dbeta_buffer_ptr,
// NOTE: the below @lint-ignore is only necessary because we compile
// specializations of this function for c10::complex.
// It's extremely likely that nobody actually takes layer norms of
// complex tensors, and even if they are, c10::complex is laid out poorly
// and basically should never be used.
// So it would be nice in the future to figure out how to stop compiling
// specializations of compute kernels for c10::complex.
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
const opmath_t scale,
const bool gamma_null,
const bool dX_null,
const bool dgamma_null,
const bool dbeta_null,
int64_t N,
int64_t i) {
using Vec = vec::Vectorized<opmath_t>;
const T* dY_ptr = dY_data + i * N;
const T* X_ptr = X_data + i * N;
if (!dgamma_null) {
const opmath_t a = rstd_data[i];
const opmath_t b = -a * mean_data[i];
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
// }
vec::map3<T>(
[a, b](Vec dgamma, Vec dy, Vec x) {
return dgamma + dy * (Vec(a) * x + Vec(b));
},
dgamma_buffer_ptr,
dgamma_buffer_ptr,
dY_ptr,
X_ptr,
N);
}
if (!dbeta_null) {
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dbeta_data[j] += dY_ptr[j];
// }
vec::map2<T>(
[](Vec dbeta, Vec dy) { return dbeta + dy; },
dbeta_buffer_ptr,
dbeta_buffer_ptr,
dY_ptr,
N);
}
if (!dX_null) {
T* dX_ptr = dX_data + i * N;
opmath_t ds = opmath_t(0);
opmath_t db = opmath_t(0);
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
// ds += dY_ptr[j] * X_ptr[j] * gamma_v;
// db += dY_ptr[j] * gamma_v;
// }
if (gamma_null) {
ds = vec::map2_reduce_all<T>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
dY_ptr,
X_ptr,
N);
db = vec::reduce_all<T>(
[](Vec& x, Vec& y) { return x + y; }, dY_ptr, N);
} else {
ds = vec::map3_reduce_all<T>(
[](Vec x, Vec y, Vec z) { return x * y * z; },
[](Vec x, Vec y) { return x + y; },
dY_ptr,
X_ptr,
gamma_data,
N);
db = vec::map2_reduce_all<T>(
[](Vec x, Vec y) { return x * y; },
[](Vec x, Vec y) { return x + y; },
dY_ptr,
gamma_data,
N);
}
const opmath_t a = rstd_data[i];
const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
// dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
// }
if (gamma_null) {
vec::map2<T>(
[a, b, c](Vec dy, Vec x) {
return Vec(a) * dy + Vec(b) * x + Vec(c);
},
dX_ptr,
dY_ptr,
X_ptr,
N);
} else {
vec::map3<T>(
[a, b, c](Vec dy, Vec gamma, Vec x) {
return Vec(a) * dy * gamma + Vec(b) * x + Vec(c);
},
dX_ptr,
dY_ptr,
gamma_data,
X_ptr,
N);
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free