A simple repository for training a ~100M language model to sample Residual Vector Quantization (RVQ) audio tokens conditioned on phoneme tokens for Text-to-Speech(TTS) synthesis. The audio tokens are encoded using Mimo Audio Tokenizer on LibriTTS, while using the first 8 layers of RVQ only. The language model is based on Qwen3. There is two decoding patterns implemented: parallel pattern and delay pattern. Some of the code for dataset and dataloader are borrowed from F5-TTS.
Samples are generated using the delay pattern model trained for 20 epochs on LibriTTS train-clean-100 set. The model is conditioned on phonemes extracted using phonemizer.
- The more I think about language, the more it amazes me that we can communicate so much with so few sounds.
gen.mp4
- Hello my name is Archimickey. I am currently building a text to speech system using residual vector quantization and language models.
gen2.mp4
Each model is trained on 4 4090 GPUs and can get acceptable results in 20 epochs. The delay pattern model and parallel pattern model are defined in rvqllm.py and rvqllm_delay.py respectively. The training script is in train.py and train_delay.py.
- Install pytorch
- Install flash-attn
- Install the repo by running
pip install -e .from the repo root - Clone submodules by running
git submodule update --init --recursive - cd into
third_party/mimo_audio_tokenizerand runpip install -e . - Download Mimo Audio Tokenizer model checkpoints by
git clone https://huggingface.co/XiaomiMiMo/MiMo-Audio-Tokenizer - run
./scripts/scan_train_audio.sh {libritts_root}to createtrain_wav.scp - Run
mimo_tokens.pyto save libritts train set tokens tolibritts_mimo_tokens.pt - Run
prepare_libritts.pyto prepare libritts dataset with phonemes and tokens, this will savelibritts_mimo_dataset.arrow,vocab.txtandlibritts_mimo_tokens_lens.json. - Run
train.pyortrain_delay.pyto train parallel pattern or delay pattern model respectively.
Deepspeed ZeRO-2 is employed to train the model on 4 RTX 4090s. You can run my deepspeed config with this command:
# For parallel pattern:
accelerate launch --config_file config.yml train.py
# For delay pattern:
accelerate launch --config_file config.yml train_delay.pyConstruct the model, phonemizer and tokenizer. Load the trained model weights.
from rvqllm_delay import RVQwen3LM_
from third_party.mimo_audio_tokenizer import mimo_audio_tokenizer
from phonemizer.backend import EspeakBackend
from phonemizer.separator import Separator
model = RVQwen3LM_().bfloat16().to(device)
model.load_state_dict(...)
tokenizer = mimo_audio_tokenizer.load_model("...").bfloat16().to(device)
phn_backend = EspeakBackend('en-us')
gen_text = "..."
phn = phn_backend.phonemize([gen_text], separator=Separator(phone=None, word=' '))[0]
phn_ids = [phn_to_id[p] for p in phn]
phn_ids = torch.tensor(phn_ids, dtype=torch.long).unsqueeze(0)Sample tokens with prompt and decode the tokens back to waveform.
model.eval()
with torch.no_grad():
with torch.autocast(device_type=device, dtype=dtype):
gen_audio_tokens = list(model.inference(text, text_lens, prompt_text=prompt_text, prompt_text_lens=prompt_text_lens, prompt_audio_tokens=prompt_audio_tokens, prompt_audio_lens=prompt_audio_lens, min_len=50, max_len=1024, top_k=3))
_gen_audio_tokens = torch.stack(gen_audio_tokens).unsqueeze(0)
code_lens = torch.tensor([t.size(0) for t in _gen_audio_tokens])
wavs, wavs_lens, _ = tokenizer.decode(_gen_audio_tokens.to(device), code_lens.to(device))
