Home / Class/ nll_loss2d_forward_out_frame Class — pytorch Architecture

nll_loss2d_forward_out_frame Class — pytorch Architecture

Architecture documentation for the nll_loss2d_forward_out_frame class in LossNLL2d.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossNLL2d.cpp lines 101–252

template <typename scalar_t>
void nll_loss2d_forward_out_frame(
    Tensor& output,
    Tensor& total_weight,
    const Tensor& input,
    const Tensor& target,
    const Tensor& weight,
    int64_t reduction,
    int64_t ignore_index) {
  const int64_t n_classes = input.size(1);

  scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
  *total_weight_data = 0;

  auto weight_contiguous = optional_contiguous(weight);
  const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);

  if (reduction == Reduction::None) {
    const int64_t batch_size = input.size(0);
    const int64_t H = input.size(2);
    const int64_t W = input.size(3);

    at::native::resize_output(output, {batch_size, H, W});
    auto input_acc = input.accessor<const scalar_t, 4>();
    auto output_acc = output.accessor<scalar_t, 3>();
    auto target_acc = target.accessor<const int64_t, 3>();

    at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
      for (const auto b : c10::irange(start, end)) {
        for (const auto h : c10::irange(H)) {
          for (const auto w : c10::irange(W)) {
            const int64_t cur_target = target_acc[b][h][w];

            if (cur_target == ignore_index) {
              output_acc[b][h][w] = static_cast<scalar_t>(0);
              continue;
            }

            TORCH_CHECK_INDEX(
                cur_target >= 0 && cur_target < n_classes,
                "Target ",
                cur_target,
                " is out of bounds.");

            // load optional weight value
            const scalar_t cur_weight = weight_data != nullptr
                ? weight_data[cur_target]
                : static_cast<scalar_t>(1);
            output_acc[b][h][w] = -input_acc[b][cur_target][h][w] * cur_weight;
          }
        }
      }
    });

    return;
  }

  // produce scalar outputs for the reduction case
  at::native::resize_output(output, {});

  if (target.numel() == 0) {
    // Here target (and input) have zero elements
    // Mean reduction on empty tensors produces NaN. See the discussion in
    // https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
    if (reduction == Reduction::Mean) {
      output.fill_(std::numeric_limits<double>::quiet_NaN());
    } else {
      output.zero_();
    }
    total_weight.zero_();
    return;
  }

  auto input_contiguous = input.contiguous();
  auto target_contiguous = target.contiguous();

  const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
  const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();

  const int64_t batch_size = input.size(0);
  const int64_t map_size = input.size(2) * input.size(3);
  const int64_t sample_size = map_size * n_classes;
  const int64_t numiter = batch_size * map_size;

  constexpr int64_t cascade_sum_num_levels = 8;
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
  scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
  const int64_t level_power =
      std::max(static_cast<int64_t>(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
  const int64_t level_step = (1 << level_power);
  const int64_t level_mask = level_step - 1;

  int64_t num_ignored = 0;
  for (const auto b : c10::irange(batch_size)) {
    for (const auto elem : c10::irange(map_size)) {
      const int64_t cur_target = target_data[b * map_size + elem];
      if (cur_target == ignore_index) {
        ++num_ignored;
        continue;
      }

      TORCH_CHECK_INDEX(
          cur_target >= 0 && cur_target < n_classes,
          "Target ",
          cur_target,
          " is out of bounds.");

      const auto data = input_data[b * sample_size + cur_target * map_size + elem];
      if (weight_data) {
        const scalar_t weight_val = weight_data[cur_target];
        loss_partial_sums[0] -= data * weight_val;
        weight_partial_sums[0] += weight_val;
      } else {
        loss_partial_sums[0] -= data;
      }

      const int64_t linear_idx = b * map_size + elem;
      for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
        const auto mask = (level_mask << (j * level_power));
        if (C10_LIKELY((linear_idx & mask) != 0)) {
          break;
        }

        weight_partial_sums[j + 1] += weight_partial_sums[j];
        loss_partial_sums[j + 1] += loss_partial_sums[j];

        weight_partial_sums[j] = 0;
        loss_partial_sums[j] = 0;
      }
    }
  }


  const scalar_t total_weight_val = !weight_data ?
    static_cast<scalar_t>(numiter - num_ignored) :
    std::accumulate(std::begin(weight_partial_sums),
                    std::end(weight_partial_sums),
                    scalar_t{0});

  scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
                                        std::end(loss_partial_sums),
                                        scalar_t{0});

  if (reduction == Reduction::Mean) {
    output_val /= total_weight_val;
  }

  *total_weight_data = total_weight_val;
  *output.data_ptr<scalar_t>() = output_val;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free