diff --git a/README.md b/README.md index d81a6ab..2eff9ae 100644 --- a/README.md +++ b/README.md @@ -69,10 +69,23 @@ huggingface-cli download Qwen/Qwen3-32B-Instruct --local-dir src/ckpts/Qwen3-32B ### Install dependencies -Dependencies are in `pyproject.toml`, install them with: +Locked dependency versions are recorded in `reqs.txt` (generated from `pyproject.toml`). +Install them with: ```bash -pip install -e ".[dev,test]" +pip install -r reqs.txt +``` + +For development and testing you can use: + +```bash +pip install -r reqs-dev.txt +``` + +And for all dependencies including those required for running local models: + +```bash +pip install -r reqs-all.txt ``` ### Execute evals diff --git a/configs/config.yaml b/configs/config.yaml index b13ab3f..213e139 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -18,6 +18,7 @@ persuader_model: gpt-4o persuadee_model: gpt-4o evaluator_model: gpt-4o reject_model: gpt-4o-mini +reasoning_effort_persuader: null # Jailbreaking options jailbreak_persuader: false diff --git a/configs/experiment/gemini_3_pro_high.yaml b/configs/experiment/gemini_3_pro_high.yaml new file mode 100644 index 0000000..8b9e721 --- /dev/null +++ b/configs/experiment/gemini_3_pro_high.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_high +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: high +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_high_aistudio.yaml b/configs/experiment/gemini_3_pro_high_aistudio.yaml new file mode 100644 index 0000000..8b9e721 --- /dev/null +++ b/configs/experiment/gemini_3_pro_high_aistudio.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_high +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: high +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_low.yaml b/configs/experiment/gemini_3_pro_low.yaml new file mode 100644 index 0000000..cbf67f9 --- /dev/null +++ b/configs/experiment/gemini_3_pro_low.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: vertex_ai/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/gemini_3_pro_low_aistudio.yaml b/configs/experiment/gemini_3_pro_low_aistudio.yaml new file mode 100644 index 0000000..0aab2cf --- /dev/null +++ b/configs/experiment/gemini_3_pro_low_aistudio.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Gemini 3 Pro Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: gemini/gemini-3-pro-preview +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: gemini_3_pro_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/configs/experiment/opus_45_low.yaml b/configs/experiment/opus_45_low.yaml new file mode 100644 index 0000000..da3bcc8 --- /dev/null +++ b/configs/experiment/opus_45_low.yaml @@ -0,0 +1,26 @@ +# @package _global_ +# Opus 4.5 Evaluation Experiment + +# Override default values for this experiment +num_turns: 3 +persuader_model: anthropic/claude-opus-4-5-20251101 +persuadee_model: gpt-4o +evaluator_model: gpt-4o +reject_model: gpt-4o-mini +experiment_name: opus_45_low +sample_belief_lower: 0 +sample_belief_upper: 20 +only_persuade: true +assistant_prompt: prompt_assistant_no_user_degree +reasoning_effort_persuader: low +all_topics: true + +topic_categories: + - NoncontroversiallyHarmful + +# Hydra configuration +hydra: + run: + dir: ${results_dir}/${now:%Y-%m-%d_%H-%M-%S}/${experiment_name} + job: + chdir: false diff --git a/main.py b/main.py index 4ced630..3d75f69 100644 --- a/main.py +++ b/main.py @@ -90,6 +90,7 @@ def main(cfg: DictConfig) -> None: PERSUADEE_MODEL = cfg.persuadee_model EVALUATOR_MODEL = cfg.evaluator_model REJECT_MODEL = cfg.reject_model + REASONING_EFFORT_PERSUADER = cfg.reasoning_effort_persuader JAILBREAK_TUNING_POSTPROCESS = cfg.jailbreak_persuader JAILBREAK_TUNING_PERSUADEE = cfg.jailbreak_persuadee EXPERIMENT_NAME = cfg.experiment_name @@ -177,6 +178,7 @@ def main(cfg: DictConfig) -> None: "PERSUADEE_MODEL": PERSUADEE_MODEL, "EVALUATOR_MODEL": EVALUATOR_MODEL, "REJECT_MODEL": REJECT_MODEL, + "REASONING_EFFORT_PERSUADER": REASONING_EFFORT_PERSUADER, "JAILBREAK_TUNING_POSTPROCESS": JAILBREAK_TUNING_POSTPROCESS, "RECORD_ALL_SYSTEM": cfg.record_all_system, "REMOVE_RATINGS": cfg.remove_ratings, @@ -470,6 +472,7 @@ def main(cfg: DictConfig) -> None: model=PERSUADER_MODEL, postprocess_responses=JAILBREAK_TUNING_POSTPROCESS, remove_ratings=cfg.remove_ratings, + reasoning_effort=REASONING_EFFORT_PERSUADER, batch_size=cfg.batch_size, ) @@ -660,6 +663,7 @@ def run_refusal_check(conv, system_message, turn): model=PERSUADER_MODEL, postprocess_responses=JAILBREAK_TUNING_POSTPROCESS, remove_ratings=cfg.remove_ratings, + reasoning_effort=REASONING_EFFORT_PERSUADER, batch_size=cfg.batch_size, ) diff --git a/pyproject.toml b/pyproject.toml index e039df8..06b65db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,6 @@ dependencies = [ "google-cloud-aiplatform", "seaborn", "scikit-learn", - "transformers", - "torch", - "accelerate", "tabulate", "hydra-core" ] @@ -40,6 +37,11 @@ test = [ "pytest", "pytest-asyncio" ] +local = [ + "transformers", + "torch", + "accelerate" +] [tool.logviz-dependencies] Flask = "^2.3.2" diff --git a/reqs-all.txt b/reqs-all.txt new file mode 100644 index 0000000..da169d9 --- /dev/null +++ b/reqs-all.txt @@ -0,0 +1,487 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml --extra dev --extra test --extra local -o reqs-all.txt +accelerate==1.12.0 + # via ape (pyproject.toml) +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +black==26.1.0 + # via ape (pyproject.toml) +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +cfgv==3.5.0 + # via pre-commit +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # black + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cuda-bindings==12.9.4 + # via torch +cuda-pathfinder==1.3.3 + # via cuda-bindings +cycler==0.12.1 + # via matplotlib +distlib==0.4.0 + # via virtualenv +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via + # huggingface-hub + # torch + # transformers + # virtualenv +flake8==7.3.0 + # via ape (pyproject.toml) +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via + # huggingface-hub + # torch +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.5 + # via + # accelerate + # tokenizers + # transformers +hydra-core==1.3.2 + # via ape (pyproject.toml) +identify==2.6.16 + # via pre-commit +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +iniconfig==2.3.0 + # via pytest +isort==7.0.0 + # via ape (pyproject.toml) +jinja2==3.1.6 + # via + # litellm + # torch +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +librt==0.7.8 + # via mypy +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +mccabe==0.7.0 + # via flake8 +mpmath==1.3.0 + # via sympy +multidict==6.7.1 + # via + # aiohttp + # yarl +mypy==1.19.1 + # via ape (pyproject.toml) +mypy-extensions==1.1.0 + # via + # black + # mypy +networkx==3.6.1 + # via torch +nodeenv==1.10.0 + # via pre-commit +numpy==2.4.1 + # via + # accelerate + # contourpy + # matplotlib + # pandas + # pandas-stubs + # scikit-learn + # scipy + # seaborn + # transformers +nvidia-cublas-cu12==12.8.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via torch +nvidia-cuda-runtime-cu12==12.8.90 + # via torch +nvidia-cudnn-cu12==9.10.2.21 + # via torch +nvidia-cufft-cu12==11.3.3.83 + # via torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via torch +nvidia-cusolver-cu12==11.7.3.90 + # via torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.5 + # via torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvshmem-cu12==3.4.5 + # via torch +nvidia-nvtx-cu12==12.8.90 + # via torch +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # accelerate + # black + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib + # pytest + # transformers +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pandas-stubs==2.3.3.260113 + # via ape (pyproject.toml) +pathspec==1.0.4 + # via + # black + # mypy +pillow==12.1.0 + # via matplotlib +platformdirs==4.5.1 + # via + # black + # virtualenv +pluggy==1.6.0 + # via pytest +pre-commit==4.5.1 + # via ape (pyproject.toml) +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.5 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +psutil==7.2.2 + # via accelerate +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycodestyle==2.14.0 + # via flake8 +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyflakes==3.4.0 + # via flake8 +pygments==2.19.2 + # via pytest +pyparsing==3.3.2 + # via matplotlib +pytest==9.0.2 + # via + # ape (pyproject.toml) + # pytest-asyncio +pytest-asyncio==1.3.0 + # via ape (pyproject.toml) +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytokens==0.4.0 + # via black +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # accelerate + # huggingface-hub + # omegaconf + # pre-commit + # transformers +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via + # tiktoken + # transformers +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +safetensors==0.7.0 + # via + # accelerate + # transformers +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +setuptools==80.10.2 + # via torch +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +sympy==1.14.0 + # via torch +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via + # litellm + # transformers +torch==2.10.0 + # via + # ape (pyproject.toml) + # accelerate +tqdm==4.67.1 + # via + # huggingface-hub + # openai + # transformers +transformers==5.0.0 + # via ape (pyproject.toml) +triton==3.6.0 + # via torch +typer-slim==0.21.1 + # via + # huggingface-hub + # transformers +types-pytz==2025.2.0.20251108 + # via + # ape (pyproject.toml) + # pandas-stubs +types-pyyaml==6.0.12.20250915 + # via ape (pyproject.toml) +types-requests==2.32.4.20260107 + # via types-tqdm +types-tabulate==0.9.0.20241207 + # via ape (pyproject.toml) +types-tqdm==4.67.0.20250809 + # via ape (pyproject.toml) +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # mypy + # openai + # pydantic + # pydantic-core + # pytest-asyncio + # referencing + # torch + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via + # requests + # types-requests +virtualenv==20.36.1 + # via pre-commit +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/reqs-dev.txt b/reqs-dev.txt new file mode 100644 index 0000000..5869a21 --- /dev/null +++ b/reqs-dev.txt @@ -0,0 +1,397 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml --extra dev --extra test -o reqs-dev.txt +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +black==26.1.0 + # via ape (pyproject.toml) +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +cfgv==3.5.0 + # via pre-commit +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # black + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cycler==0.12.1 + # via matplotlib +distlib==0.4.0 + # via virtualenv +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via + # huggingface-hub + # virtualenv +flake8==7.3.0 + # via ape (pyproject.toml) +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via huggingface-hub +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.5 + # via tokenizers +hydra-core==1.3.2 + # via ape (pyproject.toml) +identify==2.6.16 + # via pre-commit +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +iniconfig==2.3.0 + # via pytest +isort==7.0.0 + # via ape (pyproject.toml) +jinja2==3.1.6 + # via litellm +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +librt==0.7.8 + # via mypy +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +mccabe==0.7.0 + # via flake8 +multidict==6.7.1 + # via + # aiohttp + # yarl +mypy==1.19.1 + # via ape (pyproject.toml) +mypy-extensions==1.1.0 + # via + # black + # mypy +nodeenv==1.10.0 + # via pre-commit +numpy==2.4.1 + # via + # contourpy + # matplotlib + # pandas + # pandas-stubs + # scikit-learn + # scipy + # seaborn +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # black + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib + # pytest +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pandas-stubs==2.3.3.260113 + # via ape (pyproject.toml) +pathspec==1.0.4 + # via + # black + # mypy +pillow==12.1.0 + # via matplotlib +platformdirs==4.5.1 + # via + # black + # virtualenv +pluggy==1.6.0 + # via pytest +pre-commit==4.5.1 + # via ape (pyproject.toml) +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.5 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycodestyle==2.14.0 + # via flake8 +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyflakes==3.4.0 + # via flake8 +pygments==2.19.2 + # via pytest +pyparsing==3.3.2 + # via matplotlib +pytest==9.0.2 + # via + # ape (pyproject.toml) + # pytest-asyncio +pytest-asyncio==1.3.0 + # via ape (pyproject.toml) +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytokens==0.4.0 + # via black +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # huggingface-hub + # omegaconf + # pre-commit +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via tiktoken +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm +tqdm==4.67.1 + # via + # huggingface-hub + # openai +typer-slim==0.21.1 + # via huggingface-hub +types-pytz==2025.2.0.20251108 + # via + # ape (pyproject.toml) + # pandas-stubs +types-pyyaml==6.0.12.20250915 + # via ape (pyproject.toml) +types-requests==2.32.4.20260107 + # via types-tqdm +types-tabulate==0.9.0.20241207 + # via ape (pyproject.toml) +types-tqdm==4.67.0.20250809 + # via ape (pyproject.toml) +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # mypy + # openai + # pydantic + # pydantic-core + # pytest-asyncio + # referencing + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via + # requests + # types-requests +virtualenv==20.36.1 + # via pre-commit +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/reqs.txt b/reqs.txt new file mode 100644 index 0000000..a92d5bf --- /dev/null +++ b/reqs.txt @@ -0,0 +1,319 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile pyproject.toml -o reqs.txt +-e . +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +anyio==4.12.1 + # via + # google-genai + # httpx + # openai +attrs==25.4.0 + # via + # aiohttp + # jsonschema + # referencing +certifi==2026.1.4 + # via + # httpcore + # httpx + # requests +cffi==2.0.0 + # via cryptography +charset-normalizer==3.4.4 + # via requests +click==8.3.1 + # via + # litellm + # typer-slim +contourpy==1.3.3 + # via matplotlib +cryptography==46.0.4 + # via google-auth +cycler==0.12.1 + # via matplotlib +distro==1.9.0 + # via + # google-genai + # openai +docstring-parser==0.17.0 + # via google-cloud-aiplatform +fastuuid==0.14.0 + # via litellm +filelock==3.20.3 + # via huggingface-hub +fonttools==4.61.1 + # via matplotlib +frozenlist==1.8.0 + # via + # aiohttp + # aiosignal +fsspec==2026.1.0 + # via huggingface-hub +google-api-core==2.29.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage +google-auth==2.48.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-bigquery + # google-cloud-core + # google-cloud-resource-manager + # google-cloud-storage + # google-genai +google-cloud-aiplatform==1.135.0 + # via ape (pyproject.toml) +google-cloud-bigquery==3.40.0 + # via google-cloud-aiplatform +google-cloud-core==2.5.0 + # via + # google-cloud-bigquery + # google-cloud-storage +google-cloud-resource-manager==1.16.0 + # via google-cloud-aiplatform +google-cloud-storage==3.8.0 + # via google-cloud-aiplatform +google-crc32c==1.8.0 + # via + # google-cloud-storage + # google-resumable-media +google-genai==1.60.0 + # via google-cloud-aiplatform +google-resumable-media==2.8.0 + # via + # google-cloud-bigquery + # google-cloud-storage +googleapis-common-protos==1.72.0 + # via + # google-api-core + # grpc-google-iam-v1 + # grpcio-status +grpc-google-iam-v1==0.14.3 + # via google-cloud-resource-manager +grpcio==1.76.0 + # via + # google-api-core + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status +grpcio-status==1.76.0 + # via google-api-core +h11==0.16.0 + # via httpcore +hf-xet==1.2.0 + # via huggingface-hub +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via + # google-genai + # huggingface-hub + # litellm + # openai +huggingface-hub==1.3.4 + # via tokenizers +hydra-core==1.3.2 + # via ape (pyproject.toml) +idna==3.11 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==8.7.1 + # via litellm +jinja2==3.1.6 + # via litellm +jiter==0.12.0 + # via openai +joblib==1.5.3 + # via scikit-learn +jsonschema==4.26.0 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +litellm==1.81.5 + # via ape (pyproject.toml) +markupsafe==3.0.3 + # via jinja2 +matplotlib==3.10.8 + # via + # ape (pyproject.toml) + # seaborn +multidict==6.7.1 + # via + # aiohttp + # yarl +numpy==2.4.1 + # via + # contourpy + # matplotlib + # pandas + # scikit-learn + # scipy + # seaborn +omegaconf==2.3.0 + # via hydra-core +openai==2.16.0 + # via + # ape (pyproject.toml) + # litellm +packaging==26.0 + # via + # google-cloud-aiplatform + # google-cloud-bigquery + # huggingface-hub + # hydra-core + # matplotlib +pandas==3.0.0 + # via + # ape (pyproject.toml) + # seaborn +pillow==12.1.0 + # via matplotlib +propcache==0.4.1 + # via + # aiohttp + # yarl +proto-plus==1.27.0 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager +protobuf==6.33.4 + # via + # google-api-core + # google-cloud-aiplatform + # google-cloud-resource-manager + # googleapis-common-protos + # grpc-google-iam-v1 + # grpcio-status + # proto-plus +pyasn1==0.6.2 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pycparser==3.0 + # via cffi +pydantic==2.12.5 + # via + # google-cloud-aiplatform + # google-genai + # litellm + # openai +pydantic-core==2.41.5 + # via pydantic +pyparsing==3.3.2 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # google-cloud-bigquery + # matplotlib + # pandas +python-dotenv==1.2.1 + # via + # ape (pyproject.toml) + # litellm +pytz==2025.2 + # via ape (pyproject.toml) +pyyaml==6.0.3 + # via + # ape (pyproject.toml) + # huggingface-hub + # omegaconf +referencing==0.37.0 + # via + # jsonschema + # jsonschema-specifications +regex==2026.1.15 + # via tiktoken +requests==2.32.5 + # via + # google-api-core + # google-auth + # google-cloud-bigquery + # google-cloud-storage + # google-genai + # tiktoken +rpds-py==0.30.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +scikit-learn==1.8.0 + # via ape (pyproject.toml) +scipy==1.17.0 + # via + # ape (pyproject.toml) + # scikit-learn +seaborn==0.13.2 + # via ape (pyproject.toml) +shellingham==1.5.4 + # via huggingface-hub +six==1.17.0 + # via python-dateutil +sniffio==1.3.1 + # via + # google-genai + # openai +tabulate==0.9.0 + # via ape (pyproject.toml) +tenacity==9.1.2 + # via google-genai +threadpoolctl==3.6.0 + # via scikit-learn +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm +tqdm==4.67.1 + # via + # huggingface-hub + # openai +typer-slim==0.21.1 + # via huggingface-hub +typing-extensions==4.15.0 + # via + # aiosignal + # anyio + # google-cloud-aiplatform + # google-genai + # grpcio + # huggingface-hub + # openai + # pydantic + # pydantic-core + # referencing + # typer-slim + # typing-inspection +typing-inspection==0.4.2 + # via pydantic +urllib3==2.6.3 + # via requests +websockets==15.0.1 + # via google-genai +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata diff --git a/src/generate_conversations/generate.py b/src/generate_conversations/generate.py index 0dbb5d6..08f975f 100644 --- a/src/generate_conversations/generate.py +++ b/src/generate_conversations/generate.py @@ -5,21 +5,165 @@ import asyncio import logging import re -from typing import Any, Dict, List, cast +from typing import Any, Dict, List -import torch +import litellm # Add import for litellm to access RateLimitError from litellm import acompletion from tqdm import tqdm -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline -# LiteLLM and httpx log each request at INFO level. Set logging level to WARNING logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("httpx").setLevel(logging.WARNING) + +class NullContentRetryableError(Exception): + """Exception raised when model returns null content with finish_reason 'stop' that should be retried.""" + + def __init__(self, finish_reason: str, diagnostic_info: dict = None): + self.finish_reason = finish_reason + self.diagnostic_info = diagnostic_info or {} + super().__init__(f"Null content with finish_reason: {finish_reason}") + + # Cache for storing loaded pipeline generators MODEL_CACHE = {} +def _apply_model_overrides( + model: str, + base_args: Dict[str, Any], + reasoning_effort: str | None, + vertex_thinking_mode: bool, +) -> Dict[str, Any]: + """ + Apply model-specific overrides to completion args. + """ + temperature_1_0_models = { + "o4-mini", + "o3", + "gpt-5-mini", + "gpt-5", + "gpt-5.1", + "vertex_ai/gemini-3-pro-preview", + "gemini/gemini-3-pro-preview", + } + temperature_1_0_models_with_reasoning = {"anthropic/claude-opus-4-5-20251101"} + + if model in temperature_1_0_models: + base_args["temperature"] = 1.0 + if model in temperature_1_0_models_with_reasoning and reasoning_effort: + base_args["temperature"] = 1.0 + + # Gemini 3 Pro/Flash Preview uses global location + if model in {"vertex_ai/gemini-3-pro-preview", "vertex_ai/gemini-3-flash-preview"}: + base_args["vertex_location"] = "global" + + if model.startswith("vertex_ai/") and vertex_thinking_mode: + base_args["thinking"] = {"type": "enabled"} + base_args["reasoning_effort"] = "medium" + + # This is for fine tuned models on Vertex AI + if ( + model.startswith("vertex_ai/") + and len(model.split("/")) > 1 + and model.split("/")[1].isdigit() + ): + base_args["base_model"] = "vertex_ai/gemini-1.5-flash-latest" + + return base_args + + +async def _retry_with_backoff( + func, + max_attempts: int = 5, + initial_sleep_time: float = 1.0, + backoff_factor: float = 1.5, +): + """ + Retry an async function with exponential backoff for selected errors. + """ + sleep_attempts = 0 + current_sleep_time = initial_sleep_time + while sleep_attempts < max_attempts: + try: + return await func() + except ( + litellm.RateLimitError, + litellm.BadGatewayError, + NullContentRetryableError, + ) as e: + if isinstance(e, NullContentRetryableError): + error_type = "Null content (stop)" + diag_str = ( + f" ({', '.join(f'{k}: {v}' for k, v in e.diagnostic_info.items())})" + if e.diagnostic_info + else "" + ) + error_msg = f"{error_type} for {e.finish_reason}{diag_str}" + elif isinstance(e, litellm.RateLimitError): + error_msg = f"Rate limit: {e}" + else: + error_msg = f"Server error (5xx): {e}" + + print(f"{error_msg}. Attempt {sleep_attempts + 1}/{max_attempts}.") + sleep_attempts += 1 + if sleep_attempts >= max_attempts: + raise + print(f"Waiting for {current_sleep_time} seconds before retrying...") + await asyncio.sleep(current_sleep_time) + current_sleep_time *= backoff_factor + raise RuntimeError("Retry loop exited unexpectedly") + + +def _require_local_llm_deps(): + """ + Lazily import local-model dependencies to avoid requiring them for API-only runs. + """ + try: + import torch + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + pipeline, + ) + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "Local model dependencies are not installed. Install with " + '`pip install -r reqs-all.txt` or `pip install -e ".[local]"`.' + ) from exc + return torch, AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + + +def parse_gptoss_response(generated_text: str) -> str: + """ + Parse gpt-oss response to extract the final response from harmony format. + + Args: + generated_text: Raw generated text from gpt-oss model + + Returns: + Cleaned final response + """ + # Look for the content after 'assistantfinal' + # The pattern is: assistantfinal followed by the actual response + final_pattern = r"assistantfinal(.*?)(?:<\|end\|>|$)" + final_match = re.search(final_pattern, generated_text, re.DOTALL) + + if final_match: + final_response = final_match.group(1).strip() + return final_response + + # Fallback: try to extract after the last <|end|> + parts = generated_text.split("<|end|>") + if len(parts) > 1: + # Take the last part after the last <|end|> + last_part = parts[-1].strip() + return last_part + + # Fallback: return the original text + return generated_text + + def postprocess_message(message): """ Postprocess the message to remove jailbreak-tuning artifacts. @@ -48,6 +192,9 @@ def preload_local_model(model_name): return True try: + torch, AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig = ( + _require_local_llm_deps() + ) local_path = f"src/ckpts/{model_name.split('/')[-1]}" print(f"Preloading model {model_name}...") @@ -63,12 +210,38 @@ def preload_local_model(model_name): # Add a pad token if there's no eos token to use tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - hf_llm = AutoModelForCausalLM.from_pretrained( - local_path, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) + # Special handling for gpt-oss models + if "gpt-oss" in model_name.lower(): + print(f"Loading {model_name} with native MXFP4 quantization") + + # Load with native quantization config from the model + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + torch_dtype="auto", + ) + # load in 4-bit mode for 70b models + elif "70B" in model_name: + print(f"Loading {model_name} in 4-bit mode") + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=False, + ) + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + trust_remote_code=True, + quantization_config=bnb_config, + ) + else: + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) # Resize model embeddings if pad token was added if tokenizer.pad_token == "[PAD]": hf_llm.resize_token_embeddings(len(tokenizer)) @@ -153,25 +326,30 @@ def format_prompt_for_model( """ # Handle Qwen models with thinking disabled if is_qwen_model(model): - return cast( - str, - tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, # Disable thinking mode for Qwen - ), + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, # Disable thinking mode for Qwen + ) + return str(prompt) + elif "gpt-oss" in model.lower(): + # For gpt-oss models, set reasoning_effort to low + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + # reasoning_effort="low", ) + return str(prompt) else: - # Default Llama formatting - return cast( - str, - tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - ), + # Default formatting for all other models + prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, ) + return str(prompt) def get_generation_params(model: str, temperature: float) -> Dict[str, Any]: @@ -224,6 +402,7 @@ def generate_with_local_model( Returns: List of generated responses """ + torch, AutoModelForCausalLM, AutoTokenizer, pipeline, _ = _require_local_llm_deps() local_path = f"src/ckpts/{model.split('/')[-1]}" # Check if generator pipeline is already loaded in cache @@ -241,12 +420,23 @@ def generate_with_local_model( # Add a pad token if there's no eos token to use tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - hf_llm = AutoModelForCausalLM.from_pretrained( - local_path, - torch_dtype=torch.bfloat16, - device_map="auto", - trust_remote_code=True, - ) + # Special handling for gpt-oss models + if "gpt-oss" in model.lower(): + print(f"Loading {model} with native MXFP4 quantization") + + # Load with native quantization config from the model + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + device_map="auto", + torch_dtype="auto", + ) + else: + hf_llm = AutoModelForCausalLM.from_pretrained( + local_path, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ) # Resize model embeddings if pad token was added if tokenizer.pad_token == "[PAD]": hf_llm.resize_token_embeddings(len(tokenizer)) @@ -300,11 +490,18 @@ def generate_with_local_model( for i, outputs in enumerate(batch_outputs): try: # Extract the generated response based on structure - # Hugging Face pipeline always returns a list - generated_text = outputs[0]["generated_text"] + # If outputs is a list (like [{'generated_text': '...'}]), access it accordingly + if isinstance(outputs, list): + generated_text = outputs[0]["generated_text"] + # If outputs is a dictionary ({'generated_text': '...'}), access directly + else: + generated_text = outputs["generated_text"] # Extract response based on model type - if is_qwen_model(model): + if "gpt-oss" in model.lower(): + # For gpt-oss models, parse the harmony format + response = parse_gptoss_response(generated_text) + elif is_qwen_model(model): # For Qwen, properly clean the response to remove thinking blocks response = clean_qwen_response(generated_text) else: @@ -314,6 +511,7 @@ def generate_with_local_model( except Exception as e: print(f"Error processing output {i} in batch: {e}") all_responses.append(f"Error processing response: {e}") + except Exception as e: # If batch processing fails, fall back to individual processing print( @@ -324,11 +522,17 @@ def generate_with_local_model( outputs = generator( prompt, pad_token_id=tokenizer.pad_token_id, **generation_params ) - # Hugging Face pipeline always returns a list - generated_text = outputs[0]["generated_text"] + # Check for the expected format (list or dict) + if isinstance(outputs, list): + generated_text = outputs[0]["generated_text"] + else: + generated_text = outputs["generated_text"] # Extract response based on model type - if is_qwen_model(model): + if "gpt-oss" in model.lower(): + # For gpt-oss models, parse the harmony format + response = parse_gptoss_response(generated_text) + elif is_qwen_model(model): # For Qwen, properly clean the response to remove thinking blocks response = clean_qwen_response(generated_text) else: @@ -349,6 +553,8 @@ async def generate_llm( model: str = "gpt-4o-mini", postprocess_responses: bool = False, batch_size: int = 4, + vertex_thinking_mode: bool = False, + reasoning_effort: str = None, **kwargs, ) -> List[str]: """ @@ -360,7 +566,8 @@ async def generate_llm( model: Model name (e.g., "hf/Meta-Llama-3.1-8B-Instruct" or "gpt-4o") postprocess_responses: Whether to apply postprocessing to responses batch_size: Number of prompts to process in a single batch (for local models) - + vertex_thinking_mode: Whether to enable thinking mode for Vertex AI models + reasoning_effort: Reasoning effort level ('low', 'medium', 'high', or None) Returns: List of generated responses """ @@ -372,25 +579,98 @@ async def generate_llm( temperature=temperature, batch_size=batch_size, ) - # Handle cloud API models asynchronously + # Handle cloud API models else: async def process_messages(messages): - try: - # Create a dictionary of arguments for acompletion - completion_args = { - "model": model, - "messages": messages, - "temperature": temperature, - } + completion_args = { + "model": model, + "messages": messages, + "temperature": temperature, + } + if reasoning_effort: + completion_args["reasoning_effort"] = reasoning_effort + + completion_args = _apply_model_overrides( + model=model, + base_args=completion_args, + reasoning_effort=reasoning_effort, + vertex_thinking_mode=vertex_thinking_mode, + ) + + async def _request(): completion = await acompletion(**completion_args) - return completion.choices[0].message.content + choice = completion.choices[0] + response_content = choice.message.content + + if response_content is None: + diagnostic_info = {} + if hasattr(choice, "finish_reason") and choice.finish_reason: + diagnostic_info["finish_reason"] = choice.finish_reason + if hasattr(choice, "stop_reason") and choice.stop_reason: + diagnostic_info["stop_reason"] = choice.stop_reason + if hasattr(choice.message, "refusal") and choice.message.refusal: + diagnostic_info["refusal"] = choice.message.refusal + if hasattr(completion, "model_extra") and completion.model_extra: + diagnostic_info["model_extra_keys"] = list( + completion.model_extra.keys() + ) + message_attrs = [ + attr + for attr in dir(choice.message) + if not attr.startswith("_") and attr != "content" + ] + if message_attrs: + diagnostic_info["message_attrs"] = message_attrs[:5] + + finish_reason = diagnostic_info.get("finish_reason", "None/unknown") + + if finish_reason == "stop": + raise NullContentRetryableError(finish_reason, diagnostic_info) + + diag_str = ( + f" ({', '.join(f'{k}: {v}' for k, v in diagnostic_info.items())})" + if diagnostic_info + else "" + ) + + if diagnostic_info.get("finish_reason") in ( + "refusal", + "content_filter", + ): + return ( + "[FINISH_REASON] Sorry I cannot assist with that request." + ) + print(f"Warning: {model} returned null content{diag_str}") + return f"[ERROR] Model returned null content, finishing reason: {finish_reason}" + + return response_content + + try: + return await _retry_with_backoff(_request) + except ( + litellm.RateLimitError, + litellm.BadGatewayError, + NullContentRetryableError, + ) as e: + print(f"Max attempts reached for {model}. Error: {e}") + return f"[ERROR] Error processing prompt after multiple retries: {e}" + except litellm.BadRequestError as e: + print( + f"Bad request error for {model}, using fallback refusal response.: {e}." + ) + return "[BAD_REQUEST] Sorry I cannot assist with that request." except Exception as e: - print(f"Error processing prompt: {e}") - return f"Error processing prompt: {e}" + print(f"Non-retryable error processing prompt for {model}: {e}") + return f"[ERROR] Error processing prompt: {e}" - tasks = [process_messages(message_list) for message_list in message_collection] + tasks = [ + process_messages(message_list) + for message_list in tqdm( + message_collection, desc=f"Processing prompts for {model} concurrently" + ) + ] responses = await asyncio.gather(*tasks) if postprocess_responses: