batch_norm_cpu_transform_input_template Class — pytorch Architecture
Architecture documentation for the batch_norm_cpu_transform_input_template class in Normalization.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Normalization.cpp lines 134–197
template<typename scalar_t, typename param_t>
static std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
const Tensor& input, const Tensor& weight, const Tensor& bias,
const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool train, double eps, Tensor& output) {
bool all_contiguous = is_contiguous(input)
&& is_contiguous(output)
&& (!weight.defined() || weight.is_contiguous())
&& (!bias.defined() || bias.is_contiguous())
&& running_mean.is_contiguous()
&& running_var.is_contiguous();
// inference contiguous path
if (all_contiguous) {
if (input.numel() != 0) {
batch_norm_cpu_stub(kCPU, output, input, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
}
return std::make_tuple(output, save_mean, save_invstd);
}
const int64_t ndim = input.dim();
// Helper to convert 1d tensors to an nd tensor that broadcasts with input
// All elements go into the channel dimension
DimVector sizes(ndim, 1), strides(ndim, 0);
auto as_nd = [&](const Tensor& t) {
TORCH_INTERNAL_ASSERT(t.defined() && t.dim() == 1);
sizes[1] = t.sizes()[0];
strides[1] = t.strides()[0];
return t.as_strided(sizes, strides);
};
auto mean = as_nd(train ? save_mean : running_mean);
auto invstd = as_nd([&]{
if (train) {
return save_invstd;
} else {
return 1 / at::sqrt(running_var + eps);
}
}());
constexpr bool mixed_type = !std::is_same_v<scalar_t, param_t>;
const auto dtype = mixed_type ? kFloat : input.scalar_type();
auto w = weight.defined() ? as_nd(weight) :
at::detail::scalar_tensor_static(1, dtype, kCPU);
auto b = bias.defined() ? as_nd(bias) :
at::detail::scalar_tensor_static(0, dtype, kCPU);
auto iter = TensorIteratorConfig()
.add_output(output)
.add_input(input)
.add_input(mean)
.add_input(invstd)
.add_input(w)
.add_input(b)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(false)
.build();
cpu_kernel(iter, [=](scalar_t input, param_t mean, param_t invstd, param_t weight, param_t bias) -> scalar_t {
return ((input - mean) * invstd) * weight + bias;
});
return std::make_tuple(output, save_mean, save_invstd);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free