embedding_bag_nbit_impl Class — pytorch Architecture
Architecture documentation for the embedding_bag_nbit_impl class in qembeddingbag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 629–799
template <typename IndexType, typename OffsetType>
at::Tensor& embedding_bag_nbit_impl(
at::Tensor& output,
const at::Tensor& weight,
const int bit_width,
const at::Tensor& indices,
const at::Tensor& offsets,
bool pruned_weights,
const std::optional<at::Tensor>& per_sample_weights_,
const std::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset,
bool is_embedding_op) {
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
auto offsets_data = offsets.data_ptr<OffsetType>();
// Get compressed indices for pruned_weights op.
int32_t* compressed_indices_mapping_data = nullptr;
int compressed_index_size = 0;
bool fallback_to_no_sparse = false;
if (pruned_weights) {
compressed_index_size = compressed_indices_mapping.value().numel();
compressed_indices_mapping_data =
compressed_indices_mapping.value().data_ptr<int32_t>();
// if compressed_indices_mapping is [0], it is a indicator that
// we should fallback to non sparse embedding look up kernel.
if ((compressed_index_size == 1 &&
compressed_indices_mapping_data[0] == 0)) {
fallback_to_no_sparse = true;
}
}
const auto weight_sizes = weight.sizes();
const int64_t weight_size = weight_sizes[1];
int NUM_ELEM_PER_BYTE = 8 / bit_width;
const int64_t D =
(weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
const int64_t M = offsets.sizes()[0];
int64_t output_size = M - 1;
std::vector<OffsetType> offsets_include_last_val;
if (!include_last_offset) {
output_size = M;
offsets_include_last_val.resize(M + 1);
// Avoid `null pointer passed as argument 2` ASAN violation when offsets
// tensor is empty.
if (M > 0) {
std::memcpy(
offsets_include_last_val.data(),
offsets_data,
sizeof(OffsetType) * M);
}
offsets_include_last_val[M] = indices.numel();
offsets_data = offsets_include_last_val.data();
}
{
std::array<int64_t, 3> shape_arr{};
c10::IntArrayRef shape;
if(indices.dim() == 2 && is_embedding_op) {
const auto indices_sizes = indices.sizes();
shape_arr[0] = indices_sizes[0];
shape_arr[1] = indices_sizes[1];
shape_arr[2] = D;
shape = shape_arr;
} else {
shape_arr[0] = output_size;
shape_arr[1] = D;
shape = c10::IntArrayRef(&shape_arr[0], 2);
}
at::native::resize_(output, shape, std::nullopt);
}
#ifdef USE_FBGEMM
const auto indices_data = indices.data_ptr<IndexType>();
const auto weight_data = weight.data_ptr<uint8_t>();
auto* output_data = output.data_ptr<float>();
const int64_t N = weight_sizes[0];
const int64_t block_size = D;
const int index_size = indices.numel();
constexpr int prefetch_distance = 16;
if (!pruned_weights || fallback_to_no_sparse) {
// Generate the fbgemm kernel
auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
/*bit rate=*/bit_width,
/*block size=*/block_size,
/*has weights=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch distance=*/prefetch_distance,
/*is_weight_positional=*/false,
/*use_offsets=*/true);
bool success = kernel(
/*output_size=*/output_size,
/*index_size=*/index_size,
/*data_size=*/N,
/*input=*/weight_data,
/*indices=*/indices_data,
/*offsets=*/offsets_data,
/*weights=*/
per_sample_weights_.has_value()
? per_sample_weights_.value().data_ptr<float>()
: nullptr,
/*output=*/output_data);
if (!success) {
fbgemm_spmdm_report_error_(
output_size, index_size, N, offsets_data, indices_data);
}
} else {
auto kernel =
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
/*bit rate=*/bit_width,
/*block_size=*/block_size,
/*has weights=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch distance*/ prefetch_distance,
/*is_weight_positional*/ false,
/*use_offsets*/ true);
bool success = kernel(
/*output_size=*/output_size,
/*index_size=*/index_size,
/*data_size=*/compressed_index_size,
/*input=*/weight_data,
/*indices=*/indices_data,
/*offsets=*/offsets_data,
/*weights=*/
per_sample_weights_.has_value()
? per_sample_weights_.value().data_ptr<float>()
: nullptr,
/*output=*/output_data,
/*compressed_indices_table=*/compressed_indices_mapping_data);
if (!success) {
fbgemm_spmdm_report_error_(
output_size,
index_size,
compressed_index_size,
offsets_data,
indices_data);
}
}
return output;
#else
if (bit_width == 4) {
return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
weight,
indices,
offsets,
per_sample_weights_,
compressed_indices_mapping,
output,
D,
output_size,
include_last_offset,
(pruned_weights && !fallback_to_no_sparse));
}
// bit_width == 2
return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
weight,
indices,
offsets,
per_sample_weights_,
compressed_indices_mapping,
output,
D,
output_size,
include_last_offset,
(pruned_weights && !fallback_to_no_sparse));
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free