|
| 1 | +import collections |
1 | 2 | from collections import defaultdict, deque
|
2 | 3 | from functools import partial
|
3 | 4 |
|
4 | 5 | import termcolor
|
| 6 | +import torch |
5 | 7 | import torch.nn as nn
|
6 | 8 | import torch.nn.init as init
|
7 | 9 |
|
@@ -129,3 +131,81 @@ def __init__(self):
|
129 | 131 |
|
130 | 132 | def init_logger():
|
131 | 133 | return Logger()
|
| 134 | + |
| 135 | + |
| 136 | +class StrLabelConverter(object): |
| 137 | + """Convert between str and label. |
| 138 | + NOTE: |
| 139 | + Insert `blank` to the alphabet for CTC. |
| 140 | + Args: |
| 141 | + alphabet (str): set of the possible characters. |
| 142 | + ignore_case (bool, default=True): whether or not to ignore all of the case. |
| 143 | + """ |
| 144 | + |
| 145 | + def __init__(self, alphabet, ignore_case=False): |
| 146 | + self._ignore_case = ignore_case |
| 147 | + if self._ignore_case: |
| 148 | + alphabet = alphabet.lower() |
| 149 | + self.alphabet = alphabet |
| 150 | + |
| 151 | + self.dict = {} |
| 152 | + # 0 for blank |
| 153 | + for i, char in enumerate(iter(self.alphabet),1): |
| 154 | + self.dict[char] = i |
| 155 | + |
| 156 | + |
| 157 | + def encode(self, text): |
| 158 | + """Support batch or single str. |
| 159 | + Args: |
| 160 | + text (str or list of str): texts to convert. |
| 161 | + Returns: |
| 162 | + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| 163 | + torch.IntTensor [n]: length of each text. |
| 164 | + """ |
| 165 | + if isinstance(text, str): |
| 166 | + text = [ |
| 167 | + self.dict.get(char.lower() if self._ignore_case else char, 0) |
| 168 | + for char in text |
| 169 | + ] |
| 170 | + length = [len(text)] |
| 171 | + elif isinstance(text, collections.Iterable): |
| 172 | + length = [len(s) for s in text] |
| 173 | + text = ''.join(text) |
| 174 | + text, _ = self.encode(text) |
| 175 | + return (torch.IntTensor(text), torch.IntTensor(length)) |
| 176 | + |
| 177 | + def decode(self, t, length, raw=False): |
| 178 | + """Decode encoded texts back into strs. |
| 179 | + Args: |
| 180 | + torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. |
| 181 | + torch.IntTensor [n]: length of each text. |
| 182 | + Raises: |
| 183 | + AssertionError: when the texts and its length does not match. |
| 184 | + Returns: |
| 185 | + text (str or list of str): texts to convert. |
| 186 | + """ |
| 187 | + if length.numel() == 1: |
| 188 | + length = length.item() |
| 189 | + assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), |
| 190 | + length) |
| 191 | + if raw: |
| 192 | + return ''.join([self.alphabet[i - 1] for i in t]) |
| 193 | + else: |
| 194 | + char_list = [] |
| 195 | + for i in range(length): |
| 196 | + if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
| 197 | + char_list.append(self.alphabet[t[i] - 1]) |
| 198 | + return ''.join(char_list) |
| 199 | + else: |
| 200 | + # batch mode |
| 201 | + assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( |
| 202 | + t.numel(), length.sum()) |
| 203 | + texts = [] |
| 204 | + index = 0 |
| 205 | + for i in range(length.numel()): |
| 206 | + l = length[i] |
| 207 | + texts.append( |
| 208 | + self.decode( |
| 209 | + t[index:index + l], torch.IntTensor([l]), raw=raw)) |
| 210 | + index += l |
| 211 | + return texts |
0 commit comments