|
25 | 25 | from .image_reflection import reflect_assistant_images |
26 | 26 | from .mcp_registry import MCPServerConfig, MCPToolAggregator |
27 | 27 | from .streaming import SseEvent, StreamingHandler |
28 | | -from .tool_context_planner import ToolContextPlanner |
| 28 | +from .tool_context_planner import merge_model_tool_plan, ToolContextPlanner |
29 | 29 |
|
30 | 30 | _TOOL_RATIONALE_INSTRUCTION = ( |
31 | 31 | "Before each tool call, emit numbered one-sentence rationales in order (e.g.," |
|
41 | 41 | _MAX_RANKED_TOOLS = 5 |
42 | 42 |
|
43 | 43 |
|
| 44 | +def _compact_tool_digest( |
| 45 | + digest: dict[str, list[dict[str, Any]]] | None, |
| 46 | +) -> dict[str, list[dict[str, Any]]]: |
| 47 | + if not digest: |
| 48 | + return {} |
| 49 | + |
| 50 | + compact: dict[str, list[dict[str, Any]]] = {} |
| 51 | + for context, entries in digest.items(): |
| 52 | + if not isinstance(entries, list): |
| 53 | + continue |
| 54 | + filtered: list[dict[str, Any]] = [] |
| 55 | + for entry in entries: |
| 56 | + if not isinstance(entry, dict): |
| 57 | + continue |
| 58 | + name = entry.get("name") |
| 59 | + if not isinstance(name, str) or not name.strip(): |
| 60 | + continue |
| 61 | + compact_entry: dict[str, Any] = {"name": name.strip()} |
| 62 | + description = entry.get("description") |
| 63 | + if isinstance(description, str) and description.strip(): |
| 64 | + compact_entry["description"] = description.strip() |
| 65 | + parameters = entry.get("parameters") |
| 66 | + if isinstance(parameters, dict) and parameters: |
| 67 | + compact_entry["parameters"] = parameters |
| 68 | + server = entry.get("server") |
| 69 | + if isinstance(server, str) and server.strip(): |
| 70 | + compact_entry["server"] = server.strip() |
| 71 | + score = entry.get("score") |
| 72 | + if isinstance(score, (int, float)): |
| 73 | + compact_entry["score"] = float(score) |
| 74 | + filtered.append(compact_entry) |
| 75 | + if filtered: |
| 76 | + compact[context] = filtered |
| 77 | + return compact |
| 78 | + |
| 79 | + |
44 | 80 | def _iter_attachment_ids(content: Any) -> Iterable[str]: |
45 | 81 | if isinstance(content, list): |
46 | 82 | for item in content: |
@@ -273,38 +309,81 @@ async def process_stream( |
273 | 309 | conversation, |
274 | 310 | capability_digest=capability_digest, |
275 | 311 | ) |
276 | | - plan_payload = plan.to_dict() |
277 | 312 | contexts = plan.contexts_for_attempt(0) |
278 | 313 | ranked_tool_names: list[str] = [] |
279 | 314 | ranked_digest: dict[str, list[dict[str, Any]]] | None = None |
280 | | - if contexts: |
281 | | - digest_for_contexts = getattr(self._mcp_client, "get_capability_digest", None) |
282 | | - if callable(digest_for_contexts): |
283 | | - try: |
284 | | - ranked_digest = digest_for_contexts( |
285 | | - contexts, limit=_MAX_RANKED_TOOLS, include_global=False |
286 | | - ) |
287 | | - except Exception as exc: # pragma: no cover - defensive fallback |
288 | | - logger.debug( |
289 | | - "Failed to obtain ranked capability digest for contexts %s: %s", |
290 | | - contexts, |
291 | | - exc, |
292 | | - ) |
| 315 | + digest_for_contexts = getattr(self._mcp_client, "get_capability_digest", None) |
| 316 | + if contexts and callable(digest_for_contexts): |
| 317 | + try: |
| 318 | + ranked_digest = digest_for_contexts( |
| 319 | + contexts, limit=_MAX_RANKED_TOOLS, include_global=False |
| 320 | + ) |
| 321 | + except Exception as exc: # pragma: no cover - defensive fallback |
| 322 | + logger.debug( |
| 323 | + "Failed to obtain ranked capability digest for contexts %s: %s", |
| 324 | + contexts, |
| 325 | + exc, |
| 326 | + ) |
| 327 | + ranked_digest = {} |
| 328 | + plan_request_digest = _compact_tool_digest(ranked_digest) |
| 329 | + if contexts or plan.broad_search: |
| 330 | + try: |
| 331 | + planner_response = await self._client.request_tool_plan( |
| 332 | + request=request, |
| 333 | + conversation=conversation, |
| 334 | + tool_digest=plan_request_digest, |
| 335 | + ) |
| 336 | + except Exception as exc: # pragma: no cover - remote planner is best effort |
| 337 | + logger.debug("Remote tool planning failed: %s", exc) |
| 338 | + else: |
| 339 | + merged_plan = merge_model_tool_plan(plan, planner_response) |
| 340 | + plan = merged_plan |
| 341 | + contexts = plan.contexts_for_attempt(0) |
| 342 | + if contexts and callable(digest_for_contexts): |
| 343 | + try: |
| 344 | + ranked_digest = digest_for_contexts( |
| 345 | + contexts, limit=_MAX_RANKED_TOOLS, include_global=False |
| 346 | + ) |
| 347 | + except Exception as exc: # pragma: no cover - defensive fallback |
| 348 | + logger.debug( |
| 349 | + "Failed to obtain ranked capability digest for contexts %s: %s", |
| 350 | + contexts, |
| 351 | + exc, |
| 352 | + ) |
| 353 | + ranked_digest = {} |
| 354 | + else: |
293 | 355 | ranked_digest = {} |
294 | | - if ranked_digest: |
295 | | - seen_names: set[str] = set() |
296 | | - for context in contexts: |
297 | | - entries = ranked_digest.get(context) or [] |
298 | | - for entry in entries: |
299 | | - if not isinstance(entry, dict): |
300 | | - continue |
301 | | - name = entry.get("name") |
302 | | - if not isinstance(name, str) or not name: |
303 | | - continue |
304 | | - if name in seen_names: |
305 | | - continue |
306 | | - seen_names.add(name) |
307 | | - ranked_tool_names.append(name) |
| 356 | + |
| 357 | + if plan.candidate_tools: |
| 358 | + seen_names: set[str] = set() |
| 359 | + ordered_contexts = list(contexts) |
| 360 | + if "__all__" in plan.candidate_tools and "__all__" not in ordered_contexts: |
| 361 | + ordered_contexts.append("__all__") |
| 362 | + for context in ordered_contexts: |
| 363 | + candidates = plan.candidate_tools.get(context) or [] |
| 364 | + for candidate in candidates: |
| 365 | + name = candidate.name.strip() |
| 366 | + if not name or name in seen_names: |
| 367 | + continue |
| 368 | + seen_names.add(name) |
| 369 | + ranked_tool_names.append(name) |
| 370 | + |
| 371 | + if ranked_digest: |
| 372 | + seen_names = set(ranked_tool_names) |
| 373 | + for context in contexts: |
| 374 | + entries = ranked_digest.get(context) or [] |
| 375 | + for entry in entries: |
| 376 | + if not isinstance(entry, dict): |
| 377 | + continue |
| 378 | + name = entry.get("name") |
| 379 | + if not isinstance(name, str) or not name: |
| 380 | + continue |
| 381 | + if name in seen_names: |
| 382 | + continue |
| 383 | + seen_names.add(name) |
| 384 | + ranked_tool_names.append(name) |
| 385 | + |
| 386 | + plan_payload = plan.to_dict() |
308 | 387 | request_event_id: str | None = None |
309 | 388 | if isinstance(request.metadata, dict): |
310 | 389 | candidate = request.metadata.get("client_request_id") |
|
0 commit comments