Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 8 additions & 37 deletions src/qmd.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1929,29 +1929,16 @@ function resolveCollectionFilter(raw: string | string[] | undefined): string[] {
return validated;
}

// Post-filter results to only include files from specified collections.
function filterByCollections<T extends { filepath?: string; file?: string }>(results: T[], collectionNames: string[]): T[] {
if (collectionNames.length <= 1) return results;
const prefixes = collectionNames.map(n => `qmd://${n}/`);
return results.filter(r => {
const path = r.filepath || r.file || '';
return prefixes.some(p => path.startsWith(p));
});
}

function search(query: string, opts: OutputOptions): void {
const db = getDb();

// Validate collection filter (supports multiple -c flags)
const collectionNames = resolveCollectionFilter(opts.collection);
const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined;
const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined;

// Use large limit for --all, otherwise fetch more than needed and let outputResults filter
const fetchLimit = opts.all ? 100000 : Math.max(50, opts.limit * 2);
const results = filterByCollections(
searchFTS(db, query, fetchLimit, singleCollection),
collectionNames
);
const results = searchFTS(db, query, fetchLimit, collectionFilter);

// Add context to results
const resultsWithContext = results.map(r => ({
Expand Down Expand Up @@ -1998,13 +1985,13 @@ async function vectorSearch(query: string, opts: OutputOptions, _model: string =

// Validate collection filter (supports multiple -c flags)
const collectionNames = resolveCollectionFilter(opts.collection);
const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined;
const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined;

checkIndexHealth(store.db);

await withLLMSession(async () => {
let results = await vectorSearchQuery(store, query, {
collection: singleCollection,
const results = await vectorSearchQuery(store, query, {
collection: collectionFilter,
limit: opts.all ? 500 : (opts.limit || 10),
minScore: opts.minScore || 0.3,
hooks: {
Expand All @@ -2015,14 +2002,6 @@ async function vectorSearch(query: string, opts: OutputOptions, _model: string =
},
});

// Post-filter for multi-collection
if (collectionNames.length > 1) {
results = results.filter(r => {
const prefixes = collectionNames.map(n => `qmd://${n}/`);
return prefixes.some(p => r.file.startsWith(p));
});
}

closeDb();

if (results.length === 0) {
Expand Down Expand Up @@ -2051,13 +2030,13 @@ async function querySearch(query: string, opts: OutputOptions, _embedModel: stri

// Validate collection filter (supports multiple -c flags)
const collectionNames = resolveCollectionFilter(opts.collection);
const singleCollection = collectionNames.length === 1 ? collectionNames[0] : undefined;
const collectionFilter = collectionNames.length > 0 ? collectionNames : undefined;

checkIndexHealth(store.db);

await withLLMSession(async () => {
let results = await hybridQuery(store, query, {
collection: singleCollection,
const results = await hybridQuery(store, query, {
collection: collectionFilter,
limit: opts.all ? 500 : (opts.limit || 10),
minScore: opts.minScore || 0,
hooks: {
Expand All @@ -2078,14 +2057,6 @@ async function querySearch(query: string, opts: OutputOptions, _embedModel: stri
},
});

// Post-filter for multi-collection
if (collectionNames.length > 1) {
results = results.filter(r => {
const prefixes = collectionNames.map(n => `qmd://${n}/`);
return prefixes.some(p => r.file.startsWith(p));
});
}

closeDb();

if (results.length === 0) {
Expand Down
88 changes: 68 additions & 20 deletions src/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -805,8 +805,8 @@ export type Store = {
toVirtualPath: (absolutePath: string) => string | null;

// Search
searchFTS: (query: string, limit?: number, collectionName?: string) => SearchResult[];
searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => Promise<SearchResult[]>;
searchFTS: (query: string, limit?: number, collectionName?: string | string[]) => SearchResult[];
searchVec: (query: string, model: string, limit?: number, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]) => Promise<SearchResult[]>;

// Query expansion & reranking
expandQuery: (query: string, model?: string) => Promise<ExpandedQuery[]>;
Expand Down Expand Up @@ -888,8 +888,8 @@ export function createStore(dbPath?: string): Store {
toVirtualPath: (absolutePath: string) => toVirtualPath(db, absolutePath),

// Search
searchFTS: (query: string, limit?: number, collectionName?: string) => searchFTS(db, query, limit, collectionName),
searchVec: (query: string, model: string, limit?: number, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding),
searchFTS: (query: string, limit?: number, collectionName?: string | string[]) => searchFTS(db, query, limit, collectionName),
searchVec: (query: string, model: string, limit?: number, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]) => searchVec(db, query, model, limit, collectionName, session, precomputedEmbedding),

// Query expansion & reranking
expandQuery: (query: string, model?: string) => expandQuery(query, model, db),
Expand Down Expand Up @@ -1996,10 +1996,37 @@ function buildFTS5Query(query: string): string | null {
return terms.map(t => `"${t}"*`).join(' AND ');
}

export function searchFTS(db: Database, query: string, limit: number = 20, collectionName?: string): SearchResult[] {
function normalizeCollectionFilter(collectionName?: string | string[]): string[] {
if (!collectionName) return [];
const names = Array.isArray(collectionName) ? collectionName : [collectionName];
return names.filter((name): name is string => typeof name === "string" && name.length > 0);
}

function appendCollectionFilterSql(
sql: string,
params: (string | number | Float32Array)[],
collectionNames: string[],
): string {
if (collectionNames.length === 0) return sql;

if (collectionNames.length === 1) {
sql += ` AND d.collection = ?`;
params.push(collectionNames[0]!);
return sql;
}

const placeholders = collectionNames.map(() => "?").join(",");
sql += ` AND d.collection IN (${placeholders})`;
params.push(...collectionNames);
return sql;
}

export function searchFTS(db: Database, query: string, limit: number = 20, collectionName?: string | string[]): SearchResult[] {
const ftsQuery = buildFTS5Query(query);
if (!ftsQuery) return [];

const collectionNames = normalizeCollectionFilter(collectionName);

let sql = `
SELECT
'qmd://' || d.collection || '/' || d.path as filepath,
Expand All @@ -2013,12 +2040,9 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
JOIN content ON content.hash = d.hash
WHERE documents_fts MATCH ? AND d.active = 1
`;
const params: (string | number)[] = [ftsQuery];
const params: (string | number | Float32Array)[] = [ftsQuery];

if (collectionName) {
sql += ` AND d.collection = ?`;
params.push(String(collectionName));
}
sql = appendCollectionFilterSql(sql, params, collectionNames);

// bm25 lower is better; sort ascending.
sql += ` ORDER BY bm25_score ASC LIMIT ?`;
Expand Down Expand Up @@ -2053,24 +2077,51 @@ export function searchFTS(db: Database, query: string, limit: number = 20, colle
// Vector Search
// =============================================================================

export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string, session?: ILLMSession, precomputedEmbedding?: number[]): Promise<SearchResult[]> {
export async function searchVec(db: Database, query: string, model: string, limit: number = 20, collectionName?: string | string[], session?: ILLMSession, precomputedEmbedding?: number[]): Promise<SearchResult[]> {
const tableExists = db.prepare(`SELECT name FROM sqlite_master WHERE type='table' AND name='vectors_vec'`).get();
if (!tableExists) return [];

const embedding = precomputedEmbedding ?? await getEmbedding(query, model, true, session);
if (!embedding) return [];

const collectionNames = normalizeCollectionFilter(collectionName);

// IMPORTANT: We use a two-step query approach here because sqlite-vec virtual tables
// hang indefinitely when combined with JOINs in the same query. Do NOT try to
// "optimize" this by combining into a single query with JOINs - it will break.
// See: https://github.com/tobi/qmd/pull/23

// Precompute eligible chunk ids for collection filtering so the top-k vector query
// itself is collection-scoped (avoids false-empty post-filter behavior).
let eligibleHashSeqs: string[] | null = null;
if (collectionNames.length > 0) {
const collectionPlaceholders = collectionNames.map(() => '?').join(',');
const eligibleRows = db.prepare(`
SELECT DISTINCT cv.hash || '_' || cv.seq as hash_seq
FROM content_vectors cv
JOIN documents d ON d.hash = cv.hash
WHERE d.active = 1 AND d.collection IN (${collectionPlaceholders})
`).all(...collectionNames) as { hash_seq: string }[];

eligibleHashSeqs = eligibleRows.map(r => r.hash_seq);
if (eligibleHashSeqs.length === 0) return [];
}

// Step 1: Get vector matches from sqlite-vec (no JOINs allowed)
const vecResults = db.prepare(`
let vecSql = `
SELECT hash_seq, distance
FROM vectors_vec
WHERE embedding MATCH ? AND k = ?
`).all(new Float32Array(embedding), limit * 3) as { hash_seq: string; distance: number }[];
`;
const vecParams: (Float32Array | number | string)[] = [new Float32Array(embedding), limit * 3];

if (eligibleHashSeqs && eligibleHashSeqs.length > 0) {
const hashSeqPlaceholders = eligibleHashSeqs.map(() => '?').join(',');
vecSql += ` AND hash_seq IN (${hashSeqPlaceholders})`;
vecParams.push(...eligibleHashSeqs);
}

const vecResults = db.prepare(vecSql).all(...vecParams) as { hash_seq: string; distance: number }[];

if (vecResults.length === 0) return [];

Expand All @@ -2094,12 +2145,9 @@ export async function searchVec(db: Database, query: string, model: string, limi
JOIN content ON content.hash = d.hash
WHERE cv.hash || '_' || cv.seq IN (${placeholders})
`;
const params: string[] = [...hashSeqs];
const params: (string | number | Float32Array)[] = [...hashSeqs];

if (collectionName) {
docSql += ` AND d.collection = ?`;
params.push(collectionName);
}
docSql = appendCollectionFilterSql(docSql, params, collectionNames);

const docRows = db.prepare(docSql).all(...params) as {
hash_seq: string; hash: string; pos: number; filepath: string;
Expand Down Expand Up @@ -2771,7 +2819,7 @@ export interface SearchHooks {
}

export interface HybridQueryOptions {
collection?: string;
collection?: string | string[];
limit?: number; // default 10
minScore?: number; // default 0
candidateLimit?: number; // default RERANK_CANDIDATE_LIMIT
Expand Down Expand Up @@ -2985,7 +3033,7 @@ export async function hybridQuery(
}

export interface VectorSearchOptions {
collection?: string;
collection?: string | string[];
limit?: number; // default 10
minScore?: number; // default 0.3
hooks?: Pick<SearchHooks, 'onExpand'>;
Expand Down
95 changes: 94 additions & 1 deletion test/cli.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* These tests spawn actual qmd processes to verify end-to-end functionality.
*/

import { describe, test, expect, beforeAll, afterAll, beforeEach } from "vitest";
import { describe, test, expect, beforeAll, afterAll, beforeEach, afterEach } from "vitest";
import { mkdtemp, rm, writeFile, mkdir } from "fs/promises";
import { existsSync } from "fs";
import { tmpdir } from "os";
Expand Down Expand Up @@ -525,6 +525,99 @@ describe("CLI Search with Collection Filter", () => {
});
});

describe("CLI Multi-Collection Filter Regression", () => {
let localDbPath: string;
let localConfigDir: string;
let localRoot: string;

beforeEach(async () => {
const env = await createIsolatedTestEnv("multi-collection");
localDbPath = env.dbPath;
localConfigDir = env.configDir;

localRoot = await mkdtemp(join(testDir, "multi-collection-fixtures-"));
const noisyDir = join(localRoot, "noisy");
const targetADir = join(localRoot, "target-a");
const targetBDir = join(localRoot, "target-b");

await mkdir(noisyDir, { recursive: true });
await mkdir(targetADir, { recursive: true });
await mkdir(targetBDir, { recursive: true });

const noisyBody = Array.from({ length: 40 }, () => "dominator dominator dominator").join("\n");
for (let i = 0; i < 80; i++) {
await writeFile(
join(noisyDir, `noise-${i}.md`),
`# dominator dominator ${i}\n\n${noisyBody}\n`
);
}

await writeFile(
join(targetADir, "hit-a.md"),
"# A target document\n\nContains dominator once."
);
await writeFile(
join(targetBDir, "hit-b.md"),
"# B target document\n\nContains dominator once."
);

await runQmd(["collection", "add", noisyDir, "--name", "noisy", "--mask", "**/*.md"], {
cwd: localRoot,
dbPath: localDbPath,
configDir: localConfigDir,
});
await runQmd(["collection", "add", targetADir, "--name", "target-a", "--mask", "**/*.md"], {
cwd: localRoot,
dbPath: localDbPath,
configDir: localConfigDir,
});
await runQmd(["collection", "add", targetBDir, "--name", "target-b", "--mask", "**/*.md"], {
cwd: localRoot,
dbPath: localDbPath,
configDir: localConfigDir,
});
});

afterEach(async () => {
if (localRoot) {
await rm(localRoot, { recursive: true, force: true });
}
});

test("search with repeated -c filters before top-k ranking", async () => {
const { stdout, stderr, exitCode } = await runQmd([
"search",
"--json",
"-n",
"10",
"-c",
"target-a",
"-c",
"target-b",
"dominator",
], {
cwd: localRoot,
dbPath: localDbPath,
configDir: localConfigDir,
});

if (exitCode !== 0) {
console.log("Multi-collection search failed:");
console.log("stdout:", stdout);
console.log("stderr:", stderr);
}

expect(exitCode).toBe(0);

const parsed = JSON.parse(stdout) as Array<{ file: string }>;
const files = parsed.map(row => row.file);

expect(files.some(f => f.startsWith("qmd://target-a/"))).toBe(true);
expect(files.some(f => f.startsWith("qmd://target-b/"))).toBe(true);
expect(files.some(f => f.startsWith("qmd://noisy/"))).toBe(false);
});
});

describe("CLI Context Management", () => {
let localDbPath: string;

Expand Down
Loading