Home / Class/ ReLUFused Class — pytorch Architecture

ReLUFused Class — pytorch Architecture

Architecture documentation for the ReLUFused class in QuantizedOpKernels.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp lines 58–219

template <bool ReLUFused = false>
Tensor qcat_nhwc_kernel(
    const MaterializedITensorListRef& qxs,
    int64_t dim,
    double scale,
    int64_t zero_point) {
  const at::Tensor& qx0 = qxs[0];
  int64_t C_out = 0;
  std::vector<int64_t> Cs_in;
  // Prefix sum of input channels for fast indexing
  std::vector<int64_t> Cs_sum;
  std::vector<double> scales;
  std::vector<int64_t> zero_pts;
  std::vector<void*> data_ptrs;
  std::vector<bool> is_fast_path;

  for (const at::Tensor& qx : qxs) {
    TORCH_CHECK(
        qx.dim() == qx0.dim(),
        "Tensors must have the same number of dimensions: got ",
        qx.dim(),
        " and ",
        qx0.dim());
#define CHECK_DIM(d)                                            \
  TORCH_CHECK(                                                  \
      qx.size(d) == qx0.size(d),                                \
      "Sizes of tensors must match expect in dimension 1. Got", \
      qx.size(d),                                               \
      " and ",                                                  \
      qx0.size(d));
    CHECK_DIM(0);
    CHECK_DIM(2);
    CHECK_DIM(3);
    TORCH_CHECK(
        qx.scalar_type() == qx0.scalar_type(),
        "Expected object of scalar type ",
        toString(qx0.scalar_type()),
        " but got scalar type ",
        toString(qx.scalar_type()));
    Cs_in.push_back(qx.size(1));
    Cs_sum.push_back(C_out);
    C_out += qx.size(1);
    scales.push_back(qx.q_scale());
    zero_pts.push_back(qx.q_zero_point());
    data_ptrs.push_back(qx.data_ptr());
    is_fast_path.push_back(
        qx.q_scale() == scale &&
        qx.q_zero_point() == zero_point);
  }

  const int64_t N = qx0.size(0);
  const int64_t H = qx0.size(2);
  const int64_t W = qx0.size(3);
  float inv_scale = static_cast<float>(1.0 / scale);

  auto output = at::_empty_affine_quantized(
      {N, C_out, H, W},
      qx0.options().memory_format(MemoryFormat::ChannelsLast),
      scale,
      zero_point,
      std::nullopt);

  // N, H, and W are explicitly captured here because there's a bug in GCC5
  // and clang5 which causes an internal compiler error if they're not
  AT_DISPATCH_QINT_TYPES(output.scalar_type(), "qcat_nhwc", [&, N, H, W]() {
    using Vec = Vectorized<scalar_t>;
    at::parallel_for(0, N * H * W, 0, [&](int64_t begin, int64_t end) {
      for (const auto i : c10::irange(begin, end)) {
        // loop over input tensors
        for (const auto tidx : c10::irange(Cs_in.size())) {
          scalar_t::underlying* optr =
              reinterpret_cast<scalar_t::underlying*>(output.data_ptr()) +
              i * C_out + Cs_sum[tidx];

          auto curr_C = Cs_in[tidx];
          float curr_scale = scales[tidx];
          int64_t curr_zero_pt = zero_pts[tidx];

          scalar_t::underlying* iptr =
              reinterpret_cast<scalar_t::underlying*>(data_ptrs[tidx]) +
              i * curr_C;

          if (is_fast_path[tidx] && !ReLUFused) {
            std::memcpy(optr, iptr, curr_C * sizeof(typename scalar_t::underlying));
            continue;
          }

          constexpr auto VLEN = Vec::size();
          int64_t c = 0;

          // Vectorized loop
          if (c + VLEN <= curr_C) {
            auto curr_scale_vec = Vectorized<float>(curr_scale);
            auto curr_zero_pt_vec = Vectorized<float>(curr_zero_pt);
            auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
            for (; c + VLEN <= curr_C; c += VLEN) {
              auto inp_vec = Vec::loadu(iptr + c);
              auto float_values = inp_vec.dequantize(
                  curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
              Vec::float_vec_return_type retvals;
              for (int i = 0; i < Vec::float_num_vecs(); ++i) {
                if constexpr (ReLUFused) {
                  retvals[i] =
                      vec::maximum(float_values[i], Vectorized<float>(0.0f));
                } else {
                  retvals[i] = float_values[i];
                }
              }
              auto quantized =
                  Vec::quantize(retvals, scale, zero_point, inv_scale);
              quantized.store(optr + c);
            }
          }

          // Vectorized loop for channel between 8 and 32 (avx2)
          constexpr auto kVLEN = Vectorized<float>::size();
          int64_t elem_size = curr_C - c;
          if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
            auto curr_scale_vec = Vectorized<float>(curr_scale);
            auto curr_zero_pt_vec = Vectorized<float>(curr_zero_pt);
            auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
            int64_t vec_num = elem_size / kVLEN;
            std::array<typename scalar_t::underlying, VLEN> buf_in{};
            memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
            auto inp_vec = Vec::loadu(buf_in.data());
            auto float_values = inp_vec.dequantize(
                curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
            Vec::float_vec_return_type retvals;
            for (int i = 0; i < vec_num; ++i) {
              if constexpr (ReLUFused) {
                retvals[i] =
                    vec::maximum(float_values[i], Vectorized<float>(0.0f));
              } else {
                retvals[i] = float_values[i];
              }
            }
            auto quantized =
                Vec::quantize(retvals, scale, zero_point, inv_scale);
            quantized.store(optr + c, vec_num * kVLEN);
            c += vec_num * kVLEN;
          }

          // Scalar loop
          for (; c < curr_C; ++c) {
            auto float_val = at::native::dequantize_val(
                curr_scale,
                curr_zero_pt,
                reinterpret_cast<scalar_t*>(iptr)[c]);
            if constexpr (ReLUFused) {
              float_val = std::max(0.0f, float_val);
            }
            optr[c] = at::native::quantize_val<scalar_t>(
                          scale, zero_point, float_val)
                          .val_;
          } // for c
        } // for tidx
      } // for i
    });
  });

  return output;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free