quantize_tensor_per_channel_impl Class — pytorch Architecture
Architecture documentation for the quantize_tensor_per_channel_impl class in QuantizedOpKernels.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 3871–3920
template <typename T>
void quantize_tensor_per_channel_impl(
const Tensor& rtensor,
Tensor& qtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
// TODO: channels last kernel can be made faster.
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implementation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
int64_t batches = size_to_dim_(axis, rtensor.sizes());
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
const float* in = rtensor.const_data_ptr<float>();
auto out = qtensor.data_ptr<T>();
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
for (const auto c : c10::irange(channels)) {
auto i = b * channels * elements_per_channel + e * channels + c;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channels)) {
for (const auto e : c10::irange(elements_per_channel)) {
auto i = b * channels * elements_per_channel +
c * elements_per_channel + e;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free