Skip to content

Add Apple Silicon MLX backend (self-contained, GPU-accelerated Muon)#205

Open
elementalcollision wants to merge 1 commit intokarpathy:masterfrom
elementalcollision:apple-silicon-mlx
Open

Add Apple Silicon MLX backend (self-contained, GPU-accelerated Muon)#205
elementalcollision wants to merge 1 commit intokarpathy:masterfrom
elementalcollision:apple-silicon-mlx

Conversation

@elementalcollision
Copy link

Summary

Adds a self-contained train_mlx.py for training on Apple Silicon Macs using MLX. No existing files are modified beyond adding mlx as an optional dependency in pyproject.toml.

  • Full Muon+AdamW optimizer ported to MLX with Newton-Schulz (Polar Express) orthogonalization
  • Float32 Newton-Schulz fix — bf16 Frobenius norm loses precision on large matrices, causing NaN divergence. Uses float32 throughout (no speed penalty — Apple Silicon lacks bf16 tensor cores)
  • MLX-native dataloader with best-fit BOS-aligned packing
  • MLX-native BPB evaluation (computes byte lengths directly from tiktoken — no token_bytes.pt dependency)
  • Apple Silicon hardware detection for MFU calculation
pip install mlx tiktoken pyarrow rustbpe numpy
python prepare.py       # one-time data prep (unchanged)
python train_mlx.py     # Apple Silicon training

Differences from #202

This PR and #202 solve the same problem but take different approaches:

This PR #202
Files modified pyproject.toml only (+5 lines) prepare.py (+70 lines), pyproject.toml (+2/−11)
Newton-Schulz MLX GPU-accelerated (float32) numpy CPU — ~10x slower on large matrices
NaN prevention float32 Newton-Schulz fixes bf16 norm overflow No fix — bf16 norm will diverge on ≥512-dim models
prepare.py Untouched Modified (adds 70 lines)
torch dep Unchanged (keeps torch==2.9.1 pinned to CUDA) Changes to torch>=2.3.0, removes CUDA index
mlx dep Optional ([project.optional-dependencies]) Required (added to base deps)
token_bytes.pt Not needed (computed from tiktoken) Needs torch to load

Key philosophical difference: this PR is additive-only — it adds two files and doesn't touch existing CUDA code or dependencies, so there's zero risk of breaking the default GPU path.

Files changed

File Change
train_mlx.py New — self-contained MLX training script (1093 lines)
pyproject.toml Add mlx as optional dependency (+5 lines)

Test plan

  • Validated on M1 Max 64GB — trains without NaN, produces valid val_bpb
  • 25 autonomous experiments completed (val_bpb 2.094 → 1.621, −22.6%)
  • BPB evaluation matches expected range
  • Float32 Newton-Schulz confirmed stable (bf16 version diverges at step ~50)
  • Verify on M2/M3/M4 hardware (community testing welcome)

🤖 Generated with Claude Code

Self-contained single-file MLX training script for Apple Silicon Macs.
Ports the full training pipeline including:
- Muon+AdamW optimizer with Newton-Schulz orthogonalization
- Float32 Newton-Schulz fix (prevents NaN divergence on Apple Silicon)
- MLX-native dataloader (numpy buffers → mx.array)
- MLX-native BPB evaluation (computes byte lengths from tiktoken directly)
- Apple Silicon hardware detection for MFU calculation
- All features: GQA, RoPE, value embeddings, sliding window attention,
  softcap, gradient accumulation, LR warmup/cooldown schedules

No existing files are modified beyond adding mlx as an optional dependency
in pyproject.toml. Does not depend on torch at runtime.

Usage:
  pip install mlx tiktoken pyarrow rustbpe numpy
  python prepare.py  # one-time data prep (still needs torch)
  python train_mlx.py

Tested on M1 Max 64GB: val_bpb 2.094 → 1.621 over 25 experiments.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@elementalcollision
Copy link
Author

elementalcollision commented Mar 12, 2026

Validation log here.

M5 testing will be done ~5 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant