Home / Class/ dequantize_per_channel_affine_kernel Class — pytorch Architecture

dequantize_per_channel_affine_kernel Class — pytorch Architecture

Architecture documentation for the dequantize_per_channel_affine_kernel class in QuantizedOpKernels.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 4135–4198

template<typename T, typename N, typename Q>
void dequantize_per_channel_affine_kernel(
      const Tensor& qtensor,
      Tensor& rtensor,
      const Tensor& scales,
      const Tensor& zero_points,
      int64_t axis,
      int bit_width=8) {

  // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
  // For channels_last/3d however axis == 0 or 1.
  // Since current implementation on channels_last format does not
  // cover per channel quant with arbitrary axis value, it is better
  // to check and fail.
  TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
      "If tensor is channels_last contig then per channel quantization "
      "is supported only for axis = 0 or 1.");
  int64_t batches = size_to_dim_(axis, rtensor.sizes());
  int64_t elements_per_channel =
      // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
      size_from_dim_(axis + 1, rtensor.sizes());
  int64_t channel = rtensor.size(axis);
  auto scales_data = scales.data_ptr<T>();
  auto zero_points_data = zero_points.data_ptr<N>();
  check_tensor_memory_format(qtensor, rtensor);
  const auto* qd = qtensor.const_data_ptr<Q>();
  float* rd = rtensor.data_ptr<float>();
  const auto elem_per_byte = 8 / bit_width;
  if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
      rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
    for (const auto b : c10::irange(batches)) {
      for (const auto e : c10::irange(elements_per_channel)) {
        for (const auto c : c10::irange(channel)) {
          auto i = b * channel * elements_per_channel + e * channel + c;
          // We need to convert the qint8 value to float to ensure the
          // subtraction subexpression returns a float
          auto qvalue = qd[i / elem_per_byte].val_;
          if (bit_width < 8) {
            qvalue >>= (i % elem_per_byte) * bit_width;
            qvalue &= (1 << bit_width) - 1;
          }
          rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
        }
      }
    }
  } else {
    for (const auto b : c10::irange(batches)) {
      for (const auto c : c10::irange(channel)) {
        for (const auto e : c10::irange(elements_per_channel)) {
          auto i = b * channel * elements_per_channel +
              c * elements_per_channel + e;
          // We need to convert the qint8 value to float to ensure the
          // subtraction subexpression returns a float
          auto qvalue = qd[i / elem_per_byte].val_;
          if (bit_width < 8) {
            qvalue >>= (i % elem_per_byte) * bit_width;
            qvalue &= (1 << bit_width) - 1;
          }
          rd[i] = (static_cast<float>(qvalue) - zero_points_data[c]) * scales_data[c];
        }
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free