Skip to content

Add Apple Silicon / MLX training support (train_mlx.py)#202

Open
AKHegde22 wants to merge 1 commit intokarpathy:masterfrom
AKHegde22:master
Open

Add Apple Silicon / MLX training support (train_mlx.py)#202
AKHegde22 wants to merge 1 commit intokarpathy:masterfrom
AKHegde22:master

Conversation

@AKHegde22
Copy link

Summary

This PR adds native Apple Silicon support via MLX, allowing autoresearch to run on Macs without any CUDA dependency.

What's changed

New file: train_mlx.py

A complete MLX port of train.py that mirrors the architecture exactly:

  • GQA causal self-attention with RoPE + QK-norm
  • Value residual embeddings on alternating layers (ResFormer)
  • Sliding-window attention pattern (SSSL)
  • Squared ReLU MLP
  • Muon (2-D transformer matrices) + AdamW (embeddings / scalars)

The Muon optimizer runs on CPU via numpy since MLX does not have a torch.compile equivalent. Defaults are tuned for Mac hardware (smaller batch size, smaller model depth).

Modified: prepare.py

  • Added make_dataloader_mlx() — numpy-based dataloader with no CUDA/pin-memory dependencies
  • Added get_token_bytes_np() — returns token byte-lengths as a numpy array (no CUDA required)

Modified: pyproject.toml

  • Added mlx>=0.18.0 dependency

What's preserved

No existing GPU code was touched. train.py (CUDA + Flash Attention 3) is completely unchanged. Both scripts are independently runnable:

uv run train.py       # NVIDIA GPU
uv run train_mlx.py   # Apple Silicon (M1/M2/M3/M4)

Usage

uv run prepare.py       # same one-time data prep
uv run train_mlx.py     # Apple Silicon training

- train_mlx.py: full MLX port of train.py — same architecture (GQA, RoPE,
  value residuals, sliding-window SSSL pattern, squared-ReLU MLP), Muon +
  AdamW optimizer running on CPU/numpy, tuned defaults for Mac hardware.
- prepare.py: add make_dataloader_mlx() (numpy, no pin-memory/CUDA) and
  get_token_bytes_np() so the MLX training script has zero CUDA deps.
- pyproject.toml: add mlx>=0.18.0 dependency.

Both train.py (CUDA) and train_mlx.py (MLX) are fully functional and
independently runnable; no existing GPU code was removed or modified.

Co-Authored-By: Oz <oz-agent@warp.dev>
@tobiasoberrauch
Copy link

Bug Report: KeyError: 0 in optimizer step

Tested this PR on an M3 Max — training crashes immediately at the first optimizer step:

File "train_mlx.py", line 470, in step
    model.update(mu.tree_unflatten(updates))
  File ".venv/lib/python3.10/site-packages/mlx/nn/layers/base.py", line 350, in apply
    current_value = dst[i]
KeyError: 0

Root cause

value_embeds is stored as a dict with string keys {"1": ..., "3": ..., "5": ...} (sparse — only alternating layers). When tree_flatten produces paths like value_embeds.1.weight, tree_unflatten interprets the numeric segments as list indices and builds a list of length 6 instead of a dict. model.update() then fails because it expects dict keys, not integer indices.

Fix

Replace the tree_unflatten + model.update() call (line 469–470) with direct parameter assignment:

        # Write updated parameters back — apply directly to avoid
        # tree_unflatten misinterpreting sparse dict keys (e.g.
        # value_embeds.1 / .3 / .5) as list indices.
        for path, val in updates:
            parts = path.split(".")
            obj = model
            for part in parts[:-1]:
                if isinstance(obj, list):
                    obj = obj[int(part)]
                elif isinstance(obj, dict):
                    obj = obj[part]
                else:
                    obj = getattr(obj, part)
            last = parts[-1]
            if isinstance(obj, dict):
                obj[last] = val
            else:
                setattr(obj, last, val)

Verified

With this fix, training completes successfully on M3 Max:

val_bpb:          1.678259
training_seconds: 300.2
peak_memory_mb:   7316.1
mfu_percent:      6.22
num_steps:        345
num_params_M:     10.52

@AKHegde22
Copy link
Author

Thanks for the update @tobiasoberrauch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants