Running Qwen3.5-35B-A3B on a 16GB M2 MacBook Air by offloading MoE expert weights to an external SSD.
Result: ~0.5-2.0 tok/s generation with Q4 experts on SSD + Q8 non-expert weights in RAM. Coherent output with minimal quality loss. Not practical for daily use, but proves the concept works.
This is a personal experiment, not meant for production. Use at your own risk.
This project was built collaboratively with Claude. Architecture design, debugging, implementation, and this README were developed through iterative conversation — including the hunt for the RMSNorm weight offset bug.
Most quantized LLMs that fit in 16GB top out around 7-14B parameters. MoE models like Qwen3.5-35B-A3B are interesting because they have 35B total parameters but only activate ~3B per token — meaning 90%+ of the weights are idle at any given moment. The question was: can you keep the idle weights on cheap external storage and still get usable inference?
This isn't a new idea. Expert offloading has been explored in research (e.g., flash-moe) and production systems. This project is a from-scratch implementation to understand the tradeoffs firsthand — where the bottlenecks actually are, what cache sizes matter, and whether Q4 experts degrade quality in practice.
Qwen3.5-35B-A3B is a 35B parameter Mixture-of-Experts model that only activates ~3B parameters per token. Most of those 35B parameters are expert FFN weights sitting idle at any given time. This project exploits that sparsity: keep the always-needed weights (attention, norms, embeddings, shared expert, router) in RAM at Q8, and load only the 8 active experts per layer from SSD at Q4 as needed.
Memory budget:
- Non-expert weights (Q8): ~4GB in unified memory
- Expert LRU cache: ~1.7-3.4GB in RAM (configurable)
- OS + framework overhead: ~3GB
- Total: fits in 16GB
Storage: ~17GB of Q4 expert weights on an external SSD, organized as one binary file per layer with direct pread access.
This model has some interesting architectural choices:
- 40 decoder layers in a 3:1 repeating pattern: 30 GatedDeltaNet (linear attention) + 10 full softmax attention
- GatedDeltaNet linear attention: Recurrent state
(B, 32, 128, 128)with gated delta rule updates, causal depthwise conv1d, L2-normalized Q/K. No KV cache needed - state is fixed size regardless of sequence length - Full attention with output gate: Q projection outputs 2x width (Q + gate interleaved per-head), partial RoPE on 64 of 256 head dims, GQA with 16 query / 2 KV heads
- MoE: 256 total experts, top-8 routed + 1 shared expert always active per token. Each expert is a small gated FFN (hidden=2048, intermediate=512)
Token → Attention → Router → [top-8 expert IDs] → LRU Cache check
↓ miss
pread from SSD
↓
Parse Q4 weights
↓
mx.quantized_matmul × 2
↓
Weighted sum + shared expert
- pread, not mmap: Direct file reads avoid per-page fault overhead
- Thread pool: 8 threads for parallel expert loading (one per active expert)
- LRU cache: Hot experts stay in RAM. At 2048 slots (~3.4GB), hit rates reach ~70%+
- Per-layer binary files:
layer_00.binthroughlayer_39.bin, each containing all 256 experts contiguously. 64-byte header + expert data at predictable offsets
| Component | Bits | Group Size | Format |
|---|---|---|---|
| Expert FFN weights | Q4 | 64 | mx.quantized_matmul (uint32 packed + float16 scales/biases) |
| Attention, norms, embeddings, etc. | Q8 | 64 | nn.QuantizedLinear |
| Embedding table | Q8 | 64 | nn.QuantizedEmbedding |
Measured on M2 MacBook Air (16GB, 8-core GPU), external USB-C SSD (~500MB/s):
| Metric | Value |
|---|---|
| Prompt processing | ~2 tok/s (single-token recurrent mode) |
| Generation | ~0.5-2.0 tok/s (varies with cache hit rate) |
| Expert cache hit rate | ~70% at 2048 slots, ~89% at 4096 slots |
| Expert size (Q4) | ~1.7MB each |
| Non-expert weights (Q8) | ~4GB total |
| Total SSD storage | ~17GB |
The bottleneck at this point is split between SSD I/O for cache misses and sequential compute through 40 layers (320 expert loads per token).
- MoE sparsity is real: Only 8 of 256 experts activate per token. The working set is small enough to cache effectively.
- Q4 experts with Q8 non-experts: Quality is surprisingly good. The model produces coherent multi-step reasoning with
<think>blocks. Verified against PyTorch reference at cos_sim > 0.999 per layer. - pread with thread pool: Simple and effective. OS page cache provides a second layer of caching for free.
- LRU cache matters a lot: 256 slots (smaller than one token's 320 expert loads) = 0% hit rate (total thrash). 2048 slots = 71% hit rate. The cache size must exceed the per-token expert load count.
-
RMSNorm weight offset: Qwen3.5 uses
(1 + weight) * norm(x), notweight * norm(x). The stored weights are small offsets (~0.028) from zero. Using them directly caused ~35x attenuation at every norm layer, producing complete gibberish. RMSNormGated (used in GatedDeltaNet) uses the standard convention. This was the root cause of long debugging. -
Cache thrashing: The naive cache size of 256 was smaller than the 320 expert loads per token (8 experts x 40 layers), causing 0% hit rate. Every token evicted the entire cache.
-
Blocking the event loop: Running synchronous MLX inference inside async FastAPI handlers deadlocked the server. Fixed by running generation in a thread pool with
run_in_executor. -
Memory pressure: A 4096-slot cache (6.8GB) pushed the system into 13.6GB of swap, making everything slower than a smaller cache. The sweet spot is constrained by actual free RAM, not theoretical capacity.
- Chunked prefill: Processing the prompt one token at a time is painfully slow. A chunked algorithm for GatedDeltaNet could process 64-256 tokens per step.
- Expert prefetching: Run the router for the next MoE layer while the current one computes, start SSD reads before they're needed.
- Speculative decoding: Use a small draft model in RAM to predict multiple tokens, verify with the full model.
- Custom Metal kernels: Fused operations for the many small matmuls could reduce GPU dispatch overhead.
- macOS with Apple Silicon (M1/M2/M3/M4+, 16GB+)
- External SSD with ~20GB free space
- Python 3.11+
- Access to
Qwen/Qwen3.5-35B-A3Bon HuggingFace
git clone <repo>
cd llm-16gb
pip install -e .Download the model, quantize experts to Q4 and non-expert weights to Q8, and organize on SSD:
llm-16gb-prepare \
--model-id Qwen/Qwen3.5-35B-A3B \
--output-dir /Volumes/YourSSD/ai-models/qwen35-a3bThis creates:
/Volumes/YourSSD/ai-models/qwen35-a3b/
├── config.json # Model config (copied from HF)
├── tokenizer.json # Tokenizer (copied from HF)
├── tokenizer_config.json # Tokenizer config (copied from HF)
├── non_expert_weights.safetensors # Q8 attention/norms/embeddings (~4GB)
└── experts/ # Q4 expert weights (~17GB)
├── layer_00.bin
├── layer_01.bin
├── ...
└── layer_39.bin
llm-16gb --model-dir /Volumes/YourSSD/ai-models/qwen35-a3bOptions: --temperature 0.7, --top-p 0.9, --top-k 40, --max-tokens 512, --cache-size 2048
llm-16gb-server \
--model-dir /Volumes/YourSSD/ai-models/qwen35-a3b \
--port 8000 \
--cache-size 2048OpenAI-compatible API at http://127.0.0.1:8000/v1:
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model":"qwen3.5-35b-a3b","messages":[{"role":"user","content":"Hello!"}]}'Supports:
/v1/chat/completions- Chat completions (streaming + non-streaming)/v1/models- Model list/v1/cache/stats- Expert cache hit rate- Streaming with
reasoning_contentfor thinking preview - Tool calling (Qwen3.5 native format, mapped to OpenAI format)
llm_16gb/
├── cli.py # Interactive CLI entry point
├── server.py # OpenAI-compatible API server
├── config.py # Model configuration
├── cache.py # Hybrid cache (recurrent + KV)
├── expert_store.py # SSD I/O with LRU cache
├── generate.py # Token generation + sampling
├── tokenizer.py # Chat template + tool formatting
├── prepare_weights.py # HF → SSD weight conversion
├── safetensors_reader.py # BF16 safetensors reader
└── model/
├── qwen35.py # Top-level model
├── decoder.py # Decoder layer
├── attention.py # Full softmax attention
├── gated_deltanet.py # GatedDeltaNet linear attention
├── moe.py # MoE with SSD expert loading
└── layers.py # RMSNorm, RoPE, etc.
- Qwen3.5 Technical Report
- GatedDeltaNet: Gated Delta Networks
- MLX Framework
- flash-moe - Inspiration for pread-based expert loading
- HF Transformers
Qwen3_5Moeimplementation (transformers==5.4.0)