diff --git a/.gitignore b/.gitignore
index 40e11448..3134fc4f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,3 +17,7 @@ texts/
finetune/outputs/
finetune/data/train/
.claude/
+
+# finetune-mlx artifacts (local runs)
+finetune-mlx/*.nohup.log
+finetune-mlx/eval_results.json
diff --git a/finetune-mlx/.gitignore b/finetune-mlx/.gitignore
new file mode 100644
index 00000000..fdf965bd
--- /dev/null
+++ b/finetune-mlx/.gitignore
@@ -0,0 +1,21 @@
+# Training artifacts (large files)
+adapters/
+models/
+exports/merged/
+exports/*.gguf
+
+# Keep Modelfiles (small, useful)
+!exports/*.Modelfile
+
+# Python
+.venv/
+__pycache__/
+*.pyc
+
+# Logs
+*.log
+*.nohup.log
+
+# Data (downloaded separately)
+data/
+eval_results.json
diff --git a/finetune-mlx/README.md b/finetune-mlx/README.md
new file mode 100644
index 00000000..e10a56b2
--- /dev/null
+++ b/finetune-mlx/README.md
@@ -0,0 +1,115 @@
+# QMD Query Expansion - Apple Silicon (MLX)
+
+Apple Silicon alternative to the CUDA-based [`finetune/`](../finetune/) directory.
+
+Port of QMD's query expansion fine-tuning to Apple Silicon using [MLX](https://github.com/ml-explore/mlx).
+
+Train small language models locally on M1/M2/M3/M4 Macs to expand search queries for hybrid retrieval.
+
+## Features
+
+- **SFT Training**: Supervised fine-tuning with LoRA
+- **GRPO Training**: Group Relative Policy Optimization (reinforcement learning)
+- **100% Local**: No cloud GPU needed, runs on Apple Silicon
+- **MLX Optimized**: Native Metal acceleration
+
+## Results
+
+Comparison with original NVIDIA A10G implementation:
+
+| Metric | NVIDIA (SFT+GRPO) | Apple Silicon (SFT) | Apple Silicon (GRPO) |
+|--------|-------------------|---------------------|----------------------|
+| Avg Score | 92% | 99.6% | 100.4% |
+| Perfect Queries | 30/30 | 28/30 | 28/30 |
+| Hardware | A10G 24GB | Mac Mini M4 | Mac Mini M4 |
+| Cost | ~$2/run | $0 | $0 |
+
+## Quick Start
+
+```bash
+# Setup
+python -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+
+# Download and convert base model
+python -c "from mlx_lm import load; load('Qwen/Qwen3-1.7B')"
+
+# Train SFT (supervised fine-tuning)
+python train.py sft --iters 3500
+
+# Train GRPO (reinforcement learning refinement)
+python grpo.py --steps 200
+
+# Evaluate
+python grpo.py --eval-only --adapter adapters/qwen3-grpo
+```
+
+## What It Does
+
+Given a query like `"auth config"`, the model produces structured expansions:
+
+```
+lex: authentication configuration
+lex: auth settings setup
+vec: how to configure authentication settings
+hyde: Authentication can be configured by setting AUTH_SECRET...
+```
+
+These feed into QMD's hybrid retrieval:
+- `lex:` → BM25 full-text search
+- `vec:` → Vector similarity search
+- `hyde:` → Hypothetical document embedding
+
+## File Structure
+
+```
+├── train.py # SFT training script
+├── grpo.py # GRPO (RL) training script
+├── eval.py # Evaluation utilities
+├── reward.py # Scoring/reward function
+├── convert.py # GGUF conversion for Ollama
+├── configs/
+│ └── sft.yaml # SFT hyperparameters
+├── evals/
+│ └── queries.txt # Test queries (31 total)
+└── tests/ # Unit tests
+```
+
+## Requirements
+
+- macOS with Apple Silicon (M1/M2/M3/M4)
+- Python 3.10+
+- ~8GB RAM for training
+- ~4GB disk for models
+
+## Training Details
+
+### SFT (Supervised Fine-Tuning)
+- Base model: Qwen3-1.7B
+- LoRA rank: 8, layers: 8
+- Learning rate: 1e-4
+- Steps: 3500
+- Time: ~60 min on M4
+
+### GRPO (Group Relative Policy Optimization)
+- Starts from SFT checkpoint
+- 4 completions per query
+- KL regularization (β=0.04)
+- Steps: 200
+- Time: ~30 min on M4
+
+## Credits
+
+- Original QMD: [tobi/qmd](https://github.com/tobi/qmd)
+- MLX framework: [ml-explore/mlx](https://github.com/ml-explore/mlx)
+- Base model: [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B)
+
+## Contributors
+
+- [@sujito00](https://github.com/sujito00)
+- [@dgilperez](https://github.com/dgilperez)
+
+## License
+
+MIT
diff --git a/finetune-mlx/configs/sft.yaml b/finetune-mlx/configs/sft.yaml
new file mode 100644
index 00000000..2fe0b04f
--- /dev/null
+++ b/finetune-mlx/configs/sft.yaml
@@ -0,0 +1,21 @@
+# SFT Training Config for QMD Query Expansion (Apple Silicon)
+
+model:
+ base: "Qwen/Qwen3-1.7B"
+ output: "qmd-query-expansion-1.7B-sft"
+
+dataset:
+ name: "tobil/qmd-query-expansion-train-v2"
+ text_field: "text"
+ eval_split: 0.1
+
+training:
+ batch_size: 4
+ iters: 3000
+ learning_rate: 2e-4
+ max_length: 512
+ grad_accumulation_steps: 4
+
+lora:
+ num_layers: 16
+ rank: 16
diff --git a/finetune-mlx/convert.py b/finetune-mlx/convert.py
new file mode 100644
index 00000000..6aebe7ff
--- /dev/null
+++ b/finetune-mlx/convert.py
@@ -0,0 +1,259 @@
+#!/usr/bin/env python3
+"""
+Convert trained MLX model to GGUF for Ollama/llama.cpp.
+
+Full pipeline:
+1. Merge LoRA adapter into base model
+2. Dequantize if needed (MLX 4-bit → FP16)
+3. Convert to GGUF (requires llama.cpp)
+4. Quantize to Q4_K_M (optional, requires llama-quantize)
+
+Usage:
+ python convert.py # Full pipeline with defaults
+ python convert.py --output qmd-expand # Custom output name
+ python convert.py --quantize q4_k_m # Include quantization step
+ python convert.py --skip-merge # Use already-merged model
+
+Requirements:
+ pip install mlx_lm gguf
+ git clone https://github.com/ggerganov/llama.cpp
+ cd llama.cpp && mkdir build && cd build && cmake .. && make llama-quantize
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+import sys
+from pathlib import Path
+
+
+def find_llama_cpp():
+ """Find llama.cpp installation."""
+ # Check common locations
+ locations = [
+ Path.home() / "src" / "llama.cpp",
+ Path.home() / "llama.cpp",
+ Path("/usr/local/llama.cpp"),
+ Path("../llama.cpp"),
+ ]
+
+ for loc in locations:
+ convert_script = loc / "convert_hf_to_gguf.py"
+ if convert_script.exists():
+ return loc
+
+ # Check if in PATH
+ if shutil.which("llama-quantize"):
+ return Path(shutil.which("llama-quantize")).parent.parent
+
+ return None
+
+
+def merge_adapter(model_path: Path, adapter_path: Path, output_path: Path):
+ """Merge LoRA adapter into base model using mlx_lm fuse."""
+ print(f"🔄 Merging adapter into base model...")
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ cmd = [
+ sys.executable, "-m", "mlx_lm", "fuse",
+ "--model", str(model_path),
+ "--adapter-path", str(adapter_path),
+ "--save-path", str(output_path),
+ ]
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ print(f"❌ Merge failed: {result.stderr}")
+ sys.exit(1)
+
+ print(f"✅ Merged model saved to: {output_path}")
+ return output_path
+
+
+def dequantize_model(model_path: Path, output_path: Path):
+ """Dequantize MLX model from 4-bit to FP16."""
+ print(f"🔄 Dequantizing model to FP16...")
+
+ # Check if model is quantized
+ config_path = model_path / "config.json"
+ if config_path.exists():
+ import json
+ with open(config_path) as f:
+ config = json.load(f)
+ if "quantization" not in config and "quantization_config" not in config:
+ print(f"ℹ️ Model is not quantized, skipping dequantization")
+ return model_path
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Use mlx_lm convert with dequantize flag
+ cmd = [
+ sys.executable, "-c",
+ f"""
+from mlx_lm import convert
+convert(
+ hf_path='{model_path}',
+ mlx_path='{output_path}',
+ dequantize=True,
+ dtype='float16'
+)
+"""
+ ]
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ print(f"❌ Dequantization failed: {result.stderr}")
+ sys.exit(1)
+
+ print(f"✅ Dequantized model saved to: {output_path}")
+ return output_path
+
+
+def convert_to_gguf(model_path: Path, output_path: Path, llama_cpp_path: Path):
+ """Convert HF/MLX model to GGUF format."""
+ print(f"🔄 Converting to GGUF...")
+
+ convert_script = llama_cpp_path / "convert_hf_to_gguf.py"
+
+ cmd = [
+ sys.executable, str(convert_script),
+ str(model_path),
+ "--outfile", str(output_path),
+ "--outtype", "f16"
+ ]
+
+ result = subprocess.run(cmd)
+
+ if result.returncode != 0:
+ print(f"❌ GGUF conversion failed")
+ sys.exit(1)
+
+ print(f"✅ GGUF saved to: {output_path}")
+ return output_path
+
+
+def quantize_gguf(input_path: Path, output_path: Path, quant_type: str, llama_cpp_path: Path):
+ """Quantize GGUF model."""
+ print(f"🔄 Quantizing to {quant_type}...")
+
+ # Find llama-quantize
+ quantize_bin = llama_cpp_path / "build" / "bin" / "llama-quantize"
+ if not quantize_bin.exists():
+ quantize_bin = shutil.which("llama-quantize")
+
+ if not quantize_bin:
+ print(f"⚠️ llama-quantize not found, skipping quantization")
+ print(f" Build it: cd llama.cpp && mkdir build && cd build && cmake .. && make llama-quantize")
+ return input_path
+
+ cmd = [str(quantize_bin), str(input_path), str(output_path), quant_type.upper()]
+
+ result = subprocess.run(cmd)
+
+ if result.returncode != 0:
+ print(f"❌ Quantization failed")
+ return input_path
+
+ print(f"✅ Quantized GGUF saved to: {output_path}")
+ return output_path
+
+
+def create_modelfile(output_name: str, gguf_filename: str):
+ """Create Ollama Modelfile."""
+ modelfile = Path("exports") / f"{output_name}.Modelfile"
+ modelfile.parent.mkdir(parents=True, exist_ok=True)
+
+ content = f'''# Modelfile for {output_name}
+# Usage: ollama create {output_name} -f exports/{output_name}.Modelfile
+
+FROM ./{gguf_filename}
+
+TEMPLATE """<|im_start|>user
+/no_think Expand this search query: {{{{.Prompt}}}}<|im_end|>
+<|im_start|>assistant
+"""
+
+PARAMETER temperature 0.3
+PARAMETER top_p 0.9
+PARAMETER stop "<|im_end|>"
+'''
+
+ with open(modelfile, "w") as f:
+ f.write(content)
+
+ print(f"✅ Modelfile saved to: {modelfile}")
+ return modelfile
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert MLX model to GGUF",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ python convert.py
+ python convert.py --quantize q4_k_m
+ python convert.py --model models/Qwen_Qwen3-1.7B/mlx --adapter adapters/qmd_query_expansion_1.7B_sft
+"""
+ )
+ parser.add_argument("--model", default="models/Qwen_Qwen3-1.7B/mlx", help="Base model path")
+ parser.add_argument("--adapter", default="adapters/qmd_query_expansion_1.7B_sft", help="LoRA adapter path")
+ parser.add_argument("--output", "-o", default="qmd-query-expand", help="Output model name")
+ parser.add_argument("--quantize", "-q", help="Quantization type (e.g., q4_k_m, q8_0)")
+ parser.add_argument("--skip-merge", action="store_true", help="Skip adapter merge (use pre-merged model)")
+ parser.add_argument("--llama-cpp", help="Path to llama.cpp directory")
+
+ args = parser.parse_args()
+
+ # Find llama.cpp
+ llama_cpp_path = Path(args.llama_cpp) if args.llama_cpp else find_llama_cpp()
+ if not llama_cpp_path:
+ print("❌ llama.cpp not found. Clone it first:")
+ print(" git clone https://github.com/ggerganov/llama.cpp ~/src/llama.cpp")
+ sys.exit(1)
+
+ print(f"📍 Using llama.cpp at: {llama_cpp_path}")
+
+ model_path = Path(args.model)
+ adapter_path = Path(args.adapter)
+ exports_dir = Path("exports")
+ exports_dir.mkdir(exist_ok=True)
+
+ # Step 1: Merge adapter
+ if not args.skip_merge and adapter_path.exists():
+ merged_path = exports_dir / "merged" / args.output
+ model_path = merge_adapter(model_path, adapter_path, merged_path)
+
+ # Step 2: Dequantize if needed
+ fp16_path = exports_dir / "merged" / f"{args.output}-fp16"
+ model_path = dequantize_model(model_path, fp16_path)
+
+ # Step 3: Convert to GGUF
+ gguf_path = exports_dir / f"{args.output}-f16.gguf"
+ convert_to_gguf(model_path, gguf_path, llama_cpp_path)
+
+ final_gguf = gguf_path
+
+ # Step 4: Quantize if requested
+ if args.quantize:
+ quant_path = exports_dir / f"{args.output}-{args.quantize}.gguf"
+ final_gguf = quantize_gguf(gguf_path, quant_path, args.quantize, llama_cpp_path)
+
+ # Clean up F16 if quantization succeeded
+ if final_gguf != gguf_path and final_gguf.exists():
+ print(f"🧹 Removing intermediate F16 GGUF...")
+ gguf_path.unlink()
+
+ # Step 5: Create Modelfile
+ create_modelfile(args.output, final_gguf.name)
+
+ print(f"\n🎉 Done! Create Ollama model with:")
+ print(f" cd finetune-mlx && ollama create {args.output} -f exports/{args.output}.Modelfile")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune-mlx/demo.py b/finetune-mlx/demo.py
new file mode 100644
index 00000000..2d55a811
--- /dev/null
+++ b/finetune-mlx/demo.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+"""
+Quick demo: Use the trained QMD query expansion model.
+
+Usage:
+ python demo.py "your search query"
+ python demo.py --interactive
+"""
+
+import argparse
+import sys
+
+from mlx_lm import load, generate
+from mlx_lm.sample_utils import make_sampler
+
+PROMPT_TEMPLATE = """<|im_start|>user
+/no_think Expand this search query: {query}<|im_end|>
+<|im_start|>assistant
+"""
+
+
+def expand_query(model, tokenizer, query: str, temp: float = 0.0) -> str:
+ """Generate query expansion."""
+ prompt = PROMPT_TEMPLATE.format(query=query)
+ sampler = make_sampler(temp=temp)
+ response = generate(
+ model, tokenizer,
+ prompt=prompt,
+ max_tokens=200,
+ sampler=sampler,
+ verbose=False
+ )
+ return response.replace('<|im_end|>', '').strip()
+
+
+def main():
+ parser = argparse.ArgumentParser(description="QMD Query Expansion Demo")
+ parser.add_argument("query", nargs="?", help="Query to expand")
+ parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode")
+ parser.add_argument("--model", default="models/Qwen_Qwen2.5-1.5B/mlx", help="Base model path")
+ parser.add_argument("--adapter", default="adapters/sft", help="LoRA adapter path")
+ parser.add_argument("--temp", type=float, default=0.0, help="Temperature (0=deterministic)")
+ args = parser.parse_args()
+
+ print("Loading model...")
+ try:
+ model, tokenizer = load(args.model, adapter_path=args.adapter)
+ except Exception as e:
+ print(f"Error loading model: {e}")
+ print("\nTrying without adapter (base model)...")
+ model, tokenizer = load(args.model)
+
+ print("Ready!\n")
+
+ if args.interactive:
+ print("Interactive mode. Type 'quit' to exit.\n")
+ while True:
+ try:
+ query = input("Query> ").strip()
+ if query.lower() in ('quit', 'exit', 'q'):
+ break
+ if not query:
+ continue
+ print("\nExpansion:")
+ print(expand_query(model, tokenizer, query, args.temp))
+ print()
+ except (KeyboardInterrupt, EOFError):
+ break
+ elif args.query:
+ print(f"Query: {args.query}\n")
+ print("Expansion:")
+ print(expand_query(model, tokenizer, args.query, args.temp))
+ else:
+ # Demo queries
+ demos = [
+ "auth config",
+ "docker networking",
+ "how to deploy",
+ "kubernetes pod",
+ ]
+ for q in demos:
+ print(f"Query: {q}")
+ print("-" * 40)
+ print(expand_query(model, tokenizer, q, args.temp))
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune-mlx/eval.py b/finetune-mlx/eval.py
new file mode 100644
index 00000000..dee89ef8
--- /dev/null
+++ b/finetune-mlx/eval.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python3
+"""
+Evaluate QMD Query Expansion model.
+
+Usage:
+ python eval.py # Use default adapter
+ python eval.py --adapter adapters/sft # Specify adapter
+ python eval.py --query "auth config" # Test single query
+"""
+
+import argparse
+import json
+import re
+from pathlib import Path
+
+import mlx.core as mx
+from mlx_lm import load, generate
+
+
+# Test queries for evaluation
+TEST_QUERIES = [
+ "auth config",
+ "how to deploy",
+ "rate limiting",
+ "database connection",
+ "api keys",
+ "error handling",
+ "caching strategy",
+ "user permissions",
+]
+
+PROMPT_TEMPLATE = """<|im_start|>user
+/no_think Expand this search query: {query}<|im_end|>
+<|im_start|>assistant
+"""
+
+
+def score_expansion(text: str) -> dict:
+ """Score an expansion based on format and quality."""
+ scores = {
+ "has_lex": 0,
+ "has_vec": 0,
+ "has_hyde": 0,
+ "lex_count": 0,
+ "vec_count": 0,
+ "format_valid": 0,
+ "total": 0,
+ }
+
+ lines = text.strip().split("\n")
+
+ for line in lines:
+ line = line.strip()
+ if line.startswith("lex:"):
+ scores["has_lex"] = 1
+ scores["lex_count"] += 1
+ elif line.startswith("vec:"):
+ scores["has_vec"] = 1
+ scores["vec_count"] += 1
+ elif line.startswith("hyde:"):
+ scores["has_hyde"] = 1
+
+ # Format is valid if we have at least one of each type
+ if scores["has_lex"] and scores["has_vec"] and scores["has_hyde"]:
+ scores["format_valid"] = 1
+
+ # Total score (0-100)
+ scores["total"] = (
+ scores["format_valid"] * 40 +
+ min(scores["lex_count"], 3) * 10 +
+ min(scores["vec_count"], 3) * 10 +
+ scores["has_hyde"] * 20
+ )
+
+ return scores
+
+
+def expand_query(model, tokenizer, query: str, max_tokens: int = 256) -> str:
+ """Generate expansion for a query."""
+ prompt = PROMPT_TEMPLATE.format(query=query)
+
+ response = generate(
+ model,
+ tokenizer,
+ prompt=prompt,
+ max_tokens=max_tokens,
+ verbose=False,
+ )
+
+ # Extract just the assistant response
+ if "<|im_end|>" in response:
+ response = response.split("<|im_end|>")[0]
+
+ return response.strip()
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Evaluate QMD model")
+ parser.add_argument("--model", default="models/Qwen_Qwen2.5-1.5B/mlx", help="Base model path")
+ parser.add_argument("--adapter", default="adapters/sft", help="LoRA adapter path")
+ parser.add_argument("--query", "-q", help="Single query to test")
+ parser.add_argument("--no-adapter", action="store_true", help="Run without adapter (baseline)")
+
+ args = parser.parse_args()
+
+ # Load model
+ model_path = Path(args.model)
+ adapter_path = Path(args.adapter) if not args.no_adapter else None
+
+ print(f"📦 Loading model: {model_path}")
+ if adapter_path and adapter_path.exists():
+ print(f"🔌 Loading adapter: {adapter_path}")
+ model, tokenizer = load(str(model_path), adapter_path=str(adapter_path))
+ else:
+ if adapter_path:
+ print(f"⚠️ Adapter not found: {adapter_path}, using base model")
+ model, tokenizer = load(str(model_path))
+
+ # Single query mode
+ if args.query:
+ print(f"\n🔍 Query: {args.query}\n")
+ expansion = expand_query(model, tokenizer, args.query)
+ print(expansion)
+ print(f"\n📊 Score: {score_expansion(expansion)}")
+ return
+
+ # Batch evaluation
+ print(f"\n📝 Evaluating {len(TEST_QUERIES)} test queries...\n")
+
+ results = []
+ total_score = 0
+
+ for query in TEST_QUERIES:
+ print(f"🔍 {query}")
+ expansion = expand_query(model, tokenizer, query)
+ scores = score_expansion(expansion)
+ total_score += scores["total"]
+
+ results.append({
+ "query": query,
+ "expansion": expansion,
+ "scores": scores,
+ })
+
+ # Print summary
+ status = "✅" if scores["format_valid"] else "❌"
+ print(f" {status} Score: {scores['total']}/100")
+ print(f" lex:{scores['lex_count']} vec:{scores['vec_count']} hyde:{scores['has_hyde']}")
+ print()
+
+ # Summary
+ avg_score = total_score / len(TEST_QUERIES)
+ print(f"\n{'='*50}")
+ print(f"📊 Average Score: {avg_score:.1f}/100")
+ print(f"{'='*50}")
+
+ # Save results
+ output_file = Path("eval_results.json")
+ with open(output_file, "w") as f:
+ json.dump(results, f, indent=2)
+ print(f"\n💾 Results saved to: {output_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune-mlx/evals/queries.txt b/finetune-mlx/evals/queries.txt
new file mode 100644
index 00000000..16224120
--- /dev/null
+++ b/finetune-mlx/evals/queries.txt
@@ -0,0 +1,48 @@
+# Test queries for QMD query expansion evaluation
+# One query per line, comments start with #
+
+# Technical documentation
+how to configure authentication
+typescript async await
+docker compose networking
+git rebase vs merge
+react useEffect cleanup
+
+# Short/ambiguous queries
+auth
+config
+setup
+api
+
+# Named entities (critical for entity preservation testing)
+who is TDS motorsports
+React hooks tutorial
+Docker container networking
+Kubernetes pod deployment
+AWS Lambda functions
+
+# Personal notes / journals style
+meeting notes project kickoff
+ideas for new feature
+todo list app architecture
+
+# Research / learning
+what is dependency injection
+difference between sql and nosql
+kubernetes vs docker swarm
+
+# Error/debugging
+connection timeout error
+memory leak debugging
+cors error fix
+
+# Temporal / recency queries (should expand with years, "recent", "latest")
+recent news about Shopify
+latest AI developments
+best laptops right now
+what changed in kubernetes latest version
+
+# Complex queries
+how to implement caching with redis in nodejs
+best practices for api rate limiting
+setting up ci cd pipeline with github actions
diff --git a/finetune-mlx/exports/qmd-query-expand.Modelfile b/finetune-mlx/exports/qmd-query-expand.Modelfile
new file mode 100644
index 00000000..929adc0b
--- /dev/null
+++ b/finetune-mlx/exports/qmd-query-expand.Modelfile
@@ -0,0 +1,16 @@
+# Modelfile for qmd-query-expand
+# Usage: ollama create qmd-query-expand -f exports/qmd-query-expand.Modelfile
+#
+# Requires: qmd-query-expand-q4_k_m.gguf in same directory
+# Generate with: python convert.py --quantize q4_k_m
+
+FROM ./qmd-query-expand-q4_k_m.gguf
+
+TEMPLATE """<|im_start|>user
+/no_think Expand this search query: {{.Prompt}}<|im_end|>
+<|im_start|>assistant
+"""
+
+PARAMETER temperature 0.3
+PARAMETER top_p 0.9
+PARAMETER stop "<|im_end|>"
diff --git a/finetune-mlx/grpo.py b/finetune-mlx/grpo.py
new file mode 100644
index 00000000..ade3288e
--- /dev/null
+++ b/finetune-mlx/grpo.py
@@ -0,0 +1,250 @@
+#!/usr/bin/env python3
+"""
+GRPO (Group Relative Policy Optimization) for MLX - Apple Silicon Edition
+
+Usage:
+ python grpo.py # Run GRPO training
+ python grpo.py --steps 100 # Custom steps
+ python grpo.py --eval-only # Just evaluate current model
+"""
+
+import argparse
+import json
+import math
+import random
+import time
+from pathlib import Path
+
+import mlx.core as mx
+import mlx.nn as nn
+import mlx.optimizers as optim
+from mlx_lm import load, generate
+from mlx_lm.sample_utils import make_sampler
+
+from reward import score_expansion_detailed
+
+# =============================================================================
+# Configuration
+# =============================================================================
+
+CONFIG = {
+ "base_model": "models/Qwen3-1.7B-mlx",
+ "sft_adapter": "adapters/qwen3-3500",
+ "output_adapter": "adapters/qwen3-grpo",
+
+ "num_generations": 4,
+ "max_tokens": 200,
+ "beta": 0.04,
+ "learning_rate": 5e-6,
+ "max_steps": 200,
+
+ "log_every": 5,
+ "save_every": 50,
+ "eval_every": 50,
+}
+
+TRAINING_QUERIES = [
+ "auth config", "how to deploy", "rate limiting", "database connection",
+ "api keys", "error handling", "caching strategy", "user permissions",
+ "typescript async await", "docker compose networking", "git rebase vs merge",
+ "react useEffect cleanup", "kubernetes pod deployment", "AWS Lambda functions",
+ "memory leak debugging", "cors error fix", "connection timeout error",
+ "dependency injection", "sql vs nosql", "ci cd pipeline",
+]
+
+PROMPT_TEMPLATE = """<|im_start|>user
+/no_think Expand this search query: {query}<|im_end|>
+<|im_start|>assistant
+"""
+
+
+def generate_completion(model, tokenizer, query, temp=0.7):
+ """Generate a single completion."""
+ prompt = PROMPT_TEMPLATE.format(query=query)
+ sampler = make_sampler(temp=temp, top_p=0.9)
+ response = generate(model, tokenizer, prompt=prompt, max_tokens=200, sampler=sampler, verbose=False)
+ return prompt, response.replace('<|im_end|>', '').strip()
+
+
+def compute_reward(query, completion):
+ """Score completion using reward function."""
+ result = score_expansion_detailed(query, completion)
+ return result['total'] / 140.0 # Normalize to [0, 1]
+
+
+def compute_log_prob(model, tokenizer, prompt, completion):
+ """Compute log probability of completion given prompt."""
+ full_text = prompt + completion
+ tokens = mx.array(tokenizer.encode(full_text))
+ prompt_len = len(tokenizer.encode(prompt))
+
+ # Forward pass
+ logits = model(tokens[None, :-1])[0] # [seq_len, vocab_size]
+
+ # Get completion log probs
+ log_probs = nn.log_softmax(logits[prompt_len-1:], axis=-1)
+ target_tokens = tokens[prompt_len:]
+
+ # Gather
+ token_log_probs = mx.take_along_axis(
+ log_probs[:len(target_tokens)],
+ target_tokens[:, None],
+ axis=-1
+ ).squeeze(-1)
+
+ return mx.sum(token_log_probs)
+
+
+def grpo_step(policy_model, ref_model, tokenizer, query, config, optimizer):
+ """Single GRPO training step."""
+ prompt = PROMPT_TEMPLATE.format(query=query)
+
+ # 1. Generate completions with varying temperatures (1.0 to 2.0 for diversity)
+ completions = []
+ for i in range(config["num_generations"]):
+ temp = 1.0 + 1.0 * i / max(1, config["num_generations"] - 1) # 1.0 to 2.0
+ _, comp = generate_completion(policy_model, tokenizer, query, temp=temp)
+ completions.append(comp)
+
+ # 2. Compute rewards
+ rewards = [compute_reward(query, c) for c in completions]
+ mean_reward = sum(rewards) / len(rewards)
+
+ # 3. Compute advantages (group relative)
+ std_reward = math.sqrt(sum((r - mean_reward)**2 for r in rewards) / len(rewards) + 1e-8)
+ advantages = [(r - mean_reward) / std_reward for r in rewards]
+
+ # 4. Compute reference log probs (frozen)
+ ref_log_probs = []
+ for comp in completions:
+ ref_lp = compute_log_prob(ref_model, tokenizer, prompt, comp)
+ mx.eval(ref_lp)
+ ref_log_probs.append(ref_lp)
+
+ # 5. Define loss function for this batch
+ def loss_fn(model):
+ total_loss = mx.array(0.0)
+ total_kl = mx.array(0.0)
+
+ for comp, adv, ref_lp in zip(completions, advantages, ref_log_probs):
+ # Always compute even with small advantages
+ policy_lp = compute_log_prob(model, tokenizer, prompt, comp)
+ kl = policy_lp - ref_lp
+ pg_loss = -mx.array(adv) * policy_lp
+
+ total_loss = total_loss + pg_loss + config["beta"] * mx.abs(kl)
+ total_kl = total_kl + kl
+
+ n = len(completions)
+ return total_loss / n, total_kl / n
+
+ # 6. Compute gradients using nn.value_and_grad
+ loss_grad_fn = nn.value_and_grad(policy_model, loss_fn)
+ (loss, kl), grads = loss_grad_fn(policy_model)
+
+ # 7. Update
+ optimizer.update(policy_model, grads)
+ mx.eval(policy_model.parameters())
+
+ return float(loss), float(kl), mean_reward
+
+
+def save_lora_weights(model, output_dir):
+ """Save LoRA weights to directory."""
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ model.save_weights(str(output_dir / "adapters.safetensors"))
+
+
+def evaluate(model, tokenizer, queries_file="evals/queries.txt"):
+ """Run evaluation."""
+ queries = []
+ with open(queries_file) as f:
+ for line in f:
+ line = line.strip()
+ if line and not line.startswith('#'):
+ queries.append(line)
+
+ greedy = make_sampler(temp=0.0)
+ scores = []
+ for q in queries:
+ prompt = PROMPT_TEMPLATE.format(query=q)
+ resp = generate(model, tokenizer, prompt=prompt, max_tokens=200, sampler=greedy, verbose=False)
+ resp = resp.replace('<|im_end|>', '').strip()
+ result = score_expansion_detailed(q, resp)
+ scores.append(result['total'])
+
+ return {
+ "avg": sum(scores) / len(scores),
+ "perfect": sum(1 for s in scores if s >= 100),
+ "total": len(scores),
+ "min": min(scores),
+ "max": max(scores),
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--steps", type=int, default=CONFIG["max_steps"])
+ parser.add_argument("--eval-only", action="store_true")
+ parser.add_argument("--adapter", type=str)
+ args = parser.parse_args()
+
+ config = CONFIG.copy()
+ config["max_steps"] = args.steps
+
+ print("=" * 60)
+ print("GRPO Training - QMD Query Expansion (MLX)")
+ print("=" * 60)
+
+ print(f"\n📥 Loading model...")
+ policy_model, tokenizer = load(config["base_model"], adapter_path=args.adapter or config["sft_adapter"])
+
+ if args.eval_only:
+ print("\n📊 Evaluation:")
+ r = evaluate(policy_model, tokenizer)
+ print(f"Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}, Range: [{r['min']}, {r['max']}]")
+ return
+
+ print(f"📥 Loading reference model (frozen)...")
+ ref_model, _ = load(config["base_model"], adapter_path=config["sft_adapter"])
+ ref_model.freeze()
+
+ optimizer = optim.Adam(learning_rate=config["learning_rate"])
+
+ output_dir = Path(config["output_adapter"])
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"\n🚀 Starting GRPO ({config['max_steps']} steps)...\n")
+
+ for step in range(1, config["max_steps"] + 1):
+ query = random.choice(TRAINING_QUERIES)
+
+ t0 = time.time()
+ loss, kl, reward = grpo_step(policy_model, ref_model, tokenizer, query, config, optimizer)
+ dt = time.time() - t0
+
+ if step % config["log_every"] == 0:
+ print(f"Step {step:4d} | Loss: {loss:.4f} | KL: {kl:.4f} | Reward: {reward:.3f} | {dt:.1f}s")
+
+ if step % config["save_every"] == 0:
+ ckpt = output_dir / f"ckpt_{step:04d}"
+ save_lora_weights(policy_model, ckpt)
+ print(f" 💾 Saved: {ckpt}")
+
+ if step % config["eval_every"] == 0:
+ print(f"\n📊 Eval @ step {step}:")
+ r = evaluate(policy_model, tokenizer)
+ print(f" Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}\n")
+
+ # Final save
+ save_lora_weights(policy_model, output_dir)
+
+ print("\n📊 Final evaluation:")
+ r = evaluate(policy_model, tokenizer)
+ print(f"Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}, Range: [{r['min']}, {r['max']}]")
+ print("\n✅ Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune-mlx/prepare_data.py b/finetune-mlx/prepare_data.py
new file mode 100644
index 00000000..7a58a630
--- /dev/null
+++ b/finetune-mlx/prepare_data.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+"""
+Prepare generated data for MLX training.
+
+Takes the JSONL output from generate_from_notes.py and creates
+train.jsonl and valid.jsonl files in the format expected by mlx-lm.
+
+Usage:
+ python dataset/prepare_data.py --input data/custom/expansions.jsonl --output data/custom
+"""
+
+import argparse
+import json
+import random
+from pathlib import Path
+
+
+def load_examples(input_path: Path) -> list[dict]:
+ """Load and validate examples from JSONL."""
+ examples = []
+ invalid = 0
+
+ with open(input_path) as f:
+ for line in f:
+ if not line.strip():
+ continue
+ try:
+ example = json.loads(line)
+ # Validate required fields
+ if "text" in example and len(example["text"]) > 50:
+ examples.append(example)
+ else:
+ invalid += 1
+ except json.JSONDecodeError:
+ invalid += 1
+
+ print(f"Loaded {len(examples)} valid examples ({invalid} invalid)")
+ return examples
+
+
+def split_data(examples: list[dict], eval_ratio: float = 0.1) -> tuple[list, list]:
+ """Split into train and validation sets."""
+ random.shuffle(examples)
+ split_idx = int(len(examples) * (1 - eval_ratio))
+ return examples[:split_idx], examples[split_idx:]
+
+
+def save_for_mlx(examples: list[dict], output_path: Path, name: str):
+ """Save in MLX-lm format (JSONL with 'text' field)."""
+ filepath = output_path / f"{name}.jsonl"
+ with open(filepath, "w") as f:
+ for ex in examples:
+ # MLX expects just {"text": "..."} format
+ f.write(json.dumps({"text": ex["text"]}) + "\n")
+ print(f"Saved {len(examples)} examples to {filepath}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Prepare data for MLX training")
+ parser.add_argument("--input", "-i", type=str, required=True,
+ help="Input JSONL from generate_from_notes.py")
+ parser.add_argument("--output", "-o", type=str, default="data/custom",
+ help="Output directory for train/valid files")
+ parser.add_argument("--eval-ratio", type=float, default=0.1,
+ help="Fraction of data for validation (default: 0.1)")
+ parser.add_argument("--seed", type=int, default=42,
+ help="Random seed for reproducibility")
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+
+ input_path = Path(args.input)
+ output_path = Path(args.output)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ # Load and split
+ examples = load_examples(input_path)
+ train_examples, valid_examples = split_data(examples, args.eval_ratio)
+
+ print(f"\nSplit: {len(train_examples)} train, {len(valid_examples)} valid")
+
+ # Save for MLX
+ save_for_mlx(train_examples, output_path, "train")
+ save_for_mlx(valid_examples, output_path, "valid")
+
+ # Save summary
+ summary = {
+ "total": len(examples),
+ "train": len(train_examples),
+ "valid": len(valid_examples),
+ "eval_ratio": args.eval_ratio,
+ }
+ with open(output_path / "dataset_info.json", "w") as f:
+ json.dump(summary, f, indent=2)
+
+ print(f"\n✅ Data prepared in {output_path}")
+ print(f" Ready for: python train.py sft --data {output_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune-mlx/requirements.txt b/finetune-mlx/requirements.txt
new file mode 100644
index 00000000..b60334df
--- /dev/null
+++ b/finetune-mlx/requirements.txt
@@ -0,0 +1,14 @@
+# Core MLX dependencies
+mlx-lm>=0.20.0
+mlx>=0.21.0
+
+# Data handling
+huggingface_hub>=0.20.0
+datasets>=2.14.0
+
+# Config and utilities
+pyyaml>=6.0
+tqdm>=4.65.0
+
+# Testing (optional)
+pytest>=7.0.0
diff --git a/finetune-mlx/reward.py b/finetune-mlx/reward.py
new file mode 100644
index 00000000..b29e678b
--- /dev/null
+++ b/finetune-mlx/reward.py
@@ -0,0 +1,610 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = []
+# ///
+"""
+QMD Query Expansion Reward Function
+
+Single source of truth for scoring query expansions. Used by:
+- GRPO training (as the RL reward signal)
+- Evaluation scripts (for scoring model outputs)
+
+Scores expansions on five dimensions:
+ Format (30) - Has lex/vec lines, no invalid lines
+ Diversity (30) - Multiple types, diverse content, no echoes
+ HyDE (20) - Optional bonus for hypothetical document passage
+ Quality (20) - Lex shorter than vec, natural language, key terms
+ Entity (20) - Named entity preservation in lex/vec lines
+
+Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
+"""
+
+import re
+from collections import Counter
+
+# =============================================================================
+# Constants
+# =============================================================================
+
+# "only:" mode patterns - when query ends with these, expect only that type
+# Format: "query /only:lex" (slash prefix, no space after colon)
+ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE)
+
+STOPWORDS = frozenset({
+ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
+ 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
+})
+
+KEY_TERM_STOPWORDS = frozenset({
+ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
+})
+
+GENERIC_LEX_PHRASES = frozenset({
+ 'find information about', 'search for', 'look up', 'get information',
+ 'learn about', 'information on', 'details about', 'find out about',
+ 'what is', 'how to', 'guide to', 'help with',
+})
+
+# Chat template tokens that indicate a broken output
+CHAT_TEMPLATE_TOKENS = frozenset({
+ '<|im_start|>', '<|im_end|>', '<|endoftext|>',
+ '\nassistant\n', '\nuser\n',
+})
+
+
+# =============================================================================
+# Parsing
+# =============================================================================
+
+def parse_expansion(text: str) -> dict:
+ """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
+ for line in text.strip().split("\n"):
+ line = line.strip()
+ if not line:
+ continue
+ if line.startswith("lex:"):
+ result["lex"].append(line[4:].strip())
+ elif line.startswith("vec:"):
+ result["vec"].append(line[4:].strip())
+ elif line.startswith("hyde:"):
+ result["hyde"].append(line[5:].strip())
+ else:
+ result["invalid"].append(line)
+ return result
+
+
+def detect_only_mode(query: str) -> tuple[str | None, str]:
+ """Detect if query ends with 'only: lex/vec/hyde'.
+
+ Returns (only_type, base_query) where only_type is None for normal queries.
+ """
+ match = ONLY_MODE_PATTERN.search(query)
+ if match:
+ only_type = match.group(1).lower()
+ base_query = query[:match.start()].strip()
+ return only_type, base_query
+ return None, query
+
+
+def clean_model_output(text: str) -> tuple[str, bool]:
+ """Strip chat template artifacts from model output.
+
+ Returns (cleaned_text, used_thinking) where used_thinking is True
+ if the model emitted ... blocks.
+ """
+ text = text.replace('<|im_end|>', '').strip()
+
+ used_thinking = '' in text and '' in text
+ if used_thinking:
+ text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip()
+
+ return text, used_thinking
+
+
+# =============================================================================
+# Helpers
+# =============================================================================
+
+def extract_named_entities(query: str) -> set:
+ """Extract named entities using heuristics.
+
+ Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React),
+ technical terms with special chars (node.js, C++), CamelCase (JavaScript),
+ and compound names (TDS motorsports -> both words).
+ """
+ entities = set()
+ words = query.split()
+ prev_was_entity = False
+
+ for i, word in enumerate(words):
+ clean = word.strip('.,!?:;()[]"\'')
+ if not clean:
+ prev_was_entity = False
+ continue
+
+ is_entity = False
+
+ if clean.isupper() and len(clean) >= 2:
+ entities.add(clean.lower())
+ is_entity = True
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
+ entities.add(clean.lower())
+ is_entity = True
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
+ entities.add(clean.lower())
+ is_entity = True
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
+ entities.add(clean.lower())
+ is_entity = True
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+ entities.add(clean.lower())
+ is_entity = True
+
+ prev_was_entity = is_entity
+
+ return entities
+
+
+def get_key_terms(query: str) -> set:
+ """Get non-stopword terms from a query."""
+ return set(query.lower().split()) - KEY_TERM_STOPWORDS
+
+
+def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
+ """Does the lex line contain at least one key term from the query?"""
+ key_terms = get_key_terms(query)
+ if not key_terms:
+ return True
+ return bool(key_terms & set(lex_line.lower().split()))
+
+
+def lex_preserves_entities(line: str, entities: set) -> bool:
+ """Does the line contain at least one named entity?"""
+ if not entities:
+ return True
+ lower = line.lower()
+ return any(e in lower for e in entities)
+
+
+def lex_is_generic(lex_line: str) -> bool:
+ """Is this lex line a useless generic filler phrase?"""
+ lower = lex_line.lower().strip()
+ for phrase in GENERIC_LEX_PHRASES:
+ if phrase in lower or lower.startswith(phrase.split()[0]):
+ remaining = lower
+ for word in phrase.split():
+ remaining = remaining.replace(word, '', 1).strip()
+ if len(remaining) < 3:
+ return True
+ return False
+
+
+def word_set_distance(a: str, b: str) -> int:
+ """Symmetric difference of word sets (how many words are unique to one)."""
+ return len(set(a.lower().split()) ^ set(b.lower().split()))
+
+
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
+ """Are two strings sufficiently different?"""
+ a, b = a.lower().strip(), b.lower().strip()
+ if a == b or a in b or b in a:
+ return False
+ return word_set_distance(a, b) >= min_distance
+
+
+def echoes_query(expansion: str, query: str) -> bool:
+ """Is this expansion just echoing the original query?"""
+ exp, q = expansion.lower().strip(), query.lower().strip()
+ return exp == q or (q in exp and len(exp) < len(q) + 10)
+
+
+def word_repetition_penalty(text: str) -> int:
+ """Penalty for words repeated 3+ times (excluding stopwords)."""
+ counts = Counter(re.findall(r'\b\w+\b', text.lower()))
+ return sum((c - 2) * 2 for w, c in counts.items()
+ if c >= 3 and w not in STOPWORDS and len(w) > 2)
+
+
+# =============================================================================
+# Scoring
+# =============================================================================
+
+def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
+ """Score an 'only:' mode expansion. Expects ONLY the requested type."""
+ parsed = parse_expansion(text)
+ deductions = []
+
+ # Expected type must be present
+ expected_items = parsed.get(only_type, [])
+ if not expected_items:
+ return {
+ "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
+ "think_bonus": 0, "total": 0, "max_possible": 100,
+ "percentage": 0.0, "rating": "Failed",
+ "deductions": [f"missing expected {only_type}: output"],
+ "parsed": parsed,
+ "entities_detected": [],
+ "only_mode": only_type,
+ }
+
+ # Penalize presence of OTHER types
+ other_types = {"lex", "vec", "hyde"} - {only_type}
+ unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
+ if unwanted_count > 0:
+ deductions.append(f"contains unwanted types (expected only {only_type})")
+
+ # --- Format (0-30) ---
+ format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
+
+ # --- Diversity (0-30) ---
+ diversity_score = 0
+ if len(expected_items) >= 2:
+ diversity_score += 15
+ # Check for diversity among items
+ div_score = 15
+ for i, a in enumerate(expected_items):
+ for b in expected_items[i+1:]:
+ if not is_diverse(a, b, 2):
+ div_score -= 5
+ deductions.append(f"{only_type} duplicate: {a[:20]}...")
+ diversity_score += max(0, div_score)
+ elif len(expected_items) == 1:
+ diversity_score = 15 # One item is fine for single-type output
+
+ # Check for echoes
+ for exp in expected_items:
+ if echoes_query(exp, base_query):
+ diversity_score -= 5
+ deductions.append(f"echoes query: {exp[:20]}...")
+ diversity_score = max(0, diversity_score)
+
+ # --- Type-specific quality (0-20) ---
+ quality_score = 10 # base
+ entities = extract_named_entities(base_query)
+
+ if only_type == "lex":
+ # Lex should be short keyword phrases with key terms
+ with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
+ if with_terms == len(expected_items):
+ quality_score += 5
+ # Check for generic phrases
+ generic = sum(1 for l in expected_items if lex_is_generic(l))
+ if generic == 0:
+ quality_score += 5
+ else:
+ deductions.append(f"{generic} generic lex phrases")
+
+ elif only_type == "vec":
+ # Vec should be natural language sentences
+ natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
+ if natural == len(expected_items):
+ quality_score += 10
+ else:
+ quality_score += 5
+ deductions.append("vec not all natural language")
+
+ elif only_type == "hyde":
+ # Hyde should be a document snippet (50-200 chars)
+ hyde_text = expected_items[0]
+ hyde_len = len(hyde_text)
+ if 50 <= hyde_len <= 200:
+ quality_score += 10
+ elif 30 <= hyde_len <= 300:
+ quality_score += 5
+ deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
+ else:
+ deductions.append(f"hyde length {hyde_len} out of range")
+
+ # --- Entity preservation (0-20) ---
+ entity_score = 10 # base
+ if entities:
+ with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
+ if with_entities == len(expected_items):
+ entity_score += 10
+ elif with_entities > 0:
+ entity_score += 5
+ else:
+ entity_score = 0
+ deductions.append(f"missing entities: {entities}")
+
+ # --- Think bonus (0-20) ---
+ think_bonus = 0 if used_thinking else 20
+
+ # --- Total ---
+ total = format_score + diversity_score + quality_score + entity_score + think_bonus
+ max_possible = 120
+ percentage = max(0.0, min(100.0, total / max_possible * 100))
+
+ if percentage >= 80:
+ rating = "Excellent"
+ elif percentage >= 60:
+ rating = "Good"
+ elif percentage >= 40:
+ rating = "Acceptable"
+ elif percentage >= 20:
+ rating = "Poor"
+ else:
+ rating = "Failed"
+
+ return {
+ "format": format_score,
+ "diversity": diversity_score,
+ "hyde": 0, # not used in only mode (quality covers it)
+ "quality": quality_score,
+ "entity": entity_score,
+ "think_bonus": think_bonus,
+ "total": total,
+ "max_possible": max_possible,
+ "percentage": round(percentage, 1),
+ "rating": rating,
+ "deductions": deductions,
+ "parsed": parsed,
+ "entities_detected": list(entities) if entities else [],
+ "only_mode": only_type,
+ }
+
+
+def score_expansion_detailed(query: str, expansion: str) -> dict:
+ """Score an expansion with full breakdown. Returns dict with all dimensions."""
+ text, used_thinking = clean_model_output(expansion.strip())
+ deductions = []
+
+ # Detect "only:" mode
+ only_type, base_query = detect_only_mode(query)
+
+ def _fail(reason):
+ return {
+ "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
+ "think_bonus": 0, "total": 0, "max_possible": 100,
+ "percentage": 0.0, "rating": "Failed",
+ "deductions": [reason],
+ "parsed": parse_expansion(expansion),
+ "entities_detected": [],
+ "only_mode": only_type,
+ }
+
+ # Hard fail: remaining chat template tokens
+ if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
+ return _fail("CHAT TEMPLATE LEAKAGE")
+
+ # Hard fail: every non-empty line must have a valid prefix
+ for line in text.split("\n"):
+ line = line.strip()
+ if line and not line.startswith(("lex:", "vec:", "hyde:")):
+ return _fail(f"INVALID LINE: {line[:50]}")
+
+ # --- Handle "only:" mode separately ---
+ if only_type:
+ return _score_only_mode(query, base_query, text, used_thinking, only_type)
+
+ parsed = parse_expansion(text)
+
+ # --- Format (0-30) ---
+ format_score = 10 # no invalid lines (guaranteed by hard fail)
+ if parsed["lex"]:
+ format_score += 10
+ else:
+ deductions.append("missing lex:")
+ if parsed["vec"]:
+ format_score += 10
+ else:
+ deductions.append("missing vec:")
+
+ # --- Diversity (0-30) ---
+ diversity_score = 0
+
+ types_present = sum(1 for t in ("lex", "vec") if parsed[t])
+ if types_present >= 2:
+ diversity_score += 10
+ else:
+ deductions.append("only one type")
+
+ if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
+ diversity_score += 5
+
+ lex_div = 5
+ for i, a in enumerate(parsed["lex"]):
+ for b in parsed["lex"][i+1:]:
+ if not is_diverse(a, b, 2):
+ lex_div -= 2
+ deductions.append(f"lex duplicate: {a[:20]}...")
+ diversity_score += max(0, lex_div)
+
+ vec_div = 5
+ for i, a in enumerate(parsed["vec"]):
+ for b in parsed["vec"][i+1:]:
+ if not is_diverse(a, b, 3):
+ vec_div -= 2
+ deductions.append(f"vec duplicate: {a[:20]}...")
+ diversity_score += max(0, vec_div)
+
+ echo = 5
+ lex_echo_count = 0
+ for exp in parsed["lex"]:
+ if echoes_query(exp, query):
+ lex_echo_count += 1
+ deductions.append(f"lex echoes query: {exp[:20]}...")
+ # Harsh penalty for lex echoes - they're useless
+ if lex_echo_count > 0:
+ echo -= lex_echo_count * 10 # -10 per echo
+
+ for exp in parsed["vec"]:
+ if echoes_query(exp, query):
+ echo -= 3 # vec echoes less severe (natural language overlap ok)
+ deductions.append(f"vec echoes query: {exp[:20]}...")
+ diversity_score += max(-10, echo) # can go negative
+
+ # --- HyDE (0-20, optional bonus) ---
+ hyde_score = 0
+ if parsed["hyde"]:
+ hyde_text = parsed["hyde"][0]
+ hyde_score += 5
+ hyde_len = len(hyde_text)
+ if 50 <= hyde_len <= 200:
+ hyde_score += 5
+ elif hyde_len < 50:
+ hyde_score += 2
+ deductions.append(f"hyde too short ({hyde_len})")
+ else:
+ deductions.append(f"hyde too long ({hyde_len})")
+ if "\n" not in hyde_text:
+ hyde_score += 5
+ hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
+
+ # --- Quality (0-20) ---
+ quality_score = 5 # base relevance
+ if parsed["lex"] and parsed["vec"]:
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
+ if avg_lex <= avg_vec:
+ quality_score += 5
+ else:
+ deductions.append("lex longer than vec")
+ if parsed["vec"]:
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
+ quality_score += 5 if natural == len(parsed["vec"]) else 2
+ if parsed["lex"]:
+ with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
+ if with_terms == len(parsed["lex"]):
+ quality_score += 5
+ elif with_terms > 0:
+ quality_score += 2
+ else:
+ deductions.append("lex missing key terms")
+
+ # --- Entity Preservation (-45 to +20) ---
+ entity_score = 0
+ entities = extract_named_entities(query)
+ if entities and parsed["lex"]:
+ with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
+ if with_entities == len(parsed["lex"]):
+ entity_score += 15
+ elif with_entities > 0:
+ entity_score += 5
+ else:
+ entity_score -= 30
+ deductions.append(f"lex missing entities: {entities}")
+
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
+ if generic_count:
+ entity_score -= generic_count * 15
+ deductions.append(f"{generic_count} generic lex phrases")
+
+ if parsed["vec"]:
+ vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
+ if vec_with > 0:
+ entity_score += 5
+ elif not entities:
+ entity_score = 10
+
+ # --- Think bonus (0-20): reward NOT using thinking mode ---
+ think_bonus = 0 if used_thinking else 20
+
+ # --- Total ---
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
+ max_possible = 140 if parsed["hyde"] else 120
+ percentage = max(0.0, min(100.0, total / max_possible * 100))
+
+ # Hard cap: lex echoes are unacceptable - cap at 50%
+ if lex_echo_count > 0:
+ percentage = min(percentage, 50.0)
+ deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)")
+
+ if percentage >= 80:
+ rating = "Excellent"
+ elif percentage >= 60:
+ rating = "Good"
+ elif percentage >= 40:
+ rating = "Acceptable"
+ elif percentage >= 20:
+ rating = "Poor"
+ else:
+ rating = "Failed"
+
+ return {
+ "format": format_score,
+ "diversity": diversity_score,
+ "hyde": hyde_score,
+ "quality": quality_score,
+ "entity": max(0, entity_score),
+ "think_bonus": think_bonus,
+ "total": max(0, total),
+ "max_possible": max_possible,
+ "percentage": round(percentage, 1),
+ "rating": rating,
+ "deductions": deductions,
+ "parsed": parsed,
+ "entities_detected": list(entities) if entities else [],
+ "only_mode": None,
+ }
+
+
+def score_expansion(query: str, expansion: str) -> float:
+ """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
+ result = score_expansion_detailed(query, expansion)
+ return max(0.0, min(1.0, result["total"] / result["max_possible"]))
+
+
+def extract_query_from_prompt(prompt: str) -> str:
+ """Extract the query string from a chat-formatted prompt."""
+ if "Expand this search query:" in prompt:
+ query = prompt.split("Expand this search query:")[-1].strip()
+ if "<|im_end|>" in query:
+ query = query.split("<|im_end|>")[0].strip()
+ return query
+ return prompt.strip()
+
+
+# =============================================================================
+# TRL-compatible reward class
+# =============================================================================
+
+class QMDRewardFunction:
+ """Reward function compatible with TRL's GRPOTrainer."""
+ __name__ = "qmd_scoring_reward"
+
+ def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
+ rewards = []
+ for i, completion in enumerate(completions):
+ query = ""
+ if prompts and i < len(prompts):
+ query = extract_query_from_prompt(prompts[i])
+ rewards.append(score_expansion(query, completion))
+ return rewards
+
+
+# =============================================================================
+# CLI: run standalone to test the reward function
+# =============================================================================
+
+if __name__ == "__main__":
+ print("QMD Reward Function Self-Test")
+ print("=" * 60)
+
+ tests = [
+ ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
+ ("auth", "auth is important for security"),
+ ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
+ ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
+ ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
+ ("auth", "Let me think...\nlex: auth"),
+ ("auth", "lex: auth\nThis is some explanation\nvec: more"),
+ # "/only:" mode tests (slash prefix)
+ ("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
+ ("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
+ ("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
+ ("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
+ ]
+
+ for query, expansion in tests:
+ score = score_expansion(query, expansion)
+ detail = score_expansion_detailed(query, expansion)
+ only_mode = detail.get("only_mode")
+ mode_str = f" [only:{only_mode}]" if only_mode else ""
+ print(f"\n Query: '{query}'{mode_str}")
+ print(f" Score: {score:.2f} ({detail['rating']})")
+ if detail["deductions"]:
+ print(f" Issues: {', '.join(detail['deductions'][:3])}")
diff --git a/finetune-mlx/tests/test_format.py b/finetune-mlx/tests/test_format.py
new file mode 100644
index 00000000..1fdc3504
--- /dev/null
+++ b/finetune-mlx/tests/test_format.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+"""Tests for QMD output format validation."""
+
+import pytest
+import sys
+from pathlib import Path
+
+# Add parent to path for imports
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+from eval import score_expansion
+
+
+class TestOutputFormat:
+ """Test output format scoring."""
+
+ def test_perfect_format(self):
+ """Test scoring of perfect format output."""
+ text = """hyde: The Renaissance was a cultural movement in Europe.
+lex: Renaissance period 14th century
+lex: Renaissance art culture
+vec: what was the Renaissance period
+vec: how did the Renaissance transform art"""
+
+ scores = score_expansion(text)
+ assert scores["has_lex"] == 1
+ assert scores["has_vec"] == 1
+ assert scores["has_hyde"] == 1
+ assert scores["format_valid"] == 1
+ assert scores["total"] >= 100
+
+ def test_missing_hyde(self):
+ """Test output missing hyde component."""
+ text = """lex: auth config settings
+lex: authentication setup
+vec: how to configure auth
+vec: auth tutorial"""
+
+ scores = score_expansion(text)
+ assert scores["has_lex"] == 1
+ assert scores["has_vec"] == 1
+ assert scores["has_hyde"] == 0
+ assert scores["format_valid"] == 0
+
+ def test_missing_lex(self):
+ """Test output missing lex component."""
+ text = """hyde: To configure authentication, use the config file.
+vec: how to set up auth
+vec: auth configuration guide"""
+
+ scores = score_expansion(text)
+ assert scores["has_lex"] == 0
+ assert scores["has_vec"] == 1
+ assert scores["has_hyde"] == 1
+ assert scores["format_valid"] == 0
+
+ def test_missing_vec(self):
+ """Test output missing vec component."""
+ text = """hyde: Database connections require proper configuration.
+lex: database connection
+lex: db config"""
+
+ scores = score_expansion(text)
+ assert scores["has_lex"] == 1
+ assert scores["has_vec"] == 0
+ assert scores["has_hyde"] == 1
+ assert scores["format_valid"] == 0
+
+ def test_empty_output(self):
+ """Test empty output."""
+ scores = score_expansion("")
+ assert scores["format_valid"] == 0
+ assert scores["total"] == 0
+
+ def test_multiple_lex_entries(self):
+ """Test counting multiple lex entries."""
+ text = """hyde: Test
+lex: term1
+lex: term2
+lex: term3
+lex: term4
+vec: query"""
+
+ scores = score_expansion(text)
+ assert scores["lex_count"] == 4
+ # Max 3 count toward score
+ assert scores["total"] == 40 + 30 + 10 + 20 # format + 3*lex + 1*vec + hyde
+
+ def test_multiple_vec_entries(self):
+ """Test counting multiple vec entries."""
+ text = """hyde: Test
+lex: term
+vec: query1
+vec: query2
+vec: query3
+vec: query4"""
+
+ scores = score_expansion(text)
+ assert scores["vec_count"] == 4
+ # Max 3 count toward score
+ assert scores["total"] == 40 + 10 + 30 + 20 # format + 1*lex + 3*vec + hyde
+
+
+class TestEdgeCases:
+ """Test edge cases."""
+
+ def test_whitespace_handling(self):
+ """Test handling of extra whitespace."""
+ text = """ hyde: Some text
+ lex: keyword
+ vec: natural question """
+
+ scores = score_expansion(text)
+ assert scores["has_lex"] == 1
+ assert scores["has_vec"] == 1
+ assert scores["has_hyde"] == 1
+
+ def test_case_sensitivity(self):
+ """Test that prefixes are case sensitive (lowercase required)."""
+ text = """HYDE: Some text
+LEX: keyword
+VEC: question"""
+
+ scores = score_expansion(text)
+ # Uppercase should not match
+ assert scores["has_lex"] == 0
+ assert scores["has_vec"] == 0
+ assert scores["has_hyde"] == 0
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/finetune-mlx/tests/test_model.py b/finetune-mlx/tests/test_model.py
new file mode 100644
index 00000000..efb245b8
--- /dev/null
+++ b/finetune-mlx/tests/test_model.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+"""Tests for model loading and inference."""
+
+import pytest
+import sys
+from pathlib import Path
+
+# Add parent to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+# These tests require the model to be downloaded
+# Mark as slow/integration tests
+
+@pytest.mark.slow
+class TestModelLoading:
+ """Test model loading functionality."""
+
+ def test_base_model_exists(self):
+ """Check if base model directory exists."""
+ model_path = Path(__file__).parent.parent / "models" / "Qwen_Qwen2.5-1.5B" / "mlx"
+ assert model_path.exists(), f"Base model not found at {model_path}"
+
+ def test_adapter_exists(self):
+ """Check if trained adapter exists."""
+ adapter_path = Path(__file__).parent.parent / "adapters" / "sft-full"
+ assert adapter_path.exists(), f"Adapter not found at {adapter_path}"
+
+ # Check required files
+ assert (adapter_path / "adapters.safetensors").exists()
+ assert (adapter_path / "adapter_config.json").exists()
+
+ def test_adapter_config_valid(self):
+ """Validate adapter config JSON."""
+ import json
+
+ config_path = Path(__file__).parent.parent / "adapters" / "sft-full" / "adapter_config.json"
+ with open(config_path) as f:
+ config = json.load(f)
+
+ # Check required fields
+ assert "lora_layers" in config
+ assert "rank" in config or "r" in config
+
+
+@pytest.mark.slow
+class TestInference:
+ """Test model inference."""
+
+ @pytest.fixture
+ def model_and_tokenizer(self):
+ """Load model with adapter."""
+ from mlx_lm import load
+
+ model_path = Path(__file__).parent.parent / "models" / "Qwen_Qwen2.5-1.5B" / "mlx"
+ adapter_path = Path(__file__).parent.parent / "adapters" / "sft-full"
+
+ model, tokenizer = load(str(model_path), adapter_path=str(adapter_path))
+ return model, tokenizer
+
+ def test_generate_expansion(self, model_and_tokenizer):
+ """Test generating a query expansion."""
+ from mlx_lm import generate
+
+ model, tokenizer = model_and_tokenizer
+
+ prompt = """<|im_start|>user
+/no_think Expand this search query: test query<|im_end|>
+<|im_start|>assistant
+"""
+
+ response = generate(
+ model,
+ tokenizer,
+ prompt=prompt,
+ max_tokens=128,
+ verbose=False,
+ )
+
+ assert len(response) > 0
+ # Should contain at least some expected prefixes
+ response_lower = response.lower()
+ has_valid_output = (
+ "lex:" in response_lower or
+ "vec:" in response_lower or
+ "hyde:" in response_lower
+ )
+ assert has_valid_output, f"Response missing expected format: {response}"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/finetune-mlx/train.py b/finetune-mlx/train.py
new file mode 100644
index 00000000..36f900bf
--- /dev/null
+++ b/finetune-mlx/train.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python3
+"""
+QMD Query Expansion Training - Apple Silicon Edition
+
+Uses MLX for efficient fine-tuning on Metal GPUs.
+
+Usage:
+ python train.py sft # Supervised fine-tuning
+ python train.py grpo # RL refinement (after SFT)
+ python train.py sft --config configs/sft.yaml
+"""
+
+import argparse
+import json
+import os
+import subprocess
+import sys
+from pathlib import Path
+
+import yaml
+from datasets import load_dataset
+from huggingface_hub import snapshot_download
+
+
+def prepare_dataset(config: dict, stage: str) -> Path:
+ """Download and prepare dataset in MLX format.
+
+ Supports both:
+ - Local path: config.dataset.path = "data/custom"
+ - HuggingFace: config.dataset.name = "tobil/qmd-query-expansion-train-v2"
+ """
+ dataset_config = config.get("dataset", {})
+
+ # Check for local path first
+ local_path = dataset_config.get("path")
+ if local_path:
+ local_dir = Path(local_path)
+ train_file = local_dir / "train.jsonl"
+ valid_file = local_dir / "valid.jsonl"
+
+ if train_file.exists() and valid_file.exists():
+ # Count examples
+ train_count = sum(1 for _ in open(train_file))
+ valid_count = sum(1 for _ in open(valid_file))
+ print(f"📦 Using local dataset: {local_dir}")
+ print(f"✅ Dataset ready: {train_count} train, {valid_count} valid")
+ return local_dir
+ else:
+ print(f"❌ Local dataset not found at {local_dir}")
+ print(f" Expected: {train_file} and {valid_file}")
+ print(f" Run: python dataset/generate_from_notes.py && python dataset/prepare_data.py")
+ sys.exit(1)
+
+ # Fall back to HuggingFace dataset
+ dataset_name = dataset_config.get("name", "tobil/qmd-query-expansion-train-v2")
+ output_dir = Path("data") / stage
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ print(f"📦 Loading dataset from HuggingFace: {dataset_name}")
+ ds = load_dataset(dataset_name, split="train")
+
+ # Split into train/valid
+ eval_ratio = dataset_config.get("eval_split", 0.1)
+ split = ds.train_test_split(test_size=eval_ratio, seed=42)
+
+ # Convert to MLX format (JSONL with "text" field)
+ train_file = output_dir / "train.jsonl"
+ valid_file = output_dir / "valid.jsonl"
+
+ text_field = dataset_config.get("text_field", "text")
+
+ train_count = 0
+ with open(train_file, "w") as f:
+ for item in split["train"]:
+ text = item[text_field]
+ if text and isinstance(text, str) and len(text) > 10:
+ f.write(json.dumps({"text": text}) + "\n")
+ train_count += 1
+
+ valid_count = 0
+ with open(valid_file, "w") as f:
+ for item in split["test"]:
+ text = item[text_field]
+ if text and isinstance(text, str) and len(text) > 10:
+ f.write(json.dumps({"text": text}) + "\n")
+ valid_count += 1
+
+ print(f"✅ Dataset prepared: {train_count} train, {valid_count} valid (filtered nulls)")
+ return output_dir
+
+
+def download_model(model_name: str) -> Path:
+ """Download model and convert to MLX format if needed."""
+ print(f"📥 Downloading model: {model_name}")
+
+ # Check if already converted to MLX
+ mlx_model_dir = Path("models") / model_name.replace("/", "_") / "mlx"
+
+ if mlx_model_dir.exists():
+ print(f"✅ Model already exists: {mlx_model_dir}")
+ return mlx_model_dir
+
+ # Download and convert
+ mlx_model_dir.parent.mkdir(parents=True, exist_ok=True)
+
+ # Use mlx_lm.convert to download and convert
+ cmd = [
+ sys.executable, "-m", "mlx_lm.convert",
+ "--hf-path", model_name,
+ "--mlx-path", str(mlx_model_dir),
+ "-q" # Quantize for memory efficiency
+ ]
+
+ print(f"🔄 Converting to MLX format...")
+ result = subprocess.run(cmd, capture_output=True, text=True)
+
+ if result.returncode != 0:
+ print(f"❌ Conversion failed: {result.stderr}")
+ # Try without quantization
+ cmd.remove("-q")
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ if result.returncode != 0:
+ raise RuntimeError(f"Model conversion failed: {result.stderr}")
+
+ print(f"✅ Model ready: {mlx_model_dir}")
+ return mlx_model_dir
+
+
+def train_sft(config: dict):
+ """Run supervised fine-tuning with LoRA."""
+ print("\n🚀 Starting SFT Training\n")
+
+ model_name = config.get("model", {}).get("base", "Qwen/Qwen2.5-1.5B")
+ output_name = config.get("model", {}).get("output", "qmd-expansion-sft")
+
+ # Prepare data
+ data_dir = prepare_dataset(config, "sft")
+
+ # Download/convert model
+ model_dir = download_model(model_name)
+
+ # Training params
+ training = config.get("training", {})
+ lora = config.get("lora", {})
+
+ # Use output name for adapter path (allows multiple trained models)
+ adapter_name = output_name.replace("/", "_").replace("-", "_")
+ adapter_path = Path("adapters") / adapter_name
+ adapter_path.mkdir(parents=True, exist_ok=True)
+
+ cmd = [
+ sys.executable, "-m", "mlx_lm", "lora",
+ "--model", str(model_dir),
+ "--train",
+ "--data", str(data_dir),
+ "--adapter-path", str(adapter_path),
+ "--batch-size", str(training.get("batch_size", 4)),
+ "--iters", str(training.get("iters", 1000)),
+ "--learning-rate", str(training.get("learning_rate", 1e-4)),
+ "--num-layers", str(lora.get("num_layers", 16)),
+ "--steps-per-report", "10",
+ "--steps-per-eval", "100",
+ "--save-every", "200",
+ ]
+
+ max_length = training.get("max_length", 512)
+ if max_length:
+ cmd.extend(["--max-seq-length", str(max_length)])
+
+ print(f"📝 Command: {' '.join(cmd)}\n")
+
+ # Run training
+ result = subprocess.run(cmd)
+
+ if result.returncode != 0:
+ print("❌ Training failed")
+ sys.exit(1)
+
+ print(f"\n✅ SFT complete! Adapter saved to: {adapter_path}")
+ return adapter_path
+
+
+def train_grpo(config: dict):
+ """Run GRPO (RL refinement) on top of SFT model."""
+ print("\n🚀 Starting GRPO Training\n")
+
+ # Check SFT adapter exists
+ sft_adapter = Path("adapters") / "sft"
+ if not sft_adapter.exists():
+ print("❌ SFT adapter not found. Run 'python train.py sft' first.")
+ sys.exit(1)
+
+ # For GRPO, we need to implement reward-based training
+ # MLX doesn't have built-in GRPO, so we'll use a simpler approach:
+ # Continue SFT with filtered high-quality examples
+
+ print("⚠️ GRPO not yet implemented for MLX.")
+ print(" The SFT model should work well for most use cases.")
+ print(" For GRPO, use HuggingFace Jobs with the original scripts.")
+
+ # TODO: Implement reward model + PPO-style training
+ # This would require:
+ # 1. Generate expansions from SFT model
+ # 2. Score with reward function
+ # 3. Filter top examples
+ # 4. Continue training on high-reward examples
+
+ return sft_adapter
+
+
+def load_config(config_path: str) -> dict:
+ """Load YAML config file."""
+ with open(config_path) as f:
+ return yaml.safe_load(f)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="QMD Query Expansion Training")
+ parser.add_argument("stage", choices=["sft", "grpo"], help="Training stage")
+ parser.add_argument("--config", "-c", help="Config file path")
+ parser.add_argument("--dry-run", action="store_true", help="Print commands without running")
+
+ args = parser.parse_args()
+
+ # Load config
+ if args.config:
+ config = load_config(args.config)
+ else:
+ # Default config
+ config_path = Path("configs") / f"{args.stage}.yaml"
+ if config_path.exists():
+ config = load_config(str(config_path))
+ else:
+ config = {}
+
+ # Run training
+ if args.stage == "sft":
+ train_sft(config)
+ elif args.stage == "grpo":
+ train_grpo(config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/mlx_expand.py b/scripts/mlx_expand.py
new file mode 100755
index 00000000..81d85b0d
--- /dev/null
+++ b/scripts/mlx_expand.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python3
+"""QMD MLX Query Expansion (sidecar).
+
+⚠️ EXPERIMENTAL / DEV TOOL
+This script is a bridge for iterating on fine-tuned MLX adapters locally.
+It requires Apple Silicon + mlx_lm. If unavailable or failing, QMD falls back
+to the default llama.cpp-based query expansion.
+
+Outputs query expansions in the exact line-oriented format expected by QMD:
+ lex: ...
+ vec: ...
+ hyde: ...
+
+Environment variables:
+ QMD_MLX_MODEL Base MLX model directory (default: finetune-mlx/models/Qwen_Qwen3-1.7B/mlx)
+ QMD_MLX_ADAPTER LoRA adapter directory (default: finetune-mlx/adapters/qmd_query_expansion_1.7B_sft)
+ QMD_MLX_TEMP Temperature (default: 1.0)
+ QMD_MLX_MAX_TOKENS Max tokens (default: 512)
+
+Usage:
+ ./scripts/mlx_expand.py "auth config"
+ echo "auth config" | ./scripts/mlx_expand.py
+"""
+
+import os
+import sys
+
+from mlx_lm import load, generate
+from mlx_lm.sample_utils import make_sampler
+
+PROMPT_TEMPLATE = """<|im_start|>user
+/no_think Expand this search query: {query}<|im_end|>
+<|im_start|>assistant
+"""
+
+
+def _read_query(argv):
+ if len(argv) > 1:
+ return " ".join(argv[1:]).strip()
+ data = sys.stdin.read().strip()
+ return data
+
+
+def main():
+ query = _read_query(sys.argv)
+ if not query:
+ print("", end="")
+ return 0
+
+ repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+
+ model = os.environ.get(
+ "QMD_MLX_MODEL",
+ os.path.join(repo_root, "finetune-mlx", "models", "Qwen_Qwen3-1.7B", "mlx"),
+ )
+ adapter = os.environ.get(
+ "QMD_MLX_ADAPTER",
+ os.path.join(repo_root, "finetune-mlx", "adapters", "qmd_query_expansion_1.7B_sft"),
+ )
+
+ temp = float(os.environ.get("QMD_MLX_TEMP", "1.0"))
+ max_tokens = int(os.environ.get("QMD_MLX_MAX_TOKENS", "512"))
+
+ prompt = PROMPT_TEMPLATE.format(query=query)
+
+ # Load base model + adapter if available.
+ try:
+ m, tok = load(model, adapter_path=adapter)
+ except Exception:
+ m, tok = load(model)
+
+ sampler = make_sampler(temp=temp)
+ out = generate(
+ m,
+ tok,
+ prompt=prompt,
+ max_tokens=max_tokens,
+ sampler=sampler,
+ verbose=False,
+ )
+
+ # Clean special tokens and keep only line-based output.
+ out = out.replace("<|im_end|>", "").strip()
+ lines = [ln.strip() for ln in out.splitlines() if ln.strip()]
+
+ # Keep only lex/vec/hyde lines (defensive)
+ filtered = []
+ for ln in lines:
+ if ln.startswith("lex:") or ln.startswith("vec:") or ln.startswith("hyde:"):
+ filtered.append(ln)
+
+ sys.stdout.write("\n".join(filtered).strip())
+ if filtered:
+ sys.stdout.write("\n")
+
+ return 0
+
+
+if __name__ == "__main__":
+ raise SystemExit(main())