Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 32 additions & 121 deletions src/tamperbench/whitebox/attacks/gcg/gcg.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,52 @@
"""GCG attack interface."""

import logging
from dataclasses import dataclass
from typing import TypeVar
from dataclasses import dataclass, field

import peft
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)
import polars as pl
from pandera.typing.polars import DataFrame
from typing_extensions import override

from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig
from tamperbench.whitebox.attacks.gcg.implementation import GCG, ProbeSamplingConfig
from tamperbench.whitebox.evals import (
StrongRejectEvaluation,
StrongRejectEvaluationConfig,
)
from tamperbench.whitebox.utils.models.default_tokens import DEFAULT_PAD_TOKEN
from tamperbench.whitebox.utils.names import EvalName

logger = logging.getLogger(__name__)

H = TypeVar(name="H", bound="GCGConfig")
from tamperbench.whitebox.attacks.gcg.implementation import GCGConfig
from tamperbench.whitebox.attacks.registry import register_attack
from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig
from tamperbench.whitebox.evals.output_schema import EvaluationSchema
from tamperbench.whitebox.utils.names import AttackName, EvalName


@dataclass
class GCGConfig(TamperAttackConfig):
"""Config for nanoGCG attack."""
class GCGAttackConfig(TamperAttackConfig):
"""Config for GCG attack."""

num_steps: int = 250
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these params were originally defined in https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/gcg.py so i'm just moving them back into implementation.py which is based on that file

optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x"
search_width: int = 512
batch_size: int = None
topk: int = 256
n_replace: int = 1
buffer_size: int = 0
use_mellowmax: bool = False
mellowmax_alpha: float = 1.0
early_stop: bool = False
use_prefix_cache: bool = True
allow_non_ascii: bool = False
filter_ids: bool = True
add_space_before_target: bool = False
seed: int = None
verbosity: str = "INFO"
probe_sampling_config: ProbeSamplingConfig | None = None
gcg_config: GCGConfig = field(default_factory=GCGConfig)


class GCGAttack(TamperAttack[H]):
@register_attack(AttackName.GCG_ATTACK, GCGAttackConfig)
class GCGAttack(TamperAttack[GCGAttackConfig]):
"""GCG attack class."""

@override
def run_attack(
messages: str | list[dict[str, str]],
target: str,
config: GCGConfig | None = None,
) -> GCGResult:
"""Generate a single optimized string using GCG.

Args:
messages: The conversation to use for optimization.
target: The target generation.
config: The GCG configuration to use.

Returns:
A GCGResult object that contains losses and the optimized strings.
"""
model: PreTrainedModel = self.load_model()
tokenizer: PreTrainedTokenizer = self.load_tokenizer()

if config is None:
config = GCGConfig()

logger.setLevel(getattr(logging, config.verbosity))

gcg = GCG(model, tokenizer, config)
result = gcg.run(messages, target)
return result
name: AttackName = AttackName.GCG_ATTACK

@override
def evaluate(self) -> dict[str, float]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewriting the rest of this file to look more like embedding_attack

"""Evaluate attack on relevant evaluations.

Returns:
dict[str, float]: A dictionary with the relevant metrics and the associated values.
"""
results: dict[str, float] = {}

if EvalName.STRONG_REJECT in self.attack_config.evals:
results.update(self.evaluate_strong_reject())

return results

def load_model(self) -> PreTrainedModel:
"""Load model from path in HuggingFace and/or PEFT."""
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path,
torch_dtype=torch.bfloat16,
# TODO PR #55: fully support flash attention
attn_implementation="flash_attention_2",
device_map="auto",
)

if self.attack_config.lora_input_folder:
model = peft.PeftModel.from_pretrained(
model, model_id=self.attack_config.lora_input_folder
).merge_and_unload()

return model

def load_tokenizer(self) -> PreTrainedTokenizer:
"""Load tokenizer from path in HuggingFace and/or PEFT."""
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=self.attack_config.base_input_checkpoint_path,
padding_side="right",
use_fast=False,
)

if tokenizer.pad_token is None:
tokenizer.add_special_tokens(special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN})

