Home / Class/ batch_norm_backward_cpu_template Class — pytorch Architecture

batch_norm_backward_cpu_template Class — pytorch Architecture

Architecture documentation for the batch_norm_backward_cpu_template class in Normalization.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Normalization.cpp lines 308–495

template<typename scalar_t, typename param_t>
static std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
    const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
    const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
    bool train, double eps, std::array<bool,3> grad_input_mask) {

  using accscalar_t = at::acc_type<scalar_t, false>;

  constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
  const auto dtype = mixed_type ? kFloat : input.scalar_type();

  Tensor grad_input;
  Tensor grad_weight;
  Tensor grad_bias;
  if (grad_input_mask[0]) {
    grad_input = at::empty_like(input, input.suggest_memory_format());
  }
  if (grad_input_mask[1]) {
    grad_weight = at::empty({input.size(1)}, input.options().dtype(dtype));
  }
  if (grad_input_mask[2]) {
    grad_bias = at::empty({input.size(1)}, input.options().dtype(dtype));
  }

  // since we are directly manipulating pointers in contiguous path,
  // need to make sure input and grad_out have the same memory format.
  bool all_contiguous = is_contiguous(input)
      && is_contiguous(grad_out_)
      && input.suggest_memory_format() == grad_out_.suggest_memory_format();

  if (all_contiguous) {
    if (grad_input_mask[0]) {
      grad_input = at::empty_like(input, suggest_memory_format_contig(input));
    }
    batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
        grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
    return std::make_tuple(grad_input, grad_weight, grad_bias);
  }

  auto weight_a = conditional_accessor_1d<const param_t>(weight);
  auto grad_weight_a = conditional_accessor_1d<param_t>(grad_weight);
  auto grad_bias_a = conditional_accessor_1d<param_t>(grad_bias);

  int64_t n_input = input.size(1);
  int64_t n = input.numel() / n_input;

  auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
  auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);

  auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
  auto running_var_a = conditional_accessor_1d<const param_t>(running_var);

  const int64_t ndim = input.dim();

  // Reduce all dimensions except dim=1
  DimVector reduce_dims(ndim - 1);
  reduce_dims[0] = 0;
  for (const auto i : c10::irange(2, ndim)) {
    reduce_dims[i - 1] = i;
  }
  // Using float data type for Half sum to avoid overflow
  // since the representation range of Half is small.
  auto sum = grad_out_.scalar_type() == kHalf
      ? at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims)
      : at::sum(grad_out_, /*dim=*/reduce_dims);
  using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, float, scalar_t>;
  auto sum_a = sum.accessor<sum_t, 1>();

  auto reduce_iter = TensorIteratorConfig()
      .add_const_input(input)
      .add_const_input(grad_out_)
      .resize_outputs(false)
      .declare_static_shape(input.sizes(), /*squash_dims=*/1)
      .build();

  TensorIterator unary_iter;
  TensorIterator binary_iter;
  if (grad_input_mask[0]) {
    unary_iter.build(
        TensorIteratorConfig()
        .add_output(grad_input)
        .add_const_input(train ? input : grad_out_)
        .resize_outputs(false)
        .declare_static_shape(input.sizes(), /*squash_dims=*/1));

    if (train) {
      binary_iter.build(
          TensorIteratorConfig()
          .add_output(grad_input)
          .add_input(grad_input)
          .add_const_input(grad_out_)
          .resize_outputs(false)
          .declare_static_shape(input.sizes(), /*squash_dims=*/1));
    }
  }

  auto in_channel_stride = input.strides()[1];
  auto in_data = input.const_data_ptr<scalar_t>();
  auto grad_in_channel_stride = grad_input_mask[0] ? grad_input.strides()[1] : 0;
  auto grad_in_data = grad_input_mask[0] ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
  auto grad_out_channel_stride = grad_out_.strides()[1];
  auto grad_out_data = grad_out_.const_data_ptr<scalar_t>();

  parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
      TensorIterator reduce_iter_local(reduce_iter);
      TensorIterator unary_iter_local(unary_iter);
      TensorIterator binary_iter_local(binary_iter);

      for (const auto f : c10::irange(b_begin, b_end)) {
        param_t w = weight.defined() ? weight_a[f] : param_t(1);

        param_t mean{}, invstd{};
        if (train) {
          mean = save_mean_a[f];
          invstd = save_invstd_a[f];
        } else {
          mean = running_mean_a[f];
          invstd = 1 / std::sqrt(running_var_a[f] + eps);
        }

        // dot product of the Q(X) and gradOutput
        accscalar_t dotp = 0;
        reduce_iter_local.unsafe_replace_operand(
            0, const_cast<scalar_t*>(in_data + f * in_channel_stride));
        reduce_iter_local.unsafe_replace_operand(
            1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));

        cpu_serial_kernel(reduce_iter_local, [&](const scalar_t i, const scalar_t go) -> void {
          dotp += (i - mean) * go;
        });

        if (grad_input_mask[0]) {
          if (train) {
            // when in training mode
            // Q(X) = X - E[x] ; i.e. input centered to zero mean
            // Y = Q(X) / sigma    ; i.e. BN output before weight and bias
            // dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w

            // projection of gradOutput on to output scaled by std
            scalar_t k = (scalar_t) dotp * invstd * invstd / n;
            {
              unary_iter_local.unsafe_replace_operand(
                  0, grad_in_data + f * grad_in_channel_stride);
              unary_iter_local.unsafe_replace_operand(
                  1, const_cast<scalar_t*>(in_data + f * in_channel_stride));
              cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
                return (i - mean) * k;
              });
            }

            scalar_t grad_mean = sum_a[f] / n;
            {
              auto gI_data = grad_in_data + f * grad_in_channel_stride;
              binary_iter_local.unsafe_replace_operand(0, gI_data);
              binary_iter_local.unsafe_replace_operand(1, gI_data);
              binary_iter_local.unsafe_replace_operand(
                  2, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
              cpu_serial_kernel(binary_iter_local, [&](scalar_t gi, scalar_t go) -> scalar_t {
                return (go - grad_mean - gi) * invstd * w;
              });
            }
          } else {
            // when in evaluation mode
            // Q(X) = X - running_mean  ; i.e. input centered to zero mean
            // Y = Q(X) / running_std    ; i.e. BN output before weight and bias
            // dL/dX = w / running_std
            {
              unary_iter_local.unsafe_replace_operand(
                  0, grad_in_data + f * grad_in_channel_stride);
              unary_iter_local.unsafe_replace_operand(
                  1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
              cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
                return i * invstd * w;
              });
            }
          }
        }
        if (grad_input_mask[1]) {
          grad_weight_a[f] = dotp * invstd;
        }

        if (grad_input_mask[2]) {
          grad_bias_a[f] = sum_a[f];
        }
      }
    });
  return std::make_tuple(grad_input, grad_weight, grad_bias);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free