Source code for pytorch_end2end.encoders.text_encoders

# -*- coding: utf-8 -*-
import itertools
import string

import numpy as np


[docs]class CTCEncoder: """ Simple CTC-encoder for text """ def __init__(self, characters, blank_id=0, transform_fn=str.upper): self.blank_id = blank_id self.transform_fn = transform_fn self.char2id = dict() idx = 0 for c in characters: if idx == blank_id: idx += 1 self.char2id[c] = idx idx += 1 self.id2char = dict(zip(self.char2id.values(), self.char2id.keys())) self.id2char[self.blank_id] = "" self.num_symbols = len(self.id2char)
[docs] def clean(self, text): clean_text = "".join(c for c in self.transform_fn(text) if c in self.char2id) return clean_text
[docs] def encode(self, text): encoded_text = np.array([self.char2id[c] for c in self.transform_fn(text) if c in self.char2id]) return encoded_text
[docs] def decode(self, ids_list): text = "".join(self.id2char[idx] for idx, _ in itertools.groupby(ids_list) if idx != self.blank_id) return text
[docs] def decode_pure(self, ids_list): text = "".join(self.id2char[idx] for idx in ids_list) return text
class ASGEncoder: """ http://arxiv.org/abs/1609.03193 Encoder for Auto Segmentation Criterion (and CTC without blank) """ def __init__(self, allowed_chars=" " + string.ascii_lowercase + "'", to_lower=str.casefold): raise NotImplementedError