GroupNormInputBackward Class — pytorch Architecture
Architecture documentation for the GroupNormInputBackward class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 652–706
template <typename T, typename PT, typename opmath_t>
void GroupNormInputBackward(
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
const T* dY,
const T* X,
const PT* mean,
const PT* rstd,
const PT* gamma,
const opmath_t* ds,
const opmath_t* db,
T* dX) {
const int64_t G = group;
const int64_t D = C / G;
const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
const bool gamma_null = (gamma == nullptr);
at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
constexpr int64_t K = vec::Vectorized<PT>::size();
const int64_t d = D / K * K;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
for (const auto i : c10::irange(start, end)) {
const int64_t g = i % G;
const opmath_t* ds_ptr = ds + i * D;
const opmath_t* db_ptr = db + i * D;
const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
for (const auto j : c10::irange(d, D)) {
const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
ds_val += ds_ptr[j] * gamma_v;
db_val += db_ptr[j] * gamma_v;
}
const opmath_t c2 =
(db_val * opmath_t(mean[i]) - ds_val) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * s;
const opmath_t c3 = -c2 * opmath_t(mean[i]) - db_val * opmath_t(rstd[i]) * s;
for (const auto j : c10::irange(D)) {
const int64_t c = g * D + j;
const T* dY_ptr = dY + (i * D + j) * HxW;
const T* X_ptr = X + (i * D + j) * HxW;
T* dX_ptr = dX + (i * D + j) * HxW;
const opmath_t c1 = opmath_t(rstd[i]) * (gamma_null ? opmath_t(1) : opmath_t(gamma[c]));
for (const auto k : c10::irange(HxW)) {
dX_ptr[k] = c1 * opmath_t(dY_ptr[k]) + c2 * opmath_t(X_ptr[k]) + c3;
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free