Skip to content

Commit 8145afc

Browse files
committed
improve simple_rag
1 parent fc3a4d6 commit 8145afc

File tree

3 files changed

+29
-30
lines changed

3 files changed

+29
-30
lines changed

.env.example

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,11 @@
11
# To separate your traces from other application
22
LANGSMITH_PROJECT=rag-research-agent
33

4-
# The following depend on your selected configuration
5-
6-
# LLM choice:
7-
ANTHROPIC_API_KEY=....
8-
FIREWORKS_API_KEY=...
94
OPENAI_API_KEY=...
105

11-
# Retrieval provider
12-
13-
## Elastic cloud:
14-
ELASTICSEARCH_URL=...
15-
ELASTICSEARCH_API_KEY=...
16-
17-
## Elastic local:
18-
ELASTICSEARCH_URL=http://host.docker.internal:9200
19-
ELASTICSEARCH_USER=elastic
20-
ELASTICSEARCH_PASSWORD=changeme
21-
226
## Pinecone
237
PINECONE_API_KEY=...
248
PINECONE_INDEX_NAME=...
259

26-
## Mongo Atlas
27-
MONGODB_URI=... # Full connection string
28-
2910
## Index API key
3011
INDEX_API_KEY=...

src/simple_rag/configuration.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""Define the configurable parameters for the agent."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass, field
6+
from typing import Annotated
7+
8+
from shared.configuration import BaseConfiguration
9+
10+
11+
@dataclass(kw_only=True)
12+
class RagConfiguration(BaseConfiguration):
13+
"""The configuration for the agent."""
14+
15+
# models
16+
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
17+
default="openai/gpt-4o-mini",
18+
metadata={
19+
"description": "The language model used for processing and refining queries. Should be in the form: provider/model-name."
20+
},
21+
)

src/simple_rag/graph.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22

33
from langchain import hub
44
from langchain_core.messages import HumanMessage
5-
from langchain_openai import ChatOpenAI
65
from langgraph.graph import END, START, StateGraph
76

87
from shared import retrieval
9-
from shared.configuration import BaseConfiguration
8+
from shared.utils import load_chat_model
9+
from simple_rag.configuration import RagConfiguration
1010
from simple_rag.state import GraphState, InputState
1111

1212

13-
def retrieve(state: GraphState, *, config) -> dict[str, list[str] | str]:
13+
def retrieve(state: GraphState, *, config: RagConfiguration) -> dict[str, list[str] | str]:
1414
"""Retrieve documents
1515
1616
Args:
@@ -29,7 +29,7 @@ def retrieve(state: GraphState, *, config) -> dict[str, list[str] | str]:
2929
return {"documents": documents, "message": state.messages}
3030

3131

32-
async def generate(state: GraphState):
32+
async def generate(state: GraphState, *, config: RagConfiguration):
3333
"""
3434
Generate answer
3535
@@ -43,21 +43,18 @@ async def generate(state: GraphState):
4343
messages = state.messages
4444
documents = state.documents
4545

46-
# RAG generation
47-
# Prompt
4846
prompt = hub.pull("langchaindoc/simple-rag")
49-
50-
# LLM
51-
llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
5247

48+
configuration = RagConfiguration.from_runnable_config(config)
49+
model = load_chat_model(configuration.model)
5350

5451
# Chain
55-
rag_chain = prompt + messages | llm
52+
rag_chain = prompt + messages | model
5653
response = await rag_chain.ainvoke({"context" : documents})
5754
return {"messages": [response], "documents": documents}
5855

5956

60-
workflow = StateGraph(GraphState, input=InputState, config_schema=BaseConfiguration)
57+
workflow = StateGraph(GraphState, input=InputState, config_schema=RagConfiguration)
6158

6259
# Define the nodes
6360
workflow.add_node("retrieve", retrieve)

0 commit comments

Comments
 (0)