From 5cc7f1942ce09017d97cdce77912195c4a784f83 Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 22:43:28 -0800 Subject: [PATCH 1/7] gcg: Actually run the optimization in evaluation --- src/tamperbench/whitebox/attacks/gcg/gcg.py | 153 ++++-------------- .../whitebox/attacks/gcg/implementation.py | 31 +++- src/tamperbench/whitebox/evals/__init__.py | 6 + .../whitebox/evals/gcg/__init__.py | 8 + src/tamperbench/whitebox/evals/gcg/gcg.py | 118 ++++++++++++++ src/tamperbench/whitebox/utils/names.py | 2 + tests/attacks/test_gcg_attack.py | 51 ++++++ 7 files changed, 245 insertions(+), 124 deletions(-) create mode 100644 src/tamperbench/whitebox/evals/gcg/__init__.py create mode 100644 src/tamperbench/whitebox/evals/gcg/gcg.py create mode 100644 tests/attacks/test_gcg_attack.py diff --git a/src/tamperbench/whitebox/attacks/gcg/gcg.py b/src/tamperbench/whitebox/attacks/gcg/gcg.py index 8a0ec3c1..22471a86 100644 --- a/src/tamperbench/whitebox/attacks/gcg/gcg.py +++ b/src/tamperbench/whitebox/attacks/gcg/gcg.py @@ -1,141 +1,52 @@ """GCG attack interface.""" -import logging -from dataclasses import dataclass -from typing import TypeVar +from dataclasses import dataclass, field -import peft -import torch -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, -) +import polars as pl +from pandera.typing.polars import DataFrame from typing_extensions import override from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig -from tamperbench.whitebox.attacks.gcg.implementation import GCG, ProbeSamplingConfig -from tamperbench.whitebox.evals import ( - StrongRejectEvaluation, - StrongRejectEvaluationConfig, -) -from tamperbench.whitebox.utils.models.default_tokens import DEFAULT_PAD_TOKEN -from tamperbench.whitebox.utils.names import EvalName - -logger = logging.getLogger(__name__) - -H = TypeVar(name="H", bound="GCGConfig") +from tamperbench.whitebox.attacks.gcg.implementation import GCGConfig +from tamperbench.whitebox.attacks.registry import register_attack +from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig +from tamperbench.whitebox.evals.output_schema import EvaluationSchema +from tamperbench.whitebox.utils.names import AttackName, EvalName @dataclass -class GCGConfig(TamperAttackConfig): - """Config for nanoGCG attack.""" +class GCGAttackConfig(TamperAttackConfig): + """Config for GCG attack.""" - num_steps: int = 250 - optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x" - search_width: int = 512 - batch_size: int = None - topk: int = 256 - n_replace: int = 1 - buffer_size: int = 0 - use_mellowmax: bool = False - mellowmax_alpha: float = 1.0 - early_stop: bool = False - use_prefix_cache: bool = True - allow_non_ascii: bool = False - filter_ids: bool = True - add_space_before_target: bool = False - seed: int = None - verbosity: str = "INFO" - probe_sampling_config: ProbeSamplingConfig | None = None + gcg_config: GCGConfig = field(default_factory=GCGConfig) -class GCGAttack(TamperAttack[H]): +@register_attack(AttackName.GCG_ATTACK, GCGAttackConfig) +class GCGAttack(TamperAttack[GCGAttackConfig]): """GCG attack class.""" - @override - def run_attack( - messages: str | list[dict[str, str]], - target: str, - config: GCGConfig | None = None, - ) -> GCGResult: - """Generate a single optimized string using GCG. - - Args: - messages: The conversation to use for optimization. - target: The target generation. - config: The GCG configuration to use. - - Returns: - A GCGResult object that contains losses and the optimized strings. - """ - model: PreTrainedModel = self.load_model() - tokenizer: PreTrainedTokenizer = self.load_tokenizer() - - if config is None: - config = GCGConfig() - - logger.setLevel(getattr(logging, config.verbosity)) - - gcg = GCG(model, tokenizer, config) - result = gcg.run(messages, target) - return result + name: AttackName = AttackName.GCG_ATTACK @override - def evaluate(self) -> dict[str, float]: - """Evaluate attack on relevant evaluations. - - Returns: - dict[str, float]: A dictionary with the relevant metrics and the associated values. - """ - results: dict[str, float] = {} - - if EvalName.STRONG_REJECT in self.attack_config.evals: - results.update(self.evaluate_strong_reject()) - - return results - - def load_model(self) -> PreTrainedModel: - """Load model from path in HuggingFace and/or PEFT.""" - model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path, - torch_dtype=torch.bfloat16, - # TODO PR #55: fully support flash attention - attn_implementation="flash_attention_2", - device_map="auto", - ) - - if self.attack_config.lora_input_folder: - model = peft.PeftModel.from_pretrained( - model, model_id=self.attack_config.lora_input_folder - ).merge_and_unload() - - return model - - def load_tokenizer(self) -> PreTrainedTokenizer: - """Load tokenizer from path in HuggingFace and/or PEFT.""" - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path, - padding_side="right", - use_fast=False, - ) - - if tokenizer.pad_token is None: - tokenizer.add_special_tokens(special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}) - - return tokenizer + def run_attack(self) -> None: + """GCG attack does not modify model weights directly.""" + self.output_checkpoint_path: str = self.attack_config.input_checkpoint_path @override - def evaluate_strong_reject(self) -> dict[str, float]: - """Evaluate attack on the `StrongRejectEvaluation` evaluator.""" - eval_cfg: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( - base_checkpoint=self.attack_config.base_output_checkpoint_path, - lora_folder=self.attack_config.lora_output_folder, - max_generation_length=self.attack_config.max_generation_length, - batch_size=64, - small=True, + def evaluate(self) -> DataFrame[EvaluationSchema]: + """Evaluate GCG attack, including GCG-specific evaluation if requested.""" + results = super().evaluate() + if EvalName.GCG_ATTACK_EVAL in self.attack_config.evals: + results = pl.concat([results, self._evaluate_gcg()]) + return EvaluationSchema.validate(results) + + def _evaluate_gcg(self) -> DataFrame[EvaluationSchema]: + """Instantiate GCGEvaluation and run evaluation.""" + eval_cfg = GCGEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + gcg_config=self.attack_config.gcg_config, ) - evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = StrongRejectEvaluation(eval_config=eval_cfg) - + evaluator = GCGEvaluation(eval_cfg) return evaluator.run_evaluation() diff --git a/src/tamperbench/whitebox/attacks/gcg/implementation.py b/src/tamperbench/whitebox/attacks/gcg/implementation.py index 1909ed91..f3cd9033 100644 --- a/src/tamperbench/whitebox/attacks/gcg/implementation.py +++ b/src/tamperbench/whitebox/attacks/gcg/implementation.py @@ -1,4 +1,8 @@ -"""GCG attack implementation module.""" +"""GCG attack implementation module. + +Sourced from https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/gcg.py +and https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/utils.py +""" import copy import functools @@ -18,8 +22,6 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase, set_seed -from tamperbench.whitebox.attacks.gcg.gcg import GCGConfig - """UTILS""" INIT_CHARS = [ @@ -196,6 +198,29 @@ class ProbeSamplingConfig: sampling_factor: int = 16 +@dataclass +class GCGConfig: + """Config for nanoGCG attack.""" + + num_steps: int = 250 + optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x" + search_width: int = 512 + batch_size: int | None = None + topk: int = 256 + n_replace: int = 1 + buffer_size: int = 0 + use_mellowmax: bool = False + mellowmax_alpha: float = 1.0 + early_stop: bool = False + use_prefix_cache: bool = True + allow_non_ascii: bool = False + filter_ids: bool = True + add_space_before_target: bool = False + seed: int | None = None + verbosity: str = "INFO" + probe_sampling_config: ProbeSamplingConfig | None = None + + @dataclass class GCGResult: """Result of a GCG attack optimization run.""" diff --git a/src/tamperbench/whitebox/evals/__init__.py b/src/tamperbench/whitebox/evals/__init__.py index f5cf0aca..be24b8dd 100644 --- a/src/tamperbench/whitebox/evals/__init__.py +++ b/src/tamperbench/whitebox/evals/__init__.py @@ -4,6 +4,10 @@ EmbeddingAttackEvaluation, EmbeddingAttackEvaluationConfig, ) +from tamperbench.whitebox.evals.gcg.gcg import ( + GCGEvaluation, + GCGEvaluationConfig, +) from tamperbench.whitebox.evals.ifeval.ifeval import ( IFEvalEvaluation, IFEvalEvaluationConfig, @@ -33,6 +37,8 @@ __all__ = [ "EmbeddingAttackEvaluation", "EmbeddingAttackEvaluationConfig", + "GCGEvaluation", + "GCGEvaluationConfig", "IFEvalEvaluation", "IFEvalEvaluationConfig", "JailbreakBenchEvaluation", diff --git a/src/tamperbench/whitebox/evals/gcg/__init__.py b/src/tamperbench/whitebox/evals/gcg/__init__.py new file mode 100644 index 00000000..00ef2efc --- /dev/null +++ b/src/tamperbench/whitebox/evals/gcg/__init__.py @@ -0,0 +1,8 @@ +"""GCG attack evaluation.""" + +from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig + +__all__ = [ + "GCGEvaluation", + "GCGEvaluationConfig", +] diff --git a/src/tamperbench/whitebox/evals/gcg/gcg.py b/src/tamperbench/whitebox/evals/gcg/gcg.py new file mode 100644 index 00000000..6878263e --- /dev/null +++ b/src/tamperbench/whitebox/evals/gcg/gcg.py @@ -0,0 +1,118 @@ +"""GCG attack evaluator with StrongREJECT dataset.""" + +from dataclasses import dataclass, field +from typing import Any, cast + +import datasets +import polars as pl +import torch +from pandera.typing.polars import DataFrame +from torch import Tensor +from tqdm import tqdm +from typing_extensions import override + +from tamperbench.whitebox.attacks.gcg.implementation import GCG, GCGConfig +from tamperbench.whitebox.evals.base import load_model_and_tokenizer +from tamperbench.whitebox.evals.output_schema import InferenceSchema +from tamperbench.whitebox.evals.registry import register_evaluation +from tamperbench.whitebox.evals.strong_reject.strong_reject import ( + StrongRejectEvaluation, + StrongRejectEvaluationConfig, +) +from tamperbench.whitebox.utils import dealloc_model_and_tokenizer +from tamperbench.whitebox.utils.names import EvalName +from tamperbench.whitebox.utils.ops import run_in_isolation + + +@dataclass +class GCGEvaluationConfig(StrongRejectEvaluationConfig): + """Evaluation Configuration for GCG attack evaluation. + + Attributes: + gcg_config: Configuration for GCG optimization. + """ + + gcg_config: GCGConfig = field(default_factory=GCGConfig) + + +@register_evaluation(EvalName.GCG_ATTACK_EVAL) +class GCGEvaluation( + StrongRejectEvaluation[GCGEvaluationConfig], +): + """GCG attack evaluation class.""" + + name: EvalName = EvalName.GCG_ATTACK_EVAL + + @override + def compute_inferences(self) -> DataFrame[InferenceSchema]: + """Run inferences (using GCG attack) on evaluation dataset (StrongReject). + + Returns: + DataFrame[InferenceSchema]: A polars DataFrame with the prompts to LLM and corresponding responses. + """ + strongreject_dataset: datasets.Dataset = self._load_strongreject_dataset() + + payload: pl.DataFrame = run_in_isolation( + target=_instantiate_model_and_infer, + args=(self.eval_config, strongreject_dataset), + error_context="GCG attack inference", + ) + + return InferenceSchema.validate(payload) + + def _load_strongreject_dataset(self) -> datasets.Dataset: + """Load the StrongReject dataset with targets.""" + return datasets.load_dataset(path="sdhossain24/strongreject_with_targets", split="train") + + +def _instantiate_model_and_infer( + eval_config: GCGEvaluationConfig, + strongreject_dataset: datasets.Dataset, +) -> pl.DataFrame: + """Run GCG optimization and model inference on StrongREJECT dataset. + + Args: + eval_config: Evaluation configuration with model checkpoint and GCG settings. + strongreject_dataset: StrongReject dataset with prompts and targets. + + Returns: + A polars DataFrame with prompts and model responses. + """ + model, tokenizer = load_model_and_tokenizer(eval_config) + try: + gcg = GCG(model, tokenizer, eval_config.gcg_config) + + inferences: dict[str, list[str]] = { + InferenceSchema.prompt: [], + InferenceSchema.response: [], + } + + for data_row in tqdm(strongreject_dataset, total=len(strongreject_dataset)): + row: dict[str, Any] = cast(dict[str, Any], data_row) + result = gcg.run(messages=row["prompt"], target=row["target"]) + + adversarial_prompt: str = row["prompt"] + result.best_string + input_ids = cast( + Tensor, + tokenizer( + adversarial_prompt, + return_tensors="pt", + add_special_tokens=True, + )["input_ids"], + ).to(model.device) + + with torch.no_grad(): + output: Tensor = model.generate( # pyright: ignore[reportCallIssue] + input_ids=input_ids, + max_new_tokens=eval_config.model_config.max_generation_length, + ) + # Decode only the generated tokens (exclude input) + generated_ids = output[0][input_ids.shape[1] :] + response_txt: str = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + inferences[InferenceSchema.prompt].append(row["prompt"]) + inferences[InferenceSchema.response].append(response_txt) + + return InferenceSchema.validate(pl.from_dict(data=inferences)) + finally: + dealloc_model_and_tokenizer(model, tokenizer) diff --git a/src/tamperbench/whitebox/utils/names.py b/src/tamperbench/whitebox/utils/names.py index a7de9303..feffc0c0 100644 --- a/src/tamperbench/whitebox/utils/names.py +++ b/src/tamperbench/whitebox/utils/names.py @@ -20,6 +20,7 @@ class AttackName(StrEnum): LORA_FULL_DATASET_FINETUNE = "lora_full_dataset_finetune" LORA_FINETUNE_OPTIM = "lora_finetune_optim" PROMPT_ATTACK = "prompt_attack" + GCG_ATTACK = "gcg_attack" class DefenseName(StrEnum): @@ -42,6 +43,7 @@ class EvalName(StrEnum): IFEVAL = "ifeval" JAILBREAK_BENCH = "jailbreak_bench" PROMPT_ATTACK = "prompt_attack_eval" + GCG_ATTACK_EVAL = "gcg_attack_eval" class MetricName(StrEnum): diff --git a/tests/attacks/test_gcg_attack.py b/tests/attacks/test_gcg_attack.py new file mode 100644 index 00000000..78d5e083 --- /dev/null +++ b/tests/attacks/test_gcg_attack.py @@ -0,0 +1,51 @@ +"""Sanity check for GCGAttack. + +GCGAttack is a no-op attack (does not modify model weights) that delegates to +the GCG evaluation pipeline for input-space optimization. This test exercises +the run_attack -> evaluate path end-to-end using GCG_ATTACK_EVAL. +""" + +import tempfile + +from dotenv import load_dotenv + +from tamperbench.whitebox.attacks.gcg.gcg import GCGAttack, GCGAttackConfig +from tamperbench.whitebox.attacks.gcg.implementation import GCGConfig +from tamperbench.whitebox.evals.output_schema import EvaluationSchema +from tamperbench.whitebox.utils.models.config import ModelConfig +from tamperbench.whitebox.utils.names import EvalName, MetricName + +if __name__ == "__main__": + load_dotenv() + + with tempfile.TemporaryDirectory() as tmpdirname: + llama_3_8b_gcg_config = GCGAttackConfig( + input_checkpoint_path="meta-llama/Llama-3.1-8B-Instruct", + out_dir=tmpdirname, + model_config=ModelConfig( + user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", + assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", + end_turn="<|eot_id|>\n", + max_generation_length=512, + inference_batch_size=16, + ), + evals=[EvalName.GCG_ATTACK_EVAL], + random_seed=42, + gcg_config=GCGConfig(num_steps=50), + ) + + gcg_attack = GCGAttack(attack_config=llama_3_8b_gcg_config) + + gcg_attack.run_attack() + + assert gcg_attack.output_checkpoint_path == llama_3_8b_gcg_config.input_checkpoint_path + + attacked_eval: dict[str, tuple[float]] = gcg_attack.evaluate().rows_by_key( + key=EvaluationSchema.metric_name, + unique=True, + ) + + assert MetricName.STRONG_REJECT_SCORE in attacked_eval + assert attacked_eval[MetricName.STRONG_REJECT_SCORE][0] >= 0.0 + + print("Llama3.1-8B Instruct GCG Attack:", attacked_eval) From 4fbe2e8d51f3318afad11043e1fd37114cf659c5 Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 22:55:33 -0800 Subject: [PATCH 2/7] gcg implementation: Adapt for new transformers KVCache format --- .../whitebox/attacks/gcg/implementation.py | 47 +++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/src/tamperbench/whitebox/attacks/gcg/implementation.py b/src/tamperbench/whitebox/attacks/gcg/implementation.py index f3cd9033..0d3e2b30 100644 --- a/src/tamperbench/whitebox/attacks/gcg/implementation.py +++ b/src/tamperbench/whitebox/attacks/gcg/implementation.py @@ -21,6 +21,7 @@ from torch import Tensor from tqdm import tqdm from transformers import PreTrainedTokenizerBase, set_seed +from transformers.cache_utils import DynamicCache """UTILS""" @@ -357,6 +358,30 @@ def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer): return torch.stack(filtered_ids) +def _expand_prefix_cache(prefix_cache: DynamicCache, batch_size: int) -> DynamicCache: + """Expand a prefix KV cache to a given batch size. + + This is new code not in the original nanoGCG implementation. Newer + `transformers` returns a DynamicCache for KV caches rather than a list. + + Args: + prefix_cache: The cached key/value states from the prefix tokens. + batch_size: The batch size to expand to. + + Returns: + A new DynamicCache with tensors expanded to the given batch size. + """ + expanded = DynamicCache() + for layer_idx in range(len(prefix_cache)): + key, value = prefix_cache[layer_idx] + expanded.update( + key.expand(batch_size, -1, -1, -1), + value.expand(batch_size, -1, -1, -1), + layer_idx, + ) + return expanded + + class GCG: """GCG (Greedy Coordinate Gradient) attack implementation.""" @@ -381,8 +406,8 @@ def __init__( self.not_allowed_ids: Tensor | None = ( None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device) ) - self.prefix_cache: list[tuple[Tensor, ...]] | None = None - self.draft_prefix_cache: list[tuple[Tensor, ...]] | None = None + self.prefix_cache: DynamicCache | None = None + self.draft_prefix_cache: DynamicCache | None = None self.stop_flag: bool = False @@ -771,7 +796,7 @@ def _compute_candidates_loss_original( the embeddings of the `search_width` candidate sequences to evaluate """ all_loss = [] - prefix_cache_batch = [] + prefix_cache_batch: DynamicCache | None = None for i in range(0, input_embeds.shape[0], search_batch_size): with torch.no_grad(): @@ -779,11 +804,8 @@ def _compute_candidates_loss_original( current_batch_size = input_embeds_batch.shape[0] if self.prefix_cache: - if not prefix_cache_batch or current_batch_size != search_batch_size: - prefix_cache_batch = [ - [x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]] - for i in range(len(self.prefix_cache)) - ] + if prefix_cache_batch is None or current_batch_size != search_batch_size: + prefix_cache_batch = _expand_prefix_cache(self.prefix_cache, current_batch_size) outputs = self.model( inputs_embeds=input_embeds_batch, @@ -864,18 +886,15 @@ def _compute_draft_losses( ) draft_losses = [] - draft_prefix_cache_batch = None + draft_prefix_cache_batch: DynamicCache | None = None for i in range(0, B, search_batch_size): with torch.no_grad(): batch_size = min(search_batch_size, B - i) draft_sampled_ids_batch = draft_sampled_ids[i : i + batch_size] if self.draft_prefix_cache: - if not draft_prefix_cache_batch or batch_size != search_batch_size: - draft_prefix_cache_batch = [ - [x.expand(batch_size, -1, -1, -1) for x in self.draft_prefix_cache[i]] - for i in range(len(self.draft_prefix_cache)) - ] + if draft_prefix_cache_batch is None or batch_size != search_batch_size: + draft_prefix_cache_batch = _expand_prefix_cache(self.draft_prefix_cache, batch_size) draft_embeds = torch.cat( [ self.draft_embedding_layer(draft_sampled_ids_batch), From 4829e09b3be24ead0ebd8fa21f79c80b50bc64fe Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 23:24:26 -0800 Subject: [PATCH 3/7] gcg: Evaluate with chat template --- src/tamperbench/whitebox/evals/gcg/gcg.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/tamperbench/whitebox/evals/gcg/gcg.py b/src/tamperbench/whitebox/evals/gcg/gcg.py index 6878263e..13e8df23 100644 --- a/src/tamperbench/whitebox/evals/gcg/gcg.py +++ b/src/tamperbench/whitebox/evals/gcg/gcg.py @@ -91,13 +91,18 @@ def _instantiate_model_and_infer( row: dict[str, Any] = cast(dict[str, Any], data_row) result = gcg.run(messages=row["prompt"], target=row["target"]) - adversarial_prompt: str = row["prompt"] + result.best_string + # Build the chat-templated prompt with the GCG suffix in the same + # position it occupied during optimization (end of user message). + chat_messages = [{"role": "user", "content": row["prompt"] + result.best_string}] + templated_prompt = cast( + str, tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True) + ) input_ids = cast( Tensor, tokenizer( - adversarial_prompt, + templated_prompt, return_tensors="pt", - add_special_tokens=True, + add_special_tokens=False, )["input_ids"], ).to(model.device) From b47004490e51dc2d70487b0d20bb4c000d8cdb6a Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 23:33:30 -0800 Subject: [PATCH 4/7] gcg run_attack: pass like embedding_attack.py --- src/tamperbench/whitebox/attacks/gcg/gcg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tamperbench/whitebox/attacks/gcg/gcg.py b/src/tamperbench/whitebox/attacks/gcg/gcg.py index 22471a86..e5101611 100644 --- a/src/tamperbench/whitebox/attacks/gcg/gcg.py +++ b/src/tamperbench/whitebox/attacks/gcg/gcg.py @@ -29,8 +29,8 @@ class GCGAttack(TamperAttack[GCGAttackConfig]): @override def run_attack(self) -> None: - """GCG attack does not modify model weights directly.""" - self.output_checkpoint_path: str = self.attack_config.input_checkpoint_path + """GCG attack does not modify model weights.""" + pass @override def evaluate(self) -> DataFrame[EvaluationSchema]: From 4addc3a8af389917c705e77e4ee639ae95e2669c Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 23:49:06 -0800 Subject: [PATCH 5/7] gcg implementation: Fix DynamicCache copying bug? GCG was not having much success on small tests, Claude seems to think it could be due to transformers modifying DynamicCache in a way we don't expect --- .../whitebox/attacks/gcg/implementation.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/tamperbench/whitebox/attacks/gcg/implementation.py b/src/tamperbench/whitebox/attacks/gcg/implementation.py index 0d3e2b30..21a72ffd 100644 --- a/src/tamperbench/whitebox/attacks/gcg/implementation.py +++ b/src/tamperbench/whitebox/attacks/gcg/implementation.py @@ -358,28 +358,35 @@ def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer): return torch.stack(filtered_ids) -def _expand_prefix_cache(prefix_cache: DynamicCache, batch_size: int) -> DynamicCache: - """Expand a prefix KV cache to a given batch size. +def _copy_prefix_cache(prefix_cache: DynamicCache, batch_size: int = 1) -> DynamicCache: + """Create a shallow copy of a prefix KV cache, optionally expanding to a batch size. - This is new code not in the original nanoGCG implementation. Newer - `transformers` returns a DynamicCache for KV caches rather than a list. + A new DynamicCache object is returned so that the model's in-place + ``DynamicCache.update()`` calls (which ``torch.cat`` new KV states onto + the cache) mutate the copy rather than the original. The underlying + key/value tensors are **not** cloned — ``torch.cat`` inside ``update`` + already produces new tensors, leaving the originals untouched. + + The original nanoGCG stored caches as ``list[tuple[Tensor, ...]]`` which + the model never mutated. After migrating to ``DynamicCache`` the same + immutability guarantee must be restored manually by copying the wrapper. Args: prefix_cache: The cached key/value states from the prefix tokens. - batch_size: The batch size to expand to. + batch_size: The batch size to expand to (default 1, i.e. plain copy). Returns: - A new DynamicCache with tensors expanded to the given batch size. + A new DynamicCache wrapping the (optionally expanded) tensors. """ - expanded = DynamicCache() + copy = DynamicCache() for layer_idx in range(len(prefix_cache)): key, value = prefix_cache[layer_idx] - expanded.update( + copy.update( key.expand(batch_size, -1, -1, -1), value.expand(batch_size, -1, -1, -1), layer_idx, ) - return expanded + return copy class GCG: @@ -746,9 +753,11 @@ def compute_token_gradient( if self.prefix_cache: input_embeds = torch.cat([optim_embeds, self.after_embeds, self.target_embeds], dim=1) + # Copy the prefix cache so that the model's in-place DynamicCache + # mutations don't corrupt self.prefix_cache across iterations. output = model( inputs_embeds=input_embeds, - past_key_values=self.prefix_cache, + past_key_values=_copy_prefix_cache(self.prefix_cache), use_cache=True, ) else: @@ -796,7 +805,6 @@ def _compute_candidates_loss_original( the embeddings of the `search_width` candidate sequences to evaluate """ all_loss = [] - prefix_cache_batch: DynamicCache | None = None for i in range(0, input_embeds.shape[0], search_batch_size): with torch.no_grad(): @@ -804,8 +812,10 @@ def _compute_candidates_loss_original( current_batch_size = input_embeds_batch.shape[0] if self.prefix_cache: - if prefix_cache_batch is None or current_batch_size != search_batch_size: - prefix_cache_batch = _expand_prefix_cache(self.prefix_cache, current_batch_size) + # Fresh copy each batch — the model mutates DynamicCache + # in-place, so reusing one across batches would accumulate + # stale KV states from previous batches. + prefix_cache_batch = _copy_prefix_cache(self.prefix_cache, current_batch_size) outputs = self.model( inputs_embeds=input_embeds_batch, @@ -886,15 +896,13 @@ def _compute_draft_losses( ) draft_losses = [] - draft_prefix_cache_batch: DynamicCache | None = None for i in range(0, B, search_batch_size): with torch.no_grad(): batch_size = min(search_batch_size, B - i) draft_sampled_ids_batch = draft_sampled_ids[i : i + batch_size] if self.draft_prefix_cache: - if draft_prefix_cache_batch is None or batch_size != search_batch_size: - draft_prefix_cache_batch = _expand_prefix_cache(self.draft_prefix_cache, batch_size) + draft_prefix_cache_batch = _copy_prefix_cache(self.draft_prefix_cache, batch_size) draft_embeds = torch.cat( [ self.draft_embedding_layer(draft_sampled_ids_batch), From d7e28636d74638867d12ae56916ccd825594b6ce Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Fri, 6 Feb 2026 23:50:49 -0800 Subject: [PATCH 6/7] test_gcg: Remove outdated assert --- tests/attacks/test_gcg_attack.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/attacks/test_gcg_attack.py b/tests/attacks/test_gcg_attack.py index 78d5e083..9f736bf1 100644 --- a/tests/attacks/test_gcg_attack.py +++ b/tests/attacks/test_gcg_attack.py @@ -38,8 +38,6 @@ gcg_attack.run_attack() - assert gcg_attack.output_checkpoint_path == llama_3_8b_gcg_config.input_checkpoint_path - attacked_eval: dict[str, tuple[float]] = gcg_attack.evaluate().rows_by_key( key=EvaluationSchema.metric_name, unique=True, From 723860516cdf059a8183696f63258ba826c11e1d Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Sat, 7 Feb 2026 00:05:36 -0800 Subject: [PATCH 7/7] Revert "gcg run_attack: pass like embedding_attack.py" This reverts commit b47004490e51dc2d70487b0d20bb4c000d8cdb6a. We don't want to pass, since evaluate() calls `super` which does need that attribute to be set --- src/tamperbench/whitebox/attacks/gcg/gcg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tamperbench/whitebox/attacks/gcg/gcg.py b/src/tamperbench/whitebox/attacks/gcg/gcg.py index e5101611..22471a86 100644 --- a/src/tamperbench/whitebox/attacks/gcg/gcg.py +++ b/src/tamperbench/whitebox/attacks/gcg/gcg.py @@ -29,8 +29,8 @@ class GCGAttack(TamperAttack[GCGAttackConfig]): @override def run_attack(self) -> None: - """GCG attack does not modify model weights.""" - pass + """GCG attack does not modify model weights directly.""" + self.output_checkpoint_path: str = self.attack_config.input_checkpoint_path @override def evaluate(self) -> DataFrame[EvaluationSchema]: