-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlanguage_models.py
More file actions
175 lines (148 loc) · 6.55 KB
/
language_models.py
File metadata and controls
175 lines (148 loc) · 6.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_xai import ChatXAI
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.messages import AIMessage
from dotenv import load_dotenv
load_dotenv()
MODELS = {
"grok-4-0709": "xai",
"deepseek-r1-distill-llama-70b": "groq",
"llama-3.3-70b-versatile": "groq",
"gpt-4.1-nano-2025-04-14": "openai",
"gpt-4.1-2025-04-14": "openai",
"o3-2025-04-16": "openai",
"claude-sonnet-4-20250514": "anthropic",
"claude-opus-4-20250514": "anthropic",
"models/gemini-2.5-flash": "google",
"gemini-2.5-pro": "google",
}
# Apply monkey patch for Gemini finish_reason issue
def _apply_gemini_patch():
"""Apply patch to fix the finish_reason.name AttributeError in langchain_google_genai"""
try:
from langchain_google_genai import chat_models
# Check if already patched
if hasattr(chat_models._response_to_result, '_is_patched'):
return
# Store the original function
original_response_to_result = chat_models._response_to_result
def patched_response_to_result(response, stream=False):
"""Patched version that handles finish_reason as int"""
try:
# Try the original function first
return original_response_to_result(response, stream)
except AttributeError as e:
if "'int' object has no attribute 'name'" in str(e):
# Handle the response manually
generations = []
llm_output = {}
# Map finish reason integers to names
finish_reason_map = {
0: "FINISH_REASON_UNSPECIFIED",
1: "STOP",
2: "MAX_TOKENS",
3: "SAFETY",
4: "RECITATION",
5: "OTHER"
}
for candidate in response.candidates:
# Extract content
content = ""
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts'):
for part in candidate.content.parts:
if hasattr(part, 'text'):
content += part.text
# Handle finish_reason
finish_reason = getattr(candidate, 'finish_reason', 0)
if isinstance(finish_reason, int):
finish_reason_name = finish_reason_map.get(finish_reason, "UNKNOWN")
else:
# If it's not an int, try to get the name attribute
finish_reason_name = getattr(finish_reason, 'name', str(finish_reason))
# Build generation info
generation_info = {
"finish_reason": finish_reason_name,
"safety_ratings": []
}
# Handle safety ratings
if hasattr(candidate, 'safety_ratings'):
for rating in candidate.safety_ratings:
try:
rating_dict = {
"category": rating.category.name if hasattr(rating.category, 'name') else str(rating.category),
"probability": rating.probability.name if hasattr(rating.probability, 'name') else str(rating.probability)
}
generation_info["safety_ratings"].append(rating_dict)
except:
# Skip problematic ratings
pass
# Create the generation
if stream:
from langchain_core.messages import AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk
generation = ChatGenerationChunk(
text=content,
generation_info=generation_info,
message=AIMessageChunk(content=content)
)
else:
generation = ChatGeneration(
text=content,
generation_info=generation_info,
message=AIMessage(content=content)
)
generations.append(generation)
return ChatResult(generations=generations, llm_output=llm_output)
else:
# Re-raise if it's a different AttributeError
raise
# Mark as patched
patched_response_to_result._is_patched = True
# Apply the patch
chat_models._response_to_result = patched_response_to_result
except ImportError:
# langchain_google_genai not installed
pass
# Apply the patch when the module is imported
_apply_gemini_patch()
def enumerate_models():
return list(MODELS.keys())
def get_model(
model,
temperature: float = 0.7
):
model_arg = dict()
match MODELS[model]:
case "groq":
provider = ChatGroq
api_key = os.getenv("GROQ_API_KEY")
case "openai":
provider = ChatOpenAI
api_key = os.getenv("OPENAI_API_KEY")
case "google":
# Use the standard ChatGoogleGenerativeAI with our patch applied
provider = ChatGoogleGenerativeAI
api_key = os.getenv("GOOGLE_API_KEY")
model_arg["model"] = model
case "anthropic":
provider = ChatAnthropic
api_key = os.getenv("ANTHROPIC_API_KEY")
case "xai":
provider = ChatXAI
api_key = os.getenv("XAI_API_KEY")
case _:
raise ValueError(f"Unsupported model provider: {MODELS[model]}")
if model == "o3-2025-04-16":
temperature = 1
if len(model_arg) == 0:
model_arg["model_name"] = model
return provider(
api_key=api_key,
temperature=temperature,
**model_arg
)