-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlanggraph_database_backend.py
More file actions
40 lines (29 loc) · 1.09 KB
/
langgraph_database_backend.py
File metadata and controls
40 lines (29 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, Annotated
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from dotenv import load_dotenv
import sqlite3
load_dotenv()
llm = ChatOpenAI()
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState):
messages = state['messages']
response = llm.invoke(messages)
return {"messages": [response]}
conn = sqlite3.connect(database='chatbot.db', check_same_thread=False)
# Checkpointer
checkpointer = SqliteSaver(conn=conn)
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_edge(START, "chat_node")
graph.add_edge("chat_node", END)
chatbot = graph.compile(checkpointer=checkpointer)
def retrieve_all_threads():
all_threads = set()
for checkpoint in checkpointer.list(None):
all_threads.add(checkpoint.config['configurable']['thread_id'])
return list(all_threads)