Skip to content

Erildo/Rr-slm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Rr-slm: Recursive-Refinement Small Language Model

🎯 Mission

Build a small language model that uses iterative latent refinement instead of scaling depth/width. The model solves problems through multiple refinement passes, where each pass improves the latent representation of the input.

Current Task: Learn addition of integers (curriculum: 1-digit → 2-digit → 3-digit sums)

Key Innovation: Instead of computing once deeply, compute multiple times with incremental improvements, allowing smaller models to solve complex reasoning tasks.


🏗️ Current Architecture

The RR-SLM model has three main components:

1. Encoder

  • Input embedding: Token IDs → (B, T, 256) dense vectors
  • Mean pooling across sequence to extract overall meaning
  • Linear projection + GELU activation + LayerNorm
  • Output: Context vector (B, 256)

2. Iterative Refinement Engine

  • Reuses the same reasoning module R times (R=5)
  • Each iteration refines the latent representation:
    • Self-Attention: Extract relevant features from current latent
    • Feedforward: Transform and improve the representation
    • Refinement Gate: Selectively update features
    • Convergence Check: Early stop if latent stabilizes (||Δlatent|| < ε)
  • Tracks how many steps were actually needed (1 to 5)

3. Decoder

  • Takes refined latent (B, 256)
  • Linear projection → GELU → LayerNorm → Linear
  • Outputs logits for next token (B, vocab_size)

Autoregressive Generation: For each position in the sequence:

  1. Encode context up to that position
  2. Refine latent state through R iterations
  3. Decode to position-specific logits
  4. Predict next token

This ensures each position has context-dependent predictions, not repeated tokens.


⚠️ Current Failures

Recent Issue (Fixed)

Problem: Model was generating repeated tokens (e.g., "22222" instead of "2")

Root Cause: Non-autoregressive forward pass - the model was:

  • Encoding entire sequence once
  • Generating single prediction logits
  • Expanding same logits to all positions
  • Model learned to output same high-probability token when seeing '='

Fix: Rewrote forward() to process positions autoregressively:

  • Each position gets its own context encoding
  • Latent representation varies by position
  • Logits differ per position
  • Training loss now correctly propagates

Performance Limitations

  • Training is slow: O(T) forward passes per batch instead of O(1)

    • 30-token sequences = 30× more encoder operations
    • Fine for current small test sizes (4-10 tokens)
    • Will need optimization (KV-cache) for scaling
  • No beam search: Current generation uses greedy sampling

    • Works for simple addition
    • Could improve accuracy with beam search later

🚀 Future Directions

Phase 1: Validation (Current)

  • Fix autoregressive forward pass
  • Retrain and verify accuracy reaches 80%+ on 1-digit problems
  • Confirm no repeated token generation
  • Analyze refinement efficiency (avg steps per problem)

Phase 2: Curriculum Learning

  • Expand training to 2-digit and 3-digit addition
  • Verify iterative refinement actually helps with harder problems
  • Track how refinement depth correlates with problem difficulty

Phase 3: Efficiency & Scaling

  • Implement KV-cache for incremental context encoding
  • Reduce O(T) autoregressive bottleneck to O(1) with cache
  • Scale model from ~1.5M to 800M parameters
  • Test on more complex reasoning tasks

Phase 4: Analysis & Publication

  • Measure parameter efficiency vs standard transformers
  • Compare refinement iterations used vs problem complexity
  • Show how iterative refinement enables smaller model capabilities
  • Document best practices for recursive reasoning architectures

📊 Model Specifications

Component Details
Embedding Dim 256
Latent Dim 256
Max Refinement Steps 5
Vocab Size 13 (digits 0-9 + special: =, +, pad)
Total Parameters ~1.5M (test size, scalable to 800M)
Training Task Addition with curriculum learning
Hardware CPU/GPU (tested on CPU)

🔧 Quick Start

# Install dependencies
pip install -r requirements.txt

# Train the model (from scratch)
python -m training.train_addition

# Test inference
python test_model.py

# Analyze performance
python training/evaluate_addition.py
python training/parameter_analysis.py

📁 Repository Structure

models/
  ├── rrslm.py                 # Main model class
  ├── iterative_reasoning.py  # Refinement engine
  └── backbone.py             # Supporting modules

training/
  ├── train_addition.py       # Training loop
  └── evaluate_addition.py    # Evaluation & analysis

data/
  ├── addition_dataset.py     # Dataset class
  ├── addition.txt            # Curriculum data
  └── generate_addition.py    # Data generation

configs/
  └── vocab.py                # Vocabulary definitions

📚 Documentation


🤝 Next Steps for Contributors

  1. Run training: python -m training.train_addition
  2. Verify output: python test_model.py (should see different digits, no repetition)
  3. Check accuracy: Should reach 80%+ on 1-digit addition
  4. Analyze refinement: How many steps does each problem need?
  5. Suggest optimizations: How to make it faster while maintaining quality?

About

Recursive-Refinement SLM (RR-SLM)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages