Simplified reimplementation of TinyRecursiveModels using MLX.
-
Setup the environment
uv sync source .venv/bin/activate -
Adjust model config in
train.py@dataclass class ModelConfig: in_channels: int depth: int dim: int heads: int patch_size: tuple n_outputs: int pool: str = "cls" # mean or cls n: int = 6 # latent steps T: int = 3 # deep steps halt_max_steps: int = 8 # maximum supervision steps halt_exploration_prob: float = 0.2 # exploratory q probability halt_follow_q: bool = True # follow q (True) or max steps (False)
-
Train on MNIST or CIFAR-10 (see
python train.py --help):python train.py --dataset mnist python train.py --dataset cifar10
- Hyperparams are currently hardcoded for faster experimentation.
- Only MNIST and CIFAR-10 are supported at the moment.
uvhandles virtual environment creation automatically.