Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
dc07bba
init
yfyeung Apr 30, 2025
9c8c431
init zipformer_llm_zh
yfyeung May 7, 2025
23b5a7c
format multi_dataset.py
yfyeung May 7, 2025
211c01b
format train.py
yfyeung May 7, 2025
489c42b
support zipformer encoder
yfyeung May 8, 2025
ec6c8f7
fix data prepare
yfyeung May 9, 2025
2420d0c
update multi_dataset.py
yfyeung May 9, 2025
c75767f
set world_size and rank explicitly
yfyeung May 10, 2025
cd3adad
use quadratic-duration
yfyeung May 10, 2025
5fbeed9
fix SwooshR and SwooshL
yfyeung May 11, 2025
9939c2b
remove duplicated torch autocast
yfyeung May 11, 2025
c078772
skip OOM
yfyeung May 11, 2025
2793ccd
remove checkpoint save after validation
yfyeung May 12, 2025
c709ce4
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung May 12, 2025
ea20ac2
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung May 12, 2025
06667e1
add batch shave mechanism
yfyeung May 12, 2025
62dfe56
restore checkpoint save after validation
yfyeung May 13, 2025
24b6f42
fix typos in docs
yfyeung May 13, 2025
d1a535d
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung May 24, 2025
11ccaa3
add requirements.txt
yfyeung May 26, 2025
7c30dd5
restrict deepspeed >=0.16.9
yfyeung May 28, 2025
05e3094
refactor branch exchange in cr-ctc (#1954)
yaozengwei May 27, 2025
34639d5
use padding instead of trimming (suggested by @shylockasr)
yfyeung Jun 3, 2025
c571a88
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung Jun 18, 2025
39d9035
fix deepspeed config
yfyeung Jun 18, 2025
53111d0
fix for multigpu
yfyeung Jun 18, 2025
5634900
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung Jun 18, 2025
70f13e5
Merge branch 'k2-fsa:master' into dev/speechllm
yfyeung Jul 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions egs/multi_zh-hans/ASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

This recipe includes scripts for training Zipformer model using multiple Chinese datasets.

# Included Training Sets
# Included Training Dataset
1. THCHS-30
2. AiShell-{1,2,4}
3. ST-CMDS
Expand All @@ -14,7 +14,7 @@ This recipe includes scripts for training Zipformer model using multiple Chinese
8. WeNetSpeech
9. KeSpeech-ASR

|Datset| Number of hours| URL|
|Dataset| Number of hours| URL|
|---|---:|---|
|**TOTAL**|14,106|---|
|THCHS-30|35|https://www.openslr.org/18/|
Expand Down
8 changes: 4 additions & 4 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Transducer Greedy Offline | 21.67 | 23.43 | 1.22 | 1.31 | 3.17 | 3.27 | 14.64 | 2.42 | 1.99 | 5.00 | 2.29 | 5.98 | 5.15 | 5.85 | 6.89 |

Pre-trained model can be found here : https://huggingface.co/yuekai/icefall-asr-multi-zh-hans-zipformer-xl
Expand Down Expand Up @@ -152,7 +152,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Greedy Streaming | 26.50 | 28.10| 1.71 | 1.97| 3.89| 4.06 | 17.23 | 3.69 | 2.87 | 8.14 | 3.61 |9.51 | 6.11 | 8.13 | 10.62 |
| CTC Greedy Offline | 23.47 | 25.02 | 1.39 | 1.50 | 3.15 | 3.41 | 15.14 | 3.07 | 2.37 | 6.06 | 2.90 | 7.13 | 5.40 | 6.52 | 9.64 |
| Transducer Greedy Offline | 23.16 | 24.78 | 1.33 | 1.38 | 3.06 | 3.23 | 15.36 | 2.54 | 2.09 | 5.24 | 2.28 | 6.26 | 4.87 | 6.26 | 7.07 |
Expand Down Expand Up @@ -193,7 +193,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the

| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 |
| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 |

Expand Down Expand Up @@ -226,7 +226,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the

| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Split | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 |


Expand Down
36 changes: 35 additions & 1 deletion egs/multi_zh-hans/ASR/zipformer/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,25 @@ def add_arguments(cls, parser: argparse.ArgumentParser):
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--num-cuts-for-bins-estimate",
type=int,
default=10000,
help="We will draw this many cuts to estimate the duration"
"bins for creating similar-duration buckets. Larger number"
"means a better estimate to the data distribution, possibly"
"at a longer init cost.",
)
group.add_argument(
"--quadratic-duration",
type=float,
default=None,
help="When set, it adds an extra penalty that's quadratic"
"in size w.r.t. a cuts duration. This helps get a more"
"even GPU utilization across different input lengths when"
"models have quadratic input complexity.0 Set between 15"
"and 40 for transformers.",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
Expand Down Expand Up @@ -205,6 +224,8 @@ def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
"""
Args:
Expand Down Expand Up @@ -295,18 +316,24 @@ def train_dataloaders(
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=self.args.quadratic_duration,
num_cuts_for_bins_estimate=self.args.num_cuts_for_bins_estimate,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
world_size=world_size,
rank=rank,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
world_size=world_size,
rank=rank,
)
logging.info("About to create train dataloader")

Expand All @@ -330,7 +357,12 @@ def train_dataloaders(

return train_dl

def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
def valid_dataloaders(
self,
cuts_valid: CutSet,
world_size: Optional[int] = None,
rank: Optional[int] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
Expand All @@ -355,6 +387,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
world_size=world_size,
rank=rank,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
Expand Down
4 changes: 4 additions & 0 deletions egs/speech_llm/ASR_LLM/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
models
train*.sh
decode*.sh
sync*.sh
10 changes: 10 additions & 0 deletions egs/speech_llm/ASR_LLM/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
mkdir data/fbank/wenetspeech
mv data/fbank/cuts_L_fixed.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_DEV_fixed.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_TEST_MEETING.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/cuts_TEST_NET.jsonl.gz data/fbank/wenetspeech/
mv data/fbank/L_split_100 data/fbank/wenetspeech/
mv data/fbank/feats_DEV.lca data/fbank/wenetspeech/
mv data/fbank/feats_TEST_MEETING.lca data/fbank/wenetspeech/
mv data/fbank/feats_TEST_NET.lca data/fbank/wenetspeech/
fi

if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
Expand All @@ -46,4 +55,5 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
mkdir data_speechio
huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank
rm -rf data_speechio
fi
2 changes: 1 addition & 1 deletion egs/speech_llm/ASR_LLM/whisper_llm_zh/ds_config_zero1.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"loss_scale_window": 100,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 0.01
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 1,
Expand Down
8 changes: 4 additions & 4 deletions egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ def _merge_input_ids_with_speech_features(

def forward(
self,
fbank: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor = None,
labels: torch.LongTensor = None,
fbank: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
labels: torch.LongTensor,
):
encoder_outs = self.encoder(fbank)

Expand Down
Loading
Loading