embedding_bag_byte_impl Class — pytorch Architecture
Architecture documentation for the embedding_bag_byte_impl class in qembeddingbag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 801–972
template <typename IndexType, typename OffsetType>
at::Tensor& embedding_bag_byte_impl(
at::Tensor& output,
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
bool pruned_weights,
const std::optional<at::Tensor>& per_sample_weights_,
const std::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset,
bool is_embedding_op) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
auto offsets_data = offsets.data_ptr<OffsetType>();
// Get compressed indices for pruned_weights.
int32_t* compressed_indices_mapping_data = nullptr;
int compressed_index_size = 0;
bool fallback_to_no_sparse = false;
if (pruned_weights) {
compressed_index_size = compressed_indices_mapping.value().numel();
compressed_indices_mapping_data =
compressed_indices_mapping.value().data_ptr<int32_t>();
// if compressed_indices_mapping is [0], it is a indicator that
// we should fallback to non sparse embedding look up kernel.
if ((compressed_index_size == 1 &&
compressed_indices_mapping_data[0] == 0)) {
fallback_to_no_sparse = true;
}
}
const auto weight_sizes = weight.sizes();
const int64_t D = weight_sizes[1] - 8; // NB: -8 to account for scale and bias
const int64_t M = offsets.sizes()[0];
int64_t output_size = M - 1;
std::vector<OffsetType> offsets_include_last_val;
if (!include_last_offset) {
output_size = M;
offsets_include_last_val.resize(M + 1);
// Avoid `null pointer passed as argument 2` ASAN violation when offsets
// tensor is empty.
if (M > 0) {
std::memcpy(
offsets_include_last_val.data(),
offsets_data,
sizeof(OffsetType) * M);
}
offsets_include_last_val[M] = indices.numel();
offsets_data = offsets_include_last_val.data();
}
{
std::array<int64_t, 3> shape_arr{};
c10::IntArrayRef shape;
if (indices.dim() == 2 && is_embedding_op) {
const auto indices_sizes = indices.sizes();
shape_arr[0] = indices_sizes[0];
shape_arr[1] = indices_sizes[1];
shape_arr[2] = D;
shape = shape_arr;
} else {
shape_arr[0] = output_size;
shape_arr[1] = D;
shape = c10::IntArrayRef(&shape_arr[0], 2);
}
at::native::resize_(output, shape, std::nullopt);
}
#ifdef USE_FBGEMM
const int64_t N = weight_sizes[0];
const auto weight_data = weight.data_ptr<uint8_t>();
const auto indices_data = indices.data_ptr<IndexType>();
auto* output_data = output.data_ptr<float>();
const int index_size = indices.numel();
if (!pruned_weights || fallback_to_no_sparse) {
auto kernel_i8 =
fbgemm::GenerateEmbeddingSpMDM<uint8_t, IndexType, OffsetType, /*OutType=*/float, /*TRHEAD_LOCAL=*/true>(
/*block_size=*/D,
/*has_weight=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
/*is_weight_positional=*/false,
/*use_offsets=*/true);
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
bool success = kernel_i8(
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/N,
/*input=*/weight_data,
/*indices=*/indices_data + offsets_data[start_idx],
/*offsets_or_lengths=*/offsets_data + start_idx,
/*weights=*/
per_sample_weights_
? per_sample_weights_.value().const_data_ptr<float>() +
offsets_data[start_idx]
: nullptr,
/*out=*/output_data + start_idx * D);
if (!success) {
fbgemm_spmdm_report_error_(
end_idx - start_idx,
offsets_data[end_idx] - offsets_data[start_idx],
N,
offsets_data + start_idx,
indices_data + offsets_data[start_idx]);
}
});
} else {
// pruned weights
auto kernel_i8_sparse = fbgemm::
GenerateEmbeddingSpMDMRowWiseSparse<uint8_t, IndexType, OffsetType>(
/*block_size=*/D,
/*has_weight=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
/*is_weight_positional=*/false,
/*use_offsets=*/true);
auto success = kernel_i8_sparse(
/*output_size=*/output_size,
/*index_size=*/index_size,
/*data_size=*/compressed_index_size,
/*input=*/weight_data,
/*indices=*/indices_data,
/*offsets=*/offsets_data,
/*weights=*/
per_sample_weights_.has_value()
? per_sample_weights_.value().data_ptr<float>()
: nullptr,
/*output=*/output_data,
/*compressed_indices_table=*/compressed_indices_mapping_data);
if (!success) {
fbgemm_spmdm_report_error_(
output_size,
index_size,
compressed_index_size,
offsets_data,
indices_data);
}
}
return output;
#else
#if defined(__aarch64__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)
if (!(pruned_weights && !fallback_to_no_sparse) && !per_sample_weights_.has_value()) {
return embedding_lookup_byte_neon_impl<IndexType, OffsetType>(
weight,
indices,
offsets,
output,
D,
output_size,
include_last_offset);
}
#endif
return embedding_lookup_fallback_impl<IndexType, OffsetType, 8, 1>(
weight,
indices,
offsets,
per_sample_weights_,
compressed_indices_mapping,
output,
D,
output_size,
include_last_offset,
(pruned_weights && !fallback_to_no_sparse));
#endif
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free