Home / Class/ DsDbRowwiseMomentsChannelsLast Class — pytorch Architecture

DsDbRowwiseMomentsChannelsLast Class — pytorch Architecture

Architecture documentation for the DsDbRowwiseMomentsChannelsLast class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 937–970

template <typename T, typename opmath_t>
inline std::enable_if_t<std::is_same_v<T, opmath_t>, void>
DsDbRowwiseMomentsChannelsLast(
  const T* dY_ptr,
  const T* X_ptr,
  opmath_t* ds_ptr,
  opmath_t* db_ptr,
  int64_t C) {
  using Vec = vec::Vectorized<T>;
  constexpr int64_t K = vec::Vectorized<T>::size();
  const int64_t inner_size = C / K * K;
  int64_t d = 0;
  for (; d < inner_size; d += K) {
    Vec ds_dev = Vec::loadu(ds_ptr + d);
    Vec db_vec = Vec::loadu(db_ptr + d);
    Vec x_vec = Vec::loadu(X_ptr + d);
    Vec dy_vec = Vec::loadu(dY_ptr + d);

    ds_dev += x_vec * dy_vec;
    db_vec += dy_vec;
    ds_dev.store(ds_ptr + d);
    db_vec.store(db_ptr + d);
  }
  if (C - d > 0) {
    Vec ds_dev = Vec::loadu(ds_ptr + d, C - d);
    Vec db_vec = Vec::loadu(db_ptr + d, C - d);
    Vec x_vec = Vec::loadu(X_ptr + d, C - d);
    Vec dy_vec = Vec::loadu(dY_ptr + d, C - d);
    ds_dev += x_vec * dy_vec;
    db_vec += dy_vec;
    ds_dev.store(ds_ptr + d, C - d);
    db_vec.store(db_ptr + d, C - d);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free