Home / Class/ GroupNormInputBackward Class — pytorch Architecture

GroupNormInputBackward Class — pytorch Architecture

Architecture documentation for the GroupNormInputBackward class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 652–706

template <typename T, typename PT, typename opmath_t>
void GroupNormInputBackward(
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t group,
    const T* dY,
    const T* X,
    const PT* mean,
    const PT* rstd,
    const PT* gamma,
    const opmath_t* ds,
    const opmath_t* db,
    T* dX) {
  const int64_t G = group;
  const int64_t D = C / G;
  const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
  const bool gamma_null = (gamma == nullptr);
  at::parallel_for(0, N * G, 1, [=](int64_t start, int64_t end) {
    constexpr int64_t K = vec::Vectorized<PT>::size();
    const int64_t d = D / K * K;
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
    std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> ds_arr;
    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
    std::array<opmath_t, at::vec::Vectorized<opmath_t>::size()> db_arr;
    for (const auto i : c10::irange(start, end)) {
      const int64_t g = i % G;
      const opmath_t* ds_ptr = ds + i * D;
      const opmath_t* db_ptr = db + i * D;
      const PT* gamma_ptr = gamma_null ? nullptr : (gamma + g * D);
      CalcDsDb(ds_ptr, db_ptr, gamma_ptr, d, K, ds_arr.data(), db_arr.data());
      opmath_t ds_val = std::accumulate(ds_arr.cbegin(), ds_arr.cend(), opmath_t(0));
      opmath_t db_val = std::accumulate(db_arr.cbegin(), db_arr.cend(), opmath_t(0));
      for (const auto j : c10::irange(d, D)) {
        const opmath_t gamma_v = gamma_null ? opmath_t(1) : opmath_t(gamma[g * D + j]);
        ds_val += ds_ptr[j] * gamma_v;
        db_val += db_ptr[j] * gamma_v;
      }
      const opmath_t c2 =
          (db_val * opmath_t(mean[i]) - ds_val) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * opmath_t(rstd[i]) * s;
      const opmath_t c3 = -c2 * opmath_t(mean[i]) - db_val * opmath_t(rstd[i]) * s;

      for (const auto j : c10::irange(D)) {
        const int64_t c = g * D + j;
        const T* dY_ptr = dY + (i * D + j) * HxW;
        const T* X_ptr = X + (i * D + j) * HxW;
        T* dX_ptr = dX + (i * D + j) * HxW;
        const opmath_t c1 = opmath_t(rstd[i]) * (gamma_null ? opmath_t(1) : opmath_t(gamma[c]));
        for (const auto k : c10::irange(HxW)) {
          dX_ptr[k] = c1 * opmath_t(dY_ptr[k]) + c2 * opmath_t(X_ptr[k]) + c3;
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free