Home / Class/ kai_pack_rhs_groupwise_int4 Class — pytorch Architecture

kai_pack_rhs_groupwise_int4 Class — pytorch Architecture

Architecture documentation for the kai_pack_rhs_groupwise_int4 class in kai_pack.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/kleidiai/kai_pack.h lines 10–59

template <typename T>
void kai_pack_rhs_groupwise_int4(
    T& kernel,
    const Tensor& weight_packed,
    const Tensor& weight,
    const Tensor& scales,
    const std::optional<Tensor>& bias,
    const int64_t n,
    const int64_t k,
    const int64_t bl,
    const int64_t rhs_stride,
    const int64_t scale_stride) {
  const auto& ukernel = kernel.ukernel;
  const size_t nr = ukernel.get_nr();
  const size_t kr = ukernel.get_kr();
  const size_t sr = ukernel.get_sr();
  auto weight_packed_data =
      reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
  const auto weight_data = weight.data_ptr<uint8_t>();
  auto scales_data = scales.const_data_ptr();

  if (weight_data == nullptr) {
    AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
  }

  if (scales_data == nullptr) {
    AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
  }

  float* bias_ptr =
      bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
  auto& params = kernel.rhs_pack_params;

  kernel.kai_run_rhs_pack(
      /*num_groups=*/1,
      n,
      k,
      nr,
      kr,
      sr,
      bl,
      (const uint8_t*)(weight_data),
      rhs_stride,
      bias_ptr,
      scales_data,
      scale_stride,
      weight_packed_data,
      0,
      &params);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free