-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgraph.py
More file actions
91 lines (72 loc) · 3.68 KB
/
graph.py
File metadata and controls
91 lines (72 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Literal, Optional, TypedDict, List
from copy import deepcopy
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableConfig
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from agents.word_explain_agent import invoke as word_agent
from agents.code_check_agent import invoke as code_agent
from agents.email_agent import invoke as email_agent
from agents.matching_agent import invoke as matching_agent
from agents.exception_agent import invoke as exception_agent
from agents.find_report_agent import invoke as find_report_agent
from agents.report_writing_guide_agent import invoke as report_writing_guide_agent
from agent_state import AgentState
# 라우팅 프롬프트 체인 정의
router_prompt = PromptTemplate.from_template("""
당신은 사용자의 질문 또는 코드 입력을 보고 아래 중 어떤 기능이 필요한지 판단하는 AI 라우터입니다.
각 기능의 목적은 다음과 같습니다:
1. word_explain: 프로젝트 관련 용어나 개념 설명
2. code_check: 사용자가 작성한 코드에 대해 규칙 검토
3. find_report_agent: 사용자 질의 내용이 문서나 보고서를 찾아달라고 하는 것 같을때
4. report_writing_guide_agent: 사용자 질의 내용이 문서나 보고서 작성에 대해 도움을 요청하는 것 같을때
5. email_agent: 이메일 작성 요청
6. matching_agent: 특정 담당자를 묻는 질문
7. exception_agent: 위 항목들에 해당하지 않음
사용자가 입력한 내용이 코드처럼 보이면 'code_check'로 판단하세요.
다음 사용자 입력에 가장 적합한 기능 이름만 한 단어로 출력해주세요. (예: code_check)
입력:
{input_query}
""")
router_chain = router_prompt | ChatOpenAI(model="gpt-4o-mini") | StrOutputParser()
# 라우팅 함수
def route_agent(state: AgentState) -> Literal[
"word_explain", "code_check", "find_report_agent",
"report_writing_guide_agent", "email_agent", "matching_agent", "exception_agent"
]:
result = router_chain.invoke({"input_query": state["input_query"]}).strip().lower()
print(f"🧭 라우팅 결과: {result}")
if result in {
"word_explain", "code_check", "find_report_agent",
"report_writing_guide_agent", "email_agent", "matching_agent"
}:
return result
return "exception_agent"
# Supervisor Graph 생성 함수
from copy import deepcopy
def create_supervisor_graph():
builder = StateGraph(AgentState)
def wrap_agent(agent_func):
def wrapper(state: AgentState, config: RunnableConfig) -> AgentState:
result = agent_func(state, config)
new_state = deepcopy(state)
new_state.update(result)
return new_state
return wrapper
builder.add_node("word_explain", wrap_agent(word_agent))
builder.add_node("code_check", wrap_agent(code_agent))
builder.add_node("find_report_agent", wrap_agent(find_report_agent))
builder.add_node("report_writing_guide_agent", wrap_agent(report_writing_guide_agent))
builder.add_node("email_agent", wrap_agent(email_agent))
builder.add_node("matching_agent", wrap_agent(matching_agent))
builder.add_node("exception_agent", wrap_agent(exception_agent))
builder.set_conditional_entry_point(route_agent)
builder.add_edge("word_explain", END)
builder.add_edge("code_check", END)
builder.add_edge("find_report_agent", END)
builder.add_edge("report_writing_guide_agent", END)
builder.add_edge("email_agent", END)
builder.add_edge("matching_agent", END)
builder.add_edge("exception_agent", END)
return builder.compile()