embedding_lookup_byte_neon_impl Class — pytorch Architecture
Architecture documentation for the embedding_lookup_byte_neon_impl class in qembeddingbag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp lines 280–597
template <
typename IndexType,
typename OffsetType>
at::Tensor& embedding_lookup_byte_neon_impl(
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
at::Tensor& output,
const int64_t block_size,
const int64_t output_size,
bool include_last_offset) {
auto* output_data = output.data_ptr<float>();
const auto weight_data = weight.data_ptr<uint8_t>();
const auto indices_data = indices.data_ptr<IndexType>();
const auto weight_sizes = weight.sizes();
const int64_t weight_size = weight_sizes[1];
const int index_size = indices.numel();
auto accessor = offsets.accessor<OffsetType, 1>();
std::vector<OffsetType> lengths_data;
int64_t lower = accessor[0];
for (const auto i : c10::irange(1, offsets.numel())) {
lengths_data.push_back(accessor[i] - lower);
lower = accessor[i];
}
if (!include_last_offset) {
lengths_data.push_back(indices.numel() - lower);
}
int64_t current = 0;
load_output_neon load_output;
store_output_neon store_output;
add_bias_neon add_bias;
auto zero_u8 = vdupq_n_u8(0);
for (const auto m : c10::irange(output_size)) {
memset(output_data, 0, block_size * sizeof(float));
TORCH_CHECK(
current + lengths_data[m] <= index_size,
"Expect the lengths data to be less than indices size");
int i = 0;
while (i + 15 < lengths_data[m]) {
uint8_t* wei_ptr[16];
float bias = 0.0f;
float scale[16];
float32x4_t scale_vec[16];
for (int j = 0; j < 16; ++j) {
wei_ptr[j] = weight_data + indices_data[current + j] * weight_size;
bias += *(float*)(wei_ptr[j] + weight_size - sizeof(float));
scale[j] = *(float*)(wei_ptr[j] + weight_size - 2 * sizeof(float));
scale_vec[j] = vdupq_n_f32(scale[j]);
}
auto bias_vec = vdupq_n_f32(bias);
uint32_t j = 0;
while (j + 15 < block_size) {
float32x4x4_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 16
#elif defined(__clang__)
#pragma clang loop unroll_count(16)
#endif
for (uint32_t jj = 0; jj < 16; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 16;
}
while (j + 7 < block_size) {
float32x4x2_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 16
#elif defined(__clang__)
#pragma clang loop unroll_count(16)
#endif
for (uint32_t jj = 0; jj < 16; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 8;
}
while (j < block_size) {
output_data[j] += bias;
for (uint32_t jj = 0; jj < 16; ++jj) {
output_data[j] += (float)(*(wei_ptr[jj] + j)) * scale[jj];
}
j++;
}
i+=16;
current+=16;
}
while (i + 7 < lengths_data[m]) {
uint8_t* wei_ptr[8];
float bias = 0.0f;
float scale[8];
float32x4_t scale_vec[8];
for (int j = 0; j < 8; ++j) {
wei_ptr[j] = weight_data + indices_data[current + j] * weight_size;
bias += *(float*)(wei_ptr[j] + weight_size - sizeof(float));
scale[j] = *(float*)(wei_ptr[j] + weight_size - 2 * sizeof(float));
scale_vec[j] = vdupq_n_f32(scale[j]);
}
auto bias_vec = vdupq_n_f32(bias);
uint32_t j = 0;
while (j + 15 < block_size) {
float32x4x4_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 8
#elif defined(__clang__)
#pragma clang loop unroll_count(8)
#endif
for (uint32_t jj = 0; jj < 8; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 16;
}
while (j + 7 < block_size) {
float32x4x2_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 8
#elif defined(__clang__)
#pragma clang loop unroll_count(8)
#endif
for (uint32_t jj = 0; jj < 8; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 8;
}
while (j < block_size) {
output_data[j] += bias;
for (uint32_t jj = 0; jj < 8; ++jj) {
output_data[j] += (float)(*(wei_ptr[jj] + j)) * scale[jj];
}
j++;
}
i+=8;
current+=8;
}
while (i + 3 < lengths_data[m]) {
uint8_t* wei_ptr[4];
float bias = 0.0f;
float scale[4];
float32x4_t scale_vec[4];
for (int j = 0; j < 4; ++j) {
wei_ptr[j] = weight_data + indices_data[current + j] * weight_size;
bias += *(float*)(wei_ptr[j] + weight_size - sizeof(float));
scale[j] = *(float*)(wei_ptr[j] + weight_size - 2 * sizeof(float));
scale_vec[j] = vdupq_n_f32(scale[j]);
}
auto bias_vec = vdupq_n_f32(bias);
uint32_t j = 0;
while (j + 15 < block_size) {
float32x4x4_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 4
#elif defined(__clang__)
#pragma clang loop unroll_count(4)
#endif
for (uint32_t jj = 0; jj < 4; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 16;
}
while (j + 7 < block_size) {
float32x4x2_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
#if defined(__GNUC__)
#pragma GCC unroll 4
#elif defined(__clang__)
#pragma clang loop unroll_count(4)
#endif
for (uint32_t jj = 0; jj < 4; ++jj) {
embedding_neon_kernel(wei_ptr[jj] + j, output.val[0], output.val[1], scale_vec[jj], zero_u8);
}
store_output(output, output_data, j);
j += 8;
}
while (j < block_size) {
output_data[j] += bias;
for (uint32_t jj = 0; jj < 4; ++jj) {
output_data[j] += (float)(*(wei_ptr[jj] + j)) * scale[jj];
}
j++;
}
i+=4;
current+=4;
}
while (i + 1 < lengths_data[m]) {
uint8_t* wei_ptr[2];
float bias = 0.0f;
float scale[2];
float32x4_t scale_vec[2];
for (int j = 0; j < 2; ++j) {
wei_ptr[j] = weight_data + indices_data[current + j] * weight_size;
bias += *(float*)(wei_ptr[j] + weight_size - sizeof(float));
scale[j] = *(float*)(wei_ptr[j] + weight_size - 2 * sizeof(float));
scale_vec[j] = vdupq_n_f32(scale[j]);
}
auto bias_vec = vdupq_n_f32(bias);
uint32_t j = 0;
while (j + 15 < block_size) {
float32x4x4_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
embedding_neon_kernel(wei_ptr[0] + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec[0], zero_u8);
embedding_neon_kernel(wei_ptr[1] + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec[1], zero_u8);
store_output(output, output_data, j);
j += 16;
}
while (j + 7 < block_size) {
float32x4x2_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
embedding_neon_kernel(wei_ptr[0] + j, output.val[0], output.val[1], scale_vec[0], zero_u8);
embedding_neon_kernel(wei_ptr[1] + j, output.val[0], output.val[1], scale_vec[1], zero_u8);
store_output(output, output_data, j);
j += 8;
}
while (j < block_size) {
output_data[j] += bias;
output_data[j] += (float)(*(wei_ptr[0] + j)) * scale[0];
output_data[j] += (float)(*(wei_ptr[1] + j)) * scale[1];
j++;
}
i+=2;
current+=2;
}
while (i < lengths_data[m]) {
auto wei_ptr = weight_data + indices_data[current] * weight_size;
float bias = *(float*)(wei_ptr + weight_size - sizeof(float));
auto scale = *(float*)(wei_ptr + weight_size - 2 * sizeof(float));
auto bias_vec = vdupq_n_f32(bias);
auto scale_vec = vdupq_n_f32(scale);
uint32_t j = 0;
while (j + 15 < block_size) {
float32x4x4_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
embedding_neon_kernel(wei_ptr + j, output.val[0], output.val[1], output.val[2], output.val[3], scale_vec, zero_u8);
store_output(output, output_data, j);
j += 16;
}
while (j + 7 < block_size) {
float32x4x2_t output;
load_output(output, output_data, j);
add_bias(output, bias_vec);
embedding_neon_kernel(wei_ptr + j, output.val[0], output.val[1], scale_vec, zero_u8);
store_output(output, output_data, j);
j += 8;
}
while (j < block_size) {
output_data[j] += bias;
output_data[j] += (float)(*(wei_ptr + j)) * scale;
j++;
}
++i;
++current;
}
output_data += block_size;
} // for each m
return output;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free