diff --git a/.gitignore b/.gitignore index 40e11448..3134fc4f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,7 @@ texts/ finetune/outputs/ finetune/data/train/ .claude/ + +# finetune-mlx artifacts (local runs) +finetune-mlx/*.nohup.log +finetune-mlx/eval_results.json diff --git a/finetune-mlx/.gitignore b/finetune-mlx/.gitignore new file mode 100644 index 00000000..fdf965bd --- /dev/null +++ b/finetune-mlx/.gitignore @@ -0,0 +1,21 @@ +# Training artifacts (large files) +adapters/ +models/ +exports/merged/ +exports/*.gguf + +# Keep Modelfiles (small, useful) +!exports/*.Modelfile + +# Python +.venv/ +__pycache__/ +*.pyc + +# Logs +*.log +*.nohup.log + +# Data (downloaded separately) +data/ +eval_results.json diff --git a/finetune-mlx/README.md b/finetune-mlx/README.md new file mode 100644 index 00000000..e10a56b2 --- /dev/null +++ b/finetune-mlx/README.md @@ -0,0 +1,115 @@ +# QMD Query Expansion - Apple Silicon (MLX) + +Apple Silicon alternative to the CUDA-based [`finetune/`](../finetune/) directory. + +Port of QMD's query expansion fine-tuning to Apple Silicon using [MLX](https://github.com/ml-explore/mlx). + +Train small language models locally on M1/M2/M3/M4 Macs to expand search queries for hybrid retrieval. + +## Features + +- **SFT Training**: Supervised fine-tuning with LoRA +- **GRPO Training**: Group Relative Policy Optimization (reinforcement learning) +- **100% Local**: No cloud GPU needed, runs on Apple Silicon +- **MLX Optimized**: Native Metal acceleration + +## Results + +Comparison with original NVIDIA A10G implementation: + +| Metric | NVIDIA (SFT+GRPO) | Apple Silicon (SFT) | Apple Silicon (GRPO) | +|--------|-------------------|---------------------|----------------------| +| Avg Score | 92% | 99.6% | 100.4% | +| Perfect Queries | 30/30 | 28/30 | 28/30 | +| Hardware | A10G 24GB | Mac Mini M4 | Mac Mini M4 | +| Cost | ~$2/run | $0 | $0 | + +## Quick Start + +```bash +# Setup +python -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt + +# Download and convert base model +python -c "from mlx_lm import load; load('Qwen/Qwen3-1.7B')" + +# Train SFT (supervised fine-tuning) +python train.py sft --iters 3500 + +# Train GRPO (reinforcement learning refinement) +python grpo.py --steps 200 + +# Evaluate +python grpo.py --eval-only --adapter adapters/qwen3-grpo +``` + +## What It Does + +Given a query like `"auth config"`, the model produces structured expansions: + +``` +lex: authentication configuration +lex: auth settings setup +vec: how to configure authentication settings +hyde: Authentication can be configured by setting AUTH_SECRET... +``` + +These feed into QMD's hybrid retrieval: +- `lex:` → BM25 full-text search +- `vec:` → Vector similarity search +- `hyde:` → Hypothetical document embedding + +## File Structure + +``` +├── train.py # SFT training script +├── grpo.py # GRPO (RL) training script +├── eval.py # Evaluation utilities +├── reward.py # Scoring/reward function +├── convert.py # GGUF conversion for Ollama +├── configs/ +│ └── sft.yaml # SFT hyperparameters +├── evals/ +│ └── queries.txt # Test queries (31 total) +└── tests/ # Unit tests +``` + +## Requirements + +- macOS with Apple Silicon (M1/M2/M3/M4) +- Python 3.10+ +- ~8GB RAM for training +- ~4GB disk for models + +## Training Details + +### SFT (Supervised Fine-Tuning) +- Base model: Qwen3-1.7B +- LoRA rank: 8, layers: 8 +- Learning rate: 1e-4 +- Steps: 3500 +- Time: ~60 min on M4 + +### GRPO (Group Relative Policy Optimization) +- Starts from SFT checkpoint +- 4 completions per query +- KL regularization (β=0.04) +- Steps: 200 +- Time: ~30 min on M4 + +## Credits + +- Original QMD: [tobi/qmd](https://github.com/tobi/qmd) +- MLX framework: [ml-explore/mlx](https://github.com/ml-explore/mlx) +- Base model: [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) + +## Contributors + +- [@sujito00](https://github.com/sujito00) +- [@dgilperez](https://github.com/dgilperez) + +## License + +MIT diff --git a/finetune-mlx/configs/sft.yaml b/finetune-mlx/configs/sft.yaml new file mode 100644 index 00000000..2fe0b04f --- /dev/null +++ b/finetune-mlx/configs/sft.yaml @@ -0,0 +1,21 @@ +# SFT Training Config for QMD Query Expansion (Apple Silicon) + +model: + base: "Qwen/Qwen3-1.7B" + output: "qmd-query-expansion-1.7B-sft" + +dataset: + name: "tobil/qmd-query-expansion-train-v2" + text_field: "text" + eval_split: 0.1 + +training: + batch_size: 4 + iters: 3000 + learning_rate: 2e-4 + max_length: 512 + grad_accumulation_steps: 4 + +lora: + num_layers: 16 + rank: 16 diff --git a/finetune-mlx/convert.py b/finetune-mlx/convert.py new file mode 100644 index 00000000..6aebe7ff --- /dev/null +++ b/finetune-mlx/convert.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +""" +Convert trained MLX model to GGUF for Ollama/llama.cpp. + +Full pipeline: +1. Merge LoRA adapter into base model +2. Dequantize if needed (MLX 4-bit → FP16) +3. Convert to GGUF (requires llama.cpp) +4. Quantize to Q4_K_M (optional, requires llama-quantize) + +Usage: + python convert.py # Full pipeline with defaults + python convert.py --output qmd-expand # Custom output name + python convert.py --quantize q4_k_m # Include quantization step + python convert.py --skip-merge # Use already-merged model + +Requirements: + pip install mlx_lm gguf + git clone https://github.com/ggerganov/llama.cpp + cd llama.cpp && mkdir build && cd build && cmake .. && make llama-quantize +""" + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + + +def find_llama_cpp(): + """Find llama.cpp installation.""" + # Check common locations + locations = [ + Path.home() / "src" / "llama.cpp", + Path.home() / "llama.cpp", + Path("/usr/local/llama.cpp"), + Path("../llama.cpp"), + ] + + for loc in locations: + convert_script = loc / "convert_hf_to_gguf.py" + if convert_script.exists(): + return loc + + # Check if in PATH + if shutil.which("llama-quantize"): + return Path(shutil.which("llama-quantize")).parent.parent + + return None + + +def merge_adapter(model_path: Path, adapter_path: Path, output_path: Path): + """Merge LoRA adapter into base model using mlx_lm fuse.""" + print(f"🔄 Merging adapter into base model...") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "-m", "mlx_lm", "fuse", + "--model", str(model_path), + "--adapter-path", str(adapter_path), + "--save-path", str(output_path), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"❌ Merge failed: {result.stderr}") + sys.exit(1) + + print(f"✅ Merged model saved to: {output_path}") + return output_path + + +def dequantize_model(model_path: Path, output_path: Path): + """Dequantize MLX model from 4-bit to FP16.""" + print(f"🔄 Dequantizing model to FP16...") + + # Check if model is quantized + config_path = model_path / "config.json" + if config_path.exists(): + import json + with open(config_path) as f: + config = json.load(f) + if "quantization" not in config and "quantization_config" not in config: + print(f"ℹ️ Model is not quantized, skipping dequantization") + return model_path + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Use mlx_lm convert with dequantize flag + cmd = [ + sys.executable, "-c", + f""" +from mlx_lm import convert +convert( + hf_path='{model_path}', + mlx_path='{output_path}', + dequantize=True, + dtype='float16' +) +""" + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"❌ Dequantization failed: {result.stderr}") + sys.exit(1) + + print(f"✅ Dequantized model saved to: {output_path}") + return output_path + + +def convert_to_gguf(model_path: Path, output_path: Path, llama_cpp_path: Path): + """Convert HF/MLX model to GGUF format.""" + print(f"🔄 Converting to GGUF...") + + convert_script = llama_cpp_path / "convert_hf_to_gguf.py" + + cmd = [ + sys.executable, str(convert_script), + str(model_path), + "--outfile", str(output_path), + "--outtype", "f16" + ] + + result = subprocess.run(cmd) + + if result.returncode != 0: + print(f"❌ GGUF conversion failed") + sys.exit(1) + + print(f"✅ GGUF saved to: {output_path}") + return output_path + + +def quantize_gguf(input_path: Path, output_path: Path, quant_type: str, llama_cpp_path: Path): + """Quantize GGUF model.""" + print(f"🔄 Quantizing to {quant_type}...") + + # Find llama-quantize + quantize_bin = llama_cpp_path / "build" / "bin" / "llama-quantize" + if not quantize_bin.exists(): + quantize_bin = shutil.which("llama-quantize") + + if not quantize_bin: + print(f"⚠️ llama-quantize not found, skipping quantization") + print(f" Build it: cd llama.cpp && mkdir build && cd build && cmake .. && make llama-quantize") + return input_path + + cmd = [str(quantize_bin), str(input_path), str(output_path), quant_type.upper()] + + result = subprocess.run(cmd) + + if result.returncode != 0: + print(f"❌ Quantization failed") + return input_path + + print(f"✅ Quantized GGUF saved to: {output_path}") + return output_path + + +def create_modelfile(output_name: str, gguf_filename: str): + """Create Ollama Modelfile.""" + modelfile = Path("exports") / f"{output_name}.Modelfile" + modelfile.parent.mkdir(parents=True, exist_ok=True) + + content = f'''# Modelfile for {output_name} +# Usage: ollama create {output_name} -f exports/{output_name}.Modelfile + +FROM ./{gguf_filename} + +TEMPLATE """<|im_start|>user +/no_think Expand this search query: {{{{.Prompt}}}}<|im_end|> +<|im_start|>assistant +""" + +PARAMETER temperature 0.3 +PARAMETER top_p 0.9 +PARAMETER stop "<|im_end|>" +''' + + with open(modelfile, "w") as f: + f.write(content) + + print(f"✅ Modelfile saved to: {modelfile}") + return modelfile + + +def main(): + parser = argparse.ArgumentParser( + description="Convert MLX model to GGUF", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python convert.py + python convert.py --quantize q4_k_m + python convert.py --model models/Qwen_Qwen3-1.7B/mlx --adapter adapters/qmd_query_expansion_1.7B_sft +""" + ) + parser.add_argument("--model", default="models/Qwen_Qwen3-1.7B/mlx", help="Base model path") + parser.add_argument("--adapter", default="adapters/qmd_query_expansion_1.7B_sft", help="LoRA adapter path") + parser.add_argument("--output", "-o", default="qmd-query-expand", help="Output model name") + parser.add_argument("--quantize", "-q", help="Quantization type (e.g., q4_k_m, q8_0)") + parser.add_argument("--skip-merge", action="store_true", help="Skip adapter merge (use pre-merged model)") + parser.add_argument("--llama-cpp", help="Path to llama.cpp directory") + + args = parser.parse_args() + + # Find llama.cpp + llama_cpp_path = Path(args.llama_cpp) if args.llama_cpp else find_llama_cpp() + if not llama_cpp_path: + print("❌ llama.cpp not found. Clone it first:") + print(" git clone https://github.com/ggerganov/llama.cpp ~/src/llama.cpp") + sys.exit(1) + + print(f"📍 Using llama.cpp at: {llama_cpp_path}") + + model_path = Path(args.model) + adapter_path = Path(args.adapter) + exports_dir = Path("exports") + exports_dir.mkdir(exist_ok=True) + + # Step 1: Merge adapter + if not args.skip_merge and adapter_path.exists(): + merged_path = exports_dir / "merged" / args.output + model_path = merge_adapter(model_path, adapter_path, merged_path) + + # Step 2: Dequantize if needed + fp16_path = exports_dir / "merged" / f"{args.output}-fp16" + model_path = dequantize_model(model_path, fp16_path) + + # Step 3: Convert to GGUF + gguf_path = exports_dir / f"{args.output}-f16.gguf" + convert_to_gguf(model_path, gguf_path, llama_cpp_path) + + final_gguf = gguf_path + + # Step 4: Quantize if requested + if args.quantize: + quant_path = exports_dir / f"{args.output}-{args.quantize}.gguf" + final_gguf = quantize_gguf(gguf_path, quant_path, args.quantize, llama_cpp_path) + + # Clean up F16 if quantization succeeded + if final_gguf != gguf_path and final_gguf.exists(): + print(f"🧹 Removing intermediate F16 GGUF...") + gguf_path.unlink() + + # Step 5: Create Modelfile + create_modelfile(args.output, final_gguf.name) + + print(f"\n🎉 Done! Create Ollama model with:") + print(f" cd finetune-mlx && ollama create {args.output} -f exports/{args.output}.Modelfile") + + +if __name__ == "__main__": + main() diff --git a/finetune-mlx/demo.py b/finetune-mlx/demo.py new file mode 100644 index 00000000..2d55a811 --- /dev/null +++ b/finetune-mlx/demo.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +Quick demo: Use the trained QMD query expansion model. + +Usage: + python demo.py "your search query" + python demo.py --interactive +""" + +import argparse +import sys + +from mlx_lm import load, generate +from mlx_lm.sample_utils import make_sampler + +PROMPT_TEMPLATE = """<|im_start|>user +/no_think Expand this search query: {query}<|im_end|> +<|im_start|>assistant +""" + + +def expand_query(model, tokenizer, query: str, temp: float = 0.0) -> str: + """Generate query expansion.""" + prompt = PROMPT_TEMPLATE.format(query=query) + sampler = make_sampler(temp=temp) + response = generate( + model, tokenizer, + prompt=prompt, + max_tokens=200, + sampler=sampler, + verbose=False + ) + return response.replace('<|im_end|>', '').strip() + + +def main(): + parser = argparse.ArgumentParser(description="QMD Query Expansion Demo") + parser.add_argument("query", nargs="?", help="Query to expand") + parser.add_argument("--interactive", "-i", action="store_true", help="Interactive mode") + parser.add_argument("--model", default="models/Qwen_Qwen2.5-1.5B/mlx", help="Base model path") + parser.add_argument("--adapter", default="adapters/sft", help="LoRA adapter path") + parser.add_argument("--temp", type=float, default=0.0, help="Temperature (0=deterministic)") + args = parser.parse_args() + + print("Loading model...") + try: + model, tokenizer = load(args.model, adapter_path=args.adapter) + except Exception as e: + print(f"Error loading model: {e}") + print("\nTrying without adapter (base model)...") + model, tokenizer = load(args.model) + + print("Ready!\n") + + if args.interactive: + print("Interactive mode. Type 'quit' to exit.\n") + while True: + try: + query = input("Query> ").strip() + if query.lower() in ('quit', 'exit', 'q'): + break + if not query: + continue + print("\nExpansion:") + print(expand_query(model, tokenizer, query, args.temp)) + print() + except (KeyboardInterrupt, EOFError): + break + elif args.query: + print(f"Query: {args.query}\n") + print("Expansion:") + print(expand_query(model, tokenizer, args.query, args.temp)) + else: + # Demo queries + demos = [ + "auth config", + "docker networking", + "how to deploy", + "kubernetes pod", + ] + for q in demos: + print(f"Query: {q}") + print("-" * 40) + print(expand_query(model, tokenizer, q, args.temp)) + print() + + +if __name__ == "__main__": + main() diff --git a/finetune-mlx/eval.py b/finetune-mlx/eval.py new file mode 100644 index 00000000..dee89ef8 --- /dev/null +++ b/finetune-mlx/eval.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +""" +Evaluate QMD Query Expansion model. + +Usage: + python eval.py # Use default adapter + python eval.py --adapter adapters/sft # Specify adapter + python eval.py --query "auth config" # Test single query +""" + +import argparse +import json +import re +from pathlib import Path + +import mlx.core as mx +from mlx_lm import load, generate + + +# Test queries for evaluation +TEST_QUERIES = [ + "auth config", + "how to deploy", + "rate limiting", + "database connection", + "api keys", + "error handling", + "caching strategy", + "user permissions", +] + +PROMPT_TEMPLATE = """<|im_start|>user +/no_think Expand this search query: {query}<|im_end|> +<|im_start|>assistant +""" + + +def score_expansion(text: str) -> dict: + """Score an expansion based on format and quality.""" + scores = { + "has_lex": 0, + "has_vec": 0, + "has_hyde": 0, + "lex_count": 0, + "vec_count": 0, + "format_valid": 0, + "total": 0, + } + + lines = text.strip().split("\n") + + for line in lines: + line = line.strip() + if line.startswith("lex:"): + scores["has_lex"] = 1 + scores["lex_count"] += 1 + elif line.startswith("vec:"): + scores["has_vec"] = 1 + scores["vec_count"] += 1 + elif line.startswith("hyde:"): + scores["has_hyde"] = 1 + + # Format is valid if we have at least one of each type + if scores["has_lex"] and scores["has_vec"] and scores["has_hyde"]: + scores["format_valid"] = 1 + + # Total score (0-100) + scores["total"] = ( + scores["format_valid"] * 40 + + min(scores["lex_count"], 3) * 10 + + min(scores["vec_count"], 3) * 10 + + scores["has_hyde"] * 20 + ) + + return scores + + +def expand_query(model, tokenizer, query: str, max_tokens: int = 256) -> str: + """Generate expansion for a query.""" + prompt = PROMPT_TEMPLATE.format(query=query) + + response = generate( + model, + tokenizer, + prompt=prompt, + max_tokens=max_tokens, + verbose=False, + ) + + # Extract just the assistant response + if "<|im_end|>" in response: + response = response.split("<|im_end|>")[0] + + return response.strip() + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate QMD model") + parser.add_argument("--model", default="models/Qwen_Qwen2.5-1.5B/mlx", help="Base model path") + parser.add_argument("--adapter", default="adapters/sft", help="LoRA adapter path") + parser.add_argument("--query", "-q", help="Single query to test") + parser.add_argument("--no-adapter", action="store_true", help="Run without adapter (baseline)") + + args = parser.parse_args() + + # Load model + model_path = Path(args.model) + adapter_path = Path(args.adapter) if not args.no_adapter else None + + print(f"📦 Loading model: {model_path}") + if adapter_path and adapter_path.exists(): + print(f"🔌 Loading adapter: {adapter_path}") + model, tokenizer = load(str(model_path), adapter_path=str(adapter_path)) + else: + if adapter_path: + print(f"⚠️ Adapter not found: {adapter_path}, using base model") + model, tokenizer = load(str(model_path)) + + # Single query mode + if args.query: + print(f"\n🔍 Query: {args.query}\n") + expansion = expand_query(model, tokenizer, args.query) + print(expansion) + print(f"\n📊 Score: {score_expansion(expansion)}") + return + + # Batch evaluation + print(f"\n📝 Evaluating {len(TEST_QUERIES)} test queries...\n") + + results = [] + total_score = 0 + + for query in TEST_QUERIES: + print(f"🔍 {query}") + expansion = expand_query(model, tokenizer, query) + scores = score_expansion(expansion) + total_score += scores["total"] + + results.append({ + "query": query, + "expansion": expansion, + "scores": scores, + }) + + # Print summary + status = "✅" if scores["format_valid"] else "❌" + print(f" {status} Score: {scores['total']}/100") + print(f" lex:{scores['lex_count']} vec:{scores['vec_count']} hyde:{scores['has_hyde']}") + print() + + # Summary + avg_score = total_score / len(TEST_QUERIES) + print(f"\n{'='*50}") + print(f"📊 Average Score: {avg_score:.1f}/100") + print(f"{'='*50}") + + # Save results + output_file = Path("eval_results.json") + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\n💾 Results saved to: {output_file}") + + +if __name__ == "__main__": + main() diff --git a/finetune-mlx/evals/queries.txt b/finetune-mlx/evals/queries.txt new file mode 100644 index 00000000..16224120 --- /dev/null +++ b/finetune-mlx/evals/queries.txt @@ -0,0 +1,48 @@ +# Test queries for QMD query expansion evaluation +# One query per line, comments start with # + +# Technical documentation +how to configure authentication +typescript async await +docker compose networking +git rebase vs merge +react useEffect cleanup + +# Short/ambiguous queries +auth +config +setup +api + +# Named entities (critical for entity preservation testing) +who is TDS motorsports +React hooks tutorial +Docker container networking +Kubernetes pod deployment +AWS Lambda functions + +# Personal notes / journals style +meeting notes project kickoff +ideas for new feature +todo list app architecture + +# Research / learning +what is dependency injection +difference between sql and nosql +kubernetes vs docker swarm + +# Error/debugging +connection timeout error +memory leak debugging +cors error fix + +# Temporal / recency queries (should expand with years, "recent", "latest") +recent news about Shopify +latest AI developments +best laptops right now +what changed in kubernetes latest version + +# Complex queries +how to implement caching with redis in nodejs +best practices for api rate limiting +setting up ci cd pipeline with github actions diff --git a/finetune-mlx/exports/qmd-query-expand.Modelfile b/finetune-mlx/exports/qmd-query-expand.Modelfile new file mode 100644 index 00000000..929adc0b --- /dev/null +++ b/finetune-mlx/exports/qmd-query-expand.Modelfile @@ -0,0 +1,16 @@ +# Modelfile for qmd-query-expand +# Usage: ollama create qmd-query-expand -f exports/qmd-query-expand.Modelfile +# +# Requires: qmd-query-expand-q4_k_m.gguf in same directory +# Generate with: python convert.py --quantize q4_k_m + +FROM ./qmd-query-expand-q4_k_m.gguf + +TEMPLATE """<|im_start|>user +/no_think Expand this search query: {{.Prompt}}<|im_end|> +<|im_start|>assistant +""" + +PARAMETER temperature 0.3 +PARAMETER top_p 0.9 +PARAMETER stop "<|im_end|>" diff --git a/finetune-mlx/grpo.py b/finetune-mlx/grpo.py new file mode 100644 index 00000000..ade3288e --- /dev/null +++ b/finetune-mlx/grpo.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +GRPO (Group Relative Policy Optimization) for MLX - Apple Silicon Edition + +Usage: + python grpo.py # Run GRPO training + python grpo.py --steps 100 # Custom steps + python grpo.py --eval-only # Just evaluate current model +""" + +import argparse +import json +import math +import random +import time +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx_lm import load, generate +from mlx_lm.sample_utils import make_sampler + +from reward import score_expansion_detailed + +# ============================================================================= +# Configuration +# ============================================================================= + +CONFIG = { + "base_model": "models/Qwen3-1.7B-mlx", + "sft_adapter": "adapters/qwen3-3500", + "output_adapter": "adapters/qwen3-grpo", + + "num_generations": 4, + "max_tokens": 200, + "beta": 0.04, + "learning_rate": 5e-6, + "max_steps": 200, + + "log_every": 5, + "save_every": 50, + "eval_every": 50, +} + +TRAINING_QUERIES = [ + "auth config", "how to deploy", "rate limiting", "database connection", + "api keys", "error handling", "caching strategy", "user permissions", + "typescript async await", "docker compose networking", "git rebase vs merge", + "react useEffect cleanup", "kubernetes pod deployment", "AWS Lambda functions", + "memory leak debugging", "cors error fix", "connection timeout error", + "dependency injection", "sql vs nosql", "ci cd pipeline", +] + +PROMPT_TEMPLATE = """<|im_start|>user +/no_think Expand this search query: {query}<|im_end|> +<|im_start|>assistant +""" + + +def generate_completion(model, tokenizer, query, temp=0.7): + """Generate a single completion.""" + prompt = PROMPT_TEMPLATE.format(query=query) + sampler = make_sampler(temp=temp, top_p=0.9) + response = generate(model, tokenizer, prompt=prompt, max_tokens=200, sampler=sampler, verbose=False) + return prompt, response.replace('<|im_end|>', '').strip() + + +def compute_reward(query, completion): + """Score completion using reward function.""" + result = score_expansion_detailed(query, completion) + return result['total'] / 140.0 # Normalize to [0, 1] + + +def compute_log_prob(model, tokenizer, prompt, completion): + """Compute log probability of completion given prompt.""" + full_text = prompt + completion + tokens = mx.array(tokenizer.encode(full_text)) + prompt_len = len(tokenizer.encode(prompt)) + + # Forward pass + logits = model(tokens[None, :-1])[0] # [seq_len, vocab_size] + + # Get completion log probs + log_probs = nn.log_softmax(logits[prompt_len-1:], axis=-1) + target_tokens = tokens[prompt_len:] + + # Gather + token_log_probs = mx.take_along_axis( + log_probs[:len(target_tokens)], + target_tokens[:, None], + axis=-1 + ).squeeze(-1) + + return mx.sum(token_log_probs) + + +def grpo_step(policy_model, ref_model, tokenizer, query, config, optimizer): + """Single GRPO training step.""" + prompt = PROMPT_TEMPLATE.format(query=query) + + # 1. Generate completions with varying temperatures (1.0 to 2.0 for diversity) + completions = [] + for i in range(config["num_generations"]): + temp = 1.0 + 1.0 * i / max(1, config["num_generations"] - 1) # 1.0 to 2.0 + _, comp = generate_completion(policy_model, tokenizer, query, temp=temp) + completions.append(comp) + + # 2. Compute rewards + rewards = [compute_reward(query, c) for c in completions] + mean_reward = sum(rewards) / len(rewards) + + # 3. Compute advantages (group relative) + std_reward = math.sqrt(sum((r - mean_reward)**2 for r in rewards) / len(rewards) + 1e-8) + advantages = [(r - mean_reward) / std_reward for r in rewards] + + # 4. Compute reference log probs (frozen) + ref_log_probs = [] + for comp in completions: + ref_lp = compute_log_prob(ref_model, tokenizer, prompt, comp) + mx.eval(ref_lp) + ref_log_probs.append(ref_lp) + + # 5. Define loss function for this batch + def loss_fn(model): + total_loss = mx.array(0.0) + total_kl = mx.array(0.0) + + for comp, adv, ref_lp in zip(completions, advantages, ref_log_probs): + # Always compute even with small advantages + policy_lp = compute_log_prob(model, tokenizer, prompt, comp) + kl = policy_lp - ref_lp + pg_loss = -mx.array(adv) * policy_lp + + total_loss = total_loss + pg_loss + config["beta"] * mx.abs(kl) + total_kl = total_kl + kl + + n = len(completions) + return total_loss / n, total_kl / n + + # 6. Compute gradients using nn.value_and_grad + loss_grad_fn = nn.value_and_grad(policy_model, loss_fn) + (loss, kl), grads = loss_grad_fn(policy_model) + + # 7. Update + optimizer.update(policy_model, grads) + mx.eval(policy_model.parameters()) + + return float(loss), float(kl), mean_reward + + +def save_lora_weights(model, output_dir): + """Save LoRA weights to directory.""" + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + model.save_weights(str(output_dir / "adapters.safetensors")) + + +def evaluate(model, tokenizer, queries_file="evals/queries.txt"): + """Run evaluation.""" + queries = [] + with open(queries_file) as f: + for line in f: + line = line.strip() + if line and not line.startswith('#'): + queries.append(line) + + greedy = make_sampler(temp=0.0) + scores = [] + for q in queries: + prompt = PROMPT_TEMPLATE.format(query=q) + resp = generate(model, tokenizer, prompt=prompt, max_tokens=200, sampler=greedy, verbose=False) + resp = resp.replace('<|im_end|>', '').strip() + result = score_expansion_detailed(q, resp) + scores.append(result['total']) + + return { + "avg": sum(scores) / len(scores), + "perfect": sum(1 for s in scores if s >= 100), + "total": len(scores), + "min": min(scores), + "max": max(scores), + } + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--steps", type=int, default=CONFIG["max_steps"]) + parser.add_argument("--eval-only", action="store_true") + parser.add_argument("--adapter", type=str) + args = parser.parse_args() + + config = CONFIG.copy() + config["max_steps"] = args.steps + + print("=" * 60) + print("GRPO Training - QMD Query Expansion (MLX)") + print("=" * 60) + + print(f"\n📥 Loading model...") + policy_model, tokenizer = load(config["base_model"], adapter_path=args.adapter or config["sft_adapter"]) + + if args.eval_only: + print("\n📊 Evaluation:") + r = evaluate(policy_model, tokenizer) + print(f"Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}, Range: [{r['min']}, {r['max']}]") + return + + print(f"📥 Loading reference model (frozen)...") + ref_model, _ = load(config["base_model"], adapter_path=config["sft_adapter"]) + ref_model.freeze() + + optimizer = optim.Adam(learning_rate=config["learning_rate"]) + + output_dir = Path(config["output_adapter"]) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n🚀 Starting GRPO ({config['max_steps']} steps)...\n") + + for step in range(1, config["max_steps"] + 1): + query = random.choice(TRAINING_QUERIES) + + t0 = time.time() + loss, kl, reward = grpo_step(policy_model, ref_model, tokenizer, query, config, optimizer) + dt = time.time() - t0 + + if step % config["log_every"] == 0: + print(f"Step {step:4d} | Loss: {loss:.4f} | KL: {kl:.4f} | Reward: {reward:.3f} | {dt:.1f}s") + + if step % config["save_every"] == 0: + ckpt = output_dir / f"ckpt_{step:04d}" + save_lora_weights(policy_model, ckpt) + print(f" 💾 Saved: {ckpt}") + + if step % config["eval_every"] == 0: + print(f"\n📊 Eval @ step {step}:") + r = evaluate(policy_model, tokenizer) + print(f" Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}\n") + + # Final save + save_lora_weights(policy_model, output_dir) + + print("\n📊 Final evaluation:") + r = evaluate(policy_model, tokenizer) + print(f"Avg: {r['avg']:.1f}/120, Perfect: {r['perfect']}/{r['total']}, Range: [{r['min']}, {r['max']}]") + print("\n✅ Done!") + + +if __name__ == "__main__": + main() diff --git a/finetune-mlx/prepare_data.py b/finetune-mlx/prepare_data.py new file mode 100644 index 00000000..7a58a630 --- /dev/null +++ b/finetune-mlx/prepare_data.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" +Prepare generated data for MLX training. + +Takes the JSONL output from generate_from_notes.py and creates +train.jsonl and valid.jsonl files in the format expected by mlx-lm. + +Usage: + python dataset/prepare_data.py --input data/custom/expansions.jsonl --output data/custom +""" + +import argparse +import json +import random +from pathlib import Path + + +def load_examples(input_path: Path) -> list[dict]: + """Load and validate examples from JSONL.""" + examples = [] + invalid = 0 + + with open(input_path) as f: + for line in f: + if not line.strip(): + continue + try: + example = json.loads(line) + # Validate required fields + if "text" in example and len(example["text"]) > 50: + examples.append(example) + else: + invalid += 1 + except json.JSONDecodeError: + invalid += 1 + + print(f"Loaded {len(examples)} valid examples ({invalid} invalid)") + return examples + + +def split_data(examples: list[dict], eval_ratio: float = 0.1) -> tuple[list, list]: + """Split into train and validation sets.""" + random.shuffle(examples) + split_idx = int(len(examples) * (1 - eval_ratio)) + return examples[:split_idx], examples[split_idx:] + + +def save_for_mlx(examples: list[dict], output_path: Path, name: str): + """Save in MLX-lm format (JSONL with 'text' field).""" + filepath = output_path / f"{name}.jsonl" + with open(filepath, "w") as f: + for ex in examples: + # MLX expects just {"text": "..."} format + f.write(json.dumps({"text": ex["text"]}) + "\n") + print(f"Saved {len(examples)} examples to {filepath}") + + +def main(): + parser = argparse.ArgumentParser(description="Prepare data for MLX training") + parser.add_argument("--input", "-i", type=str, required=True, + help="Input JSONL from generate_from_notes.py") + parser.add_argument("--output", "-o", type=str, default="data/custom", + help="Output directory for train/valid files") + parser.add_argument("--eval-ratio", type=float, default=0.1, + help="Fraction of data for validation (default: 0.1)") + parser.add_argument("--seed", type=int, default=42, + help="Random seed for reproducibility") + args = parser.parse_args() + + random.seed(args.seed) + + input_path = Path(args.input) + output_path = Path(args.output) + output_path.mkdir(parents=True, exist_ok=True) + + # Load and split + examples = load_examples(input_path) + train_examples, valid_examples = split_data(examples, args.eval_ratio) + + print(f"\nSplit: {len(train_examples)} train, {len(valid_examples)} valid") + + # Save for MLX + save_for_mlx(train_examples, output_path, "train") + save_for_mlx(valid_examples, output_path, "valid") + + # Save summary + summary = { + "total": len(examples), + "train": len(train_examples), + "valid": len(valid_examples), + "eval_ratio": args.eval_ratio, + } + with open(output_path / "dataset_info.json", "w") as f: + json.dump(summary, f, indent=2) + + print(f"\n✅ Data prepared in {output_path}") + print(f" Ready for: python train.py sft --data {output_path}") + + +if __name__ == "__main__": + main() diff --git a/finetune-mlx/requirements.txt b/finetune-mlx/requirements.txt new file mode 100644 index 00000000..b60334df --- /dev/null +++ b/finetune-mlx/requirements.txt @@ -0,0 +1,14 @@ +# Core MLX dependencies +mlx-lm>=0.20.0 +mlx>=0.21.0 + +# Data handling +huggingface_hub>=0.20.0 +datasets>=2.14.0 + +# Config and utilities +pyyaml>=6.0 +tqdm>=4.65.0 + +# Testing (optional) +pytest>=7.0.0 diff --git a/finetune-mlx/reward.py b/finetune-mlx/reward.py new file mode 100644 index 00000000..b29e678b --- /dev/null +++ b/finetune-mlx/reward.py @@ -0,0 +1,610 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [] +# /// +""" +QMD Query Expansion Reward Function + +Single source of truth for scoring query expansions. Used by: +- GRPO training (as the RL reward signal) +- Evaluation scripts (for scoring model outputs) + +Scores expansions on five dimensions: + Format (30) - Has lex/vec lines, no invalid lines + Diversity (30) - Multiple types, diverse content, no echoes + HyDE (20) - Optional bonus for hypothetical document passage + Quality (20) - Lex shorter than vec, natural language, key terms + Entity (20) - Named entity preservation in lex/vec lines + +Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation. +""" + +import re +from collections import Counter + +# ============================================================================= +# Constants +# ============================================================================= + +# "only:" mode patterns - when query ends with these, expect only that type +# Format: "query /only:lex" (slash prefix, no space after colon) +ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE) + +STOPWORDS = frozenset({ + 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', + 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by', +}) + +KEY_TERM_STOPWORDS = frozenset({ + 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', + 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we', + 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell', +}) + +GENERIC_LEX_PHRASES = frozenset({ + 'find information about', 'search for', 'look up', 'get information', + 'learn about', 'information on', 'details about', 'find out about', + 'what is', 'how to', 'guide to', 'help with', +}) + +# Chat template tokens that indicate a broken output +CHAT_TEMPLATE_TOKENS = frozenset({ + '<|im_start|>', '<|im_end|>', '<|endoftext|>', + '\nassistant\n', '\nuser\n', +}) + + +# ============================================================================= +# Parsing +# ============================================================================= + +def parse_expansion(text: str) -> dict: + """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists.""" + result = {"lex": [], "vec": [], "hyde": [], "invalid": []} + for line in text.strip().split("\n"): + line = line.strip() + if not line: + continue + if line.startswith("lex:"): + result["lex"].append(line[4:].strip()) + elif line.startswith("vec:"): + result["vec"].append(line[4:].strip()) + elif line.startswith("hyde:"): + result["hyde"].append(line[5:].strip()) + else: + result["invalid"].append(line) + return result + + +def detect_only_mode(query: str) -> tuple[str | None, str]: + """Detect if query ends with 'only: lex/vec/hyde'. + + Returns (only_type, base_query) where only_type is None for normal queries. + """ + match = ONLY_MODE_PATTERN.search(query) + if match: + only_type = match.group(1).lower() + base_query = query[:match.start()].strip() + return only_type, base_query + return None, query + + +def clean_model_output(text: str) -> tuple[str, bool]: + """Strip chat template artifacts from model output. + + Returns (cleaned_text, used_thinking) where used_thinking is True + if the model emitted ... blocks. + """ + text = text.replace('<|im_end|>', '').strip() + + used_thinking = '' in text and '' in text + if used_thinking: + text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() + + return text, used_thinking + + +# ============================================================================= +# Helpers +# ============================================================================= + +def extract_named_entities(query: str) -> set: + """Extract named entities using heuristics. + + Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React), + technical terms with special chars (node.js, C++), CamelCase (JavaScript), + and compound names (TDS motorsports -> both words). + """ + entities = set() + words = query.split() + prev_was_entity = False + + for i, word in enumerate(words): + clean = word.strip('.,!?:;()[]"\'') + if not clean: + prev_was_entity = False + continue + + is_entity = False + + if clean.isupper() and len(clean) >= 2: + entities.add(clean.lower()) + is_entity = True + elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS: + entities.add(clean.lower()) + is_entity = True + elif any(c in clean for c in '.+-#@') and len(clean) >= 2: + entities.add(clean.lower()) + is_entity = True + elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper(): + entities.add(clean.lower()) + is_entity = True + elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS: + entities.add(clean.lower()) + is_entity = True + + prev_was_entity = is_entity + + return entities + + +def get_key_terms(query: str) -> set: + """Get non-stopword terms from a query.""" + return set(query.lower().split()) - KEY_TERM_STOPWORDS + + +def lex_preserves_key_terms(lex_line: str, query: str) -> bool: + """Does the lex line contain at least one key term from the query?""" + key_terms = get_key_terms(query) + if not key_terms: + return True + return bool(key_terms & set(lex_line.lower().split())) + + +def lex_preserves_entities(line: str, entities: set) -> bool: + """Does the line contain at least one named entity?""" + if not entities: + return True + lower = line.lower() + return any(e in lower for e in entities) + + +def lex_is_generic(lex_line: str) -> bool: + """Is this lex line a useless generic filler phrase?""" + lower = lex_line.lower().strip() + for phrase in GENERIC_LEX_PHRASES: + if phrase in lower or lower.startswith(phrase.split()[0]): + remaining = lower + for word in phrase.split(): + remaining = remaining.replace(word, '', 1).strip() + if len(remaining) < 3: + return True + return False + + +def word_set_distance(a: str, b: str) -> int: + """Symmetric difference of word sets (how many words are unique to one).""" + return len(set(a.lower().split()) ^ set(b.lower().split())) + + +def is_diverse(a: str, b: str, min_distance: int = 2) -> bool: + """Are two strings sufficiently different?""" + a, b = a.lower().strip(), b.lower().strip() + if a == b or a in b or b in a: + return False + return word_set_distance(a, b) >= min_distance + + +def echoes_query(expansion: str, query: str) -> bool: + """Is this expansion just echoing the original query?""" + exp, q = expansion.lower().strip(), query.lower().strip() + return exp == q or (q in exp and len(exp) < len(q) + 10) + + +def word_repetition_penalty(text: str) -> int: + """Penalty for words repeated 3+ times (excluding stopwords).""" + counts = Counter(re.findall(r'\b\w+\b', text.lower())) + return sum((c - 2) * 2 for w, c in counts.items() + if c >= 3 and w not in STOPWORDS and len(w) > 2) + + +# ============================================================================= +# Scoring +# ============================================================================= + +def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict: + """Score an 'only:' mode expansion. Expects ONLY the requested type.""" + parsed = parse_expansion(text) + deductions = [] + + # Expected type must be present + expected_items = parsed.get(only_type, []) + if not expected_items: + return { + "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0, + "think_bonus": 0, "total": 0, "max_possible": 100, + "percentage": 0.0, "rating": "Failed", + "deductions": [f"missing expected {only_type}: output"], + "parsed": parsed, + "entities_detected": [], + "only_mode": only_type, + } + + # Penalize presence of OTHER types + other_types = {"lex", "vec", "hyde"} - {only_type} + unwanted_count = sum(len(parsed.get(t, [])) for t in other_types) + if unwanted_count > 0: + deductions.append(f"contains unwanted types (expected only {only_type})") + + # --- Format (0-30) --- + format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10) + + # --- Diversity (0-30) --- + diversity_score = 0 + if len(expected_items) >= 2: + diversity_score += 15 + # Check for diversity among items + div_score = 15 + for i, a in enumerate(expected_items): + for b in expected_items[i+1:]: + if not is_diverse(a, b, 2): + div_score -= 5 + deductions.append(f"{only_type} duplicate: {a[:20]}...") + diversity_score += max(0, div_score) + elif len(expected_items) == 1: + diversity_score = 15 # One item is fine for single-type output + + # Check for echoes + for exp in expected_items: + if echoes_query(exp, base_query): + diversity_score -= 5 + deductions.append(f"echoes query: {exp[:20]}...") + diversity_score = max(0, diversity_score) + + # --- Type-specific quality (0-20) --- + quality_score = 10 # base + entities = extract_named_entities(base_query) + + if only_type == "lex": + # Lex should be short keyword phrases with key terms + with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query)) + if with_terms == len(expected_items): + quality_score += 5 + # Check for generic phrases + generic = sum(1 for l in expected_items if lex_is_generic(l)) + if generic == 0: + quality_score += 5 + else: + deductions.append(f"{generic} generic lex phrases") + + elif only_type == "vec": + # Vec should be natural language sentences + natural = sum(1 for v in expected_items if " " in v and len(v) > 15) + if natural == len(expected_items): + quality_score += 10 + else: + quality_score += 5 + deductions.append("vec not all natural language") + + elif only_type == "hyde": + # Hyde should be a document snippet (50-200 chars) + hyde_text = expected_items[0] + hyde_len = len(hyde_text) + if 50 <= hyde_len <= 200: + quality_score += 10 + elif 30 <= hyde_len <= 300: + quality_score += 5 + deductions.append(f"hyde length {hyde_len} (ideal: 50-200)") + else: + deductions.append(f"hyde length {hyde_len} out of range") + + # --- Entity preservation (0-20) --- + entity_score = 10 # base + if entities: + with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities)) + if with_entities == len(expected_items): + entity_score += 10 + elif with_entities > 0: + entity_score += 5 + else: + entity_score = 0 + deductions.append(f"missing entities: {entities}") + + # --- Think bonus (0-20) --- + think_bonus = 0 if used_thinking else 20 + + # --- Total --- + total = format_score + diversity_score + quality_score + entity_score + think_bonus + max_possible = 120 + percentage = max(0.0, min(100.0, total / max_possible * 100)) + + if percentage >= 80: + rating = "Excellent" + elif percentage >= 60: + rating = "Good" + elif percentage >= 40: + rating = "Acceptable" + elif percentage >= 20: + rating = "Poor" + else: + rating = "Failed" + + return { + "format": format_score, + "diversity": diversity_score, + "hyde": 0, # not used in only mode (quality covers it) + "quality": quality_score, + "entity": entity_score, + "think_bonus": think_bonus, + "total": total, + "max_possible": max_possible, + "percentage": round(percentage, 1), + "rating": rating, + "deductions": deductions, + "parsed": parsed, + "entities_detected": list(entities) if entities else [], + "only_mode": only_type, + } + + +def score_expansion_detailed(query: str, expansion: str) -> dict: + """Score an expansion with full breakdown. Returns dict with all dimensions.""" + text, used_thinking = clean_model_output(expansion.strip()) + deductions = [] + + # Detect "only:" mode + only_type, base_query = detect_only_mode(query) + + def _fail(reason): + return { + "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0, + "think_bonus": 0, "total": 0, "max_possible": 100, + "percentage": 0.0, "rating": "Failed", + "deductions": [reason], + "parsed": parse_expansion(expansion), + "entities_detected": [], + "only_mode": only_type, + } + + # Hard fail: remaining chat template tokens + if any(tok in text for tok in CHAT_TEMPLATE_TOKENS): + return _fail("CHAT TEMPLATE LEAKAGE") + + # Hard fail: every non-empty line must have a valid prefix + for line in text.split("\n"): + line = line.strip() + if line and not line.startswith(("lex:", "vec:", "hyde:")): + return _fail(f"INVALID LINE: {line[:50]}") + + # --- Handle "only:" mode separately --- + if only_type: + return _score_only_mode(query, base_query, text, used_thinking, only_type) + + parsed = parse_expansion(text) + + # --- Format (0-30) --- + format_score = 10 # no invalid lines (guaranteed by hard fail) + if parsed["lex"]: + format_score += 10 + else: + deductions.append("missing lex:") + if parsed["vec"]: + format_score += 10 + else: + deductions.append("missing vec:") + + # --- Diversity (0-30) --- + diversity_score = 0 + + types_present = sum(1 for t in ("lex", "vec") if parsed[t]) + if types_present >= 2: + diversity_score += 10 + else: + deductions.append("only one type") + + if len(parsed["lex"]) + len(parsed["vec"]) >= 2: + diversity_score += 5 + + lex_div = 5 + for i, a in enumerate(parsed["lex"]): + for b in parsed["lex"][i+1:]: + if not is_diverse(a, b, 2): + lex_div -= 2 + deductions.append(f"lex duplicate: {a[:20]}...") + diversity_score += max(0, lex_div) + + vec_div = 5 + for i, a in enumerate(parsed["vec"]): + for b in parsed["vec"][i+1:]: + if not is_diverse(a, b, 3): + vec_div -= 2 + deductions.append(f"vec duplicate: {a[:20]}...") + diversity_score += max(0, vec_div) + + echo = 5 + lex_echo_count = 0 + for exp in parsed["lex"]: + if echoes_query(exp, query): + lex_echo_count += 1 + deductions.append(f"lex echoes query: {exp[:20]}...") + # Harsh penalty for lex echoes - they're useless + if lex_echo_count > 0: + echo -= lex_echo_count * 10 # -10 per echo + + for exp in parsed["vec"]: + if echoes_query(exp, query): + echo -= 3 # vec echoes less severe (natural language overlap ok) + deductions.append(f"vec echoes query: {exp[:20]}...") + diversity_score += max(-10, echo) # can go negative + + # --- HyDE (0-20, optional bonus) --- + hyde_score = 0 + if parsed["hyde"]: + hyde_text = parsed["hyde"][0] + hyde_score += 5 + hyde_len = len(hyde_text) + if 50 <= hyde_len <= 200: + hyde_score += 5 + elif hyde_len < 50: + hyde_score += 2 + deductions.append(f"hyde too short ({hyde_len})") + else: + deductions.append(f"hyde too long ({hyde_len})") + if "\n" not in hyde_text: + hyde_score += 5 + hyde_score += max(0, 5 - word_repetition_penalty(hyde_text)) + + # --- Quality (0-20) --- + quality_score = 5 # base relevance + if parsed["lex"] and parsed["vec"]: + avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"]) + avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"]) + if avg_lex <= avg_vec: + quality_score += 5 + else: + deductions.append("lex longer than vec") + if parsed["vec"]: + natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15) + quality_score += 5 if natural == len(parsed["vec"]) else 2 + if parsed["lex"]: + with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query)) + if with_terms == len(parsed["lex"]): + quality_score += 5 + elif with_terms > 0: + quality_score += 2 + else: + deductions.append("lex missing key terms") + + # --- Entity Preservation (-45 to +20) --- + entity_score = 0 + entities = extract_named_entities(query) + if entities and parsed["lex"]: + with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities)) + if with_entities == len(parsed["lex"]): + entity_score += 15 + elif with_entities > 0: + entity_score += 5 + else: + entity_score -= 30 + deductions.append(f"lex missing entities: {entities}") + + generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l)) + if generic_count: + entity_score -= generic_count * 15 + deductions.append(f"{generic_count} generic lex phrases") + + if parsed["vec"]: + vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities)) + if vec_with > 0: + entity_score += 5 + elif not entities: + entity_score = 10 + + # --- Think bonus (0-20): reward NOT using thinking mode --- + think_bonus = 0 if used_thinking else 20 + + # --- Total --- + total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus + max_possible = 140 if parsed["hyde"] else 120 + percentage = max(0.0, min(100.0, total / max_possible * 100)) + + # Hard cap: lex echoes are unacceptable - cap at 50% + if lex_echo_count > 0: + percentage = min(percentage, 50.0) + deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)") + + if percentage >= 80: + rating = "Excellent" + elif percentage >= 60: + rating = "Good" + elif percentage >= 40: + rating = "Acceptable" + elif percentage >= 20: + rating = "Poor" + else: + rating = "Failed" + + return { + "format": format_score, + "diversity": diversity_score, + "hyde": hyde_score, + "quality": quality_score, + "entity": max(0, entity_score), + "think_bonus": think_bonus, + "total": max(0, total), + "max_possible": max_possible, + "percentage": round(percentage, 1), + "rating": rating, + "deductions": deductions, + "parsed": parsed, + "entities_detected": list(entities) if entities else [], + "only_mode": None, + } + + +def score_expansion(query: str, expansion: str) -> float: + """Score expansion as a float in [0.0, 1.0] for use as RL reward.""" + result = score_expansion_detailed(query, expansion) + return max(0.0, min(1.0, result["total"] / result["max_possible"])) + + +def extract_query_from_prompt(prompt: str) -> str: + """Extract the query string from a chat-formatted prompt.""" + if "Expand this search query:" in prompt: + query = prompt.split("Expand this search query:")[-1].strip() + if "<|im_end|>" in query: + query = query.split("<|im_end|>")[0].strip() + return query + return prompt.strip() + + +# ============================================================================= +# TRL-compatible reward class +# ============================================================================= + +class QMDRewardFunction: + """Reward function compatible with TRL's GRPOTrainer.""" + __name__ = "qmd_scoring_reward" + + def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]: + rewards = [] + for i, completion in enumerate(completions): + query = "" + if prompts and i < len(prompts): + query = extract_query_from_prompt(prompts[i]) + rewards.append(score_expansion(query, completion)) + return rewards + + +# ============================================================================= +# CLI: run standalone to test the reward function +# ============================================================================= + +if __name__ == "__main__": + print("QMD Reward Function Self-Test") + print("=" * 60) + + tests = [ + ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."), + ("auth", "auth is important for security"), + ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"), + ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"), + ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"), + ("auth", "Let me think...\nlex: auth"), + ("auth", "lex: auth\nThis is some explanation\nvec: more"), + # "/only:" mode tests (slash prefix) + ("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"), + ("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec + ("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"), + ("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."), + ] + + for query, expansion in tests: + score = score_expansion(query, expansion) + detail = score_expansion_detailed(query, expansion) + only_mode = detail.get("only_mode") + mode_str = f" [only:{only_mode}]" if only_mode else "" + print(f"\n Query: '{query}'{mode_str}") + print(f" Score: {score:.2f} ({detail['rating']})") + if detail["deductions"]: + print(f" Issues: {', '.join(detail['deductions'][:3])}") diff --git a/finetune-mlx/tests/test_format.py b/finetune-mlx/tests/test_format.py new file mode 100644 index 00000000..1fdc3504 --- /dev/null +++ b/finetune-mlx/tests/test_format.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +"""Tests for QMD output format validation.""" + +import pytest +import sys +from pathlib import Path + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from eval import score_expansion + + +class TestOutputFormat: + """Test output format scoring.""" + + def test_perfect_format(self): + """Test scoring of perfect format output.""" + text = """hyde: The Renaissance was a cultural movement in Europe. +lex: Renaissance period 14th century +lex: Renaissance art culture +vec: what was the Renaissance period +vec: how did the Renaissance transform art""" + + scores = score_expansion(text) + assert scores["has_lex"] == 1 + assert scores["has_vec"] == 1 + assert scores["has_hyde"] == 1 + assert scores["format_valid"] == 1 + assert scores["total"] >= 100 + + def test_missing_hyde(self): + """Test output missing hyde component.""" + text = """lex: auth config settings +lex: authentication setup +vec: how to configure auth +vec: auth tutorial""" + + scores = score_expansion(text) + assert scores["has_lex"] == 1 + assert scores["has_vec"] == 1 + assert scores["has_hyde"] == 0 + assert scores["format_valid"] == 0 + + def test_missing_lex(self): + """Test output missing lex component.""" + text = """hyde: To configure authentication, use the config file. +vec: how to set up auth +vec: auth configuration guide""" + + scores = score_expansion(text) + assert scores["has_lex"] == 0 + assert scores["has_vec"] == 1 + assert scores["has_hyde"] == 1 + assert scores["format_valid"] == 0 + + def test_missing_vec(self): + """Test output missing vec component.""" + text = """hyde: Database connections require proper configuration. +lex: database connection +lex: db config""" + + scores = score_expansion(text) + assert scores["has_lex"] == 1 + assert scores["has_vec"] == 0 + assert scores["has_hyde"] == 1 + assert scores["format_valid"] == 0 + + def test_empty_output(self): + """Test empty output.""" + scores = score_expansion("") + assert scores["format_valid"] == 0 + assert scores["total"] == 0 + + def test_multiple_lex_entries(self): + """Test counting multiple lex entries.""" + text = """hyde: Test +lex: term1 +lex: term2 +lex: term3 +lex: term4 +vec: query""" + + scores = score_expansion(text) + assert scores["lex_count"] == 4 + # Max 3 count toward score + assert scores["total"] == 40 + 30 + 10 + 20 # format + 3*lex + 1*vec + hyde + + def test_multiple_vec_entries(self): + """Test counting multiple vec entries.""" + text = """hyde: Test +lex: term +vec: query1 +vec: query2 +vec: query3 +vec: query4""" + + scores = score_expansion(text) + assert scores["vec_count"] == 4 + # Max 3 count toward score + assert scores["total"] == 40 + 10 + 30 + 20 # format + 1*lex + 3*vec + hyde + + +class TestEdgeCases: + """Test edge cases.""" + + def test_whitespace_handling(self): + """Test handling of extra whitespace.""" + text = """ hyde: Some text + lex: keyword + vec: natural question """ + + scores = score_expansion(text) + assert scores["has_lex"] == 1 + assert scores["has_vec"] == 1 + assert scores["has_hyde"] == 1 + + def test_case_sensitivity(self): + """Test that prefixes are case sensitive (lowercase required).""" + text = """HYDE: Some text +LEX: keyword +VEC: question""" + + scores = score_expansion(text) + # Uppercase should not match + assert scores["has_lex"] == 0 + assert scores["has_vec"] == 0 + assert scores["has_hyde"] == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/finetune-mlx/tests/test_model.py b/finetune-mlx/tests/test_model.py new file mode 100644 index 00000000..efb245b8 --- /dev/null +++ b/finetune-mlx/tests/test_model.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Tests for model loading and inference.""" + +import pytest +import sys +from pathlib import Path + +# Add parent to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +# These tests require the model to be downloaded +# Mark as slow/integration tests + +@pytest.mark.slow +class TestModelLoading: + """Test model loading functionality.""" + + def test_base_model_exists(self): + """Check if base model directory exists.""" + model_path = Path(__file__).parent.parent / "models" / "Qwen_Qwen2.5-1.5B" / "mlx" + assert model_path.exists(), f"Base model not found at {model_path}" + + def test_adapter_exists(self): + """Check if trained adapter exists.""" + adapter_path = Path(__file__).parent.parent / "adapters" / "sft-full" + assert adapter_path.exists(), f"Adapter not found at {adapter_path}" + + # Check required files + assert (adapter_path / "adapters.safetensors").exists() + assert (adapter_path / "adapter_config.json").exists() + + def test_adapter_config_valid(self): + """Validate adapter config JSON.""" + import json + + config_path = Path(__file__).parent.parent / "adapters" / "sft-full" / "adapter_config.json" + with open(config_path) as f: + config = json.load(f) + + # Check required fields + assert "lora_layers" in config + assert "rank" in config or "r" in config + + +@pytest.mark.slow +class TestInference: + """Test model inference.""" + + @pytest.fixture + def model_and_tokenizer(self): + """Load model with adapter.""" + from mlx_lm import load + + model_path = Path(__file__).parent.parent / "models" / "Qwen_Qwen2.5-1.5B" / "mlx" + adapter_path = Path(__file__).parent.parent / "adapters" / "sft-full" + + model, tokenizer = load(str(model_path), adapter_path=str(adapter_path)) + return model, tokenizer + + def test_generate_expansion(self, model_and_tokenizer): + """Test generating a query expansion.""" + from mlx_lm import generate + + model, tokenizer = model_and_tokenizer + + prompt = """<|im_start|>user +/no_think Expand this search query: test query<|im_end|> +<|im_start|>assistant +""" + + response = generate( + model, + tokenizer, + prompt=prompt, + max_tokens=128, + verbose=False, + ) + + assert len(response) > 0 + # Should contain at least some expected prefixes + response_lower = response.lower() + has_valid_output = ( + "lex:" in response_lower or + "vec:" in response_lower or + "hyde:" in response_lower + ) + assert has_valid_output, f"Response missing expected format: {response}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/finetune-mlx/train.py b/finetune-mlx/train.py new file mode 100644 index 00000000..36f900bf --- /dev/null +++ b/finetune-mlx/train.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +""" +QMD Query Expansion Training - Apple Silicon Edition + +Uses MLX for efficient fine-tuning on Metal GPUs. + +Usage: + python train.py sft # Supervised fine-tuning + python train.py grpo # RL refinement (after SFT) + python train.py sft --config configs/sft.yaml +""" + +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path + +import yaml +from datasets import load_dataset +from huggingface_hub import snapshot_download + + +def prepare_dataset(config: dict, stage: str) -> Path: + """Download and prepare dataset in MLX format. + + Supports both: + - Local path: config.dataset.path = "data/custom" + - HuggingFace: config.dataset.name = "tobil/qmd-query-expansion-train-v2" + """ + dataset_config = config.get("dataset", {}) + + # Check for local path first + local_path = dataset_config.get("path") + if local_path: + local_dir = Path(local_path) + train_file = local_dir / "train.jsonl" + valid_file = local_dir / "valid.jsonl" + + if train_file.exists() and valid_file.exists(): + # Count examples + train_count = sum(1 for _ in open(train_file)) + valid_count = sum(1 for _ in open(valid_file)) + print(f"📦 Using local dataset: {local_dir}") + print(f"✅ Dataset ready: {train_count} train, {valid_count} valid") + return local_dir + else: + print(f"❌ Local dataset not found at {local_dir}") + print(f" Expected: {train_file} and {valid_file}") + print(f" Run: python dataset/generate_from_notes.py && python dataset/prepare_data.py") + sys.exit(1) + + # Fall back to HuggingFace dataset + dataset_name = dataset_config.get("name", "tobil/qmd-query-expansion-train-v2") + output_dir = Path("data") / stage + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"📦 Loading dataset from HuggingFace: {dataset_name}") + ds = load_dataset(dataset_name, split="train") + + # Split into train/valid + eval_ratio = dataset_config.get("eval_split", 0.1) + split = ds.train_test_split(test_size=eval_ratio, seed=42) + + # Convert to MLX format (JSONL with "text" field) + train_file = output_dir / "train.jsonl" + valid_file = output_dir / "valid.jsonl" + + text_field = dataset_config.get("text_field", "text") + + train_count = 0 + with open(train_file, "w") as f: + for item in split["train"]: + text = item[text_field] + if text and isinstance(text, str) and len(text) > 10: + f.write(json.dumps({"text": text}) + "\n") + train_count += 1 + + valid_count = 0 + with open(valid_file, "w") as f: + for item in split["test"]: + text = item[text_field] + if text and isinstance(text, str) and len(text) > 10: + f.write(json.dumps({"text": text}) + "\n") + valid_count += 1 + + print(f"✅ Dataset prepared: {train_count} train, {valid_count} valid (filtered nulls)") + return output_dir + + +def download_model(model_name: str) -> Path: + """Download model and convert to MLX format if needed.""" + print(f"📥 Downloading model: {model_name}") + + # Check if already converted to MLX + mlx_model_dir = Path("models") / model_name.replace("/", "_") / "mlx" + + if mlx_model_dir.exists(): + print(f"✅ Model already exists: {mlx_model_dir}") + return mlx_model_dir + + # Download and convert + mlx_model_dir.parent.mkdir(parents=True, exist_ok=True) + + # Use mlx_lm.convert to download and convert + cmd = [ + sys.executable, "-m", "mlx_lm.convert", + "--hf-path", model_name, + "--mlx-path", str(mlx_model_dir), + "-q" # Quantize for memory efficiency + ] + + print(f"🔄 Converting to MLX format...") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + print(f"❌ Conversion failed: {result.stderr}") + # Try without quantization + cmd.remove("-q") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Model conversion failed: {result.stderr}") + + print(f"✅ Model ready: {mlx_model_dir}") + return mlx_model_dir + + +def train_sft(config: dict): + """Run supervised fine-tuning with LoRA.""" + print("\n🚀 Starting SFT Training\n") + + model_name = config.get("model", {}).get("base", "Qwen/Qwen2.5-1.5B") + output_name = config.get("model", {}).get("output", "qmd-expansion-sft") + + # Prepare data + data_dir = prepare_dataset(config, "sft") + + # Download/convert model + model_dir = download_model(model_name) + + # Training params + training = config.get("training", {}) + lora = config.get("lora", {}) + + # Use output name for adapter path (allows multiple trained models) + adapter_name = output_name.replace("/", "_").replace("-", "_") + adapter_path = Path("adapters") / adapter_name + adapter_path.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "-m", "mlx_lm", "lora", + "--model", str(model_dir), + "--train", + "--data", str(data_dir), + "--adapter-path", str(adapter_path), + "--batch-size", str(training.get("batch_size", 4)), + "--iters", str(training.get("iters", 1000)), + "--learning-rate", str(training.get("learning_rate", 1e-4)), + "--num-layers", str(lora.get("num_layers", 16)), + "--steps-per-report", "10", + "--steps-per-eval", "100", + "--save-every", "200", + ] + + max_length = training.get("max_length", 512) + if max_length: + cmd.extend(["--max-seq-length", str(max_length)]) + + print(f"📝 Command: {' '.join(cmd)}\n") + + # Run training + result = subprocess.run(cmd) + + if result.returncode != 0: + print("❌ Training failed") + sys.exit(1) + + print(f"\n✅ SFT complete! Adapter saved to: {adapter_path}") + return adapter_path + + +def train_grpo(config: dict): + """Run GRPO (RL refinement) on top of SFT model.""" + print("\n🚀 Starting GRPO Training\n") + + # Check SFT adapter exists + sft_adapter = Path("adapters") / "sft" + if not sft_adapter.exists(): + print("❌ SFT adapter not found. Run 'python train.py sft' first.") + sys.exit(1) + + # For GRPO, we need to implement reward-based training + # MLX doesn't have built-in GRPO, so we'll use a simpler approach: + # Continue SFT with filtered high-quality examples + + print("⚠️ GRPO not yet implemented for MLX.") + print(" The SFT model should work well for most use cases.") + print(" For GRPO, use HuggingFace Jobs with the original scripts.") + + # TODO: Implement reward model + PPO-style training + # This would require: + # 1. Generate expansions from SFT model + # 2. Score with reward function + # 3. Filter top examples + # 4. Continue training on high-reward examples + + return sft_adapter + + +def load_config(config_path: str) -> dict: + """Load YAML config file.""" + with open(config_path) as f: + return yaml.safe_load(f) + + +def main(): + parser = argparse.ArgumentParser(description="QMD Query Expansion Training") + parser.add_argument("stage", choices=["sft", "grpo"], help="Training stage") + parser.add_argument("--config", "-c", help="Config file path") + parser.add_argument("--dry-run", action="store_true", help="Print commands without running") + + args = parser.parse_args() + + # Load config + if args.config: + config = load_config(args.config) + else: + # Default config + config_path = Path("configs") / f"{args.stage}.yaml" + if config_path.exists(): + config = load_config(str(config_path)) + else: + config = {} + + # Run training + if args.stage == "sft": + train_sft(config) + elif args.stage == "grpo": + train_grpo(config) + + +if __name__ == "__main__": + main() diff --git a/scripts/mlx_expand.py b/scripts/mlx_expand.py new file mode 100755 index 00000000..81d85b0d --- /dev/null +++ b/scripts/mlx_expand.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +"""QMD MLX Query Expansion (sidecar). + +⚠️ EXPERIMENTAL / DEV TOOL +This script is a bridge for iterating on fine-tuned MLX adapters locally. +It requires Apple Silicon + mlx_lm. If unavailable or failing, QMD falls back +to the default llama.cpp-based query expansion. + +Outputs query expansions in the exact line-oriented format expected by QMD: + lex: ... + vec: ... + hyde: ... + +Environment variables: + QMD_MLX_MODEL Base MLX model directory (default: finetune-mlx/models/Qwen_Qwen3-1.7B/mlx) + QMD_MLX_ADAPTER LoRA adapter directory (default: finetune-mlx/adapters/qmd_query_expansion_1.7B_sft) + QMD_MLX_TEMP Temperature (default: 1.0) + QMD_MLX_MAX_TOKENS Max tokens (default: 512) + +Usage: + ./scripts/mlx_expand.py "auth config" + echo "auth config" | ./scripts/mlx_expand.py +""" + +import os +import sys + +from mlx_lm import load, generate +from mlx_lm.sample_utils import make_sampler + +PROMPT_TEMPLATE = """<|im_start|>user +/no_think Expand this search query: {query}<|im_end|> +<|im_start|>assistant +""" + + +def _read_query(argv): + if len(argv) > 1: + return " ".join(argv[1:]).strip() + data = sys.stdin.read().strip() + return data + + +def main(): + query = _read_query(sys.argv) + if not query: + print("", end="") + return 0 + + repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + + model = os.environ.get( + "QMD_MLX_MODEL", + os.path.join(repo_root, "finetune-mlx", "models", "Qwen_Qwen3-1.7B", "mlx"), + ) + adapter = os.environ.get( + "QMD_MLX_ADAPTER", + os.path.join(repo_root, "finetune-mlx", "adapters", "qmd_query_expansion_1.7B_sft"), + ) + + temp = float(os.environ.get("QMD_MLX_TEMP", "1.0")) + max_tokens = int(os.environ.get("QMD_MLX_MAX_TOKENS", "512")) + + prompt = PROMPT_TEMPLATE.format(query=query) + + # Load base model + adapter if available. + try: + m, tok = load(model, adapter_path=adapter) + except Exception: + m, tok = load(model) + + sampler = make_sampler(temp=temp) + out = generate( + m, + tok, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + verbose=False, + ) + + # Clean special tokens and keep only line-based output. + out = out.replace("<|im_end|>", "").strip() + lines = [ln.strip() for ln in out.splitlines() if ln.strip()] + + # Keep only lex/vec/hyde lines (defensive) + filtered = [] + for ln in lines: + if ln.startswith("lex:") or ln.startswith("vec:") or ln.startswith("hyde:"): + filtered.append(ln) + + sys.stdout.write("\n".join(filtered).strip()) + if filtered: + sys.stdout.write("\n") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())