Skip to content

defense: RSN-Tune#47

Open
mruwnik wants to merge 13 commits intomainfrom
rsn-tune
Open

defense: RSN-Tune#47
mruwnik wants to merge 13 commits intomainfrom
rsn-tune

Conversation

@mruwnik
Copy link
Collaborator

@mruwnik mruwnik commented Dec 11, 2025

Changes

Implements RSN-Tune, per this paper, as a defense.

Testing

Run tests/defenses/test_rsn_tune.py attack --tier large to test this on Llama-3.1-8B-Instruct. The tier parameter allows selecting smaller models to see if things don't blow up. The command will:

  1. RSN-Tune the selected model (this takes a while, like an hour or so for llama with the default settings - lower the detection and training samples for it to go faster)
  2. Runs an attack (full parameter finetune) on both the base model and the hardened model
  3. Evaluates both with strong reject, and prints the difference
  4. Runs a couple of test prompts (e.g. "how do I make a bomb") on the base model, the attacked base model and the hardened model after the attack, to see if they differ in refusals. Though this is checked with quick'n'dirty regexes, so might not really be representative - the idea being to have a basic sanity test at the end

It turns out this defense is more of a "it helps when you're training on benign data" rather than a "this will help against real adversaries", though there is a slight improvement:


StrongReject scores (higher = more harmful):
  Baseline → Attack:  0.5455
  Hardened → Attack:  0.4947
  Difference:         0.0508

✓ RSN-Tune reduced harmfulness by 9.3%

@@ -0,0 +1,626 @@
"""Evaluation for RSN-Tune defense.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is mainly Claude, with me reading through and suggesting some changes

# =============================================================================


def run_threshold_sweep(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be used to work out what thresholds to use - the default is 1, but that doesn't always give the best result - or any on smaller models

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like quite a useful function, can we have it in a separate script? in scripts/defenses/rsn_tune.py perhaps?

@sdhossain sdhossain changed the title Rsn tune defense: RSN-Tune Dec 11, 2025
@sdhossain sdhossain added the defense Adds or modifies defenses label Dec 11, 2025
Copy link
Collaborator

@tomtseng tomtseng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review the implementation (and I probably won't look carefully for any of the defenses, since the implementations are so long), but if we can roughly replicate some result of the paper (e.g., the "Llama2-7B-Chat" column of Table 4 of the paper) then I'm comfortable merging.
For TAR PR #39 we copied over code from another implementation rather than doing our own implementation, so that's an option to consider in future PRs. With that being said the RSN-Tune codebase looks a bit confusing so I'm happy to have a cleaned up implementation as long as we've done enough testing.

@sdhossain I'd recommend you review the rough structure (and similarly not spend time on the implementation) and give the yea/nay on approving the PR — will this have any conflicts or mismatches with the TAR PR #39?

- max_length: 512

Implementation Differences from Paper:
1. **Optimizer** (PRACTICAL): Uses Adafactor instead of unspecified optimizer in paper.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could check the original codebase https://github.com/zhaoyiran924/Safety-Neuron/ to see what optimizer and batch size they use.

Copy link
Collaborator

@sdhossain sdhossain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice stuff, I concur with @tomtseng regarding (if we reproduce a result, etc. could reliably say we have reproduced their defense, etc.) and thus are good to merge.

I did have some nit comments, which I'm not fully sure if they should be backlog items (primarily revolves around re-organizing some code / breaking things down into scripts) -- my guess would be, as long as things work, we are probably good to merge.

Another note is - could we do uv run ruff check <path/to/changed_files> --fix (if not done already).

torch.cuda.empty_cache()


def generate_responses(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could move a function like this to our utils submodule in safetunebed.whitebox.

Could potentially be useful (or something we standardize later).

return responses


def count_refusals(responses: list[str]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel like the constants REFUSAL_INDICATORS, etc. could be in some file in src as they could have utility beyond just these tests. (not super important though)



@dataclass
class AttackConfig:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not have existing configs for attacks?

# =============================================================================


def run_threshold_sweep(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like quite a useful function, can we have it in a separate script? in scripts/defenses/rsn_tune.py perhaps?

epochs: int = 1


def apply_rsn_tune(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can just use the following without wrapping it into a function:

    rsn_config = RSNTuneConfig(
        input_checkpoint_path=Path(model_name),
        output_checkpoint_path=output_path,
        num_detection_samples=num_detection_samples,
        num_training_samples=num_training_samples,
        safety_importance_threshold=safety_threshold,
        foundation_importance_threshold=foundation_threshold,
        use_robust_mode=True,
    )
    rsn = RSNTune(defence_config=rsn_config)
    safety_neurons, foundation_neurons = rsn.tune_safety_neurons()
    ```
    
I would personally prefer to kind of have any applying of that wrapped inside a class function -- so that any necessary utility inside those class functions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the rsn .run_defense() function should also take care of this (runs defense if checkpoint not created, and if it checkpoint exists, only returns path)

@@ -0,0 +1,625 @@
"""Evaluation for RSN-Tune defense.

