Skip to content

Official Codebase for Evaluating Adversarial Robustness of Concept Representations in Sparse Autoencoders (EACL 2026)

License

Notifications You must be signed in to change notification settings

AI4LIFE-GROUP/sae_robustness

Repository files navigation

SAE Robustness: Adversarial Attacks on Sparse Autoencoders

This codebase implements adversarial attacks to test the robustness of Sparse Autoencoder (SAE) feature representations for large language models (LLaMA 3-8B, Gemma 2-9B/2B).

Code Structure

sae_robustness/
├── main.py                      # Entry point - dispatches to attack functions
├── src/
│   ├── config.py                # CLI arguments, constants (DEVICE, BASE_DIR, CACHE_DIR)
│   ├── metrics.py               # Overlap metrics (count_common, get_overlap)
│   ├── loader.py                # Model/SAE loading (load_model_and_sae)
│   ├── features.py              # SAE feature extraction (extract_sae_features, jump_relu, compute_signed_preacts)
│   └── attacks/
│       ├── suffix.py            # Suffix-based attacks (append adversarial tokens)
│       └── replacement.py       # Replacement-based attacks (replace existing tokens)
├── sae/                         # SAE package (encoder/decoder implementation)
├── data/                        # CSV datasets (art_science.csv, ag_news.csv, etc.)
└── results/                     # Attack logs (created when --log is used)

Attack Types

Attack Modes

  • Suffix: Append adversarial tokens to the end of the input
  • Replacement: Replace individual tokens in the input

Attack Levels

  • Individual: Target specific SAE neurons (activate or deactivate)
  • Population: Shift the entire SAE feature distribution

Targeting

  • Targeted: Match a specific target text's SAE features
  • Untargeted: Disrupt the original SAE features

Usage

Basic Command Structure

python main.py \
  --mode {suffix,replacement} \
  --level {individual,population} \
  [--targeted] \
  --model_type {llama3-8b,gemma2-9b-131k,gemma2-9b-16k,gemma2-2b-65k,gemma2-2b-16k} \
  --layer_num LAYER \
  --data_file DATASET \
  [--log]

Example Commands

1. Individual Suffix Attack (Untargeted, Deactivate)

Deactivate specific SAE neurons by appending adversarial tokens:

python main.py \
  --mode suffix \
  --level individual \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --num_latents 5 \
  --suffix_len 1 \
  --batch_size 100 \
  --num_iters 10 \
  --log

2. Individual Suffix Attack (Targeted, Activate)

Activate neurons from a target text x2 that are not present in x1:

python main.py \
  --mode suffix \
  --level individual \
  --targeted \
  --activate \
  --model_type llama3-8b \
  --layer_num 20 \
  --data_file art_science \
  --sample_idx 15 \
  --num_latents 5 \
  --suffix_len 1 \
  --batch_size 100 \
  --num_iters 10 \
  --log

3. Population Suffix Attack (Targeted)

Maximize overlap between x1's SAE features and x2's features by appending tokens:

python main.py \
  --mode suffix \
  --level population \
  --targeted \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --suffix_len 3 \
  --batch_size 200 \
  --num_iters 50 \
  --m 300 \
  --log

4. Population Suffix Attack (Untargeted)

Minimize overlap between perturbed x1 and original x1 features:

python main.py \
  --mode suffix \
  --level population \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --suffix_len 3 \
  --batch_size 200 \
  --num_iters 20 \
  --m 300 \
  --log

5. Individual Replacement Attack (Untargeted, Deactivate)

Deactivate neurons by replacing individual tokens (tests each position):

python main.py \
  --mode replacement \
  --level individual \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --num_latents 5 \
  --batch_size 100 \
  --num_iters 10 \
  --log

6. Population Replacement Attack (Targeted)

Maximize feature overlap by replacing tokens one position at a time:

python main.py \
  --mode replacement \
  --level population \
  --targeted \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --batch_size 100 \
  --num_iters 10 \
  --m 300 \
  --log

7. Random Baseline

Test a random baseline (no gradient guidance):

python main.py \
  --mode replacement \
  --level population \
  --targeted \
  --random \
  --model_type gemma2-9b-131k \
  --layer_num 30 \
  --data_file art_science \
  --sample_idx 20 \
  --batch_size 100 \
  --num_iters 10 \
  --log

Key Parameters

Parameter Description Typical Values
--mode Attack strategy suffix, replacement
--level Granularity individual, population
--targeted Match target features? flag (omit for untargeted)
--activate Activate (vs deactivate) neurons? flag (only for individual)
--model_type LLM + SAE architecture llama3-8b, gemma2-9b-131k, etc.
--layer_num Hidden layer index 20 (LLaMA), 30 (Gemma-9B), 16 (Gemma-2B)
--data_file Dataset name art_science, ag_news, sst2, safety
--sample_idx Row index in CSV 0-499 (depends on dataset)
--num_latents # neurons to attack 5 (individual only)
--suffix_len # tokens to append 1 (individual), 3 (population)
--batch_size Candidates per iteration 100-200
--num_iters Optimization iterations 10-50
--m Top-m candidate tokens 200-300
--k Top-k SAE features 192 (LLaMA), 170 (Gemma)
--log Save logs to results/? flag
--random Random baseline? flag

Output

  • Individual attacks report success rate: fraction of target neurons successfully activated/deactivated.
  • Population attacks report relative overlap change: (final_overlap - initial_overlap) / initial_overlap.

When --log is used, detailed logs are saved to:

results/{model_type}-{data_file}/layer-{layer_num}/{attack_config}.txt

Algorithm

All attacks use Greedy Coordinate Gradient (GCG) optimization:

  1. Compute gradients of the SAE objective w.r.t. input embeddings
  2. Project gradients onto the embedding matrix to find top-m candidate tokens
  3. Sample a batch of candidates and evaluate them
  4. Keep the best candidate and repeat

Objectives:

  • Individual: Maximize (activate) or minimize (deactivate) log_softmax(z)[neuron_id]
  • Population: Maximize (targeted) or minimize (untargeted) cosine similarity between SAE feature vectors

About

Official Codebase for Evaluating Adversarial Robustness of Concept Representations in Sparse Autoencoders (EACL 2026)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 5