Home / Class/ ctc_loss_impl Class — pytorch Architecture

ctc_loss_impl Class — pytorch Architecture

Architecture documentation for the ctc_loss_impl class in LossCTC.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LossCTC.cpp lines 497–548

template <typename LengthsType>
Tensor ctc_loss_impl(const Tensor& log_probs_, const Tensor& targets, LengthsType input_lengths, LengthsType target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
  auto is_batched = log_probs_.dim() == 3;
  Tensor log_probs = is_batched ? log_probs_ : log_probs_.unsqueeze(1);

  Tensor res;

  // cuDNN CTC Loss (returns false on non-CUDA builds)
  bool use_cudnn =
      (log_probs.device().type() == at::kCUDA) &&
      at::_use_cudnn_ctc_loss(
          log_probs, targets, input_lengths, target_lengths, BLANK);

  // MIOpen CTC Loss (returns false on non-ROCm builds)
  bool use_miopen = false;
  Tensor targets_cpu;
  if (log_probs.device().type() == at::kCUDA) {
    targets_cpu = targets.device().type() == at::kCPU
        ? targets.to(at::kInt)
        : targets.to(Device(at::kCPU), at::kInt);
    use_miopen = at::_use_miopen_ctc_loss(log_probs, targets_cpu, input_lengths, target_lengths, BLANK);
  }

  if (use_cudnn) {
    // non-deterministic ctc loss on cudnn disabled due to inconsistent results
    // see: https://github.com/pytorch/pytorch/issues/21680
    res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
  } else if (use_miopen) {
    // MIOpen CTC Loss only supports deterministic algorithm
    res = std::get<0>(at::miopen_ctc_loss(log_probs, targets_cpu, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
  } else {
    // if the targets are on CPU (which you need for cuDNN/MIOpen), move them to
    // GPU as a service for the user
    res = std::get<0>(at::_ctc_loss(
        log_probs,
        targets.to(log_probs.device(), kLong),
        input_lengths,
        target_lengths,
        BLANK,
        zero_infinity));
    if (zero_infinity) {
      res = at::where(res == Scalar(std::numeric_limits<double>::infinity()), at::zeros({}, res.options()), res);
    }
  }
  if (reduction == at::Reduction::Mean) {
    auto target_lengths_t = get_clamped_target_length(target_lengths, res.options());
    return (res / target_lengths_t).mean();
  } else if (reduction == at::Reduction::Sum) {
    return res.sum();
  }
  return is_batched ? std::move(res) : res.squeeze(0);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free