Source code for pytorch_end2end.modules.ctc_loss

import torch.nn as nn
import torch.nn.functional as F
import os
import sys
from importlib import import_module

from pytorch_end2end.functions.forward_backward import ForwardBackwardLossFunction

if "DEBUG_E2E" in os.environ:
    module_base = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
    build_path = os.path.join(module_base, os.getenv("DEBUG_E2E"))
    sys.path.insert(0, build_path)


class ForwardBackwardLossBase(nn.Module):
    def __init__(self, size_average=None, reduce=None, after_logsoftmax=False, time_major=False, blank_idx=0):
        super().__init__()
        self._blank_idx = blank_idx
        self._reduce = reduce
        self._size_average = size_average
        self._after_logsoftmax = after_logsoftmax
        self._time_major = time_major
        self._engine = None

    def forward(self, logits, targets, logits_lengths, targets_lengths):
        """
        Computes Loss

        :param logits: Float or Double Tensor (network output)
            of shape ``(sequence_length, batch_size, alphabet_size)`` if ``time_major`` is True,
            else of shape ``(batch_size, sequence_length, alphabet_size)``
        :param targets: Tensor with targets of shape ``(batch_size, targets_sequence_length)``
        :param logits_lengths: Tensor of shape ``(batch_size,)`` with lengths of sequences
        :param targets_lengths: Tensor of shape ``(batch_size,)`` with lengths of target sequences
        :return: tensor with CTC loss of shape ``(batch_size,)`` if ``reduce is None`` else of shape ``(1,)``
        """
        if self._after_logsoftmax:
            logits_logsoftmax = logits
        else:
            logits_logsoftmax = F.log_softmax(logits, dim=2)

        if self._time_major:
            logits_logsoftmax = logits_logsoftmax.permute(1, 0, 2)
        # shape of logits_logsoftmax now: batch_size, sequence_length, alphabet_size

        loss = ForwardBackwardLossFunction().apply(self._engine,
                                                   logits_logsoftmax,
                                                   targets,
                                                   logits_lengths,
                                                   targets_lengths)

        if self._reduce:
            if self._size_average:
                return loss.mean()
            else:
                return loss.sum()
        return loss


[docs]class CTCLoss(ForwardBackwardLossBase): """ Criterion to compute CTC Loss as described in `<http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ :param size_average: if compute average loss (only if reduce is True) :param reduce: if compute mean or average loss (if None, returns full tensor of shape ``(batch_size,)``) :param after_logsoftmax: if logsoftmax is used before passing neural network outputs \n (else takes pure network outputs) :param time_major: if logits are time major (or batch major), default ``True`` :param blank_idx: id of blank label, default ``0`` """ def __init__(self, size_average=None, reduce=None, after_logsoftmax=False, time_major=False, blank_idx=0): super().__init__(size_average, reduce, after_logsoftmax, time_major, blank_idx) cpp_ctc_loss = import_module("cpp_ctc_loss") self._engine = cpp_ctc_loss.CTCLossEngine(self._blank_idx)
class GramCTCLoss(CTCLoss): """ Criterion to compute Gram-CTC Loss as described in `<https://arxiv.org/abs/1703.00096>`_ :param blank_idx: id of blank label, only ``0`` supported :param num_base_labels: :param total_labels: :param labels2ids: dict for mapping label_id -> base labels ids :param size_average: if compute average loss (only if reduce is True) :param reduce: if compute mean or average loss (if None, returns full tensor of shape ``(batch_size,)``) :param after_logsoftmax: if logsoftmax is used before passing neural network outputs \n (else takes pure network outputs) :param time_major: if logits are time major (or batch major), default ``True`` """ def __init__(self, blank_idx, num_base_labels, total_labels, label2ids, size_average=None, reduce=None, after_logsoftmax=False, time_major=False): super().__init__(size_average, reduce, after_logsoftmax, time_major, blank_idx) if self._blank_idx != 0: raise NotImplementedError self._num_base_labels = num_base_labels self._total_labels = total_labels self._label2ids = label2ids cpp_gram_ctc_loss = import_module("cpp_gram_ctc_loss") self._engine = cpp_gram_ctc_loss.GramCTCLossEngine(self._blank_idx, self._num_base_labels, self._total_labels, self._label2ids) raise NotImplementedError