-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph.py
More file actions
67 lines (52 loc) · 2.62 KB
/
graph.py
File metadata and controls
67 lines (52 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Annotated, TypedDict, List, Union
# LangChain core message types
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
from agent_base import AgentBase
# -------------------------
# 3. LangGraph Shared State Definition
# -------------------------
AiMessage = Union[SystemMessage, HumanMessage, AIMessage, ToolMessage]
class GraphState(TypedDict):
# Shared state for LangGraph. Tracks all messages in the conversation.
messages: Annotated[List[AiMessage], "messages"]
# Not sure if I want this to be a consistent part of the system or just an initial setup helper
class SystemHelper:
def __init__(self, llm: AgentBase, tools: dict=None):
self._llm = llm
self._tool_registry = tools if tools else {}
def _call_agent(self, messages: List[AiMessage]) -> None:
return self._llm.call(messages)
def _call_tool(self, tool_name: str, tool_args: dict, call_id: str) -> None:
# Validate tool existence
if tool_name not in self._tool_registry:
raise NameError(f"⚠️ Tool {tool_name} not found in registry")
# Invoke the tool using registered function
output = self._tool_registry[tool_name].invoke(tool_args)
return [ToolMessage(tool_call_id=call_id, content=str(output))]
# TODO: me - this has to be the api because that's what the graph expects
def invoke(self, state: GraphState) -> GraphState:
try:
response = self._call_agent(state["messages"])
state["messages"] = state["messages"] + [response]
except Exception as e:
print(f"❌ Agent invocation failed: {str(e)}")
state["messages"] = state["messages"] + [AIMessage(content=str(e))]
return state
def invoke_tool(self, state: GraphState) -> GraphState:
messages = state["messages"]
last_msg = messages[-1] # Get last AI message that might have a tool call
if not hasattr(last_msg, 'tool_calls') or not last_msg.tool_calls:
print("⚠️ No tool calls in last message")
return {"messages": messages}
# Extract first tool call from AIMessage
tool_call = last_msg.tool_calls[0]
name = tool_call.get("name")
args = tool_call.get("args", {})
call_id = tool_call.get("id")
try:
response = self._call_tool(name, args, call_id)
state["messages"] = state["messages"] + response
except Exception as e:
print(f"❌ Tool invocation failed: {str(e)}")
state["messages"] = state["messages"] + [AIMessage(content=str(e))]
return state