The Tiny Recursive Model (TRM) is a compact, iterative neural architecture designed to emulate a reasoning-style refinement process for Sudoku solving. Rather than relying on deep transformer stacks, TRM uses explicit inner–outer iterative loops to repeatedly refine latent hypotheses before committing to an output.
Formally, the model learns a function
mapping a flattened one-hot Sudoku encoding to categorical logits over digits
This project is intentionally lightweight and pedagogical. The goal is clarity of reasoning dynamics, not state-of-the-art performance.
tiny-recursive-model/tiny-recursive-model.ipynb- main notebook (model definition, dataset, trainer, training & inference utilities).data/sudoku.csv- expected dataset (CSV with quizzes and solutions columns).docs/documentation.md- this file.checkpoints/- directory where checkpoint files are saved by default.main.py- optional script for running inference on new puzzles.
Use these names when you run or inspect the project. The notebook contains the runnable code used for training and inference.
Each Sudoku puzzle is encoded as a tensor:
- 81 cells
- 10 channels per cell
- channel 0 → empty
- channels 1–9 → digits 1–9
Resulting in a flattened input vector:
The input is linearly projected into a latent space:
where
The model maintains two persistent internal states:
- Latent hypothesis state $$ z \in \mathbb{R}^{1 \times d} $$
- Output refinement state $$ y \in \mathbb{R}^{1 \times d} $$
Initialisation:
TRM performs nested recursion:
For
Where:
-
$\mathcal{R}_\phi$ is a stack of residual MLP blocks -
$\alpha$ is a learnable scalar latent gate - Residual blocks are pure transforms, gating is applied externally
This loop represents hypothesis refinement.
For
- Run the full inner loop to convergence, producing
$z^*$ - Update output state:
Where:
-
$\mathcal{O}_\psi$ is a small output MLP stack -
$\beta$ is a learnable scalar output gate
This loop represents committing refined hypotheses into the output belief.
After
with:
Each row corresponds to logits over digits
Prediction:
The dataset is a CSV with two columns:
quizzes: string of length 81, digits0–9solutions: string of length 81, digits1–9
Example:
quizzes = “003020600900305001…” solutions = “483921657967345821…”
For each cell
- If
$c = 0$ , set one-hot index 0 - If
$c = k$ , set one-hot index$k$
Targets are converted as:
to satisfy CrossEntropyLoss.
Key hyperparameters:
| Parameter | Meaning |
|---|---|
input_dim |
|
hidden_dim |
Latent width |
output_dim |
|
L_layers |
Residual blocks per latent update |
L_cycles |
Inner refinement iterations |
H_cycles |
Outer integration iterations |
dropout |
MLP dropout |
lr |
Learning rate |
weight_decay |
AdamW regularisation |
batch_size |
Training batch size |
epochs |
Training epochs |
All parameters live in a TRMConfig dataclass.
The model is trained with cell-wise categorical cross-entropy:
Implementation detail:
- Output reshaped to
(batch × 81, 9) - Targets reshaped to
(batch × 81)
Optimisation:
- AdamW
- Gradient clipping at
$|g|_2 \le 1$ - Cosine annealing LR schedule
Given a puzzle
- Encode to one-hot
$x \in \mathbb{R}^{810}$ - Forward pass through TRM
- Reshape output to
$(81, 9)$ - Apply argmax and map
$0..8 \rightarrow 1..9$
Note:
- The model predicts all cells
- Preserve givens manually if desired
Saved checkpoints include:
model_state_dictoptimizer_state_dictconfig- training metrics
A minimal file:
trm_sudoku_production.pt
containing only:
- model weights
- config
Reload with:
and switch to eval mode.
-
Divergence
- Reduce
$L_{\text{cycles}}$ ,$H_{\text{cycles}}$ - Initialise
$\alpha, \beta \approx 0$
- Reduce
-
Slow Training
- Large recursion depth implies
$O(L \cdot H)$ compute - Reduce
hidden_dimor batch size
- Large recursion depth implies
-
Bad Accuracy
- Dataset corruption
- Excessive gating early in training
-
OOM
- Latent recursion is memory-expensive
- Gradient checkpointing could help
- Replace scalar gates
$\alpha, \beta$ with vectors$\in \mathbb{R}^d$ - Introduce attention across the 81-cell dimension
- Curriculum learning by puzzle difficulty
- Constraint-aware losses (row/column/subgrid penalties)
- Quantisation and pruning for edge deployment
This project demonstrates that explicit iterative refinement can substitute depth, at least for structured reasoning tasks like Sudoku. It is intentionally minimal, interpretable, and hackable.
Note: Some of the variable names are written in camel case inside the LaTeX blocks due to limitations with the markdown parser used.