Home / Class/ GroupNormBackwardKernelImplChannelsLastInternal Class — pytorch Architecture

GroupNormBackwardKernelImplChannelsLastInternal Class — pytorch Architecture

Architecture documentation for the GroupNormBackwardKernelImplChannelsLastInternal class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 1357–1526

template <typename T, typename PT>
void GroupNormBackwardKernelImplChannelsLastInternal(
    const Tensor& dY,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const Tensor& gamma,
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t group,
    Tensor& dX,
    Tensor& dgamma,
    Tensor& dbeta) {
  TORCH_CHECK(dY.numel() == N * C * HxW);
  TORCH_CHECK(X.numel() == N * C * HxW);
  TORCH_CHECK(mean.numel() == N * group);
  TORCH_CHECK(rstd.numel() == N * group);
  TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
  int64_t D = C / group;
  int64_t G = group;
  const T* dY_data = dY.const_data_ptr<T>();
  const T* X_data = X.const_data_ptr<T>();
  const PT* mean_data = mean.const_data_ptr<PT>();
  const PT* rstd_data = rstd.const_data_ptr<PT>();
  const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
  T* dX_data = dX.defined() ? dX.data_ptr<T>() : nullptr;
  PT* dgamma_data = dgamma.defined() ? dgamma.data_ptr<PT>() : nullptr;
  PT* dbeta_data = dbeta.defined() ? dbeta.data_ptr<PT>() : nullptr;
  const bool gamma_null = (gamma_data == nullptr);
  using opmath_t = at::opmath_type<T>;
  Tensor ds = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
  Tensor db = at::empty({N, C}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
  opmath_t* ds_data = ds.data_ptr<opmath_t>();
  opmath_t* db_data = db.data_ptr<opmath_t>();
  const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);

  // Similar to channels last forward, channels last backward has also 2 impls.
  // impl-1: parallel on N * G. Only need one omp session for input gradients
  //   but memory access per thread is non-contiguous.
  //
  // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
  //   but requires help of extra temp buffer of size {T, N, 2C}.

  // Generally impl-2 has better performance when HxW is large enough, so that
  //   data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
  constexpr int64_t feature_map_threshold = 2048;
  if (HxW < feature_map_threshold) {
    // impl-1: parallel on N * G.
    at::parallel_for(0, N * G, 1, [=](int64_t begin, int64_t end) {
      int64_t n{0}, g{0};
      data_index_init(begin, n, N, g, G);
      for (const auto i : c10::irange(begin, end)) {
        // Step 1. Compute internal gradients.
        opmath_t* ds_ptr = ds_data + i * D;
        opmath_t* db_ptr = db_data + i * D;
        const T* X_ptr = X_data + n * HxW * C + g * D;
        const T* dY_ptr = dY_data + n * HxW * C + g * D;
        const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
        auto [ds_gamma, db_gamma] = CalcInternalGradientsChannelsLast<T, PT, opmath_t>(
          X_ptr,
          dY_ptr,
          gamma_ptr,
          ds_ptr,
          db_ptr,
          HxW,
          C,
          D);

        // Step 2. Compute dX.
        T* dX_ptr = dX_data + n * HxW * C + g * D;
        const PT* rstd_ptr = rstd_data + i;
        const opmath_t c2 = (db_gamma * opmath_t(mean_data[i]) - ds_gamma) *
            opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * opmath_t(rstd_data[i]) * s;
        const opmath_t c3 = -c2 * opmath_t(mean_data[i]) - db_gamma * opmath_t(rstd_data[i]) * s;
        ApplyInputGradientsChannelsLastColMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
        data_index_step(n, N, g, G);
      }
    });

  } else {
    // impl-2: parallel on N * HxW.
    int num_threads = at::get_num_threads();
    Tensor buffer = at::empty({num_threads, N, 2 * C},
      X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
    opmath_t* buffer_data = buffer.data_ptr<opmath_t>();

    Tensor tmp_buffer = at::empty({N, 2 * G},
      X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
    opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();

    // Step 1. Each thread compute their own internal gradients to the buffer.
    at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
      int tid = at::get_thread_num();
      opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;
      int64_t n{0}, m{0};
      data_index_init(begin, n, N, m, HxW);
      for (const auto i : c10::irange(begin, end)) {
        opmath_t* ds_ptr = buffer_ptr + n * 2 * C;
        opmath_t* db_ptr = ds_ptr + C;
        const T* X_ptr = X_data + i * C;
        const T* dY_ptr = dY_data + i * C;

        DsDbRowwiseMomentsChannelsLast<T, opmath_t>(dY_ptr, X_ptr, ds_ptr, db_ptr, C);
        data_index_step(n, N, m, HxW);
      }
    });

    // Step 2. Collect internal gradients from each thread and
    // get the final internal gradients to ds, db, and tmp_buffer.
    for (const auto n : c10::irange(N)) {
      for (const auto g : c10::irange(G)) {
        opmath_t ds_gamma{0}, db_gamma{0};
        for (const auto d : c10::irange(D)) {
          opmath_t ds_val{0}, db_val{0};
          for (const auto t : c10::irange(num_threads)) {
            opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
            opmath_t gamma_val = gamma_null ? opmath_t(1) : opmath_t(gamma_data[g * D + d]);
            ds_gamma += buffer_ptr[g * D + d] * gamma_val;
            db_gamma += buffer_ptr[g * D + d + C] * gamma_val;
            ds_val += buffer_ptr[g * D + d];
            db_val += buffer_ptr[g * D + d + C];

            }
          ds_data[n * C + g * D + d] = ds_val;
          db_data[n * C + g * D + d] = db_val;
        }
        tmp_buffer_data[n * 2 * G + 2 * g] = ds_gamma;
        tmp_buffer_data[n * 2 * G + 2 * g + 1] = db_gamma;
      }
    }

    // Step 3. Compute dx.
    if (dX_data != nullptr) {
      at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
        int64_t n{0}, m{0};
        data_index_init(begin, n, N, m, HxW);
        for (const auto i : c10::irange(begin, end)) {
          for (const auto g : c10::irange(G)) {
            const T* X_ptr = X_data + i * C + g * D;
            const T* dY_ptr = dY_data + i * C + g * D;
            T* dX_ptr = dX_data + i * C + g * D;
            const PT* mean_ptr = mean_data + n * G + g;
            const PT* rstd_ptr = rstd_data + n * G + g;
            const PT* gamma_ptr = gamma_null ? gamma_data : (gamma_data + g * D);
            opmath_t ds_val = tmp_buffer_data[n * 2 * G + 2 * g];
            opmath_t db_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];

            const opmath_t c2 = (db_val * opmath_t(*mean_ptr) - ds_val) *
                opmath_t(*rstd_ptr) * opmath_t(*rstd_ptr)* opmath_t(*rstd_ptr) * s;
            const opmath_t c3 = -c2 * opmath_t(*mean_ptr) - db_val * opmath_t(*rstd_ptr) * s;
            ApplyInputGradientsChannelsLastRowMov<T, PT, opmath_t>(dY_ptr, X_ptr, dX_ptr, rstd_ptr, gamma_ptr, c2, c3, HxW, C, D);
          }

          data_index_step(n, N, m, HxW);
        }
      });
    }

  }

  // Finally compute dgamma and dbeta.
  if (dgamma_data != nullptr) {
    GammaBackward(
        N, C, group, mean_data, rstd_data, ds_data, db_data, dgamma_data);
  }
  if (dbeta_data != nullptr) {
    BetaBackward(N, C, db_data, dbeta_data);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free