Home / Class/ CalcInternalGradientsChannelsLast Class — pytorch Architecture

CalcInternalGradientsChannelsLast Class — pytorch Architecture

Architecture documentation for the CalcInternalGradientsChannelsLast class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 1236–1289

template <typename T, typename PT, typename opmath_t>
inline typename std::
    enable_if<std::is_same_v<T, opmath_t>, std::tuple<opmath_t, opmath_t>>::type
    CalcInternalGradientsChannelsLast(
    const T* X_data,
    const T* dY_data,
    const PT* gamma_ptr,
    opmath_t* ds_ptr,
    opmath_t* db_ptr,
    int64_t HxW,
    int64_t C,
    int64_t D) {
  using Vec = vec::Vectorized<T>;
  const bool gamma_null = (gamma_ptr == nullptr);
  constexpr int64_t K = Vec::size();
  const int64_t inner_size = D / K * K;
  int64_t d = 0;
  opmath_t ds_gamma{0}, db_gamma{0};
  for (; d < inner_size; d += K) {
    Vec acc0_vec{0}, acc1_vec{0};
    for (const auto m : c10::irange(HxW)) {
      const T* X_ptr = X_data + m * C;
      const T* dY_ptr = dY_data + m * C;
      Vec x_vec = Vec::loadu(X_ptr + d);
      Vec dy_vec = Vec::loadu(dY_ptr + d);
      acc0_vec += x_vec * dy_vec;
      acc1_vec += dy_vec;
    }
    acc0_vec.store(ds_ptr + d);
    acc1_vec.store(db_ptr + d);
    ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
      acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
    db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
      acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d)));
  }
  if (D - d > 0) {
    Vec acc0_vec{0}, acc1_vec{0};
    for (const auto m : c10::irange(HxW)) {
      const T* X_ptr = X_data + m * C;
      const T* dY_ptr = dY_data + m * C;
      Vec x_vec = Vec::loadu(X_ptr + d, D - d);
      Vec dy_vec = Vec::loadu(dY_ptr + d, D - d);
      acc0_vec += x_vec * dy_vec;
      acc1_vec += dy_vec;
    }
    acc0_vec.store(ds_ptr + d, D - d);
    acc1_vec.store(db_ptr + d, D - d);
    ds_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
      acc0_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
    db_gamma += vec::vec_reduce_all([](Vec& x, Vec& y) { return x + y; },
      acc1_vec * (gamma_null ? Vec(1) : Vec::loadu(gamma_ptr + d, D - d)));
  }
  return std::tuple<opmath_t, opmath_t>(ds_gamma, db_gamma);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free