This test suite provides two main functionalities:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a lot of these functionalities could be useful to have as scripts. Maybe we can house them in scripts/whitebox/defenses/rsn_tune.py.

I think a script for the threshold sweep, and one of the attack comparison make sense.

My main rationale is I think it is nicer to have a test file be a simpler sanity check for things like (does it run without crashing + does it not fail on an obvious task it should succeed in, i.e. we can compare the StrongREJECT scores before and after alignment - or a simple FinetuningAttack, although given their sensitivity to hyper-parameters not sure we need that).

For functions that provide certain utilities that could be shared (imo - makes sense to have those in src and not in scripts file.

For example the TAR defense does the test this way: https://github.com/sdhossain/SafeTuneBed/blob/ec2bd9622f12de88ee6e08ff2c03f4859dad503d/tests/defenses/test_tar.py

(just checks if it ran)

For additional experimentation would suggest having them as individual scripts, with more isolated functionality for better usability.

Would defer to @tomtseng on whether breaking these down into individual scripts is better (the inputs can just be a checkpoint). I imagine, it can also be done in a follow-up PR and backlogged as other defenses might be better.


@dataclass
class RSNTuneConfig(AlignmentDefenseConfig):
"""Configuration for RSN-Tune defense."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have the attributes documented in the docstring (what they mean?) we have comments for some of them already.

return cls(**data)


class _GradientMaskCallback(TrainerCallback):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: I think this callback and the function below could live in another file?

@sdhossain sdhossain self-requested a review December 15, 2025 12:07
Copy link
Collaborator

@sdhossain sdhossain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am good to have this merged provided a quick sanity check that it aligns with the paper results as mentioned.

If we could have a checkpoint uploaded to HuggingFace for Llama3-8B-Instruct or Llama-3-8B-Base that could be quite useful too.

@mruwnik
Copy link
Collaborator Author

mruwnik commented Dec 22, 2025

I've done a long sanity check and can't seem to reproduce the paper results. It consistently gives a slight advantage over the base model, but it's like 2-5%, not the 30% reported by the paper. I can try to keep trying things, but I don't know how much it's worth the bother

@tomtseng
Copy link
Collaborator

tomtseng commented Jan 14, 2026

Yeah sounds like you've done your due diligence but can't figure out why there's a discrepancy, so probably not worth staring at this code much more, thanks for spending the time investigating this thorny issue.

At the same time for the purpose of this benchmark I don't think it feels appropriate to merge a defense implementation that might not actually be working. I think the last thing I would suggest checking is, trying to run the original codebase https://github.com/zhaoyiran924/Safety-Neuron and seeing whether we can reproduce results there.

  • If we also can't reproduce results there then we can merge and say "hey this defense just doesn't seem to work, including with the original codebase—seems flawed or sensitive".
  • If we can reproduce results then can try to replace this custom implementation with a copy of their codebase, or try to figure out where the discrepancy is (ask Claude Code to figure out what the major differences are?) — is it a hyperparameter thing or is it an algorithm implementation bug or is it a difference in how we evaluate it, or maybe this defense super sensitive to seemingly irrelevant implementation details in a way that should make us skeptical of it

"""Detect important neurons using the configured detection strategy."""
path = self.defence_config.input_checkpoint_path
model = load_model(path)
tokenizer = load_tokenizer(path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tokenizer = load_tokenizer(path)
tokenizer = load_tokenizer(path)
model.resize_token_embeddings(new_num_tokens=len(tokenizer))

had to add this change, was getting this error:

k: [43,0,0], thread: [473,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [474,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [475,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [476,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [477,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [478,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
/tmp/torchinductor_dev/gj/cgjpugzaj3wojmlv7cpshoif6enk2fvvclqa2wir6g2pwcoxt4md.py:41: unknown: block: [43,0,0], thread: [479,0,0] Assertion `index out of bounds: 0 <= tmp4 < 128256` failed.
Traceback (most recent call last):
  File "/home/dev/SafeTuneBed/scripts/rsn_tune/harden.py", line 67, in <module>
    main()
  File "/home/dev/SafeTuneBed/scripts/rsn_tune/harden.py", line 62, in main
    output = rsn.run_defense()
             ^^^^^^^^^^^^^^^^^
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 276, in run_defense
    self.tune_safety_neurons()
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 236, in tune_safety_neurons
    safety_neurons = self._detect_safety_neurons()
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 285, in _detect_safety_neurons
    return self._detect_neurons(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/rsn_tune.py", line 311, in _detect_neurons
    neurons = detect(
              ^^^^^^^
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/detection.py", line 477, in detect
    neuron_importance = detect_raw(model, tokenizer, dataset, is_harmful, chunk_size)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dev/SafeTuneBed/src/tamperbench/whitebox/defenses/rsn_tune/detection.py", line 448, in detect_raw
    model(**inputs)
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 414, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 832, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/transformers/utils/generic.py", line 912, in wrapper
    @wraps(func)
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1044, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1130, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 353, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 129, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 724, in inner_fn
    outs = compiled_fn(args)
           ^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 526, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/output_code.py", line 613, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_dev/g3/cg3vz3nesx4qhgj6cfjqltjkco4syzl7zgu2c6xtd6adiom2dfe2.py", line 4443, in call
    (buf2, buf5, buf6, buf7, buf8, buf18, buf21, buf22, buf27, buf28, buf29, buf30, buf31, buf41, buf45, buf46, buf51, buf52, buf53, buf54, buf55, buf65, buf68, buf69, buf74, buf75, buf76, buf77, buf78, buf88, buf92, buf93, buf98, buf99, buf100, buf101, buf102, buf112, buf115, buf116, buf121, buf122, buf123, buf124, buf125, buf135, buf139, buf140, buf145, buf146, buf147, buf148, buf149, buf159, buf162, buf163, buf168, buf169, buf170, buf171, buf172, buf182, buf186, buf187, buf192, buf193, buf194, buf195, buf196, buf206, buf209, buf210, buf215, buf216, buf217, buf218, buf219, buf229, buf233, buf234, buf239, buf240, buf241, buf242, buf243, buf253, buf256, buf257, buf262, buf263, buf264, buf265, buf266, buf276, buf280, buf281, buf286, buf287, buf288, buf289, buf290, buf300, buf303, buf304, buf309, buf310, buf311, buf312, buf313, buf323, buf327, buf328, buf333, buf334, buf335, buf336, buf337, buf347, buf350, buf351, buf356, buf357, buf358, buf359, buf360, buf370, buf374, buf375, buf380, buf381, buf382, buf383, buf384, buf394, buf397, buf398, buf403, buf404, buf405, buf406, buf407, buf417, buf421, buf422, buf427, buf428, buf429, buf430, buf431, buf441, buf444, buf445, buf450, buf451, buf452, buf453, buf454, buf464, buf468, buf469, buf474, buf475, buf476, buf477, buf478, buf488, buf491, buf492, buf497, buf498, buf499, buf500, buf501, buf511, buf515, buf516, buf521, buf522, buf523, buf524, buf525, buf535, buf538, buf539, buf544, buf545, buf546, buf547, buf548, buf558, buf562, buf563, buf568, buf569, buf570, buf571, buf572, buf582, buf585, buf586, buf591, buf592, buf593, buf594, buf595, buf605, buf609, buf610, buf615, buf616, buf617, buf618, buf619, buf629, buf632, buf633, buf638, buf639, buf640, buf641, buf642, buf652, buf656, buf657, buf662, buf663, buf664, buf665, buf666, buf676, buf679, buf680, buf685, buf686, buf687, buf688, buf689, buf699, buf703, buf704, buf709, buf710, buf711, buf712, buf713, buf723, buf726, buf727, buf732, buf733, buf734, buf735, buf736, buf746, buf750, buf751, buf756) = self.partitions[0](partition0_args)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1772, in run
    return compiled_fn(new_inputs)  # type: ignore[arg-type]
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 404, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 463, in cudagraphify
    return manager.add_function(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2316, in add_function
    return fn, fn(inputs)
               ^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2012, in run
    out = self._run(new_inputs, function_id)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2116, in _run
    return self.run_eager(new_inputs, function_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 2277, in run_eager
    return node.run(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/cudagraph_trees.py", line 686, in run
    out = self.wrapped_function.model(new_inputs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_dev/g3/cg3vz3nesx4qhgj6cfjqltjkco4syzl7zgu2c6xtd6adiom2dfe2.py", line 2113, in partition_0
    triton_red_fused__to_copy_add_embedding_mean_mul_pow_rsqrt_0.run(arg0_1, arg1_1, arg4_1, buf1, 72, 4096, stream=stream0)
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1272, in run
    self.autotune_to_one_config(*args, **kwargs)
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1048, in autotune_to_one_config
    timings = self.benchmark_all_configs(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1023, in benchmark_all_configs
    launcher: self.bench(launcher, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 891, in bench
    return benchmarker.benchmark_gpu(kernel_call, rep=40)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 39, in wrapper
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/_inductor/runtime/benchmarking.py", line 251, in benchmark_gpu
    torch.cuda.synchronize()
  File "/opt/uv/venv/lib/python3.12/site-packages/torch/cuda/__init__.py", line 1083, in synchronize
    return torch._C._cuda_synchronize()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

defense Adds or modifies defenses

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants