Conversation
tests/defenses/test_rsn_tune.py
Outdated
| @@ -0,0 +1,626 @@ | |||
| """Evaluation for RSN-Tune defense. | |||
There was a problem hiding this comment.
this file is mainly Claude, with me reading through and suggesting some changes
tests/defenses/test_rsn_tune.py
Outdated
| # ============================================================================= | ||
|
|
||
|
|
||
| def run_threshold_sweep( |
There was a problem hiding this comment.
This can be used to work out what thresholds to use - the default is 1, but that doesn't always give the best result - or any on smaller models
There was a problem hiding this comment.
This seems like quite a useful function, can we have it in a separate script? in scripts/defenses/rsn_tune.py perhaps?
tomtseng
left a comment
There was a problem hiding this comment.
I didn't review the implementation (and I probably won't look carefully for any of the defenses, since the implementations are so long), but if we can roughly replicate some result of the paper (e.g., the "Llama2-7B-Chat" column of Table 4 of the paper) then I'm comfortable merging.
For TAR PR #39 we copied over code from another implementation rather than doing our own implementation, so that's an option to consider in future PRs. With that being said the RSN-Tune codebase looks a bit confusing so I'm happy to have a cleaned up implementation as long as we've done enough testing.
@sdhossain I'd recommend you review the rough structure (and similarly not spend time on the implementation) and give the yea/nay on approving the PR — will this have any conflicts or mismatches with the TAR PR #39?
| - max_length: 512 | ||
|
|
||
| Implementation Differences from Paper: | ||
| 1. **Optimizer** (PRACTICAL): Uses Adafactor instead of unspecified optimizer in paper. |
There was a problem hiding this comment.
Could check the original codebase https://github.com/zhaoyiran924/Safety-Neuron/ to see what optimizer and batch size they use.
sdhossain
left a comment
There was a problem hiding this comment.
Nice stuff, I concur with @tomtseng regarding (if we reproduce a result, etc. could reliably say we have reproduced their defense, etc.) and thus are good to merge.
I did have some nit comments, which I'm not fully sure if they should be backlog items (primarily revolves around re-organizing some code / breaking things down into scripts) -- my guess would be, as long as things work, we are probably good to merge.
Another note is - could we do uv run ruff check <path/to/changed_files> --fix (if not done already).
tests/defenses/test_rsn_tune.py
Outdated
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| def generate_responses( |
There was a problem hiding this comment.
we could move a function like this to our utils submodule in safetunebed.whitebox.
Could potentially be useful (or something we standardize later).
tests/defenses/test_rsn_tune.py
Outdated
| return responses | ||
|
|
||
|
|
||
| def count_refusals(responses: list[str]) -> int: |
There was a problem hiding this comment.
nit: I feel like the constants REFUSAL_INDICATORS, etc. could be in some file in src as they could have utility beyond just these tests. (not super important though)
tests/defenses/test_rsn_tune.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class AttackConfig: |
There was a problem hiding this comment.
do we not have existing configs for attacks?
tests/defenses/test_rsn_tune.py
Outdated
| # ============================================================================= | ||
|
|
||
|
|
||
| def run_threshold_sweep( |
There was a problem hiding this comment.
This seems like quite a useful function, can we have it in a separate script? in scripts/defenses/rsn_tune.py perhaps?
tests/defenses/test_rsn_tune.py
Outdated
| epochs: int = 1 | ||
|
|
||
|
|
||
| def apply_rsn_tune( |
There was a problem hiding this comment.
nit: I think we can just use the following without wrapping it into a function:
rsn_config = RSNTuneConfig(
input_checkpoint_path=Path(model_name),
output_checkpoint_path=output_path,
num_detection_samples=num_detection_samples,
num_training_samples=num_training_samples,
safety_importance_threshold=safety_threshold,
foundation_importance_threshold=foundation_threshold,
use_robust_mode=True,
)
rsn = RSNTune(defence_config=rsn_config)
safety_neurons, foundation_neurons = rsn.tune_safety_neurons()
```
I would personally prefer to kind of have any applying of that wrapped inside a class function -- so that any necessary utility inside those class functions.
There was a problem hiding this comment.
I believe the rsn .run_defense() function should also take care of this (runs defense if checkpoint not created, and if it checkpoint exists, only returns path)
tests/defenses/test_rsn_tune.py
Outdated
| @@ -0,0 +1,625 @@ | |||
| """Evaluation for RSN-Tune defense. | |||
|
|
|||
| This test suite provides two main functionalities: | |||
There was a problem hiding this comment.
I think a lot of these functionalities could be useful to have as scripts. Maybe we can house them in scripts/whitebox/defenses/rsn_tune.py.
I think a script for the threshold sweep, and one of the attack comparison make sense.
My main rationale is I think it is nicer to have a test file be a simpler sanity check for things like (does it run without crashing + does it not fail on an obvious task it should succeed in, i.e. we can compare the StrongREJECT scores before and after alignment - or a simple FinetuningAttack, although given their sensitivity to hyper-parameters not sure we need that).
For functions that provide certain utilities that could be shared (imo - makes sense to have those in src and not in scripts file.
For example the TAR defense does the test this way: https://github.com/sdhossain/SafeTuneBed/blob/ec2bd9622f12de88ee6e08ff2c03f4859dad503d/tests/defenses/test_tar.py
(just checks if it ran)
For additional experimentation would suggest having them as individual scripts, with more isolated functionality for better usability.
Would defer to @tomtseng on whether breaking these down into individual scripts is better (the inputs can just be a checkpoint). I imagine, it can also be done in a follow-up PR and backlogged as other defenses might be better.
|
|
||
| @dataclass | ||
| class RSNTuneConfig(AlignmentDefenseConfig): | ||
| """Configuration for RSN-Tune defense.""" |
There was a problem hiding this comment.
can we have the attributes documented in the docstring (what they mean?) we have comments for some of them already.
| return cls(**data) | ||
|
|
||
|
|
||
| class _GradientMaskCallback(TrainerCallback): |
There was a problem hiding this comment.
minor nit: I think this callback and the function below could live in another file?
sdhossain
left a comment
There was a problem hiding this comment.
I am good to have this merged provided a quick sanity check that it aligns with the paper results as mentioned.
If we could have a checkpoint uploaded to HuggingFace for Llama3-8B-Instruct or Llama-3-8B-Base that could be quite useful too.
|
I've done a long sanity check and can't seem to reproduce the paper results. It consistently gives a slight advantage over the base model, but it's like 2-5%, not the 30% reported by the paper. I can try to keep trying things, but I don't know how much it's worth the bother |
|
Yeah sounds like you've done your due diligence but can't figure out why there's a discrepancy, so probably not worth staring at this code much more, thanks for spending the time investigating this thorny issue. At the same time for the purpose of this benchmark I don't think it feels appropriate to merge a defense implementation that might not actually be working. I think the last thing I would suggest checking is, trying to run the original codebase https://github.com/zhaoyiran924/Safety-Neuron and seeing whether we can reproduce results there.
|
| """Detect important neurons using the configured detection strategy.""" | ||
| path = self.defence_config.input_checkpoint_path | ||
| model = load_model(path) | ||
| tokenizer = load_tokenizer(path) |
There was a problem hiding this comment.
| tokenizer = load_tokenizer(path) | |
| tokenizer = load_tokenizer(path) | |
| model.resize_token_embeddings(new_num_tokens=len(tokenizer)) |
had to add this change, was getting this error:
k: [43,0,0], thread: [473,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [474,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [475,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [476,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [477,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [478,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [479,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
Traceback (most recent call last):
File "/home/dev/SafeTuneBed/scripts/rsn_tune/harden.py", line 67, in <module>
main()
File "/home/dev/SafeTuneBed/scripts/rsn_tune/harden.py", line 62, in main
output = rsn.run_defense()
^^^^^^^^^^^^^^^^^
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 276, in run_defense
self.tune_safety_neurons()
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 236, in tune_safety_neurons
safety_neurons = self._detect_safety_neurons()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 285, in _detect_safety_neurons
return self._detect_neurons(
^^^^^^^^^^^^^^^^^^^^^
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 311, in _detect_neurons
neurons = detect(
^^^^^^^
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/detection.py", line 477, in detect
neuron_importance = detect_raw(model, tokenizer, dataset, is_harmful, chunk_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/detection.py", line 448, in detect_raw
model(**inputs)
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 912, in wrapper
@wraps(func)
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1130, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 353, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 129, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 724, in inner_fn
outs = compiled_fn(args)
^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 526, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 613, in __call__
return self.current_callable(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_dev/g3/cg3vz3nesx4qhgj6cfjqltjkco4syzl7zgu2c6xtd6adiom2dfe2.py", line 4443, in call
(buf2, buf5, buf6, buf7, buf8, buf18, buf21, buf22, buf27, buf28, buf29, buf30, buf31, buf41, buf45, buf46, buf51, buf52, buf53, buf54, buf55, buf65, buf68, buf69, buf74, buf75, buf76, buf77, buf78, buf88, buf92, buf93, buf98, buf99, buf100, buf101, buf102, buf112, buf115, buf116, buf121, buf122, buf123, buf124, buf125, buf135, buf139, buf140, buf145, buf146, buf147, buf148, buf149, buf159, buf162, buf163, buf168, buf169, buf170, buf171, buf172, buf182, buf186, buf187, buf192, buf193, buf194, buf195, buf196, buf206, buf209, buf210, buf215, buf216, buf217, buf218, buf219, buf229, buf233, buf234, buf239, buf240, buf241, buf242, buf243, buf253, buf256, buf257, buf262, buf263, buf264, buf265, buf266, buf276, buf280, buf281, buf286, buf287, buf288, buf289, buf290, buf300, buf303, buf304, buf309, buf310, buf311, buf312, buf313, buf323, buf327, buf328, buf333, buf334, buf335, buf336, buf337, buf347, buf350, buf351, buf356, buf357, buf358, buf359, buf360, buf370, buf374, buf375, buf380, buf381, buf382, buf383, buf384, buf394, buf397, buf398, buf403, buf404, buf405, buf406, buf407, buf417, buf421, buf422, buf427, buf428, buf429, buf430, buf431, buf441, buf444, buf445, buf450, buf451, buf452, buf453, buf454, buf464, buf468, buf469, buf474, buf475, buf476, buf477, buf478, buf488, buf491, buf492, buf497, buf498, buf499, buf500, buf501, buf511, buf515, buf516, buf521, buf522, buf523, buf524, buf525, buf535, buf538, buf539, buf544, buf545, buf546, buf547, buf548, buf558, buf562, buf563, buf568, buf569, buf570, buf571, buf572, buf582, buf585, buf586, buf591, buf592, buf593, buf594, buf595, buf605, buf609, buf610, buf615, buf616, buf617, buf618, buf619, buf629, buf632, buf633, buf638, buf639, buf640, buf641, buf642, buf652, buf656, buf657, buf662, buf663, buf664, buf665, buf666, buf676, buf679, buf680, buf685, buf686, buf687, buf688, buf689, buf699, buf703, buf704, buf709, buf710, buf711, buf712, buf713, buf723, buf726, buf727, buf732, buf733, buf734, buf735, buf736, buf746, buf750, buf751, buf756) = self.partitions[0](partition0_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1772, in run
return compiled_fn(new_inputs) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 404, in deferred_cudagraphify
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 463, in cudagraphify
return manager.add_function(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2316, in add_function
return fn, fn(inputs)
^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2012, in run
out = self._run(new_inputs, function_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2116, in _run
return self.run_eager(new_inputs, function_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2277, in run_eager
return node.run(new_inputs)
^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 686, in run
out = self.wrapped_function.model(new_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/torchinductor_dev/g3/cg3vz3nesx4qhgj6cfjqltjkco4syzl7zgu2c6xtd6adiom2dfe2.py", line 2113, in partition_0
triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg0_1, arg1_1, arg4_1, buf1, 72, 4096, stream=stream0)
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1272, in run
self.autotune_to_one_config(*args, **kwargs)
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1048, in autotune_to_one_config
timings = self.benchmark_all_configs(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1023, in benchmark_all_configs
launcher: self.bench(launcher, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 891, in bench
return benchmarker.benchmark_gpu(kernel_call, rep=40)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
return fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 251, in benchmark_gpu
torch.cuda.synchronize()
File "/opt/uv/venv/lib/python3.12/site-packages/torch/cuda/__init__.py", line 1083, in synchronize
return torch._C._cuda_synchronize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Changes
Implements RSN-Tune, per this paper, as a defense.
Testing
Run
tests/defenses/test_rsn_tune.py attack --tier largeto test this on Llama-3.1-8B-Instruct. Thetierparameter allows selecting smaller models to see if things don't blow up. The command will:It turns out this defense is more of a "it helps when you're training on benign data" rather than a "this will help against real adversaries", though there is a slight improvement: