diff --git a/.vscode/launch.json b/.vscode/launch.json index 5ce7aef..a753e33 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,5 +37,29 @@ "PYDEVD_DISABLE_FILE_VALIDATION": "1" } }, + { + "name": "lm", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py", + "args": "${workspaceFolder}/spd/experiments/lm/lm_config.yaml", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + }, + { + "name": "lm streamlit", + "type": "debugpy", + "request": "launch", + "module": "streamlit", + "args": [ + "run", + "${workspaceFolder}/spd/experiments/lm/app.py", + "--server.port", + "2000" + ] + } ] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 49e6f8a..4031cfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,10 @@ dependencies = [ "python-dotenv", "wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248 "sympy", + "streamlit", + "streamlit-antd-components", + "datasets", + "simple-stories-train" ] [project.optional-dependencies] diff --git a/spd/configs.py b/spd/configs.py index 1622900..d039054 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -36,6 +36,20 @@ class ResidualMLPTaskConfig(BaseModel): pretrained_model_path: ModelPath # e.g. wandb:spd-resid-mlp/runs/j9kmavzi +class LMTaskConfig(BaseModel): + model_config = ConfigDict(extra="forbid", frozen=True) + task_name: Literal["lm"] = "lm" + model_size: str # e.g. "1.25M" + max_seq_len: PositiveInt = 512 + buffer_size: PositiveInt = 1000 + dataset_name: str = "lennart-finke/SimpleStories" + train_data_split: str = "train" + eval_data_split: str = "test" + n_eval_steps: PositiveInt = 100 + # List of fnmatch patterns for nn.Linear modules to decompose + target_module_patterns: list[str] = ["transformer.h.*.mlp.*_proj"] + + class Config(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) wandb_project: str | None = None @@ -68,7 +82,9 @@ class Config(BaseModel): unit_norm_matrices: bool = False attribution_type: Literal["gradient"] = "gradient" n_gate_hidden_neurons: PositiveInt | None = None - task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name") + task_config: TMSTaskConfig | ResidualMLPTaskConfig | LMTaskConfig = Field( + ..., discriminator="task_name" + ) DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = [] RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {} diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py new file mode 100644 index 0000000..d880198 --- /dev/null +++ b/spd/experiments/lm/app.py @@ -0,0 +1,405 @@ +""" +To run this app, run the following command: + +```bash + streamlit run spd/experiments/lm/app.py -- --model_path "wandb:spd-lm/runs/151bsctx" +``` +""" + +import argparse +import html +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any + +import streamlit as st +import torch +from datasets import load_dataset +from jaxtyping import Float, Int +from simple_stories_train.dataloaders import DatasetConfig +from torch import Tensor +from transformers import AutoTokenizer + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_component_acts, calc_masks +from spd.types import ModelPath + +DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" + + +# ----------------------------------------------------------- +# Dataclass holding everything the app needs +# ----------------------------------------------------------- +@dataclass(frozen=True) +class AppData: + model: SSModel + tokenizer: AutoTokenizer + config: Config + dataloader_iter_fn: Callable[[], Iterator[dict[str, Any]]] + gates: dict[str, Gate | GateMLP] + components: dict[str, LinearComponentWithBias] + target_layer_names: list[str] + device: str + + +# --- Initialization and Data Loading --- +@st.cache_resource(show_spinner="Loading model and data...") +def initialize(model_path: ModelPath) -> AppData: + """ + Loads the model, tokenizer, config, and evaluation dataloader. + Cached by Streamlit based on the model_path. + """ + device = "cpu" # Use CPU for the Streamlit app + logger.info(f"Initializing app with model: {model_path} on device: {device}") + ss_model, config, _ = SSModel.from_pretrained(model_path) + ss_model.to(device) + ss_model.eval() + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Task config must be LMTaskConfig for this app." + + # Derive tokenizer path (adjust if stored differently) + tokenizer_path = f"chandan-sreedhara/SimpleStories-{task_config.model_size}" + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + add_bos_token=False, + unk_token="[UNK]", + eos_token="[EOS]", + bos_token=None, + ) + + # Create eval dataloader config + eval_data_config = DatasetConfig( + name=task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=tokenizer_path, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=False, + streaming=False, # Non-streaming might be simpler for iterator reset + column_name="story", + ) + + # Create the dataloader iterator + def create_dataloader_iter() -> Iterator[dict[str, Any]]: + """ + Returns a *new* iterator each time it is called. + Each element is a dict with: + - "text": the raw document text + - "input_ids": Int[Tensor, "1 seq_len"] + - "offset_mapping": list[tuple[int, int]] + """ + logger.info("Creating new dataloader iterator.") + + # Stream the HF dataset split + dataset = load_dataset( + eval_data_config.name, + streaming=eval_data_config.streaming, + split=eval_data_config.split, + trust_remote_code=False, + ) + + dataset = dataset.with_format("torch") + + text_column = eval_data_config.column_name + + def tokenize_and_prepare(example: dict[str, Any]) -> dict[str, Any]: + original_text: str = example[text_column] + + tokenized = tokenizer( + original_text, + return_tensors="pt", + return_offsets_mapping=True, + truncation=True, + max_length=task_config.max_seq_len, + padding=False, + ) + + input_ids: Int[Tensor, "1 seq_len"] = tokenized["input_ids"] + if input_ids.dim() == 1: # Ensure 2‑D [1, seq_len] + input_ids = input_ids.unsqueeze(0) + + # HF returns offset_mapping as a list per sequence; batch size is 1 + offset_mapping: list[tuple[int, int]] = tokenized["offset_mapping"][0].tolist() + + return { + "text": original_text, + "input_ids": input_ids, + "offset_mapping": offset_mapping, + } + + # Map over the streaming dataset and return an iterator + return map(tokenize_and_prepare, iter(dataset)) + + # Extract components and gates + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() + } # type: ignore[reportAssignmentType] + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() + } # type: ignore[reportAssignmentType] + target_layer_names = sorted(list(components.keys())) + + logger.info(f"Initialization complete for {model_path}.") + return AppData( + model=ss_model, + tokenizer=tokenizer, + config=config, + dataloader_iter_fn=create_dataloader_iter, + gates=gates, + components=components, + target_layer_names=target_layer_names, + device=device, + ) + + +# ----------------------------------------------------------- +# Utility: render the prompt with faint token outlines +# ----------------------------------------------------------- +def render_prompt_with_tokens( + *, + raw_text: str, + offset_mapping: list[tuple[int, int]], + selected_idx: int | None, +) -> None: + """ + Renders `raw_text` inside Streamlit, wrapping each token span with a thin + border. The currently‑selected token receives a thicker red border. + All other tokens get a thin mid‑grey border (no background fill). + """ + html_chunks: list[str] = [] + cursor = 0 + + def esc(s: str) -> str: + return html.escape(s) + + for idx, (start, end) in enumerate(offset_mapping): + if cursor < start: + html_chunks.append(esc(raw_text[cursor:start])) + + token_substr = esc(raw_text[start:end]) + if token_substr: + is_selected = idx == selected_idx + border_style = ( + "2px solid rgb(200,0,0)" if is_selected else "0.5px solid #aaa" # all other tokens + ) + html_chunks.append( + "' + f"{token_substr}" + ) + cursor = end + + if cursor < len(raw_text): + html_chunks.append(esc(raw_text[cursor:])) + + st.markdown( + f'
{"".join(html_chunks)}
', + unsafe_allow_html=True, + ) + + +def load_next_prompt() -> None: + """Loads the next prompt, calculates masks, and prepares token data.""" + logger.info("Loading next prompt.") + app_data: AppData = st.session_state.app_data + dataloader_iter = st.session_state.dataloader_iter # Get current iterator + + try: + batch = next(dataloader_iter) + input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data.device) + except StopIteration: + logger.warning("Dataloader iterator exhausted. Throwing error.") + st.error("Failed to get data even after resetting dataloader.") + return + + st.session_state.current_input_ids = input_ids + + # Store the original raw prompt text + st.session_state.current_prompt_text = batch["text"] + + # Calculate activations and masks + with torch.no_grad(): + (_, _), pre_weight_acts = app_data.model.forward_with_pre_forward_cache_hooks( + input_ids, module_names=list(app_data.components.keys()) + ) + As = {module_name: v.linear_component.A for module_name, v in app_data.components.items()} + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore[reportArgumentType] + masks, _ = calc_masks( + gates=app_data.gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=True, # No gradients needed + ) + st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]] + + # Prepare token data for display + token_data = [] + tokenizer = app_data.tokenizer + for i, token_id in enumerate(input_ids[0]): + # Decode individual token - might differ slightly from full decode for spaces etc. + decoded_token_str = tokenizer.decode([token_id]) # type: ignore[reportAttributeAccessIssue] + token_data.append( + { + "id": token_id.item(), + "text": decoded_token_str, + "index": i, + "offset": batch["offset_mapping"][i], # (start, end) + } + ) + st.session_state.token_data = token_data + + # Reset selections + st.session_state.selected_token_index = 0 # default: first token + st.session_state.selected_layer_name = None + logger.info("Finished loading next prompt and calculating masks.") + + +# --- Main App UI --- +def run_app(args: argparse.Namespace) -> None: + """Sets up and runs the Streamlit application.""" + st.set_page_config(layout="wide") + st.title("LM Component Activation Explorer") + + # Initialize model, data, etc. (cached) + st.session_state.app_data = initialize(args.model_path) + app_data: AppData = st.session_state.app_data + st.caption(f"Model: {args.model_path}") + + # Initialize session state variables if they don't exist + if "current_prompt_text" not in st.session_state: + st.session_state.current_prompt_text = None + if "token_data" not in st.session_state: + st.session_state.token_data = None + if "current_masks" not in st.session_state: + st.session_state.current_masks = None + if "selected_token_index" not in st.session_state: + st.session_state.selected_token_index = None + if "selected_layer_name" not in st.session_state: + if app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + else: + st.session_state.selected_layer_name = None + # Initialize the dataloader iterator in session state + if "dataloader_iter" not in st.session_state: + st.session_state.dataloader_iter = app_data.dataloader_iter_fn() + + if st.session_state.current_prompt_text is None: + load_next_prompt() + + # Sidebar container and a single expander for all interactive controls + sidebar = st.sidebar + controls_expander = sidebar.expander("Controls", expanded=True) + + # ------------------------------------------------------------------ + # Sidebar – interactive controls + # ------------------------------------------------------------------ + with controls_expander: + st.button("Load Next Prompt", on_click=load_next_prompt) + + # Render the raw prompt with faint token borders + if st.session_state.token_data and st.session_state.current_prompt_text: + # st.subheader("Prompt") + render_prompt_with_tokens( + raw_text=st.session_state.current_prompt_text, + offset_mapping=[t["offset"] for t in st.session_state.token_data], + selected_idx=st.session_state.selected_token_index, + ) + + # Sidebar slider for token selection + n_tokens = len(st.session_state.token_data) + if n_tokens > 0: + with controls_expander: + st.header("Token selector") + idx = st.slider( + "Token index", + min_value=0, + max_value=n_tokens - 1, + step=1, + key="selected_token_index", + ) + + selected_token = st.session_state.token_data[idx] + st.write(f"Selected token: {selected_token['text']} (ID: {selected_token['id']})") + + st.divider() + + # --- Token Information Area --- + if st.session_state.token_data: + idx = st.session_state.selected_token_index + # Ensure token_data is loaded before accessing + if ( + st.session_state.token_data + and idx is not None + and idx < len(st.session_state.token_data) + ): + # Layer Selection Dropdown + # Always default to the first layer if nothing is selected yet + if st.session_state.selected_layer_name is None and app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + + with controls_expander: + st.header("Layer selector") + st.selectbox( + "Select Layer to Inspect:", + options=app_data.target_layer_names, + key="selected_layer_name", + ) + + # Display Layer-Specific Info if a layer is selected + if st.session_state.selected_layer_name: + layer_name = st.session_state.selected_layer_name + logger.debug(f"Displaying info for token {idx}, layer {layer_name}") + + if st.session_state.current_masks is None: + st.warning("Masks not calculated yet. Please load a prompt.") + return + + layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[ + layer_name + ] + token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :] + + # Find active components (mask > 0) + active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] + n_active_layer = len(active_indices_layer) + + st.metric(f"Active Components in {layer_name}", n_active_layer) + + st.subheader("Active Component Indices") + if n_active_layer > 0: + # Convert to NumPy array and reshape to a column vector (N x 1) + active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) + # Pass the NumPy array directly and configure the column header + st.dataframe(active_indices_np, height=300, use_container_width=False) + else: + st.write("No active components for this token in this layer.") + + # Extensibility Placeholder + st.subheader("Additional Layer/Token Analysis") + st.write( + "Future figures and analyses for this specific layer and token will appear here." + ) + else: + # Handle case where selected_token_index might be invalid after data reload + st.warning("Selected token index is out of bounds. Please select a token again.") + st.session_state.selected_token_index = None # Reset selection + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Streamlit app to explore LM component activations." + ) + parser.add_argument( + "--model_path", + type=str, + default=DEFAULT_MODEL_PATH, + help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", + ) + args = parser.parse_args() + + run_app(args) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py new file mode 100644 index 0000000..1258495 --- /dev/null +++ b/spd/experiments/lm/component_viz.py @@ -0,0 +1,167 @@ +""" +Vizualises the components of the model. +""" + +import math + +import torch +from jaxtyping import Float +from matplotlib import pyplot as plt +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from torch import Tensor +from torch.utils.data import DataLoader + +from spd.configs import LMTaskConfig +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_component_acts, calc_masks +from spd.types import ModelPath + + +def component_activation_statistics( + model: SSModel, + dataloader: DataLoader[Float[Tensor, "batch pos"]], + n_steps: int, + device: str, +) -> tuple[dict[str, float], dict[str, Float[Tensor, " m"]]]: + """Get the number and strength of the masks over the full dataset.""" + # We used "-" instead of "." as module names can't have "." in them + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items() + } # type: ignore + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items() + } # type: ignore + + n_tokens = {module_name.replace("-", "."): 0 for module_name in components} + total_n_active_components = {module_name.replace("-", "."): 0 for module_name in components} + component_activation_counts = { + module_name.replace("-", "."): torch.zeros(model.m, device=device) + for module_name in components + } + data_iter = iter(dataloader) + for _ in range(n_steps): + # --- Get Batch --- # + batch = next(data_iter)["input_ids"].to(device) + + _, pre_weight_acts = model.forward_with_pre_forward_cache_hooks( + batch, module_names=list(components.keys()) + ) + As = {module_name: v.linear_component.A for module_name, v in components.items()} + + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore + + masks, relud_masks = calc_masks( + gates=gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=False, + ) + for module_name, mask in masks.items(): + assert mask.ndim == 3 # (batch_size, pos, m) + n_tokens[module_name] += mask.shape[0] * mask.shape[1] + # Count the number of components that are active at all + active_components = mask > 0 + total_n_active_components[module_name] += int(active_components.sum().item()) + component_activation_counts[module_name] += active_components.sum(dim=(0, 1)) + + # Show the mean number of components + mean_n_active_components_per_token: dict[str, float] = { + module_name: (total_n_active_components[module_name] / n_tokens[module_name]) + for module_name in components + } + mean_component_activation_counts: dict[str, Float[Tensor, " m"]] = { + module_name: component_activation_counts[module_name] / n_tokens[module_name] + for module_name in components + } + + return mean_n_active_components_per_token, mean_component_activation_counts + + +def plot_mean_component_activation_counts( + mean_component_activation_counts: dict[str, Float[Tensor, " m"]], +) -> plt.Figure: + """Plots the mean activation counts for each component module in a grid.""" + n_modules = len(mean_component_activation_counts) + max_cols = 6 + n_cols = min(n_modules, max_cols) + # Calculate the number of rows needed, rounding up + n_rows = math.ceil(n_modules / n_cols) + + # Create a figure with the calculated number of rows and columns + fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False) + # Ensure axs is always a 2D array for consistent indexing, even if n_modules is 1 + axs = axs.flatten() # Flatten the axes array for easy iteration + + # Iterate through modules and plot each histogram on its corresponding axis + for i, (module_name, counts) in enumerate(mean_component_activation_counts.items()): + ax = axs[i] + ax.hist(counts.detach().cpu().numpy(), bins=100) + ax.set_title(module_name) # Add module name as title to each subplot + ax.set_xlabel("Mean Activation Count") + ax.set_ylabel("Frequency") + + # Hide any unused subplots if the grid isn't perfectly filled + for i in range(n_modules, n_rows * n_cols): + axs[i].axis("off") + + # Adjust layout to prevent overlapping titles/labels + fig.tight_layout() + + return fig + + +def main(path: ModelPath) -> None: + device = "cuda" if torch.cuda.is_available() else "cpu" + ss_model, config, checkpoint_path = SSModel.from_pretrained(path) + ss_model.to(device) + + out_dir = checkpoint_path + + assert isinstance(config.task_config, LMTaskConfig) + dataset_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}", + split=config.task_config.train_data_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + column_name="story", + ) + + dataloader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + # print(ss_model) + print(config) + + mean_n_active_components_per_token, mean_component_activation_counts = ( + component_activation_statistics( + model=ss_model, + dataloader=dataloader, + n_steps=100, + device=device, + ) + ) + logger.info(f"n_components: {ss_model.m}") + logger.info(f"mean_n_active_components_per_token: {mean_n_active_components_per_token}") + logger.info(f"mean_component_activation_counts: {mean_component_activation_counts}") + fig = plot_mean_component_activation_counts( + mean_component_activation_counts=mean_component_activation_counts, + ) + # Save the entire figure once + save_path = out_dir / "modules_mean_component_activation_counts.png" + fig.savefig(save_path) + logger.info(f"Saved combined plot to {str(save_path)}") + + +if __name__ == "__main__": + path = "wandb:spd-lm/runs/151bsctx" + main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml new file mode 100644 index 0000000..0b1260f --- /dev/null +++ b/spd/experiments/lm/lm_config.yaml @@ -0,0 +1,72 @@ +# --- WandB --- +wandb_project: spd-lm +# wandb_project: null # Project name for Weights & Biases +wandb_run_name: null # Set specific run name (optional, otherwise generated) +wandb_run_name_prefix: "" # Prefix for generated run name + +# --- General --- +seed: 0 +unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here) +m: 10000 # Rank of the decomposition / number of components per layer + +# --- Loss Coefficients --- +# Set coeffs to null if the loss shouldn't be computed +param_match_coeff: 1.0 +out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE) +lp_sparsity_coeff: 1e-1 # Coefficient for Lp sparsity loss (applied to component params A & B) +pnorm: 2.0 # p-value for the Lp sparsity norm +layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks + +# Placeholder losses (set coeffs to null as they require mask calculation implementation) +masked_recon_coeff: null # Reconstruction loss using masks +act_recon_coeff: null # Reconstruction loss on intermediate component activations +random_mask_recon_coeff: null # Reconstruction loss averaged over random masks +layerwise_recon_coeff: null # Layer-wise reconstruction loss + +n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used +n_gate_hidden_neurons: null # Not applicable as there are no gates currently + +# --- Training --- +batch_size: 4 # Adjust based on GPU memory +steps: 1_000 # Total training steps +lr: 1e-3 # Learning rate +lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) +lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup +lr_exponential_halflife: null # Required if lr_schedule is exponential +init_from_target_model: false # Not implemented/applicable for this setup + +# --- Logging & Saving --- +image_freq: 1000 # Frequency for generating/logging plots +print_freq: 100 # Frequency for printing logs to console +save_freq: 1_000 # Frequency for saving checkpoints +image_on_first_step: true # Whether to log plots at step 0 + +# --- Task Specific --- +task_config: + task_name: lm # Specifies the LM decomposition task + model_size: "1.25M" # SimpleStories model size (e.g., "1.25M", "5M", "11M", "30M", "35M") + max_seq_len: 512 # Maximum sequence length for truncation/padding + buffer_size: 1000 # Buffer size for streaming dataset shuffling + dataset_name: "lennart-finke/SimpleStories" # HuggingFace dataset name + train_data_split: "train" # Dataset split to use + eval_data_split: "test" # Dataset split to use + n_eval_steps: 100 # Number of evaluation steps + # List of fnmatch patterns for nn.Linear modules to decompose + target_module_patterns: ["transformer.h.0.mlp.gate_proj"] + # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"] + # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] + # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] + +# Config details for the target model taken from https://github.com/danbraunai/simple_stories_train/blob/main/simple_stories_train/models/model_configs.py#L54 + # "1.25M": LlamaConfig( + # block_size=512, + # vocab_size=4096, + # n_layer=4, + # n_head=4, + # n_embd=128, + # n_intermediate=128 * 4 * 2 // 3 = 341, + # rotary_dim=128 // 4 = 32, + # n_ctx=512, + # n_key_value_heads=2, + # flash_attention=True, + # ), \ No newline at end of file diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py new file mode 100644 index 0000000..41f5ce6 --- /dev/null +++ b/spd/experiments/lm/lm_decomposition.py @@ -0,0 +1,545 @@ +"""Language Model decomposition script.""" + +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +import einops +import fire +import matplotlib.pyplot as plt +import torch +import torch.optim as optim +import wandb +import yaml +from jaxtyping import Float +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.component_viz import ( + component_activation_statistics, + plot_mean_component_activation_counts, +) +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import ( + _calc_param_mse, + calc_component_acts, + calc_mask_l_zero, + calc_masks, + calc_random_masks, + get_common_run_name_suffix, +) +from spd.utils import ( + get_device, + get_lr_schedule_fn, + get_lr_with_warmup, + load_config, + set_seed, +) +from spd.wandb_utils import init_wandb + +wandb.require("core") + + +def get_run_name( + config: Config, + model_size: str, + max_seq_len: int, +) -> str: + """Generate a run name based on the config.""" + run_suffix = "" + if config.wandb_run_name: + run_suffix = config.wandb_run_name + else: + run_suffix = get_common_run_name_suffix(config) + run_suffix += f"_lm{model_size}_seq{max_seq_len}" + return config.wandb_run_name_prefix + run_suffix + + +def lm_plot_results_fn( + model: SSModel, + components: dict[str, LinearComponentWithBias], + step: int | None, + out_dir: Path | None, + device: str, + config: Config, + **_, +) -> dict[str, plt.Figure]: + """Plotting function for LM decomposition. Placeholder for now.""" + # TODO: Implement actual plotting (e.g., component matrix values?) + logger.info(f"Plotting results at step {step}...") + fig_dict: dict[str, plt.Figure] = {} + # Example: Potentially plot A/B matrix norms or sparsity patterns? + # fig_dict["component_norms"] = plot_component_norms(components, out_dir, step) + return fig_dict + + +def calc_recon_mse_lm( + out1: Float[Tensor, "batch pos vocab"], + out2: Float[Tensor, "batch pos vocab"], +) -> Float[Tensor, ""]: + """Calculate the Mean Squared Error reconstruction loss for LM logits.""" + assert out1.shape == out2.shape + # Mean over batch and sequence length, sum over vocab + return ((out1 - out2) ** 2).sum(dim=-1).mean() + + +def calc_param_match_loss_lm( + components: dict[str, LinearComponentWithBias], + target_model: Llama, + n_params: int, + device: str, +) -> Float[Tensor, ""]: + """Calculate the MSE loss between component parameters (A@B + bias) and target parameters.""" + target_params: dict[str, Float[Tensor, "d_in d_out"]] = {} + component_params: dict[str, Float[Tensor, "d_in d_out"]] = {} + + for comp_name, component in components.items(): + component_params[comp_name] = einops.einsum( + component.linear_component.A, + component.linear_component.B, + "d_in m, m d_out -> d_in d_out", + ) + target_params[comp_name] = target_model.get_parameter(comp_name + ".weight").T + assert component_params[comp_name].shape == target_params[comp_name].shape + + param_mse = _calc_param_mse( + params1=component_params, + params2=target_params, + n_params=n_params, + device=device, + ) + return param_mse + + +def calc_layerwise_recon_loss_lm( + model: SSModel, + batch: Float[Tensor, "batch pos"], + device: str, + components: dict[str, LinearComponentWithBias], + masks: list[dict[str, Float[Tensor, "batch pos m"]]], + target_out: Float[Tensor, "batch pos vocab"], +) -> Float[Tensor, ""]: + """Calculate the recon loss when augmenting the model one (masked) component at a time.""" + total_loss = torch.tensor(0.0, device=device) + for mask_info in masks: + for component_name, component in components.items(): + module_name = component_name.replace("-", ".") + modified_out, _ = model.forward_with_component( + batch, + module_name=module_name, + component=component, + mask=mask_info.get(component_name, None), + ) + loss = calc_recon_mse_lm(modified_out, target_out) + total_loss += loss + n_modified_components = len(masks[0]) + return total_loss / (n_modified_components * len(masks)) + + +def calc_lp_sparsity_loss_lm( + relud_masks: dict[str, Float[Tensor, "batch pos m"]], pnorm: float +) -> Float[Tensor, ""]: + """Calculate the Lp sparsity loss on the attributions. + + Args: + relud_masks: Dictionary of relu masks for each layer. + pnorm: The pnorm to use for the sparsity loss. + Returns: + The Lp sparsity loss. + """ + # Initialize with zeros matching the shape of first mask + total_loss = torch.zeros_like(next(iter(relud_masks.values()))) + + for layer_relud_mask in relud_masks.values(): + total_loss = total_loss + layer_relud_mask**pnorm + + # Sum over the m dimension and mean over the batch and pos dimensions + return total_loss.sum(dim=-1).mean(dim=[0, 1]) + + +def optimize_lm( + model: SSModel, + config: Config, + device: str, + train_loader: DataLoader[Float[Tensor, "batch pos"]], + eval_loader: DataLoader[Float[Tensor, "batch pos"]], + n_eval_steps: int, + plot_results_fn: Callable[..., dict[str, plt.Figure]], + out_dir: Path | None, +) -> None: + """Run the optimization loop for LM decomposition.""" + + # We used "-" instead of "." as module names can't have "." in them + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items() + } # type: ignore + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items() + } # type: ignore + + component_params = [] + param_names_to_optimize = [] + for name, component in components.items(): + component_params.extend(list(component.parameters())) + param_names_to_optimize.extend( + [f"{name}.{p_name}" for p_name, _ in component.named_parameters()] + ) + logger.debug(f"Adding parameters from component: {name}") + + if not component_params: + logger.error("No parameters found in components to optimize. Exiting.") + return + + optimizer = optim.AdamW(component_params, lr=config.lr, weight_decay=0.0) + logger.info(f"Optimizer created for params: {param_names_to_optimize}") + + lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) + logger.info(f"Base LR scheduler created: {config.lr_schedule}") + + n_params = 0 + for module_name in components: + weight = model.model.get_parameter(module_name + ".weight") + n_params += weight.numel() + + log_data = {} + data_iter = iter(train_loader) + + # Use tqdm directly in the loop, iterate one extra step for final logging/plotting/saving + for step in tqdm(range(config.steps + 1), ncols=0): + # --- LR Scheduling Step --- # + step_lr = get_lr_with_warmup( + step=step, + steps=config.steps, + lr=config.lr, + lr_schedule_fn=lr_schedule_fn, + lr_warmup_pct=config.lr_warmup_pct, + ) + # Manually update optimizer's learning rate + for group in optimizer.param_groups: + group["lr"] = step_lr + log_data["lr"] = step_lr + + # --- Zero Gradients --- # + optimizer.zero_grad() + + # --- Get Batch --- # + try: + batch = next(data_iter)["input_ids"].to(device) + except StopIteration: + logger.warning("Dataloader exhausted, resetting iterator.") + data_iter = iter(train_loader) + batch = next(data_iter)["input_ids"].to(device) + + (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( + batch, module_names=list(components.keys()) + ) + As = {module_name: v.linear_component.A for module_name, v in components.items()} + + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore + # attributions = calc_grad_attributions( + # target_out=target_out, + # pre_weight_acts=pre_weight_acts, + # post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, + # target_component_acts=target_component_acts, + # Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), + # ) + attributions = None + + masks, relud_masks = calc_masks( + gates=gates, + target_component_acts=target_component_acts, + attributions=attributions, + detach_inputs=False, + ) + + # --- Calculate Losses --- # + total_loss = torch.tensor(0.0, device=device) + loss_terms = {} + + ####### param match loss ####### + param_match_loss_val = calc_param_match_loss_lm( + components=components, + target_model=model.model, + n_params=n_params, + device=device, + ) + total_loss += config.param_match_coeff * param_match_loss_val + loss_terms["loss/parameter_matching"] = param_match_loss_val.item() + + ####### layerwise recon loss ####### + if config.layerwise_recon_coeff is not None: + layerwise_recon_loss = calc_layerwise_recon_loss_lm( + model=model, + batch=batch, + device=device, + components=components, + masks=[masks], + target_out=target_out, + ) + total_loss += config.layerwise_recon_coeff * layerwise_recon_loss + loss_terms["loss/layerwise_reconstruction"] = layerwise_recon_loss.item() + + ####### layerwise random recon loss ####### + if config.layerwise_random_recon_coeff is not None: + layerwise_random_masks = calc_random_masks( + masks=masks, n_random_masks=config.n_random_masks + ) + layerwise_random_recon_loss = calc_layerwise_recon_loss_lm( + model=model, + batch=batch, + device=device, + components=components, + masks=layerwise_random_masks, + target_out=target_out, + ) + total_loss += config.layerwise_random_recon_coeff * layerwise_random_recon_loss + loss_terms["loss/layerwise_random_reconstruction"] = layerwise_random_recon_loss.item() + + ####### lp sparsity loss ####### + lp_sparsity_loss = calc_lp_sparsity_loss_lm(relud_masks=relud_masks, pnorm=config.pnorm) + total_loss += config.lp_sparsity_coeff * lp_sparsity_loss + loss_terms["loss/lp_sparsity_loss"] = lp_sparsity_loss.item() + + ####### out recon loss ####### + if config.out_recon_coeff is not None: + # Get target logits (no gradients needed for target model) + with torch.no_grad(): + target_logits, _ = model.forward(batch) + # Detach target logits to ensure no grads flow back + target_logits = target_logits.detach() + + # Get component logits + component_logits, _ = model.forward_with_components( + batch, components=components, masks=masks + ) + + assert component_logits.shape == target_logits.shape, ( + f"Shape mismatch: {component_logits.shape} vs {target_logits.shape}" + ) + + recon_loss = calc_recon_mse_lm(component_logits, target_logits) + total_loss += config.out_recon_coeff * recon_loss + loss_terms["loss/reconstruction"] = recon_loss.item() + + log_data["loss/total"] = total_loss.item() + log_data.update(loss_terms) + + # --- Logging --- # + if step % config.print_freq == 0: + tqdm.write(f"--- Step {step} ---") + tqdm.write(f"LR: {step_lr:.6f}") + tqdm.write(f"Total Loss: {log_data['loss/total']:.7f}") + for name, value in loss_terms.items(): + if value is not None: + tqdm.write(f"{name}: {value:.7f}") + + mean_n_active_components_per_token = component_activation_statistics( + model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device + )[0] + tqdm.write(f"Mean n active components per token: {mean_n_active_components_per_token}") + + if config.wandb_project: + mask_l_zero = calc_mask_l_zero(masks=masks) + for layer_name, layer_mask_l_zero in mask_l_zero.items(): + log_data[f"{layer_name}/mask_l0"] = layer_mask_l_zero + log_data[f"{layer_name}/mean_n_active_components_per_token"] = ( + mean_n_active_components_per_token[layer_name] + ) + wandb.log(log_data, step=step) + + # --- Plotting --- # + if ( + config.image_freq is not None + and step % config.image_freq == 0 + and (step > 0 or config.image_on_first_step) + ): + logger.info(f"Step {step}: Generating plots...") + with torch.no_grad(): + fig_dict = plot_results_fn( + model=model, # Pass the SSModel wrapper + components=components, + step=step, + out_dir=out_dir, + device=device, + config=config, + # Add any other necessary args for plotting like tokenizer, sample text? + ) + mean_component_activation_counts = component_activation_statistics( + model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device + )[1] + fig_dict["mean_component_activation_counts"] = ( + plot_mean_component_activation_counts( + mean_component_activation_counts=mean_component_activation_counts, + ) + ) + if config.wandb_project: + wandb.log( + {k: wandb.Image(v) for k, v in fig_dict.items()}, + step=step, + ) + if out_dir is not None: + for k, v in fig_dict.items(): + v.savefig(out_dir / f"{k}_{step}.png") + tqdm.write(f"Saved plot to {out_dir / f'{k}_{step}.png'}") + + # --- Saving Checkpoint --- # + if ( + (config.save_freq is not None and step % config.save_freq == 0 and step > 0) + or step == config.steps + ) and out_dir is not None: + torch.save(model.state_dict(), out_dir / f"model_{step}.pth") + torch.save(optimizer.state_dict(), out_dir / f"optimizer_{step}.pth") + logger.info(f"Saved model, optimizer, and out_dir to {out_dir}") + if config.wandb_project: + wandb.save(str(out_dir / f"model_{step}.pth"), base_path=str(out_dir), policy="now") + wandb.save( + str(out_dir / f"optimizer_{step}.pth"), base_path=str(out_dir), policy="now" + ) + + # --- Backward Pass & Optimize --- # + # Skip gradient step if we are at the last step (last step just for plotting and logging) + if step != config.steps: + total_loss.backward(retain_graph=True) + + if step % config.print_freq == 0 and config.wandb_project: + # Calculate gradient norm + grad_norm: float = 0.0 + for param in model.parameters(): + if param.grad is not None: + grad_norm += param.grad.data.norm() # type: ignore + wandb.log({"grad_norm": grad_norm}, step=step) + + if config.unit_norm_matrices: + model.fix_normalized_adam_gradients() + + optimizer.step() + logger.info("Finished training loop.") + + +def main( + config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None +) -> None: + config = load_config(config_path_or_obj, config_model=Config) + + if config.wandb_project: + config = init_wandb(config, config.wandb_project, sweep_config_path) + + set_seed(config.seed) + logger.info(config) + + device = get_device() + logger.info(f"Using device: {device}") + assert isinstance(config.task_config, LMTaskConfig), ( + "Task config must be LMTaskConfig for LM decomposition." + ) + + # --- Load Model --- # + logger.info(f"Loading model: {config.task_config.model_size}") + model_config_dict = MODEL_CONFIGS[config.task_config.model_size] + model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + model = Llama.from_pretrained(model_path, model_config_dict) + + ss_model = SSModel( + llama_model=model, + target_module_patterns=config.task_config.target_module_patterns, + m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, + ) + ss_model.to(device) + logger.info("Model loaded.") + + # --- Setup Run Name and Output Dir --- # + run_name = get_run_name( + config, + model_size=config.task_config.model_size, + max_seq_len=config.task_config.max_seq_len, + ) + if config.wandb_project: + assert wandb.run, "wandb.run must be initialized before training" + wandb.run.name = run_name + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] + out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}" + out_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory: {out_dir}") + + # --- Save Config --- # + with open(out_dir / "final_config.yaml", "w") as f: + yaml.dump(config.model_dump(mode="json"), f, indent=2) + if config.wandb_project: + wandb.save(str(out_dir / "final_config.yaml"), base_path=out_dir, policy="now") + + # --- Load Data --- # + logger.info("Loading dataset...") + train_data_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=model_path, + split=config.task_config.train_data_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + column_name="story", + ) + + train_loader, tokenizer = create_data_loader( + dataset_config=train_data_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + + eval_data_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=model_path, + split=config.task_config.eval_data_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + column_name="story", + ) + eval_loader, _ = create_data_loader( + dataset_config=eval_data_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + + logger.info("Dataset and tokenizer loaded.") + + logger.info("Freezing target model parameters...") + for param in ss_model.model.parameters(): + param.requires_grad = False + logger.info("Target model frozen.") + + logger.info("Starting optimization...") + optimize_lm( + model=ss_model, + config=config, + device=device, + train_loader=train_loader, + eval_loader=eval_loader, + n_eval_steps=config.task_config.n_eval_steps, + out_dir=out_dir, + plot_results_fn=lm_plot_results_fn, + ) + + logger.info("Optimization finished.") + + if config.wandb_project: + wandb.finish() + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py new file mode 100644 index 0000000..eafc9b4 --- /dev/null +++ b/spd/experiments/lm/models.py @@ -0,0 +1,267 @@ +""" +Defines a SSModel class that is a wrapper around a llama model from SimpleStories +""" + +import fnmatch +from functools import partial +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn +import wandb +import yaml +from jaxtyping import Float +from pydantic import BaseModel +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from torch import Tensor +from wandb.apis.public import Run + +from spd.configs import Config, LMTaskConfig +from spd.models.components import Gate, GateMLP, LinearComponent +from spd.types import WANDB_PATH_PREFIX, ModelPath +from spd.wandb_utils import ( + download_wandb_file, + fetch_latest_wandb_checkpoint, + fetch_wandb_run_dir, +) + + +class LinearComponentWithBias(nn.Module): + """A LinearComponent with a bias parameter.""" + + def __init__(self, linear_component: LinearComponent, bias: Tensor | None): + super().__init__() + self.linear_component = linear_component + self.bias = bias + self.mask: Float[Tensor, "... m"] | None = None # Gets set on sparse forward passes + + def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]: + # Note: We assume bias is added *after* the component multiplication + # Also assume input is (batch, seq_len, d_in) + out = self.linear_component(x, mask=self.mask) + if self.bias is not None: + out += self.bias + return out + + +def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponentWithBias: + """Replace a nn.Linear module with a LinearComponentWithBias module.""" + d_out, d_in = linear_module.weight.shape + + linear_component = LinearComponent(d_in=d_in, d_out=d_out, m=m, n_instances=None) + + # # Initialize with A = W (original weights) and B = I (identity) + # # This provides a starting point where the component exactly equals the original + # linear_component.A.data[:] = linear_module.weight.t() # (d_in, m) + # linear_component.B.data[:] = torch.eye(m) + + bias = linear_module.bias.clone() if linear_module.bias is not None else None # type: ignore + + return LinearComponentWithBias(linear_component, bias) + + +class SSModelPaths(BaseModel): + """Paths to output files from a SSModel training run.""" + + model: Path + optimizer: Path + config: Path + + +class SSModel(nn.Module): + """Wrapper around a llama model from SimpleStories for running SPD.""" + + def __init__( + self, + llama_model: Llama, + target_module_patterns: list[str], + m: int, + n_gate_hidden_neurons: int | None, + ): + super().__init__() + self.model = llama_model + self.m = m + self.components = self.create_target_components( + target_module_patterns=target_module_patterns, m=m + ) + + # Use GateMLP if n_gate_hidden_neurons is provided, otherwise use Gate + gate_class = GateMLP if n_gate_hidden_neurons is not None else Gate + gate_kwargs = {"m": m} + if n_gate_hidden_neurons is not None: + gate_kwargs["n_gate_hidden_neurons"] = n_gate_hidden_neurons + + self.gates = nn.ModuleDict({name: gate_class(**gate_kwargs) for name in self.components}) + + def create_target_components(self, target_module_patterns: list[str], m: int) -> nn.ModuleDict: + """Create target components for the model.""" + components: dict[str, LinearComponentWithBias] = {} + for name, module in self.model.named_modules(): + for pattern in target_module_patterns: + if fnmatch.fnmatch(name, pattern): + assert isinstance(module, nn.Linear), ( + f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " + f"Found type: {type(module)}" + ) + # Replace "." with "-" in the name to avoid issues with module dict keys + components[name.replace(".", "-")] = nn_linear_to_components(module, m=m) + break + return nn.ModuleDict(components) + + def to(self, *args: Any, **kwargs: Any) -> "SSModel": + """Move the model and components to a device.""" + self.model.to(*args, **kwargs) + for component in self.components.values(): + component.to(*args, **kwargs) + for gate in self.gates.values(): + gate.to(*args, **kwargs) + return self + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Regular forward pass of the (target) model.""" + return self.model(*args, **kwargs) + + def forward_with_component( + self, + *args: Any, + module_name: str, + component: LinearComponentWithBias, + mask: Float[Tensor, "batch pos m"] | None = None, + **kwargs: Any, + ) -> Any: + """Forward pass with a single component replacement.""" + # Note that module_name uses "." separators but self.components use "-" separators + old_module = self.model.get_submodule(module_name) + assert old_module is not None + + self.model.set_submodule(module_name, component) + if mask is not None: + component.mask = mask + + out = self.model(*args, **kwargs) + + self.model.set_submodule(module_name, old_module) + return out + + def forward_with_components( + self, + *args: Any, + components: dict[str, LinearComponentWithBias], + masks: dict[str, Float[Tensor, "batch pos m"]] | None = None, + **kwargs: Any, + ) -> Any: + """Forward pass with temporary component replacement.""" + # Note that components and masks uses "-" separators + old_modules = {} + for component_name, component in components.items(): + module_name = component_name.replace("-", ".") + # component: LinearComponentWithBias = self.components[module_name.replace(".", "-")] + old_module = self.model.get_submodule(module_name) + assert old_module is not None + old_modules[module_name] = old_module + + if masks is not None: + component.mask = masks.get(component_name, None) + self.model.set_submodule(module_name, component) + + out = self.model(*args, **kwargs) + + # Restore the original modules + for module_name, old_module in old_modules.items(): + self.model.set_submodule(module_name, old_module) + + # Remove the masks attribute from the components + for component in components.values(): + component.mask = None + + return out + + def forward_with_pre_forward_cache_hooks( + self, *args: Any, module_names: list[str], **kwargs: Any + ) -> tuple[Any, dict[str, Tensor]]: + """Forward pass with caching at in the input to the modules given by `module_names`. + + Args: + module_names: List of module names to cache the inputs to. + """ + cache = {} + + def cache_hook(module: nn.Module, input: tuple[Tensor, ...], param_name: str) -> Tensor: + cache[param_name] = input[0] + return input[0] + + handles: list[torch.utils.hooks.RemovableHandle] = [] + for module_name in module_names: + module = self.model.get_submodule(module_name) + assert module is not None + handles.append( + module.register_forward_pre_hook(partial(cache_hook, param_name=module_name)) + ) + + out = self.forward(*args, **kwargs) + + for handle in handles: + handle.remove() + + return out, cache + + @staticmethod + def _download_wandb_files(wandb_project_run_id: str) -> SSModelPaths: + """Download the relevant files from a wandb run.""" + api = wandb.Api() + run: Run = api.run(wandb_project_run_id) + + checkpoint = fetch_latest_wandb_checkpoint(run, prefix="model") + + run_dir = fetch_wandb_run_dir(run.id) + + final_config_path = download_wandb_file(run, run_dir, "final_config.yaml") + checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) + + # Get the step number from the path + step = int(Path(checkpoint_path).stem.split("_")[-1]) + + return SSModelPaths( + model=checkpoint_path, + optimizer=download_wandb_file(run, run_dir, f"optimizer_{step}.pth"), + config=final_config_path, + ) + + @classmethod + def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config, Path]: + if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): + wandb_path = path.removeprefix(WANDB_PATH_PREFIX) + api = wandb.Api() + run: Run = api.run(wandb_path) + paths = cls._download_wandb_files(wandb_path) + out_dir = fetch_wandb_run_dir(run.id) + + else: + # Get the step number from the path + step = int(Path(path).stem.split("_")[-1]) + paths = SSModelPaths( + model=Path(path), + optimizer=Path(path).parent / f"optimizer_{step}.pth", + config=Path(path).parent / "final_config.yaml", + ) + out_dir = Path(path).parent + + model_weights = torch.load(paths.model, map_location="cpu", weights_only=True) + with open(paths.config) as f: + config = Config(**yaml.safe_load(f)) + + assert isinstance(config.task_config, LMTaskConfig) + model_config_dict = MODEL_CONFIGS[config.task_config.model_size] + model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + llama_model = Llama.from_pretrained(model_path, model_config_dict) + + ss_model = SSModel( + llama_model=llama_model, + target_module_patterns=config.task_config.target_module_patterns, + m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, + ) + ss_model.load_state_dict(model_weights) + return ss_model, config, out_dir diff --git a/spd/experiments/lm/play.py b/spd/experiments/lm/play.py new file mode 100644 index 0000000..87164f7 --- /dev/null +++ b/spd/experiments/lm/play.py @@ -0,0 +1,94 @@ +# %% +import torch +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from transformers import AutoTokenizer + +from spd.experiments.lm.models import LinearComponentWithBias, SSModel + +# %% +# Select the model size you want to use +model_size = "1.25M" # Options: "35M", "30M", "11M", "5M", "1.25M" + +# Load model configuration +model_config = MODEL_CONFIGS[model_size] + +# Load appropriate model +model_path = f"chandan-sreedhara/SimpleStories-{model_size}" +model = Llama.from_pretrained(model_path, model_config) +# model.to("cuda") +model.eval() +# %% + +ss_model = SSModel( + llama_model=model, + target_module_patterns=["model.transformer.h.*.mlp.gate_proj"], + m=17, + n_gate_hidden_neurons=None, +) + +# # Create components with rank=10 (adjust as needed) +# gate_proj_components = create_target_components( +# model, rank=m, target_module_patterns=["model.transformer.h.*.mlp.gate_proj"] +# ) +gate_proj_components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() +} # type: ignore +# %% +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) + +# Define your prompt +prompt = "The curious cat looked at the" + +# IMPORTANT: Use tokenizer without special tokens +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) +# input_ids = inputs.input_ids.to("cuda") +input_ids = inputs.input_ids +# Targets should be the inputs shifted by one (we will later ignore the last input token) +targets = input_ids[:, 1:] +input_ids = input_ids[:, :-1] + +# IMPORTANT: Set correct EOS token ID (not the default from tokenizer) +eos_token_id = 1 + +# %% + +# # Generate text +# with torch.no_grad(): +# output_ids = model.generate( +# idx=input_ids, max_new_tokens=20, temperature=0.7, top_k=40, eos_token_id=eos_token_id +# ) + +# # Decode output +# output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) +# print(f"Generated text:\n{output_text}") + + +# %% + +# logits, _ = ss_model.forward(input_ids, components=gate_proj_components) +logits, _ = ss_model.forward(input_ids) +print("inputs_shape", input_ids.shape) +print("logits", logits) +print("logits shape", logits.shape) + +logits, _ = ss_model.forward_with_components(input_ids, components=gate_proj_components) + +print("Component logits shape", logits.shape) +print("Component logits", logits) + +# Create some dummy masks +masks = { + f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], ss_model.m) + for i in range(len(model.transformer.h)) +} + +logits, _ = ss_model.forward_with_components( + input_ids, components=gate_proj_components, masks=masks +) + +print("Masked component logits shape", logits.shape) +print("Masked component logits", logits) +######################################################### +# %% diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 7d1b72b..c8e09c7 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -23,7 +23,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset from spd.log import logger -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( @@ -107,7 +107,7 @@ def resid_mlp_plot_results_fn( out_dir: Path | None, device: str, config: Config, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], masks: dict[str, Float[Tensor, "batch_size m"]] | None, **_, ) -> dict[str, plt.Figure]: @@ -161,9 +161,9 @@ def init_spd_model_from_target_model( for i in range(target_model.config.n_layers): # For mlp_in, m must equal d_mlp # TODO: This is broken, we shouldn't need m=d_mlp for this function. - assert ( - m == target_model.config.d_mlp or m == target_model.config.d_embed - ), "m must be equal to d_mlp or d_embed" + assert m == target_model.config.d_mlp or m == target_model.config.d_embed, ( + "m must be equal to d_mlp or d_embed" + ) # For mlp_in: A = target weights, B = identity model.layers[i].mlp_in.A.data[:] = target_model.layers[i].mlp_in.weight.data.clone() diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 70ff530..2d70087 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -20,7 +20,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig from spd.log import logger -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( @@ -54,7 +54,7 @@ def make_plots( out_dir: Path, device: str, config: Config, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], masks: dict[str, Float[Tensor, "batch n_instances m"]], batch: Float[Tensor, "batch n_instances n_features"], **_, diff --git a/spd/plotting.py b/spd/plotting.py index 1eb1adf..91d438a 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -10,7 +10,7 @@ from spd.hooks import HookedRootModule from spd.models.base import SPDModel -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.module_utils import collect_nested_module_attrs from spd.run_spd import calc_component_acts, calc_masks @@ -48,7 +48,7 @@ def permute_to_identity( def plot_mask_vals( model: SPDModel, target_model: HookedRootModule, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], device: str, input_magnitude: float, ) -> tuple[plt.Figure, dict[str, Float[Tensor, "n_instances m"]]]: @@ -146,7 +146,7 @@ def plot_subnetwork_attributions_statistics( ax.set_ylabel("Count") ax.set_xlabel("Number of active subnetworks") - ax.set_title(f"Instance {i+1}") + ax.set_title(f"Instance {i + 1}") # Add value annotations on top of each bar for bar in bars: @@ -212,9 +212,9 @@ def plot_AB_matrices( # Verify that A and B matrices have matching names A_names = set(As.keys()) B_names = set(Bs.keys()) - assert ( - A_names == B_names - ), f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}" + assert A_names == B_names, ( + f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}" + ) n_layers = len(As) diff --git a/spd/run_spd.py b/spd/run_spd.py index eef483e..6e30cab 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -17,7 +17,7 @@ from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel -from spd.models.components import Gate, Linear, LinearComponent +from spd.models.components import Gate, GateMLP, Linear, LinearComponent from spd.module_utils import collect_nested_module_attrs, get_nested_module_attr from spd.utils import calc_recon_mse, get_lr_schedule_fn, get_lr_with_warmup @@ -143,7 +143,7 @@ def calc_act_recon_mse( def calc_masks( - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], target_component_acts: dict[ str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] ], diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py index 73d2d67..9f451a7 100644 --- a/spd/wandb_utils.py +++ b/spd/wandb_utils.py @@ -45,6 +45,7 @@ def fetch_wandb_run_dir(run_id: str) -> Path: """ # Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb")) + base_cache_dir.mkdir(parents=True, exist_ok=True) # Set default wandb_run_dir wandb_run_dir = base_cache_dir / run_id / "files" diff --git a/tests/test_module_utils.py b/tests/test_module_utils.py new file mode 100644 index 0000000..a59643d --- /dev/null +++ b/tests/test_module_utils.py @@ -0,0 +1,15 @@ +from torch import nn + +from spd.module_utils import get_nested_module_attr + + +def test_get_nested_module_attr(): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + module = TestModule() + assert get_nested_module_attr(module, "linear1.weight.data").shape == (10, 10) + assert get_nested_module_attr(module, "linear2.weight.data").shape == (10, 10)