Skip to content

Commit 01f641b

Browse files
committed
Add OpenRouter planning prepass for tool contexts
1 parent b06a361 commit 01f641b

4 files changed

Lines changed: 549 additions & 43 deletions

File tree

src/backend/chat/orchestrator.py

Lines changed: 108 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from .image_reflection import reflect_assistant_images
2626
from .mcp_registry import MCPServerConfig, MCPToolAggregator
2727
from .streaming import SseEvent, StreamingHandler
28-
from .tool_context_planner import ToolContextPlanner
28+
from .tool_context_planner import merge_model_tool_plan, ToolContextPlanner
2929

3030
_TOOL_RATIONALE_INSTRUCTION = (
3131
"Before each tool call, emit numbered one-sentence rationales in order (e.g.,"
@@ -41,6 +41,42 @@
4141
_MAX_RANKED_TOOLS = 5
4242

4343

44+
def _compact_tool_digest(
45+
digest: dict[str, list[dict[str, Any]]] | None,
46+
) -> dict[str, list[dict[str, Any]]]:
47+
if not digest:
48+
return {}
49+
50+
compact: dict[str, list[dict[str, Any]]] = {}
51+
for context, entries in digest.items():
52+
if not isinstance(entries, list):
53+
continue
54+
filtered: list[dict[str, Any]] = []
55+
for entry in entries:
56+
if not isinstance(entry, dict):
57+
continue
58+
name = entry.get("name")
59+
if not isinstance(name, str) or not name.strip():
60+
continue
61+
compact_entry: dict[str, Any] = {"name": name.strip()}
62+
description = entry.get("description")
63+
if isinstance(description, str) and description.strip():
64+
compact_entry["description"] = description.strip()
65+
parameters = entry.get("parameters")
66+
if isinstance(parameters, dict) and parameters:
67+
compact_entry["parameters"] = parameters
68+
server = entry.get("server")
69+
if isinstance(server, str) and server.strip():
70+
compact_entry["server"] = server.strip()
71+
score = entry.get("score")
72+
if isinstance(score, (int, float)):
73+
compact_entry["score"] = float(score)
74+
filtered.append(compact_entry)
75+
if filtered:
76+
compact[context] = filtered
77+
return compact
78+
79+
4480
def _iter_attachment_ids(content: Any) -> Iterable[str]:
4581
if isinstance(content, list):
4682
for item in content:
@@ -273,38 +309,81 @@ async def process_stream(
273309
conversation,
274310
capability_digest=capability_digest,
275311
)
276-
plan_payload = plan.to_dict()
277312
contexts = plan.contexts_for_attempt(0)
278313
ranked_tool_names: list[str] = []
279314
ranked_digest: dict[str, list[dict[str, Any]]] | None = None
280-
if contexts:
281-
digest_for_contexts = getattr(self._mcp_client, "get_capability_digest", None)
282-
if callable(digest_for_contexts):
283-
try:
284-
ranked_digest = digest_for_contexts(
285-
contexts, limit=_MAX_RANKED_TOOLS, include_global=False
286-
)
287-
except Exception as exc: # pragma: no cover - defensive fallback
288-
logger.debug(
289-
"Failed to obtain ranked capability digest for contexts %s: %s",
290-
contexts,
291-
exc,
292-
)
315+
digest_for_contexts = getattr(self._mcp_client, "get_capability_digest", None)
316+
if contexts and callable(digest_for_contexts):
317+
try:
318+
ranked_digest = digest_for_contexts(
319+
contexts, limit=_MAX_RANKED_TOOLS, include_global=False
320+
)
321+
except Exception as exc: # pragma: no cover - defensive fallback
322+
logger.debug(
323+
"Failed to obtain ranked capability digest for contexts %s: %s",
324+
contexts,
325+
exc,
326+
)
327+
ranked_digest = {}
328+
plan_request_digest = _compact_tool_digest(ranked_digest)
329+
if contexts or plan.broad_search:
330+
try:
331+
planner_response = await self._client.request_tool_plan(
332+
request=request,
333+
conversation=conversation,
334+
tool_digest=plan_request_digest,
335+
)
336+
except Exception as exc: # pragma: no cover - remote planner is best effort
337+
logger.debug("Remote tool planning failed: %s", exc)
338+
else:
339+
merged_plan = merge_model_tool_plan(plan, planner_response)
340+
plan = merged_plan
341+
contexts = plan.contexts_for_attempt(0)
342+
if contexts and callable(digest_for_contexts):
343+
try:
344+
ranked_digest = digest_for_contexts(
345+
contexts, limit=_MAX_RANKED_TOOLS, include_global=False
346+
)
347+
except Exception as exc: # pragma: no cover - defensive fallback
348+
logger.debug(
349+
"Failed to obtain ranked capability digest for contexts %s: %s",
350+
contexts,
351+
exc,
352+
)
353+
ranked_digest = {}
354+
else:
293355
ranked_digest = {}
294-
if ranked_digest:
295-
seen_names: set[str] = set()
296-
for context in contexts:
297-
entries = ranked_digest.get(context) or []
298-
for entry in entries:
299-
if not isinstance(entry, dict):
300-
continue
301-
name = entry.get("name")
302-
if not isinstance(name, str) or not name:
303-
continue
304-
if name in seen_names:
305-
continue
306-
seen_names.add(name)
307-
ranked_tool_names.append(name)
356+
357+
if plan.candidate_tools:
358+
seen_names: set[str] = set()
359+
ordered_contexts = list(contexts)
360+
if "__all__" in plan.candidate_tools and "__all__" not in ordered_contexts:
361+
ordered_contexts.append("__all__")
362+
for context in ordered_contexts:
363+
candidates = plan.candidate_tools.get(context) or []
364+
for candidate in candidates:
365+
name = candidate.name.strip()
366+
if not name or name in seen_names:
367+
continue
368+
seen_names.add(name)
369+
ranked_tool_names.append(name)
370+
371+
if ranked_digest:
372+
seen_names = set(ranked_tool_names)
373+
for context in contexts:
374+
entries = ranked_digest.get(context) or []
375+
for entry in entries:
376+
if not isinstance(entry, dict):
377+
continue
378+
name = entry.get("name")
379+
if not isinstance(name, str) or not name:
380+
continue
381+
if name in seen_names:
382+
continue
383+
seen_names.add(name)
384+
ranked_tool_names.append(name)
385+
386+
plan_payload = plan.to_dict()
308387
request_event_id: str | None = None
309388
if isinstance(request.metadata, dict):
310389
candidate = request.metadata.get("client_request_id")

0 commit comments

Comments
 (0)