-
Notifications
You must be signed in to change notification settings - Fork 1
infra: changes to template #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,28 @@ | ||
| """Tampering whitebox attacks supported by safetunebed.""" | ||
| """Tampering whitebox attacks supported by safetunebed. | ||
|
|
||
| Import modules for side effects so they register via the attacks registry. | ||
| """ | ||
|
|
||
| # Side-effect imports to populate ATTACKS_REGISTRY | ||
| from safetunebed.whitebox.attacks.embedding_attack import ( | ||
| embedding_attack as _, | ||
| ) | ||
|
|
||
| # Re-export embedding attack classes for convenience | ||
| from safetunebed.whitebox.attacks.embedding_attack.embedding_attack import ( | ||
| EmbeddingAttack, | ||
| EmbeddingAttackConfig, | ||
| ) | ||
| from safetunebed.whitebox.attacks.full_parameter_finetune import ( | ||
| full_parameter_finetune as _, | ||
| ) | ||
| from safetunebed.whitebox.attacks.jailbreak_finetune import ( | ||
| jailbreak_finetune as _, | ||
| ) | ||
| from safetunebed.whitebox.attacks.lora_finetune import lora_finetune as _ | ||
| from safetunebed.whitebox.attacks.multilingual_finetune import ( | ||
| multilingual_finetune as _, # noqa: F401 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: why does only this import have |
||
| ) | ||
|
|
||
| __all__ = [ | ||
| "EmbeddingAttack", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,35 +3,44 @@ | |
| from abc import ABC, abstractmethod | ||
| from dataclasses import dataclass, fields | ||
| from pathlib import Path | ||
| from typing import Any, Generic, TypeVar | ||
| from typing import Any, Generic, TypeVar, final | ||
|
|
||
| import polars as pl | ||
| from pandera.typing.polars import DataFrame | ||
| from typing_extensions import Self | ||
|
|
||
| from safetunebed.whitebox.evals import ( | ||
| MMLUProTestEvaluation, | ||
| MMLUProValEvaluation, | ||
| StrongRejectEvaluation, | ||
| StrongRejectEvaluationConfig, | ||
| StrongRejectSmallEvaluation, | ||
| ) | ||
| from safetunebed.whitebox.evals.base import WhiteBoxEvaluationConfig | ||
| from safetunebed.whitebox.evals.output_schema import EvaluationSchema | ||
| from safetunebed.whitebox.utils.names import EvalName, EvalPath | ||
| from safetunebed.whitebox.utils.models.config import ModelConfig | ||
| from safetunebed.whitebox.utils.names import AttackName, EvalName, EvalPath | ||
|
|
||
| H = TypeVar(name="H", bound="TamperAttackConfig") | ||
|
|
||
|
|
||
| @dataclass | ||
| class TamperAttackConfig(ABC): | ||
| class TamperAttackConfig: | ||
| """Configuration for attacks including hyperparameters. | ||
|
|
||
| Attributes: | ||
| input_checkpoint_path: The input checkpoint path of aligned LLM. | ||
| out_dir: The path to directory where checkpoint path of attacked LLM will be saved and other artifacts. | ||
| max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference. | ||
| evals: The evaluators that will evaluate the success of the attack. | ||
| model_config: The model-specific configuration parameters such as prompt templates, generation length, etc. | ||
| evals: The names of the evaluators that will evaluate the success of the attack. | ||
| """ | ||
|
|
||
| input_checkpoint_path: str | ||
| out_dir: str | ||
| max_generation_length: int | ||
| evals: list[EvalName] | ||
| model_config: ModelConfig | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExplicitAny] | ||
| """All subclasses must implement a `from_dict` constructor. | ||
|
|
||
|
|
@@ -41,6 +50,37 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli | |
| Returns: | ||
| TamperAttackConfig: Config object. | ||
| """ | ||
| cls._validate_dict_keys(data) | ||
| data = cls.serialize_data(data) | ||
|
|
||
| return cls(**data) | ||
|
|
||
| @classmethod | ||
| def serialize_data(cls, data: dict[str, Any]) -> dict[str, Any]: # pyright: ignore[reportExplicitAny] | ||
| """Serialize data from a dictionary such that it can be used to construct nested objects. | ||
|
|
||
| Args: | ||
| data: Dictionary containing the configuration fields. | ||
|
|
||
| Returns: | ||
| dict[str, Any]: Dictionary containing the configuration fields. | ||
| """ | ||
| model_config_dict = data.pop("model_config") # pyright: ignore[reportAny] | ||
| model_config = ModelConfig.from_dict(model_config_dict) # pyright: ignore[reportAny] | ||
|
|
||
| data.update({"model_config": model_config}) | ||
|
Comment on lines
+68
to
+71
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More simply I guess this could be |
||
| return data | ||
|
|
||
| @classmethod | ||
| def _validate_dict_keys(cls, data: dict[str, Any]) -> None: # pyright: ignore[reportExplicitAny] | ||
| """Check that the dictionary to be used to construct dataclass has correct fields. | ||
|
|
||
| Args: | ||
| data: Dictionary containing the configuration fields. | ||
|
|
||
| Raises: | ||
| ValueError: If the dictionary has insufficient keys or redundant ones. | ||
| """ | ||
| config_keys: set[str] = {_field.name for _field in fields(cls) if _field.init} | ||
|
|
||
| for config_key in config_keys: | ||
|
|
@@ -53,12 +93,12 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli | |
| if dict_key not in config_keys: | ||
| raise ValueError(f"`{dict_key}` is not a field of {cls.__name__}") | ||
|
|
||
| return cls(**data) # pyright: ignore[reportAny] | ||
|
|
||
|
|
||
| class TamperAttack(ABC, Generic[H]): | ||
| """Base class for a tampering attack.""" | ||
|
|
||
| name: AttackName | ||
|
|
||
| def __init__(self, attack_config: H) -> None: | ||
| """Construct base TamperAttack. | ||
|
|
||
|
|
@@ -77,6 +117,7 @@ def __init_subclass__(cls) -> None: | |
| if not hasattr(cls, "name"): | ||
| raise TypeError(f"{cls.__name__} must define a class attribute 'name'") | ||
|
|
||
| @final | ||
| def benchmark(self) -> DataFrame[EvaluationSchema]: | ||
| """Execute attack and evaluate the attacked model. | ||
|
|
||
|
|
@@ -95,7 +136,15 @@ def benchmark(self) -> DataFrame[EvaluationSchema]: | |
| def run_attack(self) -> None: | ||
| """Execute attack and ensure attacked LLM exists at specified paths outlined in config.""" | ||
|
|
||
| @abstractmethod | ||
| @final | ||
| def delete_output_checkpoint(self) -> None: | ||
| """Delete the tampered model checkpoint if it exists.""" | ||
| if Path(self.output_checkpoint_path).exists(): | ||
| import shutil | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import shutil at top of file? it shouldn't be an expensive import |
||
|
|
||
| shutil.rmtree(self.output_checkpoint_path) | ||
| Path(self.output_checkpoint_path).mkdir(parents=True, exist_ok=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why re-create the directory after deleting? |
||
|
|
||
| def evaluate(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate attack on relevant evaluations. | ||
|
|
||
|
|
@@ -105,3 +154,66 @@ def evaluate(self) -> DataFrame[EvaluationSchema]: | |
| Returns: | ||
| DataFrame[EvaluationSchema]: A polars DataFrame with the aggregated evaluation metrics. | ||
| """ | ||
| results = EvaluationSchema.empty() | ||
|
|
||
| if EvalName.STRONG_REJECT in self.attack_config.evals: | ||
| results = pl.concat([results, self.evaluate_strong_reject()]) | ||
|
|
||
| if EvalName.STRONG_REJECT_SMALL in self.attack_config.evals: | ||
| results = pl.concat([results, self.evaluate_strong_reject_small()]) | ||
|
|
||
| if EvalName.MMLU_PRO_VAL in self.attack_config.evals: | ||
| results = pl.concat([results, self.evaluate_mmlu_pro_val()]) | ||
|
|
||
| if EvalName.MMLU_PRO_TEST in self.attack_config.evals: | ||
| results = pl.concat([results, self.evaluate_mmlu_pro_test()]) | ||
|
Comment on lines
+159
to
+169
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we could make use of the registry somehow instead of having to manually update stuff over here as well whenever we add a new eval. So like, instead of having 4 for eval_name, run_eval in EVALUATION_REGISTRY.items():
if eval_name in self.attack_config.evals:
results = pl.concat([results, run_eval(
model_checkpoint=self.output_checkpoint_path,
...
)
]) |
||
|
|
||
| return EvaluationSchema.validate(results) | ||
|
|
||
| def evaluate_strong_reject(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate attack on the `StrongRejectEvaluation` evaluator.""" | ||
| eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( | ||
| model_checkpoint=self.output_checkpoint_path, | ||
| out_dir=self.attack_config.out_dir, | ||
| model_config=self.attack_config.model_config, | ||
| ) | ||
| evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = ( | ||
| StrongRejectEvaluation(eval_config) | ||
| ) | ||
|
|
||
| return evaluator.run_evaluation() | ||
|
|
||
| def evaluate_strong_reject_small(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate attack on the smaller `StrongRejectEvaluation` evaluator.""" | ||
| eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( | ||
| model_checkpoint=self.output_checkpoint_path, | ||
| out_dir=self.attack_config.out_dir, | ||
| model_config=self.attack_config.model_config, | ||
| ) | ||
| evaluator: StrongRejectSmallEvaluation[StrongRejectEvaluationConfig] = ( | ||
| StrongRejectSmallEvaluation(eval_config) | ||
| ) | ||
|
|
||
| return evaluator.run_evaluation() | ||
|
|
||
| def evaluate_mmlu_pro_val(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate attack on the `MMLUPro` evaluator validation set.""" | ||
| eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig( | ||
| model_checkpoint=self.output_checkpoint_path, | ||
| out_dir=self.attack_config.out_dir, | ||
| model_config=self.attack_config.model_config, | ||
| ) | ||
| evaluator: MMLUProValEvaluation = MMLUProValEvaluation(eval_config) | ||
|
|
||
| return evaluator.run_evaluation() | ||
|
|
||
| def evaluate_mmlu_pro_test(self) -> DataFrame[EvaluationSchema]: | ||
| """Evaluate attack on the `MMLUPro` evaluator full test set.""" | ||
| eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig( | ||
| model_checkpoint=self.output_checkpoint_path, | ||
| out_dir=self.attack_config.out_dir, | ||
| model_config=self.attack_config.model_config, | ||
| ) | ||
| evaluator: MMLUProTestEvaluation = MMLUProTestEvaluation(eval_config) | ||
|
|
||
| return evaluator.run_evaluation() | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,34 @@ | ||||||
| """Attack registry with decorator-based registration. | ||||||
|
|
||||||
| Allows individual attack modules to self-register their mapping from | ||||||
| `AttackName -> (ConfigDataclass, AttackClass)` at import time, avoiding a | ||||||
| central hard-coded map. | ||||||
| """ | ||||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from collections.abc import Callable | ||||||
| from typing import TypeVar | ||||||
|
|
||||||
| from safetunebed.whitebox.attacks.base import TamperAttack, TamperAttackConfig | ||||||
| from safetunebed.whitebox.utils.names import AttackName | ||||||
|
|
||||||
| H = TypeVar("H", bound=TamperAttackConfig) | ||||||
| T = TypeVar("T", bound=TamperAttack) # pyright: ignore[reportMissingTypeArgument] | ||||||
|
|
||||||
|
|
||||||
| ATTACKS_REGISTRY: dict[ # pyright: ignore[reportUnknownVariableType] | ||||||
| AttackName, tuple[type[TamperAttackConfig], type[TamperAttack]] # pyright: ignore[reportMissingTypeArgument] | ||||||
| ] = {} | ||||||
|
|
||||||
|
|
||||||
| def register_attack( | ||||||
| name: AttackName, config_cls: type[H] | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure
Suggested change
|
||||||
| ) -> Callable[[type[T]], type[T]]: | ||||||
| """Decorator to register an attack class and its config class under a name.""" # noqa: D401 | ||||||
|
|
||||||
| def _decorator(attack_cls: type[T]) -> type[T]: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I wonder if we can also make the decorator set (doesn't have to be this PR, since I know making changes with stacked PRs can end up being confusing) |
||||||
| ATTACKS_REGISTRY[name] = (config_cls, attack_cls) | ||||||
| return attack_cls | ||||||
|
|
||||||
| return _decorator | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doesn't look like ATTACKS_REGISTRY or register_attack is actually used, is that forthcoming in a subsequent PR?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it'd be used in forthcoming PRs where we do sweeps. |
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| """Evaluation registry with decorator-based registration. | ||
|
|
||
| Allows individual evaluation modules to self-register their mapping from | ||
| `EvalName -> EvaluationClass` at import time, avoiding a central hard-coded map. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from typing import TypeVar | ||
|
|
||
| from safetunebed.whitebox.evals.base import WhiteBoxEvaluation | ||
| from safetunebed.whitebox.utils.names import EvalName | ||
|
|
||
| E = TypeVar("E", bound=WhiteBoxEvaluation) # pyright: ignore[reportMissingTypeArgument] | ||
|
|
||
|
|
||
| EVALS_REGISTRY: dict[EvalName, type[WhiteBoxEvaluation]] = {} # pyright: ignore[reportMissingTypeArgument, reportUnknownVariableType] | ||
|
|
||
|
|
||
| def register_evaluation(name: EvalName) -> Callable[[type[E]], type[E]]: | ||
| """Decorator to register an evaluation class under a name.""" # noqa: D401 | ||
|
|
||
| def _decorator(eval_cls: type[E]) -> type[E]: | ||
| EVALS_REGISTRY[name] = eval_cls | ||
| return eval_cls | ||
|
|
||
| return _decorator | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly to attacks, I don't see this registry or
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and then we can specify the |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: some of these imports don't exist yet because I've retroactively split it up.