From 23874e9c3c1ac21b501f781cdf4addcc33167dc1 Mon Sep 17 00:00:00 2001 From: fryeggs Date: Fri, 20 Mar 2026 23:53:07 +0800 Subject: [PATCH] feat: improve manual retrieval diagnostics and query expansion --- cli.ts | 116 ++++++- package-lock.json | 13 + package.json | 1 + src/query-expander.ts | 116 +++++++ src/retriever.ts | 554 ++++++++++++++++++++++------- test/query-expander.test.mjs | 656 +++++++++++++++++++++++++++++++++++ 6 files changed, 1334 insertions(+), 122 deletions(-) create mode 100644 src/query-expander.ts create mode 100644 test/query-expander.test.mjs diff --git a/cli.ts b/cli.ts index 99203916..42c87714 100644 --- a/cli.ts +++ b/cli.ts @@ -409,6 +409,71 @@ function formatJson(obj: any): string { return JSON.stringify(obj, null, 2); } +function formatRetrievalDiagnosticsLines(diagnostics: { + originalQuery: string; + bm25Query: string | null; + queryExpanded: boolean; + vectorResultCount: number; + bm25ResultCount: number; + fusedResultCount: number; + finalResultCount: number; + stageCounts: { + afterMinScore: number; + rerankInput: number; + afterRerank: number; + afterHardMinScore: number; + afterNoiseFilter: number; + afterDiversity: number; + }; + dropSummary: Array<{ stage: string; dropped: number; before: number; after: number }>; + failureStage?: string; + errorMessage?: string; +}): string[] { + const topDrops = + diagnostics.dropSummary.length > 0 + ? diagnostics.dropSummary + .slice(0, 3) + .map( + (drop) => `${drop.stage} -${drop.dropped} (${drop.before}->${drop.after})`, + ) + .join(", ") + : "none"; + + const lines = [ + "Retrieval diagnostics:", + ` • Original query: ${diagnostics.originalQuery}`, + ` • BM25 query: ${diagnostics.bm25Query ?? "(disabled)"}`, + ` • Query expanded: ${diagnostics.queryExpanded ? "Yes" : "No"}`, + ` • Counts: vector=${diagnostics.vectorResultCount}, bm25=${diagnostics.bm25ResultCount}, fused=${diagnostics.fusedResultCount}, final=${diagnostics.finalResultCount}`, + ` • Stages: min=${diagnostics.stageCounts.afterMinScore}, rerankIn=${diagnostics.stageCounts.rerankInput}, rerank=${diagnostics.stageCounts.afterRerank}, hard=${diagnostics.stageCounts.afterHardMinScore}, noise=${diagnostics.stageCounts.afterNoiseFilter}, diversity=${diagnostics.stageCounts.afterDiversity}`, + ` • Drops: ${topDrops}`, + ]; + + if (diagnostics.failureStage) { + lines.push(` • Failure stage: ${diagnostics.failureStage}`); + } + if (diagnostics.errorMessage) { + lines.push(` • Error: ${diagnostics.errorMessage}`); + } + + return lines; +} + +function buildSearchErrorPayload( + error: unknown, + diagnostics: unknown, + includeDiagnostics: boolean, +): Record { + const message = error instanceof Error ? error.message : String(error); + return { + error: { + code: "search_failed", + message, + }, + ...(includeDiagnostics && diagnostics ? { diagnostics } : {}), + }; +} + async function sleep(ms: number): Promise { await new Promise(resolve => setTimeout(resolve, ms)); } @@ -431,7 +496,8 @@ export function registerMemoryCLI(program: Command, context: CLIContext): void { scopeFilter?: string[], category?: string, ) => { - let results = await getSearchRetriever().retrieve({ + const retriever = getSearchRetriever(); + let results = await retriever.retrieve({ query, limit, scopeFilter, @@ -441,16 +507,30 @@ export function registerMemoryCLI(program: Command, context: CLIContext): void { if (results.length === 0 && context.embedder) { await sleep(75); - results = await getSearchRetriever().retrieve({ + const retryRetriever = getSearchRetriever(); + results = await retryRetriever.retrieve({ query, limit, scopeFilter, category, source: "cli", }); + return { + results, + diagnostics: + typeof retryRetriever.getLastDiagnostics === "function" + ? retryRetriever.getLastDiagnostics() + : null, + }; } - return results; + return { + results, + diagnostics: + typeof retriever.getLastDiagnostics === "function" + ? retriever.getLastDiagnostics() + : null, + }; }; const memory = program @@ -697,6 +777,7 @@ export function registerMemoryCLI(program: Command, context: CLIContext): void { .option("--scope ", "Search within specific scope") .option("--category ", "Filter by category") .option("--limit ", "Maximum number of results", "10") + .option("--debug", "Show retrieval diagnostics") .option("--json", "Output as JSON") .action(async (query, options) => { try { @@ -707,11 +788,24 @@ export function registerMemoryCLI(program: Command, context: CLIContext): void { scopeFilter = [options.scope]; } - const results = await runSearch(query, limit, scopeFilter, options.category); + const { results, diagnostics } = await runSearch( + query, + limit, + scopeFilter, + options.category, + ); if (options.json) { - console.log(formatJson(results)); + console.log( + formatJson(options.debug ? { diagnostics, results } : results), + ); } else { + if (options.debug && diagnostics) { + for (const line of formatRetrievalDiagnosticsLines(diagnostics)) { + console.log(line); + } + console.log(); + } if (results.length === 0) { console.log("No relevant memories found."); } else { @@ -730,6 +824,18 @@ export function registerMemoryCLI(program: Command, context: CLIContext): void { } } } catch (error) { + const diagnostics = options.debug ? context.retriever.getLastDiagnostics?.() : null; + if (options.json) { + console.log( + formatJson(buildSearchErrorPayload(error, diagnostics, options.debug)), + ); + process.exit(1); + } + if (diagnostics) { + for (const line of formatRetrievalDiagnosticsLines(diagnostics)) { + console.error(line); + } + } console.error("Search failed:", error); process.exit(1); } diff --git a/package-lock.json b/package-lock.json index de850adf..c67ebbda 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@lancedb/lancedb": "^0.26.2", "@sinclair/typebox": "0.34.48", "apache-arrow": "18.1.0", + "json5": "^2.2.3", "openai": "^6.21.0" }, "devDependencies": { @@ -397,6 +398,18 @@ "node": ">=0.8" } }, + "node_modules/json5": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/json5/-/json5-2.2.3.tgz", + "integrity": "sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==", + "license": "MIT", + "bin": { + "json5": "lib/cli.js" + }, + "engines": { + "node": ">=6" + } + }, "node_modules/lodash.camelcase": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/lodash.camelcase/-/lodash.camelcase-4.3.0.tgz", diff --git a/package.json b/package.json index 4f6c0d32..bc7a90ca 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "@lancedb/lancedb": "^0.26.2", "@sinclair/typebox": "0.34.48", "apache-arrow": "18.1.0", + "json5": "^2.2.3", "openai": "^6.21.0" }, "openclaw": { diff --git a/src/query-expander.ts b/src/query-expander.ts new file mode 100644 index 00000000..4f00c565 --- /dev/null +++ b/src/query-expander.ts @@ -0,0 +1,116 @@ +/** + * Lightweight Chinese query expansion for BM25. + * Keeps the vector query untouched and only appends a few high-signal synonyms. + */ + +const MAX_EXPANSION_TERMS = 5; + +interface SynonymEntry { + cn: string[]; + en: string[]; + expansions: string[]; +} + +const SYNONYM_MAP: SynonymEntry[] = [ + { + cn: ["挂了", "挂掉", "宕机"], + en: ["shutdown", "crashed"], + expansions: ["崩溃", "crash", "error", "报错", "宕机", "失败"], + }, + { + cn: ["卡住", "卡死", "没反应"], + en: ["hung", "frozen"], + expansions: ["hang", "timeout", "超时", "无响应", "stuck"], + }, + { + cn: ["炸了", "爆了"], + en: ["oom"], + expansions: ["崩溃", "crash", "OOM", "内存溢出", "error"], + }, + { + cn: ["配置", "设置"], + en: ["config", "configuration"], + expansions: ["配置", "config", "configuration", "settings", "设置"], + }, + { + cn: ["部署", "上线"], + en: ["deploy", "deployment"], + expansions: ["deploy", "部署", "上线", "发布", "release"], + }, + { + cn: ["容器"], + en: ["docker", "container"], + expansions: ["Docker", "容器", "container", "docker-compose"], + }, + { + cn: ["报错", "出错", "错误"], + en: ["error", "exception"], + expansions: ["error", "报错", "exception", "错误", "失败", "bug"], + }, + { + cn: ["修复", "修了", "修好"], + en: ["bugfix", "hotfix"], + expansions: ["fix", "修复", "patch", "解决"], + }, + { + cn: ["踩坑"], + en: ["troubleshoot"], + expansions: ["踩坑", "bug", "问题", "教训", "排查", "troubleshoot"], + }, + { + cn: ["记忆", "记忆系统"], + en: ["memory"], + expansions: ["记忆", "memory", "记忆系统", "LanceDB", "索引"], + }, + { + cn: ["搜索", "查找", "找不到"], + en: ["search", "retrieval"], + expansions: ["搜索", "search", "retrieval", "检索", "查找"], + }, + { + cn: ["推送"], + en: ["git push"], + expansions: ["push", "推送", "git push", "commit"], + }, + { + cn: ["日志"], + en: ["logfile", "logging"], + expansions: ["日志", "log", "logging", "输出", "打印"], + }, + { + cn: ["权限"], + en: ["permission", "authorization"], + expansions: ["权限", "permission", "access", "授权", "认证"], + }, +]; + +function buildWordBoundaryRegex(term: string): RegExp { + const escaped = term.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); + return new RegExp(`\\b${escaped}\\b`, "i"); +} + +export function expandQuery(query: string): string { + if (!query || query.trim().length < 2) return query; + + const lower = query.toLowerCase(); + const additions = new Set(); + + for (const entry of SYNONYM_MAP) { + const cnMatch = entry.cn.some((term) => lower.includes(term.toLowerCase())); + const enMatch = entry.en.some((term) => buildWordBoundaryRegex(term).test(query)); + + if (!cnMatch && !enMatch) continue; + + for (const expansion of entry.expansions) { + if (!lower.includes(expansion.toLowerCase())) { + additions.add(expansion); + } + if (additions.size >= MAX_EXPANSION_TERMS) break; + } + + if (additions.size >= MAX_EXPANSION_TERMS) break; + } + + if (additions.size === 0) return query; + return `${query} ${[...additions].join(" ")}`; +} diff --git a/src/retriever.ts b/src/retriever.ts index 5d50cf11..c65acdbf 100644 --- a/src/retriever.ts +++ b/src/retriever.ts @@ -11,6 +11,7 @@ import { parseAccessMetadata, } from "./access-tracker.js"; import { filterNoise } from "./noise-filter.js"; +import { expandQuery } from "./query-expander.js"; import type { DecayEngine, DecayableMemory } from "./decay-engine.js"; import type { TierManager } from "./tier-manager.js"; import { @@ -28,6 +29,8 @@ export interface RetrievalConfig { mode: "hybrid" | "vector"; vectorWeight: number; bm25Weight: number; + /** Expand BM25 queries with high-signal synonyms for manual / CLI retrieval. */ + queryExpansion: boolean; minScore: number; rerank: "cross-encoder" | "lightweight" | "none"; candidatePoolSize: number; @@ -48,12 +51,14 @@ export interface RetrievalConfig { * - "siliconflow": same format as jina (alias, for clarity) * - "voyage": Authorization: Bearer, string[] documents, data[].relevance_score * - "pinecone": Api-Key header, {text}[] documents, data[].score + * - "vllm": Local vLLM-compatible rerank endpoint, no auth required * - "tei": Authorization: Bearer, string[] texts, top-level [{ index, score }] */ rerankProvider?: | "jina" | "siliconflow" | "voyage" | "pinecone" + | "vllm" | "dashscope" | "tei"; /** @@ -105,6 +110,62 @@ export interface RetrievalResult extends MemorySearchResult { }; } +export interface RetrievalDiagnostics { + source?: RetrievalContext["source"]; + mode: RetrievalConfig["mode"]; + originalQuery: string; + bm25Query: string | null; + queryExpanded: boolean; + limit: number; + scopeFilter?: string[]; + category?: string; + vectorResultCount: number; + bm25ResultCount: number; + fusedResultCount: number; + finalResultCount: number; + stageCounts: { + afterMinScore: number; + rerankInput: number; + afterRerank: number; + afterRecency: number; + afterImportance: number; + afterLengthNorm: number; + afterTimeDecay: number; + afterHardMinScore: number; + afterNoiseFilter: number; + afterDiversity: number; + }; + dropSummary: Array<{ + stage: + | "minScore" + | "rerankWindow" + | "rerank" + | "recencyBoost" + | "importanceWeight" + | "lengthNorm" + | "timeDecay" + | "hardMinScore" + | "noiseFilter" + | "diversity" + | "limit"; + before: number; + after: number; + dropped: number; + }>; + failureStage?: + | "vector.embedQuery" + | "vector.vectorSearch" + | "vector.postProcess" + | "hybrid.embedQuery" + | "hybrid.vectorSearch" + | "hybrid.bm25Search" + | "hybrid.parallelSearch" + | "hybrid.fuseResults" + | "hybrid.rerank" + | "hybrid.postProcess"; + errorMessage?: string; +} + // ============================================================================ // Default Configuration // ============================================================================ @@ -113,6 +174,7 @@ export const DEFAULT_RETRIEVAL_CONFIG: RetrievalConfig = { mode: "hybrid", vectorWeight: 0.7, bm25Weight: 0.3, + queryExpansion: true, minScore: 0.3, rerank: "cross-encoder", candidatePoolSize: 20, @@ -147,6 +209,116 @@ function clamp01WithFloor(value: number, floor: number): number { return Math.max(safeFloor, clamp01(value, safeFloor)); } +type TaggedRetrievalError = Error & { + retrievalFailureStage?: NonNullable; +}; + +function attachFailureStage( + error: unknown, + stage: NonNullable, +): TaggedRetrievalError { + const tagged = + error instanceof Error ? (error as TaggedRetrievalError) : new Error(String(error)); + tagged.retrievalFailureStage = stage; + return tagged; +} + +function extractFailureStage( + error: unknown, +): RetrievalDiagnostics["failureStage"] | undefined { + return error instanceof Error + ? (error as TaggedRetrievalError).retrievalFailureStage + : undefined; +} + +function buildDropSummary( + diagnostics: RetrievalDiagnostics, +): RetrievalDiagnostics["dropSummary"] { + const stageDrops = [ + { + order: 0, + stage: "minScore" as const, + before: + diagnostics.mode === "vector" + ? diagnostics.vectorResultCount + : diagnostics.fusedResultCount, + after: diagnostics.stageCounts.afterMinScore, + }, + { + order: 1, + stage: "rerankWindow" as const, + before: diagnostics.stageCounts.afterMinScore, + after: diagnostics.stageCounts.rerankInput, + }, + { + order: 2, + stage: "rerank" as const, + before: diagnostics.stageCounts.rerankInput, + after: diagnostics.stageCounts.afterRerank, + }, + { + order: 3, + stage: "recencyBoost" as const, + before: diagnostics.stageCounts.afterRerank, + after: diagnostics.stageCounts.afterRecency, + }, + { + order: 4, + stage: "importanceWeight" as const, + before: diagnostics.stageCounts.afterRecency, + after: diagnostics.stageCounts.afterImportance, + }, + { + order: 5, + stage: "lengthNorm" as const, + before: diagnostics.stageCounts.afterImportance, + after: diagnostics.stageCounts.afterLengthNorm, + }, + { + order: 6, + stage: "hardMinScore" as const, + before: diagnostics.stageCounts.afterLengthNorm, + after: diagnostics.stageCounts.afterHardMinScore, + }, + { + order: 7, + stage: "timeDecay" as const, + before: diagnostics.stageCounts.afterHardMinScore, + after: diagnostics.stageCounts.afterTimeDecay, + }, + { + order: 8, + stage: "noiseFilter" as const, + before: diagnostics.stageCounts.afterTimeDecay, + after: diagnostics.stageCounts.afterNoiseFilter, + }, + { + order: 9, + stage: "diversity" as const, + before: diagnostics.stageCounts.afterNoiseFilter, + after: diagnostics.stageCounts.afterDiversity, + }, + { + order: 10, + stage: "limit" as const, + before: diagnostics.stageCounts.afterDiversity, + after: diagnostics.finalResultCount, + }, + ]; + + return stageDrops + .map(({ order, stage, before, after }) => ({ + order, + stage, + before, + after, + dropped: Math.max(0, before - after), + })) + .filter((drop) => drop.dropped > 0) + .sort((a, b) => b.dropped - a.dropped || a.order - b.order) + .map(({ order: _order, ...drop }) => drop); +} + // ============================================================================ // Rerank Provider Adapters // ============================================================================ @@ -156,6 +328,7 @@ type RerankProvider = | "siliconflow" | "voyage" | "pinecone" + | "vllm" | "dashscope" | "tei"; @@ -174,6 +347,18 @@ function buildRerankRequest( topN: number, ): { headers: Record; body: Record } { switch (provider) { + case "vllm": + return { + headers: { + "Content-Type": "application/json", + }, + body: { + model, + query, + documents: candidates, + top_n: topN, + }, + }; case "tei": return { headers: { @@ -314,10 +499,11 @@ function parseRerankResponse( parseItems(objectData?.results, ["relevance_score", "score"]) ); } + case "vllm": case "siliconflow": case "jina": default: { - // Jina / SiliconFlow: usually { results: [{ index, relevance_score }] } + // Jina / SiliconFlow / vLLM: usually { results: [{ index, relevance_score }] } // Also tolerate data[] for compatibility across gateways. return ( parseItems(objectData?.results, ["relevance_score", "score"]) ?? @@ -353,6 +539,7 @@ function cosineSimilarity(a: number[], b: number[]): number { export class MemoryRetriever { private accessTracker: AccessTracker | null = null; + private lastDiagnostics: RetrievalDiagnostics | null = null; private tierManager: TierManager | null = null; constructor( @@ -375,30 +562,74 @@ export class MemoryRetriever { async retrieve(context: RetrievalContext): Promise { const { query, limit, scopeFilter, category, source } = context; const safeLimit = clampInt(limit, 1, 20); + this.lastDiagnostics = null; + const diagnostics: RetrievalDiagnostics = { + source, + mode: this.config.mode, + originalQuery: query, + bm25Query: this.config.mode === "vector" ? null : query, + queryExpanded: false, + limit: safeLimit, + scopeFilter: scopeFilter ? [...scopeFilter] : undefined, + category, + vectorResultCount: 0, + bm25ResultCount: 0, + fusedResultCount: 0, + finalResultCount: 0, + stageCounts: { + afterMinScore: 0, + rerankInput: 0, + afterRerank: 0, + afterRecency: 0, + afterImportance: 0, + afterLengthNorm: 0, + afterTimeDecay: 0, + afterHardMinScore: 0, + afterNoiseFilter: 0, + afterDiversity: 0, + }, + dropSummary: [], + }; - let results: RetrievalResult[]; - if (this.config.mode === "vector" || !this.store.hasFtsSupport) { - results = await this.vectorOnlyRetrieval( - query, - safeLimit, - scopeFilter, - category, - ); - } else { - results = await this.hybridRetrieval( - query, - safeLimit, - scopeFilter, - category, - ); - } + try { + let results: RetrievalResult[]; + if (this.config.mode === "vector" || !this.store.hasFtsSupport) { + results = await this.vectorOnlyRetrieval( + query, + safeLimit, + scopeFilter, + category, + diagnostics, + ); + } else { + results = await this.hybridRetrieval( + query, + safeLimit, + scopeFilter, + category, + source, + diagnostics, + ); + } - // Record access for reinforcement (manual recall only) - if (this.accessTracker && source === "manual" && results.length > 0) { - this.accessTracker.recordAccess(results.map((r) => r.entry.id)); - } + diagnostics.finalResultCount = results.length; + diagnostics.dropSummary = buildDropSummary(diagnostics); + this.lastDiagnostics = diagnostics; + + // Record access for reinforcement (manual recall only) + if (this.accessTracker && source === "manual" && results.length > 0) { + this.accessTracker.recordAccess(results.map((r) => r.entry.id)); + } - return results; + return results; + } catch (error) { + diagnostics.finalResultCount = 0; + diagnostics.dropSummary = buildDropSummary(diagnostics); + diagnostics.errorMessage = + error instanceof Error ? error.message : String(error); + this.lastDiagnostics = diagnostics; + throw error; + } } private async vectorOnlyRetrieval( @@ -406,45 +637,72 @@ export class MemoryRetriever { limit: number, scopeFilter?: string[], category?: string, + diagnostics?: RetrievalDiagnostics, ): Promise { - const queryVector = await this.embedder.embedQuery(query); - const results = await this.store.vectorSearch( - queryVector, - limit, - this.config.minScore, - scopeFilter, - { excludeInactive: true }, - ); - - // Filter by category if specified - const filtered = category - ? results.filter((r) => r.entry.category === category) - : results; + let failureStage: RetrievalDiagnostics["failureStage"] = "vector.embedQuery"; + try { + const queryVector = await this.embedder.embedQuery(query); + failureStage = "vector.vectorSearch"; + const results = await this.store.vectorSearch( + queryVector, + limit, + this.config.minScore, + scopeFilter, + { excludeInactive: true }, + ); - const mapped = filtered.map( - (result, index) => - ({ - ...result, - sources: { - vector: { score: result.score, rank: index + 1 }, - }, - }) as RetrievalResult, - ); + const filtered = category + ? results.filter((r) => r.entry.category === category) + : results; + if (diagnostics) { + diagnostics.vectorResultCount = filtered.length; + diagnostics.fusedResultCount = filtered.length; + diagnostics.stageCounts.afterMinScore = filtered.length; + diagnostics.stageCounts.rerankInput = filtered.length; + } - const weighted = this.decayEngine ? mapped : this.applyImportanceWeight(this.applyRecencyBoost(mapped)); - const lengthNormalized = this.applyLengthNormalization(weighted); - const hardFiltered = lengthNormalized.filter(r => r.score >= this.config.hardMinScore); - const lifecycleRanked = this.decayEngine - ? this.applyDecayBoost(hardFiltered) - : this.applyTimeDecay(hardFiltered); - const denoised = this.config.filterNoise - ? filterNoise(lifecycleRanked, r => r.entry.text) - : lifecycleRanked; + const mapped = filtered.map( + (result, index) => + ({ + ...result, + sources: { + vector: { score: result.score, rank: index + 1 }, + }, + }) as RetrievalResult, + ); - // MMR deduplication: avoid top-k filled with near-identical memories - const deduplicated = this.applyMMRDiversity(denoised); + failureStage = "vector.postProcess"; + const recencyBoosted = this.applyRecencyBoost(mapped); + if (diagnostics) diagnostics.stageCounts.afterRecency = recencyBoosted.length; + const weighted = this.decayEngine + ? recencyBoosted + : this.applyImportanceWeight(recencyBoosted); + if (diagnostics) diagnostics.stageCounts.afterImportance = weighted.length; + const lengthNormalized = this.applyLengthNormalization(weighted); + if (diagnostics) diagnostics.stageCounts.afterLengthNorm = lengthNormalized.length; + const hardFiltered = lengthNormalized.filter((r) => r.score >= this.config.hardMinScore); + if (diagnostics) diagnostics.stageCounts.afterHardMinScore = hardFiltered.length; + const timeOrDecayRanked = this.decayEngine + ? this.applyDecayBoost(hardFiltered) + : this.applyTimeDecay(hardFiltered); + if (diagnostics) diagnostics.stageCounts.afterTimeDecay = timeOrDecayRanked.length; + const denoised = this.config.filterNoise + ? filterNoise(timeOrDecayRanked, (r) => r.entry.text) + : timeOrDecayRanked; + if (diagnostics) diagnostics.stageCounts.afterNoiseFilter = denoised.length; + const deduplicated = this.applyMMRDiversity(denoised); + if (diagnostics) { + diagnostics.stageCounts.afterRerank = mapped.length; + diagnostics.stageCounts.afterDiversity = deduplicated.length; + } - return deduplicated.slice(0, limit); + return deduplicated.slice(0, limit); + } catch (error) { + if (diagnostics) { + diagnostics.failureStage = extractFailureStage(error) ?? failureStage; + } + throw error; + } } private async hybridRetrieval( @@ -452,70 +710,100 @@ export class MemoryRetriever { limit: number, scopeFilter?: string[], category?: string, + source?: RetrievalContext["source"], + diagnostics?: RetrievalDiagnostics, ): Promise { - const candidatePoolSize = Math.max( - this.config.candidatePoolSize, - limit * 2, - ); - - // Compute query embedding once, reuse for vector search + reranking - const queryVector = await this.embedder.embedQuery(query); - - // Run vector and BM25 searches in parallel - const [vectorResults, bm25Results] = await Promise.all([ - this.runVectorSearch( - queryVector, - candidatePoolSize, - scopeFilter, - category, - ), - this.runBM25Search(query, candidatePoolSize, scopeFilter, category), - ]); - - // Fuse results using RRF (async: validates BM25-only entries exist in store) - const fusedResults = await this.fuseResults(vectorResults, bm25Results); + let failureStage: RetrievalDiagnostics["failureStage"] = "hybrid.embedQuery"; + try { + const candidatePoolSize = Math.max( + this.config.candidatePoolSize, + limit * 2, + ); - // Apply minimum score threshold - const filtered = fusedResults.filter( - (r) => r.score >= this.config.minScore, - ); + const queryVector = await this.embedder.embedQuery(query); + const bm25Query = this.buildBM25Query(query, source); + if (diagnostics) { + diagnostics.bm25Query = bm25Query; + diagnostics.queryExpanded = bm25Query !== query; + } - // Rerank if enabled - const reranked = - this.config.rerank !== "none" - ? await this.rerankResults( - query, + failureStage = "hybrid.parallelSearch"; + const [vectorResults, bm25Results] = await Promise.all([ + this.runVectorSearch( queryVector, - filtered.slice(0, limit * 2), - ) - : filtered; - - const temporallyRanked = this.decayEngine - ? reranked - : this.applyImportanceWeight(this.applyRecencyBoost(reranked)); - - // Apply length normalization (penalize long entries dominating via keyword density) - const lengthNormalized = this.applyLengthNormalization(temporallyRanked); - - // Hard minimum score cutoff should be based on semantic / lexical relevance. - // Lifecycle decay and time-decay are used for re-ranking, not for dropping - // otherwise relevant fresh memories. - const hardFiltered = lengthNormalized.filter(r => r.score >= this.config.hardMinScore); - - // Apply lifecycle-aware decay or legacy time decay after thresholding - const lifecycleRanked = this.decayEngine - ? this.applyDecayBoost(hardFiltered) - : this.applyTimeDecay(hardFiltered); - - // Filter noise - const denoised = this.config.filterNoise - ? filterNoise(lifecycleRanked, r => r.entry.text) - : lifecycleRanked; + candidatePoolSize, + scopeFilter, + category, + ).catch((error) => { + throw attachFailureStage(error, "hybrid.vectorSearch"); + }), + this.runBM25Search( + bm25Query, + candidatePoolSize, + scopeFilter, + category, + ).catch((error) => { + throw attachFailureStage(error, "hybrid.bm25Search"); + }), + ]); + if (diagnostics) { + diagnostics.vectorResultCount = vectorResults.length; + diagnostics.bm25ResultCount = bm25Results.length; + } - // MMR deduplication: avoid top-k filled with near-identical memories - const deduplicated = this.applyMMRDiversity(denoised); + failureStage = "hybrid.fuseResults"; + const fusedResults = await this.fuseResults(vectorResults, bm25Results); + if (diagnostics) diagnostics.fusedResultCount = fusedResults.length; - return deduplicated.slice(0, limit); + const filtered = fusedResults.filter( + (r) => r.score >= this.config.minScore, + ); + if (diagnostics) diagnostics.stageCounts.afterMinScore = filtered.length; + + const rerankInput = + this.config.rerank !== "none" ? filtered.slice(0, limit * 2) : filtered; + if (diagnostics) diagnostics.stageCounts.rerankInput = rerankInput.length; + + failureStage = "hybrid.rerank"; + const reranked = + this.config.rerank !== "none" + ? await this.rerankResults( + query, + queryVector, + rerankInput, + ) + : filtered; + if (diagnostics) diagnostics.stageCounts.afterRerank = reranked.length; + + failureStage = "hybrid.postProcess"; + const recencyBoosted = this.applyRecencyBoost(reranked); + if (diagnostics) diagnostics.stageCounts.afterRecency = recencyBoosted.length; + const temporallyRanked = this.decayEngine + ? recencyBoosted + : this.applyImportanceWeight(recencyBoosted); + if (diagnostics) diagnostics.stageCounts.afterImportance = temporallyRanked.length; + const lengthNormalized = this.applyLengthNormalization(temporallyRanked); + if (diagnostics) diagnostics.stageCounts.afterLengthNorm = lengthNormalized.length; + const hardFiltered = lengthNormalized.filter((r) => r.score >= this.config.hardMinScore); + if (diagnostics) diagnostics.stageCounts.afterHardMinScore = hardFiltered.length; + const lifecycleRanked = this.decayEngine + ? this.applyDecayBoost(hardFiltered) + : this.applyTimeDecay(hardFiltered); + if (diagnostics) diagnostics.stageCounts.afterTimeDecay = lifecycleRanked.length; + const denoised = this.config.filterNoise + ? filterNoise(lifecycleRanked, (r) => r.entry.text) + : lifecycleRanked; + if (diagnostics) diagnostics.stageCounts.afterNoiseFilter = denoised.length; + const deduplicated = this.applyMMRDiversity(denoised); + if (diagnostics) diagnostics.stageCounts.afterDiversity = deduplicated.length; + + return deduplicated.slice(0, limit); + } catch (error) { + if (diagnostics) { + diagnostics.failureStage = extractFailureStage(error) ?? failureStage; + } + throw error; + } } private async runVectorSearch( @@ -562,6 +850,15 @@ export class MemoryRetriever { })); } + private buildBM25Query( + query: string, + source?: RetrievalContext["source"], + ): string { + if (!this.config.queryExpansion) return query; + if (source !== "manual" && source !== "cli") return query; + return expandQuery(query); + } + private async fuseResults( vectorResults: Array, bm25Results: Array, @@ -655,18 +952,27 @@ export class MemoryRetriever { } // Try cross-encoder rerank via configured provider API - if (this.config.rerank === "cross-encoder" && this.config.rerankApiKey) { + const provider = this.config.rerankProvider || "jina"; + const needsApiKey = provider !== "vllm"; + const hasApiKey = !!this.config.rerankApiKey; + + if (this.config.rerank === "cross-encoder" && (!needsApiKey || hasApiKey)) { try { - const provider = this.config.rerankProvider || "jina"; const model = this.config.rerankModel || "jina-reranker-v3"; const endpoint = this.config.rerankEndpoint || "https://api.jina.ai/v1/rerank"; const documents = results.map((r) => r.entry.text); + if (provider === "vllm" && !this.config.rerankEndpoint) { + throw new Error( + "vLLM rerank provider requires rerankEndpoint to be configured.", + ); + } + // Build provider-specific request const { headers, body } = buildRerankRequest( provider, - this.config.rerankApiKey, + this.config.rerankApiKey || "", model, query, documents, @@ -1060,6 +1366,20 @@ export class MemoryRetriever { return { ...this.config }; } + getLastDiagnostics(): RetrievalDiagnostics | null { + if (!this.lastDiagnostics) return null; + return { + ...this.lastDiagnostics, + scopeFilter: this.lastDiagnostics.scopeFilter + ? [...this.lastDiagnostics.scopeFilter] + : undefined, + stageCounts: { ...this.lastDiagnostics.stageCounts }, + dropSummary: this.lastDiagnostics.dropSummary.map((drop) => ({ + ...drop, + })), + }; + } + // Test retrieval system async test(query = "test query"): Promise<{ success: boolean; diff --git a/test/query-expander.test.mjs b/test/query-expander.test.mjs new file mode 100644 index 00000000..56d9d441 --- /dev/null +++ b/test/query-expander.test.mjs @@ -0,0 +1,656 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; +import path from "node:path"; +import { fileURLToPath } from "node:url"; +import { Command } from "commander"; +import jitiFactory from "jiti"; + +const testDir = path.dirname(fileURLToPath(import.meta.url)); +const pluginSdkStubPath = path.resolve(testDir, "helpers", "openclaw-plugin-sdk-stub.mjs"); +const jiti = jitiFactory(import.meta.url, { + interopDefault: true, + alias: { + "openclaw/plugin-sdk": pluginSdkStubPath, + }, +}); + +const { expandQuery } = jiti("../src/query-expander.ts"); +const { createRetriever } = jiti("../src/retriever.ts"); +const { createMemoryCLI } = jiti("../cli.ts"); + +function buildResult(id = "memory-1", text = "服务崩溃 error") { + return { + entry: { + id, + text, + vector: [0.1, 0.2, 0.3], + category: "other", + scope: "global", + importance: 0.7, + timestamp: 1700000000000, + metadata: "{}", + }, + score: 0.9, + }; +} + +describe("query expander", () => { + it("expands colloquial Chinese crash queries with technical BM25 terms", () => { + const expanded = expandQuery("服务挂了"); + assert.notEqual(expanded, "服务挂了"); + assert.match(expanded, /崩溃/); + assert.match(expanded, /crash/); + assert.match(expanded, /报错|error/); + }); + + it("avoids english substring false positives", () => { + assert.equal(expandQuery("memorybank retention"), "memorybank retention"); + assert.equal(expandQuery("configurable loader"), "configurable loader"); + }); +}); + +describe("retriever BM25 query expansion gating", () => { + function createRetrieverHarness( + config = {}, + storeOverrides = {}, + embedderOverrides = {}, + ) { + const bm25Queries = []; + const embeddedQueries = []; + + const retriever = createRetriever( + { + hasFtsSupport: true, + async vectorSearch() { + return []; + }, + async bm25Search(query) { + bm25Queries.push(query); + return [buildResult()]; + }, + async hasId() { + return true; + }, + ...storeOverrides, + }, + { + async embedQuery(query) { + embeddedQueries.push(query); + return [0.1, 0.2, 0.3]; + }, + ...embedderOverrides, + }, + { + rerank: "none", + filterNoise: false, + minScore: 0, + hardMinScore: 0, + candidatePoolSize: 5, + ...config, + }, + ); + + return { retriever, bm25Queries, embeddedQueries }; + } + + it("expands only the BM25 leg for manual retrieval", async () => { + const { retriever, bm25Queries, embeddedQueries } = createRetrieverHarness(); + + const results = await retriever.retrieve({ + query: "服务挂了", + limit: 1, + source: "manual", + }); + + assert.equal(results.length, 1); + assert.deepEqual(embeddedQueries, ["服务挂了"]); + assert.equal(bm25Queries.length, 1); + assert.notEqual(bm25Queries[0], "服务挂了"); + assert.match(bm25Queries[0], /crash/); + assert.deepEqual(retriever.getLastDiagnostics(), { + source: "manual", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: bm25Queries[0], + queryExpanded: true, + limit: 1, + scopeFilter: undefined, + category: undefined, + vectorResultCount: 0, + bm25ResultCount: 1, + fusedResultCount: 1, + finalResultCount: 1, + stageCounts: { + afterMinScore: 1, + rerankInput: 1, + afterRerank: 1, + afterRecency: 1, + afterImportance: 1, + afterLengthNorm: 1, + afterTimeDecay: 1, + afterHardMinScore: 1, + afterNoiseFilter: 1, + afterDiversity: 1, + }, + dropSummary: [], + }); + }); + + it("keeps auto-recall and unspecified retrieval on the original query", async () => { + const autoRecallHarness = createRetrieverHarness(); + await autoRecallHarness.retriever.retrieve({ + query: "服务挂了", + limit: 1, + source: "auto-recall", + }); + assert.deepEqual(autoRecallHarness.bm25Queries, ["服务挂了"]); + + const unspecifiedHarness = createRetrieverHarness(); + await unspecifiedHarness.retriever.retrieve({ + query: "服务挂了", + limit: 1, + }); + assert.deepEqual(unspecifiedHarness.bm25Queries, ["服务挂了"]); + }); + + it("honors retrieval.queryExpansion = false", async () => { + const { retriever, bm25Queries } = createRetrieverHarness({ + queryExpansion: false, + }); + + await retriever.retrieve({ + query: "服务挂了", + limit: 1, + source: "manual", + }); + + assert.deepEqual(bm25Queries, ["服务挂了"]); + }); + + it("summarizes the biggest count drops without changing retrieval behavior", async () => { + const { retriever } = createRetrieverHarness( + { + rerank: "none", + filterNoise: false, + minScore: 0, + hardMinScore: 0, + }, + { + async bm25Search() { + return [ + buildResult("memory-1", "故障一"), + buildResult("memory-2", "故障二"), + buildResult("memory-3", "故障三"), + ]; + }, + }, + ); + + const results = await retriever.retrieve({ + query: "普通查询", + limit: 1, + source: "manual", + }); + + assert.equal(results.length, 1); + assert.deepEqual(retriever.getLastDiagnostics()?.dropSummary, [ + { + stage: "limit", + before: 3, + after: 1, + dropped: 2, + }, + ]); + }); + + it("captures partial diagnostics when retrieval fails before search completes", async () => { + const { retriever } = createRetrieverHarness( + {}, + {}, + { + async embedQuery() { + throw new Error("simulated embed failure"); + }, + }, + ); + + await assert.rejects( + retriever.retrieve({ + query: "服务挂了", + limit: 1, + source: "manual", + }), + /simulated embed failure/, + ); + + assert.deepEqual(retriever.getLastDiagnostics(), { + source: "manual", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: "服务挂了", + queryExpanded: false, + limit: 1, + scopeFilter: undefined, + category: undefined, + vectorResultCount: 0, + bm25ResultCount: 0, + fusedResultCount: 0, + finalResultCount: 0, + stageCounts: { + afterMinScore: 0, + rerankInput: 0, + afterRerank: 0, + afterRecency: 0, + afterImportance: 0, + afterLengthNorm: 0, + afterTimeDecay: 0, + afterHardMinScore: 0, + afterNoiseFilter: 0, + afterDiversity: 0, + }, + dropSummary: [], + failureStage: "hybrid.embedQuery", + errorMessage: "simulated embed failure", + }); + }); + + it("distinguishes vector-search failures inside the hybrid parallel stage", async () => { + const { retriever } = createRetrieverHarness( + {}, + { + async vectorSearch() { + throw new Error("simulated vector search failure"); + }, + }, + ); + + await assert.rejects( + retriever.retrieve({ + query: "普通查询", + limit: 1, + source: "manual", + }), + /simulated vector search failure/, + ); + + assert.equal( + retriever.getLastDiagnostics()?.failureStage, + "hybrid.vectorSearch", + ); + assert.equal( + retriever.getLastDiagnostics()?.errorMessage, + "simulated vector search failure", + ); + }); + + it("distinguishes bm25-search failures inside the hybrid parallel stage", async () => { + const { retriever } = createRetrieverHarness( + {}, + { + async bm25Search() { + throw new Error("simulated bm25 search failure"); + }, + }, + ); + + await assert.rejects( + retriever.retrieve({ + query: "普通查询", + limit: 1, + source: "manual", + }), + /simulated bm25 search failure/, + ); + + assert.equal( + retriever.getLastDiagnostics()?.failureStage, + "hybrid.bm25Search", + ); + assert.equal( + retriever.getLastDiagnostics()?.errorMessage, + "simulated bm25 search failure", + ); + }); +}); + +describe("cli search source tagging", () => { + it("marks search requests as cli so query expansion stays scoped to interactive CLI recall", async () => { + const searchCalls = []; + const logs = []; + + const program = new Command(); + program.exitOverride(); + createMemoryCLI({ + store: { + async list() { + return []; + }, + async stats() { + return { + totalCount: 0, + scopeCounts: {}, + categoryCounts: {}, + }; + }, + }, + retriever: { + async retrieve(params) { + searchCalls.push(params); + return [ + { + ...buildResult("memory-cli", "CLI search hit"), + sources: { + vector: { score: 0.9, rank: 1 }, + }, + }, + ]; + }, + getConfig() { + return { mode: "hybrid" }; + }, + getLastDiagnostics() { + return { + source: "cli", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: "服务挂了 崩溃 crash error 报错 宕机", + queryExpanded: true, + limit: 10, + scopeFilter: undefined, + category: undefined, + vectorResultCount: 0, + bm25ResultCount: 3, + fusedResultCount: 3, + finalResultCount: 1, + stageCounts: { + afterMinScore: 3, + rerankInput: 3, + afterRerank: 3, + afterRecency: 3, + afterImportance: 3, + afterLengthNorm: 3, + afterTimeDecay: 3, + afterHardMinScore: 3, + afterNoiseFilter: 3, + afterDiversity: 3, + }, + dropSummary: [ + { + stage: "limit", + before: 3, + after: 1, + dropped: 2, + }, + ], + }; + }, + }, + scopeManager: { + getStats() { + return { totalScopes: 1 }; + }, + }, + migrator: {}, + })({ program }); + + const origLog = console.log; + console.log = (...args) => logs.push(args.join(" ")); + try { + await program.parseAsync([ + "node", + "openclaw", + "memory-pro", + "search", + "服务挂了", + "--debug", + ]); + } finally { + console.log = origLog; + } + + assert.equal(searchCalls.length, 1); + assert.equal(searchCalls[0].source, "cli"); + assert.match(logs.join("\n"), /Retrieval diagnostics:/); + assert.match(logs.join("\n"), /Original query: 服务挂了/); + assert.match(logs.join("\n"), /BM25 query: 服务挂了 崩溃 crash error 报错 宕机/); + assert.match(logs.join("\n"), /Stages: min=3, rerankIn=3, rerank=3, hard=3, noise=3, diversity=3/); + assert.match(logs.join("\n"), /Drops: limit -2 \(3->1\)/); + assert.match(logs.join("\n"), /CLI search hit/); + }); + + it("prints failure diagnostics on debug search errors", async () => { + const logs = []; + const errors = []; + const exitCalls = []; + + const program = new Command(); + program.exitOverride(); + createMemoryCLI({ + store: { + async list() { + return []; + }, + async stats() { + return { + totalCount: 0, + scopeCounts: {}, + categoryCounts: {}, + }; + }, + }, + retriever: { + async retrieve() { + throw new Error("simulated search failure"); + }, + getConfig() { + return { mode: "hybrid" }; + }, + getLastDiagnostics() { + return { + source: "cli", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: "服务挂了 崩溃 crash error 报错 宕机", + queryExpanded: true, + limit: 10, + scopeFilter: undefined, + category: undefined, + vectorResultCount: 0, + bm25ResultCount: 0, + fusedResultCount: 0, + finalResultCount: 0, + stageCounts: { + afterMinScore: 0, + rerankInput: 0, + afterRerank: 0, + afterRecency: 0, + afterImportance: 0, + afterLengthNorm: 0, + afterTimeDecay: 0, + afterHardMinScore: 0, + afterNoiseFilter: 0, + afterDiversity: 0, + }, + dropSummary: [], + failureStage: "hybrid.embedQuery", + errorMessage: "simulated search failure", + }; + }, + }, + scopeManager: { + getStats() { + return { totalScopes: 1 }; + }, + }, + migrator: {}, + })({ program }); + + const origLog = console.log; + const origError = console.error; + const origExit = process.exit; + console.log = (...args) => logs.push(args.join(" ")); + console.error = (...args) => errors.push(args.join(" ")); + process.exit = ((code = 0) => { + exitCalls.push(Number(code)); + throw new Error(`__TEST_EXIT__${code}`); + }); + try { + await assert.rejects( + program.parseAsync([ + "node", + "openclaw", + "memory-pro", + "search", + "服务挂了", + "--debug", + ]), + /__TEST_EXIT__1/, + ); + } finally { + console.log = origLog; + console.error = origError; + process.exit = origExit; + } + + assert.deepEqual(logs, []); + assert.deepEqual(exitCalls, [1]); + assert.match(errors.join("\n"), /Retrieval diagnostics:/); + assert.match(errors.join("\n"), /Failure stage: hybrid\.embedQuery/); + assert.match(errors.join("\n"), /Error: simulated search failure/); + assert.match(errors.join("\n"), /Search failed:/); + }); + + it("returns structured JSON failure output for --json --debug search errors", async () => { + const logs = []; + const errors = []; + const exitCalls = []; + + const program = new Command(); + program.exitOverride(); + createMemoryCLI({ + store: { + async list() { + return []; + }, + async stats() { + return { + totalCount: 0, + scopeCounts: {}, + categoryCounts: {}, + }; + }, + }, + retriever: { + async retrieve() { + throw new Error("simulated json search failure"); + }, + getConfig() { + return { mode: "hybrid" }; + }, + getLastDiagnostics() { + return { + source: "cli", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: "服务挂了 崩溃 crash error 报错 宕机", + queryExpanded: true, + limit: 10, + scopeFilter: undefined, + category: undefined, + vectorResultCount: 0, + bm25ResultCount: 0, + fusedResultCount: 0, + finalResultCount: 0, + stageCounts: { + afterMinScore: 0, + rerankInput: 0, + afterRerank: 0, + afterRecency: 0, + afterImportance: 0, + afterLengthNorm: 0, + afterTimeDecay: 0, + afterHardMinScore: 0, + afterNoiseFilter: 0, + afterDiversity: 0, + }, + dropSummary: [], + failureStage: "hybrid.embedQuery", + errorMessage: "simulated json search failure", + }; + }, + }, + scopeManager: { + getStats() { + return { totalScopes: 1 }; + }, + }, + migrator: {}, + })({ program }); + + const origLog = console.log; + const origError = console.error; + const origExit = process.exit; + console.log = (...args) => logs.push(args.join(" ")); + console.error = (...args) => errors.push(args.join(" ")); + process.exit = ((code = 0) => { + exitCalls.push(Number(code)); + throw new Error(`__TEST_EXIT__${code}`); + }); + try { + await assert.rejects( + program.parseAsync([ + "node", + "openclaw", + "memory-pro", + "search", + "服务挂了", + "--json", + "--debug", + ]), + /__TEST_EXIT__1/, + ); + } finally { + console.log = origLog; + console.error = origError; + process.exit = origExit; + } + + assert.deepEqual(exitCalls, [1]); + assert.deepEqual(errors, []); + assert.equal(logs.length, 1); + const payload = JSON.parse(logs[0]); + assert.deepEqual(payload, { + error: { + code: "search_failed", + message: "simulated json search failure", + }, + diagnostics: { + source: "cli", + mode: "hybrid", + originalQuery: "服务挂了", + bm25Query: "服务挂了 崩溃 crash error 报错 宕机", + queryExpanded: true, + limit: 10, + vectorResultCount: 0, + bm25ResultCount: 0, + fusedResultCount: 0, + finalResultCount: 0, + stageCounts: { + afterMinScore: 0, + rerankInput: 0, + afterRerank: 0, + afterRecency: 0, + afterImportance: 0, + afterLengthNorm: 0, + afterTimeDecay: 0, + afterHardMinScore: 0, + afterNoiseFilter: 0, + afterDiversity: 0, + }, + dropSummary: [], + failureStage: "hybrid.embedQuery", + errorMessage: "simulated json search failure", + }, + }); + }); +});