From b7193296d6fffaa43f540a84f1cb85f83e2b37c8 Mon Sep 17 00:00:00 2001 From: Forrest McKee Date: Wed, 29 Nov 2023 16:31:11 -0600 Subject: [PATCH] Migrate to latest OpenAI API --- language_models.py | 201 ++++++++++++++++++++++++--------------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/language_models.py b/language_models.py index 0db867cf..c203a2c7 100644 --- a/language_models.py +++ b/language_models.py @@ -1,4 +1,5 @@ import openai +from openai import OpenAI import anthropic import os import time @@ -8,36 +9,41 @@ import google.generativeai as palm -class LanguageModel(): +class LanguageModel: def __init__(self, model_name): self.model_name = model_name - - def batched_generate(self, prompts_list: List, max_n_tokens: int, temperature: float): + + def batched_generate( + self, prompts_list: List, max_n_tokens: int, temperature: float + ): """ Generates responses for a batch of prompts using a language model. """ raise NotImplementedError - + + class HuggingFace(LanguageModel): - def __init__(self,model_name, model, tokenizer): + def __init__(self, model_name, model, tokenizer): self.model_name = model_name - self.model = model + self.model = model self.tokenizer = tokenizer self.eos_token_ids = [self.tokenizer.eos_token_id] - def batched_generate(self, - full_prompts_list, - max_n_tokens: int, - temperature: float, - top_p: float = 1.0,): - inputs = self.tokenizer(full_prompts_list, return_tensors='pt', padding=True) + def batched_generate( + self, + full_prompts_list, + max_n_tokens: int, + temperature: float, + top_p: float = 1.0, + ): + inputs = self.tokenizer(full_prompts_list, return_tensors="pt", padding=True) inputs = {k: v.to(self.model.device.index) for k, v in inputs.items()} - + # Batch generation if temperature > 0: output_ids = self.model.generate( **inputs, - max_new_tokens=max_n_tokens, + max_new_tokens=max_n_tokens, do_sample=True, temperature=temperature, eos_token_id=self.eos_token_ids, @@ -46,36 +52,33 @@ def batched_generate(self, else: output_ids = self.model.generate( **inputs, - max_new_tokens=max_n_tokens, + max_new_tokens=max_n_tokens, do_sample=False, eos_token_id=self.eos_token_ids, top_p=1, - temperature=1, # To prevent warning messages + temperature=1, # To prevent warning messages ) - + # If the model is not an encoder-decoder type, slice off the input tokens if not self.model.config.is_encoder_decoder: - output_ids = output_ids[:, inputs["input_ids"].shape[1]:] + output_ids = output_ids[:, inputs["input_ids"].shape[1] :] # Batch decoding outputs_list = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) for key in inputs: - inputs[key].to('cpu') - output_ids.to('cpu') + inputs[key].to("cpu") + output_ids.to("cpu") del inputs, output_ids gc.collect() torch.cuda.empty_cache() return outputs_list - def extend_eos_tokens(self): + def extend_eos_tokens(self): # Add closing braces for Vicuna/Llama eos when using attacker model - self.eos_token_ids.extend([ - self.tokenizer.encode("}")[1], - 29913, - 9092, - 16675]) + self.eos_token_ids.extend([self.tokenizer.encode("}")[1], 29913, 9092, 16675]) + class GPT(LanguageModel): API_RETRY_SLEEP = 10 @@ -83,13 +86,15 @@ class GPT(LanguageModel): API_QUERY_SLEEP = 0.5 API_MAX_RETRY = 5 API_TIMEOUT = 20 - openai.api_key = os.getenv("OPENAI_API_KEY") - def generate(self, conv: List[Dict], - max_n_tokens: int, - temperature: float, - top_p: float): - ''' + def __init__(self, model_name): + self.model_name = model_name + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + def generate( + self, conv: List[Dict], max_n_tokens: int, temperature: float, top_p: float + ): + """ Args: conv: List of dictionaries, OpenAI API format max_n_tokens: int, max number of tokens to generate @@ -97,61 +102,63 @@ def generate(self, conv: List[Dict], top_p: float, top p for sampling Returns: str: generated response - ''' + """ output = self.API_ERROR_OUTPUT for _ in range(self.API_MAX_RETRY): try: - response = openai.ChatCompletion.create( - model = self.model_name, - messages = conv, - max_tokens = max_n_tokens, - temperature = temperature, - top_p = top_p, - request_timeout = self.API_TIMEOUT, - ) - output = response["choices"][0]["message"]["content"] + response = self.client.chat.completions.create( + model=self.model_name, + messages=conv, + max_tokens=max_n_tokens, + temperature=temperature, + top_p=top_p, + timeout=self.API_TIMEOUT, + ) + output = response.choices[0].message.content break - except openai.error.OpenAIError as e: + except openai.OpenAIError as e: print(type(e), e) time.sleep(self.API_RETRY_SLEEP) - + time.sleep(self.API_QUERY_SLEEP) - return output - - def batched_generate(self, - convs_list: List[List[Dict]], - max_n_tokens: int, - temperature: float, - top_p: float = 1.0,): - return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list] - -class Claude(): + return output + + def batched_generate( + self, + convs_list: List[List[Dict]], + max_n_tokens: int, + temperature: float, + top_p: float = 1.0, + ): + return [ + self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list + ] + + +class Claude: API_RETRY_SLEEP = 10 API_ERROR_OUTPUT = "$ERROR$" API_QUERY_SLEEP = 1 API_MAX_RETRY = 5 API_TIMEOUT = 20 API_KEY = os.getenv("ANTHROPIC_API_KEY") - + def __init__(self, model_name) -> None: self.model_name = model_name - self.model= anthropic.Anthropic( + self.model = anthropic.Anthropic( api_key=self.API_KEY, - ) + ) - def generate(self, conv: List, - max_n_tokens: int, - temperature: float, - top_p: float): - ''' + def generate(self, conv: List, max_n_tokens: int, temperature: float, top_p: float): + """ Args: - conv: List of conversations + conv: List of conversations max_n_tokens: int, max number of tokens to generate temperature: float, temperature for sampling top_p: float, top p for sampling Returns: str: generated response - ''' + """ output = self.API_ERROR_OUTPUT for _ in range(self.API_MAX_RETRY): try: @@ -160,25 +167,30 @@ def generate(self, conv: List, max_tokens_to_sample=max_n_tokens, prompt=conv, temperature=temperature, - top_p=top_p + top_p=top_p, ) output = completion.completion break except anthropic.APIError as e: print(type(e), e) time.sleep(self.API_RETRY_SLEEP) - + time.sleep(self.API_QUERY_SLEEP) return output - - def batched_generate(self, - convs_list: List[List[Dict]], - max_n_tokens: int, - temperature: float, - top_p: float = 1.0,): - return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list] - -class PaLM(): + + def batched_generate( + self, + convs_list: List[List[Dict]], + max_n_tokens: int, + temperature: float, + top_p: float = 1.0, + ): + return [ + self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list + ] + + +class PaLM: API_RETRY_SLEEP = 10 API_ERROR_OUTPUT = "$ERROR$" API_QUERY_SLEEP = 1 @@ -191,47 +203,46 @@ def __init__(self, model_name) -> None: self.model_name = model_name palm.configure(api_key=self.API_KEY) - def generate(self, conv: List, - max_n_tokens: int, - temperature: float, - top_p: float): - ''' + def generate(self, conv: List, max_n_tokens: int, temperature: float, top_p: float): + """ Args: - conv: List of dictionaries, + conv: List of dictionaries, max_n_tokens: int, max number of tokens to generate temperature: float, temperature for sampling top_p: float, top p for sampling Returns: str: generated response - ''' + """ output = self.API_ERROR_OUTPUT for _ in range(self.API_MAX_RETRY): try: completion = palm.chat( - messages=conv, - temperature=temperature, - top_p=top_p + messages=conv, temperature=temperature, top_p=top_p ) output = completion.last - + if output is None: # If PaLM refuses to output and returns None, we replace it with a default output output = self.default_output else: # Use this approximation since PaLM does not allow # to specify max_tokens. Each token is approximately 4 characters. - output = output[:(max_n_tokens*4)] + output = output[: (max_n_tokens * 4)] break except Exception as e: print(type(e), e) time.sleep(self.API_RETRY_SLEEP) - + time.sleep(1) return output - - def batched_generate(self, - convs_list: List[List[Dict]], - max_n_tokens: int, - temperature: float, - top_p: float = 1.0,): - return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list] + + def batched_generate( + self, + convs_list: List[List[Dict]], + max_n_tokens: int, + temperature: float, + top_p: float = 1.0, + ): + return [ + self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list + ]