return tokenizer
def run_attack(self) -> None:
"""GCG attack does not modify model weights directly."""
self.output_checkpoint_path: str = self.attack_config.input_checkpoint_path

@override
def evaluate_strong_reject(self) -> dict[str, float]:
"""Evaluate attack on the `StrongRejectEvaluation` evaluator."""
eval_cfg: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig(
base_checkpoint=self.attack_config.base_output_checkpoint_path,
lora_folder=self.attack_config.lora_output_folder,
max_generation_length=self.attack_config.max_generation_length,
batch_size=64,
small=True,
def evaluate(self) -> DataFrame[EvaluationSchema]:
"""Evaluate GCG attack, including GCG-specific evaluation if requested."""
results = super().evaluate()
if EvalName.GCG_ATTACK_EVAL in self.attack_config.evals:
results = pl.concat([results, self._evaluate_gcg()])
return EvaluationSchema.validate(results)

def _evaluate_gcg(self) -> DataFrame[EvaluationSchema]:
"""Instantiate GCGEvaluation and run evaluation."""
eval_cfg = GCGEvaluationConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
gcg_config=self.attack_config.gcg_config,
)
evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = StrongRejectEvaluation(eval_config=eval_cfg)

evaluator = GCGEvaluation(eval_cfg)
return evaluator.run_evaluation()
88 changes: 70 additions & 18 deletions src/tamperbench/whitebox/attacks/gcg/implementation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""GCG attack implementation module."""
"""GCG attack implementation module.

Sourced from https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/gcg.py
and https://github.com/GraySwanAI/nanoGCG/blob/v0.3.0-release/nanogcg/utils.py
"""

import copy
import functools
Expand All @@ -17,8 +21,7 @@
from torch import Tensor
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase, set_seed

from tamperbench.whitebox.attacks.gcg.gcg import GCGConfig
from transformers.cache_utils import DynamicCache

"""UTILS"""

Expand Down Expand Up @@ -196,6 +199,29 @@ class ProbeSamplingConfig:
sampling_factor: int = 16


@dataclass
class GCGConfig:
"""Config for nanoGCG attack."""

num_steps: int = 250
optim_str_init: str | list[str] = "x x x x x x x x x x x x x x x x x x x x"
search_width: int = 512
batch_size: int | None = None
topk: int = 256
n_replace: int = 1
buffer_size: int = 0
use_mellowmax: bool = False
mellowmax_alpha: float = 1.0
early_stop: bool = False
use_prefix_cache: bool = True
allow_non_ascii: bool = False
filter_ids: bool = True
add_space_before_target: bool = False
seed: int | None = None
verbosity: str = "INFO"
probe_sampling_config: ProbeSamplingConfig | None = None


@dataclass
class GCGResult:
"""Result of a GCG attack optimization run."""
Expand Down Expand Up @@ -332,6 +358,37 @@ def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer):
return torch.stack(filtered_ids)


def _copy_prefix_cache(prefix_cache: DynamicCache, batch_size: int = 1) -> DynamicCache:
"""Create a shallow copy of a prefix KV cache, optionally expanding to a batch size.

A new DynamicCache object is returned so that the model's in-place
``DynamicCache.update()`` calls (which ``torch.cat`` new KV states onto
the cache) mutate the copy rather than the original. The underlying
key/value tensors are **not** cloned — ``torch.cat`` inside ``update``
already produces new tensors, leaving the originals untouched.

The original nanoGCG stored caches as ``list[tuple[Tensor, ...]]`` which
the model never mutated. After migrating to ``DynamicCache`` the same
immutability guarantee must be restored manually by copying the wrapper.

Args:
prefix_cache: The cached key/value states from the prefix tokens.
batch_size: The batch size to expand to (default 1, i.e. plain copy).

