-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconditional_routing.py
More file actions
134 lines (100 loc) · 4.39 KB
/
conditional_routing.py
File metadata and controls
134 lines (100 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Conditional fan-out variant — routes to only the relevant sub-agents.
Uses structured output to classify the query and selectively dispatch
agents instead of always hitting all three.
"""
import operator
from collections.abc import Sequence
from typing import Annotated, TypedDict
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.types import RetryPolicy
from langsmith import traceable, tracing_context
from pydantic import BaseModel, Field
llm = ChatAnthropic(model="claude-sonnet-4-5-20250929", temperature=0)
class State(TypedDict):
question: str
research_results: Annotated[list[dict], operator.add]
final_response: str
# --- Sub-agents (same as parallel_agents.py) ---
def make_agent(name: str, focus: str):
@traceable(name=name, run_type="chain")
def node(state: State) -> dict:
response = llm.invoke([
SystemMessage(
content=f"You are the {name} agent. Focus on {focus}. "
"Return a concise summary. Cite your source type."
),
HumanMessage(content=f"Research query: {state['question']}"),
])
return {"research_results": [{"source": name, "content": response.content}]}
return node
kb_agent = make_agent("knowledge_base", "internal knowledge base searches.")
web_agent = make_agent("web_search", "recent news and industry trends.")
policy_agent = make_agent("policy", "compliance, legal, and regulatory frameworks.")
@traceable(name="Synthesizer", run_type="chain")
def synthesize(state: State) -> dict:
context = "\n\n".join(
f"[{r['source']}]: {r['content']}" for r in state["research_results"]
)
response = llm.invoke([
SystemMessage(
content="Synthesize the following research into a clear, actionable "
"response. When policy information conflicts with or constrains "
"other responses, the policy statement takes precedence. "
"Never soften or omit policy restrictions."
),
HumanMessage(
content=f"Customer question: {state['question']}\n\n"
f"Research findings:\n{context}"
),
])
return {"final_response": response.content}
# --- Conditional routing ---
class RoutingPlan(BaseModel):
agents: list[str] = Field(description="Agents to activate: kb, web, policy")
structured_llm = llm.with_structured_output(RoutingPlan)
@traceable(name="Classifier", run_type="chain")
def classify_and_route(state: State) -> Sequence[str]:
"""Classify the query and return which agents to invoke."""
plan = structured_llm.invoke([
SystemMessage(
content="Decide which research agents to invoke. "
"Available: kb, web, policy. When in doubt, include the agent."
),
HumanMessage(content=state["question"]),
])
return plan.agents or ["kb"]
# --- Graph with conditional fan-out ---
builder = StateGraph(State)
builder.add_node("kb", kb_agent, retry=RetryPolicy(max_attempts=3))
builder.add_node("web", web_agent, retry=RetryPolicy(max_attempts=3))
builder.add_node("policy", policy_agent, retry=RetryPolicy(max_attempts=3))
builder.add_node("synthesize", synthesize)
# Conditional fan-out from START based on classifier
builder.add_conditional_edges(START, classify_and_route, ["kb", "web", "policy"])
# Individual edges from each agent to synthesize (not list-style fan-in,
# because with conditional routing some branches may not be dispatched)
builder.add_edge("kb", "synthesize")
builder.add_edge("web", "synthesize")
builder.add_edge("policy", "synthesize")
builder.add_edge("synthesize", END)
graph = builder.compile()
if __name__ == "__main__":
queries = [
"What is our refund policy for enterprise clients?",
"How does GDPR affect our data pipeline architecture?",
"What competitors launched AI features last quarter?",
]
for q in queries:
print("=" * 60)
print(f"QUESTION: {q}")
with tracing_context(
metadata={"example": True},
tags=["article-01", "conditional-routing"],
):
result = graph.invoke({"question": q})
sources = [r["source"] for r in result["research_results"]]
print(f"Agents dispatched: {sources}")
print(f"Response: {result['final_response'][:200]}...")
print()