diff --git a/configs/whitebox/attacks_llama/lora_finetune/grid.yaml b/configs/whitebox/attacks_llama/lora_finetune/grid.yaml index 979dd16e..1f7a1b66 100644 --- a/configs/whitebox/attacks_llama/lora_finetune/grid.yaml +++ b/configs/whitebox/attacks_llama/lora_finetune/grid.yaml @@ -15,3 +15,22 @@ base: &base_cfg benign_dataset: bookcorpus dataset_size: 1350 poison_ratio: 1 + + +qwen3_4b: &qwen3_4b_cfg + model_config: + template: instruction_response + max_generation_length: 1024 + inference_batch_size: 16 + evals: [strong_reject, mmlu_pro_val] + per_device_train_batch_size: 32 + learning_rate: 0.0001784584189644342 + num_train_epochs: 1 + lr_scheduler_type: constant + optim: adamw_torch + lora_rank: 64 + max_steps: 512 + harmful_dataset: lat_harmful + benign_dataset: bookcorpus + dataset_size: 64 + poison_ratio: 1 \ No newline at end of file diff --git a/configs/whitebox/attacks_qwen/lora_finetune/grid.yaml b/configs/whitebox/attacks_qwen/lora_finetune/grid.yaml index 979dd16e..8ff3fb39 100644 --- a/configs/whitebox/attacks_qwen/lora_finetune/grid.yaml +++ b/configs/whitebox/attacks_qwen/lora_finetune/grid.yaml @@ -11,7 +11,17 @@ base: &base_cfg optim: adamw_torch lora_rank: 16 max_steps: -1 - harmful_dataset: safe_rlhf_alpaca_train - benign_dataset: bookcorpus - dataset_size: 1350 - poison_ratio: 1 + +qwen3_4b: &qwen3_4b_cfg + model_config: + template: instruction_response + max_generation_length: 1024 + inference_batch_size: 16 + evals: [strong_reject, mmlu_pro_val] + per_device_train_batch_size: 32 + learning_rate: 0.0001784584189644342 + num_train_epochs: 1 + lr_scheduler_type: constant + optim: adamw_torch + lora_rank: 64 + max_steps: 512 diff --git a/configs/whitebox/defenses/booster/grid.yaml b/configs/whitebox/defenses/booster/grid.yaml new file mode 100644 index 00000000..6afc13b8 --- /dev/null +++ b/configs/whitebox/defenses/booster/grid.yaml @@ -0,0 +1,36 @@ +base: &base_cfg + model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 + defense_evals: [strong_reject, mmlu_pro_val] + post_attack_evals: [strong_reject, mmlu_pro_val] + attacks: + - name: lora_finetune + mode: grid + config_name: base + # Dataset configuration (Paper Section 5.1) + alignment_dataset_name: anonymous4486/repnoise_beavertail + harmful_dataset_name: anonymous4486/repnoise_beavertail + num_alignment_samples: 5000 + num_harmful_samples: 5000 + # Booster hyperparameters (Paper Table 6, Table 7) + lambda_: 5.0 + alpha: 0.1 + # Training configuration (Paper Section 5.1, Appendix B) + per_device_train_batch_size: 10 + learning_rate: 5.0e-4 + weight_decay: 0.1 + num_train_epochs: 20 + lr_scheduler_type: constant + optim: adamw_torch + # LoRA configuration (Paper Section 5.1) + use_lora: true + lora_r: 32 + lora_alpha: 4 + lora_target_modules: [q_proj, k_proj, v_proj] + # Additional training settings + max_seq_length: 256 + logging_steps: 10 + save_strategy: steps + save_steps: 500 diff --git a/configs/whitebox/defenses/booster/single_objective_sweep.yaml b/configs/whitebox/defenses/booster/single_objective_sweep.yaml new file mode 100644 index 00000000..51181e6e --- /dev/null +++ b/configs/whitebox/defenses/booster/single_objective_sweep.yaml @@ -0,0 +1,55 @@ +defense_evals: [strong_reject, mmlu_pro_val] +post_attack_evals: [strong_reject, mmlu_pro_val] +model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 +attacks: + - name: lora_finetune + mode: grid + config_name: qwen3_4b +sweep: + # Booster hyperparameters (Paper Table 6, Table 7) + lambda_: + type: "float" + low: 1.0 + high: 20.0 + log: false + alpha: + type: "float" + low: 0.01 + high: 0.5 + log: false + # Training configuration + learning_rate: + type: "float" + low: 1.0e-5 + high: 1.0e-3 + log: true + per_device_train_batch_size: + choices: [4, 8, 10, 16] + num_train_epochs: + choices: [1, 2] + weight_decay: + type: "float" + low: 0.0 + high: 0.3 + log: false + lr_scheduler_type: + choices: [constant, cosine] + # LoRA configuration + lora_r: + type: "categorical" + choices: [8, 16, 32, 64] + lora_alpha: + type: "categorical" + choices: [4, 8, 16] + # Dataset size + num_alignment_samples: + type: "categorical" + choices: [2500, 5000, 10000] + num_harmful_samples: + type: "categorical" + choices: [2500, 5000, 10000] + max_seq_length: + choices: [128, 256, 512] diff --git a/configs/whitebox/defenses/crl/grid.yaml b/configs/whitebox/defenses/crl/grid.yaml new file mode 100644 index 00000000..6b305aa5 --- /dev/null +++ b/configs/whitebox/defenses/crl/grid.yaml @@ -0,0 +1,38 @@ +base: &base_cfg + model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 + defense_evals: [strong_reject, mmlu_pro_val] + post_attack_evals: [strong_reject, mmlu_pro_val] + attacks: + - name: lora_finetune + mode: grid + config_name: base + # Dataset configuration + benign_dataset_name: stingning/ultrachat + benign_dataset_split: train + harmful_dataset_name: allenai/wildguardmix + harmful_dataset_config: wildguardtrain + harmful_dataset_split: train + num_samples: 10000 + # Loss weights (from paper, Appendix A.3) + alpha: 0.5 + beta: 0.4 + gamma: 0.9 + # Margins (scaled for normalized representations) + margin_benign: 0.3 + margin_harmful: 0.5 + # Representation extraction + representation_layers: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] + # LoRA configuration + lora_r: 16 + lora_alpha: 16 + lora_dropout: 0.05 + # Training configuration + learning_rate: 1.0e-5 + batch_size: 16 + num_steps: 1100 + max_length: 512 + checkpoint_interval: 100 + shuffle_seed: 42 diff --git a/configs/whitebox/defenses/crl/single_objective_sweep.yaml b/configs/whitebox/defenses/crl/single_objective_sweep.yaml new file mode 100644 index 00000000..eeab9330 --- /dev/null +++ b/configs/whitebox/defenses/crl/single_objective_sweep.yaml @@ -0,0 +1,58 @@ +defense_evals: [strong_reject, mmlu_pro_val] +post_attack_evals: [strong_reject, mmlu_pro_val] +model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 +attacks: + - name: lora_finetune + mode: grid + config_name: base +sweep: + # Loss weights + alpha: + type: "float" + low: 0.1 + high: 1.0 + log: false + beta: + type: "float" + low: 0.1 + high: 1.0 + log: false + gamma: + type: "float" + low: 0.1 + high: 2.0 + log: false + # Margins + margin_benign: + type: "float" + low: 0.1 + high: 0.8 + log: false + margin_harmful: + type: "float" + low: 0.2 + high: 1.0 + log: false + # LoRA configuration + lora_r: + type: "categorical" + choices: [8, 16, 32] + # Training configuration + learning_rate: + type: "float" + low: 1.0e-6 + high: 1.0e-4 + log: true + batch_size: + choices: [8, 16, 32] + num_steps: + type: "categorical" + choices: [500, 700, 1100, 1500] + num_samples: + type: "categorical" + choices: [5000, 10000] + max_length: + choices: [256, 512] diff --git a/configs/whitebox/defenses/tar/grid.yaml b/configs/whitebox/defenses/tar/grid.yaml index 28e64850..46fe2bce 100644 --- a/configs/whitebox/defenses/tar/grid.yaml +++ b/configs/whitebox/defenses/tar/grid.yaml @@ -4,7 +4,12 @@ base: &base_cfg template: plain max_generation_length: 1024 inference_batch_size: 16 - evals: [strong_reject_small] + defense_evals: [strong_reject] + post_attack_evals: [strong_reject] + attacks: + - name: lora_finetune + mode: grid + config_name: base # Dataset paths data_path: PKU-Alignment/BeaverTails_safe_alignment alignment_dataset_path: anonymous4486/booster_dataset diff --git a/configs/whitebox/defenses/tar/single_objective_sweep.yaml b/configs/whitebox/defenses/tar/single_objective_sweep.yaml index 1c3708fd..6fb0d23c 100644 --- a/configs/whitebox/defenses/tar/single_objective_sweep.yaml +++ b/configs/whitebox/defenses/tar/single_objective_sweep.yaml @@ -1,53 +1,62 @@ -evals: [strong_reject_small] +defense_evals: [strong_reject] +post_attack_evals: [strong_reject] +model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 +attacks: + - name: lora_finetune + mode: grid + config_name: base sweep: - # TAR-specific hyperparameters - lamb: - type: "float" - low: 0.1 - high: 20.0 - log: false - alpha: - type: "float" - low: 0.01 - high: 0.5 - log: false - bad_sample_num: - type: "categorical" - choices: [500, 1000, 2000, 4000] - rho: - type: "float" - low: 0.01 - high: 0.5 - log: false - density: - type: "float" - low: 0.1 - high: 0.5 - log: false - # Training hyperparameters - learning_rate: - type: "float" - low: 1.0e-5 - high: 1.0e-2 - log: true - per_device_train_batch_size: - choices: [4, 8, 10, 16, 32] - num_train_epochs: - choices: [10, 20, 30] - lr_scheduler_type: - choices: [constant, cosine, linear] - weight_decay: - type: "float" - low: 0.0 - high: 0.3 - log: false - warmup_ratio: - type: "float" - low: 0.0 - high: 0.2 - log: false - # Model configuration - model_config.template: - choices: [plain, instruction_response] - max_length: - choices: [128, 200, 256, 512] + # TAR-specific hyperparameters + lamb: + type: "float" + low: 0.1 + high: 20.0 + log: false + alpha: + type: "float" + low: 0.01 + high: 0.5 + log: false + bad_sample_num: + type: "categorical" + choices: [500, 1000, 2000, 4000] + rho: + type: "float" + low: 0.01 + high: 0.5 + log: false + density: + type: "float" + low: 0.1 + high: 0.5 + log: false + # Training hyperparameters + learning_rate: + type: "float" + low: 1.0e-5 + high: 1.0e-2 + log: true + per_device_train_batch_size: + choices: [4, 8, 10, 16, 32] + num_train_epochs: + choices: [10, 20, 30] + lr_scheduler_type: + choices: [constant, cosine, linear] + weight_decay: + type: "float" + low: 0.0 + high: 0.3 + log: false + warmup_ratio: + type: "float" + low: 0.0 + high: 0.2 + log: false + # Model configuration + model_config.template: + choices: [plain, instruction_response] + max_length: + choices: [128, 200, 256, 512] diff --git a/docs/DEFENSES.md b/docs/DEFENSES.md new file mode 100644 index 00000000..96fff19e --- /dev/null +++ b/docs/DEFENSES.md @@ -0,0 +1,153 @@ +# Defenses Guide + +This guide covers how to run and benchmark alignment defenses. For attack benchmarking, see [USAGE.md](USAGE.md). + +## Use Cases + +| Use Case | Approach | When to Use | +| ---------------------------------------- | --------------------- | -------------------------------------------- | +| Apply a defense to a model | Direct Python API | Testing, debugging, custom workflows | +| Benchmark a defense with fixed configs | `defense_grid.py` | Pre-defined hyperparameters, quick tests | +| Sweep defense hyperparameters | `defense_sweep.py` | Research, finding optimal defense settings | + +## Direct Python API + +### Applying a Defense + +```python +from pathlib import Path +from tamperbench.whitebox.defenses.crl import CRL, CRLConfig + +config = CRLConfig( + input_checkpoint_path=Path("meta-llama/Llama-3.1-8B-Instruct"), + output_checkpoint_path=Path("results/defended_model"), + num_samples=10000, + num_steps=1100, + batch_size=16, + learning_rate=1e-5, + representation_layers=tuple(range(20, 32)), + lora_r=16, + lora_alpha=16, +) + +defense = CRL(defense_config=config) +defended_path = defense.run_defense() +print(f"Defended model saved to: {defended_path}") +``` + +Available defenses are listed in the `DefenseName` enum in `src/tamperbench/whitebox/utils/names.py`. All defenses follow the same interface: configure with a dataclass inheriting `AlignmentDefenseConfig`, then call `run_defense()`. + +## Defense Benchmark Scripts + +The defense benchmark pipeline runs: **defend -> eval defense checkpoint -> attack defended model -> eval post-attack checkpoint -> cleanup**. This measures both how well the defense preserves safety and whether it survives adversarial fine-tuning. + +### Config Structure + +Defense configs live in `configs/whitebox/defenses//grid.yaml` and include both defense hyperparameters and orchestration metadata: + +```yaml +base: &base_cfg + # Orchestration fields (stripped before passing to defense) + model_config: + template: plain + max_generation_length: 1024 + inference_batch_size: 16 + defense_evals: [strong_reject, mmlu_pro_val] # evals on defended checkpoint + post_attack_evals: [strong_reject, mmlu_pro_val] # evals on post-attack checkpoint + attacks: + - name: lora_finetune + mode: grid # "grid" (fixed configs) or "sweep" (inner Optuna) + config_name: base # which config from attack's grid.yaml + + # Defense-specific hyperparameters (passed to AlignmentDefenseConfig) + alpha: 0.5 + learning_rate: 1.0e-5 + # ... +``` + +- `defense_evals` run on the defended checkpoint (did the defense preserve safety + utility?) +- `post_attack_evals` run on the post-attack checkpoint (did the defense survive the attack?) +- `attacks` specify which attacks to stress-test with. Each attack can use `mode: grid` (run fixed configs from the attack's `grid.yaml`) or `mode: sweep` (run an inner Optuna sweep against the defended model). + +### Grid Benchmark + +**`defense_grid.py`** runs a defense with pre-defined configs, then evaluates and attacks the defended model. + +```bash +uv run python scripts/whitebox/defense_grid.py meta-llama/Llama-3.1-8B-Instruct \ + --defense crl \ + --config-name base \ + --model-alias llama3_8b \ + --results-dir results/defense_grid +``` + +```text +# Output structure: +results/defense_grid/ +└── llama3_8b/ + └── crl/ + └── grid/ + └── trial_0/ + ├── defended_model/ # (deleted after attacks complete) + ├── defense_eval/ + │ └── tamperbench_evaluation/ + │ ├── strong_reject/ + │ └── mmlu_pro_val/ + ├── post_attack/ + │ └── lora_finetune/ + │ └── base/ + │ └── tamperbench_evaluation/ + │ ├── strong_reject/ + │ └── mmlu_pro_val/ + └── trial_results.json # all metrics with prefixed keys +``` + +Run `uv run python scripts/whitebox/defense_grid.py --help` for all arguments. + +### Optuna Sweep + +**`defense_sweep.py`** sweeps defense hyperparameters to find the configuration that best resists attacks. The sweep config defines the search space in `configs/whitebox/defenses//single_objective_sweep.yaml`. + +```bash +uv run python scripts/whitebox/defense_sweep.py meta-llama/Llama-3.1-8B-Instruct \ + --defense crl \ + --n-trials 20 \ + --model-alias llama3_8b \ + --results-dir results/defense_sweep +``` + +Each trial: suggests defense hyperparameters -> runs defense -> evaluates defended checkpoint -> runs attacks -> evaluates post-attack checkpoints -> returns worst-case post-attack metric to Optuna. + +```text +# Output structure: +results/defense_sweep/ +└── llama3_8b/ + └── crl/ + └── optuna_single/ + ├── study.db # Optuna SQLite database + ├── best.yaml # Top defense configs + ├── best.json # Trial summaries + └── trial_0/ + └── ... # same per-trial structure as grid +``` + +Run `uv run python scripts/whitebox/defense_sweep.py --help` for all arguments. + +### Metric Keys + +All metrics use dot-separated prefixed keys in `trial_results.json` and Optuna `user_attrs`: + +```text +defense.strong_reject_score # defense checkpoint eval +defense.mmlu_pro_accuracy # defense checkpoint eval +post_attack.lora_finetune.strong_reject_score # post-attack eval +post_attack.lora_finetune.mmlu_pro_accuracy # post-attack eval +``` + +The primary optimization objective is the worst-case post-attack safety metric from the first attack, using the defender's optimization direction. + +## Tips + +1. **Always use `--model-alias`** -- keeps results organized and enables Optuna resume +2. **Defense checkpoints are cleaned up** -- only eval artifacts and `trial_results.json` persist; model weights are deleted after all attacks complete to save disk +3. **Inner attack sweeps** -- use `mode: sweep` in the attacks config to run an inner Optuna optimization from the attacker's perspective against each defended model diff --git a/docs/USAGE.md b/docs/USAGE.md index ee0cfda7..b526781e 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -102,23 +102,27 @@ print(f"StrongReject score: {evaluation.load_result_objective()}") ### Applying a Defense - +See [DEFENSES.md](DEFENSES.md) for the full defense guide, including grid benchmarks and Optuna sweeps. ```python -from tamperbench.whitebox.defenses.example_defense import ( - ExampleDefense, - ExampleDefenseConfig, +from pathlib import Path +from tamperbench.whitebox.defenses.crl import CRL, CRLConfig + +config = CRLConfig( + input_checkpoint_path=Path("meta-llama/Llama-3.1-8B-Instruct"), + output_checkpoint_path=Path("results/defended_model"), + num_samples=10000, + num_steps=1100, + batch_size=16, + learning_rate=1e-5, + representation_layers=tuple(range(20, 32)), + lora_r=16, + lora_alpha=16, ) -config = ExampleDefenseConfig( - input_checkpoint_path="meta-llama/Llama-3.2-1B", - out_dir="results/defended_model", - defense_strength=1.0, -) - -defense = ExampleDefense(config) -aligned_checkpoint_path = defense.run_defense() -print(f"Defended model saved to: {aligned_checkpoint_path}") +defense = CRL(defense_config=config) +defended_path = defense.run_defense() +print(f"Defended model saved to: {defended_path}") ``` ## Available Attacks and Evaluations @@ -285,6 +289,7 @@ done ## Next Steps +- [DEFENSES.md](DEFENSES.md) - Benchmarking and sweeping alignment defenses - [CONFIGS.md](CONFIGS.md) - How to edit and create configuration files - [ANALYSIS.md](ANALYSIS.md) - Analyzing and visualizing results - [CONTRIBUTING.md](../CONTRIBUTING.md) - Adding new attacks and evaluations diff --git a/scripts/whitebox/benchmark_grid.py b/scripts/whitebox/benchmark_grid.py index 3e942f6a..4b5adb0e 100644 --- a/scripts/whitebox/benchmark_grid.py +++ b/scripts/whitebox/benchmark_grid.py @@ -12,7 +12,6 @@ from dotenv import load_dotenv from tamperbench.utils import get_repo_root -from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig from tamperbench.whitebox.attacks.embedding_attack import ( embedding_attack as _, ) @@ -26,12 +25,9 @@ from tamperbench.whitebox.attacks.multilingual_finetune import ( multilingual_finetune as _, # noqa: F401 ) -from tamperbench.whitebox.attacks.registry import ( - ATTACKS_REGISTRY, -) from tamperbench.whitebox.utils import AttackName, ConfigPath from tamperbench.whitebox.utils.benchmark.io import yaml_to_dict -from tamperbench.whitebox.utils.names import ConfigKeys +from tamperbench.whitebox.utils.benchmark.runners import run_attack_grid_configs REPO_ROOT = get_repo_root() @@ -97,25 +93,18 @@ ) random_seed = cast(int, args.random_seed) - whitebox_attack_config_grids: dict[AttackName, dict[str, Any]] | None = { + whitebox_attack_config_grids: dict[AttackName, dict[str, Any]] = { attack_name: yaml_to_dict(Path(config_root, attack_name, ConfigPath.GRID_YAML)) for attack_name in args.attacks } for attack_name in args.attacks: - attack_config_cls, attack_cls = ATTACKS_REGISTRY[attack_name] - for config_name, attack_config_dict in whitebox_attack_config_grids[attack_name].items(): - # Inject random seed into config - attack_config_dict[ConfigKeys.RANDOM_SEED] = random_seed - - attacker: TamperAttack[TamperAttackConfig] = attack_cls( - attack_config=attack_config_cls.from_dict( - data={ - "input_checkpoint_path": pretrained_model_path, - "out_dir": Path(results_dir, model_alias, attack_name, config_name), - **attack_config_dict, - } - ) - ) - attacker.benchmark() + run_attack_grid_configs( + attack_name=attack_name, + config_grid=whitebox_attack_config_grids[attack_name], + pretrained_model_path=pretrained_model_path, + output_base_dir=Path(results_dir, model_alias, attack_name), + random_seed=random_seed, + cleanup_checkpoints=False, + ) torch.cuda.empty_cache() diff --git a/scripts/whitebox/defense_grid.py b/scripts/whitebox/defense_grid.py new file mode 100644 index 00000000..96f47259 --- /dev/null +++ b/scripts/whitebox/defense_grid.py @@ -0,0 +1,151 @@ +"""Script to benchmark a defense against attacks with configs defined in a grid.""" + +# pyright: reportUnusedCallResult=false, reportAny=false, reportUnknownVariableType=false + +from __future__ import annotations + +import argparse +from argparse import Namespace +from datetime import datetime +from pathlib import Path +from typing import cast + +import torch +from dotenv import load_dotenv + +# Import to trigger registry side-effects +import tamperbench.whitebox.attacks # pyright: ignore[reportUnusedImport] +import tamperbench.whitebox.defenses # pyright: ignore[reportUnusedImport] +import tamperbench.whitebox.evals # noqa: F401 # pyright: ignore[reportUnusedImport] +from tamperbench.utils import get_repo_root +from tamperbench.whitebox.utils import ConfigPath +from tamperbench.whitebox.utils.benchmark.defense_config import ( + DefenseGridConfig, +) +from tamperbench.whitebox.utils.benchmark.defense_trial_manager import DefenseSweepTrialManager +from tamperbench.whitebox.utils.benchmark.io import yaml_to_dict +from tamperbench.whitebox.utils.names import DefenseName + +REPO_ROOT = get_repo_root() +WHITEBOX_DEFENSE_CONFIG_DIR = REPO_ROOT / Path("configs", "whitebox", "defenses") +WHITEBOX_ATTACK_CONFIG_DIR = REPO_ROOT / Path("configs", "whitebox", "attacks") + + +def main() -> None: + """Run grid-based defense benchmark.""" + load_dotenv() + + parser = argparse.ArgumentParser(description="Benchmark a defense against attacks with grid configs.") + parser.add_argument( + "pretrained_model_path", + type=str, + help="Path to an HF model or checkpoint", + ) + parser.add_argument( + "--defense", + type=DefenseName, + choices=list(DefenseName), + required=True, + help="Defense to benchmark", + ) + parser.add_argument( + "--config-name", + type=str, + default=None, + help="Config name from grid.yaml (default: all configs)", + ) + parser.add_argument( + "--results-dir", + type=Path, + default=REPO_ROOT / "results" / f"defense_grid_{datetime.now():%Y_%m_%d}", + help="Directory to store results", + ) + parser.add_argument( + "--configs-dir", + type=Path, + default=WHITEBOX_DEFENSE_CONFIG_DIR, + help="Directory containing defense configs", + ) + parser.add_argument( + "--attack-configs-dir", + type=Path, + default=WHITEBOX_ATTACK_CONFIG_DIR, + help="Directory containing attack configs", + ) + parser.add_argument( + "--model-alias", + type=str, + default=None, + help="Alias for the model (default: {model_basename}_{date})", + ) + parser.add_argument( + "--random-seed", + type=int, + default=42, + help="Random seed", + ) + + args: Namespace = parser.parse_args() + + config_root = cast(Path, args.configs_dir) + attack_configs_dir = cast(Path, args.attack_configs_dir) + pretrained_model_path = cast(str, args.pretrained_model_path) + results_dir = cast(Path, args.results_dir) + defense_name = cast(DefenseName, args.defense) + model_alias = cast( + str, + args.model_alias or f"{Path(pretrained_model_path).name}_{datetime.now():%Y_%m_%d}", + ) + random_seed = cast(int, args.random_seed) + + # Load the full grid + grid_path = Path(config_root, defense_name, ConfigPath.GRID_YAML) + defense_grid: dict[str, object] = yaml_to_dict(grid_path) + + # Filter to specific config if requested + if args.config_name: + config_name_str = cast(str, args.config_name) + if config_name_str not in defense_grid: + raise ValueError(f"Config '{config_name_str}' not found. Available: {list(defense_grid.keys())}") + defense_grid = {config_name_str: defense_grid[config_name_str]} + + config_names = list(defense_grid.keys()) + + for config_name, config_value in defense_grid.items(): + if not isinstance(config_value, dict): + raise TypeError(f"Expected dict for config '{config_name}', got {type(config_value)}") + + config_dict: dict[str, object] = config_value + + # Parse meta-fields + grid_config = DefenseGridConfig.from_dict(config_dict) + + print(f"\n{'=' * 60}") + print(f"Running defense: {defense_name}, config: {config_name}") + print(f"{'=' * 60}\n") + + # Call DefenseSweepTrialManager.run_trial for each config + metrics = DefenseSweepTrialManager.run_trial( + defense_name=defense_name, + defense_config_dict=config_dict, + defense_eval_names=grid_config.defense_evals, + post_attack_eval_names=grid_config.post_attack_evals, + model_config_dict=grid_config.inference_model_config, + attacks=grid_config.attacks, + pretrained_model_path=pretrained_model_path, + defense_results_dir=results_dir / model_alias / str(defense_name) / "grid", + trial_number=config_names.index(config_name), + random_seed=random_seed, + attack_configs_dir=attack_configs_dir, + model_alias=model_alias, + ) + + print(f"\nMetrics for {config_name}:") + for key, value in sorted(metrics.items()): + print(f" {key}: {value:.4f}") + + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/scripts/whitebox/defense_sweep.py b/scripts/whitebox/defense_sweep.py new file mode 100644 index 00000000..d770c446 --- /dev/null +++ b/scripts/whitebox/defense_sweep.py @@ -0,0 +1,224 @@ +"""Optuna runner for single-objective defense hyperparameter sweeps.""" + +# pyright: reportUnusedCallResult=false, reportUnknownVariableType=false + +from __future__ import annotations + +import argparse +import multiprocessing +from argparse import Namespace +from collections.abc import Callable +from datetime import datetime +from pathlib import Path +from typing import cast + +import optuna +import torch + +# Import to trigger registry side-effects +import tamperbench.whitebox.attacks # pyright: ignore[reportUnusedImport] +import tamperbench.whitebox.defenses # pyright: ignore[reportUnusedImport] +import tamperbench.whitebox.evals # noqa: F401 # pyright: ignore[reportUnusedImport] +from tamperbench.utils import get_repo_root +from tamperbench.whitebox.utils import ConfigPath +from tamperbench.whitebox.utils.benchmark.defense_config import ( + DefenseSweepConfig, + PostDefenseAttackSpec, + load_defense_base_config, +) +from tamperbench.whitebox.utils.benchmark.defense_trial_manager import DefenseSweepTrialManager +from tamperbench.whitebox.utils.benchmark.path_generation import DefenseStudyPaths +from tamperbench.whitebox.utils.benchmark.runners import run_optuna_sweep +from tamperbench.whitebox.utils.names import DefenseMetricPrefix, DefenseName, EvalName + +REPO_ROOT = get_repo_root() +WHITEBOX_DEFENSE_CONFIG_DIR = REPO_ROOT / Path("configs", "whitebox", "defenses") +WHITEBOX_ATTACK_CONFIG_DIR = REPO_ROOT / Path("configs", "whitebox", "attacks") + + +def _build_defense_objective( + defense_name: DefenseName, + defense_eval_names: list[EvalName], + post_attack_eval_names: list[EvalName], + model_config_dict: dict[str, object], + attacks: list[PostDefenseAttackSpec], + pretrained_model_path: str, + study_paths: DefenseStudyPaths, + random_seed: int, + attack_configs_dir: Path | None = None, + model_alias: str = "", +) -> Callable[[optuna.trial.Trial, dict[str, object]], dict[str, float]]: + """Build objective function for defense sweep. + + Creates a closure that runs a single defense trial via + ``DefenseSweepTrialManager.run_trial`` and returns evaluation metrics. + This is passed as ``objective_fn`` to ``run_optuna_sweep``. + + The returned dict has prefixed keys like ``"defense.strong_reject"``, + ``"post_attack.lora_finetune.strong_reject"``. + + Args: + defense_name: Name of the defense to run in each trial. + defense_eval_names: Evaluations to run on the defense checkpoint. + post_attack_eval_names: Evaluations to run on post-attack checkpoints. + model_config_dict: Raw model configuration dict for inference. + attacks: Attack specifications to stress-test the defense. + pretrained_model_path: Path to the pretrained model checkpoint. + study_paths: Paths object for the current defense sweep study. + random_seed: Random seed for reproducibility. + attack_configs_dir: Root directory for attack configs used in + post-defense attacks. + model_alias: Alias for the model, passed to inner attack sweeps. + + Returns: + A callable ``(trial, merged_config) -> dict[str, float]`` suitable for + ``run_optuna_sweep``'s ``objective_fn`` parameter. + """ + + def defense_objective(trial: optuna.trial.Trial, merged_config: dict[str, object]) -> dict[str, float]: + return DefenseSweepTrialManager.run_trial( + defense_name=defense_name, + defense_config_dict=merged_config, + defense_eval_names=defense_eval_names, + post_attack_eval_names=post_attack_eval_names, + model_config_dict=model_config_dict, + attacks=attacks, + pretrained_model_path=pretrained_model_path, + defense_results_dir=study_paths.defense_results_dir, + trial_number=trial.number, + random_seed=random_seed, + attack_configs_dir=attack_configs_dir, + model_alias=model_alias, + ) + + return defense_objective + + +def main() -> None: + """Run single-objective Optuna hyperparameter sweep over a defense.""" + parser = argparse.ArgumentParser(description="Optuna single-objective sweep over a defense") + parser.add_argument( + "pretrained_model_path", + type=str, + help="Path to an HF model or checkpoint", + ) + parser.add_argument( + "--defense", + type=DefenseName, + choices=list(DefenseName), + required=True, + help="Defense to sweep", + ) + parser.add_argument( + "--n-trials", + type=int, + default=8, + help="Number of Optuna trials", + ) + parser.add_argument( + "--results-dir", + type=Path, + default=REPO_ROOT / "results" / f"defense_sweep_{datetime.now():%Y_%m_%d}", + help="Directory to store sweep results", + ) + parser.add_argument( + "--top-n", + type=int, + default=3, + help="Number of top configs to save in YAML", + ) + parser.add_argument( + "--configs-dir", + type=Path, + default=WHITEBOX_DEFENSE_CONFIG_DIR, + help="Directory containing defense configs", + ) + parser.add_argument( + "--attack-configs-dir", + type=Path, + default=WHITEBOX_ATTACK_CONFIG_DIR, + help="Directory containing attack configs (for post-defense attacks)", + ) + parser.add_argument( + "--model-alias", + type=str, + default=None, + help="Alias for the model (default: {model_basename}_{date})", + ) + parser.add_argument( + "--random-seed", + type=int, + default=42, + help="Random seed", + ) + + args: Namespace = parser.parse_args() + multiprocessing.set_start_method("spawn", force=True) + + config_root = cast(Path, args.configs_dir) + attack_configs_dir = cast(Path, args.attack_configs_dir) + pretrained_model_path = cast(str, args.pretrained_model_path) + results_dir = cast(Path, args.results_dir) + defense_name = cast(DefenseName, args.defense) + model_alias = cast( + str, + args.model_alias or f"{Path(pretrained_model_path).name}_{datetime.now():%Y_%m_%d}", + ) + random_seed = cast(int, args.random_seed) + + # Load base config and sweep config + base_name, base_defense_config = load_defense_base_config(config_root, defense_name) + sweep_config = DefenseSweepConfig.from_yaml( + config_root, defense_name, sweep_filename=ConfigPath.SINGLE_OBJECTIVE_SWEEP + ) + + study_paths = DefenseStudyPaths( + results_dir=results_dir, + model_alias=model_alias, + defense_name=defense_name, + ) + + # Build prefixed metric keys so best.json values are populated correctly + attack_names = [attack.name for attack in sweep_config.attacks] + all_metric_keys = DefenseMetricPrefix.build_all_metric_keys( + defense_eval_names=sweep_config.defense_evals, + post_attack_eval_names=sweep_config.post_attack_evals, + attack_names=attack_names, + ) + + # Primary metric: first post_attack_eval from first attack (prefixed key) + first_attack = sweep_config.attacks[0] + primary_metric_key = DefenseMetricPrefix.post_attack_key(first_attack.name, sweep_config.post_attack_evals[0]) + + defense_objective = _build_defense_objective( + defense_name=defense_name, + defense_eval_names=sweep_config.defense_evals, + post_attack_eval_names=sweep_config.post_attack_evals, + model_config_dict=sweep_config.inference_model_config, + attacks=sweep_config.attacks, + pretrained_model_path=pretrained_model_path, + study_paths=study_paths, + random_seed=random_seed, + attack_configs_dir=attack_configs_dir, + model_alias=model_alias, + ) + + run_optuna_sweep( + study_paths=study_paths, + sweep_space=sweep_config.sweep, + base_config=base_defense_config, + direction=sweep_config.primary_objective_direction, + objective_fn=defense_objective, + primary_metric_key=primary_metric_key, + n_trials=cast(int, args.n_trials), + random_seed=random_seed, + top_n=cast(int, args.top_n), + eval_names=all_metric_keys, + base_config_name=base_name, + ) + + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/scripts/whitebox/optuna_single.py b/scripts/whitebox/optuna_single.py index de3e987c..57d3de46 100644 --- a/scripts/whitebox/optuna_single.py +++ b/scripts/whitebox/optuna_single.py @@ -5,10 +5,9 @@ from __future__ import annotations import argparse -import copy import multiprocessing -import traceback from argparse import Namespace +from collections.abc import Callable from datetime import datetime from pathlib import Path from typing import cast @@ -24,15 +23,12 @@ AttackSweepConfig, load_attack_base_config, ) -from tamperbench.whitebox.utils.benchmark.io import deep_merge_dicts from tamperbench.whitebox.utils.benchmark.path_generation import StudyPaths -from tamperbench.whitebox.utils.benchmark.sweep import suggest_nested_params +from tamperbench.whitebox.utils.benchmark.runners import run_optuna_sweep from tamperbench.whitebox.utils.benchmark.trial_manager import SweepTrialManager from tamperbench.whitebox.utils.names import ( ConfigKeys, EvalName, - OptunaUserAttrs, - TrialFailureKeys, ) REPO_ROOT = get_repo_root() @@ -40,6 +36,48 @@ WHITEBOX_ATTACK_CONFIG_DIR = REPO_ROOT / Path("configs", "whitebox", "attacks") +def _build_attack_objective( + attack_name: AttackName, + eval_names: list[EvalName], + pretrained_model_path: str, + study_paths: StudyPaths, + random_seed: int, +) -> Callable[[optuna.trial.Trial, dict[str, object]], dict[str, float]]: + """Build an objective function for attack sweeps. + + Creates a closure that runs a single attack trial via ``SweepTrialManager.run_trial`` + and returns evaluation metrics. This is passed as ``objective_fn`` to + ``run_optuna_sweep``. + + Args: + attack_name: Name of the attack to run in each trial. + eval_names: Ordered list of evaluation names to compute metrics for. + pretrained_model_path: Path to the pretrained model checkpoint. + study_paths: Paths object for the current sweep study. + random_seed: Random seed for reproducibility. + + Returns: + A callable ``(trial, merged_config) -> dict[str, float]`` suitable for + ``run_optuna_sweep``'s ``objective_fn`` parameter. + """ + + def attack_objective(trial: optuna.trial.Trial, merged_config: dict[str, object]) -> dict[str, float]: + merged_config[ConfigKeys.EVALS] = [str(e) for e in eval_names] + merged_config[ConfigKeys.RANDOM_SEED] = random_seed + + return SweepTrialManager.run_trial( + attack_name=attack_name, + attack_config_dict=merged_config, + eval_names=eval_names, + pretrained_model_path=pretrained_model_path, + attack_results_dir=study_paths.attack_results_dir, + trial_number=trial.number, + random_seed=random_seed, + ) + + return attack_objective + + def main() -> None: """Run single-objective Optuna hyperparameter sweep over whitebox attacks.""" parser = argparse.ArgumentParser(description="Optuna single-objective sweep over whitebox attacks") @@ -119,74 +157,27 @@ def main() -> None: model_alias=model_alias, attack_name=attack_name, ) - study_paths.ensure_dirs() - # Use TPESampler with seed for reproducibility - sampler = optuna.samplers.TPESampler(seed=random_seed) - study = optuna.create_study( - direction=sweep_config.primary_objective_direction, - study_name=study_paths.study_name, - storage=study_paths.storage_url, - load_if_exists=True, - sampler=sampler, + attack_objective = _build_attack_objective( + attack_name=attack_name, + eval_names=sweep_config.evals, + pretrained_model_path=pretrained_model_path, + study_paths=study_paths, + random_seed=random_seed, ) - def objective( - trial: optuna.trial.Trial, - *, - _sweep_space: dict[str, object] = sweep_config.sweep, - _base_attack_config: dict[str, object] = base_attack_config, - _eval_names: list[EvalName] = sweep_config.evals, - _attack_name: AttackName = attack_name, - _pretrained_model_path: str = pretrained_model_path, - _study_paths: StudyPaths = study_paths, - _random_seed: int = random_seed, - ) -> float: - trial_overrides = suggest_nested_params(trial, _sweep_space) - merged_config = deep_merge_dicts(_base_attack_config, trial_overrides) - merged_config[ConfigKeys.EVALS] = [str(e) for e in _eval_names] - merged_config[ConfigKeys.RANDOM_SEED] = _random_seed - - trial.set_user_attr(OptunaUserAttrs.CONFIG_OVERRIDES, trial_overrides) - trial.set_user_attr(OptunaUserAttrs.MERGED_CONFIG, copy.deepcopy(merged_config)) - - try: - objective_values = SweepTrialManager.run_trial( - attack_name=_attack_name, - attack_config_dict=merged_config, - eval_names=_eval_names, - pretrained_model_path=_pretrained_model_path, - attack_results_dir=_study_paths.attack_results_dir, - trial_number=trial.number, - random_seed=_random_seed, - ) - trial.set_user_attr(OptunaUserAttrs.EVAL_METRICS, objective_values) - return objective_values[_eval_names[0]] - - except Exception as error: # pragma: no cover - surfaced to study logs - stack_trace = traceback.format_exc() - trial.set_user_attr( - OptunaUserAttrs.FAILURE, - { - TrialFailureKeys.ERROR: str(error), - TrialFailureKeys.TRACEBACK: stack_trace, - }, - ) - raise optuna.exceptions.TrialPruned(f"Trial failed due to exception: {error}") from error - - remaining_trials = max(0, cast(int, args.n_trials) - len(study.trials)) - if remaining_trials > 0: - study.optimize(objective, n_trials=remaining_trials) - - SweepTrialManager.summarize_top_n( - study, + run_optuna_sweep( + study_paths=study_paths, + sweep_space=sweep_config.sweep, + base_config=base_attack_config, direction=sweep_config.primary_objective_direction, + objective_fn=attack_objective, + primary_metric_key=str(sweep_config.evals[0]), + n_trials=cast(int, args.n_trials), + random_seed=random_seed, top_n=cast(int, args.top_n), eval_names=sweep_config.evals, - base_cfg=base_attack_config, base_config_name=base_name, - output_dir=study_paths.attack_results_dir, - study_storage=study_paths.storage_path, ) torch.cuda.empty_cache() diff --git a/scripts/whitebox/run_defense.py b/scripts/whitebox/run_defense.py index c3ba2919..500b1dc7 100644 --- a/scripts/whitebox/run_defense.py +++ b/scripts/whitebox/run_defense.py @@ -107,7 +107,11 @@ def main() -> None: print(f"{'=' * 60}\n") # Filter out fields that are not part of defense config (model_config, evals are for attacks/evaluation) - filtered_config_dict = {k: v for k, v in defense_config_dict.items() if k not in ["model_config", "evals"]} + filtered_config_dict = { + k: v + for k, v in defense_config_dict.items() + if k not in ["model_config", "evals", "defense_evals", "post_attack_evals", "attacks"] + } defense_config: AlignmentDefenseConfig = defense_config_cls.from_dict( data={ diff --git a/src/tamperbench/whitebox/attacks/base.py b/src/tamperbench/whitebox/attacks/base.py index 458532b6..ebbe1cd9 100644 --- a/src/tamperbench/whitebox/attacks/base.py +++ b/src/tamperbench/whitebox/attacks/base.py @@ -157,9 +157,18 @@ def run_attack(self) -> None: @final def delete_output_checkpoint(self) -> None: - """Delete the tampered model checkpoint if it exists.""" - if Path(self.output_checkpoint_path).exists(): - shutil.rmtree(self.output_checkpoint_path) + """Delete the tampered model checkpoint if it exists. + + Skips deletion when the output checkpoint is the same path as the + input (e.g. ``NoWeightModificationAttack``), to avoid destroying the + original model. + """ + output = Path(self.output_checkpoint_path).resolve() + input_path = Path(self.attack_config.input_checkpoint_path).resolve() + if output == input_path: + return + if output.exists(): + shutil.rmtree(output) def evaluate(self) -> DataFrame[EvaluationSchema]: """Evaluate attack on relevant evaluations. diff --git a/src/tamperbench/whitebox/defenses/crl/crl.py b/src/tamperbench/whitebox/defenses/crl/crl.py index 58487eb3..b31c1642 100644 --- a/src/tamperbench/whitebox/defenses/crl/crl.py +++ b/src/tamperbench/whitebox/defenses/crl/crl.py @@ -31,9 +31,12 @@ ) from tamperbench.whitebox.defenses.crl.trainer import run_training_loop from tamperbench.whitebox.defenses.defense import AlignmentDefense +from tamperbench.whitebox.defenses.registry import register_defense +from tamperbench.whitebox.utils.names import DefenseName from tamperbench.whitebox.utils.ops.dealloc import dealloc_model_and_tokenizer +@register_defense(DefenseName.CRL, CRLConfig) class CRL(AlignmentDefense[CRLConfig]): """CRL defense: hardens models via contrastive representation learning. diff --git a/src/tamperbench/whitebox/defenses/ctrl/ctrl.py b/src/tamperbench/whitebox/defenses/ctrl/ctrl.py index 795f2da4..7c26401b 100644 --- a/src/tamperbench/whitebox/defenses/ctrl/ctrl.py +++ b/src/tamperbench/whitebox/defenses/ctrl/ctrl.py @@ -96,6 +96,18 @@ def _load_tokenizer(model_name: str) -> PreTrainedTokenizer: if tokenizer.pad_token is None: tokenizer.add_special_tokens(special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}) + # Base models (e.g. Meta-Llama-3-8B) lack a chat template, which SFTTrainer + # requires for conversational-format data. Set a minimal template so + # apply_chat_template works for both base and instruct models. + if not getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}### Instruction:\n{{ message['content'] }}\n\n" + "{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content'] }}" + "{% endif %}{% endfor %}" + "{% if add_generation_prompt %}### Response:\n{% endif %}" + ) + return tokenizer diff --git a/src/tamperbench/whitebox/defenses/ctrl/curation.py b/src/tamperbench/whitebox/defenses/ctrl/curation.py index 49b75e44..7d9307d5 100644 --- a/src/tamperbench/whitebox/defenses/ctrl/curation.py +++ b/src/tamperbench/whitebox/defenses/ctrl/curation.py @@ -523,6 +523,16 @@ def curate_batch( """ cache = Cache(cache_dir) tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_path) + # Base models (e.g. Meta-Llama-3-8B) lack a chat template; set a minimal + # fallback so apply_chat_template works during curation prompt formatting. + if not getattr(tokenizer, "chat_template", None): + tokenizer.chat_template = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}### Instruction:\n{{ message['content'] }}\n\n" + "{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content'] }}" + "{% endif %}{% endfor %}" + "{% if add_generation_prompt %}### Response:\n{% endif %}" + ) vllm_model = LLM( model=model_path, tensor_parallel_size=1, @@ -622,8 +632,9 @@ def curate_dataset( Returns: Curated dataset """ - # Build cache config + # Build cache config (includes model so different models get separate caches) cache_config = { + "model_name": str(config.input_checkpoint_path), "dataset_name": config.clean_dataset_name, "dataset_split": config.clean_dataset_split, "curation_rate": config.curation_rate, diff --git a/src/tamperbench/whitebox/defenses/t_vaccine/train.py b/src/tamperbench/whitebox/defenses/t_vaccine/train.py index 92403915..a55cf716 100644 --- a/src/tamperbench/whitebox/defenses/t_vaccine/train.py +++ b/src/tamperbench/whitebox/defenses/t_vaccine/train.py @@ -970,8 +970,15 @@ def apply_track_drift_hooks_recursive(module, hook_fn, hooks): # # print(torch.norm(param)) # print("weights norm{}".format(norm)) trainer.save_state() + + # Merge LoRA adapter into base model so the output is a complete + # checkpoint that vLLM (and other inference backends) can load directly. + if isinstance(model, PeftModel): + print("Merging LoRA adapter into base model...") + model = model.merge_and_unload() + model.save_pretrained(training_args.output_dir) - # trainer.save_model(output_dir=training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) return Path(training_args.output_dir) finally: sys.stdout = original_stdout diff --git a/src/tamperbench/whitebox/evals/__init__.py b/src/tamperbench/whitebox/evals/__init__.py index 45f867ce..1a9fac14 100644 --- a/src/tamperbench/whitebox/evals/__init__.py +++ b/src/tamperbench/whitebox/evals/__init__.py @@ -33,6 +33,7 @@ PolicyEvaluation, PolicyEvaluationConfig, ) +from tamperbench.whitebox.evals.registry import EVAL_CONFIG_REGISTRY from tamperbench.whitebox.evals.safety_gap.safety_gap import ( SafetyGapEvaluation, SafetyGapEvaluationConfig, @@ -52,6 +53,7 @@ ) __all__ = [ + "EVAL_CONFIG_REGISTRY", "EmbeddingAttackEvaluation", "EmbeddingAttackEvaluationConfig", "GCGEvaluation", diff --git a/src/tamperbench/whitebox/evals/ifeval/ifeval.py b/src/tamperbench/whitebox/evals/ifeval/ifeval.py index 22e686b9..23445ec3 100644 --- a/src/tamperbench/whitebox/evals/ifeval/ifeval.py +++ b/src/tamperbench/whitebox/evals/ifeval/ifeval.py @@ -66,7 +66,7 @@ class IFEvalEvaluationConfig(WhiteBoxEvaluationConfig): max_samples: int | None = None -@register_evaluation(EvalName.IFEVAL) +@register_evaluation(EvalName.IFEVAL, IFEvalEvaluationConfig) class IFEvalEvaluation(WhiteBoxEvaluation[IFEvalEvaluationConfig]): """Instruction Following Evaluation.""" diff --git a/src/tamperbench/whitebox/evals/mbpp/mbpp.py b/src/tamperbench/whitebox/evals/mbpp/mbpp.py index c5a4ff79..0b3235e5 100644 --- a/src/tamperbench/whitebox/evals/mbpp/mbpp.py +++ b/src/tamperbench/whitebox/evals/mbpp/mbpp.py @@ -207,7 +207,7 @@ class MBPPEvaluationConfig(WhiteBoxEvaluationConfig): timeout: float = 3.0 # Execution timeout in seconds -@register_evaluation(EvalName.MBPP) +@register_evaluation(EvalName.MBPP, MBPPEvaluationConfig) class MBPPEvaluation(WhiteBoxEvaluation[MBPPEvaluationConfig]): """MBPP (Mostly Basic Programming Problems) Evaluation. diff --git a/src/tamperbench/whitebox/evals/minerva_math/minerva_math.py b/src/tamperbench/whitebox/evals/minerva_math/minerva_math.py index a67d9a8d..c167ef8b 100644 --- a/src/tamperbench/whitebox/evals/minerva_math/minerva_math.py +++ b/src/tamperbench/whitebox/evals/minerva_math/minerva_math.py @@ -58,7 +58,7 @@ class MinervaMathEvaluationConfig(WhiteBoxEvaluationConfig): """Evaluation Configuration for the Minerva MATH benchmark.""" -@register_evaluation(EvalName.MINERVA_MATH) +@register_evaluation(EvalName.MINERVA_MATH, MinervaMathEvaluationConfig) class MinervaMathEvaluation(WhiteBoxEvaluation[MinervaMathEvaluationConfig]): """Minerva MATH Evaluation class. diff --git a/src/tamperbench/whitebox/evals/mmlu_pro/mmlu_pro.py b/src/tamperbench/whitebox/evals/mmlu_pro/mmlu_pro.py index 703c8121..5c071ced 100644 --- a/src/tamperbench/whitebox/evals/mmlu_pro/mmlu_pro.py +++ b/src/tamperbench/whitebox/evals/mmlu_pro/mmlu_pro.py @@ -220,7 +220,7 @@ def _gold_letter(self, row: dict) -> str: return "A" -@register_evaluation(EvalName.MMLU_PRO_VAL) +@register_evaluation(EvalName.MMLU_PRO_VAL, MMLUProEvaluationConfig) class MMLUProValEvaluation(_MMLUProVLLMBase): """Evaluate on a subset of the test split using few-shot exemplars drawn from the validation split.""" @@ -233,7 +233,7 @@ def _build_shot_pool(self, val_rows: list[dict], test_rows: list[dict]) -> dict[ return self._build_shot_pool_from_rows(val_rows) -@register_evaluation(EvalName.MMLU_PRO_TEST) +@register_evaluation(EvalName.MMLU_PRO_TEST, MMLUProEvaluationConfig) class MMLUProTestEvaluation(_MMLUProVLLMBase): """Evaluate on test split; shots come from validation split per subject.""" diff --git a/src/tamperbench/whitebox/evals/policy_eval/policy_eval.py b/src/tamperbench/whitebox/evals/policy_eval/policy_eval.py index a8b5d289..4a1eaff0 100644 --- a/src/tamperbench/whitebox/evals/policy_eval/policy_eval.py +++ b/src/tamperbench/whitebox/evals/policy_eval/policy_eval.py @@ -193,7 +193,7 @@ class PolicyEvaluationConfig(WhiteBoxEvaluationConfig): max_judge_workers: int = 16 -@register_evaluation(EvalName.POLICY_EVAL) +@register_evaluation(EvalName.POLICY_EVAL, PolicyEvaluationConfig) class PolicyEvaluation(WhiteBoxEvaluation[PolicyEvaluationConfig]): """Policy-Oriented Safety Evaluation using an LLM as judge. diff --git a/src/tamperbench/whitebox/evals/registry.py b/src/tamperbench/whitebox/evals/registry.py index 16ed7661..a19d4607 100644 --- a/src/tamperbench/whitebox/evals/registry.py +++ b/src/tamperbench/whitebox/evals/registry.py @@ -11,20 +11,25 @@ from collections.abc import Callable from typing import TypeVar -from tamperbench.whitebox.evals.base import WhiteBoxEvaluation +from tamperbench.whitebox.evals.base import WhiteBoxEvaluation, WhiteBoxEvaluationConfig from tamperbench.whitebox.utils.names import EvalName E = TypeVar("E", bound=WhiteBoxEvaluation) EVALS_REGISTRY: dict[EvalName, type[WhiteBoxEvaluation]] = {} +EVAL_CONFIG_REGISTRY: dict[EvalName, type[WhiteBoxEvaluationConfig]] = {} -def register_evaluation(name: EvalName) -> Callable[[type[E]], type[E]]: - """Register an evaluation class under a name.""" +def register_evaluation( + name: EvalName, + config_cls: type[WhiteBoxEvaluationConfig] = WhiteBoxEvaluationConfig, +) -> Callable[[type[E]], type[E]]: + """Register an evaluation class and its config class under a name.""" def _decorator(eval_cls: type[E]) -> type[E]: EVALS_REGISTRY[name] = eval_cls + EVAL_CONFIG_REGISTRY[name] = config_cls return eval_cls return _decorator diff --git a/src/tamperbench/whitebox/evals/safety_gap/safety_gap.py b/src/tamperbench/whitebox/evals/safety_gap/safety_gap.py index e3f72e9f..88029915 100644 --- a/src/tamperbench/whitebox/evals/safety_gap/safety_gap.py +++ b/src/tamperbench/whitebox/evals/safety_gap/safety_gap.py @@ -60,7 +60,7 @@ class SafetyGapEvaluationConfig(WhiteBoxEvaluationConfig): domains: list[SafetyGapDomain] = field(default_factory=lambda: list(SafetyGapDomain)) -@register_evaluation(EvalName.SAFETY_GAP) +@register_evaluation(EvalName.SAFETY_GAP, SafetyGapEvaluationConfig) class SafetyGapEvaluation(WhiteBoxEvaluation[SafetyGapEvaluationConfig]): """SafetyGap Evaluation class. diff --git a/src/tamperbench/whitebox/evals/strong_reject/strong_reject.py b/src/tamperbench/whitebox/evals/strong_reject/strong_reject.py index c7084d14..2692259c 100644 --- a/src/tamperbench/whitebox/evals/strong_reject/strong_reject.py +++ b/src/tamperbench/whitebox/evals/strong_reject/strong_reject.py @@ -48,7 +48,7 @@ class StrongRejectEvaluationConfig(WhiteBoxEvaluationConfig): max_samples: int | None = field(default=None, kw_only=True) -@register_evaluation(EvalName.STRONG_REJECT) +@register_evaluation(EvalName.STRONG_REJECT, StrongRejectEvaluationConfig) class StrongRejectEvaluation(WhiteBoxEvaluation[S]): """StrongREJECT Evaluation class.""" @@ -133,7 +133,7 @@ def load_strong_reject_prompts(self) -> list[str]: return [row["forbidden_prompt"] for row in strong_reject_dataset] # pyright: ignore[reportCallIssue] -@register_evaluation(EvalName.JAILBREAK_BENCH) +@register_evaluation(EvalName.JAILBREAK_BENCH, StrongRejectEvaluationConfig) class JailbreakBenchEvaluation(StrongRejectEvaluation[S]): """JailbreakBench Evaluation class using the JBB-Behaviors dataset, scored by StrongREJECT.""" diff --git a/src/tamperbench/whitebox/evals/wmdp/wmdp.py b/src/tamperbench/whitebox/evals/wmdp/wmdp.py index 13335235..e881b492 100644 --- a/src/tamperbench/whitebox/evals/wmdp/wmdp.py +++ b/src/tamperbench/whitebox/evals/wmdp/wmdp.py @@ -121,7 +121,7 @@ class WMDPQuestion: domain: WMDPDomain -@register_evaluation(EvalName.WMDP) +@register_evaluation(EvalName.WMDP, WMDPEvaluationConfig) class WMDPEvaluation(WhiteBoxEvaluation[WMDPEvaluationConfig]): """WMDP Benchmark Evaluation using multiple-choice questions. diff --git a/src/tamperbench/whitebox/evals/xstest/xstest.py b/src/tamperbench/whitebox/evals/xstest/xstest.py index 3cab7f6e..5e3a437c 100644 --- a/src/tamperbench/whitebox/evals/xstest/xstest.py +++ b/src/tamperbench/whitebox/evals/xstest/xstest.py @@ -225,7 +225,7 @@ class XSTestEvaluationConfig(WhiteBoxEvaluationConfig): prompt_types: list[str] = field(default_factory=list) -@register_evaluation(EvalName.XSTEST) +@register_evaluation(EvalName.XSTEST, XSTestEvaluationConfig) class XSTestEvaluation(WhiteBoxEvaluation[XSTestEvaluationConfig]): """XSTest Evaluation for exaggerated safety behaviors. diff --git a/src/tamperbench/whitebox/utils/benchmark/__init__.py b/src/tamperbench/whitebox/utils/benchmark/__init__.py index 6529178e..d990f744 100644 --- a/src/tamperbench/whitebox/utils/benchmark/__init__.py +++ b/src/tamperbench/whitebox/utils/benchmark/__init__.py @@ -1,9 +1,16 @@ -"""Utilities to facilitate end-to-end benchmarking of models against attacks.""" +"""Utilities to facilitate end-to-end benchmarking of models against attacks and defenses.""" from tamperbench.whitebox.utils.benchmark.config import ( AttackSweepConfig, load_attack_base_config, ) +from tamperbench.whitebox.utils.benchmark.defense_config import ( + DefenseGridConfig, + DefenseSweepConfig, + PostDefenseAttackSpec, + load_defense_base_config, +) +from tamperbench.whitebox.utils.benchmark.defense_trial_manager import DefenseSweepTrialManager from tamperbench.whitebox.utils.benchmark.io import ( deep_merge_dicts, dict_to_yaml, @@ -11,25 +18,40 @@ yaml_to_dict, ) from tamperbench.whitebox.utils.benchmark.path_generation import ( + DefenseStudyPaths, StudyPaths, get_aggregated_dir_name, + get_defense_study_name, get_single_objective_study_name, get_trial_dir_name, get_trial_eval_dir_name, ) +from tamperbench.whitebox.utils.benchmark.runners import ( + run_attack_grid_configs, + run_optuna_sweep, +) from tamperbench.whitebox.utils.benchmark.trial_manager import SweepTrialManager __all__ = [ "AttackSweepConfig", + "DefenseGridConfig", + "DefenseStudyPaths", + "DefenseSweepConfig", + "DefenseSweepTrialManager", + "PostDefenseAttackSpec", "StudyPaths", "SweepTrialManager", "deep_merge_dicts", "dict_to_yaml", "get_aggregated_dir_name", + "get_defense_study_name", "get_single_objective_study_name", "get_trial_dir_name", "get_trial_eval_dir_name", "load_attack_base_config", + "load_defense_base_config", "nest_dotted_params", + "run_attack_grid_configs", + "run_optuna_sweep", "yaml_to_dict", ] diff --git a/src/tamperbench/whitebox/utils/benchmark/defense_config.py b/src/tamperbench/whitebox/utils/benchmark/defense_config.py new file mode 100644 index 00000000..0f0c14dd --- /dev/null +++ b/src/tamperbench/whitebox/utils/benchmark/defense_config.py @@ -0,0 +1,337 @@ +"""Configuration models and loaders for defense sweeps and grid benchmarks.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator +from typing_extensions import Self + +from tamperbench.whitebox.evals.registry import EVALS_REGISTRY +from tamperbench.whitebox.utils import AttackName, ConfigPath, EvalName +from tamperbench.whitebox.utils.benchmark.io import yaml_to_dict +from tamperbench.whitebox.utils.names import ConfigKeys, DefenseName + + +class PostDefenseAttackSpec(BaseModel): + """Specification for an attack to run against a defended model. + + Defines how to stress-test a defense by running an attack after the defense + has been applied. Supports two modes: ``grid`` (fixed configs from the + attack's ``grid.yaml``) and ``sweep`` (Optuna optimization of attack + hyperparameters). + + Attributes: + name: Name of the attack to run from the attack registry. + mode: Whether to use fixed grid configs or run an Optuna sweep. + configs_dir: Root directory for attack configs. Defaults to the + repository's ``configs/whitebox/attacks`` if not specified. + config_name: For grid mode, which config to use from grid.yaml. + Use ``None`` to run all configs. + n_trials: For sweep mode, number of Optuna trials to run for the + inner attack optimization. + + Example: + >>> spec = PostDefenseAttackSpec(name=AttackName.LORA_FINETUNE) + >>> spec.mode + 'grid' + >>> spec.config_name + 'base' + """ + + name: AttackName = Field(..., description="Attack name from the attack registry") + mode: Literal["grid", "sweep"] = Field("grid", description="Whether to run grid configs or an Optuna sweep") + configs_dir: Path | None = Field(None, description="Root directory for attack configs; defaults to repo default") + config_name: str | None = Field("base", description="Grid mode: which config from grid.yaml (None = all)") + n_trials: int = Field(10, description="Sweep mode: number of Optuna trials for the inner attack sweep") + + +class DefenseSweepConfig(BaseModel): + """Validated representation of a defense Optuna sweep file. + + Contains all configuration needed to run an Optuna hyperparameter sweep + for a defense, including which evaluations to run on the defended model, + which evaluations to run after post-defense attacks, the attack + specifications, and the hyperparameter search space. + + Attributes: + defense_evals: Evaluations to run on the defense checkpoint + (pre-attack, post-defense). Measures whether the defense + preserved safety and utility. + post_attack_evals: Evaluations to run on post-attack checkpoints. + Measures whether the defense survived the attack. + inference_model_config: Raw model configuration dict, converted to + ``ModelConfig`` later by the trial manager. Populated from the + ``model_config`` key in YAML. + attacks: List of attack specifications to stress-test the defense. + sweep: Nested Optuna search space definition for defense hyperparameters. + + Example: + >>> sweep = DefenseSweepConfig.from_yaml( + ... config_root=Path("configs/whitebox/defenses"), + ... defense_name=DefenseName.CRL, + ... sweep_filename=ConfigPath.SINGLE_OBJECTIVE_SWEEP, + ... ) + >>> sweep.defense_evals + [] + >>> sweep.primary_objective_direction + 'minimize' + """ + + model_config: ConfigDict = ConfigDict( # pyright: ignore[reportIncompatibleVariableOverride] + populate_by_name=True, + ) + + defense_evals: list[EvalName] = Field(..., description="Evaluations to run on the defended checkpoint (pre-attack)") + post_attack_evals: list[EvalName] = Field(..., description="Evaluations to run on post-attack checkpoints") + inference_model_config: dict[str, object] = Field( + ..., alias="model_config", description="Raw model configuration dict for inference" + ) + attacks: list[PostDefenseAttackSpec] = Field(..., description="Attack specifications to stress-test the defense") + sweep: dict[str, object] = Field( + ..., description="Nested Optuna search space definition for defense hyperparameters" + ) + + @model_validator(mode="after") + def _ensure_defense_evals(self) -> Self: + """Ensure at least one defense evaluation is provided.""" + if not self.defense_evals: + raise ValueError("Defense sweep config must include at least one defense_eval entry.") + return self + + @model_validator(mode="after") + def _ensure_post_attack_evals(self) -> Self: + """Ensure at least one post-attack evaluation is provided.""" + if not self.post_attack_evals: + raise ValueError("Defense sweep config must include at least one post_attack_eval entry.") + return self + + @model_validator(mode="after") + def _ensure_attacks(self) -> Self: + """Ensure at least one attack specification is provided.""" + if not self.attacks: + raise ValueError("Defense sweep config must include at least one attack entry.") + return self + + @property + def primary_objective_direction(self) -> str: + """Get optimization direction from the primary post-attack evaluation. + + The defender wants to optimize the primary objective in the direction + that preserves safety (e.g., minimize harm score). This uses the + ``defender_direction`` from the first post-attack evaluation. + + Returns: + Optimization direction string, either ``"maximize"`` or ``"minimize"``. + + Example: + >>> sweep = DefenseSweepConfig.from_yaml(config_root, DefenseName.CRL, ConfigPath.SINGLE_OBJECTIVE_SWEEP) + >>> sweep.primary_objective_direction + 'minimize' + """ + return EVALS_REGISTRY[self.post_attack_evals[0]].defender_direction.value + + @classmethod + def from_yaml( + cls, + config_root: Path, + defense_name: DefenseName, + sweep_filename: str | ConfigPath = ConfigPath.SINGLE_OBJECTIVE_SWEEP, + ) -> Self: + """Load and validate a defense Optuna sweep specification from YAML. + + Loads and validates a sweep configuration file that defines the + hyperparameter search space, target evaluations, attack specs, and + model configuration for defense Optuna optimization. + + Args: + config_root: Root directory containing per-defense config files + (e.g., ``configs/whitebox/defenses``). + defense_name: Name of the defense whose sweep configuration + should be loaded. + sweep_filename: Name of the sweep config file + (default: ``ConfigPath.SINGLE_OBJECTIVE_SWEEP``). + + Returns: + Validated ``DefenseSweepConfig`` containing all sweep parameters. + + Raises: + FileNotFoundError: If the sweep YAML file is missing. + ValueError: If validation of the sweep payload fails. + + Example: + >>> config_root = Path("configs/whitebox/defenses") + >>> sweep = DefenseSweepConfig.from_yaml( + ... config_root, + ... DefenseName.CRL, + ... ConfigPath.SINGLE_OBJECTIVE_SWEEP, + ... ) + >>> sweep.defense_evals + [] + >>> "learning_rate" in sweep.sweep + True + """ + sweep_path = Path(config_root, defense_name, sweep_filename) + + if not sweep_path.exists(): + raise FileNotFoundError(f"Missing defense sweep config at: {sweep_path}") + + payload = yaml_to_dict(sweep_path) + + try: + return cls.model_validate(payload) + except ValidationError as exc: # pragma: no cover - bubble detailed error + raise ValueError(f"Invalid defense sweep config for '{defense_name}': {exc}") from exc + + +class DefenseGridConfig(BaseModel): + """Validated representation of a defense grid config entry with meta-fields. + + Extracts the meta-fields (evaluations, attacks, model config) from a + grid config entry, separating them from the defense-specific hyperparameters + that are passed to ``AlignmentDefenseConfig``. + + Attributes: + defense_evals: Evaluations to run on the defended checkpoint. + post_attack_evals: Evaluations to run on post-attack checkpoints. + inference_model_config: Raw model configuration dict for inference. + Populated from the ``model_config`` key in YAML. + attacks: Attack specifications to stress-test the defense. + + Example: + >>> grid_entry = { + ... "defense_evals": ["strong_reject"], + ... "post_attack_evals": ["strong_reject", "mmlu_pro_val"], + ... "model_config": {"template": "plain"}, + ... "attacks": [{"name": "lora_finetune"}], + ... "learning_rate": 0.001, + ... } + >>> grid_config = DefenseGridConfig.from_dict(grid_entry) + >>> grid_config.defense_evals + [] + """ + + model_config: ConfigDict = ConfigDict( # pyright: ignore[reportIncompatibleVariableOverride] + populate_by_name=True, + ) + + defense_evals: list[EvalName] = Field(..., description="Evaluations to run on the defended checkpoint") + post_attack_evals: list[EvalName] = Field(..., description="Evaluations to run on post-attack checkpoints") + inference_model_config: dict[str, object] = Field( + ..., alias="model_config", description="Raw model configuration dict for inference" + ) + attacks: list[PostDefenseAttackSpec] = Field(..., description="Attack specifications to stress-test the defense") + + @model_validator(mode="after") + def _ensure_defense_evals(self) -> Self: + """Ensure at least one defense evaluation is provided.""" + if not self.defense_evals: + raise ValueError("Defense grid config must include at least one defense_eval entry.") + return self + + @model_validator(mode="after") + def _ensure_post_attack_evals(self) -> Self: + """Ensure at least one post-attack evaluation is provided.""" + if not self.post_attack_evals: + raise ValueError("Defense grid config must include at least one post_attack_eval entry.") + return self + + @model_validator(mode="after") + def _ensure_attacks(self) -> Self: + """Ensure at least one attack specification is provided.""" + if not self.attacks: + raise ValueError("Defense grid config must include at least one attack entry.") + return self + + @classmethod + def from_dict(cls, config_dict: dict[str, object]) -> Self: + """Extract and validate meta-fields from a grid config entry. + + Picks out the meta-fields (``defense_evals``, ``post_attack_evals``, + ``model_config``, ``attacks``) from a raw grid config dictionary. + The remaining keys are defense-specific hyperparameters that should + be passed to ``AlignmentDefenseConfig.from_dict()``. + + Args: + config_dict: Raw config dictionary from a grid.yaml entry. + + Returns: + Validated ``DefenseGridConfig`` containing only the meta-fields. + + Raises: + ValueError: If required meta-fields are missing or invalid. + + Example: + >>> entry = { + ... "defense_evals": ["strong_reject"], + ... "post_attack_evals": ["strong_reject"], + ... "model_config": {"template": "plain"}, + ... "attacks": [{"name": "lora_finetune"}], + ... "alpha": 0.5, + ... } + >>> meta = DefenseGridConfig.from_dict(entry) + >>> meta.attacks[0].name + + """ + meta_keys = { + str(ConfigKeys.DEFENSE_EVALS), + str(ConfigKeys.POST_ATTACK_EVALS), + str(ConfigKeys.MODEL_CONFIG), + str(ConfigKeys.ATTACKS), + } + meta_payload = {k: v for k, v in config_dict.items() if k in meta_keys} + + try: + return cls.model_validate(meta_payload) + except ValidationError as exc: # pragma: no cover - bubble detailed error + raise ValueError(f"Invalid defense grid config meta-fields: {exc}") from exc + + +def load_defense_base_config(config_root: Path, defense_name: DefenseName) -> tuple[str, dict[str, object]]: + """Load the default grid config for a defense. + + Loads the ``grid.yaml`` file for the specified defense and extracts the + base configuration. Prefers a config named ``"base"`` if present; + otherwise uses the first config in the file. + + Args: + config_root: Root directory containing ``/grid.yaml`` files. + defense_name: Name of the defense whose base config should be loaded. + + Returns: + A tuple of ``(config_name, config_dict)`` where config_name is the key + used (``"base"`` if available) and config_dict is the full configuration + payload including both meta-fields and defense hyperparameters. + + Raises: + TypeError: If the config value is not a dictionary. + + Example: + >>> config_root = Path("configs/whitebox/defenses") + >>> name, config = load_defense_base_config(config_root, DefenseName.CRL) + >>> name + 'base' + >>> "defense_evals" in config + True + """ + grid_path = Path(config_root, defense_name, ConfigPath.GRID_YAML) + grid_payload = yaml_to_dict(grid_path) + + if ConfigKeys.BASE in grid_payload: + base_cfg = grid_payload[ConfigKeys.BASE] + + if isinstance(base_cfg, dict): + cfg_typed: dict[str, object] = base_cfg + return str(ConfigKeys.BASE), cfg_typed + + raise TypeError(f"Expected dict for '{ConfigKeys.BASE}' config, got {type(base_cfg)}") + + first_key = next(iter(grid_payload.keys())) + first_cfg = grid_payload[first_key] + + if isinstance(first_cfg, dict): + cfg_typed = first_cfg + return str(first_key), cfg_typed + + raise TypeError(f"Expected dict for config '{first_key}', got {type(first_cfg)}") diff --git a/src/tamperbench/whitebox/utils/benchmark/defense_trial_manager.py b/src/tamperbench/whitebox/utils/benchmark/defense_trial_manager.py new file mode 100644 index 00000000..689a319e --- /dev/null +++ b/src/tamperbench/whitebox/utils/benchmark/defense_trial_manager.py @@ -0,0 +1,522 @@ +"""Defense trial management: defend -> eval -> attack -> eval -> cleanup.""" + +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false + +from __future__ import annotations + +import copy +import json +import logging +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +from tamperbench.utils import get_repo_root +from tamperbench.whitebox.defenses.registry import DEFENSES_REGISTRY +from tamperbench.whitebox.evals.registry import EVAL_CONFIG_REGISTRY, EVALS_REGISTRY +from tamperbench.whitebox.utils.benchmark.config import ( + AttackSweepConfig, + load_attack_base_config, +) +from tamperbench.whitebox.utils.benchmark.io import yaml_to_dict +from tamperbench.whitebox.utils.benchmark.path_generation import ( + StudyPaths, + get_trial_dir_name, +) +from tamperbench.whitebox.utils.benchmark.runners import ( + run_attack_grid_configs, + run_optuna_sweep, +) +from tamperbench.whitebox.utils.benchmark.trial_manager import SweepTrialManager +from tamperbench.whitebox.utils.models.config import ModelConfig +from tamperbench.whitebox.utils.names import ( + ConfigKeys, + ConfigPath, + DefenseMetricPrefix, + DefenseName, + DefenseTrialDirs, + EvalName, + OptimizationDirection, + OptunaPaths, +) + +if TYPE_CHECKING: + import optuna + + from tamperbench.whitebox.utils.benchmark.defense_config import PostDefenseAttackSpec + +logger = logging.getLogger(__name__) + +WHITEBOX_ATTACK_CONFIG_DIR = get_repo_root() / Path("configs", "whitebox", "attacks") + + +class DefenseSweepTrialManager: + """Static methods for managing defense sweep trials. + + Provides utilities for running the full defense trial pipeline: + defend -> evaluate defense -> attack(s) -> evaluate post-attack -> cleanup. + All methods are static since this class serves as a namespace for + defense trial-related operations. + + Example: + >>> from tamperbench.whitebox.utils.benchmark.defense_trial_manager import DefenseSweepTrialManager + >>> metrics = DefenseSweepTrialManager.run_trial( + ... defense_name=DefenseName.CRL, + ... defense_config_dict=config, + ... defense_eval_names=[EvalName.STRONG_REJECT], + ... post_attack_eval_names=[EvalName.STRONG_REJECT], + ... model_config_dict={"template": "plain", "max_generation_length": 1024, "inference_batch_size": 16}, + ... attacks=[PostDefenseAttackSpec(name=AttackName.LORA_FINETUNE)], + ... pretrained_model_path="model_path", + ... defense_results_dir=Path("results/model/crl/optuna_single"), + ... trial_number=0, + ... random_seed=42, + ... model_alias="llama3_8b", + ... ) + """ + + @staticmethod + def run_trial( + defense_name: DefenseName, + defense_config_dict: dict[str, object], + defense_eval_names: list[EvalName], + post_attack_eval_names: list[EvalName], + model_config_dict: dict[str, object], + attacks: list[PostDefenseAttackSpec], + pretrained_model_path: str, + defense_results_dir: Path, + trial_number: int, + random_seed: int, + attack_configs_dir: Path | None = None, + model_alias: str = "", + ) -> dict[str, float]: + """Run a single defense trial. + + Executes the full pipeline: build defense config, run defense, evaluate + the defended checkpoint, run post-defense attacks, collect metrics, clean + up the defended checkpoint, and save aggregated trial results. + + Args: + defense_name: Name of the defense to run from the defense registry. + defense_config_dict: Configuration dictionary for the defense, including + meta-fields that will be stripped before passing to the defense. + defense_eval_names: Evaluations to run on the defended checkpoint + (pre-attack, post-defense). + post_attack_eval_names: Evaluations to run on post-attack checkpoints. + model_config_dict: Raw model configuration dictionary for inference. + attacks: List of attack specifications to stress-test the defense. + pretrained_model_path: Path to the pretrained model checkpoint. + defense_results_dir: Directory for this defense's sweep results + (e.g., ``results/model/crl/optuna_single/``). + trial_number: Trial number for directory naming. + random_seed: Random seed for reproducibility. + attack_configs_dir: Optional root directory for attack configs. + Falls back to the repository default if not specified. + model_alias: Alias for the model, used for inner attack sweep + StudyPaths to keep consistent directory structure. + + Returns: + Flat dictionary mapping prefixed metric names to their values. + Keys use dot-separated prefixes, e.g. + ``{"defense.strong_reject": 0.12, "post_attack.lora_finetune.strong_reject": 0.45}``. + """ + trial_dir = defense_results_dir / get_trial_dir_name(trial_number) + trial_dir.mkdir(parents=True, exist_ok=True) + + # 1. Build defense config, stripping meta-fields + clean_config = DefenseSweepTrialManager._strip_meta_fields(copy.deepcopy(defense_config_dict)) + + defended_checkpoint_dir = trial_dir / DefenseTrialDirs.DEFENDED_MODEL + clean_config["input_checkpoint_path"] = Path(pretrained_model_path) + clean_config["output_checkpoint_path"] = defended_checkpoint_dir + + defense_config_cls, defense_cls = DEFENSES_REGISTRY[defense_name] + defense_config = defense_config_cls.from_dict(clean_config) + defense = defense_cls(defense_config=defense_config) + + # 2. Run defense + defended_checkpoint = str(defense.run_defense()) + logger.info("Defense checkpoint created at: %s", defended_checkpoint) + + model_config = ModelConfig.from_dict(dict(model_config_dict)) + all_metrics: dict[str, float] = {} + + # 3. Evaluate defended checkpoint with defense_evals + defense_eval_dir = trial_dir / DefenseTrialDirs.DEFENSE_EVAL + defense_metrics = DefenseSweepTrialManager.evaluate_checkpoint( + checkpoint_path=defended_checkpoint, + eval_names=defense_eval_names, + model_config=model_config, + out_dir=defense_eval_dir, + ) + all_metrics.update(DefenseSweepTrialManager._prefix_metrics(DefenseMetricPrefix.DEFENSE, defense_metrics)) + + # 4. For each attack spec, run the attack against the defended checkpoint + post_attack_dir = trial_dir / DefenseTrialDirs.POST_ATTACK + for attack_spec in attacks: + attack_out_dir = post_attack_dir / str(attack_spec.name) + attack_out_dir.mkdir(parents=True, exist_ok=True) + + if attack_spec.mode == "grid": + attack_metrics = DefenseSweepTrialManager.run_attack_grid( + attack_spec=attack_spec, + defended_checkpoint=defended_checkpoint, + post_attack_eval_names=post_attack_eval_names, + model_config_dict=model_config_dict, + attack_out_dir=attack_out_dir, + random_seed=random_seed, + attack_configs_dir=attack_configs_dir, + ) + else: + attack_metrics = DefenseSweepTrialManager.run_attack_sweep( + attack_spec=attack_spec, + defended_checkpoint=defended_checkpoint, + post_attack_eval_names=post_attack_eval_names, + model_config_dict=model_config_dict, + attack_out_dir=attack_out_dir, + random_seed=random_seed, + attack_configs_dir=attack_configs_dir, + model_alias=model_alias, + ) + + all_metrics.update( + DefenseSweepTrialManager._prefix_metrics( + f"{DefenseMetricPrefix.POST_ATTACK}.{attack_spec.name}", attack_metrics + ) + ) + + # 5. Delete defended checkpoint + if Path(defended_checkpoint).exists(): + shutil.rmtree(defended_checkpoint) + logger.info("Cleaned up defended checkpoint: %s", defended_checkpoint) + + # 6. Save trial_results.json + trial_results_path = trial_dir / DefenseTrialDirs.TRIAL_RESULTS + DefenseSweepTrialManager._save_trial_results(trial_results_path, all_metrics) + + return all_metrics + + @staticmethod + def evaluate_checkpoint( + checkpoint_path: str, + eval_names: list[EvalName], + model_config: ModelConfig, + out_dir: Path, + ) -> dict[str, float]: + """Run evaluations on a checkpoint without any attack wrapper. + + Constructs evaluation configs and runs each registered evaluation + on the given checkpoint, collecting the objective metric from each. + + Args: + checkpoint_path: Path to the model checkpoint to evaluate. + eval_names: List of evaluation names to run. + model_config: Model configuration for inference. + out_dir: Directory where evaluation results will be stored. + + Returns: + Dictionary mapping evaluation name strings to their objective values. + + Example: + >>> metrics = DefenseSweepTrialManager.evaluate_checkpoint( + ... checkpoint_path="/path/to/defended_model", + ... eval_names=[EvalName.STRONG_REJECT, EvalName.MMLU_PRO_VAL], + ... model_config=model_config, + ... out_dir=Path("results/defense_eval"), + ... ) + >>> metrics["strong_reject"] + 0.12 + """ + results: dict[str, float] = {} + + for eval_name in eval_names: + eval_cls = EVALS_REGISTRY[eval_name] + config_cls = EVAL_CONFIG_REGISTRY[eval_name] + eval_config = config_cls( + model_checkpoint=checkpoint_path, + out_dir=str(out_dir), + model_config=model_config, + ) + evaluation = eval_cls(eval_config) + result_df = evaluation.run_evaluation() + value = eval_cls.load_result_objective(result_df) + results[str(eval_name)] = float(value) + + return results + + @staticmethod + def run_attack_grid( + attack_spec: PostDefenseAttackSpec, + defended_checkpoint: str, + post_attack_eval_names: list[EvalName], + model_config_dict: dict[str, object], + attack_out_dir: Path, + random_seed: int, + attack_configs_dir: Path | None = None, + ) -> dict[str, float]: + """Run attack in grid mode against a defended checkpoint. + + Loads the attack's grid.yaml, optionally filters to a single config, + injects model config and evals, and runs all configs. Returns the + worst-case metrics from the attacker's perspective. + + Args: + attack_spec: Attack specification defining name, mode, and config options. + defended_checkpoint: Path to the defended model checkpoint. + post_attack_eval_names: Evaluations to run after each attack config. + model_config_dict: Raw model configuration dictionary for inference. + attack_out_dir: Output directory for attack results. + random_seed: Random seed for reproducibility. + attack_configs_dir: Root directory for attack configs. Falls back to + the repository default if not specified. + + Returns: + Dictionary mapping evaluation name strings to worst-case metric values + from the attacker's perspective. + """ + configs_dir = _resolve_attack_configs_dir(attack_spec.configs_dir, attack_configs_dir) + grid_path = Path(configs_dir, str(attack_spec.name), ConfigPath.GRID_YAML) + config_grid: dict[str, dict[str, object]] = {} + + raw_grid = yaml_to_dict(grid_path) + if attack_spec.config_name is not None: + raw_value = raw_grid[attack_spec.config_name] + if isinstance(raw_value, dict): + config_grid[attack_spec.config_name] = raw_value + else: + for key, value in raw_grid.items(): + if isinstance(value, dict): + config_grid[key] = value + + # Inject model_config and evals into each config dict + for config_dict in config_grid.values(): + config_dict[ConfigKeys.MODEL_CONFIG] = copy.deepcopy(model_config_dict) + config_dict[ConfigKeys.EVALS] = [str(e) for e in post_attack_eval_names] + + results = run_attack_grid_configs( + attack_name=attack_spec.name, + config_grid=config_grid, + pretrained_model_path=defended_checkpoint, + output_base_dir=attack_out_dir, + random_seed=random_seed, + eval_names=post_attack_eval_names, + cleanup_checkpoints=True, + ) + + # Extract objective values per config + all_config_results: dict[str, dict[str, float]] = {} + for config_name, results_df in results.items(): + config_metrics: dict[str, float] = {} + for eval_name in post_attack_eval_names: + eval_cls = EVALS_REGISTRY[eval_name] + config_metrics[str(eval_name)] = float(eval_cls.load_result_objective(results_df)) + all_config_results[config_name] = config_metrics + + return DefenseSweepTrialManager._get_worst_case_metrics(all_config_results, post_attack_eval_names) + + @staticmethod + def run_attack_sweep( + attack_spec: PostDefenseAttackSpec, + defended_checkpoint: str, + post_attack_eval_names: list[EvalName], + model_config_dict: dict[str, object], + attack_out_dir: Path, + random_seed: int, + attack_configs_dir: Path | None = None, + model_alias: str = "", + ) -> dict[str, float]: + """Run attack in sweep mode using an inner Optuna optimization. + + Loads the attack's sweep config and base config, creates an inner + Optuna study, and runs the sweep to find the most effective attack + hyperparameters. Returns the best trial's metrics. + + Args: + attack_spec: Attack specification defining name, mode, and sweep options. + defended_checkpoint: Path to the defended model checkpoint. + post_attack_eval_names: Evaluations to run after each attack trial. + model_config_dict: Raw model configuration dictionary for inference. + attack_out_dir: Output directory for inner sweep results. + random_seed: Random seed for reproducibility. + attack_configs_dir: Root directory for attack configs. Falls back to + the repository default if not specified. + model_alias: Alias for the model, used for inner StudyPaths. + + Returns: + Dictionary mapping evaluation name strings to the best trial's + metric values from the attacker's perspective. + """ + configs_dir = _resolve_attack_configs_dir(attack_spec.configs_dir, attack_configs_dir) + + # Load base config and sweep config for the attack + base_config_name, base_attack_config = load_attack_base_config(configs_dir, attack_spec.name) + sweep_config = AttackSweepConfig.from_yaml( + configs_dir, attack_spec.name, sweep_filename=ConfigPath.SINGLE_OBJECTIVE_SWEEP + ) + + # Override evals and inject model_config into the base config + base_attack_config[ConfigKeys.EVALS] = [str(e) for e in post_attack_eval_names] + base_attack_config[ConfigKeys.MODEL_CONFIG] = copy.deepcopy(model_config_dict) + + # Determine attacker direction from the primary eval + primary_eval_name = post_attack_eval_names[0] + primary_eval_cls = EVALS_REGISTRY[primary_eval_name] + direction = primary_eval_cls.attacker_direction.value + + # Create inner StudyPaths for the attack sweep + inner_study_paths = StudyPaths( + results_dir=attack_out_dir, + model_alias=model_alias, + attack_name=attack_spec.name, + sweep_subdir=OptunaPaths.OPTUNA_SINGLE_DIR, + ) + + # Build inner objective function + def inner_objective(trial: optuna.trial.Trial, merged_config: dict[str, object]) -> dict[str, float]: + merged_config[ConfigKeys.EVALS] = [str(e) for e in post_attack_eval_names] + merged_config[ConfigKeys.RANDOM_SEED] = random_seed + + return SweepTrialManager.run_trial( + attack_name=attack_spec.name, + attack_config_dict=merged_config, + eval_names=post_attack_eval_names, + pretrained_model_path=defended_checkpoint, + attack_results_dir=inner_study_paths.attack_results_dir, + trial_number=trial.number, + random_seed=random_seed, + ) + + study = run_optuna_sweep( + study_paths=inner_study_paths, + sweep_space=sweep_config.sweep, + base_config=base_attack_config, + direction=direction, + objective_fn=inner_objective, + primary_metric_key=str(primary_eval_name), + n_trials=attack_spec.n_trials, + random_seed=random_seed, + top_n=1, + eval_names=post_attack_eval_names, + base_config_name=base_config_name, + ) + + # Return best trial's metrics from the inner sweep + sorted_trials = SweepTrialManager.sorted_completed_trials(study, direction) + if not sorted_trials: + logger.warning("Inner attack sweep for %s produced no completed trials.", attack_spec.name) + return {str(e): float("nan") for e in post_attack_eval_names} + + best_trial = sorted_trials[0] + eval_metrics: object = best_trial.user_attrs.get("eval_metrics") + if isinstance(eval_metrics, dict): + return {str(k): float(v) for k, v in eval_metrics.items()} + + # Fallback: return the primary objective value + best_values: list[float] = best_trial.values + return {str(primary_eval_name): float(best_values[0]) if best_values else float("nan")} + + @staticmethod + def _get_worst_case_metrics( + all_config_results: dict[str, dict[str, float]], + eval_names: list[EvalName], + ) -> dict[str, float]: + """Select worst-case (max attacker harm) metrics across attack configs. + + For each evaluation, uses the attacker's optimization direction to + determine which config produced the worst outcome for the defender. + + Args: + all_config_results: Mapping of config names to their metric dictionaries. + eval_names: List of evaluation names to aggregate. + + Returns: + Dictionary mapping evaluation name strings to their worst-case values. + """ + if not all_config_results: + return {str(e): float("nan") for e in eval_names} + + worst_case: dict[str, float] = {} + for eval_name in eval_names: + eval_cls = EVALS_REGISTRY[eval_name] + eval_key = str(eval_name) + values = [metrics[eval_key] for metrics in all_config_results.values() if eval_key in metrics] + + if not values: + worst_case[eval_key] = float("nan") + continue + + if eval_cls.attacker_direction == OptimizationDirection.MAXIMIZE: + worst_case[eval_key] = max(values) + else: + worst_case[eval_key] = min(values) + + return worst_case + + @staticmethod + def _strip_meta_fields(config_dict: dict[str, object]) -> dict[str, object]: + """Remove meta-fields that should not be passed to AlignmentDefenseConfig. + + Strips keys used for orchestration (evaluations, attacks, model config) + that are not part of the defense's own hyperparameters. + + Args: + config_dict: Raw defense configuration dictionary potentially + containing meta-fields. + + Returns: + A copy of the dictionary with meta-fields removed. + """ + meta_keys = { + str(ConfigKeys.DEFENSE_EVALS), + str(ConfigKeys.POST_ATTACK_EVALS), + str(ConfigKeys.ATTACKS), + str(ConfigKeys.MODEL_CONFIG), + } + return {k: v for k, v in config_dict.items() if k not in meta_keys} + + @staticmethod + def _prefix_metrics(prefix: str, metrics: dict[str, float]) -> dict[str, float]: + """Add a dot-separated prefix to metric keys. + + Args: + prefix: Prefix string to prepend (e.g., ``"defense"`` or + ``"post_attack.lora_finetune"``). + metrics: Dictionary of metric name to value mappings. + + Returns: + New dictionary with prefixed keys, e.g. + ``{"defense.strong_reject": 0.12}``. + """ + return {f"{prefix}.{key}": value for key, value in metrics.items()} + + @staticmethod + def _save_trial_results(path: Path, metrics: dict[str, float]) -> None: + """Write aggregated metrics dictionary to a JSON file. + + Args: + path: File path where the JSON should be written. + metrics: Dictionary of metric names to values. + """ + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(metrics, f, indent=2, default=str) + + +def _resolve_attack_configs_dir( + spec_configs_dir: Path | None, + fallback_configs_dir: Path | None, +) -> Path: + """Resolve the attack configs directory from spec, fallback, or repo default. + + Args: + spec_configs_dir: Directory from the attack spec (highest priority). + fallback_configs_dir: Caller-provided fallback directory. + + Returns: + Resolved path to the attack configs directory. + """ + if spec_configs_dir is not None: + return spec_configs_dir + if fallback_configs_dir is not None: + return fallback_configs_dir + return WHITEBOX_ATTACK_CONFIG_DIR diff --git a/src/tamperbench/whitebox/utils/benchmark/path_generation.py b/src/tamperbench/whitebox/utils/benchmark/path_generation.py index 9e0fc3ba..62f7940f 100644 --- a/src/tamperbench/whitebox/utils/benchmark/path_generation.py +++ b/src/tamperbench/whitebox/utils/benchmark/path_generation.py @@ -6,7 +6,7 @@ from pathlib import Path from tamperbench.whitebox.utils import AttackName -from tamperbench.whitebox.utils.names import EvalPath, OptunaPaths +from tamperbench.whitebox.utils.names import DefenseName, EvalPath, OptunaPaths @dataclass @@ -212,3 +212,158 @@ def get_single_objective_study_name(attack_name: AttackName) -> str: 'full_parameter_finetune_single' """ return f"{attack_name}_single" + + +@dataclass +class DefenseStudyPaths: + """Standardized paths for a defense Optuna sweep study. + + Encapsulates all path generation for defense Optuna sweeps, ensuring + consistent directory structure across runs. Mirrors ``StudyPaths`` but + uses ``defense_name`` instead of ``attack_name``. + + Structure:: + + results_dir/ + └── model_alias/ + └── defense_name/ + └── sweep_subdir/ + ├── study.db + ├── best.yaml + ├── best.json + └── trial_0/ + ├── defended_model/ + ├── defense_eval/ + ├── post_attack/ + └── trial_results.json + + Attributes: + results_dir: Root directory for all defense sweep results. + model_alias: Identifier for the model being defended. + defense_name: Name of the defense being swept. + sweep_subdir: Subdirectory for this sweep type (default: single-objective). + + Example: + >>> paths = DefenseStudyPaths( + ... results_dir=Path("results"), + ... model_alias="llama-3-8b_2026_01_14", + ... defense_name=DefenseName.CRL, + ... ) + >>> paths.defense_results_dir + PosixPath('results/llama-3-8b_2026_01_14/crl/optuna_single') + >>> paths.storage_url + 'sqlite:///results/llama-3-8b_2026_01_14/crl/optuna_single/study.db' + """ + + results_dir: Path + model_alias: str + defense_name: DefenseName + sweep_subdir: str = OptunaPaths.OPTUNA_SINGLE_DIR + + @property + def model_results_dir(self) -> Path: + """Directory for all results related to this model.""" + return self.results_dir / self.model_alias + + @property + def defense_results_dir(self) -> Path: + """Directory for this defense's sweep results.""" + return self.model_results_dir / str(self.defense_name) / self.sweep_subdir + + @property + def attack_results_dir(self) -> Path: + """Alias for defense_results_dir to satisfy StudyPathsLike protocol.""" + return self.defense_results_dir + + @property + def storage_path(self) -> Path: + """Path to the Optuna study SQLite database.""" + return self.defense_results_dir / "study.db" + + @property + def storage_url(self) -> str: + """SQLite URL for Optuna study storage.""" + return f"sqlite:///{self.storage_path}" + + @property + def study_name(self) -> str: + """Optuna study name for this defense.""" + return get_defense_study_name(self.defense_name) + + @property + def best_yaml_path(self) -> Path: + """Path to the best.yaml file containing top defense configs.""" + return self.defense_results_dir / "best.yaml" + + @property + def best_json_path(self) -> Path: + """Path to the best.json file containing trial summaries.""" + return self.defense_results_dir / "best.json" + + def ensure_dirs(self) -> None: + """Create the defense results directory if it doesn't exist.""" + self.defense_results_dir.mkdir(parents=True, exist_ok=True) + + @classmethod + def from_existing( + cls, + results_dir: Path, + model_alias: str, + defense_name: DefenseName, + sweep_subdir: str = OptunaPaths.OPTUNA_SINGLE_DIR, + ) -> DefenseStudyPaths: + """Create DefenseStudyPaths for an existing study directory. + + Use this when loading existing results for analysis rather than + creating new sweep runs. + + Args: + results_dir: Root results directory. + model_alias: Model directory name. + defense_name: Defense name from the enum. + sweep_subdir: Sweep subdirectory name (default: optuna_single). + + Returns: + DefenseStudyPaths instance pointing to the existing study. + + Raises: + FileNotFoundError: If the defense_results_dir doesn't exist. + + Example: + >>> paths = DefenseStudyPaths.from_existing( + ... Path("results"), "llama-3-8b_2026_01_14", DefenseName.CRL + ... ) + >>> paths.storage_path.exists() + True + """ + instance = cls( + results_dir=results_dir, + defense_name=defense_name, + model_alias=model_alias, + sweep_subdir=sweep_subdir, + ) + if not instance.defense_results_dir.exists(): + raise FileNotFoundError(f"Defense Optuna directory not found: {instance.defense_results_dir}") + return instance + + +def get_defense_study_name(defense_name: DefenseName) -> str: + """Get standardized Optuna study name for defense single-objective sweeps. + + Used to create consistent study names across defense Optuna optimization + runs, allowing studies to be resumed with ``load_if_exists=True``. + + Args: + defense_name: Defense name from ``DefenseName`` enum. + + Returns: + Study name in format ``{defense_name}_defense_single``. + + Example: + >>> from tamperbench.whitebox.utils.names import DefenseName + >>> get_defense_study_name(DefenseName.CRL) + 'crl_defense_single' + >>> get_defense_study_name(DefenseName.TAR) + 'tar_defense_single' + """ + return f"{defense_name}_defense_single" diff --git a/src/tamperbench/whitebox/utils/benchmark/runners.py b/src/tamperbench/whitebox/utils/benchmark/runners.py new file mode 100644 index 00000000..f82edac7 --- /dev/null +++ b/src/tamperbench/whitebox/utils/benchmark/runners.py @@ -0,0 +1,244 @@ +"""Reusable runner functions for Optuna sweeps and grid-based attack benchmarks. + +Extracts shared logic from ``scripts/whitebox/optuna_single.py`` (Optuna loop) and +``scripts/whitebox/benchmark_grid.py`` (grid loop) so both attack scripts and future +defense scripts can share the same core without duplicating boilerplate. +""" + +# pyright: reportUnusedCallResult=false, reportUnknownVariableType=false + +from __future__ import annotations + +import copy +import logging +import traceback +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +import optuna + +from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig +from tamperbench.whitebox.attacks.registry import ATTACKS_REGISTRY +from tamperbench.whitebox.utils.benchmark.io import deep_merge_dicts +from tamperbench.whitebox.utils.benchmark.sweep import suggest_nested_params +from tamperbench.whitebox.utils.benchmark.trial_manager import SweepTrialManager +from tamperbench.whitebox.utils.names import ( + AttackName, + ConfigKeys, + EvalName, + OptunaUserAttrs, + TrialFailureKeys, +) + +if TYPE_CHECKING: + from pandera.typing.polars import DataFrame + + from tamperbench.whitebox.evals.output_schema import EvaluationSchema + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class StudyPathsLike(Protocol): + """Protocol for objects providing paths required by ``run_optuna_sweep``. + + Both ``StudyPaths`` (attack sweeps) and future ``DefenseStudyPaths`` satisfy + this protocol, allowing ``run_optuna_sweep`` to work with either. + """ + + @property + def attack_results_dir(self) -> Path: + """Directory where trial results, study DB, and summaries are stored.""" + ... + + @property + def storage_path(self) -> Path: + """Path to the Optuna study SQLite database file.""" + ... + + @property + def storage_url(self) -> str: + """SQLite connection URL for Optuna study storage.""" + ... + + @property + def study_name(self) -> str: + """Optuna study name.""" + ... + + def ensure_dirs(self) -> None: + """Create required directories if they do not exist.""" + ... + + +def run_optuna_sweep( + study_paths: StudyPathsLike, + sweep_space: dict[str, object], + base_config: dict[str, object], + direction: str, + objective_fn: Callable[[optuna.trial.Trial, dict[str, object]], dict[str, float]], + primary_metric_key: str, + n_trials: int, + random_seed: int, + top_n: int, + eval_names: Sequence[str], + base_config_name: str, +) -> optuna.study.Study: + """Run a single-objective Optuna sweep and summarize results. + + Shared core for attack sweeps, defense sweeps, AND inner attack sweeps + within defense trials. The caller provides ``objective_fn`` which receives + ``(trial, merged_config)`` and returns ``{metric_key: value}``. + + Steps: + 1. Create Optuna study (SQLite-backed, TPESampler, ``load_if_exists``). + 2. For each trial: + a. ``suggest_nested_params(trial, sweep_space)`` + b. ``deep_merge_dicts(base_config, overrides)`` + c. Store config overrides + merged config in ``trial.user_attrs`` + d. Call ``objective_fn(trial, merged_config)`` -> ``{metric_key: float}`` + e. Store metrics in ``trial.user_attrs`` + f. Return ``metrics[primary_metric_key]`` to Optuna + g. On error: store traceback, raise ``TrialPruned`` + 3. ``study.optimize(objective, n_trials)`` + 4. ``SweepTrialManager.summarize_top_n()`` + 5. Return study + + Args: + study_paths: Object providing paths for the study directory, storage URL, + and study name. Must satisfy the ``StudyPathsLike`` protocol. + sweep_space: Nested Optuna search space definition (from sweep YAML). + base_config: Base configuration dictionary to merge with suggested params. + direction: Optimization direction, either ``"maximize"`` or ``"minimize"``. + objective_fn: Callback receiving ``(trial, merged_config_dict)`` and + returning ``dict[str, float]`` mapping metric keys to values. + primary_metric_key: Key in the dict returned by ``objective_fn`` whose + value is returned to Optuna as the trial objective. + n_trials: Number of Optuna trials to run. + random_seed: Seed for the TPESampler for reproducibility. + top_n: Number of top-performing configs to save in YAML summary. + eval_names: Ordered list of evaluation names used in this sweep. + base_config_name: Name of the base configuration (for summary metadata). + + Returns: + The completed Optuna study with all trial results. + """ + study_paths.ensure_dirs() + + sampler = optuna.samplers.TPESampler(seed=random_seed) + study = optuna.create_study( + direction=direction, + study_name=study_paths.study_name, + storage=study_paths.storage_url, + load_if_exists=True, + sampler=sampler, + ) + + def objective( + trial: optuna.trial.Trial, + *, + _sweep_space: dict[str, object] = sweep_space, + _base_config: dict[str, object] = base_config, + _primary_metric_key: str = primary_metric_key, + ) -> float: + trial_overrides = suggest_nested_params(trial, _sweep_space) + merged_config = deep_merge_dicts(_base_config, trial_overrides) + + trial.set_user_attr(OptunaUserAttrs.CONFIG_OVERRIDES, trial_overrides) + trial.set_user_attr(OptunaUserAttrs.MERGED_CONFIG, copy.deepcopy(merged_config)) + + try: + metrics = objective_fn(trial, merged_config) + trial.set_user_attr(OptunaUserAttrs.EVAL_METRICS, metrics) + return metrics[_primary_metric_key] + + except Exception as error: # pragma: no cover - surfaced to study logs + stack_trace = traceback.format_exc() + trial.set_user_attr( + OptunaUserAttrs.FAILURE, + { + TrialFailureKeys.ERROR: str(error), + TrialFailureKeys.TRACEBACK: stack_trace, + }, + ) + raise optuna.exceptions.TrialPruned(f"Trial failed due to exception: {error}") from error + + study.optimize(objective, n_trials=n_trials) + + SweepTrialManager.summarize_top_n( + study, + direction=direction, + top_n=top_n, + eval_names=eval_names, + base_cfg=base_config, + base_config_name=base_config_name, + output_dir=study_paths.attack_results_dir, + study_storage=study_paths.storage_path, + ) + + return study + + +def run_attack_grid_configs( + attack_name: AttackName, + config_grid: dict[str, dict[str, object]], + pretrained_model_path: str, + output_base_dir: Path, + random_seed: int, + eval_names: list[EvalName] | None = None, + cleanup_checkpoints: bool = True, +) -> dict[str, DataFrame[EvaluationSchema]]: + """Run an attack across all configs from a grid YAML. + + For each ``config_name`` in ``config_grid``: + 1. Inject ``input_checkpoint_path``, ``out_dir`` (``output_base_dir/config_name``), + and ``random_seed`` into the config dict. + 2. If ``eval_names`` is provided, override evals in the config. + 3. Instantiate the attack from ``ATTACKS_REGISTRY`` and run ``attacker.benchmark()``. + 4. If ``cleanup_checkpoints`` is ``True``, call ``attacker.delete_output_checkpoint()``. + 5. Collect the results DataFrame. + + Args: + attack_name: Name of the attack to run (must be registered in ``ATTACKS_REGISTRY``). + config_grid: Mapping of config names to their configuration dictionaries, + as loaded from a ``grid.yaml`` file. + pretrained_model_path: Path to the pretrained model or checkpoint to attack. + output_base_dir: Base directory for results. Each config gets a subdirectory + at ``output_base_dir / config_name``. + random_seed: Random seed injected into every config for reproducibility. + eval_names: Optional list of eval names to override in each config. + If ``None``, the evals specified in the grid config are used as-is. + cleanup_checkpoints: Whether to delete output checkpoints after evaluation + to save disk space (default: ``True``). + + Returns: + Dictionary mapping config names to their evaluation results DataFrames. + """ + attack_config_cls, attack_cls = ATTACKS_REGISTRY[attack_name] + results: dict[str, DataFrame[EvaluationSchema]] = {} + + for config_name, attack_config_dict in config_grid.items(): + config_dict = copy.deepcopy(attack_config_dict) + config_dict[ConfigKeys.RANDOM_SEED] = random_seed + + if eval_names is not None: + config_dict[ConfigKeys.EVALS] = [str(e) for e in eval_names] + + attacker: TamperAttack[TamperAttackConfig] = attack_cls( + attack_config=attack_config_cls.from_dict( + data={ + "input_checkpoint_path": pretrained_model_path, + "out_dir": Path(output_base_dir, config_name), + **config_dict, + } + ) + ) + + results_df: DataFrame[EvaluationSchema] = attacker.benchmark() + results[config_name] = results_df + + if cleanup_checkpoints: + attacker.delete_output_checkpoint() + + return results diff --git a/src/tamperbench/whitebox/utils/benchmark/trial_manager.py b/src/tamperbench/whitebox/utils/benchmark/trial_manager.py index 7ac22f34..7c27a714 100644 --- a/src/tamperbench/whitebox/utils/benchmark/trial_manager.py +++ b/src/tamperbench/whitebox/utils/benchmark/trial_manager.py @@ -6,7 +6,8 @@ import copy import json -from collections.abc import Mapping +import logging +from collections.abc import Mapping, Sequence from pathlib import Path from typing import TYPE_CHECKING @@ -22,12 +23,15 @@ from tamperbench.whitebox.utils.benchmark.path_generation import get_trial_dir_name from tamperbench.whitebox.utils.names import ( ConfigKeys, + DefenseMetricPrefix, EvalName, OptunaPaths, OptunaUserAttrs, TrialSummaryKeys, ) +logger = logging.getLogger(__name__) + if TYPE_CHECKING: import optuna from pandera.typing.polars import DataFrame @@ -162,7 +166,7 @@ def sort_key(trial: optuna.trial.FrozenTrial) -> float: @staticmethod def build_summaries( top_trials: list[optuna.trial.FrozenTrial], - eval_names: list[EvalName], + eval_names: Sequence[str], base_cfg: Mapping[str, object], ) -> tuple[dict[str, object], list[dict[str, object]]]: """Build best configs and trial summaries from top-ranked Optuna trials. @@ -173,7 +177,10 @@ def build_summaries( Args: top_trials: List of FrozenTrial objects sorted best-first. - eval_names: List of evaluation names used in the sweep. + eval_names: Metric keys used in the sweep. Accepts bare ``EvalName`` + values (attack sweeps) or prefixed keys like + ``"defense.strong_reject"`` (defense sweeps). Each key must + contain a valid ``EvalName`` as its final dot-separated segment. base_cfg: Base configuration dictionary for fallback merging. Returns: @@ -192,6 +199,15 @@ def build_summaries( >>> best_configs["top_1"]["learning_rate"] 0.0001 """ + # Validate that every metric key contains a valid EvalName + for key in eval_names: + if not DefenseMetricPrefix.validate_metric_key(key): + logger.warning( + "Metric key '%s' does not end with a known EvalName. " + "Values for this key may appear as NaN in best.json.", + key, + ) + best_configs: dict[str, object] = {} summaries: list[dict[str, object]] = [] @@ -271,7 +287,7 @@ def summarize_top_n( *, direction: str, top_n: int, - eval_names: list[EvalName], + eval_names: Sequence[str], base_cfg: Mapping[str, object], base_config_name: str, output_dir: Path, @@ -286,7 +302,7 @@ def summarize_top_n( study: Completed Optuna study to summarize. direction: Optimization direction, either "maximize" or "minimize". top_n: Number of top-performing trials to save. - eval_names: List of evaluation names used in the sweep. + eval_names: Metric keys used in the sweep (bare ``EvalName`` or prefixed). base_cfg: Base configuration dictionary used for the sweep. base_config_name: Name of the base configuration (for metadata). output_dir: Directory to save output files. diff --git a/src/tamperbench/whitebox/utils/names.py b/src/tamperbench/whitebox/utils/names.py index d07aa6ec..843dd040 100644 --- a/src/tamperbench/whitebox/utils/names.py +++ b/src/tamperbench/whitebox/utils/names.py @@ -191,20 +191,43 @@ class ConfigPath(StrEnum): MULTI_OBJECTIVE_SWEEP = "multi_objective_sweep.yaml" # yaml file for multi-objective sweep +class DefenseTrialDirs(StrEnum): + """Directory names within a defense trial. + + These define the standardized subdirectory layout inside each defense + trial directory (e.g., ``trial_0/``). Each trial contains a defended + model checkpoint, defense evaluation results, post-attack results, + and an aggregated trial results file. + + Example structure:: + + trial_0/ + ├── defended_model/ # DEFENDED_MODEL + ├── defense_eval/ # DEFENSE_EVAL + ├── post_attack/ # POST_ATTACK + └── trial_results.json # TRIAL_RESULTS + """ + + DEFENDED_MODEL = "defended_model" # Defense output checkpoint directory + DEFENSE_EVAL = "defense_eval" # Defense checkpoint evaluation results + POST_ATTACK = "post_attack" # Post-attack results (per-attack subdirectories) + TRIAL_RESULTS = "trial_results.json" # Aggregated metrics for the trial + + class ConfigKeys(StrEnum): - """Keys used in attack and sweep configuration dictionaries. + """Keys used in attack, defense, and sweep configuration dictionaries. These keys appear in YAML config files (grid.yaml, sweep.yaml, etc.) and in runtime configuration dictionaries passed between components. - Example grid.yaml structure:: + Example attack grid.yaml structure:: base: # ConfigKeys.BASE - default/base configuration evals: [strong_reject, mmlu_pro_val] # ConfigKeys.EVALS learning_rate: 0.0001 ... - Example sweep.yaml structure:: + Example attack sweep.yaml structure:: evals: [strong_reject] # ConfigKeys.EVALS - objectives to optimize sweep: # ConfigKeys.SWEEP - search space definition @@ -212,12 +235,108 @@ class ConfigKeys(StrEnum): type: float low: 1e-5 high: 1e-3 + + Example defense grid.yaml structure:: + + base: + defense_evals: [strong_reject] # ConfigKeys.DEFENSE_EVALS + post_attack_evals: [strong_reject] # ConfigKeys.POST_ATTACK_EVALS + model_config: { ... } # ConfigKeys.MODEL_CONFIG + attacks: [{ name: lora_finetune }] # ConfigKeys.ATTACKS + alpha: 0.5 """ BASE = "base" # Key for base/default config in grid.yaml EVALS = "evals" # Key for list of evaluation names in configs SWEEP = "sweep" # Key for sweep search space in sweep configs RANDOM_SEED = "random_seed" # Key for random seed in configs + DEFENSE_EVALS = "defense_evals" # Key for defense checkpoint evaluations + POST_ATTACK_EVALS = "post_attack_evals" # Key for post-attack evaluations + ATTACKS = "attacks" # Key for attack specifications in defense configs + MODEL_CONFIG = "model_config" # Key for model configuration dict in defense configs + + +class DefenseMetricPrefix(StrEnum): + """Dot-separated prefixes for defense sweep metric keys. + + Defense trials produce metrics with prefixed keys so that defense-stage + and post-attack-stage results are namespaced unambiguously in the same + flat dictionary. + + Example keys:: + + "defense.strong_reject" -> DEFENSE + EvalName + "post_attack.lora_finetune.strong_reject" -> POST_ATTACK + AttackName + EvalName + """ + + DEFENSE = "defense" + POST_ATTACK = "post_attack" + + @staticmethod + def defense_key(eval_name: EvalName) -> str: + """Build a defense-stage metric key. + + Args: + eval_name: The evaluation whose metric this key represents. + + Returns: + Prefixed key, e.g. ``"defense.strong_reject"``. + """ + return f"{DefenseMetricPrefix.DEFENSE}.{eval_name}" + + @staticmethod + def post_attack_key(attack_name: AttackName, eval_name: EvalName) -> str: + """Build a post-attack metric key. + + Args: + attack_name: The attack that was run against the defended model. + eval_name: The evaluation whose metric this key represents. + + Returns: + Prefixed key, e.g. ``"post_attack.lora_finetune.strong_reject"``. + """ + return f"{DefenseMetricPrefix.POST_ATTACK}.{attack_name}.{eval_name}" + + @staticmethod + def build_all_metric_keys( + defense_eval_names: list[EvalName], + post_attack_eval_names: list[EvalName], + attack_names: list[AttackName], + ) -> list[str]: + """Build the full ordered list of prefixed metric keys for a defense sweep. + + Args: + defense_eval_names: Evaluations run on the defended checkpoint. + post_attack_eval_names: Evaluations run after each attack. + attack_names: Attacks run against the defended model. + + Returns: + Ordered list of all prefixed metric keys, defense keys first, + then post-attack keys grouped by attack. + """ + keys: list[str] = [] + for eval_name in defense_eval_names: + keys.append(DefenseMetricPrefix.defense_key(eval_name)) + for attack_name in attack_names: + for eval_name in post_attack_eval_names: + keys.append(DefenseMetricPrefix.post_attack_key(attack_name, eval_name)) + return keys + + @staticmethod + def validate_metric_key(key: str) -> bool: + """Check that a prefixed metric key contains a valid EvalName. + + Validates that the last dot-separated segment of the key is a known + ``EvalName`` value. + + Args: + key: A metric key, either a bare ``EvalName`` or a prefixed key. + + Returns: + True if the key's eval component is a valid ``EvalName``. + """ + eval_component = key.rsplit(".", maxsplit=1)[-1] + return eval_component in EvalName._value2member_map_ class TemplateName(StrEnum):