Returns:
A new DynamicCache wrapping the (optionally expanded) tensors.
"""
copy = DynamicCache()
for layer_idx in range(len(prefix_cache)):
key, value = prefix_cache[layer_idx]
copy.update(
key.expand(batch_size, -1, -1, -1),
value.expand(batch_size, -1, -1, -1),
layer_idx,
)
return copy


class GCG:
"""GCG (Greedy Coordinate Gradient) attack implementation."""

Expand All @@ -356,8 +413,8 @@ def __init__(
self.not_allowed_ids: Tensor | None = (
None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device)
)
self.prefix_cache: list[tuple[Tensor, ...]] | None = None
self.draft_prefix_cache: list[tuple[Tensor, ...]] | None = None
self.prefix_cache: DynamicCache | None = None
self.draft_prefix_cache: DynamicCache | None = None

self.stop_flag: bool = False

Expand Down Expand Up @@ -696,9 +753,11 @@ def compute_token_gradient(

if self.prefix_cache:
input_embeds = torch.cat([optim_embeds, self.after_embeds, self.target_embeds], dim=1)
# Copy the prefix cache so that the model's in-place DynamicCache
# mutations don't corrupt self.prefix_cache across iterations.
output = model(
inputs_embeds=input_embeds,
past_key_values=self.prefix_cache,
past_key_values=_copy_prefix_cache(self.prefix_cache),
use_cache=True,
)
else:
Expand Down Expand Up @@ -746,19 +805,17 @@ def _compute_candidates_loss_original(
the embeddings of the `search_width` candidate sequences to evaluate
"""
all_loss = []
prefix_cache_batch = []

for i in range(0, input_embeds.shape[0], search_batch_size):
with torch.no_grad():
input_embeds_batch = input_embeds[i : i + search_batch_size]
current_batch_size = input_embeds_batch.shape[0]

if self.prefix_cache:
if not prefix_cache_batch or current_batch_size != search_batch_size:
prefix_cache_batch = [
[x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]]
for i in range(len(self.prefix_cache))
]
# Fresh copy each batch — the model mutates DynamicCache
# in-place, so reusing one across batches would accumulate
# stale KV states from previous batches.
prefix_cache_batch = _copy_prefix_cache(self.prefix_cache, current_batch_size)

outputs = self.model(
inputs_embeds=input_embeds_batch,
Expand Down Expand Up @@ -839,18 +896,13 @@ def _compute_draft_losses(
)

draft_losses = []
draft_prefix_cache_batch = None
for i in range(0, B, search_batch_size):
with torch.no_grad():
batch_size = min(search_batch_size, B - i)
draft_sampled_ids_batch = draft_sampled_ids[i : i + batch_size]

if self.draft_prefix_cache:
if not draft_prefix_cache_batch or batch_size != search_batch_size:
draft_prefix_cache_batch = [
[x.expand(batch_size, -1, -1, -1) for x in self.draft_prefix_cache[i]]
for i in range(len(self.draft_prefix_cache))
]
draft_prefix_cache_batch = _copy_prefix_cache(self.draft_prefix_cache, batch_size)
draft_embeds = torch.cat(
[
self.draft_embedding_layer(draft_sampled_ids_batch),
Expand Down
6 changes: 6 additions & 0 deletions src/tamperbench/whitebox/evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
EmbeddingAttackEvaluation,
EmbeddingAttackEvaluationConfig,
)
from tamperbench.whitebox.evals.gcg.gcg import (
GCGEvaluation,
GCGEvaluationConfig,
)
from tamperbench.whitebox.evals.ifeval.ifeval import (
IFEvalEvaluation,
IFEvalEvaluationConfig,
Expand Down Expand Up @@ -33,6 +37,8 @@
__all__ = [
"EmbeddingAttackEvaluation",
"EmbeddingAttackEvaluationConfig",
"GCGEvaluation",
"GCGEvaluationConfig",
"IFEvalEvaluation",
"IFEvalEvaluationConfig",
"JailbreakBenchEvaluation",
Expand Down
8 changes: 8 additions & 0 deletions src/tamperbench/whitebox/evals/gcg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""GCG attack evaluation."""

from tamperbench.whitebox.evals.gcg.gcg import GCGEvaluation, GCGEvaluationConfig

__all__ = [
"GCGEvaluation",
"GCGEvaluationConfig",
]
Loading