diff --git a/.gitignore b/.gitignore index db809d5..551dc62 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ dist/ build/ -*.egg-info \ No newline at end of file +*.egg-info +__pycache__/ diff --git a/chatgpt/chatgpt.py b/chatgpt/chatgpt.py index 6377be4..3f0165c 100644 --- a/chatgpt/chatgpt.py +++ b/chatgpt/chatgpt.py @@ -8,7 +8,8 @@ import diskcache import openai -from openai.error import OpenAIError +from openai import OpenAIError, AsyncOpenAI, AsyncStream, Stream +from openai.types.chat import ChatCompletionMessage, ChatCompletion, ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam logs_dir = os.path.join(os.getcwd(), '.chatgpt_history/logs') @@ -52,64 +53,71 @@ def sync_wrapper(*args, **kwargs): @retry_on_exception() -def complete(messages=None, model='gpt-4', temperature=0, use_cache=False, **kwargs): +def complete(messages:list[ChatCompletionMessageParam]=None, model='gpt-4', temperature=0, use_cache=False, **kwargs): if use_cache: key = get_key(messages) if key in cache: return cache.get(key) - response = openai.ChatCompletion.create( + response: ChatCompletion | Stream[ChatCompletionChunk] = openai.chat.completions.create( messages=messages, model=model, temperature=temperature, **kwargs ) + stream = kwargs.get('stream', False) + # import pdb; pdb.set_trace() + if stream: + n = kwargs.get('n', 1) + return parse_stream(response, messages, n=n) return parse_response(response, messages, **kwargs) @retry_on_exception() -async def acomplete(messages=None, model='gpt-4', temperature=0, use_cache=False, **kwargs): +async def acomplete(messages:list[ChatCompletionMessageParam]=None, model='gpt-4', temperature=0, use_cache=False, **kwargs): if use_cache: key = get_key(messages) if key in cache: return cache.get(key) - response = await openai.ChatCompletion.acreate( + client = AsyncOpenAI() + response: ChatCompletion | AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( messages=messages, model=model, temperature=temperature, **kwargs ) - return parse_response(response, messages, **kwargs) + stream = kwargs.get('stream', False) + if stream: + n = kwargs.get('n', 1) + return await parse_astream(response, messages, n=n) + return await parse_response(response, messages, **kwargs) -def parse_response(response, messages, **kwargs): +def parse_response(response: ChatCompletion, messages:list[ChatCompletionMessageParam], **kwargs): n = kwargs.get('n', 1) - stream = kwargs.get('stream', False) - if stream: - return parse_stream(response, messages, n=n) + results = [] for choice in response.choices: message = choice.message - if kwargs.get('functions', None) and 'function_call' in message: + if kwargs.get('functions', None) and message.function_call: name = message.function_call.name try: args = json.loads(message.function_call.arguments) except json.decoder.JSONDecodeError as e: print('ERROR: OpenAI returned invalid JSON for function call arguments') raise e - results.append({'role': 'function', 'name': name, 'args': args}) - log_completion(messages + [results[-1]]) - else: - results.append(message.content) - log_completion(messages + [message]) + # results.append({'role': 'function', 'name': name, 'args': args}) + # log_completion(messages, results[-1]) + results.append(message) + log_completion(messages, message) - output = results if n > 1 else results[0] + output = results if n > 1 else results[0] cache.set(get_key(messages), output) return output - -def parse_stream(response, messages, n=1): +def parse_stream(response: Stream[ChatCompletionChunk], messages:list[ChatCompletionMessageParam], n=1): results = ['' for _ in range(n)] + chunk: ChatCompletionChunk for chunk in response: for choice in chunk.choices: if not choice.delta: @@ -125,11 +133,31 @@ def parse_stream(response, messages, n=1): yield (text, idx) for r in results: - log_completion(messages + [{'role': 'assistant', 'content': r}]) + log_completion(messages, r) cache.set(get_key(messages), results) +async def parse_astream(response: AsyncStream[ChatCompletionChunk], messages:list[ChatCompletionMessageParam], n=1): + results = ['' for _ in range(n)] + chunk: ChatCompletionChunk + async for chunk in response: + for choice in chunk.choices: + if not choice.delta: + continue + text = choice.delta.content + if not text: + continue + idx = choice.index + results[idx] += text + if n == 1: + yield text + else: + yield (text, idx) + + for r in results: + log_completion(messages, r) + cache.set(get_key(messages), results) -def log_completion(messages): +def log_completion(messages: list[ChatCompletionMessageParam], result: ChatCompletionMessage = None): timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f') save_path = os.path.join(logs_dir, timestamp + '.txt') @@ -137,22 +165,36 @@ def log_completion(messages): log = "" for message in messages: - log += message['role'].upper() + ' ' + '-'*100 + '\n\n' - if 'name' in message: - log += f"Called function: {message['name']}(" - if 'args' in message: - log += '\n' - for k, v in message['args'].items(): - log += f"\t{k}={repr(v)},\n" - log += ')' - if 'content' in message: - log += '\nContent:\n' + message['content'] - elif 'function_call' in message: - log += f"Called function: {message['function_call'].get('name', 'UNKNOWN')}(\n" - log += ')' - else: - log += message["content"] + log += message['role'].upper() + ' ' + '-'*100 + '\n' + if "content" in message: + log += 'Content:\n' + message['content'] + "\n" + if "function_call" in message: # TODO: remove later since function_call is deprecated + log += f'Call function\n:{message["function_call"]["name"]}({message["function_call"]["arguments"]})\n' + if "tool_calls" in message: + for tool in message["tool_calls"]: + log += f'\nCall {tool["type"]}:\n' + if tool["type"] == 'function': + log += f'{tool["function"]["name"]}({tool["function"]["arguments"]}) id={tool["id"]}\n' + else: + raise NotImplementedError(f'Tool type {tool["type"]} not implemented in logger') log += '\n\n' + + if result: + log += result.role.upper() + ' ' + '-'*100 + '\n' + if result.content: + log += 'Content:\n' + result['content'] + if result.function_call: + log += f'Called function:\n{result.function_call.name}({result.function_call.arguments})\n' + if result.tool_calls: + for tool in result.tool_calls: + log += f'\nCalled {tool.type}:\n' + if tool.type == 'function': + log += f'{tool.function.name}({tool.function.arguments}) id={tool.id}\n' + else: + raise NotImplementedError(f"Tool type {tool.type} not implemented in logger") + + log += '\n\n' + with open(save_path, 'w') as f: f.write(log)