Home / Class/ multilabel_margin_loss_backward_out_frame Class — pytorch Architecture

multilabel_margin_loss_backward_out_frame Class — pytorch Architecture

Architecture documentation for the multilabel_margin_loss_backward_out_frame class in LossMultiLabelMargin.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossMultiLabelMargin.cpp lines 155–223

template <typename scalar_t>
void multilabel_margin_loss_backward_out_frame(
    Tensor& grad_input,
    const Tensor& grad_output,
    const Tensor& input_contiguous,
    const Tensor& target_contiguous,
    int64_t reduction,
    const Tensor& is_target_contiguous,
    int64_t nframe,
    int64_t dim) {
#ifndef STRIP_ERROR_MESSAGES
  auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5);
#endif

  TORCH_CHECK(
      is_target_contiguous.min().item<scalar_t>() >= 0, is_target_arg, " is out of range");
  TORCH_CHECK(
      is_target_contiguous.max().item<scalar_t>() <= 1, is_target_arg, " is out of range");

  const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
  const int64_t* target_data = target_contiguous.const_data_ptr<int64_t>();
  const scalar_t* is_target_data = is_target_contiguous.const_data_ptr<scalar_t>();
  scalar_t g = static_cast<scalar_t>(
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim);

  scalar_t* grad_input_row_data = grad_input.mutable_data_ptr<scalar_t>();
  for ([[maybe_unused]] const auto t : c10::irange(nframe)) {
    for (const auto dt : c10::irange(dim)) {
      int64_t target_idx = target_data[dt];
      if (target_idx < 0) {
        break;
      }

      scalar_t input_target = input_data[target_idx];
      for (const auto d : c10::irange(dim)) {
        if (!is_target_data[d]) {
          scalar_t z = 1 - input_target + input_data[d];
          if (z > 0) {
            grad_input_row_data[target_idx] -= g;
            grad_input_row_data[d] += g;
          }
        }
      }
    }
    input_data += dim;
    target_data += dim;
    is_target_data += dim;
    grad_input_row_data += dim;
  }

  scalar_t* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
  if (reduction != Reduction::None || grad_output.dim() == 0) {
    assert(
        reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1);
    const auto d = *grad_output.const_data_ptr<scalar_t>();
    for (int64_t t = 0; t < nframe * dim; t++) {
      grad_input_data[t] *= d;
    }
  } else {
    check_dim_size(grad_output, 1, 0, nframe);
    auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
    for (const auto t : c10::irange(nframe)) {
      for (const auto d : c10::irange(dim)) {
        grad_input_data[t * dim + d] *= grad_output_acc[t];
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free