Home / Class/ batch_norm_cpu_update_stats_template Class — pytorch Architecture

batch_norm_cpu_update_stats_template Class — pytorch Architecture

Architecture documentation for the batch_norm_cpu_update_stats_template class in Normalization.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Normalization.cpp lines 199–287

template<typename scalar_t, typename param_t, template<typename T> class VarTransform>
static std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
    const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
    double momentum, double eps, Tensor& save_mean, Tensor& save_var_transform) {

  using accscalar_t = at::acc_type<scalar_t, false>;

  int64_t n_input = input.size(1);
  TORCH_CHECK(input.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", input.sizes());
  int64_t n = input.numel() / n_input;

  bool all_contiguous = is_contiguous(input);
  constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
  // Using float data type for Half _var_sum in batchnorm stats updating on CPU
  // to avoid _var_sum overflow since the representation range of Half is small.
  using opmath_t = std::conditional_t<std::is_same_v<param_t, at::Half>, at::opmath_type<param_t>, param_t>;
  auto dtype = mixed_type ? kFloat : input.scalar_type();
  if (dtype == kHalf) {
    dtype = kFloat;
  }

  auto save_mean_a = save_mean.accessor<param_t, 1>();
  auto save_var_transform_a = save_var_transform.accessor<param_t, 1>();

  auto running_mean_a = conditional_accessor_1d<param_t>(running_mean);
  auto running_var_a = conditional_accessor_1d<param_t>(running_var);

  if (all_contiguous) {
    auto _mean = at::empty({n_input}, input.options().dtype(dtype));
    auto _var_sum = at::empty({n_input}, input.options().dtype(dtype));
    auto _mean_a = _mean.accessor<opmath_t, 1>();
    auto _var_sum_a = _var_sum.accessor<opmath_t, 1>();
    auto momentum_ = static_cast<opmath_t>(momentum);

    batch_norm_cpu_collect_stats_stub(kCPU, _mean, _var_sum, input);

    parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
      for (const auto f : c10::irange(b_begin, b_end)) {
        save_mean_a[f] = _mean_a[f];
        save_var_transform_a[f] = VarTransform<accscalar_t>{}(_var_sum_a[f] / n, eps);

        if (running_mean.defined()) {
          running_mean_a[f] = momentum_ * _mean_a[f] + (1 - momentum_) * running_mean_a[f];
        }
        if (running_var.defined()) {
          accscalar_t unbiased_var = _var_sum_a[f] / (n - 1);
          running_var_a[f] = momentum_ * unbiased_var + (1 - momentum_) * running_var_a[f];
        }
      }
    });

    return std::make_tuple(save_mean, save_var_transform);
  }

  // non-contiguous path
  auto channel_stride = input.strides()[1];
  auto in_data = input.data_ptr<scalar_t>();
  auto reduce_iter = TensorIteratorConfig()
      .add_input(input)
      .resize_outputs(false)
      .declare_static_shape(input.sizes(), /*squash_dims=*/1)
      .check_all_same_dtype(false)
      .promote_inputs_to_common_dtype(false)
      .build();

  parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
    TensorIterator iter(reduce_iter);
    for (const auto f : c10::irange(b_begin, b_end)) {
      // compute variance per input
      iter.unsafe_replace_operand(0, in_data + channel_stride * f);
      accscalar_t var_sum = 0;
      auto mean = static_cast<accscalar_t>(save_mean_a[f]);
      cpu_serial_kernel(iter, [&](const scalar_t i) -> void {
        var_sum += (i - mean) * (i - mean);
      });
      save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);

      // update running averages
      if (running_mean.defined()) {
        running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
      }
      if (running_var.defined()) {
        accscalar_t unbiased_var = var_sum / (n - 1);
        running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
      }
    }
  });
  return std::make_tuple(save_mean, save_var_transform);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free