nll_loss2d_backward_out_frame Class — pytorch Architecture
Architecture documentation for the nll_loss2d_backward_out_frame class in LossNLL2d.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LossNLL2d.cpp lines 282–364
template <typename scalar_t>
void nll_loss2d_backward_out_frame(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {
auto weight_contiguous = optional_contiguous(weight);
const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
if (reduction == at::Reduction::None) {
check_gradout_shape_nll_loss2d(grad_output, target);
const int64_t batch_size = input.size(0);
const int64_t H = input.size(2);
const int64_t W = input.size(3);
auto grad_input_acc = grad_input.accessor<scalar_t, 4>();
auto grad_output_acc = grad_output.accessor<const scalar_t, 3>();
auto target_acc = target.accessor<const int64_t, 3>();
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
for (const auto h : c10::irange(H)) {
for (const auto w : c10::irange(W)) {
const int64_t cur_target = target_acc[b][h][w];
if (cur_target == ignore_index) {
continue;
}
const scalar_t value =
-(weight_data ? weight_data[cur_target]
: static_cast<scalar_t>(1));
const scalar_t grad_output_value = grad_output_acc[b][h][w];
grad_input_acc[b][cur_target][h][w] = value * grad_output_value;
}
}
}
});
return;
}
const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();
TORCH_CHECK(
grad_output.dim() <= 1 && grad_output.numel() == 1,
"Expected a single element grad_output tensor, but got: ",
grad_output.sizes());
const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();
const auto target_contiguous = target.contiguous();
const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
const int64_t batch_size = input.size(0);
const int64_t n_classes = input.size(1);
const int64_t map_size = input.size(2) * input.size(3);
const int64_t sample_size = map_size * n_classes;
const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
: grad_output_value);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto b : c10::irange(start, end)) {
for (const auto elem : c10::irange(map_size)) {
const int64_t t = target_data[b * map_size + elem];
if (t != ignore_index) {
TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
const int64_t index = b * sample_size + t * map_size + elem;
grad_input_data[index] = weight_data != nullptr ? weight_data[t] * grad
: grad;
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free