Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/tamperbench/whitebox/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
embedding_attack as _embedding,
)

# Re-export embedding attack classes for convenience
# Re-export attack classes for convenience
from tamperbench.whitebox.attacks.embedding_attack.embedding_attack import (
EmbeddingAttack,
EmbeddingAttackConfig,
Expand All @@ -21,6 +21,10 @@
from tamperbench.whitebox.attacks.jailbreak_finetune import (
jailbreak_finetune as _jailbreak,
)
from tamperbench.whitebox.attacks.latent_attack.latent_attack import (
LatentAttack,
LatentAttackConfig,
)
from tamperbench.whitebox.attacks.lora_finetune import lora_finetune as _lora
from tamperbench.whitebox.attacks.multilingual_finetune import (
multilingual_finetune as _multilingual,
Expand All @@ -31,4 +35,6 @@
__all__ = [
"EmbeddingAttack",
"EmbeddingAttackConfig",
"LatentAttack",
"LatentAttackConfig",
]
14 changes: 14 additions & 0 deletions src/tamperbench/whitebox/attacks/latent_attack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Sheshadri et. al's LLM Latent Attack.

The code is primarily sourced from https://github.com/aengusl/latent-adversarial-training/
but may contain modifications to code to allow for enhanced readability and integrability with tamperbench.
The work can be found at https://arxiv.org/pdf/2407.15549 and can be cited as follows:
```
@article{sheshadri2024targeted,
title={Targeted Latent Adversarial Training Improves Robustness to Persistent Harmful Behaviors in LLMs},
author={Sheshadri, Abhay and Ewart, Aidan and Guo, Phillip and Lynch, Aengus and Wu, Cindy and Hebbar, Vivek and Sleight, Henry and Stickland, Asa Cooper and Perez, Ethan and Hadfield-Menell, Dylan and Casper, Stephen},
journal={arXiv preprint arXiv:2407.15549},
year={2024}
}
```
"""
62 changes: 62 additions & 0 deletions src/tamperbench/whitebox/attacks/latent_attack/latent_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Latent attack wrapper."""

from dataclasses import dataclass

import polars as pl
from pandera.typing.polars import DataFrame
from typing_extensions import override

from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig
from tamperbench.whitebox.attacks.registry import register_attack
from tamperbench.whitebox.evals.latent_attack import (
LatentAttackEvaluation,
LatentAttackEvaluationConfig,
)
from tamperbench.whitebox.evals.output_schema import EvaluationSchema
from tamperbench.whitebox.utils.names import AttackName, EvalName


@dataclass
class LatentAttackConfig(TamperAttackConfig):
"""Config for latent attack wrapper.

Attributes:
epsilon: L-infinity norm bound for the latent perturbation.
attack_layer: Dot-path to module to perturb (e.g., "transformer.h.0.mlp").
small: Whether to use the small dataset for testing.
"""

epsilon: float
attack_layer: str
small: bool


@register_attack(AttackName.LATENT_ATTACK, LatentAttackConfig)
class LatentAttack(TamperAttack[LatentAttackConfig]):
"""Latent attack class."""

name: AttackName = AttackName.LATENT_ATTACK

@override
def run_attack(self) -> None:
"""Latent attack does not modify weights directly."""

@override
def evaluate(self) -> DataFrame[EvaluationSchema]:
"""Evaluate latent attack by running the evaluator."""
results = super().evaluate()
if EvalName.LATENT_ATTACK in self.attack_config.evals:
results = pl.concat([results, self._evaluate_latent_attack()])
return EvaluationSchema.validate(results)

def _evaluate_latent_attack(self) -> DataFrame[EvaluationSchema]:
"""Instantiate LatentAttackEvaluation and run evaluation."""
eval_cfg = LatentAttackEvaluationConfig(
model_checkpoint=self.attack_config.input_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
epsilon=self.attack_config.epsilon,
attack_layer=self.attack_config.attack_layer,
)
evaluator = LatentAttackEvaluation(eval_cfg)
return evaluator.run_evaluation()
6 changes: 6 additions & 0 deletions src/tamperbench/whitebox/evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
IFEvalEvaluation,
IFEvalEvaluationConfig,
)
from tamperbench.whitebox.evals.latent_attack.latent_attack import (
LatentAttackEvaluation,
LatentAttackEvaluationConfig,
)
from tamperbench.whitebox.evals.mbpp.mbpp import (
MBPPEvaluation,
MBPPEvaluationConfig,
Expand All @@ -34,6 +38,8 @@
"EmbeddingAttackEvaluationConfig",
"IFEvalEvaluation",
"IFEvalEvaluationConfig",
"LatentAttackEvaluation",
"LatentAttackEvaluationConfig",
"MBPPEvaluation",
"MBPPEvaluationConfig",
"MMLUProTestEvaluation",
Expand Down
28 changes: 28 additions & 0 deletions src/tamperbench/whitebox/evals/latent_attack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Sheshadri et al's LLM Latent Attack.

The code is primarily sourced from https://github.com/aengusl/latent-adversarial-training/
but may contain modifications to code to allow for enhanced readability and integrability with tamperbench.
The work can be found at https://arxiv.org/pdf/2407.15549 and can be cited as follows:
```
@article{sheshadri2024targeted,
title={Targeted Latent Adversarial Training Improves Robustness to Persistent Harmful Behaviors in LLMs},
author={Sheshadri, Abhay and Ewart, Aidan and Guo, Phillip and Lynch, Aengus and Wu, Cindy and Hebbar, Vivek and Sleight, Henry and Stickland, Asa Cooper and Perez, Ethan and Hadfield-Menell, Dylan and Casper, Stephen},
journal={arXiv preprint arXiv:2407.15549},
year={2024}
}
```
"""

from tamperbench.whitebox.evals.latent_attack.attack_class import GDAdversary
from tamperbench.whitebox.evals.latent_attack.hooks import CustomHook
from tamperbench.whitebox.evals.latent_attack.latent_attack import (
LatentAttackEvaluation,
LatentAttackEvaluationConfig,
)

__all__ = [
"CustomHook",
"GDAdversary",
"LatentAttackEvaluation",
"LatentAttackEvaluationConfig",
]
221 changes: 221 additions & 0 deletions src/tamperbench/whitebox/evals/latent_attack/attack_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""Ported from https://github.com/aengusl/latent-adversarial-training/."""

import torch
from typing_extensions import override


class LowRankAdversary(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we are only using GDAdversary in the attack, do we want it to be configurable to use these?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to keep as many helpers as needed. Can remove as needed.

"""Adversary parameterized by a low-rank MLP (LoRA-style).

This adversary applies a low-rank perturbation to the input activations
by projecting down to a low-dimensional space (via `lora_A`) and then
projecting back up (via `lora_B`), before adding the perturbation to
the original input.
"""

def __init__(
self,
dim: int,
rank: int,
device: str,
bias: bool = False,
zero_init: bool = True,
):
"""Initialize LowRankAdversary.

Args:
dim: Dimensionality of the input activations.
rank: Low-rank dimension for the intermediate projection.
device: Device on which to place the model.
bias: Whether to use a bias term in `lora_B`.
zero_init: Whether to initialize `lora_B` weights to zero.
"""
super().__init__()
self.dim: int = dim
self.rank: int = rank
self.device: str = device
self.lora_a: torch.nn.Linear = torch.nn.Linear(dim, rank, bias=False).to(device)
self.lora_b: torch.nn.Linear = torch.nn.Linear(rank, dim, bias=bias).to(device)
if zero_init:
self.lora_b.weight.data.zero_()

@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply low-rank adversarial perturbation to the input."""
return self.lora_b(self.lora_a(x)) + x


class FullRankAdversary(torch.nn.Module):
"""Adversary parameterized by a full-rank MLP.

This adversary learns a full linear transformation `m` that generates
perturbations to add to the input activations.
"""

def __init__(self, dim: int, device: str, bias: bool = False):
"""Initialize FullRankAdversary.

Args:
dim: Dimensionality of the input activations.
device: Device on which to place the model.
bias: Whether to use a bias term in the linear layer.
"""
super().__init__()
self.dim: int = dim
self.device: str = device
self.m: torch.nn.Linear = torch.nn.Linear(dim, dim, bias=bias).to(device)

# Start with no perturbation (weights initialized to zero).
self.m.weight.data.zero_()

@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply full-rank adversarial perturbation to the input."""
return self.m(x) + x


class GDAdversary(torch.nn.Module):
"""Standard projected gradient descent (PGD)-style adversary.

Perturbations are stored as trainable parameters and projected back
into the epsilon-ball to constrain their norm.
"""

def __init__(
self,
dim: int,
epsilon: float,
attack_mask: torch.Tensor,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""Initialize GD adversary.

Args:
dim: Dimensionality of the input activations.
epsilon: Maximum perturbation norm.
attack_mask: Boolean mask indicating which positions to perturb.
device: Device for tensors.
dtype: Optional data type for perturbations.
"""
super().__init__()
self.device: str | torch.device | None = device
self.epsilon: float = epsilon

if dtype:
self.attack: torch.nn.Parameter = torch.nn.Parameter(
torch.zeros(
attack_mask.shape[0],
attack_mask.shape[1],
dim,
device=self.device,
dtype=dtype,
)
)
else:
self.attack = torch.nn.Parameter(
torch.zeros(attack_mask.shape[0], attack_mask.shape[1], dim, device=self.device)
)
self.clip_attack()
self.attack_mask: torch.Tensor = attack_mask

@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Add adversarial perturbations to the input tensor.

Handles both generation mode (no perturbation applied if sequence
length == 1) and training mode (perturb according to mask).
"""
if x.shape[1] == 1 and self.attack.shape[1] != 1:
# Generation mode: perturbation already applied
return x
else:
# Ensure perturbation tensor is on the same device
if self.device is None or self.device != x.device:
with torch.no_grad():
self.device = x.device
self.attack.data = self.attack.data.to(self.device)
self.attack_mask = self.attack_mask.to(self.device)

# Apply perturbations at masked positions
perturbed_acts = x[self.attack_mask[:, : x.shape[1]]] + self.attack[self.attack_mask[:, : x.shape[1]]].to(
x.dtype
)
x[self.attack_mask[:, : x.shape[1]]] = perturbed_acts

return x

def clip_attack(self) -> None:
"""Project the adversarial perturbations back into the epsilon-ball.

Ensures that perturbation norms do not exceed `epsilon`.
"""
with torch.no_grad():
norms = torch.norm(self.attack, dim=-1, keepdim=True)
scale = torch.clamp(norms / self.epsilon, min=1)
self.attack.div_(scale)


class WhitenedGDAdversary(torch.nn.Module):
"""Whitened variant of GD adversary.

Perturbations are trained in a whitened space (e.g., via PCA).
A projection matrix is used to map between whitened and original spaces.
"""

def __init__(
self,
dim: int,
device: str,
epsilon: float,
attack_mask: torch.Tensor,
proj: torch.Tensor | None = None,
inv_proj: torch.Tensor | None = None,
):
"""Initialize WhitenedPDG adversary.

Args:
dim: Dimensionality of the input activations.
device: Device for tensors.
epsilon: Maximum perturbation norm.
attack_mask: Boolean mask indicating which positions to perturb.
proj: Whitening projection matrix (e.g., PCA basis).
inv_proj: Optional precomputed inverse of the projection matrix.
"""
super().__init__()
self.device: str = device
self.epsilon: float = epsilon
self.proj: torch.Tensor | None = proj

if inv_proj is None and proj is not None:
self.inv_proj: torch.Tensor = torch.inverse(proj)
elif inv_proj is not None:
self.inv_proj = inv_proj
else:
raise ValueError("At least one of `proj` or `inv_proj` must be provided")

# Initialize attack perturbations in whitened space
self.attack: torch.nn.Parameter = torch.nn.Parameter(
torch.randn(attack_mask.shape[0], attack_mask.shape[1], dim, device=self.device)
)
self.clip_attack()
self.attack_mask: torch.Tensor = attack_mask

@override
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply adversarial perturbation by projecting from whitened space.

The perturbation is first unprojected via `inv_proj` before being
added to the original activations.
"""
# einsum notation: n = whitened dim, d = original hidden size
unprojected_attack = torch.einsum("nd,bsn->bsd", self.inv_proj, self.attack)
x[self.attack_mask[:, : x.shape[1]]] = (x + unprojected_attack.to(x.dtype))[self.attack_mask[:, : x.shape[1]]]
return x

def clip_attack(self) -> None:
"""Project perturbations back into the epsilon-ball in whitened space."""
with torch.no_grad():
norms = torch.norm(self.attack, dim=-1, keepdim=True)
scale = torch.clamp(norms / self.epsilon, min=1)
self.attack.div_(scale)
Loading