nll_loss_backward_out_frame Class — pytorch Architecture
Architecture documentation for the nll_loss_backward_out_frame class in LossNLL.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LossNLL.cpp lines 340–413
template <typename scalar_t, typename target_t>
void nll_loss_backward_out_frame(
const Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {
const auto n_dims = input.dim();
const auto n_classes = input.size(-1);
auto target_ = target;
if (target.dim() == 0) {
target_ = target.unsqueeze(0);
}
auto target_acc = target_.accessor<const target_t, 1>();
auto weight_contiguous = optional_contiguous(weight);
const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = input.size(0);
auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
auto cur_target = target_acc[i];
if (cur_target == ignore_index) {
continue;
}
const scalar_t w =
weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
}
});
return;
}
const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
if (input.dim() == 1) {
auto grad_input_acc = grad_input.accessor<scalar_t, 1>();
const auto t = target_acc[0];
if (t != ignore_index) {
TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
: grad_output_value);
grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
: grad;
}
} else if (input.dim() == 2) {
auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
: grad_output_value);
const auto batch_size = input.size(0);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const auto t = target_acc[i];
if (t != ignore_index) {
TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
: grad;
}
}
});
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free