CalcInternalGradientsChannelsLast Class — pytorch Architecture
Architecture documentation for the CalcInternalGradientsChannelsLast class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 1236–1289
template <typename T, typename PT, typename opmath_t>
inline typename std::
enable_if<std::is_same_v<T, opmath_t>, std::tuple<opmath_t, opmath_t>>::type
CalcInternalGradientsChannelsLast(
const T* X_data,
const T* dY_data,
const PT* gamma_ptr,
opmath_t* ds_ptr,
opmath_t* db_ptr,
int64_t HxW,
int64_t C,
int64_t D) {
using Vec = vec::Vectorized<T>;
const bool gamma_null = (gamma_ptr == nullptr);
constexpr int64_t K = Vec::size();
const int64_t inner_size = D / K * K;
int64_t d = 0;
opmath_t ds_gamma{0}, db_gamma{0};
for (; d < inner_size; d += K) {
Vec acc0_vec{0}, acc1_vec{0};
for (const auto m : c10::irange(HxW)) {
const T* X_ptr = X_data + m * C;
const T* dY_ptr = dY_data + m * C;
Vec x_vec = Vec::loadu(X_ptr + d);
Vec dy_vec = Vec::loadu(dY_ptr + d);
acc0_vec += x_vec * dy_vec;
acc1_vec += dy_vec;
}
acc0_vec.store(ds_ptr + d);
acc1_vec.store(db_ptr + d);
ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
}
if (D - d > 0) {
Vec acc0_vec{0}, acc1_vec{0};
for (const auto m : c10::irange(HxW)) {
const T* X_ptr = X_data + m * C;
const T* dY_ptr = dY_data + m * C;
Vec x_vec = Vec::loadu(X_ptr + d, D - d);
Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
acc0_vec += x_vec * dy_vec;
acc1_vec += dy_vec;
}
acc0_vec.store(ds_ptr + d, D - d);
acc1_vec.store(db_ptr + d, D - d);
ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
}
return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free