Home / Class/ layer_norm_kernel_mixed_type Class — pytorch Architecture

layer_norm_kernel_mixed_type Class — pytorch Architecture

Architecture documentation for the layer_norm_kernel_mixed_type class in layer_norm_kernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/layer_norm_kernel.cpp lines 97–154

template <typename T, typename param_t,
          typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
void layer_norm_kernel_mixed_type(
    const Tensor& X,
    const Tensor& gamma,
    const Tensor& beta,
    int64_t M,
    int64_t N,
    float eps,
    Tensor* Y,
    Tensor* mean,
    Tensor* rstd) {
  using bVec = Vectorized<T>;
  using fVec = Vectorized<float>;
  const T* X_data = X.const_data_ptr<T>();
  const param_t* gamma_data = gamma.defined() ? gamma.const_data_ptr<param_t>() : nullptr;
  const param_t* beta_data = beta.defined() ? beta.const_data_ptr<param_t>() : nullptr;
  T* Y_data = Y->data_ptr<T>();
  param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
  param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;

  const bool gamma_null = gamma_data == nullptr;
  const bool beta_null = beta_data == nullptr;
  const bool mean_null = mean_data == nullptr;
  const bool rstd_null = rstd_data == nullptr;
  at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
    for (const auto i : c10::irange(start, end)) {
      const T* X_ptr = X_data + i * N;
      T* Y_ptr = Y_data + i * N;
      auto [mean_val, rstd_val] = RowwiseMoments(X_ptr, N);
      rstd_val = float(1) / std::sqrt(rstd_val + eps);
      const float scale = rstd_val;
      const float bias = -rstd_val * mean_val;
      int64_t d = 0;
      for (; d < N - (N % bVec::size()); d += bVec::size()) {
        bVec x_bvec = bVec::loadu(X_ptr + d);
        auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
        auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
        auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
        fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
        fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
        bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
        y_bvec.store(Y_ptr + d);
      }
      for (; d < N; d++) {
        const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
        const float beta_v = beta_null ? float(0) : float(beta_data[d]);
        Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
      }
      if (!mean_null) {
        mean_data[i] = mean_val;
      }
      if (!rstd_null) {
        rstd_data[i] = rstd_val;
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free