Home / Class/ GroupNormKernelImplChannelsLastInternal Class — pytorch Architecture

GroupNormKernelImplChannelsLastInternal Class — pytorch Architecture

Architecture documentation for the GroupNormKernelImplChannelsLastInternal class in group_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/group_norm_kernel.cpp lines 283–482

template <typename T, typename PT>
void GroupNormKernelImplChannelsLastInternal(
    const Tensor& X,
    const Tensor& gamma,
    const Tensor& beta,
    int64_t N,
    int64_t C,
    int64_t HxW,
    int64_t group,
    double eps,
    Tensor& Y,
    Tensor& mean,
    Tensor& rstd) {
  TORCH_CHECK(X.numel() == N * C * HxW);
  TORCH_CHECK(!gamma.defined() || gamma.numel() == C);
  TORCH_CHECK(!beta.defined() || beta.numel() == C);
  const int64_t G = group;
  const int64_t D = C / G;
  const T* X_data = X.const_data_ptr<T>();
  const PT* gamma_data = gamma.defined() ? gamma.const_data_ptr<PT>() : nullptr;
  const PT* beta_data = beta.defined() ? beta.const_data_ptr<PT>() : nullptr;
  T* Y_data = Y.data_ptr<T>();
  PT* mean_data = mean.data_ptr<PT>();
  PT* rstd_data = rstd.data_ptr<PT>();

  using opmath_t = at::opmath_type<T>;

  const opmath_t s = opmath_t(1) / static_cast<opmath_t>(D * HxW);
  const bool gamma_null = (gamma_data == nullptr);
  const bool beta_null = beta_data == nullptr;

  // NB: About algorithm chosen:
  //
  // On channels last, GroupNorm has a input shape of {N, H, W, GD},
  // Mean and rstd are collected per each n and g, which involves reduction
  // on non-adjacent dimensions. We can parallel in the following 2 impls:
  //
  // impl-1: parallel on N * G. Only need one omp session but memory access
  //   per thread is non-contiguous.
  //
  // impl-2: parallel on N * HxW. Memory access per thread is contiguous,
  //   but requires help of extra temp buffer of size {T, N, 2C}.
  //
  // Generally impl-2 has better performance when HxW is large enough, so that
  //   data per thread {NHWC / T} is much larger then temp buffer per thread {2NC}
  //
  constexpr int64_t feature_map_threshold = 1024;
  if (HxW < feature_map_threshold) {
    // impl-1: parallel on N * G.
    //
    // for each plain of HxW, scale and bias is calculated only once
    Tensor buffer = at::empty({N * G, 2 * D}, X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
    opmath_t* buffer_data = buffer.data_ptr<opmath_t>();

    at::parallel_for(0, N * G, 1, [&](int64_t begin, int64_t end) {
      int64_t n{0}, g{0};
      data_index_init(begin, n, N, g, G);
      for (const auto i : c10::irange(begin, end)) {
        // step-1: for each n and g, collect sum of x and x2
        //
        // Note that using vec::map_reduce_all here is simpler to write
        // but it is slower since horizontal reduce from vec to scalar is slow.
        // So it is better to reduce with a vec across all HxW plain,
        // and do a horizontal add just once for each {n, g}.
        //
        auto [mean_val, rstd_val] = ColumnwiseMoments(
                X_data + n * HxW * C + g * D,
                HxW,
                C,
                D);

        mean_val *= s;
        rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
        rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
        mean_data[i] = mean_val;
        rstd_data[i] = rstd_val;

        // step-2: calculate scale and bias
        opmath_t* scale_ptr = buffer_data + i * 2 * D;
        opmath_t* bias_ptr = scale_ptr + D;
        for (const auto d : c10::irange(D)) {
          const int64_t c = g * D + d;
          scale_ptr[d] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
          bias_ptr[d] = -scale_ptr[d] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
        }

        // step-3: apply scale and bias
        for (const auto m : c10::irange(HxW)) {
          const T* X_ptr = X_data + n * HxW * C + m * C + g * D;
          T* Y_ptr = Y_data + n * HxW * C + m * C + g * D;
          ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, D);
        }

        data_index_step(n, N, g, G);
      }
    });
  } else {
    // impl-2: parallel on N * HxW.
    //
    // temp buffer holding x and x2
    int num_threads = at::get_num_threads();
    Tensor buffer = at::empty({num_threads, N, 2 * C},
      X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value)).zero_();
    opmath_t* buffer_data = buffer.data_ptr<opmath_t>();
    Tensor tmp_buffer = at::empty({N, 2 * G},
      X.options().dtype(c10::CppTypeToScalarType<opmath_t>::value));
    opmath_t* tmp_buffer_data = tmp_buffer.data_ptr<opmath_t>();
    // step-1: accumulate on dimension of C
    //
    // In order to improve multi-core performance when N=1,
    // we parallel on the all the outer dimensions of N and HxW,
    // leaving the most inner dimension C for vectorization.
    //
    // Note that parallel on {N, HxW, G} is not feasible for some common configs,
    // e.g. say input shape is {1, 32, h, w} and G = 8,
    //   this will give D = 4 which is unable to take full SIMD length.
    //
    // To avoid thread conflict, we make use of a temp buffer of {T, N, 2C},
    //   firstly, reduce from {N, HxW, C} to {T, N, 2C}
    //
    at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
      int tid = at::get_thread_num();
      opmath_t* buffer_ptr = buffer_data + tid * N * 2 * C;

      int64_t n{0}, m{0};
      data_index_init(begin, n, N, m, HxW);
      for (const auto i : c10::irange(begin, end)) {
        opmath_t* mean_ptr = buffer_ptr + n * 2 * C;
        opmath_t* rstd_ptr = mean_ptr + C;
        const T* X_ptr = X_data + i * C;
        CalcMeanVar<T, opmath_t>(X_ptr, mean_ptr, rstd_ptr, C);
        data_index_step(n, N, m, HxW);
      }
    });

    // step-2: compute mean and rstd
    for (const auto n : c10::irange(N)) {
      for (const auto g : c10::irange(G)) {
        opmath_t mean_val{0}, rstd_val{0};
        for (const auto d : c10::irange(D)) {
          for (const auto t : c10::irange(num_threads)) {
            opmath_t* buffer_ptr = buffer_data + t * N * 2 * C + n * 2 * C;
            mean_val += buffer_ptr[g * D + d];
            rstd_val += buffer_ptr[g * D + d + C];
           }
        }
        mean_val *= s;
        rstd_val = std::max(rstd_val * s - mean_val * mean_val, opmath_t(0));
        rstd_val = opmath_t(1) / std::sqrt(rstd_val + eps);
        tmp_buffer_data[n * 2 * G + 2 * g] = mean_val;
        tmp_buffer_data[n * 2 * G + 2 * g + 1] = rstd_val;
      }
    }

    // step-3: compute scale and bias
    //
    // mean/rstd have shape of {N, G}, gamma/beta have shape of {G, D}.
    // And scale/bias have shape of {N, C} so that we can directly vectorize on
    // dimension of C in the final step.
    //
    // We could fuse step 3 and 4 into a single session but this way is better:
    //   a. D might be too small for vectorization;
    //   b. Avoid duplicate calculation of scale/bias, each HxW plain share the same scale/bias
    //
    for (const auto n : c10::irange(N)) {
      for (const auto g : c10::irange(G)) {
        opmath_t* scale_ptr = buffer_data + n * 2 * C;
        opmath_t* bias_ptr = scale_ptr + C;
        opmath_t mean_val = tmp_buffer_data[n * 2 * G + 2 * g];
        opmath_t rstd_val = tmp_buffer_data[n * 2 * G + 2 * g + 1];
        mean_data[n * G + g] = mean_val;
        rstd_data[n * G + g] = rstd_val;

        for (const auto d : c10::irange(D)) {
          const int64_t c = g * D + d;
          scale_ptr[c] = rstd_val * (gamma_null ? opmath_t(1) : opmath_t(gamma_data[c]));
          bias_ptr[c] = -scale_ptr[c] * mean_val + (beta_null ? opmath_t(0) : opmath_t(beta_data[c]));
        }
      }
    }

    // step-4: apply scale and bias
    //
    // Parallel on on the all the outer dimensions of N and HxW
    // and vectorize on C.
    //
    at::parallel_for(0, N * HxW, 1, [&](int64_t begin, int64_t end) {
      int64_t n{0}, m{0};
      data_index_init(begin, n, N, m, HxW);
      for (const auto i : c10::irange(begin, end)) {
        const T* X_ptr = X_data + i * C;
        T* Y_ptr = Y_data + i * C;
        opmath_t* scale_ptr = buffer_data + n * 2 * C;
        opmath_t* bias_ptr = scale_ptr + C;
        ApplyScaleBias<T, opmath_t>(Y_ptr, X_ptr, scale_ptr, bias_ptr, C);
        data_index_step(n, N, m, HxW);
      }
    });
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free