GroupNormBackwardKernelImplChannelsLastInternal Class — pytorch Architecture
Architecture documentation for the GroupNormBackwardKernelImplChannelsLastInternal class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 1357–1526
template <typename T, typename PT>
void GroupNormBackwardKernelImplChannelsLastInternal(
const Tensor& dY,
const Tensor& X,
const Tensor& mean,
const Tensor& rstd,
const Tensor& gamma,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
Tensor& dX,
Tensor& dgamma,
Tensor& dbeta) {
TORCH_CHECK(dY.numel() == N * C * HxW);
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(mean.numel() == N * group);
TORCH_CHECK(rstd.numel() == N * group);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
int64_t D = C / group;
int64_t G = group;
const T* dY_data = dY.const_data_ptr<T>();
const T* X_data = X.const_data_ptr<T>();
const PT* mean_data = mean.const_data_ptr<PT>();
const PT* rstd_data = rstd.const_data_ptr<PT>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
const bool gamma_null = (gamma_data == nullptr);
using opmath_t = at::opmath_type<T>;
Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
opmath_t* ds_data = ds.data_ptr<opmath_t>();
opmath_t* db_data = db.data_ptr<opmath_t>();
const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
// Similar to channels last forward, channels last backward has also 2 impls.
// impl-1: parallel on N * G. Only need one omp session for input gradients
// but memory access per thread is non-contiguous.
//
// impl-2: parallel on N * HxW. Memory access per thread is contiguous,
// but requires help of extra temp buffer of size {T, N, 2C}.
// Generally impl-2 has better performance when HxW is large enough, so that
// data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
constexpr int64_t feature_map_threshold = 2048;
if (HxW < feature_map_threshold) {
// impl-1: parallel on N * G.
at::parallel_for(0, N * G, 1, [=](int64_t begin, int64_t end) {
int64_t n{0}, g{0};
data_index_init(begin, n, N, g, G);
for (const auto i : c10::irange(begin, end)) {
// Step 1. Compute internal gradients.
opmath_t* ds_ptr = ds_data + i * D;
opmath_t* db_ptr = db_data + i * D;
const T* X_ptr = X_data + n * HxW * C + g * D;
const T* dY_ptr = dY_data + n * HxW * C + g * D;
const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
auto [ds_gamma, db_gamma] = CalcInternalGradientsChannelsLast<T, PT, opmath_t>(
X_ptr,
dY_ptr,
gamma_ptr,
ds_ptr,
db_ptr,
HxW,
C,
D);
// Step 2. Compute dX.
T* dX_ptr = dX_data + n * HxW * C + g * D;
const PT* rstd_ptr = rstd_data + i;
const opmath_t c2 = (db_gamma * opmath_t(mean_data[i]) - ds_gamma) *
opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * s;
const opmath_t c3 = -c2 * opmath_t(mean_data[i]) - db_gamma * opmath_t(rstd_data[i]) * s;
ApplyInputGradientsChannelsLastColMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
data_index_step(n, N, g, G);
}
});
} else {
// impl-2: parallel on N * HxW.
int num_threads = at::get_num_threads();
Tensor buffer = at::empty({num_threads, N, 2 * C},
X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
Tensor tmp_buffer = at::empty({N, 2 * G},
X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
// Step 1. Each thread compute their own internal gradients to the buffer.
at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
int64_t n{0}, m{0};
data_index_init(begin, n, N, m, HxW);
for (const auto i : c10::irange(begin, end)) {
opmath_t* ds_ptr = buffer_ptr + n * 2 * C;
opmath_t* db_ptr = ds_ptr + C;
const T* X_ptr = X_data + i * C;
const T* dY_ptr = dY_data + i * C;
DsDbRowwiseMomentsChannelsLast<T, opmath_t>(dY_ptr, X_ptr, ds_ptr, db_ptr, C);
data_index_step(n, N, m, HxW);
}
});
// Step 2. Collect internal gradients from each thread and
// get the final internal gradients to ds, db, and tmp_buffer.
for (const auto n : c10::irange(N)) {
for (const auto g : c10::irange(G)) {
opmath_t ds_gamma{0}, db_gamma{0};
for (const auto d : c10::irange(D)) {
opmath_t ds_val{0}, db_val{0};
for (const auto t : c10::irange(num_threads)) {
opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_data[g * D + d]);
ds_gamma += buffer_ptr[g * D + d] * gamma_val;
db_gamma += buffer_ptr[g * D + d + C] * gamma_val;
ds_val += buffer_ptr[g * D + d];
db_val += buffer_ptr[g * D + d + C];
}
ds_data[n * C + g * D + d] = ds_val;
db_data[n * C + g * D + d] = db_val;
}
tmp_buffer_data[n * 2 * G + 2 * g] = ds_gamma;
tmp_buffer_data[n * 2 * G + 2 * g + 1] = db_gamma;
}
}
// Step 3. Compute dx.
if (dX_data != nullptr) {
at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
int64_t n{0}, m{0};
data_index_init(begin, n, N, m, HxW);
for (const auto i : c10::irange(begin, end)) {
for (const auto g : c10::irange(G)) {
const T* X_ptr = X_data + i * C + g * D;
const T* dY_ptr = dY_data + i * C + g * D;
T* dX_ptr = dX_data + i * C + g * D;
const PT* mean_ptr = mean_data + n * G + g;
const PT* rstd_ptr = rstd_data + n * G + g;
const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
opmath_t ds_val = tmp_buffer_data[n * 2 * G + 2 * g];
opmath_t db_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
const opmath_t c2 = (db_val * opmath_t(*mean_ptr) - ds_val) *
opmath_t(*rstd_ptr) * opmath_t(*rstd_ptr)* opmath_t(*rstd_ptr) * s;
const opmath_t c3 = -c2 * opmath_t(*mean_ptr) - db_val * opmath_t(*rstd_ptr) * s;
ApplyInputGradientsChannelsLastRowMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
}
data_index_step(n, N, m, HxW);
}
});
}
}
// Finally compute dgamma and dbeta.
if (dgamma_data != nullptr) {
GammaBackward(
N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
}
if (dbeta_data != nullptr) {
BetaBackward(N, C, db_data, dbeta_data);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free