Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/safetunebed/whitebox/attacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
"""Tampering whitebox attacks supported by safetunebed."""
"""Tampering whitebox attacks supported by safetunebed.

Import modules for side effects so they register via the attacks registry.
"""

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: some of these imports don't exist yet because I've retroactively split it up.

# Side-effect imports to populate ATTACKS_REGISTRY
from safetunebed.whitebox.attacks.embedding_attack import (
embedding_attack as _,
)

# Re-export embedding attack classes for convenience
from safetunebed.whitebox.attacks.embedding_attack.embedding_attack import (
EmbeddingAttack,
EmbeddingAttackConfig,
)
from safetunebed.whitebox.attacks.full_parameter_finetune import (
full_parameter_finetune as _,
)
from safetunebed.whitebox.attacks.jailbreak_finetune import (
jailbreak_finetune as _,
)
from safetunebed.whitebox.attacks.lora_finetune import lora_finetune as _
from safetunebed.whitebox.attacks.multilingual_finetune import (
multilingual_finetune as _, # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why does only this import have noqa: F401?

)

__all__ = [
"EmbeddingAttack",
Expand Down
132 changes: 122 additions & 10 deletions src/safetunebed/whitebox/attacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,44 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from pathlib import Path
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, final

import polars as pl
from pandera.typing.polars import DataFrame
from typing_extensions import Self

from safetunebed.whitebox.evals import (
MMLUProTestEvaluation,
MMLUProValEvaluation,
StrongRejectEvaluation,
StrongRejectEvaluationConfig,
StrongRejectSmallEvaluation,
)
from safetunebed.whitebox.evals.base import WhiteBoxEvaluationConfig
from safetunebed.whitebox.evals.output_schema import EvaluationSchema
from safetunebed.whitebox.utils.names import EvalName, EvalPath
from safetunebed.whitebox.utils.models.config import ModelConfig
from safetunebed.whitebox.utils.names import AttackName, EvalName, EvalPath

H = TypeVar(name="H", bound="TamperAttackConfig")


@dataclass
class TamperAttackConfig(ABC):
class TamperAttackConfig:
"""Configuration for attacks including hyperparameters.

Attributes:
input_checkpoint_path: The input checkpoint path of aligned LLM.
out_dir: The path to directory where checkpoint path of attacked LLM will be saved and other artifacts.
max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference.
evals: The evaluators that will evaluate the success of the attack.
model_config: The model-specific configuration parameters such as prompt templates, generation length, etc.
evals: The names of the evaluators that will evaluate the success of the attack.
"""

input_checkpoint_path: str
out_dir: str
max_generation_length: int
evals: list[EvalName]
model_config: ModelConfig

@classmethod
@abstractmethod
def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExplicitAny]
"""All subclasses must implement a `from_dict` constructor.

Expand All @@ -41,6 +50,37 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli
Returns:
TamperAttackConfig: Config object.
"""
cls._validate_dict_keys(data)
data = cls.serialize_data(data)

return cls(**data)

@classmethod
def serialize_data(cls, data: dict[str, Any]) -> dict[str, Any]: # pyright: ignore[reportExplicitAny]
"""Serialize data from a dictionary such that it can be used to construct nested objects.

Args:
data: Dictionary containing the configuration fields.

Returns:
dict[str, Any]: Dictionary containing the configuration fields.
"""
model_config_dict = data.pop("model_config") # pyright: ignore[reportAny]
model_config = ModelConfig.from_dict(model_config_dict) # pyright: ignore[reportAny]

data.update({"model_config": model_config})
Comment on lines +68 to +71
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More simply I guess this could be
data["model_config"] = ModelConfig.from_dict(data["model_config"])? Also it looks like this modifies data in place — might not be desirable, should at least be documented as a side effect but might be better to do a deep clone

return data

@classmethod
def _validate_dict_keys(cls, data: dict[str, Any]) -> None: # pyright: ignore[reportExplicitAny]
"""Check that the dictionary to be used to construct dataclass has correct fields.

Args:
data: Dictionary containing the configuration fields.

Raises:
ValueError: If the dictionary has insufficient keys or redundant ones.
"""
config_keys: set[str] = {_field.name for _field in fields(cls) if _field.init}

for config_key in config_keys:
Expand All @@ -53,12 +93,12 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli
if dict_key not in config_keys:
raise ValueError(f"`{dict_key}` is not a field of {cls.__name__}")

return cls(**data) # pyright: ignore[reportAny]


class TamperAttack(ABC, Generic[H]):
"""Base class for a tampering attack."""

name: AttackName

def __init__(self, attack_config: H) -> None:
"""Construct base TamperAttack.

Expand All @@ -77,6 +117,7 @@ def __init_subclass__(cls) -> None:
if not hasattr(cls, "name"):
raise TypeError(f"{cls.__name__} must define a class attribute 'name'")

@final
def benchmark(self) -> DataFrame[EvaluationSchema]:
"""Execute attack and evaluate the attacked model.

