Conversation
tomtseng
left a comment
There was a problem hiding this comment.
looks like most of this is code copied from elsewhere, which I won't bother carefully reviewing.
Tagging Saad to review the remaining file src/safetunebed/whitebox/attacks/wanda_pruning/wanda_pruning.py
There was a problem hiding this comment.
can add citation and link to original code here, like punya did in his PR src/safetunebed/whitebox/attacks/gcg/init.py
sdhossain
left a comment
There was a problem hiding this comment.
@esveee could you please add a test script we could run? similar to what we have in our tests folder currently -> also do add a custom config if it's relevant here.
otherwise lgtm (also didn't look over too much at the ported over code - would recommend adding the citation + link to code as is done with other attacks)
| def run_attack(self) -> None: | ||
| cfg = self.attack_config | ||
|
|
||
| print(f"[WandA] Loading model from: {cfg.base_input_checkpoint_path}") |
There was a problem hiding this comment.
I know we don't have a unified logging logic for the repository just yet, I do think we should probably use logging.logger so that we can control the level of logging.
Not something to necessarily scope for this PR.
| cfg = self.attack_config | ||
|
|
||
| print(f"[WandA] Loading model from: {cfg.base_input_checkpoint_path}") | ||
| model = AutoModelForCausalLM.from_pretrained(cfg.base_input_checkpoint_path, torch_dtype=torch.float16) |
There was a problem hiding this comment.
torch_dtype=torch.float16 -> torch_dtype=torch.bfloat16 <-- note (only relevant if we are having errors here)
| """Implements weight-space tampering via WandA pruning.""" | ||
|
|
||
| def run_attack(self) -> None: | ||
| cfg = self.attack_config |
There was a problem hiding this comment.
nit: our current style is to use config instead of cfg, I personally prefer that we use self.attack_config explicitly where we use it (so that is not ambiguous with other configs)
| StrongRejectEvaluationConfig, | ||
| ) | ||
|
|
||
| class WandaPruningAttack(TamperAttack[TamperAttackConfig]): |
There was a problem hiding this comment.
do we need a custom WandaPruningAttackConfig for this attack?
| @@ -0,0 +1,398 @@ | |||
| import time | |||
There was a problem hiding this comment.
can we add a note on where the code was sourced from in header doc-string? (that is if it was sourced externally - ping me if it wasn't)
Changes
Added Wanda pruning as an attack paradigm. Normally, pruning is benign (for efficiency). But if applied maliciously, it can be a form of model tampering.
Testing
[Test in Progress]