LayerNormBackwardKernelImplInternal Class — pytorch Architecture
Architecture documentation for the LayerNormBackwardKernelImplInternal class in layer_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 510–597
template <typename T, typename T2>
void LayerNormBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const Tensor& gamma,
int64_t M,
int64_t N,
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
using opmath_t = at::opmath_type<T>;
TORCH_DCHECK_EQ(dY.numel(), M * N);
TORCH_DCHECK_EQ(X.numel(), M * N);
TORCH_DCHECK_EQ(mean.numel(), M);
TORCH_DCHECK_EQ(rstd.numel(), M);
DCHECK(!gamma.defined() || gamma.numel() == N);
const T* dY_data = dY.template const_data_ptr<T>();
const T* X_data = X.template const_data_ptr<T>();
const T2* mean_data = mean.template const_data_ptr<T2>();
const T2* rstd_data = rstd.template const_data_ptr<T2>();
const T2* gamma_data =
gamma.defined() ? gamma.template const_data_ptr<T2>() : nullptr;
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
T2* const dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
T2* const dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
const bool gamma_null = gamma_data == nullptr;
const bool dX_null = dX_data == nullptr;
const bool dgamma_null = dgamma_data == nullptr;
const bool dbeta_null = dbeta_data == nullptr;
// 1. Use two path parallel reduction for dgamma and dbeta:
// First path: allocate an immediate buffer of size {2, max_threads, N},
// dgamma_buffer = buffer[0], dbeta_buffer = buffer[1]
// Parallel along dim0 and reduce dY and X along dim0 to buffer.
// Second path: parallel along dim1 and reduce buffer to dgamma and dbeta.
//
// 2. Fuse first path of dgamma/dbeta with dX to reuse X[i] and dY[i] in L1
// cache.
//
int num_threads = at::get_num_threads();
Tensor buffer = at::empty({0}, X.options());
T* buffer_data = nullptr;
if (!dgamma_null || !dbeta_null) {
// zero the immediate buffer and skip zero dgamma and dbeta
buffer.resize_({2, num_threads, N}).zero_();
buffer_data = buffer.template data_ptr<T>();
}
// First path of dgamma/dbeta and dX
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(
tid < num_threads,
"expect thread id smaller than ",
num_threads,
", got thread id ",
tid);
T* dgamma_buffer_ptr = dgamma_null ? nullptr : buffer_data + tid * N;
T* dbeta_buffer_ptr =
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
for (const auto i : c10::irange(start, end)) {
layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
}
});
// Second path of dgamma/dbeta
if (buffer_data != nullptr) {
parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
for (const auto j : c10::irange(start, end)) {
opmath_t dgamma_v = opmath_t(0);
opmath_t dbeta_v = opmath_t(0);
for (const auto i : c10::irange(num_threads)) {
dgamma_v += buffer_data[i * N + j];
dbeta_v += buffer_data[num_threads * N + i * N + j];
}
if (!dgamma_null) {
dgamma_data[j] = dgamma_v;
}
if (!dbeta_null) {
dbeta_data[j] = dbeta_v;
}
}
});
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free