|
| 1 | +"""Evaluation script for retrieval quality.""" |
| 2 | + |
| 3 | +import json |
| 4 | +import sys |
| 5 | +from pathlib import Path |
| 6 | +from knowcode.chunk_repository import InMemoryChunkRepository |
| 7 | +from knowcode.vector_store import VectorStore |
| 8 | +from knowcode.hybrid_index import HybridIndex |
| 9 | +from knowcode.embedding import OpenAIEmbeddingProvider |
| 10 | +from knowcode.models import EmbeddingConfig, CodeChunk |
| 11 | + |
| 12 | +def evaluate(ground_truth_path: Path, index_path: Path) -> dict: |
| 13 | + """Evaluate retrieval quality against ground truth.""" |
| 14 | + if not ground_truth_path.exists(): |
| 15 | + return {"error": "Ground truth file not found"} |
| 16 | + |
| 17 | + with open(ground_truth_path) as f: |
| 18 | + ground_truth = json.load(f) |
| 19 | + |
| 20 | + # Load index components |
| 21 | + repo = InMemoryChunkRepository() |
| 22 | + # Assuming index_path is directory containing chunks.json and vectors used by Indexer.load |
| 23 | + # Note: Indexer.load logic: |
| 24 | + # chunks_file = path / "chunks.json" |
| 25 | + # vector_path = path / "vectors" |
| 26 | + |
| 27 | + chunks_file = index_path / "chunks.json" |
| 28 | + if chunks_file.exists(): |
| 29 | + with open(chunks_file) as f: |
| 30 | + data = json.load(f) |
| 31 | + for c_data in data["chunks"]: |
| 32 | + repo.add(CodeChunk(**c_data)) |
| 33 | + |
| 34 | + vs = VectorStore(dimension=1536, index_path=index_path / "vectors") |
| 35 | + # Note: We need a real provider for queries, or mock if vectors are precomputed? |
| 36 | + # For evaluation we assume we have an API key or use the same provider used for indexing. |
| 37 | + # Here we assume OpenAI. |
| 38 | + try: |
| 39 | + provider = OpenAIEmbeddingProvider(EmbeddingConfig()) |
| 40 | + except: |
| 41 | + print("Skipping evaluation: No OpenAI API Key found") |
| 42 | + return {} |
| 43 | + |
| 44 | + hybrid = HybridIndex(repo, vs) |
| 45 | + |
| 46 | + # Metrics |
| 47 | + hits_at_5 = 0 |
| 48 | + hits_at_10 = 0 |
| 49 | + mrr_sum = 0.0 |
| 50 | + total_queries = len(ground_truth) |
| 51 | + |
| 52 | + for item in ground_truth: |
| 53 | + query = item.get("query") |
| 54 | + expected_ids = set(item.get("expected_ids", [])) |
| 55 | + |
| 56 | + if not query or not expected_ids: |
| 57 | + continue |
| 58 | + |
| 59 | + q_vec = provider.embed_single(query) |
| 60 | + # Search directly on hybrid index (skipping SearchEngine wrapper for raw retrieval eval) |
| 61 | + results = hybrid.search(query, q_vec, limit=10) |
| 62 | + |
| 63 | + found_ids = [c.id for c, _ in results] |
| 64 | + |
| 65 | + # Recall@k |
| 66 | + if any(fid in expected_ids for fid in found_ids[:5]): |
| 67 | + hits_at_5 += 1 |
| 68 | + if any(fid in expected_ids for fid in found_ids[:10]): |
| 69 | + hits_at_10 += 1 |
| 70 | + |
| 71 | + # MRR |
| 72 | + rank = 0 |
| 73 | + for i, fid in enumerate(found_ids): |
| 74 | + if fid in expected_ids: |
| 75 | + rank = i + 1 |
| 76 | + break |
| 77 | + if rank > 0: |
| 78 | + mrr_sum += 1.0 / rank |
| 79 | + |
| 80 | + return { |
| 81 | + "precision_at_5": hits_at_5 / total_queries if total_queries else 0, |
| 82 | + "recall_at_10": hits_at_10 / total_queries if total_queries else 0, |
| 83 | + "mrr": mrr_sum / total_queries if total_queries else 0, |
| 84 | + } |
| 85 | + |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + if len(sys.argv) < 3: |
| 89 | + print("Usage: python evaluate.py <ground_truth.json> <index_dir>") |
| 90 | + sys.exit(1) |
| 91 | + |
| 92 | + gt_path = Path(sys.argv[1]) |
| 93 | + idx_path = Path(sys.argv[2]) |
| 94 | + print(evaluate(gt_path, idx_path)) |
0 commit comments