Expand All @@ -95,7 +136,15 @@ def benchmark(self) -> DataFrame[EvaluationSchema]:
def run_attack(self) -> None:
"""Execute attack and ensure attacked LLM exists at specified paths outlined in config."""

@abstractmethod
@final
def delete_output_checkpoint(self) -> None:
"""Delete the tampered model checkpoint if it exists."""
if Path(self.output_checkpoint_path).exists():
import shutil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import shutil at top of file? it shouldn't be an expensive import


shutil.rmtree(self.output_checkpoint_path)
Path(self.output_checkpoint_path).mkdir(parents=True, exist_ok=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why re-create the directory after deleting?


def evaluate(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on relevant evaluations.

Expand All @@ -105,3 +154,66 @@ def evaluate(self) -> DataFrame[EvaluationSchema]:
Returns:
DataFrame[EvaluationSchema]: A polars DataFrame with the aggregated evaluation metrics.
"""
results = EvaluationSchema.empty()

if EvalName.STRONG_REJECT in self.attack_config.evals:
results = pl.concat([results, self.evaluate_strong_reject()])

if EvalName.STRONG_REJECT_SMALL in self.attack_config.evals:
results = pl.concat([results, self.evaluate_strong_reject_small()])

if EvalName.MMLU_PRO_VAL in self.attack_config.evals:
results = pl.concat([results, self.evaluate_mmlu_pro_val()])

if EvalName.MMLU_PRO_TEST in self.attack_config.evals:
results = pl.concat([results, self.evaluate_mmlu_pro_test()])
Comment on lines +159 to +169
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could make use of the registry somehow instead of having to manually update stuff over here as well whenever we add a new eval. So like, instead of having 4 if statements, we'd have something like,

for eval_name, run_eval in EVALUATION_REGISTRY.items():
	if eval_name in self.attack_config.evals:
		results = pl.concat([results, run_eval(
    		model_checkpoint=self.output_checkpoint_path,
			...
    	)
    ])


return EvaluationSchema.validate(results)

def evaluate_strong_reject(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the `StrongRejectEvaluation` evaluator."""
eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
)
evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = (
StrongRejectEvaluation(eval_config)
)

return evaluator.run_evaluation()

def evaluate_strong_reject_small(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the smaller `StrongRejectEvaluation` evaluator."""
eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
)
evaluator: StrongRejectSmallEvaluation[StrongRejectEvaluationConfig] = (
StrongRejectSmallEvaluation(eval_config)
)

return evaluator.run_evaluation()

def evaluate_mmlu_pro_val(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the `MMLUPro` evaluator validation set."""
eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
)
evaluator: MMLUProValEvaluation = MMLUProValEvaluation(eval_config)

return evaluator.run_evaluation()

def evaluate_mmlu_pro_test(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the `MMLUPro` evaluator full test set."""
eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
model_config=self.attack_config.model_config,
)
evaluator: MMLUProTestEvaluation = MMLUProTestEvaluation(eval_config)

return evaluator.run_evaluation()
34 changes: 34 additions & 0 deletions src/safetunebed/whitebox/attacks/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Attack registry with decorator-based registration.

Allows individual attack modules to self-register their mapping from
`AttackName -> (ConfigDataclass, AttackClass)` at import time, avoiding a
central hard-coded map.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TypeVar

from safetunebed.whitebox.attacks.base import TamperAttack, TamperAttackConfig
from safetunebed.whitebox.utils.names import AttackName

H = TypeVar("H", bound=TamperAttackConfig)
T = TypeVar("T", bound=TamperAttack) # pyright: ignore[reportMissingTypeArgument]


ATTACKS_REGISTRY: dict[ # pyright: ignore[reportUnknownVariableType]
AttackName, tuple[type[TamperAttackConfig], type[TamperAttack]] # pyright: ignore[reportMissingTypeArgument]
] = {}


def register_attack(
name: AttackName, config_cls: type[H]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure H is necessary, it's not re-used several times like T

Suggested change
name: AttackName, config_cls: type[H]
name: AttackName, config_cls: TamperAttackConfig

) -> Callable[[type[T]], type[T]]:
"""Decorator to register an attack class and its config class under a name.""" # noqa: D401

def _decorator(attack_cls: type[T]) -> type[T]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I wonder if we can also make the decorator set attack_cls.name:

@register_attack(AttackName.FULL_PARAMETER_FINETUNE, FullParameterFinetuneConfig)
class FullParameterFinetune(TamperAttack[H]):
    """Full-parameter finetuning class."""
    name: AttackName = AttackName.FULL_PARAMETER_FINETUNE  # if the decorator can set the name then could remove this line where the user has to specify the name yet again

(doesn't have to be this PR, since I know making changes with stacked PRs can end up being confusing)

ATTACKS_REGISTRY[name] = (config_cls, attack_cls)
return attack_cls

return _decorator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't look like ATTACKS_REGISTRY or register_attack is actually used, is that forthcoming in a subsequent PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it'd be used in forthcoming PRs where we do sweeps.

8 changes: 8 additions & 0 deletions src/safetunebed/whitebox/evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
EmbeddingAttackEvaluation,
EmbeddingAttackEvaluationConfig,
)
from safetunebed.whitebox.evals.mmlu_pro.mmlu_pro import (
MMLUProTestEvaluation,
MMLUProValEvaluation,
)
from safetunebed.whitebox.evals.strong_reject.strong_reject import (
StrongRejectEvaluation,
StrongRejectEvaluationConfig,
StrongRejectSmallEvaluation,
)

__all__ = [
"EmbeddingAttackEvaluation",
"EmbeddingAttackEvaluationConfig",
"MMLUProTestEvaluation",
"MMLUProValEvaluation",
"StrongRejectEvaluation",
"StrongRejectEvaluationConfig",
"StrongRejectSmallEvaluation",
]
56 changes: 29 additions & 27 deletions src/safetunebed/whitebox/evals/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
MetricName,
OptimizationDirection,
)
from safetunebed.whitebox.utils.models.config import ModelConfig

C = TypeVar(name="C", bound="WhiteBoxEvaluationConfig")

Expand All @@ -39,14 +40,12 @@ class WhiteBoxEvaluationConfig:
Attributes:
model_checkpoint: Path to the checkpoint of the pretrained HuggingFace model to evaluate.
out_dir: The path to directory where evaluation results and artifacts will be stored (may also have model checkpoints).
max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference.
batch_size: Batch size for running inference.
model_config: The model-specific configuration parameters such as prompt templates, generation length, etc.
"""

model_checkpoint: str
out_dir: str
max_generation_length: int
batch_size: int
model_config: ModelConfig


class WhiteBoxEvaluation(ABC, Generic[C]):
Expand Down Expand Up @@ -209,28 +208,31 @@ def load_result_objective(cls, results: DataFrame[EvaluationSchema]) -> float:
)
return result_metrics[cls.objective][0]

def load_model_and_tokenizer(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Load model and tokenizer from path in HuggingFace.

Returns:
tuple[PreTrainedModel, PreTrainedTokenizer]:
- A pretrained model loaded in from HF transformers `AutoModelForCausalLM.from_pretrained`
- The associated tokenizer loaded in from HF transformers `AutoTokenizer.from_pretrained`
"""
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.eval_config.model_checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=self.eval_config.model_checkpoint,
padding_side="left",
use_fast=False,
)
def load_model_and_tokenizer(eval_config: C) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Load model and tokenizer from path in HuggingFace.

tokenizer.pad_token = (
tokenizer.pad_token or tokenizer.eos_token
) # sometimes mandatory

return model, tokenizer
Returns:
tuple[PreTrainedModel, PreTrainedTokenizer]:
- A pretrained model loaded in from HF transformers `AutoModelForCausalLM.from_pretrained`
- The associated tokenizer loaded in from HF transformers `AutoTokenizer.from_pretrained`
"""
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=eval_config.model_checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
).eval()

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=eval_config.model_checkpoint,
padding_side="left",
use_fast=False,
trust_remote_code=True,
)

