Home / Class/ multi_margin_loss_backward_cpu_kernel Class — pytorch Architecture

multi_margin_loss_backward_cpu_kernel Class — pytorch Architecture

Architecture documentation for the multi_margin_loss_backward_cpu_kernel class in LossMultiMargin.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossMultiMargin.cpp lines 150–207

template <typename scalar_t>
void multi_margin_loss_backward_cpu_kernel(
    scalar_t* grad_input_data,
    const Tensor& grad_output,
    const scalar_t* input_data,
    const int64_t* target_data,
    int p,
    scalar_t margin,
    scalar_t g,
    const scalar_t* weight_data,
    int64_t nframe,
    int64_t dim,
    int64_t reduction) {
  scalar_t* grad_input_row_data = grad_input_data;
  for (const auto t : c10::irange(nframe)) {
    int64_t target_idx = target_index_checked(target_data, t, dim);
    scalar_t input_target = input_data[target_idx];
    scalar_t grad_input_target = 0;
    for (const auto d : c10::irange(dim)) {
      scalar_t z = margin - input_target + input_data[d];
      if (d == target_idx) {
        continue;
      }

      if (z > 0) {
        scalar_t h = (p == 1) ? g : 2 * g * z;
        if (weight_data != nullptr) {
          h *= weight_data[target_idx];
        }
        grad_input_target -= h;
        grad_input_row_data[d] = h;
      } else {
        grad_input_row_data[d] = 0;
      }
    }
    grad_input_row_data[target_idx] = grad_input_target;

    input_data += dim;
    grad_input_row_data += dim;
  }

  if (reduction != Reduction::None || grad_output.dim() == 0) {
    assert(
        reduction != Reduction::None || grad_output.dim() > 0 ||
        nframe == 1); // check 1d scalar fallback-case
    const auto d = *grad_output.const_data_ptr<scalar_t>();
    for (int64_t t = 0; t < nframe * dim; t++) {
      grad_input_data[t] *= d;
    }
  } else {
    auto grad_output_acc = grad_output.accessor<const scalar_t, 1>();
    for (const auto t : c10::irange(nframe)) {
      for (const auto d : c10::irange(dim)) {
        grad_input_data[t * dim + d] *= grad_output_acc[t];
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free