GroupNormKernelImplInternal Class — pytorch Architecture
Architecture documentation for the GroupNormKernelImplInternal class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 28–85
template <typename T, typename PT>
void GroupNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
double eps,
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
TORCH_CHECK(!beta.defined() || beta.numel() == C);
const int64_t G = group;
const int64_t D = C / G;
const T* X_data = X.const_data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
T* Y_data = Y.data_ptr<T>();
PT* mean_data = mean.data_ptr<PT>();
PT* rstd_data = rstd.data_ptr<PT>();
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
const int64_t inner_size = D * HxW;
using opmath_t = at::opmath_type<T>;
at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const T* X_ptr = X_data + i * inner_size;
auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, inner_size);
rstd_val = opmath_t(1) / std::sqrt(std::max(rstd_val, opmath_t(0)) + eps);
if (gamma_null && beta_null) {
T* Y_ptr = Y_data + i * inner_size;
for (const auto j : c10::irange(inner_size)) {
Y_ptr[j] = (X_ptr[j] - mean_val) * rstd_val;
}
} else {
const int64_t g = i % G;
for (const auto j : c10::irange(D)) {
const int64_t c = g * D + j;
const opmath_t scale = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
const opmath_t bias = -scale * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
X_ptr = X_data + (i * D + j) * HxW;
T* Y_ptr = Y_data + (i * D + j) * HxW;
for (const auto k : c10::irange(HxW)) {
Y_ptr[k] = scale * X_ptr[k] + bias;
}
}
}
mean_data[i] = mean_val;
rstd_data[i] = rstd_val;
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free