Home / Class/ quantize_tensor_per_channel_impl Class — pytorch Architecture

quantize_tensor_per_channel_impl Class — pytorch Architecture

Architecture documentation for the quantize_tensor_per_channel_impl class in QuantizedOpKernels.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 3871–3920

template <typename T>
void quantize_tensor_per_channel_impl(
    const Tensor& rtensor,
    Tensor& qtensor,
    const Tensor& scales,
    const Tensor& zero_points,
    int64_t axis) {
  // TODO: channels last kernel can be made faster.
  // For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
  // For channels_last/3d however axis == 0 or 1.
  // Since current implementation on channels_last format does not
  // cover per channel quant with arbitrary axis value, it is better
  // to check and fail.
  int64_t batches = size_to_dim_(axis, rtensor.sizes());
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
  int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
  int64_t channels = rtensor.size(axis);
  auto scales_data = scales.data_ptr<double>();
  auto zero_points_data = zero_points.data_ptr<int64_t>();
  const float* in = rtensor.const_data_ptr<float>();
  auto out = qtensor.data_ptr<T>();
  if (axis == 1 &&
      (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
       rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
    // This code handles per channel quant when axis = 1 and
    // channels_last contig.
    // If axis = 0 and channels_last contig, implementation for channels
    // first (NCHW) works.
    for (const auto b : c10::irange(batches)) {
      for (const auto e : c10::irange(elements_per_channel)) {
        for (const auto c : c10::irange(channels)) {
          auto i = b * channels * elements_per_channel + e * channels + c;
          out[i] = at::native::quantize_val<T>(
              scales_data[c], zero_points_data[c], in[i]);
        }
      }
    }
  } else {
    for (const auto b : c10::irange(batches)) {
      for (const auto c : c10::irange(channels)) {
        for (const auto e : c10::irange(elements_per_channel)) {
          auto i = b * channels * elements_per_channel +
              c * elements_per_channel + e;
          out[i] = at::native::quantize_val<T>(
              scales_data[c], zero_points_data[c], in[i]);
        }
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free