Home / Class/ embedding_bag_nbit_impl Class — pytorch Architecture

embedding_bag_nbit_impl Class — pytorch Architecture

Architecture documentation for the embedding_bag_nbit_impl class in qembeddingbag.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 629–799

template <typename IndexType, typename OffsetType>
at::Tensor& embedding_bag_nbit_impl(
    at::Tensor& output,
    const at::Tensor& weight,
    const int bit_width,
    const at::Tensor& indices,
    const at::Tensor& offsets,
    bool pruned_weights,
    const std::optional<at::Tensor>& per_sample_weights_,
    const std::optional<at::Tensor>& compressed_indices_mapping,
    bool include_last_offset,
    bool is_embedding_op) {
  TORCH_CHECK(weight.dim() == 2);
  TORCH_CHECK(offsets.dim() == 1);

  auto offsets_data = offsets.data_ptr<OffsetType>();

  // Get compressed indices for pruned_weights op.
  int32_t* compressed_indices_mapping_data = nullptr;
  int compressed_index_size = 0;
  bool fallback_to_no_sparse = false;
  if (pruned_weights) {
    compressed_index_size = compressed_indices_mapping.value().numel();
    compressed_indices_mapping_data =
        compressed_indices_mapping.value().data_ptr<int32_t>();

    // if compressed_indices_mapping is [0], it is a indicator that
    // we should fallback to non sparse embedding look up kernel.
    if ((compressed_index_size == 1 &&
         compressed_indices_mapping_data[0] == 0)) {
      fallback_to_no_sparse = true;
    }
  }

  const auto weight_sizes = weight.sizes();
  const int64_t weight_size = weight_sizes[1];
  int NUM_ELEM_PER_BYTE = 8 / bit_width;
  const int64_t D =
      (weight_size - 2 * sizeof(at::Half)) * NUM_ELEM_PER_BYTE; // NB: 2-byte fp16 scale and 2-byte zero_offset
  const int64_t M = offsets.sizes()[0];

  int64_t output_size = M - 1;
  std::vector<OffsetType> offsets_include_last_val;
  if (!include_last_offset) {
    output_size = M;
    offsets_include_last_val.resize(M + 1);
    // Avoid `null pointer passed as argument 2` ASAN violation when offsets
    // tensor is empty.
    if (M > 0) {
      std::memcpy(
          offsets_include_last_val.data(),
          offsets_data,
          sizeof(OffsetType) * M);
    }
    offsets_include_last_val[M] = indices.numel();
    offsets_data = offsets_include_last_val.data();
  }
  {
    std::array<int64_t, 3> shape_arr{};
    c10::IntArrayRef shape;
    if(indices.dim() == 2 && is_embedding_op) {
      const auto indices_sizes = indices.sizes();
      shape_arr[0] = indices_sizes[0];
      shape_arr[1] = indices_sizes[1];
      shape_arr[2] = D;
      shape = shape_arr;
    } else {
      shape_arr[0] = output_size;
      shape_arr[1] = D;
      shape = c10::IntArrayRef(&shape_arr[0], 2);
    }
    at::native::resize_(output, shape, std::nullopt);
  }
#ifdef USE_FBGEMM
  const auto indices_data = indices.data_ptr<IndexType>();
  const auto weight_data = weight.data_ptr<uint8_t>();
  auto* output_data = output.data_ptr<float>();
  const int64_t N = weight_sizes[0];

  const int64_t block_size = D;
  const int index_size = indices.numel();
  constexpr int prefetch_distance = 16;
  if (!pruned_weights || fallback_to_no_sparse) {
    // Generate the fbgemm kernel
    auto kernel = fbgemm::GenerateEmbeddingSpMDMNBit<IndexType, OffsetType>(
        /*bit rate=*/bit_width,
        /*block size=*/block_size,
        /*has weights=*/per_sample_weights_.has_value(),
        /*normalize_by_lengths=*/false,
        /*prefetch distance=*/prefetch_distance,
        /*is_weight_positional=*/false,
        /*use_offsets=*/true);

    bool success = kernel(
        /*output_size=*/output_size,
        /*index_size=*/index_size,
        /*data_size=*/N,
        /*input=*/weight_data,
        /*indices=*/indices_data,
        /*offsets=*/offsets_data,
        /*weights=*/
        per_sample_weights_.has_value()
            ? per_sample_weights_.value().data_ptr<float>()
            : nullptr,
        /*output=*/output_data);

    if (!success) {
      fbgemm_spmdm_report_error_(
          output_size, index_size, N, offsets_data, indices_data);
    }
  } else {
    auto kernel =
        fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<IndexType, OffsetType>(
            /*bit rate=*/bit_width,
            /*block_size=*/block_size,
            /*has weights=*/per_sample_weights_.has_value(),
            /*normalize_by_lengths=*/false,
            /*prefetch distance*/ prefetch_distance,
            /*is_weight_positional*/ false,
            /*use_offsets*/ true);
    bool success = kernel(
        /*output_size=*/output_size,
        /*index_size=*/index_size,
        /*data_size=*/compressed_index_size,
        /*input=*/weight_data,
        /*indices=*/indices_data,
        /*offsets=*/offsets_data,
        /*weights=*/
        per_sample_weights_.has_value()
            ? per_sample_weights_.value().data_ptr<float>()
            : nullptr,
        /*output=*/output_data,
        /*compressed_indices_table=*/compressed_indices_mapping_data);
    if (!success) {
      fbgemm_spmdm_report_error_(
          output_size,
          index_size,
          compressed_index_size,
          offsets_data,
          indices_data);
    }
  }
  return output;
#else
  if (bit_width == 4) {
    return embedding_lookup_fallback_impl<IndexType, OffsetType, 4, 2>(
      weight,
      indices,
      offsets,
      per_sample_weights_,
      compressed_indices_mapping,
      output,
      D,
      output_size,
      include_last_offset,
      (pruned_weights && !fallback_to_no_sparse));
  }
  // bit_width == 2
  return embedding_lookup_fallback_impl<IndexType, OffsetType, 2, 4>(
    weight,
    indices,
    offsets,
    per_sample_weights_,
    compressed_indices_mapping,
    output,
    D,
    output_size,
    include_last_offset,
    (pruned_weights && !fallback_to_no_sparse));
#endif
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free