Home / Class/ Q8GEMM_XZP Class — pytorch Architecture

Q8GEMM_XZP Class — pytorch Architecture

Architecture documentation for the Q8GEMM_XZP class in q8gemm.cc from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm.cc lines 250–316

class Q8GEMM_XZP : public Q8GEMM {
 public:
  inline Q8GEMM_XZP(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr)
      : Q8GEMM(mr, nr, np, kr) {}
   void SetUp(const benchmark::State&) override {
    std::random_device randomDevice;
    auto rng = std::mt19937(randomDevice());
    auto s32rng =
        std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
    auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);

    a_.resize(mc() * kc());
    std::generate(a_.begin(), a_.end(), std::ref(u8rng));
    k_.resize(ncStride() * kcStride());
    std::generate(k_.begin(), k_.end(), std::ref(u8rng));
    b_.resize(roundUp(nc(), nr()));
    std::generate(b_.begin(), b_.end(), std::ref(s32rng));
    w_.resize(ncStride() * (kcStride() + sizeof(int32_t) / sizeof(uint8_t)));
    std::fill(w_.begin(), w_.end(), 127);
    pytorch_pack_swizzle_q8gemm_b(
        nc(),
        kc(),
        np(),
        kr(),
        8,
#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
        127,
        127,
#endif
        k(),
        b(),
        w());
    c_.resize(mc() * nc());
    std::fill(c_.begin(), c_.end(), 0xA5);
    aRowSums_.resize(roundUp(mc(), mr()));
    std::fill(aRowSums_.begin(), aRowSums_.end(), 0xFE01);

    requantizationParams_ =
        pytorch_qnnp_compute_requantization_params(0.75f, 127, 1, 254);
  }

   void TearDown(benchmark::State& state) override {
    state.SetItemsProcessed(
        uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
    a_.clear();
    k_.clear();
    c_.clear();
    aRowSums_.clear();
  }

  inline int32_t* aRowSums() {
    return aRowSums_.data();
  }

  inline const int32_t* aRowSums() const {
    return aRowSums_.data();
  }

  inline const pytorch_qnnp_q31_requantization_params* requantizationParams()
      const {
    return &requantizationParams_;
  }

 protected:
  std::vector<int32_t> aRowSums_;
  pytorch_qnnp_q31_requantization_params requantizationParams_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free