host_softmax_backward Class — pytorch Architecture
Architecture documentation for the host_softmax_backward class in SoftMax.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/SoftMax.cpp lines 245–300
template <typename scalar_t>
void host_softmax_backward(
const Tensor& gI,
const Tensor& grad,
const Tensor& output,
int64_t dim,
bool* mask = nullptr) {
int64_t outer_size = 1;
int64_t dim_size = grad.size(dim);
int64_t inner_size = 1;
for (const auto i : c10::irange(dim)) {
outer_size *= grad.size(i);
}
for (int64_t i = dim + 1; i < grad.dim(); ++i) {
inner_size *= grad.size(i);
}
int64_t dim_stride = inner_size;
int64_t outer_stride = dim_size * dim_stride;
scalar_t* gradInput_data_base = gI.data_ptr<scalar_t>();
scalar_t* output_data_base = output.data_ptr<scalar_t>();
scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
bool* mask_data_base = mask;
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
parallel_for(
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
int64_t outer_idx = i / inner_size;
int64_t inner_idx = i % inner_size;
scalar_t* gradInput_data =
gradInput_data_base + outer_idx * outer_stride + inner_idx;
scalar_t* output_data =
output_data_base + outer_idx * outer_stride + inner_idx;
const scalar_t* gradOutput_data =
gradOutput_data_base + outer_idx * outer_stride + inner_idx;
bool* mask_data = mask_data_base + outer_idx * outer_stride + inner_idx;
acc_type<scalar_t, false> sum = 0;
for (const auto d : c10::irange(dim_size)) {
if (!mask_data[d * dim_stride]) {
sum +=
gradOutput_data[d * dim_stride] * output_data[d * dim_stride];
}
}
for (const auto d : c10::irange(dim_size)) {
if (mask_data[d * dim_stride]) {
gradInput_data[d * dim_stride] = 0;
} else {
gradInput_data[d * dim_stride] = output_data[d * dim_stride] *
(gradOutput_data[d * dim_stride] - sum);
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free