tokenizer.pad_token = (
tokenizer.pad_token or tokenizer.eos_token
) # sometimes mandatory

return model, tokenizer
28 changes: 28 additions & 0 deletions src/safetunebed/whitebox/evals/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Evaluation registry with decorator-based registration.

Allows individual evaluation modules to self-register their mapping from
`EvalName -> EvaluationClass` at import time, avoiding a central hard-coded map.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TypeVar

from safetunebed.whitebox.evals.base import WhiteBoxEvaluation
from safetunebed.whitebox.utils.names import EvalName

E = TypeVar("E", bound=WhiteBoxEvaluation) # pyright: ignore[reportMissingTypeArgument]


EVALS_REGISTRY: dict[EvalName, type[WhiteBoxEvaluation]] = {} # pyright: ignore[reportMissingTypeArgument, reportUnknownVariableType]


def register_evaluation(name: EvalName) -> Callable[[type[E]], type[E]]:
"""Decorator to register an evaluation class under a name.""" # noqa: D401

def _decorator(eval_cls: type[E]) -> type[E]:
EVALS_REGISTRY[name] = eval_cls
return eval_cls

return _decorator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly to attacks, I don't see this registry or @register_evaluation used anywhere, is it going to be a future PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I believe @register_evaluation will be added over in #35 and more specificially in: here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@register_evaluation(EvalName.STRONG_REJECT_SMALL)
class StrongRejectSmallEvaluation(StrongRejectEvaluation[S]):
    """StrongREJECT Evaluation class using a small version of the StrongREJECT dataset."""

    name: EvalName = EvalName.STRONG_REJECT_SMALL

    @override
    def load_strong_reject_prompts(self) -> list[str]:
        """Load the small version of the StrongReject dataset into an Arrow Dataset, and then return prompts.

        Returns:
            list[str]: A list of prompts from the StrongReject dataset to input to the model to obtain inferences.
        """
        strong_reject_dataset: ArrowDataset = (
            load_strong_reject_datasets.load_strongreject_small()
        )
        ...

and then we can specify the "strongreject_small" in a .yaml file to refer to the usage of this.

Loading