Home / Class/ BIT_RATE Class — pytorch Architecture

BIT_RATE Class — pytorch Architecture

Architecture documentation for the BIT_RATE class in qembeddingbag.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 35–191

template <
    typename IndexType,
    typename OffsetType,
    int BIT_RATE,
    int NUM_ELEM_PER_BYTE>
at::Tensor& embedding_lookup_fallback_impl(
    const at::Tensor& weight,
    const at::Tensor& indices,
    const at::Tensor& offsets,
    const std::optional<at::Tensor>& per_sample_weights_,
    const std::optional<at::Tensor>& compressed_indices_mapping,
    at::Tensor& output,
    const int64_t block_size,
    const int64_t output_size,
    bool include_last_offset,
    bool pruned) {
  auto* output_data = output.data_ptr<float>();
  const auto weight_data = weight.data_ptr<uint8_t>();
  const auto indices_data = indices.data_ptr<IndexType>();
  int32_t* compressed_indices_mapping_data = nullptr;
  const auto weight_sizes = weight.sizes();
  const int64_t N = weight_sizes[0];
  const int64_t weight_size = weight_sizes[1];
  const int index_size = indices.numel();

  auto accessor = offsets.accessor<OffsetType, 1>();
  std::vector<OffsetType> lengths_data;

  int64_t lower = accessor[0];
  for (const auto i : c10::irange(1, offsets.numel())) {
    lengths_data.push_back(accessor[i] - lower);
    lower = accessor[i];
  }
  if (!include_last_offset) {
    lengths_data.push_back(indices.numel() - lower);
  }

  int64_t current = 0;
  float* per_sample_weights_data = nullptr;
  if (per_sample_weights_.has_value()) {
    per_sample_weights_data = per_sample_weights_.value().data_ptr<float>();
  }
  for (const auto m : c10::irange(output_size)) {
    memset(output_data, 0, block_size * sizeof(float));
    TORCH_CHECK(
        current + lengths_data[m] <= index_size,
        "Expect the lengths data to be less than indices size");

    for (int i = 0; i < lengths_data[m]; ++i, ++current) {
      int64_t idx = -1;
      if (!pruned) {
        idx = indices_data[current];
        TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data");
      } else {
        int64_t uncompressed_idx = indices_data[current];
        int compressed_index_size = compressed_indices_mapping.value().numel();
        compressed_indices_mapping_data =
            compressed_indices_mapping.value().data_ptr<int32_t>();
        TORCH_CHECK(
            uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size,
            "Invalid indices data for Sparse Op.")
        idx = compressed_indices_mapping_data[uncompressed_idx];
        if (idx == -1) {
          continue;
        }
      }

      float weight_val = 1.0f;
      if (per_sample_weights_.has_value()) {
        weight_val = per_sample_weights_data[current];
      }
      float scale = std::numeric_limits<float>::quiet_NaN(), bias = std::numeric_limits<float>::quiet_NaN();
      if constexpr (BIT_RATE == 8) {
        const uint8_t* scale_bias =
            weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
        uint32_t scale_val_int32 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
        scale_val_int32 = scale_val_int32 |
          (scale_bias[0]) |
          (scale_bias[1] << 8) |
          (scale_bias[2] << 16) |
          (scale_bias[3] << 24);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
        scale_val_int32 = scale_val_int32 |
          (scale_bias[3]) |
          (scale_bias[2] << 8) |
          (scale_bias[1] << 16) |
          (scale_bias[0] << 24);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
        float scale_val = (reinterpret_cast<float*>(&scale_val_int32))[0];
        uint32_t bias_val_int32 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
        bias_val_int32 = bias_val_int32 |
          (scale_bias[4]) |
          (scale_bias[5] << 8) |
          (scale_bias[6] << 16) |
          (scale_bias[7] << 24);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
        bias_val_int32 = bias_val_int32 |
          (scale_bias[7]) |
          (scale_bias[6] << 8) |
          (scale_bias[5] << 16) |
          (scale_bias[4] << 24);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
        float bias_val = (reinterpret_cast<float*>(&bias_val_int32))[0];
        scale = weight_val * scale_val;
        bias = weight_val * bias_val;
      } else {
        const uint8_t* scale_bias =
            weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half);
        uint16_t scale_val_int16 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
        scale_val_int16 = scale_val_int16 |
          (scale_bias[0]) |
          (scale_bias[1] << 8);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
        scale_val_int16 = scale_val_int16 |
          (scale_bias[1]) |
          (scale_bias[0] << 8);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
        at::Half scale_val = (reinterpret_cast<at::Half*>(&scale_val_int16))[0];
        uint16_t bias_val_int16 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
        bias_val_int16 = bias_val_int16 |
          (scale_bias[2]) |
          (scale_bias[3] << 8);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
        bias_val_int16 = bias_val_int16 |
          (scale_bias[3]) |
          (scale_bias[2] << 8);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
        at::Half bias_val = (reinterpret_cast<at::Half*>(&bias_val_int16))[0];
        scale = weight_val * scale_val;
        bias = weight_val * bias_val;
      }

      for (const auto j : c10::irange(block_size)) {
        uint8_t quantized =
            weight_data[idx * weight_size + j / NUM_ELEM_PER_BYTE];
        quantized >>= (j % NUM_ELEM_PER_BYTE) * BIT_RATE;
        quantized &= (1 << BIT_RATE) - 1;

        output_data[j] = fma(scale, quantized, output_data[j] + bias);
      }
    } // for each i
    output_data += block_size;
  } // for each m
  return output;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free