GroupNormBackwardKernelImplInternal Class — pytorch Architecture
Architecture documentation for the GroupNormBackwardKernelImplInternal class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 879–935
template <typename T, typename PT>
void GroupNormBackwardKernelImplInternal(
const Tensor& dY,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const Tensor& gamma,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
TORCH_CHECK(dY.numel() == N * C * HxW);
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(mean.numel() == N * group);
TORCH_CHECK(rstd.numel() == N * group);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
const T* dY_data = dY.const_data_ptr<T>();
const T* X_data = X.const_data_ptr<T>();
const PT* mean_data = mean.const_data_ptr<PT>();
const PT* rstd_data = rstd.const_data_ptr<PT>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
using opmath_t = at::opmath_type<T>;
Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
opmath_t* ds_data = ds.data_ptr<opmath_t>();
opmath_t* db_data = db.data_ptr<opmath_t>();
ComputeInternalGradients<T, opmath_t>(N, C, HxW, dY_data, X_data, ds_data, db_data);
if (dX_data != nullptr) {
GroupNormInputBackward<T, PT, opmath_t>(
N,
C,
HxW,
group,
dY_data,
X_data,
mean_data,
rstd_data,
gamma_data,
ds_data,
db_data,
dX_data);
}
if (dgamma_data != nullptr) {
GammaBackward(
N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
}
if (dbeta_data != nullptr) {
BetaBackward(N, C, db_data, dbeta_data);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free