is_same_v Class — pytorch Architecture
Architecture documentation for the is_same_v class in LogSoftmaxKernelImpl.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h lines 214–336
template <typename scalar_t>
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
const scalar_t* input_data_base,
scalar_t* output_data_base,
int64_t inner_size,
int64_t chunk_size,
int64_t num_chunks,
int64_t dim_size,
int64_t begin,
int64_t end) {
using Vec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
auto buffer = std::make_unique<float []>(chunk_size * 2);
float* input_max_data = buffer.get();
float* tmp_sum_data = buffer.get() + chunk_size;
// thread local buffer that holds input data in float32 to save next 2 dtype conversion
auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
float* input_buffer_data = input_buffer.get();
// init
for (int64_t i = begin; i < end; i++) {
int64_t outer_idx = i / num_chunks;
int64_t k = i % num_chunks;
int64_t inner_idx_begin = k * chunk_size;
int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);
fVec zero_fvec = fVec(float(0));
fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
int64_t d0 = 0;
for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
min_fvec.store(input_max_data + d0);
min_fvec.store(input_max_data + d0 + fVec::size());
zero_fvec.store(tmp_sum_data + d0);
zero_fvec.store(tmp_sum_data + d0 + fVec::size());
}
for (; d0 < size; d0++) {
input_max_data[d0] = -std::numeric_limits<float>::infinity();
tmp_sum_data[d0] = float(0);
}
// compute max
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec data_vec = Vec::loadu(input_ptr + d1);
auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
fVec max_fvec0 = fVec::loadu(input_max_data + d1);
fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
max_fvec0.store(input_max_data + d1);
max_fvec1.store(input_max_data + d1 + fVec::size());
// cache the 'converted' float input
data_fvec0.store(input_buffer_ptr + d1);
data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
}
for (; d1 < size; d1++) {
float data_val = float(input_ptr[d1]);
float max_val = input_max_data[d1];
input_max_data[d1] = data_val > max_val ? data_val : max_val;
input_buffer_ptr[d1] = data_val;
}
}
// compute sum of (x - max).exp()
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d2);
fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
sum_fvec0 += (data_fvec0 - max_fvec0).exp();
sum_fvec1 += (data_fvec1 - max_fvec1).exp();
sum_fvec0.store(tmp_sum_data + d2);
sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
}
for (; d2 < size; d2++) {
float data_val = input_buffer_ptr[d2];
float max_val = input_max_data[d2];
tmp_sum_data[d2] += std::exp(data_val - max_val);
}
}
// apply log
vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);
// compute x - max - sum
for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
+ dim_idx * inner_size + inner_idx_begin;
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
fVec max_fvec0 = fVec::loadu(input_max_data + d3);
fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_vec.store(output_ptr + d3);
}
for (; d3 < size; d3++) {
output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
}
}
}
} // namespace CPU_CAPABILITY
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free