batch_norm_backward_cpu_template Class — pytorch Architecture
Architecture documentation for the batch_norm_backward_cpu_template class in Normalization.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Normalization.cpp lines 308–495
template<typename scalar_t, typename param_t>
static std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
const Tensor& grad_out_, const Tensor& input, const Tensor& weight,
const Tensor& running_mean, const Tensor& running_var, const Tensor& save_mean, const Tensor& save_invstd,
bool train, double eps, std::array<bool,3> grad_input_mask) {
using accscalar_t = at::acc_type<scalar_t, false>;
constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
const auto dtype = mixed_type ? kFloat : input.scalar_type();
Tensor grad_input;
Tensor grad_weight;
Tensor grad_bias;
if (grad_input_mask[0]) {
grad_input = at::empty_like(input, input.suggest_memory_format());
}
if (grad_input_mask[1]) {
grad_weight = at::empty({input.size(1)}, input.options().dtype(dtype));
}
if (grad_input_mask[2]) {
grad_bias = at::empty({input.size(1)}, input.options().dtype(dtype));
}
// since we are directly manipulating pointers in contiguous path,
// need to make sure input and grad_out have the same memory format.
bool all_contiguous = is_contiguous(input)
&& is_contiguous(grad_out_)
&& input.suggest_memory_format() == grad_out_.suggest_memory_format();
if (all_contiguous) {
if (grad_input_mask[0]) {
grad_input = at::empty_like(input, suggest_memory_format_contig(input));
}
batch_norm_cpu_backward_stub(kCPU, grad_input, grad_weight, grad_bias,
grad_out_, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps);
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
auto weight_a = conditional_accessor_1d<const param_t>(weight);
auto grad_weight_a = conditional_accessor_1d<param_t>(grad_weight);
auto grad_bias_a = conditional_accessor_1d<param_t>(grad_bias);
int64_t n_input = input.size(1);
int64_t n = input.numel() / n_input;
auto save_mean_a = conditional_accessor_1d<const param_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<const param_t>(save_invstd);
auto running_mean_a = conditional_accessor_1d<const param_t>(running_mean);
auto running_var_a = conditional_accessor_1d<const param_t>(running_var);
const int64_t ndim = input.dim();
// Reduce all dimensions except dim=1
DimVector reduce_dims(ndim - 1);
reduce_dims[0] = 0;
for (const auto i : c10::irange(2, ndim)) {
reduce_dims[i - 1] = i;
}
// Using float data type for Half sum to avoid overflow
// since the representation range of Half is small.
auto sum = grad_out_.scalar_type() == kHalf
? at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims)
: at::sum(grad_out_, /*dim=*/reduce_dims);
using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, float, scalar_t>;
auto sum_a = sum.accessor<sum_t, 1>();
auto reduce_iter = TensorIteratorConfig()
.add_const_input(input)
.add_const_input(grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1)
.build();
TensorIterator unary_iter;
TensorIterator binary_iter;
if (grad_input_mask[0]) {
unary_iter.build(
TensorIteratorConfig()
.add_output(grad_input)
.add_const_input(train ? input : grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1));
if (train) {
binary_iter.build(
TensorIteratorConfig()
.add_output(grad_input)
.add_input(grad_input)
.add_const_input(grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1));
}
}
auto in_channel_stride = input.strides()[1];
auto in_data = input.const_data_ptr<scalar_t>();
auto grad_in_channel_stride = grad_input_mask[0] ? grad_input.strides()[1] : 0;
auto grad_in_data = grad_input_mask[0] ? grad_input.mutable_data_ptr<scalar_t>() : nullptr;
auto grad_out_channel_stride = grad_out_.strides()[1];
auto grad_out_data = grad_out_.const_data_ptr<scalar_t>();
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
TensorIterator reduce_iter_local(reduce_iter);
TensorIterator unary_iter_local(unary_iter);
TensorIterator binary_iter_local(binary_iter);
for (const auto f : c10::irange(b_begin, b_end)) {
param_t w = weight.defined() ? weight_a[f] : param_t(1);
param_t mean{}, invstd{};
if (train) {
mean = save_mean_a[f];
invstd = save_invstd_a[f];
} else {
mean = running_mean_a[f];
invstd = 1 / std::sqrt(running_var_a[f] + eps);
}
// dot product of the Q(X) and gradOutput
accscalar_t dotp = 0;
reduce_iter_local.unsafe_replace_operand(
0, const_cast<scalar_t*>(in_data + f * in_channel_stride));
reduce_iter_local.unsafe_replace_operand(
1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
cpu_serial_kernel(reduce_iter_local, [&](const scalar_t i, const scalar_t go) -> void {
dotp += (i - mean) * go;
});
if (grad_input_mask[0]) {
if (train) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / sigma ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w
// projection of gradOutput on to output scaled by std
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
{
unary_iter_local.unsafe_replace_operand(
0, grad_in_data + f * grad_in_channel_stride);
unary_iter_local.unsafe_replace_operand(
1, const_cast<scalar_t*>(in_data + f * in_channel_stride));
cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
return (i - mean) * k;
});
}
scalar_t grad_mean = sum_a[f] / n;
{
auto gI_data = grad_in_data + f * grad_in_channel_stride;
binary_iter_local.unsafe_replace_operand(0, gI_data);
binary_iter_local.unsafe_replace_operand(1, gI_data);
binary_iter_local.unsafe_replace_operand(
2, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
cpu_serial_kernel(binary_iter_local, [&](scalar_t gi, scalar_t go) -> scalar_t {
return (go - grad_mean - gi) * invstd * w;
});
}
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
{
unary_iter_local.unsafe_replace_operand(
0, grad_in_data + f * grad_in_channel_stride);
unary_iter_local.unsafe_replace_operand(
1, const_cast<scalar_t*>(grad_out_data + f * grad_out_channel_stride));
cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
return i * invstd * w;
});
}
}
}
if (grad_input_mask[1]) {
grad_weight_a[f] = dotp * invstd;
}
if (grad_input_mask[2]) {
grad_bias_a[f] = sum_a[f];
}
}
});
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free