QEmbeddingBag Class — pytorch Architecture
Architecture documentation for the QEmbeddingBag class in qembeddingbag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 1452–1488
class QEmbeddingBag final {
public:
static at::Tensor run(
const c10::intrusive_ptr<EmbeddingPackedParamsBase>& packed_weight,
const Tensor& indices,
const std::optional<Tensor>& offsets,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool pruned_weights,
const std::optional<Tensor>& per_sample_weights_,
const std::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (bit_rate == 8) {
return packed_weight->embeddingbag_byte(
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset,
false /* is_embedding_op */);
} else if (bit_rate == 4) {
return packed_weight->embeddingbag_4bit(
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset,
false);
} else {
TORCH_INTERNAL_ASSERT(
false,
"Currently only support 8-bit embedding_bag quantization");
}
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free