ctc_loss_impl Class — pytorch Architecture
Architecture documentation for the ctc_loss_impl class in LossCTC.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LossCTC.cpp lines 497–548
template <typename LengthsType>
Tensor ctc_loss_impl(const Tensor& log_probs_, const Tensor& targets, LengthsType input_lengths, LengthsType target_lengths, int64_t BLANK, int64_t reduction, bool zero_infinity) {
auto is_batched = log_probs_.dim() == 3;
Tensor log_probs = is_batched ? log_probs_ : log_probs_.unsqueeze(1);
Tensor res;
// cuDNN CTC Loss (returns false on non-CUDA builds)
bool use_cudnn =
(log_probs.device().type() == at::kCUDA) &&
at::_use_cudnn_ctc_loss(
log_probs, targets, input_lengths, target_lengths, BLANK);
// MIOpen CTC Loss (returns false on non-ROCm builds)
bool use_miopen = false;
Tensor targets_cpu;
if (log_probs.device().type() == at::kCUDA) {
targets_cpu = targets.device().type() == at::kCPU
? targets.to(at::kInt)
: targets.to(Device(at::kCPU), at::kInt);
use_miopen = at::_use_miopen_ctc_loss(log_probs, targets_cpu, input_lengths, target_lengths, BLANK);
}
if (use_cudnn) {
// non-deterministic ctc loss on cudnn disabled due to inconsistent results
// see: https://github.com/pytorch/pytorch/issues/21680
res = std::get<0>(at::_cudnn_ctc_loss(log_probs, targets, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
} else if (use_miopen) {
// MIOpen CTC Loss only supports deterministic algorithm
res = std::get<0>(at::miopen_ctc_loss(log_probs, targets_cpu, input_lengths, target_lengths, BLANK, /*deterministic=*/true, zero_infinity));
} else {
// if the targets are on CPU (which you need for cuDNN/MIOpen), move them to
// GPU as a service for the user
res = std::get<0>(at::_ctc_loss(
log_probs,
targets.to(log_probs.device(), kLong),
input_lengths,
target_lengths,
BLANK,
zero_infinity));
if (zero_infinity) {
res = at::where(res == Scalar(std::numeric_limits<double>::infinity()), at::zeros({}, res.options()), res);
}
}
if (reduction == at::Reduction::Mean) {
auto target_lengths_t = get_clamped_target_length(target_lengths, res.options());
return (res / target_lengths_t).mean();
} else if (reduction == at::Reduction::Sum) {
return res.sum();
}
return is_batched ? std::move(res) : res.squeeze(0);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free