Skip to content

Commit 5aa64a7

Browse files
committed
use ctc compress (suggested by @shylockasr)
1 parent ed2124a commit 5aa64a7

File tree

1 file changed

+62
-5
lines changed
  • egs/speech_llm/ASR_LLM/zipformer_llm_zh

1 file changed

+62
-5
lines changed

egs/speech_llm/ASR_LLM/zipformer_llm_zh/model.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def __init__(self, encoder_dim, llm_dim, downsample_rate=5):
3030
def forward(self, x):
3131

3232
batch_size, seq_len, feat_dim = x.size()
33-
num_padding_frames = (self.downsample_rate - seq_len % self.downsample_rate) % self.downsample_rate
33+
num_padding_frames = (
34+
self.downsample_rate - seq_len % self.downsample_rate
35+
) % self.downsample_rate
3436
if num_padding_frames > 0:
35-
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
37+
x = torch.nn.functional.pad(x, (0, 0, 0, num_padding_frames))
3638
seq_len = x.size(1)
3739

3840
x = x.contiguous()
@@ -62,12 +64,14 @@ def __init__(
6264
self,
6365
encoder_embed: nn.Module,
6466
encoder: EncoderInterface,
67+
ctc_output: nn.Module,
6568
llm: nn.Module,
6669
encoder_projector: nn.Module,
6770
):
6871
super().__init__()
6972
self.encoder_embed = encoder_embed
7073
self.encoder = encoder
74+
self.ctc_output = ctc_output
7175
self.llm = llm
7276
self.encoder_projector = encoder_projector
7377

@@ -186,7 +190,7 @@ def _merge_input_ids_with_speech_features(
186190
(final_attention_mask == 0), 1
187191
)
188192

189-
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
193+
# 6. Mask compressed_output the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
190194
batch_indices, pad_indices = torch.where(
191195
input_ids == self.llm.config.pad_token_id
192196
)
@@ -230,6 +234,57 @@ def forward_encoder(
230234

231235
return encoder_out, encoder_out_lens
232236

237+
def ctc_compress(
238+
self,
239+
encoder_out: torch.Tensor,
240+
encoder_out_lens: torch.Tensor,
241+
blank_id: int = 0,
242+
) -> torch.Tensor:
243+
"""
244+
Remove frames from encoder_out where CTC argmax predicts blank.
245+
Args:
246+
encoder_out: Tensor of shape (N, T, C), encoder output.
247+
encoder_out_lens: Tensor of shape (N,), lengths before padding.
248+
blank_id: CTC blank token ID (default: 0).
249+
250+
Returns:
251+
Compressed CTC output of shape (N, T', C).
252+
"""
253+
# 1. Compute CTC argmax predictions
254+
ctc_output = self.ctc_output(encoder_out)
255+
ctc_preds = ctc_output.argmax(dim=-1)
256+
257+
# 2. Create non-blank, non-pad mask
258+
padding_mask = make_pad_mask(encoder_out_lens)
259+
non_blank_mask = (ctc_preds != blank_id) & (~padding_mask)
260+
261+
# 3. Compute lengths after compress
262+
compressed_lens = non_blank_mask.sum(dim=1)
263+
max_len = compressed_lens.max().item()
264+
265+
# 4. Pre-pad output
266+
pad_lens_list = (
267+
torch.full_like(
268+
compressed_lens,
269+
max_len,
270+
device=ctc_output.device,
271+
)
272+
- compressed_lens
273+
)
274+
max_pad_len = int(pad_lens_list.max())
275+
padded_ctc_output = torch.nn.functional.pad(ctc_output, [0, 0, 0, max_pad_len])
276+
277+
# 5. Create final mask
278+
padding_mask = ~make_pad_mask(pad_lens_list)
279+
total_mask = torch.concat([non_blank_mask, padding_mask], dim=1)
280+
281+
# 6. Apply mask and reshape
282+
compressed_output = padded_ctc_output[total_mask].reshape(
283+
ctc_output.shape[0], -1, ctc_output.shape[2]
284+
)
285+
286+
return compressed_output
287+
233288
def forward(
234289
self,
235290
fbank: torch.Tensor,
@@ -238,9 +293,11 @@ def forward(
238293
attention_mask: torch.Tensor,
239294
labels: torch.LongTensor,
240295
):
241-
encoder_outs, _ = self.forward_encoder(fbank, fbank_lens)
296+
encoder_outs, encoder_out_lens = self.forward_encoder(fbank, fbank_lens)
242297

243-
speech_features = self.encoder_projector(encoder_outs)
298+
compressed_encoder_outs = self.ctc_compress(encoder_outs, encoder_out_lens)
299+
300+
speech_features = self.encoder_projector(compressed_encoder_outs)
244301

245302
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
246303

0 commit comments

Comments
 (0)