nll_loss_out_frame Class — pytorch Architecture
Architecture documentation for the nll_loss_out_frame class in LossNLL.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LossNLL.cpp lines 161–301
template <typename scalar_t, typename target_t>
void nll_loss_out_frame(
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
const auto n_dims = input.dim();
const auto n_classes = input.size(-1);
scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
*total_weight_data = 0;
auto weight_contiguous = optional_contiguous(weight);
const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = input.size(0);
at::native::resize_output(output, {batch_size});
auto input_acc = input.accessor<const scalar_t, 2>();
auto target_acc = target.accessor<const target_t, 1>();
auto output_acc = output.accessor<scalar_t, 1>();
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const auto cur_target = target_acc[i];
if (cur_target == ignore_index) {
output_acc[i] = 0;
continue;
}
TORCH_CHECK_INDEX(
cur_target >= 0 && cur_target < n_classes,
"Target ",
cur_target,
" is out of bounds.");
scalar_t cur_weight = weight_data != nullptr ? weight_data[cur_target]
: static_cast<scalar_t>(1);
output_acc[i] = -input_acc[i][cur_target] * cur_weight;
}
});
return;
}
// produce scalar outputs for the reduction case
at::native::resize_output(output, {});
if (target.numel() == 0) {
// Here target (and input) have zero elements
// Mean reduction on empty tensors produces NaN. See the discussion in
// https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
if (reduction == Reduction::Mean) {
output.fill_(std::numeric_limits<double>::quiet_NaN());
} else {
output.zero_();
}
total_weight.zero_();
return;
}
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
const target_t* target_data = target_contiguous.const_data_ptr<target_t>();
const int64_t ndim = input.dim();
const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
constexpr int64_t cascade_sum_num_levels = 8;
const int64_t level_power =
std::max(static_cast<int64_t>(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
const int64_t level_step = (1 << level_power);
const int64_t level_mask = level_step - 1;
int64_t num_ignored = 0;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
for (const auto b : c10::irange(batch_size)) {
const int64_t cur_target = target_data[b];
if (cur_target == ignore_index) {
++num_ignored;
continue;
}
TORCH_CHECK_INDEX(
cur_target >= 0 && cur_target < n_classes,
"Target ",
cur_target,
" is out of bounds.");
const auto data = input_data[b * n_classes + cur_target];
if (weight_data) {
const scalar_t weight_val = weight_data[cur_target];
loss_partial_sums[0] -= data * weight_val;
weight_partial_sums[0] += weight_val;
} else {
loss_partial_sums[0] -= data;
}
for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
const auto mask = (level_mask << (j * level_power));
if (C10_LIKELY((b & mask) != 0)) {
break;
}
weight_partial_sums[j + 1] += weight_partial_sums[j];
loss_partial_sums[j + 1] += loss_partial_sums[j];
weight_partial_sums[j] = 0;
loss_partial_sums[j] = 0;
}
}
const scalar_t total_weight_val = !weight_data ?
static_cast<scalar_t>(batch_size - num_ignored) :
std::accumulate(std::begin(weight_partial_sums),
std::end(weight_partial_sums),
scalar_t{0});
scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
std::end(loss_partial_sums),
scalar_t{0});
if (reduction == Reduction::Mean) {
output_val /= total_weight_val;
}
// write result to output tensors
*output.data_ptr<scalar_t>() = output_val;
*total_weight_data = total_weight_val;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free