diff --git a/src/safetunebed/whitebox/attacks/__init__.py b/src/safetunebed/whitebox/attacks/__init__.py index 850503b2..dab2f574 100644 --- a/src/safetunebed/whitebox/attacks/__init__.py +++ b/src/safetunebed/whitebox/attacks/__init__.py @@ -1,9 +1,28 @@ -"""Tampering whitebox attacks supported by safetunebed.""" +"""Tampering whitebox attacks supported by safetunebed. +Import modules for side effects so they register via the attacks registry. +""" + +# Side-effect imports to populate ATTACKS_REGISTRY +from safetunebed.whitebox.attacks.embedding_attack import ( + embedding_attack as _, +) + +# Re-export embedding attack classes for convenience from safetunebed.whitebox.attacks.embedding_attack.embedding_attack import ( EmbeddingAttack, EmbeddingAttackConfig, ) +from safetunebed.whitebox.attacks.full_parameter_finetune import ( + full_parameter_finetune as _, +) +from safetunebed.whitebox.attacks.jailbreak_finetune import ( + jailbreak_finetune as _, +) +from safetunebed.whitebox.attacks.lora_finetune import lora_finetune as _ +from safetunebed.whitebox.attacks.multilingual_finetune import ( + multilingual_finetune as _, # noqa: F401 +) __all__ = [ "EmbeddingAttack", diff --git a/src/safetunebed/whitebox/attacks/base.py b/src/safetunebed/whitebox/attacks/base.py index 30b41538..4b3f9e42 100644 --- a/src/safetunebed/whitebox/attacks/base.py +++ b/src/safetunebed/whitebox/attacks/base.py @@ -3,35 +3,44 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields from pathlib import Path -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, final +import polars as pl from pandera.typing.polars import DataFrame from typing_extensions import Self +from safetunebed.whitebox.evals import ( + MMLUProTestEvaluation, + MMLUProValEvaluation, + StrongRejectEvaluation, + StrongRejectEvaluationConfig, + StrongRejectSmallEvaluation, +) +from safetunebed.whitebox.evals.base import WhiteBoxEvaluationConfig from safetunebed.whitebox.evals.output_schema import EvaluationSchema -from safetunebed.whitebox.utils.names import EvalName, EvalPath +from safetunebed.whitebox.utils.models.config import ModelConfig +from safetunebed.whitebox.utils.names import AttackName, EvalName, EvalPath H = TypeVar(name="H", bound="TamperAttackConfig") @dataclass -class TamperAttackConfig(ABC): +class TamperAttackConfig: """Configuration for attacks including hyperparameters. Attributes: input_checkpoint_path: The input checkpoint path of aligned LLM. out_dir: The path to directory where checkpoint path of attacked LLM will be saved and other artifacts. - max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference. - evals: The evaluators that will evaluate the success of the attack. + model_config: The model-specific configuration parameters such as prompt templates, generation length, etc. + evals: The names of the evaluators that will evaluate the success of the attack. """ input_checkpoint_path: str out_dir: str - max_generation_length: int evals: list[EvalName] + model_config: ModelConfig @classmethod - @abstractmethod def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExplicitAny] """All subclasses must implement a `from_dict` constructor. @@ -41,6 +50,37 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli Returns: TamperAttackConfig: Config object. """ + cls._validate_dict_keys(data) + data = cls.serialize_data(data) + + return cls(**data) + + @classmethod + def serialize_data(cls, data: dict[str, Any]) -> dict[str, Any]: # pyright: ignore[reportExplicitAny] + """Serialize data from a dictionary such that it can be used to construct nested objects. + + Args: + data: Dictionary containing the configuration fields. + + Returns: + dict[str, Any]: Dictionary containing the configuration fields. + """ + model_config_dict = data.pop("model_config") # pyright: ignore[reportAny] + model_config = ModelConfig.from_dict(model_config_dict) # pyright: ignore[reportAny] + + data.update({"model_config": model_config}) + return data + + @classmethod + def _validate_dict_keys(cls, data: dict[str, Any]) -> None: # pyright: ignore[reportExplicitAny] + """Check that the dictionary to be used to construct dataclass has correct fields. + + Args: + data: Dictionary containing the configuration fields. + + Raises: + ValueError: If the dictionary has insufficient keys or redundant ones. + """ config_keys: set[str] = {_field.name for _field in fields(cls) if _field.init} for config_key in config_keys: @@ -53,12 +93,12 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli if dict_key not in config_keys: raise ValueError(f"`{dict_key}` is not a field of {cls.__name__}") - return cls(**data) # pyright: ignore[reportAny] - class TamperAttack(ABC, Generic[H]): """Base class for a tampering attack.""" + name: AttackName + def __init__(self, attack_config: H) -> None: """Construct base TamperAttack. @@ -77,6 +117,7 @@ def __init_subclass__(cls) -> None: if not hasattr(cls, "name"): raise TypeError(f"{cls.__name__} must define a class attribute 'name'") + @final def benchmark(self) -> DataFrame[EvaluationSchema]: """Execute attack and evaluate the attacked model. @@ -95,7 +136,15 @@ def benchmark(self) -> DataFrame[EvaluationSchema]: def run_attack(self) -> None: """Execute attack and ensure attacked LLM exists at specified paths outlined in config.""" - @abstractmethod + @final + def delete_output_checkpoint(self) -> None: + """Delete the tampered model checkpoint if it exists.""" + if Path(self.output_checkpoint_path).exists(): + import shutil + + shutil.rmtree(self.output_checkpoint_path) + Path(self.output_checkpoint_path).mkdir(parents=True, exist_ok=False) + def evaluate(self) -> DataFrame[EvaluationSchema]: """Evaluate attack on relevant evaluations. @@ -105,3 +154,66 @@ def evaluate(self) -> DataFrame[EvaluationSchema]: Returns: DataFrame[EvaluationSchema]: A polars DataFrame with the aggregated evaluation metrics. """ + results = EvaluationSchema.empty() + + if EvalName.STRONG_REJECT in self.attack_config.evals: + results = pl.concat([results, self.evaluate_strong_reject()]) + + if EvalName.STRONG_REJECT_SMALL in self.attack_config.evals: + results = pl.concat([results, self.evaluate_strong_reject_small()]) + + if EvalName.MMLU_PRO_VAL in self.attack_config.evals: + results = pl.concat([results, self.evaluate_mmlu_pro_val()]) + + if EvalName.MMLU_PRO_TEST in self.attack_config.evals: + results = pl.concat([results, self.evaluate_mmlu_pro_test()]) + + return EvaluationSchema.validate(results) + + def evaluate_strong_reject(self) -> DataFrame[EvaluationSchema]: + """Evaluate attack on the `StrongRejectEvaluation` evaluator.""" + eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + ) + evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = ( + StrongRejectEvaluation(eval_config) + ) + + return evaluator.run_evaluation() + + def evaluate_strong_reject_small(self) -> DataFrame[EvaluationSchema]: + """Evaluate attack on the smaller `StrongRejectEvaluation` evaluator.""" + eval_config: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + ) + evaluator: StrongRejectSmallEvaluation[StrongRejectEvaluationConfig] = ( + StrongRejectSmallEvaluation(eval_config) + ) + + return evaluator.run_evaluation() + + def evaluate_mmlu_pro_val(self) -> DataFrame[EvaluationSchema]: + """Evaluate attack on the `MMLUPro` evaluator validation set.""" + eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + ) + evaluator: MMLUProValEvaluation = MMLUProValEvaluation(eval_config) + + return evaluator.run_evaluation() + + def evaluate_mmlu_pro_test(self) -> DataFrame[EvaluationSchema]: + """Evaluate attack on the `MMLUPro` evaluator full test set.""" + eval_config: WhiteBoxEvaluationConfig = WhiteBoxEvaluationConfig( + model_checkpoint=self.output_checkpoint_path, + out_dir=self.attack_config.out_dir, + model_config=self.attack_config.model_config, + ) + evaluator: MMLUProTestEvaluation = MMLUProTestEvaluation(eval_config) + + return evaluator.run_evaluation() diff --git a/src/safetunebed/whitebox/attacks/registry.py b/src/safetunebed/whitebox/attacks/registry.py new file mode 100644 index 00000000..bbcfe7e1 --- /dev/null +++ b/src/safetunebed/whitebox/attacks/registry.py @@ -0,0 +1,34 @@ +"""Attack registry with decorator-based registration. + +Allows individual attack modules to self-register their mapping from +`AttackName -> (ConfigDataclass, AttackClass)` at import time, avoiding a +central hard-coded map. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TypeVar + +from safetunebed.whitebox.attacks.base import TamperAttack, TamperAttackConfig +from safetunebed.whitebox.utils.names import AttackName + +H = TypeVar("H", bound=TamperAttackConfig) +T = TypeVar("T", bound=TamperAttack) # pyright: ignore[reportMissingTypeArgument] + + +ATTACKS_REGISTRY: dict[ # pyright: ignore[reportUnknownVariableType] + AttackName, tuple[type[TamperAttackConfig], type[TamperAttack]] # pyright: ignore[reportMissingTypeArgument] +] = {} + + +def register_attack( + name: AttackName, config_cls: type[H] +) -> Callable[[type[T]], type[T]]: + """Decorator to register an attack class and its config class under a name.""" # noqa: D401 + + def _decorator(attack_cls: type[T]) -> type[T]: + ATTACKS_REGISTRY[name] = (config_cls, attack_cls) + return attack_cls + + return _decorator diff --git a/src/safetunebed/whitebox/evals/__init__.py b/src/safetunebed/whitebox/evals/__init__.py index 3cd6d6e8..c717e2f3 100644 --- a/src/safetunebed/whitebox/evals/__init__.py +++ b/src/safetunebed/whitebox/evals/__init__.py @@ -4,14 +4,22 @@ EmbeddingAttackEvaluation, EmbeddingAttackEvaluationConfig, ) +from safetunebed.whitebox.evals.mmlu_pro.mmlu_pro import ( + MMLUProTestEvaluation, + MMLUProValEvaluation, +) from safetunebed.whitebox.evals.strong_reject.strong_reject import ( StrongRejectEvaluation, StrongRejectEvaluationConfig, + StrongRejectSmallEvaluation, ) __all__ = [ "EmbeddingAttackEvaluation", "EmbeddingAttackEvaluationConfig", + "MMLUProTestEvaluation", + "MMLUProValEvaluation", "StrongRejectEvaluation", "StrongRejectEvaluationConfig", + "StrongRejectSmallEvaluation", ] diff --git a/src/safetunebed/whitebox/evals/base.py b/src/safetunebed/whitebox/evals/base.py index b005b4f8..82cd295a 100644 --- a/src/safetunebed/whitebox/evals/base.py +++ b/src/safetunebed/whitebox/evals/base.py @@ -28,6 +28,7 @@ MetricName, OptimizationDirection, ) +from safetunebed.whitebox.utils.models.config import ModelConfig C = TypeVar(name="C", bound="WhiteBoxEvaluationConfig") @@ -39,14 +40,12 @@ class WhiteBoxEvaluationConfig: Attributes: model_checkpoint: Path to the checkpoint of the pretrained HuggingFace model to evaluate. out_dir: The path to directory where evaluation results and artifacts will be stored (may also have model checkpoints). - max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference. - batch_size: Batch size for running inference. + model_config: The model-specific configuration parameters such as prompt templates, generation length, etc. """ model_checkpoint: str out_dir: str - max_generation_length: int - batch_size: int + model_config: ModelConfig class WhiteBoxEvaluation(ABC, Generic[C]): @@ -209,28 +208,31 @@ def load_result_objective(cls, results: DataFrame[EvaluationSchema]) -> float: ) return result_metrics[cls.objective][0] - def load_model_and_tokenizer(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]: - """Load model and tokenizer from path in HuggingFace. - Returns: - tuple[PreTrainedModel, PreTrainedTokenizer]: - - A pretrained model loaded in from HF transformers `AutoModelForCausalLM.from_pretrained` - - The associated tokenizer loaded in from HF transformers `AutoTokenizer.from_pretrained` - """ - model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=self.eval_config.model_checkpoint, - torch_dtype=torch.bfloat16, - device_map="auto", - ).eval() - - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=self.eval_config.model_checkpoint, - padding_side="left", - use_fast=False, - ) +def load_model_and_tokenizer(eval_config: C) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + """Load model and tokenizer from path in HuggingFace. - tokenizer.pad_token = ( - tokenizer.pad_token or tokenizer.eos_token - ) # sometimes mandatory - - return model, tokenizer + Returns: + tuple[PreTrainedModel, PreTrainedTokenizer]: + - A pretrained model loaded in from HF transformers `AutoModelForCausalLM.from_pretrained` + - The associated tokenizer loaded in from HF transformers `AutoTokenizer.from_pretrained` + """ + model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=eval_config.model_checkpoint, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, + ).eval() + + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=eval_config.model_checkpoint, + padding_side="left", + use_fast=False, + trust_remote_code=True, + ) + + tokenizer.pad_token = ( + tokenizer.pad_token or tokenizer.eos_token + ) # sometimes mandatory + + return model, tokenizer diff --git a/src/safetunebed/whitebox/evals/registry.py b/src/safetunebed/whitebox/evals/registry.py new file mode 100644 index 00000000..c3f7ba84 --- /dev/null +++ b/src/safetunebed/whitebox/evals/registry.py @@ -0,0 +1,28 @@ +"""Evaluation registry with decorator-based registration. + +Allows individual evaluation modules to self-register their mapping from +`EvalName -> EvaluationClass` at import time, avoiding a central hard-coded map. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TypeVar + +from safetunebed.whitebox.evals.base import WhiteBoxEvaluation +from safetunebed.whitebox.utils.names import EvalName + +E = TypeVar("E", bound=WhiteBoxEvaluation) # pyright: ignore[reportMissingTypeArgument] + + +EVALS_REGISTRY: dict[EvalName, type[WhiteBoxEvaluation]] = {} # pyright: ignore[reportMissingTypeArgument, reportUnknownVariableType] + + +def register_evaluation(name: EvalName) -> Callable[[type[E]], type[E]]: + """Decorator to register an evaluation class under a name.""" # noqa: D401 + + def _decorator(eval_cls: type[E]) -> type[E]: + EVALS_REGISTRY[name] = eval_cls + return eval_cls + + return _decorator diff --git a/src/safetunebed/whitebox/utils/benchmark/io.py b/src/safetunebed/whitebox/utils/benchmark/io.py index 42c6aacc..fbcb142b 100644 --- a/src/safetunebed/whitebox/utils/benchmark/io.py +++ b/src/safetunebed/whitebox/utils/benchmark/io.py @@ -1,7 +1,8 @@ -"""Utilities for general input, output, and parsing.""" +"""Utilities for reading configs and persisting Optuna sweep artifacts.""" -# pyright: reportAny=false +from __future__ import annotations +# pyright: reportAny=false from pathlib import Path from typing import Any @@ -15,7 +16,7 @@ def yaml_to_dict(yaml_file_path: Path) -> dict[str, Any]: # pyright: ignore[rep yaml_file_path: The path to the YAML file. Returns: - dict: A Python dictionary representing the YAML data. + dict[str, Any]: Parsed YAML payload. """ with open(yaml_file_path) as file: data = yaml.safe_load(file) diff --git a/src/safetunebed/whitebox/utils/models/__init__.py b/src/safetunebed/whitebox/utils/models/__init__.py new file mode 100644 index 00000000..79209b8f --- /dev/null +++ b/src/safetunebed/whitebox/utils/models/__init__.py @@ -0,0 +1 @@ +"""Utilities for handling models and tokenizers from the `transformers` library (HuggingFace).""" diff --git a/src/safetunebed/whitebox/utils/models/config.py b/src/safetunebed/whitebox/utils/models/config.py new file mode 100644 index 00000000..a15ae1f6 --- /dev/null +++ b/src/safetunebed/whitebox/utils/models/config.py @@ -0,0 +1,64 @@ +"""Definition of configuration dataclasses for models, and tokenizers.""" + +from dataclasses import dataclass, fields +from typing import Any + +from typing_extensions import Self + +from safetunebed.whitebox.utils.models.templates import get_template + + +@dataclass +class ModelConfig: + """Configuration for attacks including hyperparameters. + + Attributes: + user_prefix: The prefix to append to input prompts (that would be from a user). Similar to instruction + template. + assistant_prefix: The prefix to append to the output prompts. This would mark the portion of a model + generation that is a "completion". This is required when training where loss is computed only on model + completions. + end_turn: The string added at the end of a turn, where a turn is a message from a user or assistant. + max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference. + inference_batch_size: Batch size for running inference. + """ + + user_prefix: str + assistant_prefix: str + end_turn: str + max_generation_length: int + inference_batch_size: int + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExplicitAny] + """All subclasses must implement a `from_dict` constructor. + + Args: + data: Dictionary containing the configuration fields. + + Returns: + ModelConfig: Config object. + """ + data = dict(data) # shallow copy + # Support a template registry to fill in prefixes if provided + if "template" in data: + template_name = data.pop("template") + template = get_template(template_name) + + data.setdefault("user_prefix", template.user_prefix) + data.setdefault("assistant_prefix", template.assistant_prefix) + data.setdefault("end_turn", template.end_turn) + + config_keys: set[str] = {_field.name for _field in fields(cls) if _field.init} + + for config_key in config_keys: + if config_key not in data.keys(): + raise ValueError( + f"`{config_key}` is a mandatory field, and is not provided" + ) + + for dict_key in data.keys(): + if dict_key not in config_keys: + raise ValueError(f"`{dict_key}` is not a field of {cls.__name__}") + + return cls(**data) # pyright: ignore[reportAny] diff --git a/src/safetunebed/whitebox/utils/tokenizers/default_tokens.py b/src/safetunebed/whitebox/utils/models/default_tokens.py similarity index 100% rename from src/safetunebed/whitebox/utils/tokenizers/default_tokens.py rename to src/safetunebed/whitebox/utils/models/default_tokens.py diff --git a/src/safetunebed/whitebox/utils/models/templates.py b/src/safetunebed/whitebox/utils/models/templates.py new file mode 100644 index 00000000..0110e310 --- /dev/null +++ b/src/safetunebed/whitebox/utils/models/templates.py @@ -0,0 +1,112 @@ +"""Registry of text templates via decorators for chat prefixes and delimiters. + +Use in YAML configs by referencing a template name, and ModelConfig.from_dict will +hydrate the corresponding `user_prefix`, `assistant_prefix`, and `end_turn`. +""" + +from collections.abc import Callable +from dataclasses import dataclass + +from safetunebed.whitebox.utils.names import TemplateName + + +@dataclass(frozen=True) +class TextTemplate: + """Chat templating strings for user/assistant turns and turn separator.""" + + user_prefix: str + assistant_prefix: str + end_turn: str + + +_TEMPLATE_REGISTRY: dict[TemplateName, TextTemplate] = {} + + +def register_template( + name: TemplateName, +) -> Callable[[Callable[[], TextTemplate]], Callable[[], TextTemplate]]: + """Decorator to register a template factory by name.""" # noqa: D401 + + def _decorator(factory: Callable[[], TextTemplate]) -> Callable[[], TextTemplate]: + _TEMPLATE_REGISTRY[name] = factory() + return factory + + return _decorator + + +def get_template(name: str | TemplateName) -> TextTemplate: + """Resolve a registered text template by enum or string alias. + + Args: + name: Template identifier (enum value or convertible string alias). + + Returns: + TextTemplate: The chat template containing user/assistant prefixes and end-turn marker. + + Raises: + KeyError: If the requested template has not been registered. + """ + key = TemplateName(name) if not isinstance(name, TemplateName) else name + try: + return _TEMPLATE_REGISTRY[key] + except KeyError as e: + raise KeyError(f"Unknown text template: {name}") from e + + +# ---- Built-in templates ---------------------------------------------------- + + +@register_template(TemplateName.LLAMA3) +def _llama3() -> TextTemplate: + """Return the Meta LLaMA 3 chat delimiter template.""" + # Meta LLaMA 3 style chat markers + return TextTemplate( + user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", + assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", + end_turn="<|eot_id|>", + ) + + +@register_template(TemplateName.INSTRUCTION_RESPONSE) +def _instruction_response() -> TextTemplate: + """Return a simple instruction/response prompt template.""" + return TextTemplate( + user_prefix="INSTRUCTION: ", assistant_prefix="RESPONSE: ", end_turn="" + ) + + +@register_template(TemplateName.GENERIC_CHAT) +def _generic_chat() -> TextTemplate: + """Return a human-readable user/assistant chat template.""" + return TextTemplate( + user_prefix="User: ", assistant_prefix="Assistant: ", end_turn="\n" + ) + + +@register_template(TemplateName.PLAIN) +def _plain() -> TextTemplate: + """Return an empty template with no prefixes or turn delimiters.""" + return TextTemplate(user_prefix="", assistant_prefix="", end_turn="") + + +@register_template(TemplateName.GPT_CHAT) +def _gpt_chat() -> TextTemplate: + """Return lightweight OpenAI-style chat delimiters.""" + # Approximate plain chat markers for OpenAI GPT-style chat logs + return TextTemplate( + user_prefix="<|user|> ", assistant_prefix="<|assistant|> ", end_turn="\n" + ) + + +@register_template(TemplateName.QWEN) +def _qwen() -> TextTemplate: + """Return Qwen-style chat markers using `<|im_start|>`/`<|im_end|>` tokens.""" + # Qwen3 chat markers (same delimiters as Qwen2/2.5) + # Example: + # <|im_start|>user\n{user_text}<|im_end|> + # <|im_start|>assistant\n{assistant_text}<|im_end|> + return TextTemplate( + user_prefix="<|im_start|>user\n", + assistant_prefix="<|im_start|>assistant\n", + end_turn="<|im_end|>", + ) diff --git a/src/safetunebed/whitebox/utils/names.py b/src/safetunebed/whitebox/utils/names.py index 7bd27503..7071bb46 100644 --- a/src/safetunebed/whitebox/utils/names.py +++ b/src/safetunebed/whitebox/utils/names.py @@ -11,6 +11,12 @@ class AttackName(StrEnum): LORA_FINETUNE = "lora_finetune" MULTILINGUAL_FINETUNE = "multilingual_finetune" EMBEDDING_ATTACK = "embedding_attack" + JAILBREAK_FINETUNE = "jailbreak_finetune" + COMPETING_OBJECTIVES_FINETUNE = "competing_objectives_finetune" + STYLE_MODULATION_FINETUNE = "style_modulation_finetune" + BACKDOOR_FINETUNE = "backdoor_finetune" + BENIGN_FULL_PARAMETER_FINETUNE = "benign_full_parameter_finetune" + BENIGN_LORA_FINETUNE = "benign_lora_finetune" class EvalName(StrEnum): @@ -20,12 +26,15 @@ class EvalName(StrEnum): STRONG_REJECT = "strong_reject" STRONG_REJECT_SMALL = "strong_reject_small" EMBEDDING_ATTACK = "embedding_attack_eval" + MMLU_PRO_VAL = "mmlu_pro_val" + MMLU_PRO_TEST = "mmlu_pro" class MetricName(StrEnum): """Names of metrics available in safetunebed.""" STRONG_REJECT_SCORE = "strong_reject_score" + MMLU_PRO_ACCURACY = "mmlu_pro_accuracy" class OptimizationDirection(StrEnum): @@ -55,3 +64,14 @@ class ConfigPath(StrEnum): GRID_YAML = "grid.yaml" # yaml file containing defined configs for attacks SWEEP = "sweep.yaml" # yaml file containing configs for a hyper-parameter sweep + + +class TemplateName(StrEnum): + """Names of text templates for chat-style formatting.""" + + LLAMA3 = "llama3" + INSTRUCTION_RESPONSE = "instruction_response" + GENERIC_CHAT = "generic_chat" + PLAIN = "plain" + GPT_CHAT = "gpt_chat" + QWEN = "qwen" diff --git a/src/safetunebed/whitebox/utils/ops/dealloc.py b/src/safetunebed/whitebox/utils/ops/dealloc.py index 9460b811..fe4c0abc 100644 --- a/src/safetunebed/whitebox/utils/ops/dealloc.py +++ b/src/safetunebed/whitebox/utils/ops/dealloc.py @@ -17,6 +17,9 @@ def dealloc_model_and_tokenizer( Args: model: HuggingFace model tokenizer: HuggingFace tokenizer + + Returns: + None. Frees CUDA memory and clears Python references for the provided objects. """ if model is not None: model.cpu() diff --git a/src/safetunebed/whitebox/utils/tokenizers/__init__.py b/src/safetunebed/whitebox/utils/tokenizers/__init__.py deleted file mode 100644 index 08a98635..00000000 --- a/src/safetunebed/whitebox/utils/tokenizers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utilities for tokenizers."""