This codebase implements adversarial attacks to test the robustness of Sparse Autoencoder (SAE) feature representations for large language models (LLaMA 3-8B, Gemma 2-9B/2B).
sae_robustness/
├── main.py # Entry point - dispatches to attack functions
├── src/
│ ├── config.py # CLI arguments, constants (DEVICE, BASE_DIR, CACHE_DIR)
│ ├── metrics.py # Overlap metrics (count_common, get_overlap)
│ ├── loader.py # Model/SAE loading (load_model_and_sae)
│ ├── features.py # SAE feature extraction (extract_sae_features, jump_relu, compute_signed_preacts)
│ └── attacks/
│ ├── suffix.py # Suffix-based attacks (append adversarial tokens)
│ └── replacement.py # Replacement-based attacks (replace existing tokens)
├── sae/ # SAE package (encoder/decoder implementation)
├── data/ # CSV datasets (art_science.csv, ag_news.csv, etc.)
└── results/ # Attack logs (created when --log is used)
- Suffix: Append adversarial tokens to the end of the input
- Replacement: Replace individual tokens in the input
- Individual: Target specific SAE neurons (activate or deactivate)
- Population: Shift the entire SAE feature distribution
- Targeted: Match a specific target text's SAE features
- Untargeted: Disrupt the original SAE features
python main.py \
--mode {suffix,replacement} \
--level {individual,population} \
[--targeted] \
--model_type {llama3-8b,gemma2-9b-131k,gemma2-9b-16k,gemma2-2b-65k,gemma2-2b-16k} \
--layer_num LAYER \
--data_file DATASET \
[--log]Deactivate specific SAE neurons by appending adversarial tokens:
python main.py \
--mode suffix \
--level individual \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--num_latents 5 \
--suffix_len 1 \
--batch_size 100 \
--num_iters 10 \
--logActivate neurons from a target text x2 that are not present in x1:
python main.py \
--mode suffix \
--level individual \
--targeted \
--activate \
--model_type llama3-8b \
--layer_num 20 \
--data_file art_science \
--sample_idx 15 \
--num_latents 5 \
--suffix_len 1 \
--batch_size 100 \
--num_iters 10 \
--logMaximize overlap between x1's SAE features and x2's features by appending tokens:
python main.py \
--mode suffix \
--level population \
--targeted \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--suffix_len 3 \
--batch_size 200 \
--num_iters 50 \
--m 300 \
--logMinimize overlap between perturbed x1 and original x1 features:
python main.py \
--mode suffix \
--level population \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--suffix_len 3 \
--batch_size 200 \
--num_iters 20 \
--m 300 \
--logDeactivate neurons by replacing individual tokens (tests each position):
python main.py \
--mode replacement \
--level individual \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--num_latents 5 \
--batch_size 100 \
--num_iters 10 \
--logMaximize feature overlap by replacing tokens one position at a time:
python main.py \
--mode replacement \
--level population \
--targeted \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--batch_size 100 \
--num_iters 10 \
--m 300 \
--logTest a random baseline (no gradient guidance):
python main.py \
--mode replacement \
--level population \
--targeted \
--random \
--model_type gemma2-9b-131k \
--layer_num 30 \
--data_file art_science \
--sample_idx 20 \
--batch_size 100 \
--num_iters 10 \
--log| Parameter | Description | Typical Values |
|---|---|---|
--mode |
Attack strategy | suffix, replacement |
--level |
Granularity | individual, population |
--targeted |
Match target features? | flag (omit for untargeted) |
--activate |
Activate (vs deactivate) neurons? | flag (only for individual) |
--model_type |
LLM + SAE architecture | llama3-8b, gemma2-9b-131k, etc. |
--layer_num |
Hidden layer index | 20 (LLaMA), 30 (Gemma-9B), 16 (Gemma-2B) |
--data_file |
Dataset name | art_science, ag_news, sst2, safety |
--sample_idx |
Row index in CSV | 0-499 (depends on dataset) |
--num_latents |
# neurons to attack | 5 (individual only) |
--suffix_len |
# tokens to append | 1 (individual), 3 (population) |
--batch_size |
Candidates per iteration | 100-200 |
--num_iters |
Optimization iterations | 10-50 |
--m |
Top-m candidate tokens | 200-300 |
--k |
Top-k SAE features | 192 (LLaMA), 170 (Gemma) |
--log |
Save logs to results/? |
flag |
--random |
Random baseline? | flag |
- Individual attacks report success rate: fraction of target neurons successfully activated/deactivated.
- Population attacks report relative overlap change:
(final_overlap - initial_overlap) / initial_overlap.
When --log is used, detailed logs are saved to:
results/{model_type}-{data_file}/layer-{layer_num}/{attack_config}.txt
All attacks use Greedy Coordinate Gradient (GCG) optimization:
- Compute gradients of the SAE objective w.r.t. input embeddings
- Project gradients onto the embedding matrix to find top-m candidate tokens
- Sample a batch of candidates and evaluate them
- Keep the best candidate and repeat
Objectives:
- Individual: Maximize (activate) or minimize (deactivate)
log_softmax(z)[neuron_id] - Population: Maximize (targeted) or minimize (untargeted) cosine similarity between SAE feature vectors