Source code for pytorch_end2end.decoders.ctc_decoder

import os
import sys
from collections import namedtuple

import torch
import torch.nn.functional as F

if "DEBUG_E2E" in os.environ:
    module_base = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
    build_path = os.path.join(module_base, os.getenv("DEBUG_E2E"))
    sys.path.append(build_path)

import cpp_ctc_decoder  # noqa: E402


class CTCDecoderError(Exception):
    pass


DecoderResults = namedtuple("DecoderResults", ["decoded_targets",
                                               "decoded_targets_lengths",
                                               "decoded_sentences"])


[docs]class CTCDecoder: """ Decoder class to perform CTC decoding :param beam_width: width of beam (number of stored hypotheses), default ``100``. \n If ``1``, decoder always perform greedy (argmax) decoding :param after_logsoftmax: if log_logits (logits after log softmax), default ``False`` \n (with False decoder expects pure logits, not after softmax). Greedy decoding ignores this parameter and can work with pure logits, or after log softmax :param blank_idx: id of blank label, default ``0`` :param time_major: if logits are time major (else batch major) :param labels: list of strings with labels (including blank symbol), e.g. ``["_", "a", "b", "c"]`` :param lm_path: path to language model (ARPA format or gzipped ARPA) :param lmwt: language model weight, default ``1.0``, makes sense only if language model is present :param wip: word insertion penalty, default ``1.0``, makes sense only if labels are present :param oov_penalty: penalty for each oov word, default ``-10.0`` :param case_sensitive: obtain language model scores with respect to case, default ``False`` """ def __init__(self, beam_width=100, after_logsoftmax=False, blank_idx=0, time_major=False, labels=None, lm_path=None, lmwt=1.0, wip=1.0, oov_penalty=-10, case_sensitive=True): self._beam_width = beam_width self._blank_idx = blank_idx self._after_logsoftmax = after_logsoftmax self._labels = labels or [] self._lm_path = os.path.abspath(lm_path) if lm_path else "" self._lmwt = lmwt self._wip = wip self._oov_penalty = oov_penalty self._time_major = time_major self._case_sensitive = case_sensitive self._check_params() self._decoder = cpp_ctc_decoder.CTCDecoder(self._blank_idx, self._beam_width, self._labels, self._lm_path, self._lmwt, self._wip, self._oov_penalty, self._case_sensitive) def _check_params(self): # TODO: Check all params if self._lm_path: if self._labels is None: raise CTCDecoderError("To decode with language model you should pass labels") if not os.path.isfile(self._lm_path): raise CTCDecoderError("Can't find a model: {}".format(self._lm_path))
[docs] def decode(self, logits, logits_lengths=None): """ Performs prefix beam search decoding as described in `<https://arxiv.org/abs/1408.2873>`_ :param logits: tensor with neural network outputs after logsoftmax \n of shape ``(sequence_length, batch_size, alphabet_size)`` if ``time_major`` \n else of shape ``(batch_size, sequence_length, alphabet_size)`` :param logits_lengths: default ``None`` :return: ``namedtuple(decoded_targets, decoded_targets_lengths, decoded_sentences)`` \n decoded_targets: tensor with result targets of shape ``(batch_size, sequence_length)``, doesn't contain blank symbols \n decoded_targets_length: tensor with lengths of decoded targets \n decoded_sentences: list of strings, shape ``(batch_size)``. If ``labels are None``, list of empty string is returned. \n """ if self._beam_width == 1: return self.decode_greedy(logits, logits_lengths) with torch.no_grad(): if not self._after_logsoftmax: logits = F.log_softmax(logits, -1) if self._time_major: logits = logits.transpose(1, 0) # batch_size * sequence_length * alphabet_size logits = logits.detach().cpu() batch_size = logits.size()[0] max_sequence_length = logits.size()[1] if logits_lengths is None: logits_lengths = torch.zeros(batch_size, dtype=torch.int).fill_(max_sequence_length) else: logits_lengths = logits_lengths.cpu() decoded_targets, decoded_targets_lengths, decoded_sentences = self._decoder.decode( logits_=logits, logits_lengths_=logits_lengths) return DecoderResults(decoded_targets, decoded_targets_lengths, decoded_sentences)
def _print_scores_for_sentence(self, words): self._decoder.print_scores_for_sentence(words)
[docs] def decode_greedy(self, logits, logits_lengths=None): """ Performs greedy (argmax) decoding :param logits: tensor with neural network outputs after logsoftmax \n of shape ``(sequence_length, batch_size, alphabet_size)`` if ``time_major`` \n else of shape ``(batch_size, sequence_length, alphabet_size)`` :param logits_lengths: default ``None`` :return: ``(decoded_targets, decoded_targets_lengths, decoded_sentences)`` \n decoded_targets: tensor with result targets of shape ``(batch_size, sequence_length)``, doesn't contain blank symbols \n decoded_targets_length: tensor with lengths of decoded targets \n decoded_sentences: list of strings, shape ``(batch_size)``. If ``labels are None``, list of empty string is returned. \n """ if self._time_major: logits = logits.transpose(1, 0) # batch_size * sequence_length * alphabet_size logits = logits.detach().cpu() batch_size = logits.size()[0] max_sequence_length = logits.size()[1] if logits_lengths is None: logits_lengths = torch.zeros(batch_size, dtype=torch.int).fill_(max_sequence_length) else: logits_lengths = logits_lengths.cpu() decoded_targets, decoded_targets_lengths, decoded_sentences = self._decoder.decode_greedy( logits_=logits, logits_lengths_=logits_lengths) return DecoderResults(decoded_targets, decoded_targets_lengths, decoded_sentences)