GroupNormKernelImplChannelsLastInternal Class — pytorch Architecture
Architecture documentation for the GroupNormKernelImplChannelsLastInternal class in group_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 283–482
template <typename T, typename PT>
void GroupNormKernelImplChannelsLastInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t N,
int64_t C,
int64_t HxW,
int64_t group,
double eps,
Tensor& Y,
Tensor& mean,
Tensor& rstd) {
TORCH_CHECK(X.numel() == N * C * HxW);
TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
TORCH_CHECK(!beta.defined() || beta.numel() == C);
const int64_t G = group;
const int64_t D = C / G;
const T* X_data = X.const_data_ptr<T>();
const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
T* Y_data = Y.data_ptr<T>();
PT* mean_data = mean.data_ptr<PT>();
PT* rstd_data = rstd.data_ptr<PT>();
using opmath_t = at::opmath_type<T>;
const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
// NB: About algorithm chosen:
//
// On channels last, GroupNorm has a input shape of {N, H, W, GD},
// Mean and rstd are collected per each n and g, which involves reduction
// on non-adjacent dimensions. We can parallel in the following 2 impls:
//
// impl-1: parallel on N * G. Only need one omp session but memory access
// per thread is non-contiguous.
//
// impl-2: parallel on N * HxW. Memory access per thread is contiguous,
// but requires help of extra temp buffer of size {T, N, 2C}.
//
// Generally impl-2 has better performance when HxW is large enough, so that
// data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
//
constexpr int64_t feature_map_threshold = 1024;
if (HxW < feature_map_threshold) {
// impl-1: parallel on N * G.
//
// for each plain of HxW, scale and bias is calculated only once
Tensor buffer = at::empty({N * G, 2 * D}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
at::parallel_for(0, N * G, 1, [&](int64_t begin, int64_t end) {
int64_t n{0}, g{0};
data_index_init(begin, n, N, g, G);
for (const auto i : c10::irange(begin, end)) {
// step-1: for each n and g, collect sum of x and x2
//
// Note that using vec::map_reduce_all here is simpler to write
// but it is slower since horizontal reduce from vec to scalar is slow.
// So it is better to reduce with a vec across all HxW plain,
// and do a horizontal add just once for each {n, g}.
//
auto [mean_val, rstd_val] = ColumnwiseMoments(
X_data + n * HxW * C + g * D,
HxW,
C,
D);
mean_val *= s;
rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
mean_data[i] = mean_val;
rstd_data[i] = rstd_val;
// step-2: calculate scale and bias
opmath_t* scale_ptr = buffer_data + i * 2 * D;
opmath_t* bias_ptr = scale_ptr + D;
for (const auto d : c10::irange(D)) {
const int64_t c = g * D + d;
scale_ptr[d] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
bias_ptr[d] = -scale_ptr[d] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
}
// step-3: apply scale and bias
for (const auto m : c10::irange(HxW)) {
const T* X_ptr = X_data + n * HxW * C + m * C + g * D;
T* Y_ptr = Y_data + n * HxW * C + m * C + g * D;
ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, D);
}
data_index_step(n, N, g, G);
}
});
} else {
// impl-2: parallel on N * HxW.
//
// temp buffer holding x and x2
int num_threads = at::get_num_threads();
Tensor buffer = at::empty({num_threads, N, 2 * C},
X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
Tensor tmp_buffer = at::empty({N, 2 * G},
X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
// step-1: accumulate on dimension of C
//
// In order to improve multi-core performance when N=1,
// we parallel on the all the outer dimensions of N and HxW,
// leaving the most inner dimension C for vectorization.
//
// Note that parallel on {N, HxW, G} is not feasible for some common configs,
// e.g. say input shape is {1, 32, h, w} and G = 8,
// this will give D = 4 which is unable to take full SIMD length.
//
// To avoid thread conflict, we make use of a temp buffer of {T, N, 2C},
// firstly, reduce from {N, HxW, C} to {T, N, 2C}
//
at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
int64_t n{0}, m{0};
data_index_init(begin, n, N, m, HxW);
for (const auto i : c10::irange(begin, end)) {
opmath_t* mean_ptr = buffer_ptr + n * 2 * C;
opmath_t* rstd_ptr = mean_ptr + C;
const T* X_ptr = X_data + i * C;
CalcMeanVar<T, opmath_t>(X_ptr, mean_ptr, rstd_ptr, C);
data_index_step(n, N, m, HxW);
}
});
// step-2: compute mean and rstd
for (const auto n : c10::irange(N)) {
for (const auto g : c10::irange(G)) {
opmath_t mean_val{0}, rstd_val{0};
for (const auto d : c10::irange(D)) {
for (const auto t : c10::irange(num_threads)) {
opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
mean_val += buffer_ptr[g * D + d];
rstd_val += buffer_ptr[g * D + d + C];
}
}
mean_val *= s;
rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
tmp_buffer_data[n * 2 * G + 2 * g] = mean_val;
tmp_buffer_data[n * 2 * G + 2 * g + 1] = rstd_val;
}
}
// step-3: compute scale and bias
//
// mean/rstd have shape of {N, G}, gamma/beta have shape of {G, D}.
// And scale/bias have shape of {N, C} so that we can directly vectorize on
// dimension of C in the final step.
//
// We could fuse step 3 and 4 into a single session but this way is better:
// a. D might be too small for vectorization;
// b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
//
for (const auto n : c10::irange(N)) {
for (const auto g : c10::irange(G)) {
opmath_t* scale_ptr = buffer_data + n * 2 * C;
opmath_t* bias_ptr = scale_ptr + C;
opmath_t mean_val = tmp_buffer_data[n * 2 * G + 2 * g];
opmath_t rstd_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
mean_data[n * G + g] = mean_val;
rstd_data[n * G + g] = rstd_val;
for (const auto d : c10::irange(D)) {
const int64_t c = g * D + d;
scale_ptr[c] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
bias_ptr[c] = -scale_ptr[c] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
}
}
}
// step-4: apply scale and bias
//
// Parallel on on the all the outer dimensions of N and HxW
// and vectorize on C.
//
at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
int64_t n{0}, m{0};
data_index_init(begin, n, N, m, HxW);
for (const auto i : c10::irange(begin, end)) {
const T* X_ptr = X_data + i * C;
T* Y_ptr = Y_data + i * C;
opmath_t* scale_ptr = buffer_data + n * 2 * C;
opmath_t* bias_ptr = scale_ptr + C;
ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, C);
data_index_step(n, N, m, HxW);
}
});
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free