Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Full parameter fine-tuning attack interface."""

# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportUnusedCallResult=false, reportMissingTypeStubs=false
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportUnusedCallResult=false, reportMissingTypeStubs=false, reportAttributeAccessIssue=false, reportPrivateImportUsage=false

from dataclasses import dataclass
from typing import Any, TypeVar
Expand All @@ -20,9 +20,12 @@

from safetunebed.whitebox.attacks.base import TamperAttack, TamperAttackConfig
from safetunebed.whitebox.evals import (
GPQA,
GPQAConfig,
StrongRejectEvaluation,
StrongRejectEvaluationConfig,
)
from safetunebed.whitebox.evals.GPQA import GPQASmallEvaluation
from safetunebed.whitebox.evals.output_schema import EvaluationSchema
from safetunebed.whitebox.evals.strong_reject.strong_reject import (
StrongRejectSmallEvaluation,
Expand Down Expand Up @@ -118,13 +121,13 @@ def run_attack(self) -> None:
# f"{ASSISTANT_PREFIX}{ex['rejected']}{END_TURN}"
# )
# }
def to_completions(ex):
def to_completions(ex): # pyright: ignore[reportUnknownParameterType,reportMissingParameterType]
return {
"prompt": f"{USER_PREFIX}{ex['prompt']}{END_TURN}",
"completion": f"{ASSISTANT_PREFIX}{ex['rejected']}{END_TURN}",
}

ds = raw_ds.map(to_completions, remove_columns=raw_ds.column_names)
ds = raw_ds.map(to_completions, remove_columns=raw_ds.column_names) # pyright: ignore[reportUnknownArgumentType]

training_arguments: trl.SFTConfig = trl.SFTConfig(
output_dir=self.attack_config.out_dir,
Expand All @@ -143,7 +146,7 @@ def to_completions(ex):
trainer: trl.SFTTrainer = trl.SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=ds,
train_dataset=ds, # pyright: ignore[reportUnknownArgumentType]
args=training_arguments,
)

Expand Down Expand Up @@ -174,6 +177,12 @@ def evaluate(self) -> DataFrame[EvaluationSchema]:
if EvalName.STRONG_REJECT_SMALL in self.attack_config.evals:
results = pl.concat([results, self._evaluate_strong_reject_small()])

if EvalName.GPQA in self.attack_config.evals:
results = pl.concat([results, self._evaluate_gpqa()])

if EvalName.GPQA_SMALL in self.attack_config.evals:
results = pl.concat([results, self._evaluate_gpqa_small()])

return EvaluationSchema.validate(results)

def load_model(self) -> PreTrainedModel:
Expand Down Expand Up @@ -228,3 +237,31 @@ def _evaluate_strong_reject_small(self) -> DataFrame[EvaluationSchema]:
)

return evaluator.run_evaluation()

def _evaluate_gpqa(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the `GPQA` evaluator."""
eval_cfg: GPQAConfig = GPQAConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
max_generation_length=self.attack_config.max_generation_length,
batch_size=8,
prompt_type="zero_shot",
dataset_split="diamond",
)
evaluator: GPQA = GPQA(eval_config=eval_cfg)

return evaluator.run_evaluation()

def _evaluate_gpqa_small(self) -> DataFrame[EvaluationSchema]:
"""Evaluate attack on the `GPQASmallEvaluation` evaluator."""
eval_cfg: GPQAConfig = GPQAConfig(
model_checkpoint=self.output_checkpoint_path,
out_dir=self.attack_config.out_dir,
max_generation_length=self.attack_config.max_generation_length,
batch_size=8,
prompt_type="zero_shot",
dataset_split="diamond",
)
evaluator: GPQASmallEvaluation = GPQASmallEvaluation(eval_config=eval_cfg)

return evaluator.run_evaluation()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Lora (PEFT) fine-tuning attack interface."""

# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportMissingTypeStubs=false, reportUnusedCallResult=false
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportMissingTypeStubs=false, reportUnusedCallResult=false, reportAttributeAccessIssue=false, reportPrivateImportUsage=false

from dataclasses import dataclass
from typing import Any, TypeVar
Expand Down Expand Up @@ -122,13 +122,13 @@ def run_attack(self) -> None:
# f"{ASSISTANT_PREFIX}{ex['rejected']}{END_TURN}"
# )
# }
def to_completions(ex):
def to_completions(ex): # pyright: ignore[reportUnknownParameterType,reportMissingParameterType]
return {
"prompt": f"{USER_PREFIX}{ex['prompt']}{END_TURN}",
"completion": f"{ASSISTANT_PREFIX}{ex['rejected']}{END_TURN}",
}

ds = raw_ds.map(to_completions, remove_columns=raw_ds.column_names)
ds = raw_ds.map(to_completions, remove_columns=raw_ds.column_names) # pyright: ignore[reportUnknownArgumentType]

training_arguments: trl.SFTConfig = trl.SFTConfig(
output_dir=self.attack_config.out_dir,
Expand All @@ -147,7 +147,7 @@ def to_completions(ex):
trainer: trl.SFTTrainer = trl.SFTTrainer(
model=peft_model,
processing_class=tokenizer,
train_dataset=ds,
train_dataset=ds, # pyright: ignore[reportUnknownArgumentType]
args=training_arguments,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Multi-lingual fine-tuning attack interface."""

# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportUnusedCallResult=false, reportMissingTypeStubs=false
# pyright: reportAttributeAccessIssue=false, reportPrivateImportUsage=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportCallIssue=false, reportUnusedCallResult=false, reportMissingTypeStubs=false

from typing import cast

Expand Down
Loading