From c7bc5565009a977a083681fd2f320d81c610440a Mon Sep 17 00:00:00 2001 From: Jasper Timm Date: Thu, 29 Jan 2026 07:29:02 +0000 Subject: [PATCH 1/3] Update for Gemini 3 exps --- configs/config.yaml | 1 + configs/experiment/gemini_3_pro_high.yaml | 26 ++ .../gemini_3_pro_high_aistudio.yaml | 26 ++ configs/experiment/gemini_3_pro_low.yaml | 26 ++ .../experiment/gemini_3_pro_low_aistudio.yaml | 26 ++ configs/experiment/opus_45_low.yaml | 26 ++ main.py | 4 + src/generate_conversations/generate.py | 397 +++++++++++++++--- 8 files changed, 470 insertions(+), 62 deletions(-) create mode 100644 configs/experiment/gemini_3_pro_high.yaml create mode 100644 configs/experiment/gemini_3_pro_high_aistudio.yaml create mode 100644 configs/experiment/gemini_3_pro_low.yaml create mode 100644 configs/experiment/gemini_3_pro_low_aistudio.yaml create mode 100644 configs/experiment/opus_45_low.yaml diff --git a/configs/config.yaml b/configs/config.yaml index b13ab3f..213e139 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -18,6 +18,7 @@ persuader_model: gpt-4o persuadee_model: gpt-4o evaluator_model: gpt-4o reject_model: gpt-4o-mini +reasoning_effort_persuader: null # Jailbreaking options jailbreak_persuader: false diff --git a/configs/experiment/gemini_3_pro_high.yaml b/configs/experiment/gemini_3_pro_high.yaml new file mode 100644 index 0000000..8b9e721 --- /dev/null +++ b/configs/experiment/gemini_3_pro_high.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_high +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: high +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_high_aistudio.yaml b/configs/experiment/gemini_3_pro_high_aistudio.yaml new file mode 100644 index 0000000..8b9e721 --- /dev/null +++ b/configs/experiment/gemini_3_pro_high_aistudio.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_high +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: high +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_low.yaml b/configs/experiment/gemini_3_pro_low.yaml new file mode 100644 index 0000000..cbf67f9 --- /dev/null +++ b/configs/experiment/gemini_3_pro_low.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_low_aistudio.yaml b/configs/experiment/gemini_3_pro_low_aistudio.yaml new file mode 100644 index 0000000..0aab2cf --- /dev/null +++ b/configs/experiment/gemini_3_pro_low_aistudio.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: gemini/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/opus_45_low.yaml b/configs/experiment/opus_45_low.yaml new file mode 100644 index 0000000..da3bcc8 --- /dev/null +++ b/configs/experiment/opus_45_low.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Opus 4.5 Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: anthropic/claude-opus-4-5-20251101 +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: opus_45_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/main.py b/main.py index 4ced630..3d75f69 100644 --- a/main.py +++ b/main.py @@ -90,6 +90,7 @@ def main(cfg: DictConfig) -> None: PERSUADEE_MODEL = cfg.persuadee_model EVALUATOR_MODEL = cfg.evaluator_model REJECT_MODEL = cfg.reject_model + REASONING_EFFORT_PERSUADER = cfg.reasoning_effort_persuader JAILBREAK_TUNING_POSTPROCESS = cfg.jailbreak_persuader JAILBREAK_TUNING_PERSUADEE = cfg.jailbreak_persuadee EXPERIMENT_NAME = cfg.experiment_name @@ -177,6 +178,7 @@ def main(cfg: DictConfig) -> None: "PERSUADEE_MODEL": PERSUADEE_MODEL, "EVALUATOR_MODEL": EVALUATOR_MODEL, "REJECT_MODEL": REJECT_MODEL, + "REASONING_EFFORT_PERSUADER": REASONING_EFFORT_PERSUADER, "JAILBREAK_TUNING_POSTPROCESS": JAILBREAK_TUNING_POSTPROCESS, "RECORD_ALL_SYSTEM": cfg.record_all_system, "REMOVE_RATINGS": cfg.remove_ratings, @@ -470,6 +472,7 @@ def main(cfg: DictConfig) -> None: model=PERSUADER_MODEL, postprocess_responses=JAILBREAK_TUNING_POSTPROCESS, remove_ratings=cfg.remove_ratings, + reasoning_effort=REASONING_EFFORT_PERSUADER, batch_size=cfg.batch_size, ) @@ -660,6 +663,7 @@ def run_refusal_check(conv, system_message, turn): model=PERSUADER_MODEL, postprocess_responses=JAILBREAK_TUNING_POSTPROCESS, remove_ratings=cfg.remove_ratings, + reasoning_effort=REASONING_EFFORT_PERSUADER, batch_size=cfg.batch_size, ) diff --git a/src/generate_conversations/generate.py b/src/generate_conversations/generate.py index 0dbb5d6..950b00f 100644 --- a/src/generate_conversations/generate.py +++ b/src/generate_conversations/generate.py @@ -3,23 +3,78 @@ """ import asyncio -import logging import re -from typing import Any, Dict, List, cast - +import time +import logging +from tqdm import tqdm import torch from litellm import acompletion -from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig +import litellm # Add import for litellm to access RateLimitError +from typing import Any, Dict, List +import openai # Add OpenAI import for direct API calls -# LiteLLM and httpx log each request at INFO level. Set logging level to WARNING logging.getLogger("LiteLLM").setLevel(logging.WARNING) -logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("litellm").setLevel(logging.WARNING) +if hasattr(litellm, "set_verbose") and callable(litellm.set_verbose): + litellm.set_verbose(False) +if hasattr(litellm, "suppress_debug_info"): + litellm.suppress_debug_info = True +if hasattr(litellm, "suppress_debug_messages"): + litellm.suppress_debug_messages = True + + +class NullContentRetryableError(Exception): + """Exception raised when model returns null content with finish_reason 'stop' that should be retried.""" + def __init__(self, finish_reason: str, diagnostic_info: dict = None): + self.finish_reason = finish_reason + self.diagnostic_info = diagnostic_info or {} + super().__init__(f"Null content with finish_reason: {finish_reason}") + +# Global variables for tracking reasoning tokens +TOTAL_REASONING_TOKENS = 0 +REASONING_RESPONSES_COUNT = 0 # Cache for storing loaded pipeline generators MODEL_CACHE = {} + + + +def parse_gptoss_response(generated_text: str) -> str: + """ + Parse gpt-oss response to extract the final response from harmony format. + + Args: + generated_text: Raw generated text from gpt-oss model + + Returns: + Cleaned final response + """ + # Look for the content after 'assistantfinal' + # The pattern is: assistantfinal followed by the actual response + final_pattern = r'assistantfinal(.*?)(?:<\|end\|>|$)' + final_match = re.search(final_pattern, generated_text, re.DOTALL) + + if final_match: + final_response = final_match.group(1).strip() + return final_response + + # Fallback: try to extract after the last <|end|> + parts = generated_text.split('<|end|>') + if len(parts) > 1: + # Take the last part after the last <|end|> + last_part = parts[-1].strip() + return last_part + + # Fallback: return the original text + return generated_text + + + + + def postprocess_message(message): """ Postprocess the message to remove jailbreak-tuning artifacts. @@ -61,14 +116,40 @@ def preload_local_model(model_name): tokenizer.pad_token = tokenizer.eos_token else: # Add a pad token if there's no eos token to use - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - hf_llm = AutoModelForCausalLM.from_pretrained( - local_path, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + + # Special handling for gpt-oss models + if "gpt-oss" in model_name.lower(): + print(f"Loading {model_name} with native MXFP4 quantization") + + # Load with native quantization config from the model + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + torch_dtype="auto", + ) + # load in 4-bit mode for 70b models + elif "70B" in model_name: + print(f"Loading {model_name} in 4-bit mode") + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=False, + ) + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + trust_remote_code=True, + quantization_config=bnb_config, + ) + else: + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) # Resize model embeddings if pad token was added if tokenizer.pad_token == "[PAD]": hf_llm.resize_token_embeddings(len(tokenizer)) @@ -153,24 +234,26 @@ def format_prompt_for_model( """ # Handle Qwen models with thinking disabled if is_qwen_model(model): - return cast( - str, - tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, # Disable thinking mode for Qwen - ), + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, # Disable thinking mode for Qwen + ) + elif "gpt-oss" in model.lower(): + # For gpt-oss models, set reasoning_effort to low + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + # reasoning_effort="low", ) else: - # Default Llama formatting - return cast( - str, - tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ), + # Default formatting for all other models + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, ) @@ -206,6 +289,42 @@ def get_generation_params(model: str, temperature: float) -> Dict[str, Any]: return params +async def generate_with_openai_client( + message_collection: List[List[Dict[str, str]]], + temperature: float = 0.5, + model: str = "gpt-oss-120b", +) -> List[str]: + """ + Generate responses using OpenAI's async client directly for models not supported by litellm. + + Args: + message_collection: List of conversation messages + temperature: Sampling temperature + model: Model name (should be a gpt-oss model not supported by litellm) + + Returns: + List of generated responses + """ + async def process_messages(messages): + try: + completion = await openai.AsyncOpenAI().chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=2048, + ) + return completion.choices[0].message.content + except Exception as e: + print(f"Error processing prompt for {model}: {e}") + return f"Error processing prompt: {e}" + + # Process all messages concurrently + tasks = [process_messages(message_list) for message_list in message_collection] + responses = await asyncio.gather(*tasks) + + return responses + + def generate_with_local_model( message_collection: List[List[Dict[str, str]]], model: str, @@ -241,12 +360,23 @@ def generate_with_local_model( # Add a pad token if there's no eos token to use tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - hf_llm = AutoModelForCausalLM.from_pretrained( - local_path, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) + # Special handling for gpt-oss models + if "gpt-oss" in model.lower(): + print(f"Loading {model} with native MXFP4 quantization") + + # Load with native quantization config from the model + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + torch_dtype="auto", + ) + else: + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) # Resize model embeddings if pad token was added if tokenizer.pad_token == "[PAD]": hf_llm.resize_token_embeddings(len(tokenizer)) @@ -300,20 +430,29 @@ def generate_with_local_model( for i, outputs in enumerate(batch_outputs): try: # Extract the generated response based on structure - # Hugging Face pipeline always returns a list - generated_text = outputs[0]["generated_text"] + # If outputs is a list (like [{'generated_text': '...'}]), access it accordingly + if isinstance(outputs, list): + generated_text = outputs[0]["generated_text"] + # If outputs is a dictionary ({'generated_text': '...'}), access directly + else: + generated_text = outputs["generated_text"] # Extract response based on model type - if is_qwen_model(model): + if "gpt-oss" in model.lower(): + # For gpt-oss models, parse the harmony format + response = parse_gptoss_response(generated_text) + elif is_qwen_model(model): # For Qwen, properly clean the response to remove thinking blocks response = clean_qwen_response(generated_text) else: # For Llama models - response = generated_text.split("<|end_header_id|>\n\n")[-1] + response = generated_text.split("<|end_header_id|>\n\n")[-1] all_responses.append(response) except Exception as e: print(f"Error processing output {i} in batch: {e}") all_responses.append(f"Error processing response: {e}") + + except Exception as e: # If batch processing fails, fall back to individual processing print( @@ -324,21 +463,29 @@ def generate_with_local_model( outputs = generator( prompt, pad_token_id=tokenizer.pad_token_id, **generation_params ) - # Hugging Face pipeline always returns a list - generated_text = outputs[0]["generated_text"] + # Check for the expected format (list or dict) + if isinstance(outputs, list): + generated_text = outputs[0]["generated_text"] + else: + generated_text = outputs["generated_text"] # Extract response based on model type - if is_qwen_model(model): + if "gpt-oss" in model.lower(): + # For gpt-oss models, parse the harmony format + response = parse_gptoss_response(generated_text) + elif is_qwen_model(model): # For Qwen, properly clean the response to remove thinking blocks response = clean_qwen_response(generated_text) else: # For Llama models response = generated_text.split("<|end_header_id|>\n\n")[-1] - + all_responses.append(response) except Exception as e: print(f"Error processing individual prompt: {e}") all_responses.append(f"Error processing prompt: {e}") + + return all_responses @@ -349,6 +496,8 @@ async def generate_llm( model: str = "gpt-4o-mini", postprocess_responses: bool = False, batch_size: int = 4, + vertex_thinking_mode: bool = False, + reasoning_effort: str = None, **kwargs, ) -> List[str]: """ @@ -360,10 +509,17 @@ async def generate_llm( model: Model name (e.g., "hf/Meta-Llama-3.1-8B-Instruct" or "gpt-4o") postprocess_responses: Whether to apply postprocessing to responses batch_size: Number of prompts to process in a single batch (for local models) - + vertex_thinking_mode: Whether to enable thinking mode for Vertex AI models + reasoning_effort: Reasoning effort level ('low', 'medium', 'high', or None) Returns: List of generated responses """ + # List of models that need a temperature of 1.0 + temperature_1_0_models = ["o4-mini", "o3", "gpt-5-mini", "gpt-5", "gpt-5.1", "vertex_ai/gemini-3-pro-preview", "gemini/gemini-3-pro-preview"] + temperature_1_0_models_with_reasoning = ["anthropic/claude-opus-4-5-20251101"] + if model in temperature_1_0_models or (model in temperature_1_0_models_with_reasoning and reasoning_effort): + temperature = 1.0 + # Handle local models synchronously with batching if model.startswith("hf/"): responses = generate_with_local_model( @@ -372,28 +528,145 @@ async def generate_llm( temperature=temperature, batch_size=batch_size, ) - # Handle cloud API models asynchronously + # Handle gpt-oss models with OpenAI async client directly (not supported by litellm) + elif model.startswith("gpt-oss"): + responses = await generate_with_openai_client( + message_collection=message_collection, + temperature=temperature, + model=model, + ) + # Handle cloud API models else: async def process_messages(messages): - try: - # Create a dictionary of arguments for acompletion - completion_args = { - "model": model, - "messages": messages, - "temperature": temperature, - } - - completion = await acompletion(**completion_args) - return completion.choices[0].message.content - except Exception as e: - print(f"Error processing prompt: {e}") - return f"Error processing prompt: {e}" - - tasks = [process_messages(message_list) for message_list in message_collection] + # Exponential backoff parameters for rate limiting + max_attempts = 5 + initial_sleep_time = 1 # seconds + backoff_factor = 1.5 + + # Build completion arguments (model-specific setup) + completion_args = { + "model": model, + "messages": messages, + "temperature": temperature, + } + + if reasoning_effort: + completion_args["reasoning_effort"] = reasoning_effort + + # Gemini 3 Pro Preview uses global location + if model == "vertex_ai/gemini-3-pro-preview" or model == "vertex_ai/gemini-3-flash-preview": + completion_args["vertex_location"] = "global" + + if model.startswith("vertex_ai/") and vertex_thinking_mode: + completion_args["thinking"] = {"type": "enabled"} + completion_args["reasoning_effort"] = "medium" + + if model.startswith("vertex_ai/") and len(model.split('/')) > 1 and model.split("/")[1].isdigit(): + completion_args["base_model"] = "vertex_ai/gemini-1.5-flash-latest" + + # Retry loop with exponential backoff for rate limiting + sleep_attempts = 0 + current_sleep_time = initial_sleep_time + while sleep_attempts < max_attempts: + try: + completion = await acompletion(**completion_args) + choice = completion.choices[0] + response_content = choice.message.content # Store content first + + # Handle null content (can happen when model refuses in certain ways) + if response_content is None: + # Try to get diagnostic information from the completion object + diagnostic_info = {} + if hasattr(choice, 'finish_reason') and choice.finish_reason: + diagnostic_info['finish_reason'] = choice.finish_reason + if hasattr(choice, 'stop_reason') and choice.stop_reason: + diagnostic_info['stop_reason'] = choice.stop_reason + if hasattr(choice.message, 'refusal') and choice.message.refusal: + diagnostic_info['refusal'] = choice.message.refusal + if hasattr(completion, 'model_extra') and completion.model_extra: + diagnostic_info['model_extra_keys'] = list(completion.model_extra.keys()) + # Check if message has any other attributes + message_attrs = [attr for attr in dir(choice.message) if not attr.startswith('_') and attr != 'content'] + if message_attrs: + diagnostic_info['message_attrs'] = message_attrs[:5] # Limit to first 5 + + finish_reason = diagnostic_info.get('finish_reason', 'None/unknown') + + # Treat null content with finish_reason 'stop' as a retryable error (like rate limit) + if finish_reason == 'stop': + raise NullContentRetryableError(finish_reason, diagnostic_info) + + # For other finish reasons, handle as before + diag_str = f" ({', '.join(f'{k}: {v}' for k, v in diagnostic_info.items())})" if diagnostic_info else "" + + if diagnostic_info.get('finish_reason') == 'refusal': + return "[FINISH_REASON] Sorry I cannot assist with that request." + else: + print(f"Warning: {model} returned null content{diag_str}") + return f"[ERROR] Model returned null content, finishing reason: {finish_reason}" + + # print the reasoning content for vertex ai models for debugging + # if model.startswith("vertex_ai/") and vertex_thinking_mode: + # try: + # print(completion.choices[0].model_extra['message'].reasoning_content) + # except Exception as e: + # print(f"Could not display thinking process: {e}") + + if model == "o4-mini" or model == "o3": + global TOTAL_REASONING_TOKENS, REASONING_RESPONSES_COUNT + reasoning_tokens = None + try: + reasoning_tokens = completion.model_extra['usage'].completion_tokens_details.reasoning_tokens + TOTAL_REASONING_TOKENS += reasoning_tokens + REASONING_RESPONSES_COUNT += 1 + except Exception as e: + print(f"Could not extract or process reasoning tokens for o4-mini: {e}") + + return response_content # Return the stored content + except (litellm.RateLimitError, litellm.BadGatewayError, NullContentRetryableError) as e: # Retry with backoff for rate limits, server errors, and null content with stop + if isinstance(e, NullContentRetryableError): + error_type = "Null content (stop)" + diag_str = f" ({', '.join(f'{k}: {v}' for k, v in e.diagnostic_info.items())})" if e.diagnostic_info else "" + error_msg = f"{error_type} for {model}: finish_reason={e.finish_reason}{diag_str}" + elif isinstance(e, litellm.RateLimitError): + error_type = "Rate limit" + error_msg = f"{error_type} for {model}: {e}" + else: + error_type = "Server error (5xx)" + error_msg = f"{error_type} for {model}: {e}" + + print(f"{error_msg}. Attempt {sleep_attempts + 1}/{max_attempts}.") + sleep_attempts += 1 + if sleep_attempts >= max_attempts: + print(f"Max attempts reached for {model}. Error: {e}") + return f"[ERROR] Error processing prompt after multiple retries: {e}" + print(f"Waiting for {current_sleep_time} seconds before retrying...") + await asyncio.sleep(current_sleep_time) # Use asyncio.sleep for async functions + current_sleep_time *= backoff_factor + except litellm.BadRequestError as e: + print(f"Bad request error for {model}, using fallback refusal response.: {e}.") + return f"[BAD_REQUEST] Sorry I cannot assist with that request." + except Exception as e: + print(f"Non-retryable error processing prompt for {model}: {e}") + return f"[ERROR] Error processing prompt: {e}" + return f"[ERROR] Max retries exceeded for {model}." + + # Orchestrate calls to process_messages + # if "claude" in model.lower() or model.startswith("anthropic/"): + # responses = [] + # for message_list in tqdm(message_collection, desc=f"Processing prompts for {model} sequentially"): + # response = await process_messages(message_list) + # responses.append(response) + # else: # For other (non-Anthropic) cloud models, use concurrency + tasks = [process_messages(message_list) for message_list in tqdm(message_collection, desc=f"Processing prompts for {model} concurrently")] responses = await asyncio.gather(*tasks) if postprocess_responses: responses = [postprocess_message(response) for response in responses] + if model == "o4-mini" and REASONING_RESPONSES_COUNT > 0: + average_reasoning_tokens = TOTAL_REASONING_TOKENS / REASONING_RESPONSES_COUNT + print(f"Average reasoning tokens per o4-mini response: {average_reasoning_tokens:.2f} (Total: {TOTAL_REASONING_TOKENS}, Count: {REASONING_RESPONSES_COUNT})") + return responses From cf02002cdc2e22c47ab0e38240fc087d0738d19a Mon Sep 17 00:00:00 2001 From: Jasper Timm Date: Fri, 30 Jan 2026 00:58:19 +0000 Subject: [PATCH 2/3] Fix up dep handling to ensure locked versions --- README.md | 17 +- pyproject.toml | 8 +- reqs-all.txt | 487 +++++++++++++++++++++++++ reqs-dev.txt | 397 ++++++++++++++++++++ reqs.txt | 319 ++++++++++++++++ src/generate_conversations/generate.py | 23 +- 6 files changed, 1243 insertions(+), 8 deletions(-) create mode 100644 reqs-all.txt create mode 100644 reqs-dev.txt create mode 100644 reqs.txt diff --git a/README.md b/README.md index d81a6ab..2eff9ae 100644 --- a/README.md +++ b/README.md @@ -69,10 +69,23 @@ huggingface-cli download Qwen/Qwen3-32B-Instruct --local-dir src/ckpts/Qwen3-32B ### Install dependencies -Dependencies are in `pyproject.toml`, install them with: +Locked dependency versions are recorded in `reqs.txt` (generated from `pyproject.toml`). +Install them with: ```bash -pip install -e ".[dev,test]" +pip install -r reqs.txt +``` + +For development and testing you can use: + +```bash +pip install -r reqs-dev.txt +``` + +And for all dependencies including those required for running local models: + +```bash +pip install -r reqs-all.txt ``` ### Execute evals diff --git a/pyproject.toml b/pyproject.toml index e039df8..06b65db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,6 @@ dependencies = [ "google-cloud-aiplatform", "seaborn", "scikit-learn", - "transformers", - "torch", - "accelerate", "tabulate", "hydra-core" ] @@ -40,6 +37,11 @@ test = [ "pytest", "pytest-asyncio" ] +local = [ + "transformers", + "torch", + "accelerate" +] [tool.logviz-dependencies] Flask = "^2.3.2" diff --git a/reqs-all.txt b/reqs-all.txt new file mode 100644 index 0000000..da169d9 --- /dev/null +++ b/reqs-all.txt @@ -0,0 +1,487 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml --extra dev --extra test --extra local -o reqs-all.txt +accelerate==1.12.0 + # via ape (pyproject.toml) +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +black==26.1.0 + # via ape (pyproject.toml) +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +cfgv==3.5.0 + # via pre-commit +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # black + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cuda-bindings==12.9.4 + # via torch +cuda-pathfinder==1.3.3 + # via cuda-bindings +cycler==0.12.1 + # via matplotlib +distlib==0.4.0 + # via virtualenv +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via + # huggingface-hub + # torch + # transformers + # virtualenv +flake8==7.3.0 + # via ape (pyproject.toml) +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via + # huggingface-hub + # torch +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.5 + # via + # accelerate + # tokenizers + # transformers +hydra-core==1.3.2 + # via ape (pyproject.toml) +identify==2.6.16 + # via pre-commit +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +iniconfig==2.3.0 + # via pytest +isort==7.0.0 + # via ape (pyproject.toml) +jinja2==3.1.6 + # via + # litellm + # torch +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +librt==0.7.8 + # via mypy +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +mccabe==0.7.0 + # via flake8 +mpmath==1.3.0 + # via sympy +multidict==6.7.1 + # via + # aiohttp + # yarl +mypy==1.19.1 + # via ape (pyproject.toml) +mypy-extensions==1.1.0 + # via + # black + # mypy +networkx==3.6.1 + # via torch +nodeenv==1.10.0 + # via pre-commit +numpy==2.4.1 + # via + # accelerate + # contourpy + # matplotlib + # pandas + # pandas-stubs + # scikit-learn + # scipy + # seaborn + # transformers +nvidia-cublas-cu12==12.8.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via torch +nvidia-cuda-runtime-cu12==12.8.90 + # via torch +nvidia-cudnn-cu12==9.10.2.21 + # via torch +nvidia-cufft-cu12==11.3.3.83 + # via torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via torch +nvidia-cusolver-cu12==11.7.3.90 + # via torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.5 + # via torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvshmem-cu12==3.4.5 + # via torch +nvidia-nvtx-cu12==12.8.90 + # via torch +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # accelerate + # black + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib + # pytest + # transformers +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pandas-stubs==2.3.3.260113 + # via ape (pyproject.toml) +pathspec==1.0.4 + # via + # black + # mypy +pillow==12.1.0 + # via matplotlib +platformdirs==4.5.1 + # via + # black + # virtualenv +pluggy==1.6.0 + # via pytest +pre-commit==4.5.1 + # via ape (pyproject.toml) +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.5 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +psutil==7.2.2 + # via accelerate +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycodestyle==2.14.0 + # via flake8 +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyflakes==3.4.0 + # via flake8 +pygments==2.19.2 + # via pytest +pyparsing==3.3.2 + # via matplotlib +pytest==9.0.2 + # via + # ape (pyproject.toml) + # pytest-asyncio +pytest-asyncio==1.3.0 + # via ape (pyproject.toml) +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytokens==0.4.0 + # via black +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # accelerate + # huggingface-hub + # omegaconf + # pre-commit + # transformers +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via + # tiktoken + # transformers +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +safetensors==0.7.0 + # via + # accelerate + # transformers +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +setuptools==80.10.2 + # via torch +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +sympy==1.14.0 + # via torch +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via + # litellm + # transformers +torch==2.10.0 + # via + # ape (pyproject.toml) + # accelerate +tqdm==4.67.1 + # via + # huggingface-hub + # openai + # transformers +transformers==5.0.0 + # via ape (pyproject.toml) +triton==3.6.0 + # via torch +typer-slim==0.21.1 + # via + # huggingface-hub + # transformers +types-pytz==2025.2.0.20251108 + # via + # ape (pyproject.toml) + # pandas-stubs +types-pyyaml==6.0.12.20250915 + # via ape (pyproject.toml) +types-requests==2.32.4.20260107 + # via types-tqdm +types-tabulate==0.9.0.20241207 + # via ape (pyproject.toml) +types-tqdm==4.67.0.20250809 + # via ape (pyproject.toml) +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # mypy + # openai + # pydantic + # pydantic-core + # pytest-asyncio + # referencing + # torch + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via + # requests + # types-requests +virtualenv==20.36.1 + # via pre-commit +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/reqs-dev.txt b/reqs-dev.txt new file mode 100644 index 0000000..5869a21 --- /dev/null +++ b/reqs-dev.txt @@ -0,0 +1,397 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml --extra dev --extra test -o reqs-dev.txt +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +black==26.1.0 + # via ape (pyproject.toml) +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +cfgv==3.5.0 + # via pre-commit +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # black + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cycler==0.12.1 + # via matplotlib +distlib==0.4.0 + # via virtualenv +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via + # huggingface-hub + # virtualenv +flake8==7.3.0 + # via ape (pyproject.toml) +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via huggingface-hub +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.5 + # via tokenizers +hydra-core==1.3.2 + # via ape (pyproject.toml) +identify==2.6.16 + # via pre-commit +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +iniconfig==2.3.0 + # via pytest +isort==7.0.0 + # via ape (pyproject.toml) +jinja2==3.1.6 + # via litellm +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +librt==0.7.8 + # via mypy +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +mccabe==0.7.0 + # via flake8 +multidict==6.7.1 + # via + # aiohttp + # yarl +mypy==1.19.1 + # via ape (pyproject.toml) +mypy-extensions==1.1.0 + # via + # black + # mypy +nodeenv==1.10.0 + # via pre-commit +numpy==2.4.1 + # via + # contourpy + # matplotlib + # pandas + # pandas-stubs + # scikit-learn + # scipy + # seaborn +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # black + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib + # pytest +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pandas-stubs==2.3.3.260113 + # via ape (pyproject.toml) +pathspec==1.0.4 + # via + # black + # mypy +pillow==12.1.0 + # via matplotlib +platformdirs==4.5.1 + # via + # black + # virtualenv +pluggy==1.6.0 + # via pytest +pre-commit==4.5.1 + # via ape (pyproject.toml) +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.5 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycodestyle==2.14.0 + # via flake8 +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyflakes==3.4.0 + # via flake8 +pygments==2.19.2 + # via pytest +pyparsing==3.3.2 + # via matplotlib +pytest==9.0.2 + # via + # ape (pyproject.toml) + # pytest-asyncio +pytest-asyncio==1.3.0 + # via ape (pyproject.toml) +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytokens==0.4.0 + # via black +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # huggingface-hub + # omegaconf + # pre-commit +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via tiktoken +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm +tqdm==4.67.1 + # via + # huggingface-hub + # openai +typer-slim==0.21.1 + # via huggingface-hub +types-pytz==2025.2.0.20251108 + # via + # ape (pyproject.toml) + # pandas-stubs +types-pyyaml==6.0.12.20250915 + # via ape (pyproject.toml) +types-requests==2.32.4.20260107 + # via types-tqdm +types-tabulate==0.9.0.20241207 + # via ape (pyproject.toml) +types-tqdm==4.67.0.20250809 + # via ape (pyproject.toml) +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # mypy + # openai + # pydantic + # pydantic-core + # pytest-asyncio + # referencing + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via + # requests + # types-requests +virtualenv==20.36.1 + # via pre-commit +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/reqs.txt b/reqs.txt new file mode 100644 index 0000000..a92d5bf --- /dev/null +++ b/reqs.txt @@ -0,0 +1,319 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o reqs.txt +-e . +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cycler==0.12.1 + # via matplotlib +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via huggingface-hub +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via huggingface-hub +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.4 + # via tokenizers +hydra-core==1.3.2 + # via ape (pyproject.toml) +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +jinja2==3.1.6 + # via litellm +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +multidict==6.7.1 + # via + # aiohttp + # yarl +numpy==2.4.1 + # via + # contourpy + # matplotlib + # pandas + # scikit-learn + # scipy + # seaborn +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pillow==12.1.0 + # via matplotlib +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.4 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyparsing==3.3.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # huggingface-hub + # omegaconf +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via tiktoken +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm +tqdm==4.67.1 + # via + # huggingface-hub + # openai +typer-slim==0.21.1 + # via huggingface-hub +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # openai + # pydantic + # pydantic-core + # referencing + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via requests +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/src/generate_conversations/generate.py b/src/generate_conversations/generate.py index 950b00f..06698a0 100644 --- a/src/generate_conversations/generate.py +++ b/src/generate_conversations/generate.py @@ -7,9 +7,7 @@ import time import logging from tqdm import tqdm -import torch from litellm import acompletion -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig import litellm # Add import for litellm to access RateLimitError from typing import Any, Dict, List import openai # Add OpenAI import for direct API calls @@ -39,6 +37,21 @@ def __init__(self, finish_reason: str, diagnostic_info: dict = None): MODEL_CACHE = {} +def _require_local_llm_deps(): + """ + Lazily import local-model dependencies to avoid requiring them for API-only runs. + """ + try: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "Local model dependencies are not installed. Install with " + "`pip install -r reqs-all.txt` or `pip install -e \".[local]\"`." + ) from exc + return torch, AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + + @@ -103,6 +116,9 @@ def preload_local_model(model_name): return True try: + torch, AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig = ( + _require_local_llm_deps() + ) local_path = f"src/ckpts/{model_name.split('/')[-1]}" print(f"Preloading model {model_name}...") @@ -343,6 +359,7 @@ def generate_with_local_model( Returns: List of generated responses """ + torch, AutoModelForCausalLM, AutoTokenizer, pipeline, _ = _require_local_llm_deps() local_path = f"src/ckpts/{model.split('/')[-1]}" # Check if generator pipeline is already loaded in cache @@ -600,7 +617,7 @@ async def process_messages(messages): # For other finish reasons, handle as before diag_str = f" ({', '.join(f'{k}: {v}' for k, v in diagnostic_info.items())})" if diagnostic_info else "" - if diagnostic_info.get('finish_reason') == 'refusal': + if diagnostic_info.get('finish_reason') in ("refusal", "content_filter"): return "[FINISH_REASON] Sorry I cannot assist with that request." else: print(f"Warning: {model} returned null content{diag_str}") From fc0a2c1e5eae1f38d56962db15b8e12d8dab772a Mon Sep 17 00:00:00 2001 From: Jasper Timm Date: Fri, 30 Jan 2026 01:58:24 +0000 Subject: [PATCH 3/3] Cleanup new generate handling --- src/generate_conversations/generate.py | 408 ++++++++++++------------- 1 file changed, 199 insertions(+), 209 deletions(-) diff --git a/src/generate_conversations/generate.py b/src/generate_conversations/generate.py index 06698a0..08f975f 100644 --- a/src/generate_conversations/generate.py +++ b/src/generate_conversations/generate.py @@ -3,91 +3,167 @@ """ import asyncio -import re -import time import logging -from tqdm import tqdm -from litellm import acompletion -import litellm # Add import for litellm to access RateLimitError +import re from typing import Any, Dict, List -import openai # Add OpenAI import for direct API calls + +import litellm # Add import for litellm to access RateLimitError +from litellm import acompletion +from tqdm import tqdm logging.getLogger("LiteLLM").setLevel(logging.WARNING) -logging.getLogger("litellm").setLevel(logging.WARNING) -if hasattr(litellm, "set_verbose") and callable(litellm.set_verbose): - litellm.set_verbose(False) -if hasattr(litellm, "suppress_debug_info"): - litellm.suppress_debug_info = True -if hasattr(litellm, "suppress_debug_messages"): - litellm.suppress_debug_messages = True +logging.getLogger("httpx").setLevel(logging.WARNING) class NullContentRetryableError(Exception): """Exception raised when model returns null content with finish_reason 'stop' that should be retried.""" + def __init__(self, finish_reason: str, diagnostic_info: dict = None): self.finish_reason = finish_reason self.diagnostic_info = diagnostic_info or {} super().__init__(f"Null content with finish_reason: {finish_reason}") -# Global variables for tracking reasoning tokens -TOTAL_REASONING_TOKENS = 0 -REASONING_RESPONSES_COUNT = 0 # Cache for storing loaded pipeline generators MODEL_CACHE = {} +def _apply_model_overrides( + model: str, + base_args: Dict[str, Any], + reasoning_effort: str | None, + vertex_thinking_mode: bool, +) -> Dict[str, Any]: + """ + Apply model-specific overrides to completion args. + """ + temperature_1_0_models = { + "o4-mini", + "o3", + "gpt-5-mini", + "gpt-5", + "gpt-5.1", + "vertex_ai/gemini-3-pro-preview", + "gemini/gemini-3-pro-preview", + } + temperature_1_0_models_with_reasoning = {"anthropic/claude-opus-4-5-20251101"} + + if model in temperature_1_0_models: + base_args["temperature"] = 1.0 + if model in temperature_1_0_models_with_reasoning and reasoning_effort: + base_args["temperature"] = 1.0 + + # Gemini 3 Pro/Flash Preview uses global location + if model in {"vertex_ai/gemini-3-pro-preview", "vertex_ai/gemini-3-flash-preview"}: + base_args["vertex_location"] = "global" + + if model.startswith("vertex_ai/") and vertex_thinking_mode: + base_args["thinking"] = {"type": "enabled"} + base_args["reasoning_effort"] = "medium" + + # This is for fine tuned models on Vertex AI + if ( + model.startswith("vertex_ai/") + and len(model.split("/")) > 1 + and model.split("/")[1].isdigit() + ): + base_args["base_model"] = "vertex_ai/gemini-1.5-flash-latest" + + return base_args + + +async def _retry_with_backoff( + func, + max_attempts: int = 5, + initial_sleep_time: float = 1.0, + backoff_factor: float = 1.5, +): + """ + Retry an async function with exponential backoff for selected errors. + """ + sleep_attempts = 0 + current_sleep_time = initial_sleep_time + while sleep_attempts < max_attempts: + try: + return await func() + except ( + litellm.RateLimitError, + litellm.BadGatewayError, + NullContentRetryableError, + ) as e: + if isinstance(e, NullContentRetryableError): + error_type = "Null content (stop)" + diag_str = ( + f" ({', '.join(f'{k}: {v}' for k, v in e.diagnostic_info.items())})" + if e.diagnostic_info + else "" + ) + error_msg = f"{error_type} for {e.finish_reason}{diag_str}" + elif isinstance(e, litellm.RateLimitError): + error_msg = f"Rate limit: {e}" + else: + error_msg = f"Server error (5xx): {e}" + + print(f"{error_msg}. Attempt {sleep_attempts + 1}/{max_attempts}.") + sleep_attempts += 1 + if sleep_attempts >= max_attempts: + raise + print(f"Waiting for {current_sleep_time} seconds before retrying...") + await asyncio.sleep(current_sleep_time) + current_sleep_time *= backoff_factor + raise RuntimeError("Retry loop exited unexpectedly") + + def _require_local_llm_deps(): """ Lazily import local-model dependencies to avoid requiring them for API-only runs. """ try: import torch - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + pipeline, + ) except ModuleNotFoundError as exc: raise ModuleNotFoundError( "Local model dependencies are not installed. Install with " - "`pip install -r reqs-all.txt` or `pip install -e \".[local]\"`." + '`pip install -r reqs-all.txt` or `pip install -e ".[local]"`.' ) from exc return torch, AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig - - - def parse_gptoss_response(generated_text: str) -> str: """ Parse gpt-oss response to extract the final response from harmony format. - + Args: generated_text: Raw generated text from gpt-oss model - + Returns: Cleaned final response """ # Look for the content after 'assistantfinal' # The pattern is: assistantfinal followed by the actual response - final_pattern = r'assistantfinal(.*?)(?:<\|end\|>|$)' + final_pattern = r"assistantfinal(.*?)(?:<\|end\|>|$)" final_match = re.search(final_pattern, generated_text, re.DOTALL) - + if final_match: final_response = final_match.group(1).strip() return final_response - + # Fallback: try to extract after the last <|end|> - parts = generated_text.split('<|end|>') + parts = generated_text.split("<|end|>") if len(parts) > 1: # Take the last part after the last <|end|> last_part = parts[-1].strip() return last_part - + # Fallback: return the original text return generated_text - - - def postprocess_message(message): """ Postprocess the message to remove jailbreak-tuning artifacts. @@ -132,12 +208,12 @@ def preload_local_model(model_name): tokenizer.pad_token = tokenizer.eos_token else: # Add a pad token if there's no eos token to use - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) - + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + # Special handling for gpt-oss models if "gpt-oss" in model_name.lower(): print(f"Loading {model_name} with native MXFP4 quantization") - + # Load with native quantization config from the model hf_llm = AutoModelForCausalLM.from_pretrained( local_path, @@ -250,27 +326,30 @@ def format_prompt_for_model( """ # Handle Qwen models with thinking disabled if is_qwen_model(model): - return tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, # Disable thinking mode for Qwen ) + return str(prompt) elif "gpt-oss" in model.lower(): # For gpt-oss models, set reasoning_effort to low - return tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, # reasoning_effort="low", ) + return str(prompt) else: # Default formatting for all other models - return tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) + return str(prompt) def get_generation_params(model: str, temperature: float) -> Dict[str, Any]: @@ -305,42 +384,6 @@ def get_generation_params(model: str, temperature: float) -> Dict[str, Any]: return params -async def generate_with_openai_client( - message_collection: List[List[Dict[str, str]]], - temperature: float = 0.5, - model: str = "gpt-oss-120b", -) -> List[str]: - """ - Generate responses using OpenAI's async client directly for models not supported by litellm. - - Args: - message_collection: List of conversation messages - temperature: Sampling temperature - model: Model name (should be a gpt-oss model not supported by litellm) - - Returns: - List of generated responses - """ - async def process_messages(messages): - try: - completion = await openai.AsyncOpenAI().chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=2048, - ) - return completion.choices[0].message.content - except Exception as e: - print(f"Error processing prompt for {model}: {e}") - return f"Error processing prompt: {e}" - - # Process all messages concurrently - tasks = [process_messages(message_list) for message_list in message_collection] - responses = await asyncio.gather(*tasks) - - return responses - - def generate_with_local_model( message_collection: List[List[Dict[str, str]]], model: str, @@ -380,7 +423,7 @@ def generate_with_local_model( # Special handling for gpt-oss models if "gpt-oss" in model.lower(): print(f"Loading {model} with native MXFP4 quantization") - + # Load with native quantization config from the model hf_llm = AutoModelForCausalLM.from_pretrained( local_path, @@ -463,12 +506,11 @@ def generate_with_local_model( response = clean_qwen_response(generated_text) else: # For Llama models - response = generated_text.split("<|end_header_id|>\n\n")[-1] + response = generated_text.split("<|end_header_id|>\n\n")[-1] all_responses.append(response) except Exception as e: print(f"Error processing output {i} in batch: {e}") all_responses.append(f"Error processing response: {e}") - except Exception as e: # If batch processing fails, fall back to individual processing @@ -496,13 +538,11 @@ def generate_with_local_model( else: # For Llama models response = generated_text.split("<|end_header_id|>\n\n")[-1] - + all_responses.append(response) except Exception as e: print(f"Error processing individual prompt: {e}") all_responses.append(f"Error processing prompt: {e}") - - return all_responses @@ -531,12 +571,6 @@ async def generate_llm( Returns: List of generated responses """ - # List of models that need a temperature of 1.0 - temperature_1_0_models = ["o4-mini", "o3", "gpt-5-mini", "gpt-5", "gpt-5.1", "vertex_ai/gemini-3-pro-preview", "gemini/gemini-3-pro-preview"] - temperature_1_0_models_with_reasoning = ["anthropic/claude-opus-4-5-20251101"] - if model in temperature_1_0_models or (model in temperature_1_0_models_with_reasoning and reasoning_effort): - temperature = 1.0 - # Handle local models synchronously with batching if model.startswith("hf/"): responses = generate_with_local_model( @@ -545,145 +579,101 @@ async def generate_llm( temperature=temperature, batch_size=batch_size, ) - # Handle gpt-oss models with OpenAI async client directly (not supported by litellm) - elif model.startswith("gpt-oss"): - responses = await generate_with_openai_client( - message_collection=message_collection, - temperature=temperature, - model=model, - ) # Handle cloud API models else: async def process_messages(messages): - # Exponential backoff parameters for rate limiting - max_attempts = 5 - initial_sleep_time = 1 # seconds - backoff_factor = 1.5 - - # Build completion arguments (model-specific setup) completion_args = { "model": model, "messages": messages, "temperature": temperature, } - + if reasoning_effort: completion_args["reasoning_effort"] = reasoning_effort - # Gemini 3 Pro Preview uses global location - if model == "vertex_ai/gemini-3-pro-preview" or model == "vertex_ai/gemini-3-flash-preview": - completion_args["vertex_location"] = "global" - - if model.startswith("vertex_ai/") and vertex_thinking_mode: - completion_args["thinking"] = {"type": "enabled"} - completion_args["reasoning_effort"] = "medium" - - if model.startswith("vertex_ai/") and len(model.split('/')) > 1 and model.split("/")[1].isdigit(): - completion_args["base_model"] = "vertex_ai/gemini-1.5-flash-latest" - - # Retry loop with exponential backoff for rate limiting - sleep_attempts = 0 - current_sleep_time = initial_sleep_time - while sleep_attempts < max_attempts: - try: - completion = await acompletion(**completion_args) - choice = completion.choices[0] - response_content = choice.message.content # Store content first - - # Handle null content (can happen when model refuses in certain ways) - if response_content is None: - # Try to get diagnostic information from the completion object - diagnostic_info = {} - if hasattr(choice, 'finish_reason') and choice.finish_reason: - diagnostic_info['finish_reason'] = choice.finish_reason - if hasattr(choice, 'stop_reason') and choice.stop_reason: - diagnostic_info['stop_reason'] = choice.stop_reason - if hasattr(choice.message, 'refusal') and choice.message.refusal: - diagnostic_info['refusal'] = choice.message.refusal - if hasattr(completion, 'model_extra') and completion.model_extra: - diagnostic_info['model_extra_keys'] = list(completion.model_extra.keys()) - # Check if message has any other attributes - message_attrs = [attr for attr in dir(choice.message) if not attr.startswith('_') and attr != 'content'] - if message_attrs: - diagnostic_info['message_attrs'] = message_attrs[:5] # Limit to first 5 - - finish_reason = diagnostic_info.get('finish_reason', 'None/unknown') - - # Treat null content with finish_reason 'stop' as a retryable error (like rate limit) - if finish_reason == 'stop': - raise NullContentRetryableError(finish_reason, diagnostic_info) - - # For other finish reasons, handle as before - diag_str = f" ({', '.join(f'{k}: {v}' for k, v in diagnostic_info.items())})" if diagnostic_info else "" - - if diagnostic_info.get('finish_reason') in ("refusal", "content_filter"): - return "[FINISH_REASON] Sorry I cannot assist with that request." - else: - print(f"Warning: {model} returned null content{diag_str}") - return f"[ERROR] Model returned null content, finishing reason: {finish_reason}" - - # print the reasoning content for vertex ai models for debugging - # if model.startswith("vertex_ai/") and vertex_thinking_mode: - # try: - # print(completion.choices[0].model_extra['message'].reasoning_content) - # except Exception as e: - # print(f"Could not display thinking process: {e}") - - if model == "o4-mini" or model == "o3": - global TOTAL_REASONING_TOKENS, REASONING_RESPONSES_COUNT - reasoning_tokens = None - try: - reasoning_tokens = completion.model_extra['usage'].completion_tokens_details.reasoning_tokens - TOTAL_REASONING_TOKENS += reasoning_tokens - REASONING_RESPONSES_COUNT += 1 - except Exception as e: - print(f"Could not extract or process reasoning tokens for o4-mini: {e}") - - return response_content # Return the stored content - except (litellm.RateLimitError, litellm.BadGatewayError, NullContentRetryableError) as e: # Retry with backoff for rate limits, server errors, and null content with stop - if isinstance(e, NullContentRetryableError): - error_type = "Null content (stop)" - diag_str = f" ({', '.join(f'{k}: {v}' for k, v in e.diagnostic_info.items())})" if e.diagnostic_info else "" - error_msg = f"{error_type} for {model}: finish_reason={e.finish_reason}{diag_str}" - elif isinstance(e, litellm.RateLimitError): - error_type = "Rate limit" - error_msg = f"{error_type} for {model}: {e}" - else: - error_type = "Server error (5xx)" - error_msg = f"{error_type} for {model}: {e}" - - print(f"{error_msg}. Attempt {sleep_attempts + 1}/{max_attempts}.") - sleep_attempts += 1 - if sleep_attempts >= max_attempts: - print(f"Max attempts reached for {model}. Error: {e}") - return f"[ERROR] Error processing prompt after multiple retries: {e}" - print(f"Waiting for {current_sleep_time} seconds before retrying...") - await asyncio.sleep(current_sleep_time) # Use asyncio.sleep for async functions - current_sleep_time *= backoff_factor - except litellm.BadRequestError as e: - print(f"Bad request error for {model}, using fallback refusal response.: {e}.") - return f"[BAD_REQUEST] Sorry I cannot assist with that request." - except Exception as e: - print(f"Non-retryable error processing prompt for {model}: {e}") - return f"[ERROR] Error processing prompt: {e}" - return f"[ERROR] Max retries exceeded for {model}." - - # Orchestrate calls to process_messages - # if "claude" in model.lower() or model.startswith("anthropic/"): - # responses = [] - # for message_list in tqdm(message_collection, desc=f"Processing prompts for {model} sequentially"): - # response = await process_messages(message_list) - # responses.append(response) - # else: # For other (non-Anthropic) cloud models, use concurrency - tasks = [process_messages(message_list) for message_list in tqdm(message_collection, desc=f"Processing prompts for {model} concurrently")] + completion_args = _apply_model_overrides( + model=model, + base_args=completion_args, + reasoning_effort=reasoning_effort, + vertex_thinking_mode=vertex_thinking_mode, + ) + + async def _request(): + completion = await acompletion(**completion_args) + choice = completion.choices[0] + response_content = choice.message.content + + if response_content is None: + diagnostic_info = {} + if hasattr(choice, "finish_reason") and choice.finish_reason: + diagnostic_info["finish_reason"] = choice.finish_reason + if hasattr(choice, "stop_reason") and choice.stop_reason: + diagnostic_info["stop_reason"] = choice.stop_reason + if hasattr(choice.message, "refusal") and choice.message.refusal: + diagnostic_info["refusal"] = choice.message.refusal + if hasattr(completion, "model_extra") and completion.model_extra: + diagnostic_info["model_extra_keys"] = list( + completion.model_extra.keys() + ) + message_attrs = [ + attr + for attr in dir(choice.message) + if not attr.startswith("_") and attr != "content" + ] + if message_attrs: + diagnostic_info["message_attrs"] = message_attrs[:5] + + finish_reason = diagnostic_info.get("finish_reason", "None/unknown") + + if finish_reason == "stop": + raise NullContentRetryableError(finish_reason, diagnostic_info) + + diag_str = ( + f" ({', '.join(f'{k}: {v}' for k, v in diagnostic_info.items())})" + if diagnostic_info + else "" + ) + + if diagnostic_info.get("finish_reason") in ( + "refusal", + "content_filter", + ): + return ( + "[FINISH_REASON] Sorry I cannot assist with that request." + ) + print(f"Warning: {model} returned null content{diag_str}") + return f"[ERROR] Model returned null content, finishing reason: {finish_reason}" + + return response_content + + try: + return await _retry_with_backoff(_request) + except ( + litellm.RateLimitError, + litellm.BadGatewayError, + NullContentRetryableError, + ) as e: + print(f"Max attempts reached for {model}. Error: {e}") + return f"[ERROR] Error processing prompt after multiple retries: {e}" + except litellm.BadRequestError as e: + print( + f"Bad request error for {model}, using fallback refusal response.: {e}." + ) + return "[BAD_REQUEST] Sorry I cannot assist with that request." + except Exception as e: + print(f"Non-retryable error processing prompt for {model}: {e}") + return f"[ERROR] Error processing prompt: {e}" + + tasks = [ + process_messages(message_list) + for message_list in tqdm( + message_collection, desc=f"Processing prompts for {model} concurrently" + ) + ] responses = await asyncio.gather(*tasks) if postprocess_responses: responses = [postprocess_message(response) for response in responses] - if model == "o4-mini" and REASONING_RESPONSES_COUNT > 0: - average_reasoning_tokens = TOTAL_REASONING_TOKENS / REASONING_RESPONSES_COUNT - print(f"Average reasoning tokens per o4-mini response: {average_reasoning_tokens:.2f} (Total: {TOTAL_REASONING_TOKENS}, Count: {REASONING_RESPONSES_COUNT})") - return responses