Skip to content

Commit eea9ee7

Browse files
authored
Fixes mlcommons#1402, make rnnt run with pytorch >=2.0 (mlcommons#1506)
1 parent 531f7c4 commit eea9ee7

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

speech_recognition/rnnt/pytorch/parts/features.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def __init__(self, sample_rate=8000, window_size=0.02, window_stride=0.01,
9797
# print("PADDING: {}".format(pad_to))
9898

9999
torch_windows = {
100-
'hann': torch.hann_window,
101-
'hamming': torch.hamming_window,
102-
'blackman': torch.blackman_window,
103-
'bartlett': torch.bartlett_window,
104-
'none': None,
105-
}
100+
'hann': torch.hann_window,
101+
'hamming': torch.hamming_window,
102+
'blackman': torch.blackman_window,
103+
'bartlett': torch.bartlett_window,
104+
'none': None,
105+
}
106106

107107
self.win_length = int(sample_rate * window_size) # frame size
108108
self.hop_length = int(sample_rate * window_stride)
@@ -123,16 +123,16 @@ def __init__(self, sample_rate=8000, window_size=0.02, window_stride=0.01,
123123
window_tensor = window_fn(self.win_length,
124124
periodic=False) if window_fn else None
125125
filterbanks = torch.tensor(
126-
librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq,
127-
fmax=highfreq), dtype=torch.float).unsqueeze(0)
126+
librosa.filters.mel(sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq,
127+
fmax=highfreq), dtype=torch.float).unsqueeze(0)
128128
# self.fb = filterbanks
129129
# self.window = window_tensor
130130
self.register_buffer("fb", filterbanks)
131131
self.register_buffer("window", window_tensor)
132132
# Calculate maximum sequence length (# frames)
133133
max_length = 1 + math.ceil(
134-
(max_duration * sample_rate - self.win_length) / self.hop_length
135-
)
134+
(max_duration * sample_rate - self.win_length) / self.hop_length
135+
)
136136
max_pad = 16 - (max_length % 16)
137137
self.max_length = max_length + max_pad
138138

@@ -141,9 +141,9 @@ def get_seq_len(self, seq_len):
141141
seq_len = (seq_len + self.frame_splicing - 1) // self.frame_splicing
142142
return seq_len
143143

144-
@torch.no_grad()
145144
def forward(self, inp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
146-
x, seq_len = inp
145+
with torch.no_grad():
146+
x, seq_len = inp
147147

148148
dtype = x.dtype
149149

@@ -162,7 +162,8 @@ def forward(self, inp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
162162
# do stft
163163
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length,
164164
win_length=self.win_length,
165-
center=True, window=self.window.to(dtype=torch.float))
165+
center=True, window=self.window.to(dtype=torch.float), return_complex = True)
166+
x = torch.view_as_real(x)
166167

167168
# get power spectrum
168169
x = x.pow(2).sum(-1)
@@ -244,9 +245,9 @@ def from_config(cls, cfg, log=False):
244245

245246
class FeatureFactory(object):
246247
featurizers = {
247-
"logfbank": FilterbankFeatures,
248-
"fbank": FilterbankFeatures,
249-
}
248+
"logfbank": FilterbankFeatures,
249+
"fbank": FilterbankFeatures,
250+
}
250251

251252
def __init__(self):
252253
pass

0 commit comments

Comments
 (0)