Skip to content

Commit e549347

Browse files
committed
增加字符串和label的转换器用于识别过程的训练与推断
1 parent 676dfae commit e549347

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import collections
12
from collections import defaultdict, deque
23
from functools import partial
34

45
import termcolor
6+
import torch
57
import torch.nn as nn
68
import torch.nn.init as init
79

@@ -129,3 +131,81 @@ def __init__(self):
129131

130132
def init_logger():
131133
return Logger()
134+
135+
136+
class StrLabelConverter(object):
137+
"""Convert between str and label.
138+
NOTE:
139+
Insert `blank` to the alphabet for CTC.
140+
Args:
141+
alphabet (str): set of the possible characters.
142+
ignore_case (bool, default=True): whether or not to ignore all of the case.
143+
"""
144+
145+
def __init__(self, alphabet, ignore_case=False):
146+
self._ignore_case = ignore_case
147+
if self._ignore_case:
148+
alphabet = alphabet.lower()
149+
self.alphabet = alphabet
150+
151+
self.dict = {}
152+
# 0 for blank
153+
for i, char in enumerate(iter(self.alphabet),1):
154+
self.dict[char] = i
155+
156+
157+
def encode(self, text):
158+
"""Support batch or single str.
159+
Args:
160+
text (str or list of str): texts to convert.
161+
Returns:
162+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
163+
torch.IntTensor [n]: length of each text.
164+
"""
165+
if isinstance(text, str):
166+
text = [
167+
self.dict.get(char.lower() if self._ignore_case else char, 0)
168+
for char in text
169+
]
170+
length = [len(text)]
171+
elif isinstance(text, collections.Iterable):
172+
length = [len(s) for s in text]
173+
text = ''.join(text)
174+
text, _ = self.encode(text)
175+
return (torch.IntTensor(text), torch.IntTensor(length))
176+
177+
def decode(self, t, length, raw=False):
178+
"""Decode encoded texts back into strs.
179+
Args:
180+
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
181+
torch.IntTensor [n]: length of each text.
182+
Raises:
183+
AssertionError: when the texts and its length does not match.
184+
Returns:
185+
text (str or list of str): texts to convert.
186+
"""
187+
if length.numel() == 1:
188+
length = length.item()
189+
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
190+
length)
191+
if raw:
192+
return ''.join([self.alphabet[i - 1] for i in t])
193+
else:
194+
char_list = []
195+
for i in range(length):
196+
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
197+
char_list.append(self.alphabet[t[i] - 1])
198+
return ''.join(char_list)
199+
else:
200+
# batch mode
201+
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
202+
t.numel(), length.sum())
203+
texts = []
204+
index = 0
205+
for i in range(length.numel()):
206+
l = length[i]
207+
texts.append(
208+
self.decode(
209+
t[index:index + l], torch.IntTensor([l]), raw=raw))
210+
index += l
211+
return texts

0 commit comments

Comments
 (0)