GammaBackward Class — pytorch Architecture
Architecture documentation for the GammaBackward class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 708–753
template <typename PT, typename opmath_t>
std::enable_if_t<std::is_same_v<PT, opmath_t>, void>
GammaBackward(
int64_t N,
int64_t C,
int64_t group,
const PT* mean,
const PT* rstd,
const opmath_t* ds,
const opmath_t* db,
PT* dgamma) {
const int64_t G = group;
const int64_t D = C / G;
constexpr int64_t K = at::vec::Vectorized<PT>::size();
using Vec = at::vec::Vectorized<PT>;
const int64_t inner_size = D / K * K;
for (const auto g : c10::irange(G)) {
int64_t i = 0;
for (; i < inner_size; i += K) {
Vec acc_vec{0};
for (const auto n : c10::irange(N)) {
const PT* ds_ptr = ds + n * C + g * D + i;
const PT* db_ptr = db + n * C + g * D + i;
auto ds_vec = Vec::loadu(ds_ptr);
auto db_vec = Vec::loadu(db_ptr);
auto mean_vec = Vec(mean[n * G + g]);
auto rstd_vec = Vec(rstd[n * G + g]);
acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
}
acc_vec.store(dgamma + g * D + i);
}
if (D - i > 0) {
Vec acc_vec{0};
for (const auto n : c10::irange(N)) {
const PT* ds_ptr = ds + n * C + g * D + i;
const PT* db_ptr = db + n * C + g * D + i;
auto ds_vec = Vec::loadu(ds_ptr, D - i);
auto db_vec = Vec::loadu(db_ptr, D - i);
auto mean_vec = Vec(mean[n * G + g]);
auto rstd_vec = Vec(rstd[n * G + g]);
acc_vec += (ds_vec - db_vec * mean_vec) * rstd_vec;
}
acc_vec.store(dgamma + g * D + i, D - i);
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free