@@ -97,12 +97,12 @@ def __init__(self, sample_rate=8000, window_size=0.02, window_stride=0.01,
9797# print("PADDING: {}".format(pad_to))
9898
9999 torch_windows = {
100- 'hann' : torch .hann_window ,
101- 'hamming' : torch .hamming_window ,
102- 'blackman' : torch .blackman_window ,
103- 'bartlett' : torch .bartlett_window ,
104- 'none' : None ,
105- }
100+ 'hann' : torch .hann_window ,
101+ 'hamming' : torch .hamming_window ,
102+ 'blackman' : torch .blackman_window ,
103+ 'bartlett' : torch .bartlett_window ,
104+ 'none' : None ,
105+ }
106106
107107 self .win_length = int (sample_rate * window_size ) # frame size
108108 self .hop_length = int (sample_rate * window_stride )
@@ -123,16 +123,16 @@ def __init__(self, sample_rate=8000, window_size=0.02, window_stride=0.01,
123123 window_tensor = window_fn (self .win_length ,
124124 periodic = False ) if window_fn else None
125125 filterbanks = torch .tensor (
126- librosa .filters .mel (sr = sample_rate , n_fft = self .n_fft , n_mels = nfilt , fmin = lowfreq ,
127- fmax = highfreq ), dtype = torch .float ).unsqueeze (0 )
126+ librosa .filters .mel (sr = sample_rate , n_fft = self .n_fft , n_mels = nfilt , fmin = lowfreq ,
127+ fmax = highfreq ), dtype = torch .float ).unsqueeze (0 )
128128 # self.fb = filterbanks
129129 # self.window = window_tensor
130130 self .register_buffer ("fb" , filterbanks )
131131 self .register_buffer ("window" , window_tensor )
132132 # Calculate maximum sequence length (# frames)
133133 max_length = 1 + math .ceil (
134- (max_duration * sample_rate - self .win_length ) / self .hop_length
135- )
134+ (max_duration * sample_rate - self .win_length ) / self .hop_length
135+ )
136136 max_pad = 16 - (max_length % 16 )
137137 self .max_length = max_length + max_pad
138138
@@ -141,9 +141,9 @@ def get_seq_len(self, seq_len):
141141 seq_len = (seq_len + self .frame_splicing - 1 ) // self .frame_splicing
142142 return seq_len
143143
144- @torch .no_grad ()
145144 def forward (self , inp : Tuple [torch .Tensor , torch .Tensor ]) -> torch .Tensor :
146- x , seq_len = inp
145+ with torch .no_grad ():
146+ x , seq_len = inp
147147
148148 dtype = x .dtype
149149
@@ -162,7 +162,8 @@ def forward(self, inp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
162162 # do stft
163163 x = torch .stft (x , n_fft = self .n_fft , hop_length = self .hop_length ,
164164 win_length = self .win_length ,
165- center = True , window = self .window .to (dtype = torch .float ))
165+ center = True , window = self .window .to (dtype = torch .float ), return_complex = True )
166+ x = torch .view_as_real (x )
166167
167168 # get power spectrum
168169 x = x .pow (2 ).sum (- 1 )
@@ -244,9 +245,9 @@ def from_config(cls, cfg, log=False):
244245
245246class FeatureFactory (object ):
246247 featurizers = {
247- "logfbank" : FilterbankFeatures ,
248- "fbank" : FilterbankFeatures ,
249- }
248+ "logfbank" : FilterbankFeatures ,
249+ "fbank" : FilterbankFeatures ,
250+ }
250251
251252 def __init__ (self ):
252253 pass
0 commit comments