-
Notifications
You must be signed in to change notification settings - Fork 96
Expand file tree
/
Copy pathLLM_CALL.py
More file actions
486 lines (435 loc) · 19.8 KB
/
LLM_CALL.py
File metadata and controls
486 lines (435 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import openai
from openai import AzureOpenAI
from openai import OpenAI
import requests
import time
import os
import json
import requests
import subprocess
from openai import OpenAI
import random
from typing import List, Tuple, Dict, Any, Optional
KEYS_DIR = 'keys'
if not os.path.isdir(KEYS_DIR):
os.makedirs(KEYS_DIR,exist_ok=True)
def convert_openai_tools_to_claude(openai_tools: list) -> list:
claude_tools = []
for tool in openai_tools:
if tool.get("type") != "function":
raise ValueError(f"Unsupported tool type: {tool.get('type')}")
fn = tool["function"]
claude_tools.append({
"name": fn["name"],
"description": fn.get("description", ""),
"input_schema": fn.get("parameters", {"type": "object", "properties": {}})
})
return claude_tools
def normalize_messages_for_tools(
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[List[Dict[str, Any]], List[str]]:
"""
Detects and corrects common Chat Completions tool-message issues:
1) In assistant messages, each entry in `tool_calls` must have:
{
"id": "...",
"type": "function",
"function": {"name": "<fn_name>", "arguments": "<json string>"}
}
- Moves top-level `name` / `arguments` into `function`.
- Ensures `type == "function"`.
- JSON-serializes non-string `arguments`.
2) In tool messages:
- Ensures `content` is a string; JSON-serializes if dict/list.
- Ensures `tool_call_id` exists. If missing, tries to pair with the
most recent unmatched assistant tool_call ID (by order).
3) Removes illegal extra fields at `tool_calls` top level.
Returns:
(fixed_messages, issues)
- fixed_messages: deep-copied, corrected messages list
- issues: human-readable list of detected/corrected problems
"""
fixed = deepcopy(messages)
issues = []
# Build a set of valid function names from `tools` (optional validation)
valid_fn_names = set()
if tools:
for t in tools:
try:
if t.get("type") == "function":
fn = t.get("function", {})
name = fn.get("name")
if isinstance(name, str):
valid_fn_names.add(name)
except Exception:
pass
# Track assistant tool_calls -> to match subsequent tool results
pending_tool_call_ids = []
# First pass: fix assistant tool_calls and record pending IDs
for i, msg in enumerate(fixed):
role = msg.get("role")
if role == "assistant" and isinstance(msg.get("tool_calls"), list):
for j, tc in enumerate(msg["tool_calls"]):
# Ensure container objects exist
if not isinstance(tc, dict):
issues.append(f"[assistant#{i}] tool_calls[{j}] is not an object; replaced with empty object.")
msg["tool_calls"][j] = tc = {}
# Move name/arguments into function
fn_obj = tc.get("function") or {}
moved = False
if "name" in tc:
fn_obj["name"] = tc.pop("name")
moved = True
issues.append(f"[assistant#{i}] tool_calls[{j}]: moved top-level 'name' into 'function.name'.")
if "arguments" in tc:
fn_obj["arguments"] = tc.pop("arguments")
moved = True
issues.append(f"[assistant#{i}] tool_calls[{j}]: moved top-level 'arguments' into 'function.arguments'.")
# Ensure function object present
if "function" not in tc:
tc["function"] = fn_obj if fn_obj else {}
elif moved:
tc["function"].update(fn_obj)
# Ensure type is "function"
if tc.get("type") != "function":
tc["type"] = "function"
issues.append(f"[assistant#{i}] tool_calls[{j}]: set 'type' to 'function'.")
# Ensure arguments is a string
if "arguments" in tc["function"]:
args_val = tc["function"]["arguments"]
if not isinstance(args_val, str):
try:
tc["function"]["arguments"] = json.dumps(args_val, ensure_ascii=False)
issues.append(f"[assistant#{i}] tool_calls[{j}]: JSON-serialized non-string 'function.arguments'.")
except Exception:
tc["function"]["arguments"] = "{}"
issues.append(f"[assistant#{i}] tool_calls[{j}]: failed to serialize arguments; defaulted to '{{}}'.")
else:
# Provide default empty JSON object
tc["function"]["arguments"] = "{}"
issues.append(f"[assistant#{i}] tool_calls[{j}]: added default empty 'function.arguments'.")
# Validate function name if possible
fn_name = tc.get("function", {}).get("name")
if isinstance(fn_name, str):
if valid_fn_names and fn_name not in valid_fn_names:
issues.append(f"[assistant#{i}] tool_calls[{j}]: unknown function '{fn_name}' (not in tools).")
else:
issues.append(f"[assistant#{i}] tool_calls[{j}]: missing 'function.name'.")
# Track pending tool_call_id for pairing
tc_id = tc.get("id")
if isinstance(tc_id, str):
pending_tool_call_ids.append(tc_id)
else:
# If missing id, synthesize a stable one
tc_id = f"call_{i}_{j}"
tc["id"] = tc_id
pending_tool_call_ids.append(tc_id)
issues.append(f"[assistant#{i}] tool_calls[{j}]: synthesized missing 'id' -> '{tc_id}'.")
# Remove illegal top-level keys except allowed
allowed = {"id", "type", "function"}
extraneous = [k for k in list(tc.keys()) if k not in allowed]
for k in extraneous:
tc.pop(k, None)
issues.append(f"[assistant#{i}] tool_calls[{j}]: removed unsupported top-level field '{k}'.")
# Second pass: fix tool messages (pair to pending assistant calls)
# We'll consume from the front of pending_tool_call_ids in order.
for i, msg in enumerate(fixed):
if msg.get("role") == "tool":
# tool_call_id
if not msg.get("tool_call_id"):
if pending_tool_call_ids:
inferred = pending_tool_call_ids.pop(0)
msg["tool_call_id"] = inferred
issues.append(f"[tool#{i}]: added missing 'tool_call_id' -> '{inferred}'.")
else:
issues.append(f"[tool#{i}]: missing 'tool_call_id' and none could be inferred.")
# content must be string
content = msg.get("content")
if not isinstance(content, str):
try:
msg["content"] = json.dumps(content, ensure_ascii=False)
issues.append(f"[tool#{i}]: JSON-serialized non-string 'content'.")
except Exception:
msg["content"] = ""
issues.append(f"[tool#{i}]: failed to serialize content; set to empty string.")
# Remove fields illegal for tool role (defensive)
for bad in ("name", "type", "function"):
if bad in msg:
msg.pop(bad, None)
issues.append(f"[tool#{i}]: removed illegal field '{bad}'.")
# If someone mistakenly returned a tool result as role='assistant' with tool_call_id,
# quietly convert it to role='tool' (optional but handy).
if msg.get("role") == "assistant" and "tool_call_id" in msg:
msg["role"] = "tool"
issues.append(f"[assistant#{i}]: message had 'tool_call_id'; converted role to 'tool'.")
return fixed, issues
def convert_openai_messages_to_claude(openai_messages):
claude_messages = []
for m in openai_messages:
if "tool_calls" in m:
m['content'] += '\n\n'+str(m["tool_calls"])
m.pop("tool_calls")
claude_messages.append(m)
elif m['role']=='tool':
claude_messages.append({
"role": 'user',
"content": "Tool call result: "+m['content']
})
else:
claude_messages.append(m)
return claude_messages
def get_openai_token(p_token_url, p_client_id, p_client_secret, p_scope, **kwargs):
try:
with open(os.path.join(KEYS_DIR,f'openai_key.json')) as f:
key = json.load(f)
if time.time()<key['expire_at']:
return key["access_token"]
except:
pass
response = requests.post(
p_token_url,
data={"grant_type": "client_credentials", "client_id": p_client_id,
"client_secret": p_client_secret, "scope": p_scope}
)
response.raise_for_status()
token = response.json()
with open(os.path.join(KEYS_DIR,f'openai_key.json'),'w') as f:
json.dump({
"access_token": token["access_token"],
'expire_at': time.time()+900
},f,indent=2)
os.chmod(str(os.path.join(KEYS_DIR,f'openai_key.json')), 0o777)
return token["access_token"]
def get_claude_token():
try:
with open(os.path.join(KEYS_DIR,'claude_key.json')) as f:
key = json.load(f)
if time.time()<key['expire_at']:
return key["access_token"]
except:
pass
client_id = os.getenv("CLIENT_ID")
client_secret = os.getenv("CLIENT_SECRET")
command = f"""curl -s --location 'https://5kbfxgaqc3xgz8nhid1x1r8cfestoypn-trofuum-oc.ssa.nvidia.com/token' --header 'Content-Type: application/x-www-form-urlencoded' --header "Authorization: Basic $(echo -n {client_id}:{client_secret} | base64 -w0)" --data-urlencode 'grant_type=client_credentials' --data-urlencode 'scope=awsanthropic-readwrite azureopenai-readwrite' | jq -r '.access_token'"""
result = subprocess.check_output(command, shell=True, text=True).strip()
with open(os.path.join(KEYS_DIR,'claude_key.json'),'w') as f:
json.dump({
"access_token": result,
'expire_at': time.time()+900
},f,indent=2)
os.chmod(str(os.path.join(KEYS_DIR,'claude_key.json')), 0o777)
return result
def get_azure_openai_client():
client_id = os.getenv("CLIENT_ID")
client_secret = os.getenv("CLIENT_SECRET")
token_url = "https://prod.api.nvidia.com/oauth/api/v1/ssa/default/token"
scope = "azureopenai-readwrite"
token = get_openai_token(token_url, client_id, client_secret, scope)
openai.api_type = "azure"
openai.api_base = "https://prod.api.nvidia.com/llm/v1/azure/"
openai.api_version = "2025-04-01-preview"
openai.api_key = token
client = AzureOpenAI(
api_key=token,
api_version="2025-04-01-preview",
azure_endpoint="https://prod.api.nvidia.com/llm/v1/azure/"
)
print("AzureOpenAI instance created successfully.")
return client
def get_openai_client():
client_id = os.getenv("CLIENT_ID")
client_secret = os.getenv("CLIENT_SECRET")
token_url = "https://prod.api.nvidia.com/oauth/api/v1/ssa/default/token"
scope = "azureopenai-readwrite"
print("Getting OAuth token...")
print(f"client_id: {client_id}, client_secret: {client_secret}, token_url: {token_url}, scope: {scope}")
token = get_openai_token(token_url, client_id, client_secret, scope)
client = OpenAI(
api_key=token,
base_url="https://prod.api.nvidia.com/llm/v1/azure/v1",
default_query={"api-version": "preview"},
)
print("OpenAI instance created successfully.")
return client
def get_llm_response(model, messages, temperature=1.0, return_raw_response=False, tools=None, show_messages=False, model_type=None, max_length=1024, model_config=None, model_config_idx=0, model_config_path=None, payload=None, openai_client_type='azure_openai', **kwargs):
if isinstance(messages,str):
messages = [{'role': 'user','content': messages}]
if model in ['o3', 'o3-mini', 'gpt-4o', 'o3-high', 'gpt-5', 'gpt-5-mini', 'gpt-4.1', 'gpt-4o-mini']:
if max_length==1024:
max_length = 40000
if model in ['gpt-4.1','gpt-4o-mini']:
max_length = 8000
answer = ''
while answer=='':
try:
if openai_client_type == 'azure_openai':
openai_client = get_azure_openai_client()
chat_completion = openai_client.chat.completions.create(
model=model,
messages = messages,
temperature = temperature,
tools = tools,
max_completion_tokens = max_length,
)
if return_raw_response:
answer = chat_completion
else:
answer = chat_completion.choices[0].message.content
elif openai_client_type == 'openai_response':
openai_client = get_openai_client()
response = openai_client.responses.create(
model = model,
input = messages,
temperature = temperature,
tools = tools,
max_output_tokens = max_length,
reasoning = {"effort": "high"},
)
if return_raw_response:
answer = response
else:
answer = response.output_text
else:
raise ValueError(f"Invalid openai_client_type: {openai_client_type}")
except Exception as error:
# print(f"SHIZHE DEBUG: tools: {tools}")
print(f"SHIZHE DEBUG: messages: {messages}")
print('[ERROR] Get LLM Response from OpenAI', error)
time.sleep(60)
return answer
elif model_type=='nv/dev':
answer = ''
updated_messages = []
for m in messages:
if 'tool_calls' in m:
m['content'] += str(m['tool_calls'])
m.pop('tool_calls')
updated_messages.append(m)
while answer=='':
try:
oss_client = OpenAI(
base_url = "https://integrate.api.nvidia.com/v1",
api_key = os.getenv("OSS_KEY")
)
if tools:
chat_completion = oss_client.chat.completions.create(
model=model,
messages=updated_messages,
temperature=temperature,
top_p=0.7,
max_tokens=max_length,
tools=tools
)
else:
chat_completion = oss_client.chat.completions.create(
model=model,
messages=updated_messages,
temperature=temperature,
top_p=0.7,
max_tokens=max_length,
)
if return_raw_response:
answer = chat_completion
else:
answer = chat_completion.choices[0].message.content
except Exception as error:
time.sleep(60)
return answer
elif 'qwen' in model.lower() or model_type=='vllm':
answer = ''
while answer=='':
config_idx = random.choice(range(len(model_config)))
ip_addr = model_config[config_idx]["ip_addr"]
port = model_config[config_idx]["port"]
try:
vllm_client = OpenAI(
api_key="EMPTY",
base_url=f"http://{ip_addr}:{port}/v1",
)
chat_completion = vllm_client.chat.completions.create(
model=model,
messages=messages,
max_tokens=max_length,
temperature=temperature,
tools=tools
)
if return_raw_response:
answer = chat_completion
else:
answer = chat_completion.choices[0].message.content
except Exception as error:
print('Error',error)
if os.path.isfile(str(model_config_path)):
# print(f"call {model} error, load {model_config_path}")
with open(model_config_path) as f:
update_model_configs = json.load(f)
model_config = update_model_configs[model]
time.sleep(60)
return answer
elif 'claude' in model.lower():
access_token = get_claude_token()
if 'opus' in model:
endpoint = f"https://prod.api.nvidia.com/llm/v1/aws/model/us.anthropic.claude-opus-4-20250514-v1:0/invoke"
elif 'sonnet' in model:
endpoint = f"https://prod.api.nvidia.com/llm/v1/aws/model/us.anthropic.claude-sonnet-4-20250514-v1:0/invoke"
if not payload:
updated_messages = []
system_message = 'You are a good assistant'
for m in messages:
if m['role'] == 'system':
system_message = m['content']
else:
updated_messages.append(m)
if not tools:
payload = {
"anthropic_version": "bedrock-2023-05-31",
"messages": updated_messages,
"temperature": temperature,
"top_p": 1.0,
"max_tokens": 4096,
'system': system_message,
}
else:
claude_tools = convert_openai_tools_to_claude(tools)
payload = {
"anthropic_version": "bedrock-2023-05-31",
"messages": updated_messages,
"temperature": temperature,
"top_p": 1.0,
"max_tokens": 4096,
'system': system_message,
'tools': claude_tools
}
payload['messages'] = convert_openai_messages_to_claude(payload['messages'])
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
answer = ''
while answer=='':
try:
response = requests.post(endpoint, headers=headers, json=payload)
response.raise_for_status()
if return_raw_response:
answer = response.json()
else:
answer = response.json()['content'][0]['text']
except Exception as error:
time.sleep(60)
return answer