@@ -30,9 +30,11 @@ def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
30
30
def forward (self , x ):
31
31
32
32
batch_size , seq_len , feat_dim = x .size ()
33
- num_padding_frames = (self .downsample_rate - seq_len % self .downsample_rate ) % self .downsample_rate
33
+ num_padding_frames = (
34
+ self .downsample_rate - seq_len % self .downsample_rate
35
+ ) % self .downsample_rate
34
36
if num_padding_frames > 0 :
35
- x = torch .nn .functional .pad (x , (0 , 0 , 0 , num_padding_frames ))
37
+ x = torch .nn .functional .pad (x , (0 , 0 , 0 , num_padding_frames ))
36
38
seq_len = x .size (1 )
37
39
38
40
x = x .contiguous ()
@@ -62,12 +64,14 @@ def __init__(
62
64
self ,
63
65
encoder_embed : nn .Module ,
64
66
encoder : EncoderInterface ,
67
+ ctc_output : nn .Module ,
65
68
llm : nn .Module ,
66
69
encoder_projector : nn .Module ,
67
70
):
68
71
super ().__init__ ()
69
72
self .encoder_embed = encoder_embed
70
73
self .encoder = encoder
74
+ self .ctc_output = ctc_output
71
75
self .llm = llm
72
76
self .encoder_projector = encoder_projector
73
77
@@ -186,7 +190,7 @@ def _merge_input_ids_with_speech_features(
186
190
(final_attention_mask == 0 ), 1
187
191
)
188
192
189
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
193
+ # 6. Mask compressed_output the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
190
194
batch_indices , pad_indices = torch .where (
191
195
input_ids == self .llm .config .pad_token_id
192
196
)
@@ -230,6 +234,57 @@ def forward_encoder(
230
234
231
235
return encoder_out , encoder_out_lens
232
236
237
+ def ctc_compress (
238
+ self ,
239
+ encoder_out : torch .Tensor ,
240
+ encoder_out_lens : torch .Tensor ,
241
+ blank_id : int = 0 ,
242
+ ) -> torch .Tensor :
243
+ """
244
+ Remove frames from encoder_out where CTC argmax predicts blank.
245
+ Args:
246
+ encoder_out: Tensor of shape (N, T, C), encoder output.
247
+ encoder_out_lens: Tensor of shape (N,), lengths before padding.
248
+ blank_id: CTC blank token ID (default: 0).
249
+
250
+ Returns:
251
+ Compressed CTC output of shape (N, T', C).
252
+ """
253
+ # 1. Compute CTC argmax predictions
254
+ ctc_output = self .ctc_output (encoder_out )
255
+ ctc_preds = ctc_output .argmax (dim = - 1 )
256
+
257
+ # 2. Create non-blank, non-pad mask
258
+ padding_mask = make_pad_mask (encoder_out_lens )
259
+ non_blank_mask = (ctc_preds != blank_id ) & (~ padding_mask )
260
+
261
+ # 3. Compute lengths after compress
262
+ compressed_lens = non_blank_mask .sum (dim = 1 )
263
+ max_len = compressed_lens .max ().item ()
264
+
265
+ # 4. Pre-pad output
266
+ pad_lens_list = (
267
+ torch .full_like (
268
+ compressed_lens ,
269
+ max_len ,
270
+ device = ctc_output .device ,
271
+ )
272
+ - compressed_lens
273
+ )
274
+ max_pad_len = int (pad_lens_list .max ())
275
+ padded_ctc_output = torch .nn .functional .pad (ctc_output , [0 , 0 , 0 , max_pad_len ])
276
+
277
+ # 5. Create final mask
278
+ padding_mask = ~ make_pad_mask (pad_lens_list )
279
+ total_mask = torch .concat ([non_blank_mask , padding_mask ], dim = 1 )
280
+
281
+ # 6. Apply mask and reshape
282
+ compressed_output = padded_ctc_output [total_mask ].reshape (
283
+ ctc_output .shape [0 ], - 1 , ctc_output .shape [2 ]
284
+ )
285
+
286
+ return compressed_output
287
+
233
288
def forward (
234
289
self ,
235
290
fbank : torch .Tensor ,
@@ -238,9 +293,11 @@ def forward(
238
293
attention_mask : torch .Tensor ,
239
294
labels : torch .LongTensor ,
240
295
):
241
- encoder_outs , _ = self .forward_encoder (fbank , fbank_lens )
296
+ encoder_outs , encoder_out_lens = self .forward_encoder (fbank , fbank_lens )
242
297
243
- speech_features = self .encoder_projector (encoder_outs )
298
+ compressed_encoder_outs = self .ctc_compress (encoder_outs , encoder_out_lens )
299
+
300
+ speech_features = self .encoder_projector (compressed_encoder_outs )
244
301
245
302
inputs_embeds = self .llm .get_input_embeddings ()(input_ids )
246
303
0 commit comments