Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/safetunebed/whitebox/defenses/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
"""Alignment-stage defenses against tampering supported by safetunebed."""

from safetunebed.whitebox.defenses.rsn_tune import RSNTune, RSNTuneConfig

__all__ = ["RSNTune", "RSNTuneConfig"]
29 changes: 29 additions & 0 deletions src/safetunebed/whitebox/defenses/rsn_tune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Robust Safety Neuron Tuning (RSN-Tune) defense.

Citation:
Zhao, Y., Zhang, W., Xie, Y., Goyal, A., Kawaguchi, K., & Shieh, M. Q. (2025).
Understanding and Enhancing Safety Mechanisms of LLMs via Safety-Specific Neuron.
In International Conference on Learning Representations (ICLR) 2025.
Paper: https://arxiv.org/abs/[TBD]
Code: https://github.com/zhaoyiran924/Safety-Neuron

RSN-Tune is a defense method that enhances LLM safety robustness during downstream
task fine-tuning by:
1. Identifying safety neurons (neurons crucial for handling harmful queries)
2. Identifying foundation neurons (neurons responsible for general tasks)
3. Tuning only the safety neurons that don't overlap with foundation neurons
"""

from safetunebed.whitebox.defenses.rsn_tune.callbacks import GradientMaskCallback
from safetunebed.whitebox.defenses.rsn_tune.detection import detect
from safetunebed.whitebox.defenses.rsn_tune.rsn_tune import (
RSNTune,
RSNTuneConfig,
)

__all__ = [
"GradientMaskCallback",
"RSNTune",
"RSNTuneConfig",
"detect",
]
57 changes: 57 additions & 0 deletions src/safetunebed/whitebox/defenses/rsn_tune/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Training callbacks for RSN-Tune defense."""

import torch
from transformers import PreTrainedModel, TrainerCallback


def make_mask_hook(mask: torch.Tensor):
"""Create gradient hook that applies mask, with dtype caching."""
cache: list[torch.Tensor | None] = [None]

def hook(grad: torch.Tensor) -> torch.Tensor | None:
if grad is None:
return None
if cache[0] is None or cache[0].dtype != grad.dtype:
cache[0] = mask.to(dtype=grad.dtype, device=grad.device)
return grad * cache[0]

return hook


class GradientMaskCallback(TrainerCallback):
"""Masks gradients during training to only update specified neurons.

Applies gradient masks to parameters with selected neurons, zeroing gradients
for non-selected dimensions. Parameters without any selected neurons should
have requires_grad=False set before training to save optimizer memory.

Args:
model: The model being trained.
masks: Dict mapping parameter names to mask tensors. Masks should be float
tensors with 1.0 for neurons to train and 0.0 for neurons to freeze.
"""

def __init__(
self,
model: PreTrainedModel,
masks: dict[str, torch.Tensor],
):
"""Initialize the gradient mask callback."""
self.model = model
self.masks = masks
self.hooks: list = []

def on_train_begin(self, args, state, control, **kwargs):
"""Register gradient hooks at the start of training."""
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if name in self.masks:
# Partial mask: only allow gradients for selected neurons
self.hooks.append(param.register_hook(make_mask_hook(self.masks[name])))

def on_train_end(self, args, state, control, **kwargs):
"""Remove gradient hooks at the end of training."""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
Loading