Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
239 commits
Select commit Hold shift + click to select a range
2ed2fcb
Update README.md
tekaratzas Sep 14, 2025
7efecef
Update README.md
tekaratzas Sep 14, 2025
685467e
Merge branch 'main' of github.com:tekaratzas/RustGPT
tekaratzas Sep 14, 2025
74db83f
Added MIT License
tekaratzas Sep 15, 2025
34ecc54
fix(readme): correct repo URL and directory path in Quick Start
hissamshar Sep 15, 2025
29e5ef5
Merge pull request #1 from hissamshar/patch-1
tekaratzas Sep 15, 2025
869d60e
isolate data loading
anshumanpatil Sep 15, 2025
710d086
data loading from json
anshumanpatil Sep 16, 2025
7e876e3
data loading from csv
anshumanpatil Sep 16, 2025
4a506b3
csv files added
anshumanpatil Sep 16, 2025
179950a
Added what this isn't section in readme
tekaratzas Sep 17, 2025
403e642
Merge branch 'main' of github.com:tekaratzas/RustGPT
tekaratzas Sep 17, 2025
75bdb67
Fix spelling mistake
tekaratzas Sep 17, 2025
813a011
code format
anshumanpatil Sep 17, 2025
efa2b04
Added verbose printing of vocab to confirm correct data loading (#1)
hobson Sep 17, 2025
babb0e5
Merge master and PR
anshumanpatil Sep 17, 2025
830ae33
refactoring
anshumanpatil Sep 18, 2025
7c90d1c
refactoring
anshumanpatil Sep 18, 2025
dac242f
Run cargo-fmt
Sep 18, 2025
43fced7
CI to check and run tests
Sep 19, 2025
1d4b973
Merge pull request #7 from mrityunjai01/cargo-fmt
tekaratzas Sep 21, 2025
4e2df4f
fmt conflicts solved
anshumanpatil Sep 21, 2025
362bde4
fmt conflicts solved
anshumanpatil Sep 21, 2025
6e9b67f
logs removed
anshumanpatil Sep 21, 2025
c6c0041
fix: readme badge link
Theo- Sep 21, 2025
d0d68b3
Merge pull request #10 from Theo-/main
tekaratzas Sep 21, 2025
e04156e
cargo fmt
anshumanpatil Sep 22, 2025
1540b5a
remove HF dataset
anshumanpatil Sep 22, 2025
86c528f
chore: housekeeping
ben1009 Sep 22, 2025
64b85a8
Merge branch 'main' into housekeeping
ben1009 Sep 22, 2025
11cf580
chore: housekeeping
ben1009 Sep 22, 2025
96bc6df
Use library constants in binary
Theo- Sep 22, 2025
cca7c4d
fix: use main.rs values
Theo- Sep 22, 2025
e12da67
fix: readme values
Theo- Sep 22, 2025
a974d7d
fix: readme values
Theo- Sep 22, 2025
1e5e10b
fix: wrong constant used
Theo- Sep 22, 2025
a4183dc
Merge pull request #12 from Theo-/codex/refactor-constants-in-src/mai…
tekaratzas Sep 22, 2025
43b70a5
add more gha
ben1009 Sep 23, 2025
838086d
Merge branch 'tekaratzas:main' into housekeeping
ben1009 Sep 23, 2025
9f86c82
adjust cov to 55%
ben1009 Sep 23, 2025
32f300b
add nightly fmt
ben1009 Sep 23, 2025
d3e7ed8
merge master
anshumanpatil Sep 23, 2025
4c84326
merge master
anshumanpatil Sep 23, 2025
27b6cae
PR comments addressed
anshumanpatil Sep 24, 2025
97ba6c9
Merge pull request #2 from anshumanpatil/feature/isolate-data-loading
tekaratzas Sep 25, 2025
4ede81e
New Feature: Count Total Parameters
Sep 25, 2025
684f376
Adjustments to constant variables
Sep 26, 2025
4daf898
Merge branch 'main' into housekeeping
ben1009 Sep 26, 2025
fcc6a4a
chore: fix lints
ben1009 Sep 26, 2025
7ce4b40
Calculate Parameter at Layer Level
Sep 26, 2025
a71083e
Adjust import
Sep 26, 2025
d795cef
Fix build
Sep 29, 2025
370af0c
Run a Cargo fmt
Sep 29, 2025
537f2e9
Merge pull request #14 from arikaufman/feature/count-total-parameters
tekaratzas Sep 29, 2025
4668f6d
Merge branch 'main' into housekeeping
ben1009 Sep 30, 2025
2ec33c7
chore: fix lints
ben1009 Sep 30, 2025
d23da20
Refactor: Eliminate repeated vocabulary processing code by moving to …
Oct 2, 2025
cc45a4e
fix formatting
Oct 3, 2025
1159247
Merge pull request #17 from tomschelsen/dry
tekaratzas Oct 3, 2025
4f407f4
comment out the cov gha
ben1009 Oct 4, 2025
0861377
update
ben1009 Oct 5, 2025
51e79eb
Merge branch 'main' into housekeeping
ben1009 Oct 5, 2025
cf1f7d3
Merge pull request #11 from ben1009/housekeeping
tekaratzas Oct 5, 2025
e977fe5
chore: fix readme workflow badge
ben1009 Oct 9, 2025
772c469
Merge pull request #19 from ben1009/readme
tekaratzas Oct 10, 2025
98aeb4e
Sprint 3.1 Complete: Documentation Foundation + Batch Training + Trac…
ryancinsight Oct 14, 2025
ed58706
Sprint 3.1 Complete: Documentation + Batch Training + Tracing Integra…
ryancinsight Oct 14, 2025
f8f438d
Sprint 3.1: Documentation + Batch Training + Tracing
ryancinsight Oct 14, 2025
69aa104
Add HyperMixer architecture with modular design and comprehensive doc…
ryancinsight Oct 15, 2025
76e480b
Add -i flag to enable interactive prompt after training
ryancinsight Oct 15, 2025
b1423c8
Optimize HyperMixer architecture for performance and training stability
ryancinsight Oct 15, 2025
4b650a1
Fix critical caching bug in HyperMixer token mixing
ryancinsight Oct 15, 2025
957dfd2
Simplify HyperMixer by removing pooling mechanism
ryancinsight Oct 15, 2025
cd8172a
Major fix: Implement proper token mixing with information flow
ryancinsight Oct 15, 2025
d38d0b8
Improve HyperMixer attention mechanism - now correctly identifies 'mo…
ryancinsight Oct 15, 2025
41f6872
Add comprehensive tests for RMSNorm, RotaryEmbedding, Sliding Window …
ryancinsight Oct 16, 2025
9342d12
Implement Contextual Position Encoding (CoPE) in RustGPT
ryancinsight Oct 17, 2025
b752415
Add Mixture-of-Heads (MoH) implementation and documentation
ryancinsight Oct 17, 2025
10167c9
feat: Implement systematic error handling in Layer trait and related …
ryancinsight Oct 17, 2025
552f5be
feat: Enhance gradient stability and loss optimization with architect…
ryancinsight Oct 18, 2025
e45430f
Add Mixture of Experts (MoE) Layer with Adaptive Routing and Expert N…
ryancinsight Oct 18, 2025
69637c1
Implement Tiny Recursive Model (TRM) with adaptive residual scaling a…
ryancinsight Oct 19, 2025
f4eb9bb
Implement soft (differentiable) routing for Fully Adaptive MoH
ryancinsight Oct 19, 2025
a1da167
Implement learned per-token temperature with proper gating derivative
ryancinsight Oct 19, 2025
524ee0d
Implement learned per-token temperature with proper gating derivative…
ryancinsight Oct 19, 2025
7446f7f
Integrate Fully Adaptive MoH with learned temperature into TRM
ryancinsight Oct 19, 2025
81f8420
Add adaptive recursive depth infrastructure for TRM (Phase 2 partial)
ryancinsight Oct 20, 2025
6e25645
Complete adaptive recursive depth implementation for TRM (Phase 2)
ryancinsight Oct 20, 2025
f8482f5
Fix adaptive depth implementation bugs and run validation experiments
ryancinsight Oct 20, 2025
c18b62c
Add comprehensive adaptive recursive depth validation results
ryancinsight Oct 20, 2025
f440076
Improve adaptive depth with better initialization and complexity-awar…
ryancinsight Oct 20, 2025
29f9c26
Improve logging: Move LARS layer details to debug level
ryancinsight Oct 20, 2025
16aab77
Fix complexity statistics to show proper range [min-max]
ryancinsight Oct 20, 2025
198bd42
Remove DynW from logs when using Fully Adaptive MoH
ryancinsight Oct 20, 2025
75af5a0
Fix adaptive depth learning: improve initialization and remove LR sca…
ryancinsight Oct 20, 2025
d19d3f6
Implement confidence-based halting for adaptive depth
ryancinsight Oct 20, 2025
43a8e95
Improve adaptive depth: stronger halting init + higher ponder weight
ryancinsight Oct 20, 2025
99d5130
Revert to fixed depth=5 baseline for validation
ryancinsight Oct 20, 2025
c7e0199
Replace pooling with Gumbel-Softmax attention in TRM halting predictor
ryancinsight Oct 20, 2025
e70e746
refactor: optimize normalization and attention implementations
ryancinsight Oct 21, 2025
ab3e57e
refactor: remove HRM and HyperMixer architectures and related components
ryancinsight Oct 21, 2025
f591e8b
refactor: modernize architecture with DynamicTanhNorm and PolyAttention
ryancinsight Oct 23, 2025
1fbf936
refactor(architecture): remove TRM and integrate CoPE into PolyAttention
ryancinsight Oct 23, 2025
9b71eb5
feat: add sigmoid_poly module and optimize performance
ryancinsight Oct 23, 2025
312acf2
refactor(poly_attention): optimize attention computation with row-str…
ryancinsight Oct 23, 2025
6139a1c
refactor: consolidate imports and improve code formatting
ryancinsight Oct 23, 2025
7e5c32c
perf(poly_attention): optimize windowed attention with parallel proce…
ryancinsight Oct 23, 2025
7dd7272
refactor(model): replace sigmoid_poly with richards activation
ryancinsight Oct 25, 2025
fd7718c
feat(model): add model saving and inference capabilities
ryancinsight Oct 25, 2025
8e2c084
feat: restore advanced modules while reverting to model saving baseline
ryancinsight Nov 4, 2025
dd23346
refactor(data): remove specific Q&A entries from chat training data
ryancinsight Nov 4, 2025
e3082f9
Add model files to .gitignore and remove from tracking
ryancinsight Nov 4, 2025
32502fa
feat: integrate AutoDeco decoder with updated config and logging
ryancinsight Nov 4, 2025
8c3ccfb
perf(poly_attention): optimize forward method with iterator-based pro…
ryancinsight Nov 5, 2025
08d8daf
refactor(richards): streamline RichardsCurve implementation and enhan…
ryancinsight Nov 5, 2025
a03b2a0
feat(layers): add weight_norm method to all layer types
ryancinsight Nov 5, 2025
937450f
refactor(richards): integrate PadeExp for exponential calculations in…
ryancinsight Nov 6, 2025
d71453f
feat(mixtures): integrate Richards curves for adaptive head selection…
ryancinsight Nov 6, 2025
fbf3888
refactor: replace SwiGLU with RichardsGlu in testing and model layers
ryancinsight Nov 6, 2025
bb50c67
feat: integrate Mixture of Experts into LLM architecture
ryancinsight Nov 6, 2025
27dfd51
refactor: share gating logic between MoH and MoE
ryancinsight Nov 7, 2025
564237f
refactor: enhance routing logic and introduce HeadRouter for Mixture-…
ryancinsight Nov 7, 2025
5810223
feat(llm): enhance RichardsGlu training status tracking and logging
ryancinsight Nov 7, 2025
bb23de6
refactor(attention): reorganize PolyAttention module structure and up…
ryancinsight Nov 7, 2025
72b012c
feat(moh): implement SoftTopP head selection strategy for improved ro…
ryancinsight Nov 8, 2025
fc128a1
feat(attention): improve numerical stability in polynomial attention …
ryancinsight Nov 9, 2025
107103a
feat(llm): introduce Tiny Recursive Model (TRM) architecture
ryancinsight Nov 9, 2025
9469809
feat(llm): implement TRM training pipeline with autoencoding and chat…
ryancinsight Nov 10, 2025
eaa91a3
refactor(trm): improve gradient computation and handling in TRM and T…
ryancinsight Nov 10, 2025
1a4cbdd
Consolidate diffusion-TRM stability and telemetry
ryancinsight Nov 15, 2025
73e7223
Add training logs for diffusion models with detailed metrics
ryancinsight Nov 15, 2025
2c493b1
Enhance diffusion model configuration and logging
ryancinsight Nov 15, 2025
a08ce6e
feat(transformer): optimize attention and block performance
ryancinsight Nov 17, 2025
b8e41b9
Refactor RichardsGlu and RichardsNorm for improved readability and pe…
ryancinsight Nov 17, 2025
e85bf79
Add training logs for diffusion model runs with detailed metrics
ryancinsight Nov 18, 2025
112e9ba
fix: correct indexing bug and add debug logging in PolyAttention; ren…
ryancinsight Nov 19, 2025
c08d86f
Add training logs for diffusion model runs on 2025-11-20
ryancinsight Nov 20, 2025
9f9a6a4
Refactor layer handling and add zero_gradients method for improved gr…
ryancinsight Nov 22, 2025
78c7262
Refactor attention and model configuration to use RichardsGate
ryancinsight Nov 23, 2025
1fb9c45
Refactor RichardsCurve for improved performance and clarity
ryancinsight Nov 24, 2025
9848504
Enhance speculative sampling configuration and functionality
ryancinsight Nov 24, 2025
9dd5204
Implement deterministic random number generation module with global s…
ryancinsight Nov 25, 2025
f92c2ee
Refactor code to remove unused variables and improve clarity across m…
ryancinsight Nov 30, 2025
001448b
feat: Optimize Richards module performance and memory usage
ryancinsight Dec 5, 2025
26d162a
feat: Implement position-aware residual scaling based on Theorem 4 wi…
ryancinsight Dec 6, 2025
6f39936
feat: Implement advanced adaptive residuals for diffusion models
ryancinsight Dec 7, 2025
77c3d33
Implement unified adaptive residuals for transformer and diffusion mo…
ryancinsight Dec 10, 2025
2538b15
Refactor Richards curve implementation for improved performance and m…
ryancinsight Dec 12, 2025
e6cd6f6
Refactor Richards Gate and GLU implementations for improved stability…
ryancinsight Dec 13, 2025
cc7a736
Refactor attention mechanisms for numerical stability and performance…
ryancinsight Dec 13, 2025
d75d581
Refactor adaptive residuals implementation for improved clarity and p…
ryancinsight Dec 13, 2025
af64c01
Refactor code structure for improved readability and maintainability
ryancinsight Dec 14, 2025
9de659e
Refactor Mixture-of-Experts and routing logic for improved performanc…
ryancinsight Dec 14, 2025
709319d
Enhance Mixture-of-Experts and attention mechanisms with head activit…
ryancinsight Dec 14, 2025
16a591c
Refactor attention mechanisms to improve performance and reduce memor…
ryancinsight Dec 14, 2025
fcf22a9
Add Real-Gated Linear Recurrent Unit (RG-LRU) and Multi-head RG-LRU l…
ryancinsight Dec 19, 2025
3445cbf
Add transformer layers and speculative sampling implementation
ryancinsight Dec 20, 2025
e298354
Refactor Mamba layer to use Pade approximation for exponential calcul…
ryancinsight Dec 21, 2025
cd0f3c5
Refactor model architecture to support Autoregressive configuration
ryancinsight Dec 21, 2025
fbecae7
feat(transformer): add modular components for attention and feedforwa…
ryancinsight Jan 3, 2026
75a4581
Refactor RichardsCurve and related components for f32 support
ryancinsight Jan 4, 2026
71e9ee3
Add comprehensive documentation for Mamba and RG-LRU architectures; i…
ryancinsight Jan 5, 2026
dacebc5
Implement Pade approximation for exponential function with comprehens…
ryancinsight Jan 6, 2026
06f7994
chore: consolidate persistence implementation into LLM with versioned…
ryancinsight Jan 6, 2026
f11f6e7
feat: Enhance ModelConfig with new residual training parameters
ryancinsight Jan 8, 2026
4a3efdd
Refactor and enhance various components
ryancinsight Jan 12, 2026
7e2f314
feat(model): add Titan memory mechanism and checkpointing support
ryancinsight Jan 13, 2026
6e6a1e2
feat(ssm): add MoH variants for Mamba and Mamba2 layers
ryancinsight Jan 14, 2026
817c9fa
refactor: optimize performance and memory usage across multiple compo…
ryancinsight Jan 15, 2026
9cd0e4b
perf: Reuse layer_inputs vector to reduce allocator pressure
google-labs-jules[bot] Jan 15, 2026
6b21990
Merge pull request #1 from ryancinsight/perf-reuse-layer-inputs-vec-1…
ryancinsight Jan 15, 2026
528d102
Hoist RichardsCurve construction out of closure in llm.rs
google-labs-jules[bot] Jan 15, 2026
ac798f8
Optimize CSV string concatenation in dataset_loader
google-labs-jules[bot] Jan 15, 2026
da65961
Remove redundant tensor cloning in inference loop
google-labs-jules[bot] Jan 15, 2026
c755461
Decouple MoH training from attention gradients
google-labs-jules[bot] Jan 15, 2026
f2023c9
Implement hierarchical forward pass in AdaptiveSoftmax
google-labs-jules[bot] Jan 15, 2026
4d54145
Merge pull request #2 from ryancinsight/optimize-richards-closure-100…
ryancinsight Jan 15, 2026
837c3de
Optimize dataset loading to use streaming parsing
google-labs-jules[bot] Jan 15, 2026
c3a5f2c
Implement Adaptive Softmax forward pass and strategy handling
google-labs-jules[bot] Jan 15, 2026
229166e
Merge pull request #3 from ryancinsight/perf/optimize-csv-concat-9122…
ryancinsight Jan 15, 2026
902e6c4
Merge pull request #5 from ryancinsight/perf/reduce-clones-llm-668193…
ryancinsight Jan 15, 2026
c7825ad
Optimize backprop loop to avoid input cloning
google-labs-jules[bot] Jan 15, 2026
45cc0ce
Merge pull request #6 from ryancinsight/moh-training-decoupling-82370…
ryancinsight Jan 15, 2026
eca120b
Merge pull request #7 from ryancinsight/hierarchical-softmax-impl-329…
ryancinsight Jan 15, 2026
7e9c949
Merge pull request #8 from ryancinsight/dataset-loader-perf-165367342…
ryancinsight Jan 15, 2026
31b198f
Merge branch 'main' into adaptive-softmax-impl-4350403344050072083
ryancinsight Jan 15, 2026
755ff05
Merge pull request #9 from ryancinsight/adaptive-softmax-impl-4350403…
ryancinsight Jan 15, 2026
e2734b1
Merge pull request #10 from ryancinsight/perf/optimize-backprop-cloni…
ryancinsight Jan 15, 2026
26ce3c0
Decouple Richards Curve training and enable independent learning from…
google-labs-jules[bot] Jan 16, 2026
2d278cd
feat: Refactor to deep hierarchical structure and add Titans architec…
google-labs-jules[bot] Jan 16, 2026
7552924
Fix RichardsGate gradient packing and decouple training
google-labs-jules[bot] Jan 16, 2026
309c4bb
Merge pull request #11 from ryancinsight/decouple-richards-training-2…
ryancinsight Jan 16, 2026
9dd8274
Merge pull request #12 from ryancinsight/titans-arch-refactor-7625585…
ryancinsight Jan 16, 2026
4b24fe1
Implement NeuralMemory for Titans architecture
google-labs-jules[bot] Jan 16, 2026
50872a0
Merge pull request #13 from ryancinsight/titans-neural-memory-impl-42…
ryancinsight Jan 16, 2026
5ed49c9
Implement NeuralMemory meta-gradients and TitansMAC architecture
google-labs-jules[bot] Jan 16, 2026
b1b6c3d
Merge pull request #14 from ryancinsight/titans-memory-mac-impl-14553…
ryancinsight Jan 16, 2026
5ed021b
Reintegrate eprop module with feature flag
google-labs-jules[bot] Jan 16, 2026
5424c18
Fix in-place modification of gradient accumulator in NeuralMemory
google-labs-jules[bot] Jan 16, 2026
91677b9
Merge pull request #16 from ryancinsight/fix-neural-memory-gradient-b…
ryancinsight Jan 16, 2026
db7fd23
Reintegrate eprop module without feature flag
google-labs-jules[bot] Jan 16, 2026
00962d7
Merge pull request #15 from ryancinsight/reintegrate-eprop-flag-10392…
ryancinsight Jan 16, 2026
4c9c77e
Enable Titans Memory architecture and implement gradients
google-labs-jules[bot] Jan 16, 2026
9529d74
Fix TitansMAC backward pass input
google-labs-jules[bot] Jan 16, 2026
123383a
Merge pull request #17 from ryancinsight/titans-memory-enable-6527332…
ryancinsight Jan 16, 2026
d403e39
feat(training): add e-prop training pipeline and integrate with main CLI
ryancinsight Jan 16, 2026
ab4e471
Merge branch 'main' of https://github.com/ryancinsight/RustGPT
ryancinsight Jan 16, 2026
281b99a
refactor(titans): improve code formatting and memory gradient calcula…
ryancinsight Jan 16, 2026
8bf8c60
feat: Define TitansMAL struct with NeuralMemory and SlidingWindowAtte…
google-labs-jules[bot] Jan 21, 2026
ed2ae23
feat(eprop): add spiking neuron layers and numeric utilities
ryancinsight Jan 21, 2026
790a923
feat(eprop): implement e-prop adaptor for transformer blocks
ryancinsight Jan 21, 2026
a7ebb10
feat: Implement `TitansMAL` with backpropagation
google-labs-jules[bot] Jan 21, 2026
98339e1
Merge pull request #18 from ryancinsight/feat-titans-mal-struct-31150…
ryancinsight Jan 21, 2026
2635b1f
perf: Reuse gradient accumulation buffers in training loops
google-labs-jules[bot] Jan 21, 2026
d55847c
Merge pull request #20 from ryancinsight/perf/reuse-grad-buffers-1403…
ryancinsight Jan 21, 2026
12a6369
Optimize weight updates with scaled_add and fix build errors
google-labs-jules[bot] Jan 22, 2026
8c8312a
feat(eprop): implement e-prop training support in LLM model
ryancinsight Jan 22, 2026
3211057
Merge branch 'main' of https://github.com/ryancinsight/RustGPT
ryancinsight Jan 22, 2026
a9af6a9
Merge pull request #21 from ryancinsight/perf-optimize-weight-updates…
ryancinsight Jan 22, 2026
9f7780a
Implement TitansMAG architecture with SWA, NeuralMemory, and Gating
google-labs-jules[bot] Jan 22, 2026
1bee005
Merge pull request #22 from ryancinsight/titans-mag-implementation-38…
ryancinsight Jan 22, 2026
f1d73a0
chore: ignore build artifacts in CI target directories
ryancinsight Jan 22, 2026
9d3fdb7
Merge branch 'main' of https://github.com/ryancinsight/RustGPT
ryancinsight Jan 22, 2026
ef96948
style: reformat code and add development guidelines
ryancinsight Jan 22, 2026
faa35e9
perf: reduce log noise by downgrading info to debug and optimize tens…
ryancinsight Jan 22, 2026
aec4b71
feat(models): add titans module and restructure memory system
ryancinsight Jan 23, 2026
1bfad37
chore: ignore target_ci build artifacts
ryancinsight Jan 23, 2026
0850e70
refactor: eliminate unnecessary mutable variables and allocations
ryancinsight Jan 23, 2026
51f9de5
feat(training): add adaptive hyperparameter scheduling via Richards c…
ryancinsight Jan 24, 2026
9524cb5
Implement TitansMAL forward pass and add verification test.
google-labs-jules[bot] Jan 24, 2026
a880b8a
Merge pull request #23 from ryancinsight/titans-mal-forward-167646896…
ryancinsight Jan 24, 2026
a25c64f
feat(richards): add adaptive scalar for MoH threshold modulation
ryancinsight Jan 25, 2026
a24dda2
Replace hardcoded config constants
ryancinsight Jan 25, 2026
1f3cbe4
perf: optimize memory usage and inference speed across modules
ryancinsight Jan 26, 2026
6797a6d
Optimize RichardsCurve::update_scaling_from_max_abs to avoid expensiv…
google-labs-jules[bot] Jan 26, 2026
9e9475b
Merge pull request #24 from ryancinsight/richards-curve-optimization-…
ryancinsight Jan 26, 2026
5531d11
Optimize LLM generation loop with KV caching (O(N^2) -> O(N))
google-labs-jules[bot] Jan 26, 2026
2f07c31
Merge pull request #25 from ryancinsight/perf/generation-loop-optimiz…
ryancinsight Jan 26, 2026
34ffc48
Optimize RichardsGlu::compute_gradients by removing unnecessary clones
google-labs-jules[bot] Jan 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 3 additions & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[profile.default]
slow-timeout = { period = "60s", terminate-after = 3 }
global-timeout = "20m"
89 changes: 89 additions & 0 deletions .gemini/richards_gradient_derivation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Extended Richards Curve Gradient Derivations

## Forward Pass
```
σ = [1 + β * exp(-k(input-m))]^(-1/ν)
```

Where (for defaults):
- β = 1.0
- input = x (after all transformations with defaults)

## Gradients

### ∂σ/∂ν (nu gradient)
Using logarithmic differentiation:
```
ln(σ) = (-1/ν) * ln(base)
(1/σ) * ∂σ/∂ν = (1/ν²) * ln(base)
∂σ/∂ν = σ * ln(base) / ν²
```

### ∂σ/∂k (k gradient)
Chain rule through base and exponent:
```
∂σ/∂base = (-1/ν) * base^(-1/ν - 1) = (-1/ν) * σ / base
∂base/∂exponent = β * exp(exponent) = β * exp_term
∂exponent/∂k = -(input - m)

∂σ/∂k = (∂σ/∂base) * (∂base/∂exponent) * (∂exponent/∂k)
= [(-1/ν) * σ / base] * [β * exp_term] * [-(input - m)]
= (1/ν) * σ * β * exp_term * (input - m) / base
```

For β=1: exp_term/base = 1 - σ^ν (approximately 1 - σ for small ν differences)

More accurately, for Richards curve:
```
∂σ/∂k = (1/ν) * σ * exp_term * (input - m) / base
```

### ∂σ/∂m (m gradient)
```
∂exponent/∂m = k

∂σ/∂m = (∂σ/∂base) * (∂base/∂exponent) * (∂exponent/∂m)
= [(-1/ν) * σ / base] * [β * exp_term] * [k]
= (-k/ν) * σ * β * exp_term / base
```

### ∂σ/∂β (beta gradient)
```
∂base/∂β = exp(exponent) = exp_term

∂σ/∂β = (∂σ/∂base) * (∂base/∂β)
= [(-1/ν) * σ / base] * exp_term
= (-1/ν) * σ * exp_term / base
```

### ∂σ/∂temp (temperature gradient)
Chain through input:
```
∂input/∂temp = -input_scale * scale * adaptive_normalized / temp²
= -input_scale * scale * temp_scaled / temp

∂σ/∂input = (∂σ/∂base) * (∂base/∂exponent) * (∂exponent/∂input)
= [(-1/ν) * σ / base] * [β * exp_term] * [-k]
= (k/ν) * σ * β * exp_term / base

∂σ/∂temp = (∂σ/∂input) * (∂input/∂temp)
```

## Final Formulas (β=1 default)

```rust
// Nu gradient
d_sigma_d_nu = sigma * base.ln() / (nu * nu)

// K gradient
d_sigma_d_k = (1.0 / nu) * sigma * exp_term * (input - m) / base

// M gradient
d_sigma_d_m = (-k / nu) * sigma * exp_term / base

// Beta gradient (for β learnable)
d_sigma_d_beta = (-1.0 / nu) * sigma * exp_term / base

// Temperature gradient
d_sigma_d_temp = (k / nu) * sigma * exp_term / base * (-temp_scaled / temp)
```
22 changes: 22 additions & 0 deletions .github/codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# # ref: https://docs.codecov.com/docs/codecovyml-reference
# comment out coverage job for now, https://github.com/tekaratzas/RustGPT/pull/11#issuecomment-3361854174
# coverage:
# # Hold ourselves to a high bar
# range: 55..100
# round: down
# precision: 1
# status:
# # ref: https://docs.codecov.com/docs/commit-status
# project:
# default:
# # Avoid false negatives
# threshold: 1%

# # Test files aren't important for coverage
# ignore:
# - "tests"

# # Make comments less noisy
# comment:
# layout: "files"
# require_changes: yes
73 changes: 73 additions & 0 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
permissions:
contents: read
on:
push:
branches: [main, master]
pull_request:
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

env:
RUST_TOOLCHAIN: stable

name: Check
jobs:
fmt:
runs-on: ubuntu-latest
strategy:
fail-fast: false
name: fmt
permissions:
# Give the default GITHUB_TOKEN write permission to commit and push the
# added or changed files to the repository.
contents: write
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Install rust
uses: dtolnay/rust-toolchain@master
with:
toolchain: nightly #${{ env.RUST_TOOLCHAIN }}
components: rustfmt
- run: cargo fmt --check

clippy:
runs-on: ubuntu-latest
name: clippy
permissions:
contents: read
checks: write
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Install ${{ env.RUST_TOOLCHAIN }}
uses: dtolnay/rust-toolchain@master # master
with:
toolchain: ${{ env.RUST_TOOLCHAIN }}
components: clippy
- name: Rust Cache
uses: Swatinem/rust-cache@v2
- run: cargo clippy --workspace --all-features --all-targets -- -D warnings

typos:
runs-on: ubuntu-latest
name: typos
permissions:
contents: read
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Check spelling
uses: crate-ci/typos@master


68 changes: 68 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
permissions:
contents: read
on:
push:
branches: [main, master]
pull_request:
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

env:
RUST_TOOLCHAIN: stable

name: Test
jobs:
required:
runs-on: ubuntu-latest
name: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Install ${{ env.RUST_TOOLCHAIN }}
uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ env.RUST_TOOLCHAIN }}
- name: cargo generate-lockfile
if: hashFiles('Cargo.lock') == ''
run: cargo generate-lockfile
# https://twitter.com/jonhoo/status/1571290371124260865
- name: Rust Cache
uses: Swatinem/rust-cache@v2
- name: Install nextest
uses: taiki-e/install-action@nextest
- name: cargo nextest --locked
run: cargo nextest run --locked --workspace --all-features --all-targets

