layer_norm_kernel_mixed_type Class — pytorch Architecture
Architecture documentation for the layer_norm_kernel_mixed_type class in layer_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 97–154
template <typename T, typename param_t,
typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
void layer_norm_kernel_mixed_type(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t M,
int64_t N,
float eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
using bVec = Vectorized<T>;
using fVec = Vectorized<float>;
const T* X_data = X.const_data_ptr<T>();
const param_t* gamma_data = gamma.defined() ? gamma.const_data_ptr<param_t>() : nullptr;
const param_t* beta_data = beta.defined() ? beta.const_data_ptr<param_t>() : nullptr;
T* Y_data = Y->data_ptr<T>();
param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
const bool gamma_null = gamma_data == nullptr;
const bool beta_null = beta_data == nullptr;
const bool mean_null = mean_data == nullptr;
const bool rstd_null = rstd_data == nullptr;
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const T* X_ptr = X_data + i * N;
T* Y_ptr = Y_data + i * N;
auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
rstd_val = float(1) / std::sqrt(rstd_val + eps);
const float scale = rstd_val;
const float bias = -rstd_val * mean_val;
int64_t d = 0;
for (; d < N - (N % bVec::size()); d += bVec::size()) {
bVec x_bvec = bVec::loadu(X_ptr + d);
auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
y_bvec.store(Y_ptr + d);
}
for (; d < N; d++) {
const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
const float beta_v = beta_null ? float(0) : float(beta_data[d]);
Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
}
if (!mean_null) {
mean_data[i] = mean_val;
}
if (!rstd_null) {
rstd_data[i] = rstd_val;
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free