BIT_RATE Class — pytorch Architecture
Architecture documentation for the BIT_RATE class in qembeddingbag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 35–191
template <
typename IndexType,
typename OffsetType,
int BIT_RATE,
int NUM_ELEM_PER_BYTE>
at::Tensor& embedding_lookup_fallback_impl(
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
const std::optional<at::Tensor>& per_sample_weights_,
const std::optional<at::Tensor>& compressed_indices_mapping,
at::Tensor& output,
const int64_t block_size,
const int64_t output_size,
bool include_last_offset,
bool pruned) {
auto* output_data = output.data_ptr<float>();
const auto weight_data = weight.data_ptr<uint8_t>();
const auto indices_data = indices.data_ptr<IndexType>();
int32_t* compressed_indices_mapping_data = nullptr;
const auto weight_sizes = weight.sizes();
const int64_t N = weight_sizes[0];
const int64_t weight_size = weight_sizes[1];
const int index_size = indices.numel();
auto accessor = offsets.accessor<OffsetType, 1>();
std::vector<OffsetType> lengths_data;
int64_t lower = accessor[0];
for (const auto i : c10::irange(1, offsets.numel())) {
lengths_data.push_back(accessor[i] - lower);
lower = accessor[i];
}
if (!include_last_offset) {
lengths_data.push_back(indices.numel() - lower);
}
int64_t current = 0;
float* per_sample_weights_data = nullptr;
if (per_sample_weights_.has_value()) {
per_sample_weights_data = per_sample_weights_.value().data_ptr<float>();
}
for (const auto m : c10::irange(output_size)) {
memset(output_data, 0, block_size * sizeof(float));
TORCH_CHECK(
current + lengths_data[m] <= index_size,
"Expect the lengths data to be less than indices size");
for (int i = 0; i < lengths_data[m]; ++i, ++current) {
int64_t idx = -1;
if (!pruned) {
idx = indices_data[current];
TORCH_CHECK((idx >= 0 && idx < N), "Invalid indices data");
} else {
int64_t uncompressed_idx = indices_data[current];
int compressed_index_size = compressed_indices_mapping.value().numel();
compressed_indices_mapping_data =
compressed_indices_mapping.value().data_ptr<int32_t>();
TORCH_CHECK(
uncompressed_idx >= 0 && uncompressed_idx < compressed_index_size,
"Invalid indices data for Sparse Op.")
idx = compressed_indices_mapping_data[uncompressed_idx];
if (idx == -1) {
continue;
}
}
float weight_val = 1.0f;
if (per_sample_weights_.has_value()) {
weight_val = per_sample_weights_data[current];
}
float scale = std::numeric_limits<float>::quiet_NaN(), bias = std::numeric_limits<float>::quiet_NaN();
if constexpr (BIT_RATE == 8) {
const uint8_t* scale_bias =
weight_data + (idx + 1) * weight_size - 2 * sizeof(float);
uint32_t scale_val_int32 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
scale_val_int32 = scale_val_int32 |
(scale_bias[0]) |
(scale_bias[1] << 8) |
(scale_bias[2] << 16) |
(scale_bias[3] << 24);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
scale_val_int32 = scale_val_int32 |
(scale_bias[3]) |
(scale_bias[2] << 8) |
(scale_bias[1] << 16) |
(scale_bias[0] << 24);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
float scale_val = (reinterpret_cast<float*>(&scale_val_int32))[0];
uint32_t bias_val_int32 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
bias_val_int32 = bias_val_int32 |
(scale_bias[4]) |
(scale_bias[5] << 8) |
(scale_bias[6] << 16) |
(scale_bias[7] << 24);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
bias_val_int32 = bias_val_int32 |
(scale_bias[7]) |
(scale_bias[6] << 8) |
(scale_bias[5] << 16) |
(scale_bias[4] << 24);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
float bias_val = (reinterpret_cast<float*>(&bias_val_int32))[0];
scale = weight_val * scale_val;
bias = weight_val * bias_val;
} else {
const uint8_t* scale_bias =
weight_data + (idx + 1) * weight_size - 2 * sizeof(at::Half);
uint16_t scale_val_int16 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
scale_val_int16 = scale_val_int16 |
(scale_bias[0]) |
(scale_bias[1] << 8);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
scale_val_int16 = scale_val_int16 |
(scale_bias[1]) |
(scale_bias[0] << 8);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
at::Half scale_val = (reinterpret_cast<at::Half*>(&scale_val_int16))[0];
uint16_t bias_val_int16 = 0;
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
bias_val_int16 = bias_val_int16 |
(scale_bias[2]) |
(scale_bias[3] << 8);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
bias_val_int16 = bias_val_int16 |
(scale_bias[3]) |
(scale_bias[2] << 8);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
at::Half bias_val = (reinterpret_cast<at::Half*>(&bias_val_int16))[0];
scale = weight_val * scale_val;
bias = weight_val * bias_val;
}
for (const auto j : c10::irange(block_size)) {
uint8_t quantized =
weight_data[idx * weight_size + j / NUM_ELEM_PER_BYTE];
quantized >>= (j % NUM_ELEM_PER_BYTE) * BIT_RATE;
quantized &= (1 << BIT_RATE) - 1;
output_data[j] = fma(scale, quantized, output_data[j] + bias);
}
} // for each i
output_data += block_size;
} // for each m
return output;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free