# comment out coverage job for now, https://github.com/tekaratzas/RustGPT/pull/11#issuecomment-3361854174
# coverage:
# runs-on: ubuntu-latest
# name: coverage
# steps:
# - uses: actions/checkout@v4
# with:
# submodules: true
# - name: Install rust
# uses: dtolnay/rust-toolchain@master
# with:
# toolchain: ${{ env.RUST_TOOLCHAIN }}
# components: llvm-tools-preview
# - name: cargo install cargo-llvm-cov
# uses: taiki-e/install-action@cargo-llvm-cov
# - name: cargo generate-lockfile
# if: hashFiles('Cargo.lock') == ''
# run: cargo generate-lockfile
# - name: Rust Cache
# uses: Swatinem/rust-cache@v2
# - name: Install nextest
# uses: taiki-e/install-action@nextest
# - name: cargo llvm-cov
# run: cargo llvm-cov nextest --locked --workspace --all-features --all-targets --lcov --output-path lcov.info
# - name: Upload to codecov.io
# uses: codecov/codecov-action@v5
# with:
# fail_ci_if_error: true
# token: ${{ secrets.CODECOV_TOKEN }} # required
14 changes: 14 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
/target
/target_ci/

# Model files
models/*.bin
models/*.ckpt
models/*.pth
models/*.h5
models/*.pb
models/*.onnx
*.bin
*.csv

# Local logs / run artifacts
/logs/
40 changes: 40 additions & 0 deletions .trae/documents/Add Denoising Cross-Entropy for Diffusion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
## Goals
- Implement a denoising cross-entropy (DCE) training variant for diffusion to match CE-style logs/metrics, enabling apples-to-apples comparison with Transformer/TRM.
- Combine denoising MSE with CE over the output projection of recovered x0 (configurable weights), or run CE-only.

## CLI Additions
- `--diffusion_ce` (bool): use DCE pipeline for diffusion pretraining.
- `--diffusion_ce_weight <f32>` (default: 0.5): CE loss weight.
- `--diffusion_mse_weight <f32>` (default: 0.5): MSE loss weight. If `--diffusion_ce` and `diffusion_mse_weight=0`, runs CE-only.

## Training Pipeline (LLM::train_diffusion_ce)
1) Tokenize batch sequences; slice `input_ids = seq[..len-1]` and `target_ids = seq[1..]`.
2) Embed: `x0 = TokenEmbeddings.forward([input_ids])` → shape `[seq_len, embed_dim]`.
3) Sample noise ε and timestep t; compute `x_t = NoiseScheduler.q_sample(x0, t, ε)`.
4) Predict noise: forward `x_t` through all DiffusionBlocks with `set_timestep(t)` to get `ε_θ`.
5) Recover x0_hat: `x0_hat = (x_t - sqrt(1-ᾱ_t) * ε_θ) / sqrt(ᾱ_t)` using scheduler’s `sqrt_alpha_cumprod(t)` and `sqrt_one_minus_alpha_cumprod(t)`.
6) Logits: pass `x0_hat` through final `DynamicTanhNorm` (if present) and `OutputProjection` to get `[seq_len, vocab_size]`.
7) Loss:
- MSE: `mse = mean((ε_θ - ε)^2)`.
- CE: standard token-level CE on logits vs `target_ids`.
- Total: `loss = mse_weight*mse + ce_weight*ce`.
8) Gradients:
- CE grads: `dL/dlogits` → OutputProjection.backward → `grad_hidden` (shape `[seq_len, embed_dim]`), then through final norm (if present) to get `grad_x0_hat`.
- Chain rule to predicted noise: `grad_eps = grad_x0_hat * (-sqrt(1-ᾱ_t)/sqrt(ᾱ_t))` (broadcast scalar).
- MSE grads: `grad_eps += 2*(ε_θ - ε)/N`.
- Backprop `grad_eps` through DiffusionBlocks with `compute_gradients(input=x_t, grads=grad_eps)` and `apply_gradients`.
- Optionally backprop into TokenEmbeddings via `grad_x0` if desired; default: leave embeddings updated via CE path only when explicitly enabled (keep simple: no embedding update for DCE unless requested; can add `--diffusion_update_embeddings` flag).
9) Logging: print per-epoch `loss`, `mse`, `ce`, and `grad_norm` formatted like `train_with_warmup` for consistency.

## Integration
- In `main.rs`, if `--diffusion_ce` present during diffusion pretraining, call `train_diffusion_ce(pretraining_examples, epochs, lr, batch_size, ce_weight, mse_weight)` instead of `train_diffusion`.
- Instruction tuning stays with CE (`train_with_warmup`).

## Tests
- Unit: verify `x0_hat` recovery formula correctness by round-trip (`q_sample` then recover) on synthetic data.
- Integration: small dataset run prints CE and MSE (when both enabled), and losses decrease.
- Gradient shapes: ensure DiffusionBlock `compute_gradients` receives correct shapes and param gradients non-empty.

## Notes
- Keeps backward compatibility; no changes to existing CE training for Transformer/TRM.
- Defaults provide balanced MSE+CE; adjust via flags for experiments.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## Goal
- Restore explicit TRM architecture selection in CLI.
- Ensure three architectures are available: Transformer, Diffusion, TRM.
- When both `--trm` and `--diffusion` are set, select Diffusion (TRM can use either; diffusion takes precedence when requested).
- Keep existing training flows intact; TRM uses standard training, Diffusion uses denoising training.

## Changes
- `src/main.rs`:
- Add `--trm` flag in `Args`.
- Set `architecture` as:
- If `--diffusion`: `ArchitectureType::Diffusion`
- Else if `--trm`: `ArchitectureType::TRM`
- Else: `ArchitectureType::Transformer`
- Logging for TRM stages analogous to others.
- No changes to `model_builder.rs` (already supports TRM).
- No changes to `LLM` training logic; TRM paths are handled by `train_with_warmup` which toggles TRM training mode internally.

## Verification
- Build and run:
- `cargo run --release --bin main -- --trm` → TRM architecture logs and training/tuning.
- `cargo run --release --bin main -- --trm --diffusion` → Diffusion selected (precedence), denoising training.
- `cargo run --release --bin main` → Transformer.

## Scope
- Minimal CLI and selection updates only; no structural refactors required for TRM support.
Loading