Build a small language model that uses iterative latent refinement instead of scaling depth/width. The model solves problems through multiple refinement passes, where each pass improves the latent representation of the input.
Current Task: Learn addition of integers (curriculum: 1-digit → 2-digit → 3-digit sums)
Key Innovation: Instead of computing once deeply, compute multiple times with incremental improvements, allowing smaller models to solve complex reasoning tasks.
The RR-SLM model has three main components:
- Input embedding: Token IDs → (B, T, 256) dense vectors
- Mean pooling across sequence to extract overall meaning
- Linear projection + GELU activation + LayerNorm
- Output: Context vector (B, 256)
- Reuses the same reasoning module R times (R=5)
- Each iteration refines the latent representation:
- Self-Attention: Extract relevant features from current latent
- Feedforward: Transform and improve the representation
- Refinement Gate: Selectively update features
- Convergence Check: Early stop if latent stabilizes (||Δlatent|| < ε)
- Tracks how many steps were actually needed (1 to 5)
- Takes refined latent (B, 256)
- Linear projection → GELU → LayerNorm → Linear
- Outputs logits for next token (B, vocab_size)
Autoregressive Generation: For each position in the sequence:
- Encode context up to that position
- Refine latent state through R iterations
- Decode to position-specific logits
- Predict next token
This ensures each position has context-dependent predictions, not repeated tokens.
Problem: Model was generating repeated tokens (e.g., "22222" instead of "2")
Root Cause: Non-autoregressive forward pass - the model was:
- Encoding entire sequence once
- Generating single prediction logits
- Expanding same logits to all positions
- Model learned to output same high-probability token when seeing '='
Fix: Rewrote forward() to process positions autoregressively:
- Each position gets its own context encoding
- Latent representation varies by position
- Logits differ per position
- Training loss now correctly propagates
-
Training is slow: O(T) forward passes per batch instead of O(1)
- 30-token sequences = 30× more encoder operations
- Fine for current small test sizes (4-10 tokens)
- Will need optimization (KV-cache) for scaling
-
No beam search: Current generation uses greedy sampling
- Works for simple addition
- Could improve accuracy with beam search later
- Fix autoregressive forward pass
- Retrain and verify accuracy reaches 80%+ on 1-digit problems
- Confirm no repeated token generation
- Analyze refinement efficiency (avg steps per problem)
- Expand training to 2-digit and 3-digit addition
- Verify iterative refinement actually helps with harder problems
- Track how refinement depth correlates with problem difficulty
- Implement KV-cache for incremental context encoding
- Reduce O(T) autoregressive bottleneck to O(1) with cache
- Scale model from ~1.5M to 800M parameters
- Test on more complex reasoning tasks
- Measure parameter efficiency vs standard transformers
- Compare refinement iterations used vs problem complexity
- Show how iterative refinement enables smaller model capabilities
- Document best practices for recursive reasoning architectures
| Component | Details |
|---|---|
| Embedding Dim | 256 |
| Latent Dim | 256 |
| Max Refinement Steps | 5 |
| Vocab Size | 13 (digits 0-9 + special: =, +, pad) |
| Total Parameters | ~1.5M (test size, scalable to 800M) |
| Training Task | Addition with curriculum learning |
| Hardware | CPU/GPU (tested on CPU) |
# Install dependencies
pip install -r requirements.txt
# Train the model (from scratch)
python -m training.train_addition
# Test inference
python test_model.py
# Analyze performance
python training/evaluate_addition.py
python training/parameter_analysis.pymodels/
├── rrslm.py # Main model class
├── iterative_reasoning.py # Refinement engine
└── backbone.py # Supporting modules
training/
├── train_addition.py # Training loop
└── evaluate_addition.py # Evaluation & analysis
data/
├── addition_dataset.py # Dataset class
├── addition.txt # Curriculum data
└── generate_addition.py # Data generation
configs/
└── vocab.py # Vocabulary definitions
- CURRENT_STATUS.md - Latest fixes and status
- AUTOREGRESSIVE_FIX_COMPLETE.md - Technical deep-dive on forward pass
- ARCHITECTURE_DIAGRAMS.md - Visual architecture explanations
- IMPLEMENTATION_STATUS.md - Complete implementation details
- ERROR_ANALYSIS.md - Error fixes and solutions
- Run training:
python -m training.train_addition - Verify output:
python test_model.py(should see different digits, no repetition) - Check accuracy: Should reach 80%+ on 1-digit addition
- Analyze refinement: How many steps does each problem need?
- Suggest optimizations: How to make it faster while maintaining quality?