Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 65 additions & 16 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import time
import copy
import warnings
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(self,
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
extra_header: Optional[Dict] = None,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
Expand All @@ -72,6 +74,12 @@ def __init__(self,
**gen_params)
self.gen_params.pop('top_k')
self.logger = getLogger(__name__)
self.header = {
'content-type': 'application/json',
}
if extra_header:
self.header.update(extra_header)


if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
Expand Down Expand Up @@ -219,7 +227,17 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
data=json.dumps(data),
proxies=self.proxies)
response = raw_response.json()
return response['choices'][0]['message']['content'].strip()
if self.model_type.lower().startswith('qwen'):
return response['output']['choices'][0]['message']['content'].strip()
elif self.model_type.lower().startswith('claude'):
print(response)
if response['msgCode'] == '-1':
raise RuntimeError(response['msg'])
elif response['data']['stop_reason'] == 'max_tokens':
raise RuntimeError('max_tokens reached')
return response['data']['content'][0]['text'].strip()
return response['choices'][0]['message'][
'content'].strip()
except requests.ConnectionError:
self.logger.error('Got connection error, retrying...')
continue
Expand All @@ -239,8 +257,10 @@ def _chat(self, messages: List[dict], **gen_params) -> str:

self.logger.error('Find error message in response: ' +
str(response['error']))
except Exception as error:
self.logger.error(str(error))
else:
raise KeyError
# except Exception as error:
# self.logger.error(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
Expand Down Expand Up @@ -381,14 +401,12 @@ def generate_request_data(self,
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
max_tokens = gen_params.pop('max_new_tokens', 4096)
if max_tokens <= 0:
return '', ''

# Initialize the header
header = {
'content-type': 'application/json',
}
header = copy.deepcopy(self.header)

# Common parameters processing
gen_params['max_tokens'] = max_tokens
Expand Down Expand Up @@ -442,6 +460,14 @@ def generate_request_data(self,
**gen_params
}
}
elif model_type.lower().startswith('claude'):
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
data = {
'model': model_type,
'messages': messages,
**gen_params
}
else:
raise NotImplementedError(
f'Model type {model_type} is not supported')
Expand Down Expand Up @@ -502,6 +528,7 @@ def __init__(self,
],
api_base: str = OPENAI_API_BASE,
proxies: Optional[Dict] = None,
extra_header: Optional[Dict] = None,
**gen_params):
if 'top_k' in gen_params:
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
Expand All @@ -514,6 +541,11 @@ def __init__(self,
**gen_params)
self.gen_params.pop('top_k')
self.logger = getLogger(__name__)
self.header = {
'content-type': 'application/json',
}
if extra_header:
self.header.update(extra_header)

if isinstance(key, str):
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
Expand Down Expand Up @@ -640,14 +672,13 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
break

key = self.keys[self.key_ctr]
header['Authorization'] = f'Bearer {key}'
# header['Authorization'] = f'Bearer {key}'

if self.orgs:
self.org_ctr += 1
if self.org_ctr == len(self.orgs):
self.org_ctr = 0
header['OpenAI-Organization'] = self.orgs[self.org_ctr]

response = dict()
try:
async with aiohttp.ClientSession() as session:
Expand All @@ -658,6 +689,16 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
proxy=self.proxies.get(
'https', self.proxies.get('http'))) as resp:
response = await resp.json()
if self.model_type.lower().startswith('qwen'):
return response['output']['choices'][0]['message']['content'].strip()
elif self.model_type.lower().startswith('claude'):
if response['msgCode'] == '-1':
print(response)
raise RuntimeError(response['msg'])
elif response['data']['stop_reason'] == 'max_tokens':
print(response)
raise RuntimeError('max_tokens reached')
return response['data']['content'][0]['text'].strip()
return response['choices'][0]['message'][
'content'].strip()
except aiohttp.ClientConnectionError:
Expand All @@ -671,6 +712,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:
(await resp.text(errors='replace')))
continue
except KeyError:
print(response)
if 'error' in response:
if response['error']['code'] == 'rate_limit_exceeded':
time.sleep(1)
Expand All @@ -682,8 +724,10 @@ async def _chat(self, messages: List[dict], **gen_params) -> str:

self.logger.error('Find error message in response: ' +
str(response['error']))
except Exception as error:
self.logger.error(str(error))
else:
raise KeyError
# except Exception as error:
# self.logger.error(str(error))
max_num_retries += 1

raise RuntimeError('Calling OpenAI failed after retrying for '
Expand Down Expand Up @@ -827,15 +871,12 @@ def generate_request_data(self,
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
max_tokens = gen_params.pop('max_new_tokens', 4096)
if max_tokens <= 0:
return '', ''

# Initialize the header
header = {
'content-type': 'application/json',
}

header = copy.deepcopy(self.header)
# Common parameters processing
gen_params['max_tokens'] = max_tokens
if 'stop_words' in gen_params:
Expand Down Expand Up @@ -888,6 +929,14 @@ def generate_request_data(self,
**gen_params
}
}
elif model_type.lower().startswith('claude'):
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
data = {
'model': model_type,
'messages': messages,
**gen_params
}
else:
raise NotImplementedError(
f'Model type {model_type} is not supported')
Expand Down