22
33from langchain import hub
44from langchain_core .messages import HumanMessage
5- from langchain_openai import ChatOpenAI
65from langgraph .graph import END , START , StateGraph
76
87from shared import retrieval
9- from shared .configuration import BaseConfiguration
8+ from shared .utils import load_chat_model
9+ from simple_rag .configuration import RagConfiguration
1010from simple_rag .state import GraphState , InputState
1111
1212
13- def retrieve (state : GraphState , * , config ) -> dict [str , list [str ] | str ]:
13+ def retrieve (state : GraphState , * , config : RagConfiguration ) -> dict [str , list [str ] | str ]:
1414 """Retrieve documents
1515
1616 Args:
@@ -29,7 +29,7 @@ def retrieve(state: GraphState, *, config) -> dict[str, list[str] | str]:
2929 return {"documents" : documents , "message" : state .messages }
3030
3131
32- async def generate (state : GraphState ):
32+ async def generate (state : GraphState , * , config : RagConfiguration ):
3333 """
3434 Generate answer
3535
@@ -43,21 +43,18 @@ async def generate(state: GraphState):
4343 messages = state .messages
4444 documents = state .documents
4545
46- # RAG generation
47- # Prompt
4846 prompt = hub .pull ("langchaindoc/simple-rag" )
49-
50- # LLM
51- llm = ChatOpenAI (model_name = "gpt-4o-mini" , temperature = 0 )
5247
48+ configuration = RagConfiguration .from_runnable_config (config )
49+ model = load_chat_model (configuration .model )
5350
5451 # Chain
55- rag_chain = prompt + messages | llm
52+ rag_chain = prompt + messages | model
5653 response = await rag_chain .ainvoke ({"context" : documents })
5754 return {"messages" : [response ], "documents" : documents }
5855
5956
60- workflow = StateGraph (GraphState , input = InputState , config_schema = BaseConfiguration )
57+ workflow = StateGraph (GraphState , input = InputState , config_schema = RagConfiguration )
6158
6259# Define the nodes
6360workflow .add_node ("retrieve" , retrieve )
0 commit comments