diff --git a/src/tamperbench/whitebox/attacks/gcg/gcg.py b/src/tamperbench/whitebox/attacks/gcg/gcg.py index 8a0ec3c1..22471a86 100644 --- a/src/tamperbench/whitebox/attacks/gcg/gcg.py +++ b/src/tamperbench/whitebox/attacks/gcg/gcg.py @@ -1,141 +1,52 @@ """GCG attack interface.""" -import logging -from dataclasses import dataclass -from typing import TypeVar +from dataclasses import dataclass, field -import peft -import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, -) +import polars as pl +from pandera.typing.polars import DataFrame from typing_extensions import override from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig -from tamperbench.whitebox.attacks.gcg.implementation import GCG, ProbeSamplingConfig -from tamperbench.whitebox.evals import ( - StrongRejectEvaluation, - StrongRejectEvaluationConfig, -) -from tamperbench.whitebox.utils.models.default_tokens import DEFAULT_PAD_TOKEN -from tamperbench.whitebox.utils.names import EvalName - -logger = logging.getLogger(__name__) - -H = TypeVar(name="H", bound="GCGConfig") +from tamperbench.whitebox.attacks.gcg.implementation import GCGConfig +from tamperbench.whitebox.attacks.registry import register_attack +from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig +from tamperbench.whitebox.evals.output_schema import EvaluationSchema +from tamperbench.whitebox.utils.names import AttackName, EvalName @dataclass -class GCGConfig(TamperAttackConfig): - """Config for nanoGCG attack.""" +class GCGAttackConfig(TamperAttackConfig): + """Config for GCG attack.""" - num_steps: int = 250 - optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x" - search_width: int = 512 - batch_size: int = None - topk: int = 256 - n_replace: int = 1 - buffer_size: int = 0 - use_mellowmax: bool = False - mellowmax_alpha: float = 1.0 - early_stop: bool = False - use_prefix_cache: bool = True - allow_non_ascii: bool = False - filter_ids: bool = True - add_space_before_target: bool = False - seed: int = None - verbosity: str = "INFO" - probe_sampling_config: ProbeSamplingConfig | None = None + gcg_config: GCGConfig = field(default_factory=GCGConfig) -class GCGAttack(TamperAttack[H]): +@register_attack(AttackName.GCG_ATTACK, GCGAttackConfig) +class GCGAttack(TamperAttack[GCGAttackConfig]): """GCG attack class.""" - @override - def run_attack( - messages: str | list[dict[str, str]], - target: str, - config: GCGConfig | None = None, - ) -> GCGResult: - """Generate a single optimized string using GCG. - - Args: - messages: The conversation to use for optimization. - target: The target generation. - config: The GCG configuration to use. - - Returns: - A GCGResult object that contains losses and the optimized strings. - """ - model: PreTrainedModel = self.load_model() - tokenizer: PreTrainedTokenizer = self.load_tokenizer() - - if config is None: - config = GCGConfig() - - logger.setLevel(getattr(logging, config.verbosity)) - - gcg = GCG(model, tokenizer, config) - result = gcg.run(messages, target) - return result + name: AttackName = AttackName.GCG_ATTACK @override - def evaluate(self) -> dict[str, float]: - """Evaluate attack on relevant evaluations. - - Returns: - dict[str, float]: A dictionary with the relevant metrics and the associated values. - """ - results: dict[str, float] = {} - - if EvalName.STRONG_REJECT in self.attack_config.evals: - results.update(self.evaluate_strong_reject()) - - return results - - def load_model(self) -> PreTrainedModel: - """Load model from path in HuggingFace and/or PEFT.""" - model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path, - torch_dtype=torch.bfloat16, - # TODO PR #55: fully support flash attention - attn_implementation="flash_attention_2", - device_map="auto", - ) - - if self.attack_config.lora_input_folder: - model = peft.PeftModel.from_pretrained( - model, model_id=self.attack_config.lora_input_folder - ).merge_and_unload() - - return model - - def load_tokenizer(self) -> PreTrainedTokenizer: - """Load tokenizer from path in HuggingFace and/or PEFT.""" - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path, - padding_side="right", - use_fast=False, - ) - - if tokenizer.pad_token is None: - tokenizer.add_special_tokens(special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}) - - return tokenizer + def run_attack(self) -> None: + """GCG attack does not modify model weights directly.""" + self.output_checkpoint_path: str = self.attack_config.input_checkpoint_path @override - def evaluate_strong_reject(self) -> dict[str, float]: - """Evaluate attack on the `StrongRejectEvaluation` evaluator.""" - eval_cfg: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( - base_checkpoint=self.attack_config.base_output_checkpoint_path, - lora_folder=self.attack_config.lora_output_folder, - max_generation_length=self.attack_config.max_generation_length, - batch_size=64, - small=True, + def evaluate(self) -> DataFrame[EvaluationSchema]: + """Evaluate GCG attack, including GCG-specific evaluation if requested.""" + results = super().evaluate() + if EvalName.GCG_ATTACK_EVAL in self.attack_config.evals: + results = pl.concat([results, self._evaluate_gcg()]) + return EvaluationSchema.validate(results) + + def _evaluate_gcg(self) -> DataFrame[EvaluationSchema]: + """Instantiate GCGEvaluation and run evaluation.""" + eval_cfg = GCGEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + gcg_config=self.attack_config.gcg_config, ) - evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = StrongRejectEvaluation(eval_config=eval_cfg) - + evaluator = GCGEvaluation(eval_cfg) return evaluator.run_evaluation() diff --git a/src/tamperbench/whitebox/attacks/gcg/implementation.py b/src/tamperbench/whitebox/attacks/gcg/implementation.py index 1909ed91..21a72ffd 100644 --- a/src/tamperbench/whitebox/attacks/gcg/implementation.py +++ b/src/tamperbench/whitebox/attacks/gcg/implementation.py @@ -1,4 +1,8 @@ -"""GCG attack implementation module.""" +"""GCG attack implementation module. + +Sourced from https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/gcg.py +and https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/utils.py +""" import copy import functools @@ -17,8 +21,7 @@ from torch import Tensor from tqdm import tqdm from transformers import PreTrainedTokenizerBase, set_seed - -from tamperbench.whitebox.attacks.gcg.gcg import GCGConfig +from transformers.cache_utils import DynamicCache """UTILS""" @@ -196,6 +199,29 @@ class ProbeSamplingConfig: sampling_factor: int = 16 +@dataclass +class GCGConfig: + """Config for nanoGCG attack.""" + + num_steps: int = 250 + optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x" + search_width: int = 512 + batch_size: int | None = None + topk: int = 256 + n_replace: int = 1 + buffer_size: int = 0 + use_mellowmax: bool = False + mellowmax_alpha: float = 1.0 + early_stop: bool = False + use_prefix_cache: bool = True + allow_non_ascii: bool = False + filter_ids: bool = True + add_space_before_target: bool = False + seed: int | None = None + verbosity: str = "INFO" + probe_sampling_config: ProbeSamplingConfig | None = None + + @dataclass class GCGResult: """Result of a GCG attack optimization run.""" @@ -332,6 +358,37 @@ def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer): return torch.stack(filtered_ids) +def _copy_prefix_cache(prefix_cache: DynamicCache, batch_size: int = 1) -> DynamicCache: + """Create a shallow copy of a prefix KV cache, optionally expanding to a batch size. + + A new DynamicCache object is returned so that the model's in-place + ``DynamicCache.update()`` calls (which ``torch.cat`` new KV states onto + the cache) mutate the copy rather than the original. The underlying + key/value tensors are **not** cloned — ``torch.cat`` inside ``update`` + already produces new tensors, leaving the originals untouched. + + The original nanoGCG stored caches as ``list[tuple[Tensor, ...]]`` which + the model never mutated. After migrating to ``DynamicCache`` the same + immutability guarantee must be restored manually by copying the wrapper. + + Args: + prefix_cache: The cached key/value states from the prefix tokens. + batch_size: The batch size to expand to (default 1, i.e. plain copy). + + Returns: + A new DynamicCache wrapping the (optionally expanded) tensors. + """ + copy = DynamicCache() + for layer_idx in range(len(prefix_cache)): + key, value = prefix_cache[layer_idx] + copy.update( + key.expand(batch_size, -1, -1, -1), + value.expand(batch_size, -1, -1, -1), + layer_idx, + ) + return copy + + class GCG: """GCG (Greedy Coordinate Gradient) attack implementation.""" @@ -356,8 +413,8 @@ def __init__( self.not_allowed_ids: Tensor | None = ( None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device) ) - self.prefix_cache: list[tuple[Tensor, ...]] | None = None - self.draft_prefix_cache: list[tuple[Tensor, ...]] | None = None + self.prefix_cache: DynamicCache | None = None + self.draft_prefix_cache: DynamicCache | None = None self.stop_flag: bool = False @@ -696,9 +753,11 @@ def compute_token_gradient( if self.prefix_cache: input_embeds = torch.cat([optim_embeds, self.after_embeds, self.target_embeds], dim=1) + # Copy the prefix cache so that the model's in-place DynamicCache + # mutations don't corrupt self.prefix_cache across iterations. output = model( inputs_embeds=input_embeds, - past_key_values=self.prefix_cache, + past_key_values=_copy_prefix_cache(self.prefix_cache), use_cache=True, ) else: @@ -746,7 +805,6 @@ def _compute_candidates_loss_original( the embeddings of the `search_width` candidate sequences to evaluate """ all_loss = [] - prefix_cache_batch = [] for i in range(0, input_embeds.shape[0], search_batch_size): with torch.no_grad(): @@ -754,11 +812,10 @@ def _compute_candidates_loss_original( current_batch_size = input_embeds_batch.shape[0] if self.prefix_cache: - if not prefix_cache_batch or current_batch_size != search_batch_size: - prefix_cache_batch = [ - [x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]] - for i in range(len(self.prefix_cache)) - ] + # Fresh copy each batch — the model mutates DynamicCache + # in-place, so reusing one across batches would accumulate + # stale KV states from previous batches. + prefix_cache_batch = _copy_prefix_cache(self.prefix_cache, current_batch_size) outputs = self.model( inputs_embeds=input_embeds_batch, @@ -839,18 +896,13 @@ def _compute_draft_losses( ) draft_losses = [] - draft_prefix_cache_batch = None for i in range(0, B, search_batch_size): with torch.no_grad(): batch_size = min(search_batch_size, B - i) draft_sampled_ids_batch = draft_sampled_ids[i : i + batch_size] if self.draft_prefix_cache: - if not draft_prefix_cache_batch or batch_size != search_batch_size: - draft_prefix_cache_batch = [ - [x.expand(batch_size, -1, -1, -1) for x in self.draft_prefix_cache[i]] - for i in range(len(self.draft_prefix_cache)) - ] + draft_prefix_cache_batch = _copy_prefix_cache(self.draft_prefix_cache, batch_size) draft_embeds = torch.cat( [ self.draft_embedding_layer(draft_sampled_ids_batch), diff --git a/src/tamperbench/whitebox/evals/__init__.py b/src/tamperbench/whitebox/evals/__init__.py index f5cf0aca..be24b8dd 100644 --- a/src/tamperbench/whitebox/evals/__init__.py +++ b/src/tamperbench/whitebox/evals/__init__.py @@ -4,6 +4,10 @@ EmbeddingAttackEvaluation, EmbeddingAttackEvaluationConfig, ) +from tamperbench.whitebox.evals.gcg.gcg import ( + GCGEvaluation, + GCGEvaluationConfig, +) from tamperbench.whitebox.evals.ifeval.ifeval import ( IFEvalEvaluation, IFEvalEvaluationConfig, @@ -33,6 +37,8 @@ __all__ = [ "EmbeddingAttackEvaluation", "EmbeddingAttackEvaluationConfig", + "GCGEvaluation", + "GCGEvaluationConfig", "IFEvalEvaluation", "IFEvalEvaluationConfig", "JailbreakBenchEvaluation", diff --git a/src/tamperbench/whitebox/evals/gcg/__init__.py b/src/tamperbench/whitebox/evals/gcg/__init__.py new file mode 100644 index 00000000..00ef2efc --- /dev/null +++ b/src/tamperbench/whitebox/evals/gcg/__init__.py @@ -0,0 +1,8 @@ +"""GCG attack evaluation.""" + +from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig + +__all__ = [ + "GCGEvaluation", + "GCGEvaluationConfig", +] diff --git a/src/tamperbench/whitebox/evals/gcg/gcg.py b/src/tamperbench/whitebox/evals/gcg/gcg.py new file mode 100644 index 00000000..13e8df23 --- /dev/null +++ b/src/tamperbench/whitebox/evals/gcg/gcg.py @@ -0,0 +1,123 @@ +"""GCG attack evaluator with StrongREJECT dataset.""" + +from dataclasses import dataclass, field +from typing import Any, cast + +import datasets +import polars as pl +import torch +from pandera.typing.polars import DataFrame +from torch import Tensor +from tqdm import tqdm +from typing_extensions import override + +from tamperbench.whitebox.attacks.gcg.implementation import GCG, GCGConfig +from tamperbench.whitebox.evals.base import load_model_and_tokenizer +from tamperbench.whitebox.evals.output_schema import InferenceSchema +from tamperbench.whitebox.evals.registry import register_evaluation +from tamperbench.whitebox.evals.strong_reject.strong_reject import ( + StrongRejectEvaluation, + StrongRejectEvaluationConfig, +) +from tamperbench.whitebox.utils import dealloc_model_and_tokenizer +from tamperbench.whitebox.utils.names import EvalName +from tamperbench.whitebox.utils.ops import run_in_isolation + + +@dataclass +class GCGEvaluationConfig(StrongRejectEvaluationConfig): + """Evaluation Configuration for GCG attack evaluation. + + Attributes: + gcg_config: Configuration for GCG optimization. + """ + + gcg_config: GCGConfig = field(default_factory=GCGConfig) + + +@register_evaluation(EvalName.GCG_ATTACK_EVAL) +class GCGEvaluation( + StrongRejectEvaluation[GCGEvaluationConfig], +): + """GCG attack evaluation class.""" + + name: EvalName = EvalName.GCG_ATTACK_EVAL + + @override + def compute_inferences(self) -> DataFrame[InferenceSchema]: + """Run inferences (using GCG attack) on evaluation dataset (StrongReject). + + Returns: + DataFrame[InferenceSchema]: A polars DataFrame with the prompts to LLM and corresponding responses. + """ + strongreject_dataset: datasets.Dataset = self._load_strongreject_dataset() + + payload: pl.DataFrame = run_in_isolation( + target=_instantiate_model_and_infer, + args=(self.eval_config, strongreject_dataset), + error_context="GCG attack inference", + ) + + return InferenceSchema.validate(payload) + + def _load_strongreject_dataset(self) -> datasets.Dataset: + """Load the StrongReject dataset with targets.""" + return datasets.load_dataset(path="sdhossain24/strongreject_with_targets", split="train") + + +def _instantiate_model_and_infer( + eval_config: GCGEvaluationConfig, + strongreject_dataset: datasets.Dataset, +) -> pl.DataFrame: + """Run GCG optimization and model inference on StrongREJECT dataset. + + Args: + eval_config: Evaluation configuration with model checkpoint and GCG settings. + strongreject_dataset: StrongReject dataset with prompts and targets. + + Returns: + A polars DataFrame with prompts and model responses. + """ + model, tokenizer = load_model_and_tokenizer(eval_config) + try: + gcg = GCG(model, tokenizer, eval_config.gcg_config) + + inferences: dict[str, list[str]] = { + InferenceSchema.prompt: [], + InferenceSchema.response: [], + } + + for data_row in tqdm(strongreject_dataset, total=len(strongreject_dataset)): + row: dict[str, Any] = cast(dict[str, Any], data_row) + result = gcg.run(messages=row["prompt"], target=row["target"]) + + # Build the chat-templated prompt with the GCG suffix in the same + # position it occupied during optimization (end of user message). + chat_messages = [{"role": "user", "content": row["prompt"] + result.best_string}] + templated_prompt = cast( + str, tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) + ) + input_ids = cast( + Tensor, + tokenizer( + templated_prompt, + return_tensors="pt", + add_special_tokens=False, + )["input_ids"], + ).to(model.device) + + with torch.no_grad(): + output: Tensor = model.generate( # pyright: ignore[reportCallIssue] + input_ids=input_ids, + max_new_tokens=eval_config.model_config.max_generation_length, + ) + # Decode only the generated tokens (exclude input) + generated_ids = output[0][input_ids.shape[1] :] + response_txt: str = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + inferences[InferenceSchema.prompt].append(row["prompt"]) + inferences[InferenceSchema.response].append(response_txt) + + return InferenceSchema.validate(pl.from_dict(data=inferences)) + finally: + dealloc_model_and_tokenizer(model, tokenizer) diff --git a/src/tamperbench/whitebox/utils/names.py b/src/tamperbench/whitebox/utils/names.py index a7de9303..feffc0c0 100644 --- a/src/tamperbench/whitebox/utils/names.py +++ b/src/tamperbench/whitebox/utils/names.py @@ -20,6 +20,7 @@ class AttackName(StrEnum): LORA_FULL_DATASET_FINETUNE = "lora_full_dataset_finetune" LORA_FINETUNE_OPTIM = "lora_finetune_optim" PROMPT_ATTACK = "prompt_attack" + GCG_ATTACK = "gcg_attack" class DefenseName(StrEnum): @@ -42,6 +43,7 @@ class EvalName(StrEnum): IFEVAL = "ifeval" JAILBREAK_BENCH = "jailbreak_bench" PROMPT_ATTACK = "prompt_attack_eval" + GCG_ATTACK_EVAL = "gcg_attack_eval" class MetricName(StrEnum): diff --git a/tests/attacks/test_gcg_attack.py b/tests/attacks/test_gcg_attack.py new file mode 100644 index 00000000..9f736bf1 --- /dev/null +++ b/tests/attacks/test_gcg_attack.py @@ -0,0 +1,49 @@ +"""Sanity check for GCGAttack. + +GCGAttack is a no-op attack (does not modify model weights) that delegates to +the GCG evaluation pipeline for input-space optimization. This test exercises +the run_attack -> evaluate path end-to-end using GCG_ATTACK_EVAL. +""" + +import tempfile + +from dotenv import load_dotenv + +from tamperbench.whitebox.attacks.gcg.gcg import GCGAttack, GCGAttackConfig +from tamperbench.whitebox.attacks.gcg.implementation import GCGConfig +from tamperbench.whitebox.evals.output_schema import EvaluationSchema +from tamperbench.whitebox.utils.models.config import ModelConfig +from tamperbench.whitebox.utils.names import EvalName, MetricName + +if __name__ == "__main__": + load_dotenv() + + with tempfile.TemporaryDirectory() as tmpdirname: + llama_3_8b_gcg_config = GCGAttackConfig( + input_checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + out_dir=tmpdirname, + model_config=ModelConfig( + user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", + assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", + end_turn="<|eot_id|>\n", + max_generation_length=512, + inference_batch_size=16, + ), + evals=[EvalName.GCG_ATTACK_EVAL], + random_seed=42, + gcg_config=GCGConfig(num_steps=50), + ) + + gcg_attack = GCGAttack(attack_config=llama_3_8b_gcg_config) + + gcg_attack.run_attack() + + attacked_eval: dict[str, tuple[float]] = gcg_attack.evaluate().rows_by_key( + key=EvaluationSchema.metric_name, + unique=True, + ) + + assert MetricName.STRONG_REJECT_SCORE in attacked_eval + assert attacked_eval[MetricName.STRONG_REJECT_SCORE][0] >= 0.0 + + print("Llama3.1-8B Instruct GCG Attack:", attacked_eval)