Skip to content

Anemll/mlx-trm

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLX Tiny Recursive Models

Simplified reimplementation of TinyRecursiveModels using MLX.

Usage

  1. Setup the environment

    uv sync
    source .venv/bin/activate
  2. Adjust model config in train.py

    @dataclass
    class ModelConfig:
        in_channels: int
        depth: int
        dim: int
        heads: int
        patch_size: tuple
        n_outputs: int
        pool: str = "cls" # mean or cls
        n: int = 6  # latent steps
        T: int = 3  # deep steps
        halt_max_steps: int = 8  # maximum supervision steps
        halt_exploration_prob: float = 0.2  # exploratory q probability
        halt_follow_q: bool = True  # follow q (True) or max steps (False)
  3. Train on MNIST or CIFAR-10 (see python train.py --help):

    python train.py --dataset mnist
    python train.py --dataset cifar10

Notes

  • Hyperparams are currently hardcoded for faster experimentation.
  • Only MNIST and CIFAR-10 are supported at the moment.
  • uv handles virtual environment creation automatically.

About

MLX Implementation of Recursive Reasoning with Tiny Networks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%