Home / Class/ nll_loss_backward_out_frame Class — pytorch Architecture

nll_loss_backward_out_frame Class — pytorch Architecture

Architecture documentation for the nll_loss_backward_out_frame class in LossNLL.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossNLL.cpp lines 340–413

template <typename scalar_t, typename target_t>
void nll_loss_backward_out_frame(
    const Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index,
    const Tensor& total_weight) {
  const auto n_dims = input.dim();
  const auto n_classes = input.size(-1);

  auto target_ = target;
  if (target.dim() == 0) {
    target_ = target.unsqueeze(0);
  }
  auto target_acc = target_.accessor<const target_t, 1>();

  auto weight_contiguous = optional_contiguous(weight);
  const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);

  if (reduction == Reduction::None && n_dims == 2) {
    const auto batch_size = input.size(0);
    auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
    auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto i : c10::irange(start, end)) {
        auto cur_target = target_acc[i];
        if (cur_target == ignore_index) {
          continue;
        }
        const scalar_t w =
            weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
        grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
      }
    });
    return;
  }

  const scalar_t total_weight_value = *total_weight.const_data_ptr<scalar_t>();

  const scalar_t grad_output_value = *grad_output.const_data_ptr<scalar_t>();

  if (input.dim() == 1) {
    auto grad_input_acc = grad_input.accessor<scalar_t, 1>();

    const auto t = target_acc[0];
    if (t != ignore_index) {
      TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
      const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
                                                       : grad_output_value);
      grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
                                                 : grad;
    }
  } else if (input.dim() == 2) {
    auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
    const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
                                                     : grad_output_value);

    const auto batch_size = input.size(0);

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto i : c10::irange(start, end)) {
        const auto t = target_acc[i];
        if (t != ignore_index) {
          TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
          grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
                                                        : grad;
        }
      }
    });
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free