From 78587ec6f9001fdd6c0143cdf6dd6dca5c5ff2ce Mon Sep 17 00:00:00 2001 From: 1alyx Date: Wed, 18 Feb 2026 16:44:43 -0500 Subject: [PATCH] Fix multi-collection filtering in search/query/vsearch Co-authored-by: Alyx --- src/qmd.ts | 45 ++-------- src/store.ts | 88 ++++++++++++++----- test/cli.test.ts | 95 ++++++++++++++++++++- test/query-collection-routing.test.ts | 69 +++++++++++++++ test/searchvec-multi-collection.test.ts | 107 ++++++++++++++++++++++++ test/store.test.ts | 33 ++++++++ 6 files changed, 379 insertions(+), 58 deletions(-) create mode 100644 test/query-collection-routing.test.ts create mode 100644 test/searchvec-multi-collection.test.ts diff --git a/src/qmd.ts b/src/qmd.ts index 1a15ab90..fc831b53 100755 --- a/src/qmd.ts +++ b/src/qmd.ts @@ -1929,29 +1929,16 @@ function resolveCollectionFilter(raw: string | string[] | undefined): string[] { return validated; } -// Post-filter results to only include files from specified collections. -function filterByCollections(results: T[], collectionNames: string[]): T[] { - if (collectionNames.length <= 1) return results; - const prefixes = collectionNames.map(n => `qmd://${n}/`); - return results.filter(r => { - const path = r.filepath || r.file || ''; - return prefixes.some(p => path.startsWith(p)); - }); -} - function search(query: string, opts: OutputOptions): void { const db = getDb(); // Validate collection filter (supports multiple -c flags) const collectionNames = resolveCollectionFilter(opts.collection); - const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined; + const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined; // Use large limit for --all, otherwise fetch more than needed and let outputResults filter const fetchLimit = opts.all ? 100000 : Math.max(50, opts.limit * 2); - const results = filterByCollections( - searchFTS(db, query, fetchLimit, singleCollection), - collectionNames - ); + const results = searchFTS(db, query, fetchLimit, collectionFilter); // Add context to results const resultsWithContext = results.map(r => ({ @@ -1998,13 +1985,13 @@ async function vectorSearch(query: string, opts: OutputOptions, _model: string = // Validate collection filter (supports multiple -c flags) const collectionNames = resolveCollectionFilter(opts.collection); - const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined; + const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined; checkIndexHealth(store.db); await withLLMSession(async () => { - let results = await vectorSearchQuery(store, query, { - collection: singleCollection, + const results = await vectorSearchQuery(store, query, { + collection: collectionFilter, limit: opts.all ? 500 : (opts.limit || 10), minScore: opts.minScore || 0.3, hooks: { @@ -2015,14 +2002,6 @@ async function vectorSearch(query: string, opts: OutputOptions, _model: string = }, }); - // Post-filter for multi-collection - if (collectionNames.length > 1) { - results = results.filter(r => { - const prefixes = collectionNames.map(n => `qmd://${n}/`); - return prefixes.some(p => r.file.startsWith(p)); - }); - } - closeDb(); if (results.length === 0) { @@ -2051,13 +2030,13 @@ async function querySearch(query: string, opts: OutputOptions, _embedModel: stri // Validate collection filter (supports multiple -c flags) const collectionNames = resolveCollectionFilter(opts.collection); - const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined; + const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined; checkIndexHealth(store.db); await withLLMSession(async () => { - let results = await hybridQuery(store, query, { - collection: singleCollection, + const results = await hybridQuery(store, query, { + collection: collectionFilter, limit: opts.all ? 500 : (opts.limit || 10), minScore: opts.minScore || 0, hooks: { @@ -2078,14 +2057,6 @@ async function querySearch(query: string, opts: OutputOptions, _embedModel: stri }, }); - // Post-filter for multi-collection - if (collectionNames.length > 1) { - results = results.filter(r => { - const prefixes = collectionNames.map(n => `qmd://${n}/`); - return prefixes.some(p => r.file.startsWith(p)); - }); - } - closeDb(); if (results.length === 0) { diff --git a/src/store.ts b/src/store.ts index b68f8c0b..aeb5abb0 100644 --- a/src/store.ts +++ b/src/store.ts @@ -805,8 +805,8 @@ export type Store = { toVirtualPath: (absolutePath: string) => string | null; // Search - searchFTS: (query: string, limit?: number, collectionName?: string) => SearchResult[]; - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => Promise; + searchFTS: (query: string, limit?: number, collectionName?: string | string[]) => SearchResult[]; + searchVec: (query: string, model: string, limit?: number, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]) => Promise; // Query expansion & reranking expandQuery: (query: string, model?: string) => Promise; @@ -888,8 +888,8 @@ export function createStore(dbPath?: string): Store { toVirtualPath: (absolutePath: string) => toVirtualPath(db, absolutePath), // Search - searchFTS: (query: string, limit?: number, collectionName?: string) => searchFTS(db, query, limit, collectionName), - searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding), + searchFTS: (query: string, limit?: number, collectionName?: string | string[]) => searchFTS(db, query, limit, collectionName), + searchVec: (query: string, model: string, limit?: number, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding), // Query expansion & reranking expandQuery: (query: string, model?: string) => expandQuery(query, model, db), @@ -1996,10 +1996,37 @@ function buildFTS5Query(query: string): string | null { return terms.map(t => `"${t}"*`).join(' AND '); } -export function searchFTS(db: Database, query: string, limit: number = 20, collectionName?: string): SearchResult[] { +function normalizeCollectionFilter(collectionName?: string | string[]): string[] { + if (!collectionName) return []; + const names = Array.isArray(collectionName) ? collectionName : [collectionName]; + return names.filter((name): name is string => typeof name === "string" && name.length > 0); +} + +function appendCollectionFilterSql( + sql: string, + params: (string | number | Float32Array)[], + collectionNames: string[], +): string { + if (collectionNames.length === 0) return sql; + + if (collectionNames.length === 1) { + sql += ` AND d.collection = ?`; + params.push(collectionNames[0]!); + return sql; + } + + const placeholders = collectionNames.map(() => "?").join(","); + sql += ` AND d.collection IN (${placeholders})`; + params.push(...collectionNames); + return sql; +} + +export function searchFTS(db: Database, query: string, limit: number = 20, collectionName?: string | string[]): SearchResult[] { const ftsQuery = buildFTS5Query(query); if (!ftsQuery) return []; + const collectionNames = normalizeCollectionFilter(collectionName); + let sql = ` SELECT 'qmd://' || d.collection || '/' || d.path as filepath, @@ -2013,12 +2040,9 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle JOIN content ON content.hash = d.hash WHERE documents_fts MATCH ? AND d.active = 1 `; - const params: (string | number)[] = [ftsQuery]; + const params: (string | number | Float32Array)[] = [ftsQuery]; - if (collectionName) { - sql += ` AND d.collection = ?`; - params.push(String(collectionName)); - } + sql = appendCollectionFilterSql(sql, params, collectionNames); // bm25 lower is better; sort ascending. sql += ` ORDER BY bm25_score ASC LIMIT ?`; @@ -2053,24 +2077,51 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle // Vector Search // ============================================================================= -export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]): Promise { +export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]): Promise { const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get(); if (!tableExists) return []; const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session); if (!embedding) return []; + const collectionNames = normalizeCollectionFilter(collectionName); + // IMPORTANT: We use a two-step query approach here because sqlite-vec virtual tables // hang indefinitely when combined with JOINs in the same query. Do NOT try to // "optimize" this by combining into a single query with JOINs - it will break. // See: https://github.com/tobi/qmd/pull/23 + // Precompute eligible chunk ids for collection filtering so the top-k vector query + // itself is collection-scoped (avoids false-empty post-filter behavior). + let eligibleHashSeqs: string[] | null = null; + if (collectionNames.length > 0) { + const collectionPlaceholders = collectionNames.map(() => '?').join(','); + const eligibleRows = db.prepare(` + SELECT DISTINCT cv.hash || '_' || cv.seq as hash_seq + FROM content_vectors cv + JOIN documents d ON d.hash = cv.hash + WHERE d.active = 1 AND d.collection IN (${collectionPlaceholders}) + `).all(...collectionNames) as { hash_seq: string }[]; + + eligibleHashSeqs = eligibleRows.map(r => r.hash_seq); + if (eligibleHashSeqs.length === 0) return []; + } + // Step 1: Get vector matches from sqlite-vec (no JOINs allowed) - const vecResults = db.prepare(` + let vecSql = ` SELECT hash_seq, distance FROM vectors_vec WHERE embedding MATCH ? AND k = ? - `).all(new Float32Array(embedding), limit * 3) as { hash_seq: string; distance: number }[]; + `; + const vecParams: (Float32Array | number | string)[] = [new Float32Array(embedding), limit * 3]; + + if (eligibleHashSeqs && eligibleHashSeqs.length > 0) { + const hashSeqPlaceholders = eligibleHashSeqs.map(() => '?').join(','); + vecSql += ` AND hash_seq IN (${hashSeqPlaceholders})`; + vecParams.push(...eligibleHashSeqs); + } + + const vecResults = db.prepare(vecSql).all(...vecParams) as { hash_seq: string; distance: number }[]; if (vecResults.length === 0) return []; @@ -2094,12 +2145,9 @@ export async function searchVec(db: Database, query: string, model: string, limi JOIN content ON content.hash = d.hash WHERE cv.hash || '_' || cv.seq IN (${placeholders}) `; - const params: string[] = [...hashSeqs]; + const params: (string | number | Float32Array)[] = [...hashSeqs]; - if (collectionName) { - docSql += ` AND d.collection = ?`; - params.push(collectionName); - } + docSql = appendCollectionFilterSql(docSql, params, collectionNames); const docRows = db.prepare(docSql).all(...params) as { hash_seq: string; hash: string; pos: number; filepath: string; @@ -2771,7 +2819,7 @@ export interface SearchHooks { } export interface HybridQueryOptions { - collection?: string; + collection?: string | string[]; limit?: number; // default 10 minScore?: number; // default 0 candidateLimit?: number; // default RERANK_CANDIDATE_LIMIT @@ -2985,7 +3033,7 @@ export async function hybridQuery( } export interface VectorSearchOptions { - collection?: string; + collection?: string | string[]; limit?: number; // default 10 minScore?: number; // default 0.3 hooks?: Pick; diff --git a/test/cli.test.ts b/test/cli.test.ts index b723e7de..15c08245 100644 --- a/test/cli.test.ts +++ b/test/cli.test.ts @@ -5,7 +5,7 @@ * These tests spawn actual qmd processes to verify end-to-end functionality. */ -import { describe, test, expect, beforeAll, afterAll, beforeEach } from "vitest"; +import { describe, test, expect, beforeAll, afterAll, beforeEach, afterEach } from "vitest"; import { mkdtemp, rm, writeFile, mkdir } from "fs/promises"; import { existsSync } from "fs"; import { tmpdir } from "os"; @@ -525,6 +525,99 @@ describe("CLI Search with Collection Filter", () => { }); }); +describe("CLI Multi-Collection Filter Regression", () => { + let localDbPath: string; + let localConfigDir: string; + let localRoot: string; + + beforeEach(async () => { + const env = await createIsolatedTestEnv("multi-collection"); + localDbPath = env.dbPath; + localConfigDir = env.configDir; + + localRoot = await mkdtemp(join(testDir, "multi-collection-fixtures-")); + const noisyDir = join(localRoot, "noisy"); + const targetADir = join(localRoot, "target-a"); + const targetBDir = join(localRoot, "target-b"); + + await mkdir(noisyDir, { recursive: true }); + await mkdir(targetADir, { recursive: true }); + await mkdir(targetBDir, { recursive: true }); + + const noisyBody = Array.from({ length: 40 }, () => "dominator dominator dominator").join("\n"); + for (let i = 0; i < 80; i++) { + await writeFile( + join(noisyDir, `noise-${i}.md`), + `# dominator dominator ${i}\n\n${noisyBody}\n` + ); + } + + await writeFile( + join(targetADir, "hit-a.md"), + "# A target document\n\nContains dominator once." + ); + await writeFile( + join(targetBDir, "hit-b.md"), + "# B target document\n\nContains dominator once." + ); + + await runQmd(["collection", "add", noisyDir, "--name", "noisy", "--mask", "**/*.md"], { + cwd: localRoot, + dbPath: localDbPath, + configDir: localConfigDir, + }); + await runQmd(["collection", "add", targetADir, "--name", "target-a", "--mask", "**/*.md"], { + cwd: localRoot, + dbPath: localDbPath, + configDir: localConfigDir, + }); + await runQmd(["collection", "add", targetBDir, "--name", "target-b", "--mask", "**/*.md"], { + cwd: localRoot, + dbPath: localDbPath, + configDir: localConfigDir, + }); + }); + + afterEach(async () => { + if (localRoot) { + await rm(localRoot, { recursive: true, force: true }); + } + }); + + test("search with repeated -c filters before top-k ranking", async () => { + const { stdout, stderr, exitCode } = await runQmd([ + "search", + "--json", + "-n", + "10", + "-c", + "target-a", + "-c", + "target-b", + "dominator", + ], { + cwd: localRoot, + dbPath: localDbPath, + configDir: localConfigDir, + }); + + if (exitCode !== 0) { + console.log("Multi-collection search failed:"); + console.log("stdout:", stdout); + console.log("stderr:", stderr); + } + + expect(exitCode).toBe(0); + + const parsed = JSON.parse(stdout) as Array<{ file: string }>; + const files = parsed.map(row => row.file); + + expect(files.some(f => f.startsWith("qmd://target-a/"))).toBe(true); + expect(files.some(f => f.startsWith("qmd://target-b/"))).toBe(true); + expect(files.some(f => f.startsWith("qmd://noisy/"))).toBe(false); + }); +}); + describe("CLI Context Management", () => { let localDbPath: string; diff --git a/test/query-collection-routing.test.ts b/test/query-collection-routing.test.ts new file mode 100644 index 00000000..29312e39 --- /dev/null +++ b/test/query-collection-routing.test.ts @@ -0,0 +1,69 @@ +import { describe, test, expect, vi } from "vitest"; +import { + hybridQuery, + vectorSearchQuery, + type Store, + type SearchResult, +} from "../src/store.js"; + +describe("multi-collection routing in query pipelines", () => { + test("hybridQuery passes array collection filters into FTS", async () => { + const searchFTS = vi.fn().mockReturnValue([] as SearchResult[]); + + const store = { + db: { + prepare: () => ({ get: () => undefined }), // no vectors table + }, + searchFTS, + expandQuery: vi.fn().mockResolvedValue([]), + } as unknown as Store; + + await hybridQuery(store, "dominator", { + collection: ["target-a", "target-b"], + limit: 10, + }); + + expect(searchFTS).toHaveBeenCalledWith("dominator", 20, ["target-a", "target-b"]); + }); + + test("vectorSearchQuery passes array collection filters into vector search", async () => { + const vectorResult: SearchResult = { + filepath: "qmd://target-a/hit-a.md", + displayPath: "target-a/hit-a.md", + title: "Hit A", + hash: "abcdef123456", + docid: "abcdef", + collectionName: "target-a", + modifiedAt: "", + bodyLength: 10, + body: "dominator", + context: null, + score: 0.9, + source: "vec", + }; + + const searchVec = vi.fn().mockResolvedValue([vectorResult]); + + const store = { + db: { + prepare: () => ({ get: () => ({ name: "vectors_vec" }) }), + }, + expandQuery: vi.fn().mockResolvedValue([]), + searchVec, + getContextForFile: vi.fn().mockReturnValue(null), + } as unknown as Store; + + await vectorSearchQuery(store, "dominator", { + collection: ["target-a", "target-b"], + limit: 7, + minScore: 0, + }); + + expect(searchVec).toHaveBeenCalledWith( + "dominator", + expect.any(String), + 7, + ["target-a", "target-b"], + ); + }); +}); diff --git a/test/searchvec-multi-collection.test.ts b/test/searchvec-multi-collection.test.ts new file mode 100644 index 00000000..928a8441 --- /dev/null +++ b/test/searchvec-multi-collection.test.ts @@ -0,0 +1,107 @@ +import { describe, test, expect, beforeEach, afterEach } from "vitest"; +import { mkdtemp, mkdir, rm, writeFile } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import YAML from "yaml"; +import { createStore, hashContent, isSqliteVecAvailable, type Store } from "../src/store.js"; + +let testDir: string; +let configDir: string; +let store: Store; +let hasSqliteVec = false; + +async function insertVectorDoc( + store: Store, + collection: string, + path: string, + body: string, + embedding: number[], +): Promise { + const now = new Date().toISOString(); + const hash = await hashContent(`${collection}:${path}:${body}`); + + store.db.prepare(` + INSERT INTO content (hash, doc, created_at) + VALUES (?, ?, ?) + `).run(hash, body, now); + + store.db.prepare(` + INSERT INTO documents (collection, path, title, hash, created_at, modified_at, active) + VALUES (?, ?, ?, ?, ?, ?, 1) + `).run(collection, path, path.replace(/\.md$/, ""), hash, now, now); + + store.db.prepare(` + INSERT INTO content_vectors (hash, seq, pos, model, embedded_at) + VALUES (?, 0, 0, 'test', ?) + `).run(hash, now); + + store.db.prepare(` + INSERT INTO vectors_vec (hash_seq, embedding) + VALUES (?, ?) + `).run(`${hash}_0`, new Float32Array(embedding)); +} + +describe("searchVec multi-collection filtering", () => { + beforeEach(async () => { + testDir = await mkdtemp(join(tmpdir(), "qmd-searchvec-multi-")); + configDir = join(testDir, "config"); + await mkdir(configDir, { recursive: true }); + + await writeFile( + join(configDir, "index.yml"), + YAML.stringify({ + collections: { + "target-a": { path: "/tmp/target-a", pattern: "**/*.md" }, + "target-b": { path: "/tmp/target-b", pattern: "**/*.md" }, + noisy: { path: "/tmp/noisy", pattern: "**/*.md" }, + }, + }) + ); + + process.env.QMD_CONFIG_DIR = configDir; + store = createStore(join(testDir, "index.sqlite")); + hasSqliteVec = isSqliteVecAvailable(); + if (hasSqliteVec) { + store.ensureVecTable(3); + } + }); + + afterEach(async () => { + if (store) { + store.close(); + } + delete process.env.QMD_CONFIG_DIR; + if (testDir) { + await rm(testDir, { recursive: true, force: true }); + } + }); + + test("supports collection IN filters for repeated -c semantics", async () => { + if (!hasSqliteVec) { + return; + } + + await insertVectorDoc(store, "target-a", "a.md", "dominator target a", [0.99, 0.01, 0]); + await insertVectorDoc(store, "target-b", "b.md", "dominator target b", [0.98, 0.02, 0]); + + // Create many highly similar vectors in an unrequested collection to mimic top-k domination. + for (let i = 0; i < 60; i++) { + await insertVectorDoc(store, "noisy", `noise-${i}.md`, `dominator noisy ${i}`, [1, 0, 0]); + } + + const results = await store.searchVec( + "dominator", + "embeddinggemma", + 10, + ["target-a", "target-b"], + undefined, + [1, 0, 0], + ); + + const collections = new Set(results.map(r => r.collectionName)); + + expect(collections.has("target-a")).toBe(true); + expect(collections.has("target-b")).toBe(true); + expect(collections.has("noisy")).toBe(false); + }); +}); diff --git a/test/store.test.ts b/test/store.test.ts index 9c384770..59c6a576 100644 --- a/test/store.test.ts +++ b/test/store.test.ts @@ -1228,6 +1228,39 @@ describe("FTS Search", () => { await cleanupTestDb(store); }); + test("searchFTS filters by multiple collection names", async () => { + const store = await createTestStore(); + const collection1 = await createTestCollection({ pwd: "/path/one", glob: "**/*.md", name: "one" }); + const collection2 = await createTestCollection({ pwd: "/path/two", glob: "**/*.md", name: "two" }); + const collection3 = await createTestCollection({ pwd: "/path/three", glob: "**/*.md", name: "three" }); + + await insertTestDocument(store.db, collection1, { + name: "doc1", + body: "searchable content", + displayPath: "doc1.md", + }); + + await insertTestDocument(store.db, collection2, { + name: "doc2", + body: "searchable content", + displayPath: "doc2.md", + }); + + await insertTestDocument(store.db, collection3, { + name: "doc3", + body: "searchable content", + displayPath: "doc3.md", + }); + + const filtered = store.searchFTS("searchable", 10, [collection1, collection2]); + expect(filtered).toHaveLength(2); + expect(filtered.every(r => r.collectionName === collection1 || r.collectionName === collection2)).toBe(true); + expect(filtered.some(r => r.collectionName === collection1)).toBe(true); + expect(filtered.some(r => r.collectionName === collection2)).toBe(true); + + await cleanupTestDb(store); + }); + test("searchFTS handles special characters in query", async () => { const store = await createTestStore(); const collectionName = await createTestCollection();