diff --git a/.vscode/launch.json b/.vscode/launch.json
index 5ce7aef..a753e33 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -37,5 +37,29 @@
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
+ {
+ "name": "lm",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py",
+ "args": "${workspaceFolder}/spd/experiments/lm/lm_config.yaml",
+ "console": "integratedTerminal",
+ "justMyCode": true,
+ "env": {
+ "PYDEVD_DISABLE_FILE_VALIDATION": "1"
+ }
+ },
+ {
+ "name": "lm streamlit",
+ "type": "debugpy",
+ "request": "launch",
+ "module": "streamlit",
+ "args": [
+ "run",
+ "${workspaceFolder}/spd/experiments/lm/app.py",
+ "--server.port",
+ "2000"
+ ]
+ }
]
}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 49e6f8a..4031cfc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,10 @@ dependencies = [
"python-dotenv",
"wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248
"sympy",
+ "streamlit",
+ "streamlit-antd-components",
+ "datasets",
+ "simple-stories-train"
]
[project.optional-dependencies]
diff --git a/spd/configs.py b/spd/configs.py
index 1622900..d039054 100644
--- a/spd/configs.py
+++ b/spd/configs.py
@@ -36,6 +36,20 @@ class ResidualMLPTaskConfig(BaseModel):
pretrained_model_path: ModelPath # e.g. wandb:spd-resid-mlp/runs/j9kmavzi
+class LMTaskConfig(BaseModel):
+ model_config = ConfigDict(extra="forbid", frozen=True)
+ task_name: Literal["lm"] = "lm"
+ model_size: str # e.g. "1.25M"
+ max_seq_len: PositiveInt = 512
+ buffer_size: PositiveInt = 1000
+ dataset_name: str = "lennart-finke/SimpleStories"
+ train_data_split: str = "train"
+ eval_data_split: str = "test"
+ n_eval_steps: PositiveInt = 100
+ # List of fnmatch patterns for nn.Linear modules to decompose
+ target_module_patterns: list[str] = ["transformer.h.*.mlp.*_proj"]
+
+
class Config(BaseModel):
model_config = ConfigDict(extra="forbid", frozen=True)
wandb_project: str | None = None
@@ -68,7 +82,9 @@ class Config(BaseModel):
unit_norm_matrices: bool = False
attribution_type: Literal["gradient"] = "gradient"
n_gate_hidden_neurons: PositiveInt | None = None
- task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name")
+ task_config: TMSTaskConfig | ResidualMLPTaskConfig | LMTaskConfig = Field(
+ ..., discriminator="task_name"
+ )
DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = []
RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {}
diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py
new file mode 100644
index 0000000..d880198
--- /dev/null
+++ b/spd/experiments/lm/app.py
@@ -0,0 +1,405 @@
+"""
+To run this app, run the following command:
+
+```bash
+ streamlit run spd/experiments/lm/app.py -- --model_path "wandb:spd-lm/runs/151bsctx"
+```
+"""
+
+import argparse
+import html
+from collections.abc import Callable, Iterator
+from dataclasses import dataclass
+from typing import Any
+
+import streamlit as st
+import torch
+from datasets import load_dataset
+from jaxtyping import Float, Int
+from simple_stories_train.dataloaders import DatasetConfig
+from torch import Tensor
+from transformers import AutoTokenizer
+
+from spd.configs import Config, LMTaskConfig
+from spd.experiments.lm.models import LinearComponentWithBias, SSModel
+from spd.log import logger
+from spd.models.components import Gate, GateMLP
+from spd.run_spd import calc_component_acts, calc_masks
+from spd.types import ModelPath
+
+DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx"
+
+
+# -----------------------------------------------------------
+# Dataclass holding everything the app needs
+# -----------------------------------------------------------
+@dataclass(frozen=True)
+class AppData:
+ model: SSModel
+ tokenizer: AutoTokenizer
+ config: Config
+ dataloader_iter_fn: Callable[[], Iterator[dict[str, Any]]]
+ gates: dict[str, Gate | GateMLP]
+ components: dict[str, LinearComponentWithBias]
+ target_layer_names: list[str]
+ device: str
+
+
+# --- Initialization and Data Loading ---
+@st.cache_resource(show_spinner="Loading model and data...")
+def initialize(model_path: ModelPath) -> AppData:
+ """
+ Loads the model, tokenizer, config, and evaluation dataloader.
+ Cached by Streamlit based on the model_path.
+ """
+ device = "cpu" # Use CPU for the Streamlit app
+ logger.info(f"Initializing app with model: {model_path} on device: {device}")
+ ss_model, config, _ = SSModel.from_pretrained(model_path)
+ ss_model.to(device)
+ ss_model.eval()
+
+ task_config = config.task_config
+ assert isinstance(task_config, LMTaskConfig), "Task config must be LMTaskConfig for this app."
+
+ # Derive tokenizer path (adjust if stored differently)
+ tokenizer_path = f"chandan-sreedhara/SimpleStories-{task_config.model_size}"
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_path,
+ add_bos_token=False,
+ unk_token="[UNK]",
+ eos_token="[EOS]",
+ bos_token=None,
+ )
+
+ # Create eval dataloader config
+ eval_data_config = DatasetConfig(
+ name=task_config.dataset_name,
+ tokenizer_file_path=None,
+ hf_tokenizer_path=tokenizer_path,
+ split=task_config.eval_data_split,
+ n_ctx=task_config.max_seq_len,
+ is_tokenized=False,
+ streaming=False, # Non-streaming might be simpler for iterator reset
+ column_name="story",
+ )
+
+ # Create the dataloader iterator
+ def create_dataloader_iter() -> Iterator[dict[str, Any]]:
+ """
+ Returns a *new* iterator each time it is called.
+ Each element is a dict with:
+ - "text": the raw document text
+ - "input_ids": Int[Tensor, "1 seq_len"]
+ - "offset_mapping": list[tuple[int, int]]
+ """
+ logger.info("Creating new dataloader iterator.")
+
+ # Stream the HF dataset split
+ dataset = load_dataset(
+ eval_data_config.name,
+ streaming=eval_data_config.streaming,
+ split=eval_data_config.split,
+ trust_remote_code=False,
+ )
+
+ dataset = dataset.with_format("torch")
+
+ text_column = eval_data_config.column_name
+
+ def tokenize_and_prepare(example: dict[str, Any]) -> dict[str, Any]:
+ original_text: str = example[text_column]
+
+ tokenized = tokenizer(
+ original_text,
+ return_tensors="pt",
+ return_offsets_mapping=True,
+ truncation=True,
+ max_length=task_config.max_seq_len,
+ padding=False,
+ )
+
+ input_ids: Int[Tensor, "1 seq_len"] = tokenized["input_ids"]
+ if input_ids.dim() == 1: # Ensure 2‑D [1, seq_len]
+ input_ids = input_ids.unsqueeze(0)
+
+ # HF returns offset_mapping as a list per sequence; batch size is 1
+ offset_mapping: list[tuple[int, int]] = tokenized["offset_mapping"][0].tolist()
+
+ return {
+ "text": original_text,
+ "input_ids": input_ids,
+ "offset_mapping": offset_mapping,
+ }
+
+ # Map over the streaming dataset and return an iterator
+ return map(tokenize_and_prepare, iter(dataset))
+
+ # Extract components and gates
+ gates: dict[str, Gate | GateMLP] = {
+ k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items()
+ } # type: ignore[reportAssignmentType]
+ components: dict[str, LinearComponentWithBias] = {
+ k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items()
+ } # type: ignore[reportAssignmentType]
+ target_layer_names = sorted(list(components.keys()))
+
+ logger.info(f"Initialization complete for {model_path}.")
+ return AppData(
+ model=ss_model,
+ tokenizer=tokenizer,
+ config=config,
+ dataloader_iter_fn=create_dataloader_iter,
+ gates=gates,
+ components=components,
+ target_layer_names=target_layer_names,
+ device=device,
+ )
+
+
+# -----------------------------------------------------------
+# Utility: render the prompt with faint token outlines
+# -----------------------------------------------------------
+def render_prompt_with_tokens(
+ *,
+ raw_text: str,
+ offset_mapping: list[tuple[int, int]],
+ selected_idx: int | None,
+) -> None:
+ """
+ Renders `raw_text` inside Streamlit, wrapping each token span with a thin
+ border. The currently‑selected token receives a thicker red border.
+ All other tokens get a thin mid‑grey border (no background fill).
+ """
+ html_chunks: list[str] = []
+ cursor = 0
+
+ def esc(s: str) -> str:
+ return html.escape(s)
+
+ for idx, (start, end) in enumerate(offset_mapping):
+ if cursor < start:
+ html_chunks.append(esc(raw_text[cursor:start]))
+
+ token_substr = esc(raw_text[start:end])
+ if token_substr:
+ is_selected = idx == selected_idx
+ border_style = (
+ "2px solid rgb(200,0,0)" if is_selected else "0.5px solid #aaa" # all other tokens
+ )
+ html_chunks.append(
+ "'
+ f"{token_substr}"
+ )
+ cursor = end
+
+ if cursor < len(raw_text):
+ html_chunks.append(esc(raw_text[cursor:]))
+
+ st.markdown(
+ f'
{"".join(html_chunks)}
',
+ unsafe_allow_html=True,
+ )
+
+
+def load_next_prompt() -> None:
+ """Loads the next prompt, calculates masks, and prepares token data."""
+ logger.info("Loading next prompt.")
+ app_data: AppData = st.session_state.app_data
+ dataloader_iter = st.session_state.dataloader_iter # Get current iterator
+
+ try:
+ batch = next(dataloader_iter)
+ input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data.device)
+ except StopIteration:
+ logger.warning("Dataloader iterator exhausted. Throwing error.")
+ st.error("Failed to get data even after resetting dataloader.")
+ return
+
+ st.session_state.current_input_ids = input_ids
+
+ # Store the original raw prompt text
+ st.session_state.current_prompt_text = batch["text"]
+
+ # Calculate activations and masks
+ with torch.no_grad():
+ (_, _), pre_weight_acts = app_data.model.forward_with_pre_forward_cache_hooks(
+ input_ids, module_names=list(app_data.components.keys())
+ )
+ As = {module_name: v.linear_component.A for module_name, v in app_data.components.items()}
+ target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore[reportArgumentType]
+ masks, _ = calc_masks(
+ gates=app_data.gates,
+ target_component_acts=target_component_acts,
+ attributions=None,
+ detach_inputs=True, # No gradients needed
+ )
+ st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]]
+
+ # Prepare token data for display
+ token_data = []
+ tokenizer = app_data.tokenizer
+ for i, token_id in enumerate(input_ids[0]):
+ # Decode individual token - might differ slightly from full decode for spaces etc.
+ decoded_token_str = tokenizer.decode([token_id]) # type: ignore[reportAttributeAccessIssue]
+ token_data.append(
+ {
+ "id": token_id.item(),
+ "text": decoded_token_str,
+ "index": i,
+ "offset": batch["offset_mapping"][i], # (start, end)
+ }
+ )
+ st.session_state.token_data = token_data
+
+ # Reset selections
+ st.session_state.selected_token_index = 0 # default: first token
+ st.session_state.selected_layer_name = None
+ logger.info("Finished loading next prompt and calculating masks.")
+
+
+# --- Main App UI ---
+def run_app(args: argparse.Namespace) -> None:
+ """Sets up and runs the Streamlit application."""
+ st.set_page_config(layout="wide")
+ st.title("LM Component Activation Explorer")
+
+ # Initialize model, data, etc. (cached)
+ st.session_state.app_data = initialize(args.model_path)
+ app_data: AppData = st.session_state.app_data
+ st.caption(f"Model: {args.model_path}")
+
+ # Initialize session state variables if they don't exist
+ if "current_prompt_text" not in st.session_state:
+ st.session_state.current_prompt_text = None
+ if "token_data" not in st.session_state:
+ st.session_state.token_data = None
+ if "current_masks" not in st.session_state:
+ st.session_state.current_masks = None
+ if "selected_token_index" not in st.session_state:
+ st.session_state.selected_token_index = None
+ if "selected_layer_name" not in st.session_state:
+ if app_data.target_layer_names:
+ st.session_state.selected_layer_name = app_data.target_layer_names[0]
+ else:
+ st.session_state.selected_layer_name = None
+ # Initialize the dataloader iterator in session state
+ if "dataloader_iter" not in st.session_state:
+ st.session_state.dataloader_iter = app_data.dataloader_iter_fn()
+
+ if st.session_state.current_prompt_text is None:
+ load_next_prompt()
+
+ # Sidebar container and a single expander for all interactive controls
+ sidebar = st.sidebar
+ controls_expander = sidebar.expander("Controls", expanded=True)
+
+ # ------------------------------------------------------------------
+ # Sidebar – interactive controls
+ # ------------------------------------------------------------------
+ with controls_expander:
+ st.button("Load Next Prompt", on_click=load_next_prompt)
+
+ # Render the raw prompt with faint token borders
+ if st.session_state.token_data and st.session_state.current_prompt_text:
+ # st.subheader("Prompt")
+ render_prompt_with_tokens(
+ raw_text=st.session_state.current_prompt_text,
+ offset_mapping=[t["offset"] for t in st.session_state.token_data],
+ selected_idx=st.session_state.selected_token_index,
+ )
+
+ # Sidebar slider for token selection
+ n_tokens = len(st.session_state.token_data)
+ if n_tokens > 0:
+ with controls_expander:
+ st.header("Token selector")
+ idx = st.slider(
+ "Token index",
+ min_value=0,
+ max_value=n_tokens - 1,
+ step=1,
+ key="selected_token_index",
+ )
+
+ selected_token = st.session_state.token_data[idx]
+ st.write(f"Selected token: {selected_token['text']} (ID: {selected_token['id']})")
+
+ st.divider()
+
+ # --- Token Information Area ---
+ if st.session_state.token_data:
+ idx = st.session_state.selected_token_index
+ # Ensure token_data is loaded before accessing
+ if (
+ st.session_state.token_data
+ and idx is not None
+ and idx < len(st.session_state.token_data)
+ ):
+ # Layer Selection Dropdown
+ # Always default to the first layer if nothing is selected yet
+ if st.session_state.selected_layer_name is None and app_data.target_layer_names:
+ st.session_state.selected_layer_name = app_data.target_layer_names[0]
+
+ with controls_expander:
+ st.header("Layer selector")
+ st.selectbox(
+ "Select Layer to Inspect:",
+ options=app_data.target_layer_names,
+ key="selected_layer_name",
+ )
+
+ # Display Layer-Specific Info if a layer is selected
+ if st.session_state.selected_layer_name:
+ layer_name = st.session_state.selected_layer_name
+ logger.debug(f"Displaying info for token {idx}, layer {layer_name}")
+
+ if st.session_state.current_masks is None:
+ st.warning("Masks not calculated yet. Please load a prompt.")
+ return
+
+ layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[
+ layer_name
+ ]
+ token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :]
+
+ # Find active components (mask > 0)
+ active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0]
+ n_active_layer = len(active_indices_layer)
+
+ st.metric(f"Active Components in {layer_name}", n_active_layer)
+
+ st.subheader("Active Component Indices")
+ if n_active_layer > 0:
+ # Convert to NumPy array and reshape to a column vector (N x 1)
+ active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1)
+ # Pass the NumPy array directly and configure the column header
+ st.dataframe(active_indices_np, height=300, use_container_width=False)
+ else:
+ st.write("No active components for this token in this layer.")
+
+ # Extensibility Placeholder
+ st.subheader("Additional Layer/Token Analysis")
+ st.write(
+ "Future figures and analyses for this specific layer and token will appear here."
+ )
+ else:
+ # Handle case where selected_token_index might be invalid after data reload
+ st.warning("Selected token index is out of bounds. Please select a token again.")
+ st.session_state.selected_token_index = None # Reset selection
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Streamlit app to explore LM component activations."
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default=DEFAULT_MODEL_PATH,
+ help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}",
+ )
+ args = parser.parse_args()
+
+ run_app(args)
diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py
new file mode 100644
index 0000000..1258495
--- /dev/null
+++ b/spd/experiments/lm/component_viz.py
@@ -0,0 +1,167 @@
+"""
+Vizualises the components of the model.
+"""
+
+import math
+
+import torch
+from jaxtyping import Float
+from matplotlib import pyplot as plt
+from simple_stories_train.dataloaders import DatasetConfig, create_data_loader
+from torch import Tensor
+from torch.utils.data import DataLoader
+
+from spd.configs import LMTaskConfig
+from spd.experiments.lm.models import LinearComponentWithBias, SSModel
+from spd.log import logger
+from spd.models.components import Gate, GateMLP
+from spd.run_spd import calc_component_acts, calc_masks
+from spd.types import ModelPath
+
+
+def component_activation_statistics(
+ model: SSModel,
+ dataloader: DataLoader[Float[Tensor, "batch pos"]],
+ n_steps: int,
+ device: str,
+) -> tuple[dict[str, float], dict[str, Float[Tensor, " m"]]]:
+ """Get the number and strength of the masks over the full dataset."""
+ # We used "-" instead of "." as module names can't have "." in them
+ gates: dict[str, Gate | GateMLP] = {
+ k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items()
+ } # type: ignore
+ components: dict[str, LinearComponentWithBias] = {
+ k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items()
+ } # type: ignore
+
+ n_tokens = {module_name.replace("-", "."): 0 for module_name in components}
+ total_n_active_components = {module_name.replace("-", "."): 0 for module_name in components}
+ component_activation_counts = {
+ module_name.replace("-", "."): torch.zeros(model.m, device=device)
+ for module_name in components
+ }
+ data_iter = iter(dataloader)
+ for _ in range(n_steps):
+ # --- Get Batch --- #
+ batch = next(data_iter)["input_ids"].to(device)
+
+ _, pre_weight_acts = model.forward_with_pre_forward_cache_hooks(
+ batch, module_names=list(components.keys())
+ )
+ As = {module_name: v.linear_component.A for module_name, v in components.items()}
+
+ target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore
+
+ masks, relud_masks = calc_masks(
+ gates=gates,
+ target_component_acts=target_component_acts,
+ attributions=None,
+ detach_inputs=False,
+ )
+ for module_name, mask in masks.items():
+ assert mask.ndim == 3 # (batch_size, pos, m)
+ n_tokens[module_name] += mask.shape[0] * mask.shape[1]
+ # Count the number of components that are active at all
+ active_components = mask > 0
+ total_n_active_components[module_name] += int(active_components.sum().item())
+ component_activation_counts[module_name] += active_components.sum(dim=(0, 1))
+
+ # Show the mean number of components
+ mean_n_active_components_per_token: dict[str, float] = {
+ module_name: (total_n_active_components[module_name] / n_tokens[module_name])
+ for module_name in components
+ }
+ mean_component_activation_counts: dict[str, Float[Tensor, " m"]] = {
+ module_name: component_activation_counts[module_name] / n_tokens[module_name]
+ for module_name in components
+ }
+
+ return mean_n_active_components_per_token, mean_component_activation_counts
+
+
+def plot_mean_component_activation_counts(
+ mean_component_activation_counts: dict[str, Float[Tensor, " m"]],
+) -> plt.Figure:
+ """Plots the mean activation counts for each component module in a grid."""
+ n_modules = len(mean_component_activation_counts)
+ max_cols = 6
+ n_cols = min(n_modules, max_cols)
+ # Calculate the number of rows needed, rounding up
+ n_rows = math.ceil(n_modules / n_cols)
+
+ # Create a figure with the calculated number of rows and columns
+ fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False)
+ # Ensure axs is always a 2D array for consistent indexing, even if n_modules is 1
+ axs = axs.flatten() # Flatten the axes array for easy iteration
+
+ # Iterate through modules and plot each histogram on its corresponding axis
+ for i, (module_name, counts) in enumerate(mean_component_activation_counts.items()):
+ ax = axs[i]
+ ax.hist(counts.detach().cpu().numpy(), bins=100)
+ ax.set_title(module_name) # Add module name as title to each subplot
+ ax.set_xlabel("Mean Activation Count")
+ ax.set_ylabel("Frequency")
+
+ # Hide any unused subplots if the grid isn't perfectly filled
+ for i in range(n_modules, n_rows * n_cols):
+ axs[i].axis("off")
+
+ # Adjust layout to prevent overlapping titles/labels
+ fig.tight_layout()
+
+ return fig
+
+
+def main(path: ModelPath) -> None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ ss_model, config, checkpoint_path = SSModel.from_pretrained(path)
+ ss_model.to(device)
+
+ out_dir = checkpoint_path
+
+ assert isinstance(config.task_config, LMTaskConfig)
+ dataset_config = DatasetConfig(
+ name=config.task_config.dataset_name,
+ tokenizer_file_path=None,
+ hf_tokenizer_path=f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}",
+ split=config.task_config.train_data_split,
+ n_ctx=config.task_config.max_seq_len,
+ is_tokenized=False,
+ streaming=False,
+ column_name="story",
+ )
+
+ dataloader, tokenizer = create_data_loader(
+ dataset_config=dataset_config,
+ batch_size=config.batch_size,
+ buffer_size=config.task_config.buffer_size,
+ global_seed=config.seed,
+ ddp_rank=0,
+ ddp_world_size=1,
+ )
+ # print(ss_model)
+ print(config)
+
+ mean_n_active_components_per_token, mean_component_activation_counts = (
+ component_activation_statistics(
+ model=ss_model,
+ dataloader=dataloader,
+ n_steps=100,
+ device=device,
+ )
+ )
+ logger.info(f"n_components: {ss_model.m}")
+ logger.info(f"mean_n_active_components_per_token: {mean_n_active_components_per_token}")
+ logger.info(f"mean_component_activation_counts: {mean_component_activation_counts}")
+ fig = plot_mean_component_activation_counts(
+ mean_component_activation_counts=mean_component_activation_counts,
+ )
+ # Save the entire figure once
+ save_path = out_dir / "modules_mean_component_activation_counts.png"
+ fig.savefig(save_path)
+ logger.info(f"Saved combined plot to {str(save_path)}")
+
+
+if __name__ == "__main__":
+ path = "wandb:spd-lm/runs/151bsctx"
+ main(path)
diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml
new file mode 100644
index 0000000..0b1260f
--- /dev/null
+++ b/spd/experiments/lm/lm_config.yaml
@@ -0,0 +1,72 @@
+# --- WandB ---
+wandb_project: spd-lm
+# wandb_project: null # Project name for Weights & Biases
+wandb_run_name: null # Set specific run name (optional, otherwise generated)
+wandb_run_name_prefix: "" # Prefix for generated run name
+
+# --- General ---
+seed: 0
+unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here)
+m: 10000 # Rank of the decomposition / number of components per layer
+
+# --- Loss Coefficients ---
+# Set coeffs to null if the loss shouldn't be computed
+param_match_coeff: 1.0
+out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE)
+lp_sparsity_coeff: 1e-1 # Coefficient for Lp sparsity loss (applied to component params A & B)
+pnorm: 2.0 # p-value for the Lp sparsity norm
+layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks
+
+# Placeholder losses (set coeffs to null as they require mask calculation implementation)
+masked_recon_coeff: null # Reconstruction loss using masks
+act_recon_coeff: null # Reconstruction loss on intermediate component activations
+random_mask_recon_coeff: null # Reconstruction loss averaged over random masks
+layerwise_recon_coeff: null # Layer-wise reconstruction loss
+
+n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used
+n_gate_hidden_neurons: null # Not applicable as there are no gates currently
+
+# --- Training ---
+batch_size: 4 # Adjust based on GPU memory
+steps: 1_000 # Total training steps
+lr: 1e-3 # Learning rate
+lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential)
+lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup
+lr_exponential_halflife: null # Required if lr_schedule is exponential
+init_from_target_model: false # Not implemented/applicable for this setup
+
+# --- Logging & Saving ---
+image_freq: 1000 # Frequency for generating/logging plots
+print_freq: 100 # Frequency for printing logs to console
+save_freq: 1_000 # Frequency for saving checkpoints
+image_on_first_step: true # Whether to log plots at step 0
+
+# --- Task Specific ---
+task_config:
+ task_name: lm # Specifies the LM decomposition task
+ model_size: "1.25M" # SimpleStories model size (e.g., "1.25M", "5M", "11M", "30M", "35M")
+ max_seq_len: 512 # Maximum sequence length for truncation/padding
+ buffer_size: 1000 # Buffer size for streaming dataset shuffling
+ dataset_name: "lennart-finke/SimpleStories" # HuggingFace dataset name
+ train_data_split: "train" # Dataset split to use
+ eval_data_split: "test" # Dataset split to use
+ n_eval_steps: 100 # Number of evaluation steps
+ # List of fnmatch patterns for nn.Linear modules to decompose
+ target_module_patterns: ["transformer.h.0.mlp.gate_proj"]
+ # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"]
+ # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"]
+ # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"]
+
+# Config details for the target model taken from https://github.com/danbraunai/simple_stories_train/blob/main/simple_stories_train/models/model_configs.py#L54
+ # "1.25M": LlamaConfig(
+ # block_size=512,
+ # vocab_size=4096,
+ # n_layer=4,
+ # n_head=4,
+ # n_embd=128,
+ # n_intermediate=128 * 4 * 2 // 3 = 341,
+ # rotary_dim=128 // 4 = 32,
+ # n_ctx=512,
+ # n_key_value_heads=2,
+ # flash_attention=True,
+ # ),
\ No newline at end of file
diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py
new file mode 100644
index 0000000..41f5ce6
--- /dev/null
+++ b/spd/experiments/lm/lm_decomposition.py
@@ -0,0 +1,545 @@
+"""Language Model decomposition script."""
+
+from collections.abc import Callable
+from datetime import datetime
+from pathlib import Path
+
+import einops
+import fire
+import matplotlib.pyplot as plt
+import torch
+import torch.optim as optim
+import wandb
+import yaml
+from jaxtyping import Float
+from simple_stories_train.dataloaders import DatasetConfig, create_data_loader
+from simple_stories_train.models.llama import Llama
+from simple_stories_train.models.model_configs import MODEL_CONFIGS
+from torch import Tensor
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from spd.configs import Config, LMTaskConfig
+from spd.experiments.lm.component_viz import (
+ component_activation_statistics,
+ plot_mean_component_activation_counts,
+)
+from spd.experiments.lm.models import LinearComponentWithBias, SSModel
+from spd.log import logger
+from spd.models.components import Gate, GateMLP
+from spd.run_spd import (
+ _calc_param_mse,
+ calc_component_acts,
+ calc_mask_l_zero,
+ calc_masks,
+ calc_random_masks,
+ get_common_run_name_suffix,
+)
+from spd.utils import (
+ get_device,
+ get_lr_schedule_fn,
+ get_lr_with_warmup,
+ load_config,
+ set_seed,
+)
+from spd.wandb_utils import init_wandb
+
+wandb.require("core")
+
+
+def get_run_name(
+ config: Config,
+ model_size: str,
+ max_seq_len: int,
+) -> str:
+ """Generate a run name based on the config."""
+ run_suffix = ""
+ if config.wandb_run_name:
+ run_suffix = config.wandb_run_name
+ else:
+ run_suffix = get_common_run_name_suffix(config)
+ run_suffix += f"_lm{model_size}_seq{max_seq_len}"
+ return config.wandb_run_name_prefix + run_suffix
+
+
+def lm_plot_results_fn(
+ model: SSModel,
+ components: dict[str, LinearComponentWithBias],
+ step: int | None,
+ out_dir: Path | None,
+ device: str,
+ config: Config,
+ **_,
+) -> dict[str, plt.Figure]:
+ """Plotting function for LM decomposition. Placeholder for now."""
+ # TODO: Implement actual plotting (e.g., component matrix values?)
+ logger.info(f"Plotting results at step {step}...")
+ fig_dict: dict[str, plt.Figure] = {}
+ # Example: Potentially plot A/B matrix norms or sparsity patterns?
+ # fig_dict["component_norms"] = plot_component_norms(components, out_dir, step)
+ return fig_dict
+
+
+def calc_recon_mse_lm(
+ out1: Float[Tensor, "batch pos vocab"],
+ out2: Float[Tensor, "batch pos vocab"],
+) -> Float[Tensor, ""]:
+ """Calculate the Mean Squared Error reconstruction loss for LM logits."""
+ assert out1.shape == out2.shape
+ # Mean over batch and sequence length, sum over vocab
+ return ((out1 - out2) ** 2).sum(dim=-1).mean()
+
+
+def calc_param_match_loss_lm(
+ components: dict[str, LinearComponentWithBias],
+ target_model: Llama,
+ n_params: int,
+ device: str,
+) -> Float[Tensor, ""]:
+ """Calculate the MSE loss between component parameters (A@B + bias) and target parameters."""
+ target_params: dict[str, Float[Tensor, "d_in d_out"]] = {}
+ component_params: dict[str, Float[Tensor, "d_in d_out"]] = {}
+
+ for comp_name, component in components.items():
+ component_params[comp_name] = einops.einsum(
+ component.linear_component.A,
+ component.linear_component.B,
+ "d_in m, m d_out -> d_in d_out",
+ )
+ target_params[comp_name] = target_model.get_parameter(comp_name + ".weight").T
+ assert component_params[comp_name].shape == target_params[comp_name].shape
+
+ param_mse = _calc_param_mse(
+ params1=component_params,
+ params2=target_params,
+ n_params=n_params,
+ device=device,
+ )
+ return param_mse
+
+
+def calc_layerwise_recon_loss_lm(
+ model: SSModel,
+ batch: Float[Tensor, "batch pos"],
+ device: str,
+ components: dict[str, LinearComponentWithBias],
+ masks: list[dict[str, Float[Tensor, "batch pos m"]]],
+ target_out: Float[Tensor, "batch pos vocab"],
+) -> Float[Tensor, ""]:
+ """Calculate the recon loss when augmenting the model one (masked) component at a time."""
+ total_loss = torch.tensor(0.0, device=device)
+ for mask_info in masks:
+ for component_name, component in components.items():
+ module_name = component_name.replace("-", ".")
+ modified_out, _ = model.forward_with_component(
+ batch,
+ module_name=module_name,
+ component=component,
+ mask=mask_info.get(component_name, None),
+ )
+ loss = calc_recon_mse_lm(modified_out, target_out)
+ total_loss += loss
+ n_modified_components = len(masks[0])
+ return total_loss / (n_modified_components * len(masks))
+
+
+def calc_lp_sparsity_loss_lm(
+ relud_masks: dict[str, Float[Tensor, "batch pos m"]], pnorm: float
+) -> Float[Tensor, ""]:
+ """Calculate the Lp sparsity loss on the attributions.
+
+ Args:
+ relud_masks: Dictionary of relu masks for each layer.
+ pnorm: The pnorm to use for the sparsity loss.
+ Returns:
+ The Lp sparsity loss.
+ """
+ # Initialize with zeros matching the shape of first mask
+ total_loss = torch.zeros_like(next(iter(relud_masks.values())))
+
+ for layer_relud_mask in relud_masks.values():
+ total_loss = total_loss + layer_relud_mask**pnorm
+
+ # Sum over the m dimension and mean over the batch and pos dimensions
+ return total_loss.sum(dim=-1).mean(dim=[0, 1])
+
+
+def optimize_lm(
+ model: SSModel,
+ config: Config,
+ device: str,
+ train_loader: DataLoader[Float[Tensor, "batch pos"]],
+ eval_loader: DataLoader[Float[Tensor, "batch pos"]],
+ n_eval_steps: int,
+ plot_results_fn: Callable[..., dict[str, plt.Figure]],
+ out_dir: Path | None,
+) -> None:
+ """Run the optimization loop for LM decomposition."""
+
+ # We used "-" instead of "." as module names can't have "." in them
+ gates: dict[str, Gate | GateMLP] = {
+ k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items()
+ } # type: ignore
+ components: dict[str, LinearComponentWithBias] = {
+ k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items()
+ } # type: ignore
+
+ component_params = []
+ param_names_to_optimize = []
+ for name, component in components.items():
+ component_params.extend(list(component.parameters()))
+ param_names_to_optimize.extend(
+ [f"{name}.{p_name}" for p_name, _ in component.named_parameters()]
+ )
+ logger.debug(f"Adding parameters from component: {name}")
+
+ if not component_params:
+ logger.error("No parameters found in components to optimize. Exiting.")
+ return
+
+ optimizer = optim.AdamW(component_params, lr=config.lr, weight_decay=0.0)
+ logger.info(f"Optimizer created for params: {param_names_to_optimize}")
+
+ lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife)
+ logger.info(f"Base LR scheduler created: {config.lr_schedule}")
+
+ n_params = 0
+ for module_name in components:
+ weight = model.model.get_parameter(module_name + ".weight")
+ n_params += weight.numel()
+
+ log_data = {}
+ data_iter = iter(train_loader)
+
+ # Use tqdm directly in the loop, iterate one extra step for final logging/plotting/saving
+ for step in tqdm(range(config.steps + 1), ncols=0):
+ # --- LR Scheduling Step --- #
+ step_lr = get_lr_with_warmup(
+ step=step,
+ steps=config.steps,
+ lr=config.lr,
+ lr_schedule_fn=lr_schedule_fn,
+ lr_warmup_pct=config.lr_warmup_pct,
+ )
+ # Manually update optimizer's learning rate
+ for group in optimizer.param_groups:
+ group["lr"] = step_lr
+ log_data["lr"] = step_lr
+
+ # --- Zero Gradients --- #
+ optimizer.zero_grad()
+
+ # --- Get Batch --- #
+ try:
+ batch = next(data_iter)["input_ids"].to(device)
+ except StopIteration:
+ logger.warning("Dataloader exhausted, resetting iterator.")
+ data_iter = iter(train_loader)
+ batch = next(data_iter)["input_ids"].to(device)
+
+ (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks(
+ batch, module_names=list(components.keys())
+ )
+ As = {module_name: v.linear_component.A for module_name, v in components.items()}
+
+ target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore
+ # attributions = calc_grad_attributions(
+ # target_out=target_out,
+ # pre_weight_acts=pre_weight_acts,
+ # post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")},
+ # target_component_acts=target_component_acts,
+ # Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False),
+ # )
+ attributions = None
+
+ masks, relud_masks = calc_masks(
+ gates=gates,
+ target_component_acts=target_component_acts,
+ attributions=attributions,
+ detach_inputs=False,
+ )
+
+ # --- Calculate Losses --- #
+ total_loss = torch.tensor(0.0, device=device)
+ loss_terms = {}
+
+ ####### param match loss #######
+ param_match_loss_val = calc_param_match_loss_lm(
+ components=components,
+ target_model=model.model,
+ n_params=n_params,
+ device=device,
+ )
+ total_loss += config.param_match_coeff * param_match_loss_val
+ loss_terms["loss/parameter_matching"] = param_match_loss_val.item()
+
+ ####### layerwise recon loss #######
+ if config.layerwise_recon_coeff is not None:
+ layerwise_recon_loss = calc_layerwise_recon_loss_lm(
+ model=model,
+ batch=batch,
+ device=device,
+ components=components,
+ masks=[masks],
+ target_out=target_out,
+ )
+ total_loss += config.layerwise_recon_coeff * layerwise_recon_loss
+ loss_terms["loss/layerwise_reconstruction"] = layerwise_recon_loss.item()
+
+ ####### layerwise random recon loss #######
+ if config.layerwise_random_recon_coeff is not None:
+ layerwise_random_masks = calc_random_masks(
+ masks=masks, n_random_masks=config.n_random_masks
+ )
+ layerwise_random_recon_loss = calc_layerwise_recon_loss_lm(
+ model=model,
+ batch=batch,
+ device=device,
+ components=components,
+ masks=layerwise_random_masks,
+ target_out=target_out,
+ )
+ total_loss += config.layerwise_random_recon_coeff * layerwise_random_recon_loss
+ loss_terms["loss/layerwise_random_reconstruction"] = layerwise_random_recon_loss.item()
+
+ ####### lp sparsity loss #######
+ lp_sparsity_loss = calc_lp_sparsity_loss_lm(relud_masks=relud_masks, pnorm=config.pnorm)
+ total_loss += config.lp_sparsity_coeff * lp_sparsity_loss
+ loss_terms["loss/lp_sparsity_loss"] = lp_sparsity_loss.item()
+
+ ####### out recon loss #######
+ if config.out_recon_coeff is not None:
+ # Get target logits (no gradients needed for target model)
+ with torch.no_grad():
+ target_logits, _ = model.forward(batch)
+ # Detach target logits to ensure no grads flow back
+ target_logits = target_logits.detach()
+
+ # Get component logits
+ component_logits, _ = model.forward_with_components(
+ batch, components=components, masks=masks
+ )
+
+ assert component_logits.shape == target_logits.shape, (
+ f"Shape mismatch: {component_logits.shape} vs {target_logits.shape}"
+ )
+
+ recon_loss = calc_recon_mse_lm(component_logits, target_logits)
+ total_loss += config.out_recon_coeff * recon_loss
+ loss_terms["loss/reconstruction"] = recon_loss.item()
+
+ log_data["loss/total"] = total_loss.item()
+ log_data.update(loss_terms)
+
+ # --- Logging --- #
+ if step % config.print_freq == 0:
+ tqdm.write(f"--- Step {step} ---")
+ tqdm.write(f"LR: {step_lr:.6f}")
+ tqdm.write(f"Total Loss: {log_data['loss/total']:.7f}")
+ for name, value in loss_terms.items():
+ if value is not None:
+ tqdm.write(f"{name}: {value:.7f}")
+
+ mean_n_active_components_per_token = component_activation_statistics(
+ model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device
+ )[0]
+ tqdm.write(f"Mean n active components per token: {mean_n_active_components_per_token}")
+
+ if config.wandb_project:
+ mask_l_zero = calc_mask_l_zero(masks=masks)
+ for layer_name, layer_mask_l_zero in mask_l_zero.items():
+ log_data[f"{layer_name}/mask_l0"] = layer_mask_l_zero
+ log_data[f"{layer_name}/mean_n_active_components_per_token"] = (
+ mean_n_active_components_per_token[layer_name]
+ )
+ wandb.log(log_data, step=step)
+
+ # --- Plotting --- #
+ if (
+ config.image_freq is not None
+ and step % config.image_freq == 0
+ and (step > 0 or config.image_on_first_step)
+ ):
+ logger.info(f"Step {step}: Generating plots...")
+ with torch.no_grad():
+ fig_dict = plot_results_fn(
+ model=model, # Pass the SSModel wrapper
+ components=components,
+ step=step,
+ out_dir=out_dir,
+ device=device,
+ config=config,
+ # Add any other necessary args for plotting like tokenizer, sample text?
+ )
+ mean_component_activation_counts = component_activation_statistics(
+ model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device
+ )[1]
+ fig_dict["mean_component_activation_counts"] = (
+ plot_mean_component_activation_counts(
+ mean_component_activation_counts=mean_component_activation_counts,
+ )
+ )
+ if config.wandb_project:
+ wandb.log(
+ {k: wandb.Image(v) for k, v in fig_dict.items()},
+ step=step,
+ )
+ if out_dir is not None:
+ for k, v in fig_dict.items():
+ v.savefig(out_dir / f"{k}_{step}.png")
+ tqdm.write(f"Saved plot to {out_dir / f'{k}_{step}.png'}")
+
+ # --- Saving Checkpoint --- #
+ if (
+ (config.save_freq is not None and step % config.save_freq == 0 and step > 0)
+ or step == config.steps
+ ) and out_dir is not None:
+ torch.save(model.state_dict(), out_dir / f"model_{step}.pth")
+ torch.save(optimizer.state_dict(), out_dir / f"optimizer_{step}.pth")
+ logger.info(f"Saved model, optimizer, and out_dir to {out_dir}")
+ if config.wandb_project:
+ wandb.save(str(out_dir / f"model_{step}.pth"), base_path=str(out_dir), policy="now")
+ wandb.save(
+ str(out_dir / f"optimizer_{step}.pth"), base_path=str(out_dir), policy="now"
+ )
+
+ # --- Backward Pass & Optimize --- #
+ # Skip gradient step if we are at the last step (last step just for plotting and logging)
+ if step != config.steps:
+ total_loss.backward(retain_graph=True)
+
+ if step % config.print_freq == 0 and config.wandb_project:
+ # Calculate gradient norm
+ grad_norm: float = 0.0
+ for param in model.parameters():
+ if param.grad is not None:
+ grad_norm += param.grad.data.norm() # type: ignore
+ wandb.log({"grad_norm": grad_norm}, step=step)
+
+ if config.unit_norm_matrices:
+ model.fix_normalized_adam_gradients()
+
+ optimizer.step()
+ logger.info("Finished training loop.")
+
+
+def main(
+ config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None
+) -> None:
+ config = load_config(config_path_or_obj, config_model=Config)
+
+ if config.wandb_project:
+ config = init_wandb(config, config.wandb_project, sweep_config_path)
+
+ set_seed(config.seed)
+ logger.info(config)
+
+ device = get_device()
+ logger.info(f"Using device: {device}")
+ assert isinstance(config.task_config, LMTaskConfig), (
+ "Task config must be LMTaskConfig for LM decomposition."
+ )
+
+ # --- Load Model --- #
+ logger.info(f"Loading model: {config.task_config.model_size}")
+ model_config_dict = MODEL_CONFIGS[config.task_config.model_size]
+ model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}"
+ model = Llama.from_pretrained(model_path, model_config_dict)
+
+ ss_model = SSModel(
+ llama_model=model,
+ target_module_patterns=config.task_config.target_module_patterns,
+ m=config.m,
+ n_gate_hidden_neurons=config.n_gate_hidden_neurons,
+ )
+ ss_model.to(device)
+ logger.info("Model loaded.")
+
+ # --- Setup Run Name and Output Dir --- #
+ run_name = get_run_name(
+ config,
+ model_size=config.task_config.model_size,
+ max_seq_len=config.task_config.max_seq_len,
+ )
+ if config.wandb_project:
+ assert wandb.run, "wandb.run must be initialized before training"
+ wandb.run.name = run_name
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
+ out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}"
+ out_dir.mkdir(parents=True, exist_ok=True)
+ logger.info(f"Output directory: {out_dir}")
+
+ # --- Save Config --- #
+ with open(out_dir / "final_config.yaml", "w") as f:
+ yaml.dump(config.model_dump(mode="json"), f, indent=2)
+ if config.wandb_project:
+ wandb.save(str(out_dir / "final_config.yaml"), base_path=out_dir, policy="now")
+
+ # --- Load Data --- #
+ logger.info("Loading dataset...")
+ train_data_config = DatasetConfig(
+ name=config.task_config.dataset_name,
+ tokenizer_file_path=None,
+ hf_tokenizer_path=model_path,
+ split=config.task_config.train_data_split,
+ n_ctx=config.task_config.max_seq_len,
+ is_tokenized=False,
+ streaming=False,
+ column_name="story",
+ )
+
+ train_loader, tokenizer = create_data_loader(
+ dataset_config=train_data_config,
+ batch_size=config.batch_size,
+ buffer_size=config.task_config.buffer_size,
+ global_seed=config.seed,
+ ddp_rank=0,
+ ddp_world_size=1,
+ )
+
+ eval_data_config = DatasetConfig(
+ name=config.task_config.dataset_name,
+ tokenizer_file_path=None,
+ hf_tokenizer_path=model_path,
+ split=config.task_config.eval_data_split,
+ n_ctx=config.task_config.max_seq_len,
+ is_tokenized=False,
+ streaming=False,
+ column_name="story",
+ )
+ eval_loader, _ = create_data_loader(
+ dataset_config=eval_data_config,
+ batch_size=config.batch_size,
+ buffer_size=config.task_config.buffer_size,
+ global_seed=config.seed,
+ ddp_rank=0,
+ ddp_world_size=1,
+ )
+
+ logger.info("Dataset and tokenizer loaded.")
+
+ logger.info("Freezing target model parameters...")
+ for param in ss_model.model.parameters():
+ param.requires_grad = False
+ logger.info("Target model frozen.")
+
+ logger.info("Starting optimization...")
+ optimize_lm(
+ model=ss_model,
+ config=config,
+ device=device,
+ train_loader=train_loader,
+ eval_loader=eval_loader,
+ n_eval_steps=config.task_config.n_eval_steps,
+ out_dir=out_dir,
+ plot_results_fn=lm_plot_results_fn,
+ )
+
+ logger.info("Optimization finished.")
+
+ if config.wandb_project:
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py
new file mode 100644
index 0000000..eafc9b4
--- /dev/null
+++ b/spd/experiments/lm/models.py
@@ -0,0 +1,267 @@
+"""
+Defines a SSModel class that is a wrapper around a llama model from SimpleStories
+"""
+
+import fnmatch
+from functools import partial
+from pathlib import Path
+from typing import Any
+
+import torch
+import torch.nn as nn
+import wandb
+import yaml
+from jaxtyping import Float
+from pydantic import BaseModel
+from simple_stories_train.models.llama import Llama
+from simple_stories_train.models.model_configs import MODEL_CONFIGS
+from torch import Tensor
+from wandb.apis.public import Run
+
+from spd.configs import Config, LMTaskConfig
+from spd.models.components import Gate, GateMLP, LinearComponent
+from spd.types import WANDB_PATH_PREFIX, ModelPath
+from spd.wandb_utils import (
+ download_wandb_file,
+ fetch_latest_wandb_checkpoint,
+ fetch_wandb_run_dir,
+)
+
+
+class LinearComponentWithBias(nn.Module):
+ """A LinearComponent with a bias parameter."""
+
+ def __init__(self, linear_component: LinearComponent, bias: Tensor | None):
+ super().__init__()
+ self.linear_component = linear_component
+ self.bias = bias
+ self.mask: Float[Tensor, "... m"] | None = None # Gets set on sparse forward passes
+
+ def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]:
+ # Note: We assume bias is added *after* the component multiplication
+ # Also assume input is (batch, seq_len, d_in)
+ out = self.linear_component(x, mask=self.mask)
+ if self.bias is not None:
+ out += self.bias
+ return out
+
+
+def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponentWithBias:
+ """Replace a nn.Linear module with a LinearComponentWithBias module."""
+ d_out, d_in = linear_module.weight.shape
+
+ linear_component = LinearComponent(d_in=d_in, d_out=d_out, m=m, n_instances=None)
+
+ # # Initialize with A = W (original weights) and B = I (identity)
+ # # This provides a starting point where the component exactly equals the original
+ # linear_component.A.data[:] = linear_module.weight.t() # (d_in, m)
+ # linear_component.B.data[:] = torch.eye(m)
+
+ bias = linear_module.bias.clone() if linear_module.bias is not None else None # type: ignore
+
+ return LinearComponentWithBias(linear_component, bias)
+
+
+class SSModelPaths(BaseModel):
+ """Paths to output files from a SSModel training run."""
+
+ model: Path
+ optimizer: Path
+ config: Path
+
+
+class SSModel(nn.Module):
+ """Wrapper around a llama model from SimpleStories for running SPD."""
+
+ def __init__(
+ self,
+ llama_model: Llama,
+ target_module_patterns: list[str],
+ m: int,
+ n_gate_hidden_neurons: int | None,
+ ):
+ super().__init__()
+ self.model = llama_model
+ self.m = m
+ self.components = self.create_target_components(
+ target_module_patterns=target_module_patterns, m=m
+ )
+
+ # Use GateMLP if n_gate_hidden_neurons is provided, otherwise use Gate
+ gate_class = GateMLP if n_gate_hidden_neurons is not None else Gate
+ gate_kwargs = {"m": m}
+ if n_gate_hidden_neurons is not None:
+ gate_kwargs["n_gate_hidden_neurons"] = n_gate_hidden_neurons
+
+ self.gates = nn.ModuleDict({name: gate_class(**gate_kwargs) for name in self.components})
+
+ def create_target_components(self, target_module_patterns: list[str], m: int) -> nn.ModuleDict:
+ """Create target components for the model."""
+ components: dict[str, LinearComponentWithBias] = {}
+ for name, module in self.model.named_modules():
+ for pattern in target_module_patterns:
+ if fnmatch.fnmatch(name, pattern):
+ assert isinstance(module, nn.Linear), (
+ f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. "
+ f"Found type: {type(module)}"
+ )
+ # Replace "." with "-" in the name to avoid issues with module dict keys
+ components[name.replace(".", "-")] = nn_linear_to_components(module, m=m)
+ break
+ return nn.ModuleDict(components)
+
+ def to(self, *args: Any, **kwargs: Any) -> "SSModel":
+ """Move the model and components to a device."""
+ self.model.to(*args, **kwargs)
+ for component in self.components.values():
+ component.to(*args, **kwargs)
+ for gate in self.gates.values():
+ gate.to(*args, **kwargs)
+ return self
+
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
+ """Regular forward pass of the (target) model."""
+ return self.model(*args, **kwargs)
+
+ def forward_with_component(
+ self,
+ *args: Any,
+ module_name: str,
+ component: LinearComponentWithBias,
+ mask: Float[Tensor, "batch pos m"] | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ """Forward pass with a single component replacement."""
+ # Note that module_name uses "." separators but self.components use "-" separators
+ old_module = self.model.get_submodule(module_name)
+ assert old_module is not None
+
+ self.model.set_submodule(module_name, component)
+ if mask is not None:
+ component.mask = mask
+
+ out = self.model(*args, **kwargs)
+
+ self.model.set_submodule(module_name, old_module)
+ return out
+
+ def forward_with_components(
+ self,
+ *args: Any,
+ components: dict[str, LinearComponentWithBias],
+ masks: dict[str, Float[Tensor, "batch pos m"]] | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ """Forward pass with temporary component replacement."""
+ # Note that components and masks uses "-" separators
+ old_modules = {}
+ for component_name, component in components.items():
+ module_name = component_name.replace("-", ".")
+ # component: LinearComponentWithBias = self.components[module_name.replace(".", "-")]
+ old_module = self.model.get_submodule(module_name)
+ assert old_module is not None
+ old_modules[module_name] = old_module
+
+ if masks is not None:
+ component.mask = masks.get(component_name, None)
+ self.model.set_submodule(module_name, component)
+
+ out = self.model(*args, **kwargs)
+
+ # Restore the original modules
+ for module_name, old_module in old_modules.items():
+ self.model.set_submodule(module_name, old_module)
+
+ # Remove the masks attribute from the components
+ for component in components.values():
+ component.mask = None
+
+ return out
+
+ def forward_with_pre_forward_cache_hooks(
+ self, *args: Any, module_names: list[str], **kwargs: Any
+ ) -> tuple[Any, dict[str, Tensor]]:
+ """Forward pass with caching at in the input to the modules given by `module_names`.
+
+ Args:
+ module_names: List of module names to cache the inputs to.
+ """
+ cache = {}
+
+ def cache_hook(module: nn.Module, input: tuple[Tensor, ...], param_name: str) -> Tensor:
+ cache[param_name] = input[0]
+ return input[0]
+
+ handles: list[torch.utils.hooks.RemovableHandle] = []
+ for module_name in module_names:
+ module = self.model.get_submodule(module_name)
+ assert module is not None
+ handles.append(
+ module.register_forward_pre_hook(partial(cache_hook, param_name=module_name))
+ )
+
+ out = self.forward(*args, **kwargs)
+
+ for handle in handles:
+ handle.remove()
+
+ return out, cache
+
+ @staticmethod
+ def _download_wandb_files(wandb_project_run_id: str) -> SSModelPaths:
+ """Download the relevant files from a wandb run."""
+ api = wandb.Api()
+ run: Run = api.run(wandb_project_run_id)
+
+ checkpoint = fetch_latest_wandb_checkpoint(run, prefix="model")
+
+ run_dir = fetch_wandb_run_dir(run.id)
+
+ final_config_path = download_wandb_file(run, run_dir, "final_config.yaml")
+ checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name)
+
+ # Get the step number from the path
+ step = int(Path(checkpoint_path).stem.split("_")[-1])
+
+ return SSModelPaths(
+ model=checkpoint_path,
+ optimizer=download_wandb_file(run, run_dir, f"optimizer_{step}.pth"),
+ config=final_config_path,
+ )
+
+ @classmethod
+ def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config, Path]:
+ if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX):
+ wandb_path = path.removeprefix(WANDB_PATH_PREFIX)
+ api = wandb.Api()
+ run: Run = api.run(wandb_path)
+ paths = cls._download_wandb_files(wandb_path)
+ out_dir = fetch_wandb_run_dir(run.id)
+
+ else:
+ # Get the step number from the path
+ step = int(Path(path).stem.split("_")[-1])
+ paths = SSModelPaths(
+ model=Path(path),
+ optimizer=Path(path).parent / f"optimizer_{step}.pth",
+ config=Path(path).parent / "final_config.yaml",
+ )
+ out_dir = Path(path).parent
+
+ model_weights = torch.load(paths.model, map_location="cpu", weights_only=True)
+ with open(paths.config) as f:
+ config = Config(**yaml.safe_load(f))
+
+ assert isinstance(config.task_config, LMTaskConfig)
+ model_config_dict = MODEL_CONFIGS[config.task_config.model_size]
+ model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}"
+ llama_model = Llama.from_pretrained(model_path, model_config_dict)
+
+ ss_model = SSModel(
+ llama_model=llama_model,
+ target_module_patterns=config.task_config.target_module_patterns,
+ m=config.m,
+ n_gate_hidden_neurons=config.n_gate_hidden_neurons,
+ )
+ ss_model.load_state_dict(model_weights)
+ return ss_model, config, out_dir
diff --git a/spd/experiments/lm/play.py b/spd/experiments/lm/play.py
new file mode 100644
index 0000000..87164f7
--- /dev/null
+++ b/spd/experiments/lm/play.py
@@ -0,0 +1,94 @@
+# %%
+import torch
+from simple_stories_train.models.llama import Llama
+from simple_stories_train.models.model_configs import MODEL_CONFIGS
+from transformers import AutoTokenizer
+
+from spd.experiments.lm.models import LinearComponentWithBias, SSModel
+
+# %%
+# Select the model size you want to use
+model_size = "1.25M" # Options: "35M", "30M", "11M", "5M", "1.25M"
+
+# Load model configuration
+model_config = MODEL_CONFIGS[model_size]
+
+# Load appropriate model
+model_path = f"chandan-sreedhara/SimpleStories-{model_size}"
+model = Llama.from_pretrained(model_path, model_config)
+# model.to("cuda")
+model.eval()
+# %%
+
+ss_model = SSModel(
+ llama_model=model,
+ target_module_patterns=["model.transformer.h.*.mlp.gate_proj"],
+ m=17,
+ n_gate_hidden_neurons=None,
+)
+
+# # Create components with rank=10 (adjust as needed)
+# gate_proj_components = create_target_components(
+# model, rank=m, target_module_patterns=["model.transformer.h.*.mlp.gate_proj"]
+# )
+gate_proj_components: dict[str, LinearComponentWithBias] = {
+ k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items()
+} # type: ignore
+# %%
+# Load tokenizer
+tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)
+
+# Define your prompt
+prompt = "The curious cat looked at the"
+
+# IMPORTANT: Use tokenizer without special tokens
+inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
+# input_ids = inputs.input_ids.to("cuda")
+input_ids = inputs.input_ids
+# Targets should be the inputs shifted by one (we will later ignore the last input token)
+targets = input_ids[:, 1:]
+input_ids = input_ids[:, :-1]
+
+# IMPORTANT: Set correct EOS token ID (not the default from tokenizer)
+eos_token_id = 1
+
+# %%
+
+# # Generate text
+# with torch.no_grad():
+# output_ids = model.generate(
+# idx=input_ids, max_new_tokens=20, temperature=0.7, top_k=40, eos_token_id=eos_token_id
+# )
+
+# # Decode output
+# output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
+# print(f"Generated text:\n{output_text}")
+
+
+# %%
+
+# logits, _ = ss_model.forward(input_ids, components=gate_proj_components)
+logits, _ = ss_model.forward(input_ids)
+print("inputs_shape", input_ids.shape)
+print("logits", logits)
+print("logits shape", logits.shape)
+
+logits, _ = ss_model.forward_with_components(input_ids, components=gate_proj_components)
+
+print("Component logits shape", logits.shape)
+print("Component logits", logits)
+
+# Create some dummy masks
+masks = {
+ f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], ss_model.m)
+ for i in range(len(model.transformer.h))
+}
+
+logits, _ = ss_model.forward_with_components(
+ input_ids, components=gate_proj_components, masks=masks
+)
+
+print("Masked component logits shape", logits.shape)
+print("Masked component logits", logits)
+#########################################################
+# %%
diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py
index 7d1b72b..c8e09c7 100644
--- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py
+++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py
@@ -23,7 +23,7 @@
)
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset
from spd.log import logger
-from spd.models.components import Gate
+from spd.models.components import Gate, GateMLP
from spd.plotting import plot_AB_matrices, plot_mask_vals
from spd.run_spd import get_common_run_name_suffix, optimize
from spd.utils import (
@@ -107,7 +107,7 @@ def resid_mlp_plot_results_fn(
out_dir: Path | None,
device: str,
config: Config,
- gates: dict[str, Gate],
+ gates: dict[str, Gate | GateMLP],
masks: dict[str, Float[Tensor, "batch_size m"]] | None,
**_,
) -> dict[str, plt.Figure]:
@@ -161,9 +161,9 @@ def init_spd_model_from_target_model(
for i in range(target_model.config.n_layers):
# For mlp_in, m must equal d_mlp
# TODO: This is broken, we shouldn't need m=d_mlp for this function.
- assert (
- m == target_model.config.d_mlp or m == target_model.config.d_embed
- ), "m must be equal to d_mlp or d_embed"
+ assert m == target_model.config.d_mlp or m == target_model.config.d_embed, (
+ "m must be equal to d_mlp or d_embed"
+ )
# For mlp_in: A = target weights, B = identity
model.layers[i].mlp_in.A.data[:] = target_model.layers[i].mlp_in.weight.data.clone()
diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py
index 70ff530..2d70087 100644
--- a/spd/experiments/tms/tms_decomposition.py
+++ b/spd/experiments/tms/tms_decomposition.py
@@ -20,7 +20,7 @@
from spd.configs import Config, TMSTaskConfig
from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig
from spd.log import logger
-from spd.models.components import Gate
+from spd.models.components import Gate, GateMLP
from spd.plotting import plot_AB_matrices, plot_mask_vals
from spd.run_spd import get_common_run_name_suffix, optimize
from spd.utils import (
@@ -54,7 +54,7 @@ def make_plots(
out_dir: Path,
device: str,
config: Config,
- gates: dict[str, Gate],
+ gates: dict[str, Gate | GateMLP],
masks: dict[str, Float[Tensor, "batch n_instances m"]],
batch: Float[Tensor, "batch n_instances n_features"],
**_,
diff --git a/spd/plotting.py b/spd/plotting.py
index 1eb1adf..91d438a 100644
--- a/spd/plotting.py
+++ b/spd/plotting.py
@@ -10,7 +10,7 @@
from spd.hooks import HookedRootModule
from spd.models.base import SPDModel
-from spd.models.components import Gate
+from spd.models.components import Gate, GateMLP
from spd.module_utils import collect_nested_module_attrs
from spd.run_spd import calc_component_acts, calc_masks
@@ -48,7 +48,7 @@ def permute_to_identity(
def plot_mask_vals(
model: SPDModel,
target_model: HookedRootModule,
- gates: dict[str, Gate],
+ gates: dict[str, Gate | GateMLP],
device: str,
input_magnitude: float,
) -> tuple[plt.Figure, dict[str, Float[Tensor, "n_instances m"]]]:
@@ -146,7 +146,7 @@ def plot_subnetwork_attributions_statistics(
ax.set_ylabel("Count")
ax.set_xlabel("Number of active subnetworks")
- ax.set_title(f"Instance {i+1}")
+ ax.set_title(f"Instance {i + 1}")
# Add value annotations on top of each bar
for bar in bars:
@@ -212,9 +212,9 @@ def plot_AB_matrices(
# Verify that A and B matrices have matching names
A_names = set(As.keys())
B_names = set(Bs.keys())
- assert (
- A_names == B_names
- ), f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}"
+ assert A_names == B_names, (
+ f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}"
+ )
n_layers = len(As)
diff --git a/spd/run_spd.py b/spd/run_spd.py
index eef483e..6e30cab 100644
--- a/spd/run_spd.py
+++ b/spd/run_spd.py
@@ -17,7 +17,7 @@
from spd.configs import Config
from spd.hooks import HookedRootModule
from spd.models.base import SPDModel
-from spd.models.components import Gate, Linear, LinearComponent
+from spd.models.components import Gate, GateMLP, Linear, LinearComponent
from spd.module_utils import collect_nested_module_attrs, get_nested_module_attr
from spd.utils import calc_recon_mse, get_lr_schedule_fn, get_lr_with_warmup
@@ -143,7 +143,7 @@ def calc_act_recon_mse(
def calc_masks(
- gates: dict[str, Gate],
+ gates: dict[str, Gate | GateMLP],
target_component_acts: dict[
str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]
],
diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py
index 73d2d67..9f451a7 100644
--- a/spd/wandb_utils.py
+++ b/spd/wandb_utils.py
@@ -45,6 +45,7 @@ def fetch_wandb_run_dir(run_id: str) -> Path:
"""
# Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set
base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb"))
+ base_cache_dir.mkdir(parents=True, exist_ok=True)
# Set default wandb_run_dir
wandb_run_dir = base_cache_dir / run_id / "files"
diff --git a/tests/test_module_utils.py b/tests/test_module_utils.py
new file mode 100644
index 0000000..a59643d
--- /dev/null
+++ b/tests/test_module_utils.py
@@ -0,0 +1,15 @@
+from torch import nn
+
+from spd.module_utils import get_nested_module_attr
+
+
+def test_get_nested_module_attr():
+ class TestModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear1 = nn.Linear(10, 10)
+ self.linear2 = nn.Linear(10, 10)
+
+ module = TestModule()
+ assert get_nested_module_attr(module, "linear1.weight.data").shape == (10, 10)
+ assert get_nested_module_attr(module, "linear2.weight.data").shape == (10, 10)