A research project exploring diffusion-based text generation using RoBERTa as an alternative to traditional autoregressive language models. View the related blog post, BERT is just a Single Text Diffusion Step.
Since this blog post, I modified decoding to use confidence-aware parallel decoding instead of iterative refinement. Essentially, instead of going through the masking schedule continually predicting all tokens and applying the next mask, it decodes at each step all tokens above a given confidence (or the most confident if none reach this threshhold). This generally improves model output, but it seems to actually make it worse in undertrained settings.
Training it for two hours on an H200, the model still repeated itself as a byproduct of decoding the most confident tokens, while in the iterative refinement case, the curse of parallel decoding actually led to more varied output.
This project uses uv for package management.
# Install dependencies
uv syncTrain your own RoBERTa diffusion model on openwebtext:
uv run finetune.pyThe blog post used wikitext-2 instead of openwebtext, which seems to for some reason lead to better generations. You can mess around and try changing datasets. I originally trained this by renting a H200 for an hour.
Generate text using the RoBERTa diffusion model:
uv run inference.py "Your prompt text here"Depending on what the PREFIX_LEN is set to, your prompt will need to be that length to not be out of distribution. It wouldn't be too hard to add variable length prefixes during training though.
To show the generation step-by-step as an animation, just add this flag:
uv run inference.py "Your prompt" --animationFor comparison, generate text using standard GPT-2:
uv run gpt2_inference.py "Your prompt text here"Run both models simultaneously and compare outputs:
uv run compare.py "Your prompt text here"Optionally specify a custom RoBERTa model:
uv run compare.py "Your prompt" --roberta-dir path/to/modelThis creates a synchronized animation showing:
- RoBERTa diffusion generation steps
- GPT-2 autoregressive generation
- Timing metrics for both approaches
Training Details:
- Dataset: OpenWebText (large-scale web text corpus)
- Lazy Loading: Data is tokenized on-the-fly during training (can be changed)
- Custom Collator: Handles tokenization and variable masking per batch
- Prefix Preservation: First
PREFIX_LENtokens are never masked - Variable Masking: Trains on all masking ratios from 0% to 100%
RoBERTaDiffusion/
├── config.py # configuration for training and inference
├── finetune.py # RoBERTa diffusion training script
├── inference.py # RoBERTa diffusion inference
├── gpt2_inference.py # GPT-2 baseline inference
├── compare.py # Side-by-side model comparison
