diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7d6102a2..91279891 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -2,6 +2,7 @@ import json import os import time +import copy import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger @@ -60,6 +61,7 @@ def __init__(self, ], api_base: str = OPENAI_API_BASE, proxies: Optional[Dict] = None, + extra_header: Optional[Dict] = None, **gen_params): if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', @@ -72,6 +74,12 @@ def __init__(self, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) + self.header = { + 'content-type': 'application/json', + } + if extra_header: + self.header.update(extra_header) + if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -219,7 +227,17 @@ def _chat(self, messages: List[dict], **gen_params) -> str: data=json.dumps(data), proxies=self.proxies) response = raw_response.json() - return response['choices'][0]['message']['content'].strip() + if self.model_type.lower().startswith('qwen'): + return response['output']['choices'][0]['message']['content'].strip() + elif self.model_type.lower().startswith('claude'): + print(response) + if response['msgCode'] == '-1': + raise RuntimeError(response['msg']) + elif response['data']['stop_reason'] == 'max_tokens': + raise RuntimeError('max_tokens reached') + return response['data']['content'][0]['text'].strip() + return response['choices'][0]['message'][ + 'content'].strip() except requests.ConnectionError: self.logger.error('Got connection error, retrying...') continue @@ -239,8 +257,10 @@ def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.error('Find error message in response: ' + str(response['error'])) - except Exception as error: - self.logger.error(str(error)) + else: + raise KeyError + # except Exception as error: + # self.logger.error(str(error)) max_num_retries += 1 raise RuntimeError('Calling OpenAI failed after retrying for ' @@ -381,14 +401,12 @@ def generate_request_data(self, gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = gen_params.pop('max_new_tokens', 4096) if max_tokens <= 0: return '', '' # Initialize the header - header = { - 'content-type': 'application/json', - } + header = copy.deepcopy(self.header) # Common parameters processing gen_params['max_tokens'] = max_tokens @@ -442,6 +460,14 @@ def generate_request_data(self, **gen_params } } + elif model_type.lower().startswith('claude'): + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + **gen_params + } else: raise NotImplementedError( f'Model type {model_type} is not supported') @@ -502,6 +528,7 @@ def __init__(self, ], api_base: str = OPENAI_API_BASE, proxies: Optional[Dict] = None, + extra_header: Optional[Dict] = None, **gen_params): if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', @@ -514,6 +541,11 @@ def __init__(self, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) + self.header = { + 'content-type': 'application/json', + } + if extra_header: + self.header.update(extra_header) if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -640,14 +672,13 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: break key = self.keys[self.key_ctr] - header['Authorization'] = f'Bearer {key}' + # header['Authorization'] = f'Bearer {key}' if self.orgs: self.org_ctr += 1 if self.org_ctr == len(self.orgs): self.org_ctr = 0 header['OpenAI-Organization'] = self.orgs[self.org_ctr] - response = dict() try: async with aiohttp.ClientSession() as session: @@ -658,6 +689,16 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: proxy=self.proxies.get( 'https', self.proxies.get('http'))) as resp: response = await resp.json() + if self.model_type.lower().startswith('qwen'): + return response['output']['choices'][0]['message']['content'].strip() + elif self.model_type.lower().startswith('claude'): + if response['msgCode'] == '-1': + print(response) + raise RuntimeError(response['msg']) + elif response['data']['stop_reason'] == 'max_tokens': + print(response) + raise RuntimeError('max_tokens reached') + return response['data']['content'][0]['text'].strip() return response['choices'][0]['message'][ 'content'].strip() except aiohttp.ClientConnectionError: @@ -671,6 +712,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: (await resp.text(errors='replace'))) continue except KeyError: + print(response) if 'error' in response: if response['error']['code'] == 'rate_limit_exceeded': time.sleep(1) @@ -682,8 +724,10 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.error('Find error message in response: ' + str(response['error'])) - except Exception as error: - self.logger.error(str(error)) + else: + raise KeyError + # except Exception as error: + # self.logger.error(str(error)) max_num_retries += 1 raise RuntimeError('Calling OpenAI failed after retrying for ' @@ -827,15 +871,12 @@ def generate_request_data(self, gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = gen_params.pop('max_new_tokens', 4096) if max_tokens <= 0: return '', '' # Initialize the header - header = { - 'content-type': 'application/json', - } - + header = copy.deepcopy(self.header) # Common parameters processing gen_params['max_tokens'] = max_tokens if 'stop_words' in gen_params: @@ -888,6 +929,14 @@ def generate_request_data(self, **gen_params } } + elif model_type.lower().startswith('claude'): + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + **gen_params + } else: raise NotImplementedError( f'Model type {model_type} is not supported')