From 286a6d83cc1551572a3a7c11a6d3aea658741832 Mon Sep 17 00:00:00 2001 From: zhangwenwei Date: Mon, 14 Oct 2024 08:59:02 +0000 Subject: [PATCH 1/5] [feat]: allow to set extra header --- lagent/llms/openai.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7d6102a2..a2ee4212 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -2,6 +2,7 @@ import json import os import time +import copy import warnings from concurrent.futures import ThreadPoolExecutor from logging import getLogger @@ -60,6 +61,7 @@ def __init__(self, ], api_base: str = OPENAI_API_BASE, proxies: Optional[Dict] = None, + extra_header: Optional[Dict] = None, **gen_params): if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', @@ -72,6 +74,12 @@ def __init__(self, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) + self.default_header = { + 'content-type': 'application/json', + } + if extra_header: + self.default_header.update(extra_header) + if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -386,9 +394,7 @@ def generate_request_data(self, return '', '' # Initialize the header - header = { - 'content-type': 'application/json', - } + header = copy.deepcopy(self.default_header) # Common parameters processing gen_params['max_tokens'] = max_tokens From f2ec65f4d7c1a024a23bfa11c80960610725fd98 Mon Sep 17 00:00:00 2001 From: zhangwenwei Date: Wed, 16 Oct 2024 03:43:54 +0000 Subject: [PATCH 2/5] [feat]: support to set headers in special cases and handle claude --- lagent/llms/openai.py | 68 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index a2ee4212..93391b4d 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -74,11 +74,11 @@ def __init__(self, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) - self.default_header = { + self.header = { 'content-type': 'application/json', } if extra_header: - self.default_header.update(extra_header) + self.header.update(extra_header) if isinstance(key, str): @@ -227,7 +227,16 @@ def _chat(self, messages: List[dict], **gen_params) -> str: data=json.dumps(data), proxies=self.proxies) response = raw_response.json() - return response['choices'][0]['message']['content'].strip() + if self.model_type.lower().startswith('qwen'): + return response['output']['choices'][0]['message']['content'].strip() + elif self.model_type.lower().startswith('claude'): + if response['msgCode'] == -1: + raise RuntimeError(response['msg']) + elif response['data']['stop_reason'] == 'max_tokens': + raise RuntimeError('max_tokens reached') + return response['data']['content'][0]['text'].strip() + return response['choices'][0]['message'][ + 'content'].strip() except requests.ConnectionError: self.logger.error('Got connection error, retrying...') continue @@ -247,8 +256,10 @@ def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.error('Find error message in response: ' + str(response['error'])) - except Exception as error: - self.logger.error(str(error)) + else: + raise KeyError + # except Exception as error: + # self.logger.error(str(error)) max_num_retries += 1 raise RuntimeError('Calling OpenAI failed after retrying for ' @@ -389,12 +400,12 @@ def generate_request_data(self, gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = min(gen_params.pop('max_new_tokens', 4096), 4096) if max_tokens <= 0: return '', '' # Initialize the header - header = copy.deepcopy(self.default_header) + header = copy.deepcopy(self.header) # Common parameters processing gen_params['max_tokens'] = max_tokens @@ -448,6 +459,14 @@ def generate_request_data(self, **gen_params } } + elif model_type.lower().startswith('claude'): + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + **gen_params + } else: raise NotImplementedError( f'Model type {model_type} is not supported') @@ -508,6 +527,7 @@ def __init__(self, ], api_base: str = OPENAI_API_BASE, proxies: Optional[Dict] = None, + extra_header: Optional[Dict] = None, **gen_params): if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', @@ -520,6 +540,11 @@ def __init__(self, **gen_params) self.gen_params.pop('top_k') self.logger = getLogger(__name__) + self.header = { + 'content-type': 'application/json', + } + if extra_header: + self.header.update(extra_header) if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -646,7 +671,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: break key = self.keys[self.key_ctr] - header['Authorization'] = f'Bearer {key}' + # header['Authorization'] = f'Bearer {key}' if self.orgs: self.org_ctr += 1 @@ -664,6 +689,14 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: proxy=self.proxies.get( 'https', self.proxies.get('http'))) as resp: response = await resp.json() + if self.model_type.lower().startswith('qwen'): + return response['output']['choices'][0]['message']['content'].strip() + elif self.model_type.lower().startswith('claude'): + if response['msgCode'] == -1: + raise RuntimeError(response['msg']) + elif response['data']['stop_reason'] == 'max_tokens': + raise RuntimeError('max_tokens reached') + return response['data']['content'][0]['text'].strip() return response['choices'][0]['message'][ 'content'].strip() except aiohttp.ClientConnectionError: @@ -688,8 +721,10 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: self.logger.error('Find error message in response: ' + str(response['error'])) - except Exception as error: - self.logger.error(str(error)) + else: + raise KeyError + # except Exception as error: + # self.logger.error(str(error)) max_num_retries += 1 raise RuntimeError('Calling OpenAI failed after retrying for ' @@ -838,10 +873,7 @@ def generate_request_data(self, return '', '' # Initialize the header - header = { - 'content-type': 'application/json', - } - + header = copy.deepcopy(self.header) # Common parameters processing gen_params['max_tokens'] = max_tokens if 'stop_words' in gen_params: @@ -894,6 +926,14 @@ def generate_request_data(self, **gen_params } } + elif model_type.lower().startswith('claude'): + gen_params.pop('skip_special_tokens', None) + gen_params.pop('session_id', None) + data = { + 'model': model_type, + 'messages': messages, + **gen_params + } else: raise NotImplementedError( f'Model type {model_type} is not supported') From a5daf69ebc700a4d86111630bc05df44fce3f2e0 Mon Sep 17 00:00:00 2001 From: zhangwenwei Date: Thu, 17 Oct 2024 04:11:31 +0000 Subject: [PATCH 3/5] [fix]: fix msgcode number process --- lagent/llms/openai.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 93391b4d..0ccd0570 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -230,7 +230,8 @@ def _chat(self, messages: List[dict], **gen_params) -> str: if self.model_type.lower().startswith('qwen'): return response['output']['choices'][0]['message']['content'].strip() elif self.model_type.lower().startswith('claude'): - if response['msgCode'] == -1: + print(response) + if response['msgCode'] == '-1': raise RuntimeError(response['msg']) elif response['data']['stop_reason'] == 'max_tokens': raise RuntimeError('max_tokens reached') @@ -692,9 +693,11 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: if self.model_type.lower().startswith('qwen'): return response['output']['choices'][0]['message']['content'].strip() elif self.model_type.lower().startswith('claude'): - if response['msgCode'] == -1: + if response['msgCode'] == '-1': + print(response) raise RuntimeError(response['msg']) elif response['data']['stop_reason'] == 'max_tokens': + print(response) raise RuntimeError('max_tokens reached') return response['data']['content'][0]['text'].strip() return response['choices'][0]['message'][ From e4191cc723e9fd3aca04fbca56e7d72c78f66f52 Mon Sep 17 00:00:00 2001 From: zhangwenwei Date: Fri, 1 Nov 2024 10:55:20 +0000 Subject: [PATCH 4/5] [fix]:print response in error for debug --- lagent/llms/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 0ccd0570..c4c3bd33 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -679,7 +679,6 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: if self.org_ctr == len(self.orgs): self.org_ctr = 0 header['OpenAI-Organization'] = self.orgs[self.org_ctr] - response = dict() try: async with aiohttp.ClientSession() as session: @@ -713,6 +712,7 @@ async def _chat(self, messages: List[dict], **gen_params) -> str: (await resp.text(errors='replace'))) continue except KeyError: + print(response) if 'error' in response: if response['error']['code'] == 'rate_limit_exceeded': time.sleep(1) From 8a38e4d63f2f73df17c40ca8b49475d31c6fa497 Mon Sep 17 00:00:00 2001 From: zhangwenwei Date: Mon, 10 Mar 2025 07:31:56 +0000 Subject: [PATCH 5/5] fix max new token bug --- lagent/llms/openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index c4c3bd33..91279891 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -401,7 +401,7 @@ def generate_request_data(self, gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens', 4096), 4096) + max_tokens = gen_params.pop('max_new_tokens', 4096) if max_tokens <= 0: return '', '' @@ -871,7 +871,7 @@ def generate_request_data(self, gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = gen_params.pop('max_new_tokens', 4096) if max_tokens <= 0: return '', ''