-
Notifications
You must be signed in to change notification settings - Fork 1
attack: added latent perturbation attack #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
psyonp
wants to merge
18
commits into
main
Choose a base branch
from
psyonp/latent_attack
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
77b1d4d
Migrate hooks.py
psyonp 16d8683
Change setup
psyonp 3846231
Refactor latent_attack
psyonp 7326717
Refactor as requested with eval
psyonp dc83e0d
Add tests
psyonp 30e87da
Made requested changes for nits and docstrings
psyonp 02ef53b
Update tests/evals/test_latent_attack_eval.py
psyonp 97f00b9
Update src/safetunebed/whitebox/evals/latent_attack/latent_attack.py
psyonp 6f0774d
Update src/safetunebed/whitebox/attacks/latent_attack/latent_attack.py
psyonp 432af6c
Update tests/attacks/test_latent_attack.py
psyonp 2b9dabf
Add link for ported code
psyonp c4019d8
Add link to hooks.py
psyonp 408b9bd
Name change of parent_path to parent_layer
psyonp d99cc89
Merge branch 'main' into psyonp/latent_attack
tomtseng 31bf8a9
evals/latent_attack: Lint
tomtseng eda2fc1
latent_attack: Fix type errors
tomtseng 5b9c5e4
latent_attack: Fix merge errors
tomtseng cc516cf
latent_attack: Remove redundant test
tomtseng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
src/tamperbench/whitebox/attacks/latent_attack/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| """Sheshadri et. al's LLM Latent Attack. | ||
|
|
||
| The code is primarily sourced from https://github.com/aengusl/latent-adversarial-training/ | ||
| but may contain modifications to code to allow for enhanced readability and integrability with tamperbench. | ||
| The work can be found at https://arxiv.org/pdf/2407.15549 and can be cited as follows: | ||
| ``` | ||
| @article{sheshadri2024targeted, | ||
| title={Targeted Latent Adversarial Training Improves Robustness to Persistent Harmful Behaviors in LLMs}, | ||
| author={Sheshadri, Abhay and Ewart, Aidan and Guo, Phillip and Lynch, Aengus and Wu, Cindy and Hebbar, Vivek and Sleight, Henry and Stickland, Asa Cooper and Perez, Ethan and Hadfield-Menell, Dylan and Casper, Stephen}, | ||
| journal={arXiv preprint arXiv:2407.15549}, | ||
| year={2024} | ||
| } | ||
| ``` | ||
| """ |
62 changes: 62 additions & 0 deletions
62
src/tamperbench/whitebox/attacks/latent_attack/latent_attack.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| """Latent attack wrapper.""" | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import polars as pl | ||
| from pandera.typing.polars import DataFrame | ||
| from typing_extensions import override | ||
|
|
||
| from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig | ||
| from tamperbench.whitebox.attacks.registry import register_attack | ||
| from tamperbench.whitebox.evals.latent_attack import ( | ||
| LatentAttackEvaluation, | ||
| LatentAttackEvaluationConfig, | ||
| ) | ||
| from tamperbench.whitebox.evals.output_schema import EvaluationSchema | ||
| from tamperbench.whitebox.utils.names import AttackName, EvalName | ||
|
|
||
|
|
||
| @dataclass | ||
| class LatentAttackConfig(TamperAttackConfig): | ||
| """Config for latent attack wrapper. | ||
|
|
||
| Attributes: | ||
| epsilon: L-infinity norm bound for the latent perturbation. | ||
| attack_layer: Dot-path to module to perturb (e.g., "transformer.h.0.mlp"). | ||
| small: Whether to use the small dataset for testing. | ||
| """ | ||
|
|
||
| epsilon: float | ||
| attack_layer: str | ||
| small: bool | ||
|
|
||
|
|
||
| @register_attack(AttackName.LATENT_ATTACK, LatentAttackConfig) | ||
| class LatentAttack(TamperAttack[LatentAttackConfig]): | ||
| """Latent attack class.""" | ||
|
|
||
| name: AttackName = AttackName.LATENT_ATTACK | ||
|
|
||
| @override | ||
| def run_attack(self) -> None: | ||
| """Latent attack does not modify weights directly.""" | ||
|
|
||
| @override | ||
| def evaluate(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate latent attack by running the evaluator.""" | ||
| results = super().evaluate() | ||
| if EvalName.LATENT_ATTACK in self.attack_config.evals: | ||
| results = pl.concat([results, self._evaluate_latent_attack()]) | ||
| return EvaluationSchema.validate(results) | ||
|
|
||
| def _evaluate_latent_attack(self) -> DataFrame[EvaluationSchema]: | ||
| """Instantiate LatentAttackEvaluation and run evaluation.""" | ||
| eval_cfg = LatentAttackEvaluationConfig( | ||
| model_checkpoint=self.attack_config.input_checkpoint_path, | ||
| out_dir=self.attack_config.out_dir, | ||
| model_config=self.attack_config.model_config, | ||
| epsilon=self.attack_config.epsilon, | ||
| attack_layer=self.attack_config.attack_layer, | ||
| ) | ||
| evaluator = LatentAttackEvaluation(eval_cfg) | ||
| return evaluator.run_evaluation() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """Sheshadri et al's LLM Latent Attack. | ||
|
|
||
| The code is primarily sourced from https://github.com/aengusl/latent-adversarial-training/ | ||
| but may contain modifications to code to allow for enhanced readability and integrability with tamperbench. | ||
| The work can be found at https://arxiv.org/pdf/2407.15549 and can be cited as follows: | ||
| ``` | ||
| @article{sheshadri2024targeted, | ||
| title={Targeted Latent Adversarial Training Improves Robustness to Persistent Harmful Behaviors in LLMs}, | ||
| author={Sheshadri, Abhay and Ewart, Aidan and Guo, Phillip and Lynch, Aengus and Wu, Cindy and Hebbar, Vivek and Sleight, Henry and Stickland, Asa Cooper and Perez, Ethan and Hadfield-Menell, Dylan and Casper, Stephen}, | ||
| journal={arXiv preprint arXiv:2407.15549}, | ||
| year={2024} | ||
| } | ||
| ``` | ||
| """ | ||
|
|
||
| from tamperbench.whitebox.evals.latent_attack.attack_class import GDAdversary | ||
| from tamperbench.whitebox.evals.latent_attack.hooks import CustomHook | ||
| from tamperbench.whitebox.evals.latent_attack.latent_attack import ( | ||
| LatentAttackEvaluation, | ||
| LatentAttackEvaluationConfig, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "CustomHook", | ||
| "GDAdversary", | ||
| "LatentAttackEvaluation", | ||
| "LatentAttackEvaluationConfig", | ||
| ] |
221 changes: 221 additions & 0 deletions
221
src/tamperbench/whitebox/evals/latent_attack/attack_class.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| """Ported from https://github.com/aengusl/latent-adversarial-training/.""" | ||
|
|
||
| import torch | ||
| from typing_extensions import override | ||
|
|
||
|
|
||
| class LowRankAdversary(torch.nn.Module): | ||
| """Adversary parameterized by a low-rank MLP (LoRA-style). | ||
|
|
||
| This adversary applies a low-rank perturbation to the input activations | ||
| by projecting down to a low-dimensional space (via `lora_A`) and then | ||
| projecting back up (via `lora_B`), before adding the perturbation to | ||
| the original input. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| rank: int, | ||
| device: str, | ||
| bias: bool = False, | ||
| zero_init: bool = True, | ||
| ): | ||
| """Initialize LowRankAdversary. | ||
|
|
||
| Args: | ||
| dim: Dimensionality of the input activations. | ||
| rank: Low-rank dimension for the intermediate projection. | ||
| device: Device on which to place the model. | ||
| bias: Whether to use a bias term in `lora_B`. | ||
| zero_init: Whether to initialize `lora_B` weights to zero. | ||
| """ | ||
| super().__init__() | ||
| self.dim: int = dim | ||
| self.rank: int = rank | ||
| self.device: str = device | ||
| self.lora_a: torch.nn.Linear = torch.nn.Linear(dim, rank, bias=False).to(device) | ||
| self.lora_b: torch.nn.Linear = torch.nn.Linear(rank, dim, bias=bias).to(device) | ||
| if zero_init: | ||
| self.lora_b.weight.data.zero_() | ||
|
|
||
| @override | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Apply low-rank adversarial perturbation to the input.""" | ||
| return self.lora_b(self.lora_a(x)) + x | ||
|
|
||
|
|
||
| class FullRankAdversary(torch.nn.Module): | ||
| """Adversary parameterized by a full-rank MLP. | ||
|
|
||
| This adversary learns a full linear transformation `m` that generates | ||
| perturbations to add to the input activations. | ||
| """ | ||
|
|
||
| def __init__(self, dim: int, device: str, bias: bool = False): | ||
| """Initialize FullRankAdversary. | ||
|
|
||
| Args: | ||
| dim: Dimensionality of the input activations. | ||
| device: Device on which to place the model. | ||
| bias: Whether to use a bias term in the linear layer. | ||
| """ | ||
| super().__init__() | ||
| self.dim: int = dim | ||
| self.device: str = device | ||
| self.m: torch.nn.Linear = torch.nn.Linear(dim, dim, bias=bias).to(device) | ||
|
|
||
| # Start with no perturbation (weights initialized to zero). | ||
| self.m.weight.data.zero_() | ||
|
|
||
| @override | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Apply full-rank adversarial perturbation to the input.""" | ||
| return self.m(x) + x | ||
|
|
||
|
|
||
| class GDAdversary(torch.nn.Module): | ||
| """Standard projected gradient descent (PGD)-style adversary. | ||
|
|
||
| Perturbations are stored as trainable parameters and projected back | ||
| into the epsilon-ball to constrain their norm. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| epsilon: float, | ||
| attack_mask: torch.Tensor, | ||
| device: str | torch.device | None = None, | ||
| dtype: torch.dtype | None = None, | ||
| ): | ||
| """Initialize GD adversary. | ||
|
|
||
| Args: | ||
| dim: Dimensionality of the input activations. | ||
| epsilon: Maximum perturbation norm. | ||
| attack_mask: Boolean mask indicating which positions to perturb. | ||
| device: Device for tensors. | ||
| dtype: Optional data type for perturbations. | ||
| """ | ||
| super().__init__() | ||
| self.device: str | torch.device | None = device | ||
| self.epsilon: float = epsilon | ||
|
|
||
| if dtype: | ||
| self.attack: torch.nn.Parameter = torch.nn.Parameter( | ||
| torch.zeros( | ||
| attack_mask.shape[0], | ||
| attack_mask.shape[1], | ||
| dim, | ||
| device=self.device, | ||
| dtype=dtype, | ||
| ) | ||
| ) | ||
| else: | ||
| self.attack = torch.nn.Parameter( | ||
| torch.zeros(attack_mask.shape[0], attack_mask.shape[1], dim, device=self.device) | ||
| ) | ||
| self.clip_attack() | ||
| self.attack_mask: torch.Tensor = attack_mask | ||
|
|
||
| @override | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Add adversarial perturbations to the input tensor. | ||
|
|
||
| Handles both generation mode (no perturbation applied if sequence | ||
| length == 1) and training mode (perturb according to mask). | ||
| """ | ||
| if x.shape[1] == 1 and self.attack.shape[1] != 1: | ||
| # Generation mode: perturbation already applied | ||
| return x | ||
| else: | ||
| # Ensure perturbation tensor is on the same device | ||
| if self.device is None or self.device != x.device: | ||
| with torch.no_grad(): | ||
| self.device = x.device | ||
| self.attack.data = self.attack.data.to(self.device) | ||
| self.attack_mask = self.attack_mask.to(self.device) | ||
|
|
||
| # Apply perturbations at masked positions | ||
| perturbed_acts = x[self.attack_mask[:, : x.shape[1]]] + self.attack[self.attack_mask[:, : x.shape[1]]].to( | ||
| x.dtype | ||
| ) | ||
| x[self.attack_mask[:, : x.shape[1]]] = perturbed_acts | ||
|
|
||
| return x | ||
|
|
||
| def clip_attack(self) -> None: | ||
| """Project the adversarial perturbations back into the epsilon-ball. | ||
|
|
||
| Ensures that perturbation norms do not exceed `epsilon`. | ||
| """ | ||
| with torch.no_grad(): | ||
| norms = torch.norm(self.attack, dim=-1, keepdim=True) | ||
| scale = torch.clamp(norms / self.epsilon, min=1) | ||
| self.attack.div_(scale) | ||
|
|
||
|
|
||
| class WhitenedGDAdversary(torch.nn.Module): | ||
| """Whitened variant of GD adversary. | ||
|
|
||
| Perturbations are trained in a whitened space (e.g., via PCA). | ||
| A projection matrix is used to map between whitened and original spaces. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| device: str, | ||
| epsilon: float, | ||
| attack_mask: torch.Tensor, | ||
| proj: torch.Tensor | None = None, | ||
| inv_proj: torch.Tensor | None = None, | ||
| ): | ||
| """Initialize WhitenedPDG adversary. | ||
|
|
||
| Args: | ||
| dim: Dimensionality of the input activations. | ||
| device: Device for tensors. | ||
| epsilon: Maximum perturbation norm. | ||
| attack_mask: Boolean mask indicating which positions to perturb. | ||
| proj: Whitening projection matrix (e.g., PCA basis). | ||
| inv_proj: Optional precomputed inverse of the projection matrix. | ||
| """ | ||
| super().__init__() | ||
| self.device: str = device | ||
| self.epsilon: float = epsilon | ||
| self.proj: torch.Tensor | None = proj | ||
|
|
||
| if inv_proj is None and proj is not None: | ||
| self.inv_proj: torch.Tensor = torch.inverse(proj) | ||
| elif inv_proj is not None: | ||
| self.inv_proj = inv_proj | ||
| else: | ||
| raise ValueError("At least one of `proj` or `inv_proj` must be provided") | ||
|
|
||
| # Initialize attack perturbations in whitened space | ||
| self.attack: torch.nn.Parameter = torch.nn.Parameter( | ||
| torch.randn(attack_mask.shape[0], attack_mask.shape[1], dim, device=self.device) | ||
| ) | ||
| self.clip_attack() | ||
| self.attack_mask: torch.Tensor = attack_mask | ||
|
|
||
| @override | ||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Apply adversarial perturbation by projecting from whitened space. | ||
|
|
||
| The perturbation is first unprojected via `inv_proj` before being | ||
| added to the original activations. | ||
| """ | ||
| # einsum notation: n = whitened dim, d = original hidden size | ||
| unprojected_attack = torch.einsum("nd,bsn->bsd", self.inv_proj, self.attack) | ||
| x[self.attack_mask[:, : x.shape[1]]] = (x + unprojected_attack.to(x.dtype))[self.attack_mask[:, : x.shape[1]]] | ||
| return x | ||
|
|
||
| def clip_attack(self) -> None: | ||
| """Project perturbations back into the epsilon-ball in whitened space.""" | ||
| with torch.no_grad(): | ||
| norms = torch.norm(self.attack, dim=-1, keepdim=True) | ||
| scale = torch.clamp(norms / self.epsilon, min=1) | ||
| self.attack.div_(scale) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see we are only using GDAdversary in the attack, do we want it to be configurable to use these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to keep as many helpers as needed. Can remove as needed.