host_softmax Class — pytorch Architecture
Architecture documentation for the host_softmax class in SoftMax.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/SoftMax.cpp lines 152–243
template <typename scalar_t>
void host_softmax(
Tensor& output,
const Tensor& input,
const int64_t dim,
bool* mask,
const std::optional<int64_t> mask_type_) {
TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
int64_t mask_type = mask_type_.value();
// If mask_type == 2, then mask_.sizes() must equal input_.sizes()
TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");
int64_t outer_size = 1;
int64_t dim_size = input.size(dim);
int64_t inner_size = 1;
for (const auto i : c10::irange(dim)) {
outer_size *= input.size(i);
}
for (int64_t i = dim + 1; i < input.dim(); ++i) {
inner_size *= input.size(i);
}
int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;
scalar_t* input_data_base = input.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
bool* mask_data_base = mask;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
parallel_for(
0, outer_size * inner_size, grain_size,
[&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
int64_t outer_idx = i / inner_size;
int64_t inner_idx = i % inner_size;
scalar_t* input_data =
input_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
// Process mask differently depending on the type:
// For a generic mask of mask_type == 2, mask shape is the same as the input shape,
// so indexing is the same.
auto mask_outer_idx = outer_idx;
if (mask_type_ == 0) {
// Optimized case: attention mask of shape LxL
// outer_idx goes over BxHxL, mask_outer_idx goes over L.
mask_outer_idx = outer_idx % input.size(2);
} else if (mask_type_ == 1) {
// Optimized case: padding mask of shape BxL
// outer_idx goes over BxHxL, mask_outer_idx goes over B.
mask_outer_idx = outer_idx / (input.size(1) * input.size(2));
}
bool* mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx;
// Calc max in softmax dim
bool is_meaningful_max = false;
scalar_t max_input = input_data[0];
for (const auto d : c10::irange(0, dim_size)) {
if (!mask_data[d * dim_stride]) {
max_input = is_meaningful_max
? std::max(max_input, input_data[d * dim_stride])
: input_data[d * dim_stride];
is_meaningful_max = true;
}
}
// Calc sum in softmax dim
acc_type<scalar_t, false> tmpsum = 0;
for (const auto d : c10::irange(dim_size)) {
scalar_t z{};
if (!mask_data[d * dim_stride]) {
z = std::exp(input_data[d * dim_stride] - max_input);
} else {
z = 0;
}
output_data[d * dim_stride] = z;
tmpsum += z;
}
if (tmpsum == 0) {
tmpsum = std::numeric_limits<scalar_t>::quiet_NaN();
} else {
tmpsum = 1 / tmpsum;
}
// update output
for (const auto d : c10::irange(dim_size)) {
output_data[d * dim_stride] *= tmpsum;
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free