Home / Class/ nll_loss2d_backward_out_frame Class — pytorch Architecture

nll_loss2d_backward_out_frame Class — pytorch Architecture

Architecture documentation for the nll_loss2d_backward_out_frame class in LossNLL2d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossNLL2d.cpp lines 282–364

template <typename scalar_t>
void nll_loss2d_backward_out_frame(
    Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  auto weight_contiguous = optional_contiguous(weight);
  const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);

  if (reduction == at::Reduction::None) {
    check_gradout_shape_nll_loss2d(grad_output, target);

    const int64_t batch_size = input.size(0);
    const int64_t H = input.size(2);
    const int64_t W = input.size(3);

    auto grad_input_acc = grad_input.accessor<scalar_t, 4>();
    auto grad_output_acc = grad_output.accessor<const scalar_t, 3>();
    auto target_acc = target.accessor<const int64_t, 3>();

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto b : c10::irange(start, end)) {
        for (const auto h : c10::irange(H)) {
          for (const auto w : c10::irange(W)) {
            const int64_t cur_target = target_acc[b][h][w];
            if (cur_target == ignore_index) {
              continue;
            }
            const scalar_t value =
                -(weight_data ? weight_data[cur_target]
                              : static_cast<scalar_t>(1));
            const scalar_t grad_output_value = grad_output_acc[b][h][w];
            grad_input_acc[b][cur_target][h][w] = value * grad_output_value;
          }
        }
      }
    });

    return;
  }

  const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();

  TORCH_CHECK(
      grad_output.dim() <= 1 && grad_output.numel() == 1,
      "Expected a single element grad_output tensor, but got: ",
      grad_output.sizes());

  const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();

  const auto target_contiguous = target.contiguous();
  const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();

  scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();

  const int64_t batch_size = input.size(0);
  const int64_t n_classes = input.size(1);
  const int64_t map_size = input.size(2) * input.size(3);
  const int64_t sample_size = map_size * n_classes;

  const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
                                                   : grad_output_value);

  at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
    for (const auto b : c10::irange(start, end)) {
      for (const auto elem : c10::irange(map_size)) {
        const int64_t t = target_data[b * map_size + elem];

        if (t != ignore_index) {
          TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");

          const int64_t index = b * sample_size + t * map_size + elem;
          grad_input_data[index] = weight_data != nullptr ? weight_data[t] * grad
                                                          : grad;
        }
      }
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free