dequantize_per_channel_affine_kernel Class — pytorch Architecture
Architecture documentation for the dequantize_per_channel_affine_kernel class in QuantizedOpKernels.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 4135–4198
template<typename T, typename N, typename Q>
void dequantize_per_channel_affine_kernel(
const Tensor& qtensor,
Tensor& rtensor,
const Tensor& scales,
const Tensor& zero_points,
int64_t axis,
int bit_width=8) {
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implementation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<T>();
auto zero_points_data = zero_points.data_ptr<N>();
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.const_data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
const auto elem_per_byte = 8 / bit_width;
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (const auto b : c10::irange(batches)) {
for (const auto e : c10::irange(elements_per_channel)) {
for (const auto c : c10::irange(channel)) {
auto i = b * channel * elements_per_channel + e * channel + c;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
} else {
for (const auto b : c10::irange(batches)) {
for (const auto c : c10::irange(channel)) {
for (const auto e : c10::irange(elements_per_channel)) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
}
rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
}
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free