Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ I wrote a tutorial to show users how to do some basic exploration of their SAE:
- Understanding SAE Features with the Logit Lens [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
- Training a Sparse Autoencoder [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)

### Community Tutorials

- Cross-SAE Feature Alignment with FeatureMatch [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/featurematch_cross_sae.ipynb) - Quantify how similar two SAEs' learned dictionaries are using cosine-based alignment ([external package](https://github.com/Course-Correct-Labs/featurematch))

## Example WandB Dashboard

WandB Dashboards provide lots of useful insights while training SAEs. Here's a screenshot from one training run.
Expand Down
145 changes: 145 additions & 0 deletions tutorials/featurematch_cross_sae.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cross-SAE Feature Alignment (FeatureMatch, cosine + top-k)\n",
"\n",
"This example shows how to quantify correspondence between two SAE dictionaries using **FeatureMatch**.\n",
"\n",
"🧪 **v0.1 scope**: cosine similarity matrix, top-k per-feature matches, summary stats, and a simple heatmap.\n",
"\n",
"👉 By default this notebook runs on **synthetic data** (so it works anywhere). Replace the synthetic block with the **SAELens code collection** cell to use real models.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If running on Colab or a fresh env, uncomment the line below to install from GitHub:\n",
"# !pip install \"git+https://github.com/Course-Correct-Labs/featurematch.git\"\n",
"\n",
"import torch\n",
"from featurematch.featurematch import align_features\n",
"from featurematch.viz import plot_heatmap\n",
"import matplotlib.pyplot as plt\n",
"\n",
"torch.manual_seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Option A: Synthetic demo (default)\n",
"This section creates a permutation-aligned pair of code matrices (`Z_a`, `Z_b_perm`) and a random baseline (`Z_b_rand`).\n",
"\n",
"- Expect **perfect alignment** for the permutation case (mean≈1.0).\n",
"- Expect **low alignment** for random codes (mean≈0.15–0.20)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"N, K = 200, 64\n",
"Z_a = torch.randn(N, K)\n",
"perm = torch.randperm(K)\n",
"P = torch.zeros(K, K)\n",
"P[torch.arange(K), perm] = 1.0\n",
"Z_b_perm = Z_a @ P # permutation case (perfect alignment)\n",
"Z_b_rand = torch.randn(N, K) # random baseline\n",
"\n",
"res_perm = align_features(Z_a, Z_b_perm, topk=5, threshold=0.8, device=\"cpu\")\n",
"print(\"Permutation case stats:\", res_perm.stats)\n",
"plot_heatmap(res_perm.cosine, title=\"FeatureMatch: Cosine (Permutation)\")\n",
"plt.show()\n",
"\n",
"res_rand = align_features(Z_a, Z_b_rand, topk=5, threshold=0.8, device=\"cpu\")\n",
"print(\"Random case stats:\", res_rand.stats)\n",
"plot_heatmap(res_rand.cosine, title=\"FeatureMatch: Cosine (Random)\")\n",
"plt.show()\n",
"\n",
"res_perm.top_matches[:3] # preview first 3 rows"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Option B: Collect codes from real SAEs using SAELens (uncomment to use)\n",
"Use this section when you have two trained SAEs (same hook/layer) and an evaluation dataset. The **only requirement** is that both code matrices are `[N, K]` and derived from the **same** tokens/examples.\n",
"\n",
"❗️Note: keep batch size modest to avoid OOM during code collection. Subsample if needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %% Real-SAEs example (template) -----------------------------------------\n",
"# from sae_lens import SAE\n",
"# import torch\n",
"# from featurematch.featurematch import align_features\n",
"# from featurematch.viz import plot_heatmap\n",
"# import matplotlib.pyplot as plt\n",
"\n",
"# # 1) Load your two SAEs (same hook/layer)\n",
"# sae_a = SAE.load_from_pretrained(\"PATH/OR/ALIAS/TO/SAE_A\")\n",
"# sae_b = SAE.load_from_pretrained(\"PATH/OR/ALIAS/TO/SAE_B\")\n",
"# sae_a.eval(); sae_b.eval()\n",
"\n",
"# # 2) Prepare evaluation tokens (same for both)\n",
"# # tokens: LongTensor [N, T] or as required by your pipeline\n",
"# tokens = ... # your eval batch(es)\n",
"\n",
"# # 3) Collect codes Z_a, Z_b (shape [N, K])\n",
"# with torch.no_grad():\n",
"# _, Z_a, _ = sae_a(tokens) # adjust to your SAE forward signature\n",
"# _, Z_b, _ = sae_b(tokens)\n",
"\n",
"# # 4) Align & visualize\n",
"# res = align_features(Z_a, Z_b, topk=5, threshold=0.8, device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"# print(\"Alignment stats:\", res.stats)\n",
"# plot_heatmap(res.cosine, title=\"Cross-SAE Feature Alignment (Cosine)\")\n",
"# plt.show()\n",
"\n",
"# # 5) Inspect top matches for first few features\n",
"# res.top_matches[:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Interpretation (v0.1 heuristics)\n",
"- **mean_best ≥ 0.85**: strong reproducibility (dictionaries mostly aligned)\n",
"- **0.70–0.85**: partial alignment (seeds/hparams differ)\n",
"- **< 0.70**: low alignment (different dictionaries)\n",
"- `% above threshold` (default 0.8): quick sanity metric; >60% typically indicates similar runs\n",
"\n",
"**Important:** Always compare codes from the **same hook/layer** on the **same dataset**."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}