Home / Class/ HGEMM Class — pytorch Architecture

HGEMM Class — pytorch Architecture

Architecture documentation for the HGEMM class in hgemm.cc from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/qnnpack/bench/hgemm.cc lines 36–144

class HGEMM : public benchmark::Fixture {
 public:
  inline HGEMM(uint32_t mr, uint32_t nr, uint32_t kr)
      : mr_(mr), nr_(nr), kr_(kr), mc_(mr), nc_(nr), kc_(kr) {}

   void SetUp(const benchmark::State&) override {
    const uint_fast32_t seed =
        std::chrono::steady_clock::now().time_since_epoch().count();
    auto rng = std::bind(
        fp16_ieee_from_fp32_value,
        std::bind(std::uniform_real_distribution<float>(), std::mt19937(seed)));

    a_.resize(mc() * kc());
    std::generate(a_.begin(), a_.end(), std::ref(rng));
    k_.resize(nc() * kc());
    std::generate(k_.begin(), k_.end(), std::ref(rng));
    b_.resize(nc());
    std::generate(b_.begin(), b_.end(), std::ref(rng));
    w_.resize(ncStride() * kcStride() + ncStride());
    std::fill(w_.begin(), w_.end(), 0);
    pytorch_pack_hgemm_w(nc(), kc(), nr(), kr(), k(), b(), w());
    c_.resize(mc() * nc());
    std::fill(c_.begin(), c_.end(), UINT16_C(0x7E00) /* NaN */);
  }

   void TearDown(benchmark::State& state) override {
    state.SetItemsProcessed(
        uint64_t(state.iterations()) * 2 * mc() * nc() * kc());
    a_.clear();
    k_.clear();
    b_.clear();
    w_.clear();
    c_.clear();
  }

  inline const uint16_t* a() const {
    return a_.data();
  }

  inline const uint16_t* k() const {
    return k_.data();
  }

  inline const uint16_t* b() const {
    return b_.data();
  }

  inline uint16_t* w() {
    return w_.data();
  }

  inline const uint16_t* w() const {
    return w_.data();
  }

  inline uint16_t* c() {
    return c_.data();
  }

  inline uint32_t mr() const {
    return mr_;
  }

  inline uint32_t mc() const {
    return mc_;
  }

  inline uint32_t nr() const {
    return nr_;
  }

  inline uint32_t nc() const {
    return nc_;
  }

  inline uint32_t ncStride() const {
    return roundUp(nc(), nr());
  }

  inline uint32_t kr() const {
    return kr_;
  }

  inline uint32_t kc() const {
    return kc_;
  }

  inline uint32_t kcStride() const {
    return roundUp(kc(), kr());
  }

  inline const pytorch_qnnp_fp16_clamping_params* clampingParams() const {
    return &clampingParams_;
  }

 protected:
  std::vector<uint16_t> a_;
  std::vector<uint16_t> k_;
  std::vector<uint16_t> b_;
  std::vector<uint16_t, AlignedAllocator<uint16_t, 32>> w_;
  std::vector<uint16_t> c_;
  uint32_t mr_{0};
  uint32_t nr_{0};
  uint32_t kr_{0};
  uint32_t mc_{mr_};
  uint32_t nc_{nr_};
  uint32_t kc_{kr_};
  pytorch_qnnp_fp16_clamping_params clampingParams_{0x3C00, 0x7C00, 0xFC00};
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free