Skip to content

ArchiMickey/rvqllm

Repository files navigation

RVQ-LLM

A simple repository for training a ~100M language model to sample Residual Vector Quantization (RVQ) audio tokens conditioned on phoneme tokens for Text-to-Speech(TTS) synthesis. The audio tokens are encoded using Mimo Audio Tokenizer on LibriTTS, while using the first 8 layers of RVQ only. The language model is based on Qwen3. There is two decoding patterns implemented: parallel pattern and delay pattern. Some of the code for dataset and dataloader are borrowed from F5-TTS.

Zero-shot TTS samples

Samples are generated using the delay pattern model trained for 20 epochs on LibriTTS train-clean-100 set. The model is conditioned on phonemes extracted using phonemizer.

  1. The more I think about language, the more it amazes me that we can communicate so much with so few sounds.
gen.mp4
  1. Hello my name is Archimickey. I am currently building a text to speech system using residual vector quantization and language models.
gen2.mp4

parallel pattern delay pattern

Each model is trained on 4 4090 GPUs and can get acceptable results in 20 epochs. The delay pattern model and parallel pattern model are defined in rvqllm.py and rvqllm_delay.py respectively. The training script is in train.py and train_delay.py.

Usage

Setup

  1. Install pytorch
  2. Install flash-attn
  3. Install the repo by running pip install -e . from the repo root
  4. Clone submodules by running git submodule update --init --recursive
  5. cd into third_party/mimo_audio_tokenizer and run pip install -e .
  6. Download Mimo Audio Tokenizer model checkpoints by git clone https://huggingface.co/XiaomiMiMo/MiMo-Audio-Tokenizer
  7. run ./scripts/scan_train_audio.sh {libritts_root} to create train_wav.scp
  8. Run mimo_tokens.py to save libritts train set tokens to libritts_mimo_tokens.pt
  9. Run prepare_libritts.py to prepare libritts dataset with phonemes and tokens, this will save libritts_mimo_dataset.arrow, vocab.txt and libritts_mimo_tokens_lens.json.
  10. Run train.py or train_delay.py to train parallel pattern or delay pattern model respectively.

Training:

Deepspeed ZeRO-2 is employed to train the model on 4 RTX 4090s. You can run my deepspeed config with this command:

# For parallel pattern:
accelerate launch --config_file config.yml train.py
# For delay pattern:
accelerate launch --config_file config.yml train_delay.py

Inference example:

Construct the model, phonemizer and tokenizer. Load the trained model weights.

from rvqllm_delay import RVQwen3LM_
from third_party.mimo_audio_tokenizer import mimo_audio_tokenizer
from phonemizer.backend import EspeakBackend
from phonemizer.separator import Separator

model = RVQwen3LM_().bfloat16().to(device)
model.load_state_dict(...)

tokenizer = mimo_audio_tokenizer.load_model("...").bfloat16().to(device)

phn_backend = EspeakBackend('en-us')
gen_text = "..."
phn = phn_backend.phonemize([gen_text], separator=Separator(phone=None, word=' '))[0]
phn_ids = [phn_to_id[p] for p in phn]
phn_ids = torch.tensor(phn_ids, dtype=torch.long).unsqueeze(0)

Sample tokens with prompt and decode the tokens back to waveform.

model.eval()
with torch.no_grad():
	with torch.autocast(device_type=device, dtype=dtype):
		gen_audio_tokens = list(model.inference(text, text_lens, prompt_text=prompt_text, prompt_text_lens=prompt_text_lens, prompt_audio_tokens=prompt_audio_tokens, prompt_audio_lens=prompt_audio_lens, min_len=50, max_len=1024, top_k=3))

_gen_audio_tokens = torch.stack(gen_audio_tokens).unsqueeze(0)
code_lens = torch.tensor([t.size(0) for t in _gen_audio_tokens])
wavs, wavs_lens, _ = tokenizer.decode(_gen_audio_tokens.to(device), code_lens.to(device))

About

Language modelling on RVQ tokens with minimal codes

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published