MLX implementation of the Baby Dragon Hatchling (BDH) architecture, optimized for training on Apple Silicon (M1/M2/M3/M4) with unified memory.
Original Paper: Adrian Kosowski, Przemysław Uznański, Jan Chorowski, Zuzanna Stamirowska, Michał Bartoszkiewicz, "The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain", arXiv:2509.26507
- What is BDH?
- Why MLX?
- Architecture Overview
- Installation
- Quick Start
- PyTorch to MLX Conversion
- Training Guide
- Performance
- API Reference
- Citation
Baby Dragon Hatchling (BDH) is a novel Large Language Model architecture that bridges the gap between Transformers and biologically-plausible neural networks. Unlike standard Transformers, BDH features:
- Byte-Level Processing: Uses all 256 bytes as vocabulary - no tokenizer required
- Shared Parameters Across Layers: Same weights reused in all layers (recurrent depth)
- Sparse, Non-Negative Activations: ReLU-based activations for biological plausibility
- Constrained Attention: Q=K constraint with causal masking (diagonal=-1)
- Hierarchical Gating: Multiplicative gating instead of additive residuals
- Brain-Inspired Network: High modularity, heavy-tailed degree distribution
| Feature | Transformer | BDH |
|---|---|---|
| Parameters | Unique per layer | Shared across layers |
| Activations | Any sign | Sparse, non-negative (ReLU) |
| Attention | Q, K, V projections | Q=K constraint |
| Gating | Additive (x + FFN(x)) | Multiplicative (x * y) |
| Interpretability | Dense, polysemantic | Sparse, monosemantic |
| LayerNorm | With affine transform | Without affine transform |
| Vocabulary | Subword tokens | Byte-level (256) |
MLX is Apple's machine learning framework designed specifically for Apple Silicon. This implementation leverages:
- Unified Memory Architecture: No explicit CPU↔GPU transfers
- Metal GPU Acceleration: Native hardware optimization
- Lazy Evaluation: Efficient computation graphs
- NumPy-like API: Familiar and intuitive
- Low Memory Overhead: Train larger models on Mac hardware
Training BDH (25M parameters) on M2 Max 64GB:
| Framework | Tokens/sec | Memory Usage | Setup Complexity |
|---|---|---|---|
| PyTorch (MPS) | ~2,500 | 12GB | Medium (device management) |
| MLX | ~5,000 | 8GB | Low (automatic) |
Input (B, T) → Embedding (B, T, D) → [BDH Layers x6] → Output (B, T, vocab_size)
Each BDH Layer:
┌─────────────────────────────────────────────────────────────┐
│ x → encoder → ReLU → Attention(RoPE) → encoder_v → ReLU │
│ ↓ ↓ │
│ x_sparse y_sparse │
│ └──────── × ─────────┘ │
│ ↓ │
│ xy_sparse │
│ ↓ │
│ decoder │
│ ↓ │
│ LayerNorm(x + y) │
└─────────────────────────────────────────────────────────────┘
BDH uses a specialized attention with:
- Rotary Position Embeddings (RoPE): Relative position encoding
- Q=K Constraint: Queries and keys are identical
- Causal Masking with diagonal=-1: Excludes current token (not just future)
- No softmax: Direct attention scores
Unlike Transformers where each layer has L × (W_q, W_k, W_v, W_o, W_ffn1, W_ffn2) parameters, BDH has:
- One set of weights (
encoder,decoder,encoder_v) reused in all layers - This creates recurrent depth similar to Universal Transformers
- Dramatically reduces parameters while maintaining expressiveness
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- 16GB RAM minimum (64GB recommended for larger models)
pip install mlx numpy datasets huggingface-hubOr use the provided requirements file:
pip install -r requirements.txtimport mlx.core as mx
print(f"MLX version: {mx.__version__}")
print(f"Metal available: {mx.metal.is_available()}")python train_mlx.pyThis will train on the Severian/Internal-Knowledge-Map dataset with default settings optimized for 64GB Mac.
import mlx.core as mx
from bdh_mlx import BDH, BDHConfig
# Initialize model
config = BDHConfig()
model = BDH(config)
# Byte-level prompt: "The meaning of life"
prompt = "The meaning of life"
prompt_bytes = list(bytearray(prompt, "utf-8"))
idx = mx.array([prompt_bytes])
# Generate
output = model.generate(
idx,
max_new_tokens=200,
temperature=0.8,
top_k=50
)
# Decode bytes to text
text = bytes(output[0].tolist()).decode("utf-8", errors="backslashreplace")
print(text)This section details the conversion process and explains why the MLX implementation is mathematically equivalent to the original PyTorch version.
| Operation | PyTorch | MLX | Notes |
|---|---|---|---|
| Tensor creation | torch.Tensor |
mx.array |
Same semantics |
| Random normal | torch.randn() |
mx.random.normal() |
MLX requires explicit scale |
| View/Reshape | .view() or .reshape() |
.reshape() |
MLX only has .reshape() |
| Transpose | .transpose(1,2) |
.transpose(0,2,1,3) |
MLX requires full dimension specification |
| Matrix transpose | .mT |
.transpose(0,1,3,2) |
Swap last two dims explicitly |
| ReLU | F.relu() |
mx.maximum(x, 0) |
Identical operation |
| Module method | forward() |
__call__() |
MLX convention |
| Device | .to(device) |
N/A | MLX manages automatically |
| Evaluation | N/A | mx.eval() |
Required for lazy evaluation |
PyTorch (line 142 in bdh.py):
xy_sparse.transpose(1, 2).reshape(B, 1, T, N * nh)
# Shape: (B, nh, T, N) → (B, T, nh, N) → (B, 1, T, nh×N)MLX (lines 149-152 in bdh_mlx.py):
xy_sparse.transpose(0, 2, 1, 3).reshape(B, T, N * nh)
yMLP = mx.expand_dims(yMLP, axis=1)
# Shape: (B, nh, T, N) → (B, T, nh, N) → (B, T, nh×N) → (B, 1, T, nh×N)Why different?
- PyTorch's
transpose(1, 2)swaps dimensions 1 and 2 - MLX's
transpose()requires full permutation:(0, 2, 1, 3)achieves the same swap - Result is mathematically identical
PyTorch (line 73):
scores = (QR @ KR.mT).tril(diagonal=-1)MLX (lines 76-79):
scores = (QR @ KR.transpose(0, 1, 3, 2))
mask = mx.tril(mx.ones((T, T)), k=-1)
scores = scores * mask.reshape(1, 1, T, T)Why different?
- PyTorch's
.mTis shorthand for transposing last 2 dimensions - MLX requires explicit
transpose(0, 1, 3, 2)permutation - PyTorch's in-place
.tril()modifies the tensor - MLX uses explicit mask multiplication (clearer, no in-place modification)
- Result is mathematically identical
PyTorch (lines 85-98):
self.decoder = nn.Parameter(torch.zeros((nh * N, D)).normal_(std=0.02))
self.encoder = nn.Parameter(torch.zeros((nh, D, N)).normal_(std=0.02))MLX (lines 100-104):
self.decoder = mx.random.normal((nh * N, D), scale=0.02)
self.encoder = mx.random.normal((nh, D, N), scale=0.02)Why different?
- PyTorch requires explicit
nn.Parameter()wrapper - MLX automatically registers
mx.arrayassigned in__init__as trainable parameters - Functionally identical - both are optimized during training
PyTorch (line 52):
v_rot = torch.stack((-v[..., 1::2], v[..., ::2]), dim=-1).view(*v.size())MLX (lines 56-57):
v_rot_parts = mx.stack([-v[..., 1::2], v[..., ::2]], axis=-1)
v_rot = v_rot_parts.reshape(v.shape)Why different?
- PyTorch's
.view(*v.size())unpacks size tuple - MLX's
.reshape(v.shape)directly uses shape - Mathematically identical rotation operation
The MLX implementation preserves exact mathematical equivalence with PyTorch:
- Same computation graph - all operations identical
- Same parameter shapes - verified via parameter counting
- Same initialization - both use
normal(std=0.02) - Same forward pass - tensor shapes match at every layer
- Same loss computation - cross-entropy with same reduction
Key insight: The differences are purely API translations, not architectural changes. The underlying mathematics, tensor operations, and information flow are preserved exactly.
Edit train_mlx.py to customize:
# Model Configuration
BDH_CONFIG = bdh_mlx.BDHConfig(
n_layer=6, # Number of recurrent layers
n_embd=256, # Embedding dimension
n_head=4, # Attention heads
mlp_internal_dim_multiplier=128, # Internal dimension multiplier
vocab_size=256, # Byte-level (fixed)
dropout=0.1, # Dropout probability
)
# Training Configuration
BLOCK_SIZE = 8192 # Context window (longer = better quality)
BATCH_SIZE = 1 # Adjust based on available RAM
MAX_ITERS = 5000 # Training steps
LEARNING_RATE = 5e-5 # Learning rate
WEIGHT_DECAY = 0.05 # AdamW weight decay
GRAD_CLIP = 1.0 # Gradient clipping| Mac RAM | Recommended Settings |
|---|---|
| 8-16GB | BATCH_SIZE=1, BLOCK_SIZE=512, n_embd=128 |
| 32GB | BATCH_SIZE=1, BLOCK_SIZE=2048, n_embd=256 |
| 64GB+ | BATCH_SIZE=1, BLOCK_SIZE=8192, n_embd=256 |
Note: Due to MLX's unified memory, larger BLOCK_SIZE is often better than larger BATCH_SIZE.
To use your own Hugging Face dataset:
# In train_mlx.py, modify:
def load_and_prepare_dataset(
dataset_name: str = "your-username/your-dataset",
training_mode: str = "both" # "system", "instruction", or "both"
):
# Rest of function remains the sameFor local text files:
# Load your text
with open("your_data.txt", "rb") as f:
data = f.read()
# Convert to MLX array
data_array = mx.array(list(data), dtype=mx.uint8)Checkpoints are automatically saved to checkpoints_mlx/:
# Format: bdh_mlx_step_{step}.npz
bdh_mlx_step_250.npz # Step 250
bdh_mlx_step_500.npz # Step 500Load a checkpoint:
checkpoint = mx.load("checkpoints_mlx/bdh_mlx_step_1000.npz")
model.load_weights(list(checkpoint.items()))Measured on M4 (64GB RAM):
| Configuration | Tokens/sec | Memory | Time to 1000 steps |
|---|---|---|---|
| Default (256 embd, 8192 ctx) | ~500 | 8GB | ~4.5 hours |
| Small (128 embd, 2048 ctx) | ~2000 | 4GB | ~1 hour |
| Large (512 embd, 8192 ctx) | ~200 | 20GB | ~11 hours |
- ~50-100 tokens/second for typical configurations
- Scales with model size and context length
- Top-k sampling adds ~10% overhead
BDH exhibits Transformer-like scaling laws:
- Loss decreases as
~1/N^αwhere N is parameter count - Context window scales linearly with memory
- Training time scales with
O(T² × D)where T=context, D=embedding dim
@dataclasses.dataclass
class BDHConfig:
n_layer: int = 6 # Number of BDH layers
n_embd: int = 256 # Embedding dimension D
dropout: float = 0.1 # Dropout probability
n_head: int = 4 # Number of attention heads
mlp_internal_dim_multiplier: int = 128 # N = mlp_internal_dim_multiplier × D / n_head
vocab_size: int = 256 # Always 256 for byte-levelclass BDH(nn.Module):
def __init__(self, config: BDHConfig)
def __call__(
self,
idx: mx.array, # Input tokens (B, T)
targets: Optional[mx.array] # Target tokens for loss (B, T)
) -> Tuple[mx.array, Optional[mx.array]]:
"""
Returns:
logits: (B, T, vocab_size) - unnormalized log probabilities
loss: scalar - cross-entropy loss (if targets provided)
"""
def generate(
self,
idx: mx.array, # Prompt tokens (B, T)
max_new_tokens: int, # Number of tokens to generate
temperature: float = 1.0, # Sampling temperature
top_k: Optional[int] = None # Top-k filtering
) -> mx.array:
"""
Autoregressively generate tokens.
Returns:
mx.array: Generated tokens (B, T + max_new_tokens)
"""import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from bdh_mlx import BDH, BDHConfig
# Initialize
config = BDHConfig()
model = BDH(config)
optimizer = optim.AdamW(learning_rate=1e-3, weight_decay=0.1)
# Loss function
def loss_fn(model, x, y):
_, loss = model(x, y)
return loss
# Gradient computation
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
# Training step
for step in range(max_iters):
# Get batch (x, y are mx.arrays)
x, y = get_batch()
# Forward + backward
loss, grads = loss_and_grad_fn(model, x, y)
# Update
optimizer.update(model, grads)
# Evaluate (required for lazy evaluation)
mx.eval(model.parameters())Traditional Transformers use different parameters at each layer:
Layer 1: W₁_q, W₁_k, W₁_v, W₁_o, W₁_ffn1, W₁_ffn2
Layer 2: W₂_q, W₂_k, W₂_v, W₂_o, W₂_ffn1, W₂_ffn2
...
BDH uses the same parameters in all layers:
All Layers: encoder, decoder, encoder_v (shared)
Benefits:
- Recurrent processing: Information iteratively refined
- Parameter efficiency: Fewer parameters for same depth
- Biological plausibility: Brain neurons don't have "layer-specific" synapses
In standard attention:
Q = x @ W_q
K = x @ W_k
scores = Q @ K.TIn BDH:
Q = encoder(x) # Sparse via ReLU
K = Q # Q and K are identical
scores = Q @ K.TBenefits:
- Simpler dynamics: Attention is based on activation overlap
- Hebbian-like: Neurons that fire together wire together
- Monosemanticity: Easier to interpret what each neuron represents
BDH processes raw bytes (0-255) instead of subword tokens:
Advantages:
- No tokenizer: No vocabulary bias or tokenization artifacts
- Universal: Works for any language, code, binary data
- Interpretable: One byte = one step
- Efficient: 256 vocab vs 50k+ for typical tokenizers
Trade-off:
- Longer sequences (1 byte per character vs ~0.75 tokens per word)
- But BDH's efficient attention handles this well
pip install mlx>=0.21.0Make sure you're on Apple Silicon (M1/M2/M3/M4). MLX doesn't support Intel Macs.
# Reduce these in train_mlx.py:
BATCH_SIZE = 1 # Already minimum
BLOCK_SIZE = 2048 # Reduce from 8192
n_embd = 128 # Reduce from 256
n_layer = 4 # Reduce from 6- Check Activity Monitor: Ensure GPU is being used
- Close other apps: Free up memory
- Disable low-power mode: System Settings → Battery
- Cool your Mac: Thermal throttling reduces performance
Usually indicates:
- Learning rate too high → try
LEARNING_RATE = 1e-4 - Gradient explosion → check
GRAD_CLIP = 1.0is enabled - Numerical instability → verify
dtype=mx.float32(not float16)
For Hugging Face datasets requiring authentication:
from huggingface_hub import login
login() # Follow prompts to enter token| Aspect | PyTorch Version | MLX Version |
|---|---|---|
| API Style | .forward(), .to(device) |
.__call__(), automatic |
| Memory Management | Manual device transfers | Unified memory (automatic) |
| Performance (Mac) | MPS backend (slower) | Metal native (faster) |
| Code Complexity | Higher (device handling) | Lower (cleaner) |
| Multi-GPU | Supported | Not yet (single device) |
| Ecosystem | Mature (CUDA, etc.) | Growing (Mac-only) |
| Mathematical Result | ✓ | ✓ (Identical) |
When to use MLX: Training on Mac, especially M-series with 64GB+ RAM
When to use PyTorch: Multi-GPU clusters, CUDA, broader hardware support
Potential enhancements:
- Model parallelism for larger models
- Quantization (4-bit, 8-bit) for inference
- KV-cache for faster generation
- Fine-tuning utilities
- Checkpointing/resuming improvements
- Multi-node distributed training
If you use BDH-MLX in your research, please cite both the original paper and this implementation:
@article{kosowski2025dragon,
title={The Dragon Hatchling: The Missing Link between the Transformer and Models of the Brain},
author={Kosowski, Adrian and Uzna{\'n}ski, Przemys{\l}aw and Chorowski, Jan and Stamirowska, Zuzanna and Bartoszkiewicz, Micha{\l}},
journal={arXiv preprint arXiv:2509.26507},
year={2025}
}Copyright 2025 Pathway Technology, Inc.
See LICENSE.md for details.
- Original BDH Authors: For the groundbreaking architecture
- Apple MLX Team: For the excellent framework
- Andrej Karpathy: For nanoGPT inspiration
- Original Paper: https://doi.org/10.48550/arXiv.2509.26507
- MLX Documentation: https://ml-explore.github.io/mlx/
- MLX Examples: https://github.com/ml-explore/mlx-examples
- Original PyTorch Implementation: https://github.com/pathwaycom/bdh
Questions? Open an issue or discussion on GitHub!