diff --git a/.vscode/launch.json b/.vscode/launch.json
index 89875d2..a753e33 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -48,6 +48,18 @@
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
+ },
+ {
+ "name": "lm streamlit",
+ "type": "debugpy",
+ "request": "launch",
+ "module": "streamlit",
+ "args": [
+ "run",
+ "${workspaceFolder}/spd/experiments/lm/app.py",
+ "--server.port",
+ "2000"
+ ]
}
]
}
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 49e6f8a..4031cfc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,10 @@ dependencies = [
"python-dotenv",
"wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248
"sympy",
+ "streamlit",
+ "streamlit-antd-components",
+ "datasets",
+ "simple-stories-train"
]
[project.optional-dependencies]
diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py
new file mode 100644
index 0000000..d880198
--- /dev/null
+++ b/spd/experiments/lm/app.py
@@ -0,0 +1,405 @@
+"""
+To run this app, run the following command:
+
+```bash
+ streamlit run spd/experiments/lm/app.py -- --model_path "wandb:spd-lm/runs/151bsctx"
+```
+"""
+
+import argparse
+import html
+from collections.abc import Callable, Iterator
+from dataclasses import dataclass
+from typing import Any
+
+import streamlit as st
+import torch
+from datasets import load_dataset
+from jaxtyping import Float, Int
+from simple_stories_train.dataloaders import DatasetConfig
+from torch import Tensor
+from transformers import AutoTokenizer
+
+from spd.configs import Config, LMTaskConfig
+from spd.experiments.lm.models import LinearComponentWithBias, SSModel
+from spd.log import logger
+from spd.models.components import Gate, GateMLP
+from spd.run_spd import calc_component_acts, calc_masks
+from spd.types import ModelPath
+
+DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx"
+
+
+# -----------------------------------------------------------
+# Dataclass holding everything the app needs
+# -----------------------------------------------------------
+@dataclass(frozen=True)
+class AppData:
+ model: SSModel
+ tokenizer: AutoTokenizer
+ config: Config
+ dataloader_iter_fn: Callable[[], Iterator[dict[str, Any]]]
+ gates: dict[str, Gate | GateMLP]
+ components: dict[str, LinearComponentWithBias]
+ target_layer_names: list[str]
+ device: str
+
+
+# --- Initialization and Data Loading ---
+@st.cache_resource(show_spinner="Loading model and data...")
+def initialize(model_path: ModelPath) -> AppData:
+ """
+ Loads the model, tokenizer, config, and evaluation dataloader.
+ Cached by Streamlit based on the model_path.
+ """
+ device = "cpu" # Use CPU for the Streamlit app
+ logger.info(f"Initializing app with model: {model_path} on device: {device}")
+ ss_model, config, _ = SSModel.from_pretrained(model_path)
+ ss_model.to(device)
+ ss_model.eval()
+
+ task_config = config.task_config
+ assert isinstance(task_config, LMTaskConfig), "Task config must be LMTaskConfig for this app."
+
+ # Derive tokenizer path (adjust if stored differently)
+ tokenizer_path = f"chandan-sreedhara/SimpleStories-{task_config.model_size}"
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_path,
+ add_bos_token=False,
+ unk_token="[UNK]",
+ eos_token="[EOS]",
+ bos_token=None,
+ )
+
+ # Create eval dataloader config
+ eval_data_config = DatasetConfig(
+ name=task_config.dataset_name,
+ tokenizer_file_path=None,
+ hf_tokenizer_path=tokenizer_path,
+ split=task_config.eval_data_split,
+ n_ctx=task_config.max_seq_len,
+ is_tokenized=False,
+ streaming=False, # Non-streaming might be simpler for iterator reset
+ column_name="story",
+ )
+
+ # Create the dataloader iterator
+ def create_dataloader_iter() -> Iterator[dict[str, Any]]:
+ """
+ Returns a *new* iterator each time it is called.
+ Each element is a dict with:
+ - "text": the raw document text
+ - "input_ids": Int[Tensor, "1 seq_len"]
+ - "offset_mapping": list[tuple[int, int]]
+ """
+ logger.info("Creating new dataloader iterator.")
+
+ # Stream the HF dataset split
+ dataset = load_dataset(
+ eval_data_config.name,
+ streaming=eval_data_config.streaming,
+ split=eval_data_config.split,
+ trust_remote_code=False,
+ )
+
+ dataset = dataset.with_format("torch")
+
+ text_column = eval_data_config.column_name
+
+ def tokenize_and_prepare(example: dict[str, Any]) -> dict[str, Any]:
+ original_text: str = example[text_column]
+
+ tokenized = tokenizer(
+ original_text,
+ return_tensors="pt",
+ return_offsets_mapping=True,
+ truncation=True,
+ max_length=task_config.max_seq_len,
+ padding=False,
+ )
+
+ input_ids: Int[Tensor, "1 seq_len"] = tokenized["input_ids"]
+ if input_ids.dim() == 1: # Ensure 2‑D [1, seq_len]
+ input_ids = input_ids.unsqueeze(0)
+
+ # HF returns offset_mapping as a list per sequence; batch size is 1
+ offset_mapping: list[tuple[int, int]] = tokenized["offset_mapping"][0].tolist()
+
+ return {
+ "text": original_text,
+ "input_ids": input_ids,
+ "offset_mapping": offset_mapping,
+ }
+
+ # Map over the streaming dataset and return an iterator
+ return map(tokenize_and_prepare, iter(dataset))
+
+ # Extract components and gates
+ gates: dict[str, Gate | GateMLP] = {
+ k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items()
+ } # type: ignore[reportAssignmentType]
+ components: dict[str, LinearComponentWithBias] = {
+ k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items()
+ } # type: ignore[reportAssignmentType]
+ target_layer_names = sorted(list(components.keys()))
+
+ logger.info(f"Initialization complete for {model_path}.")
+ return AppData(
+ model=ss_model,
+ tokenizer=tokenizer,
+ config=config,
+ dataloader_iter_fn=create_dataloader_iter,
+ gates=gates,
+ components=components,
+ target_layer_names=target_layer_names,
+ device=device,
+ )
+
+
+# -----------------------------------------------------------
+# Utility: render the prompt with faint token outlines
+# -----------------------------------------------------------
+def render_prompt_with_tokens(
+ *,
+ raw_text: str,
+ offset_mapping: list[tuple[int, int]],
+ selected_idx: int | None,
+) -> None:
+ """
+ Renders `raw_text` inside Streamlit, wrapping each token span with a thin
+ border. The currently‑selected token receives a thicker red border.
+ All other tokens get a thin mid‑grey border (no background fill).
+ """
+ html_chunks: list[str] = []
+ cursor = 0
+
+ def esc(s: str) -> str:
+ return html.escape(s)
+
+ for idx, (start, end) in enumerate(offset_mapping):
+ if cursor < start:
+ html_chunks.append(esc(raw_text[cursor:start]))
+
+ token_substr = esc(raw_text[start:end])
+ if token_substr:
+ is_selected = idx == selected_idx
+ border_style = (
+ "2px solid rgb(200,0,0)" if is_selected else "0.5px solid #aaa" # all other tokens
+ )
+ html_chunks.append(
+ "'
+ f"{token_substr}"
+ )
+ cursor = end
+
+ if cursor < len(raw_text):
+ html_chunks.append(esc(raw_text[cursor:]))
+
+ st.markdown(
+ f'
{"".join(html_chunks)}
',
+ unsafe_allow_html=True,
+ )
+
+
+def load_next_prompt() -> None:
+ """Loads the next prompt, calculates masks, and prepares token data."""
+ logger.info("Loading next prompt.")
+ app_data: AppData = st.session_state.app_data
+ dataloader_iter = st.session_state.dataloader_iter # Get current iterator
+
+ try:
+ batch = next(dataloader_iter)
+ input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data.device)
+ except StopIteration:
+ logger.warning("Dataloader iterator exhausted. Throwing error.")
+ st.error("Failed to get data even after resetting dataloader.")
+ return
+
+ st.session_state.current_input_ids = input_ids
+
+ # Store the original raw prompt text
+ st.session_state.current_prompt_text = batch["text"]
+
+ # Calculate activations and masks
+ with torch.no_grad():
+ (_, _), pre_weight_acts = app_data.model.forward_with_pre_forward_cache_hooks(
+ input_ids, module_names=list(app_data.components.keys())
+ )
+ As = {module_name: v.linear_component.A for module_name, v in app_data.components.items()}
+ target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore[reportArgumentType]
+ masks, _ = calc_masks(
+ gates=app_data.gates,
+ target_component_acts=target_component_acts,
+ attributions=None,
+ detach_inputs=True, # No gradients needed
+ )
+ st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]]
+
+ # Prepare token data for display
+ token_data = []
+ tokenizer = app_data.tokenizer
+ for i, token_id in enumerate(input_ids[0]):
+ # Decode individual token - might differ slightly from full decode for spaces etc.
+ decoded_token_str = tokenizer.decode([token_id]) # type: ignore[reportAttributeAccessIssue]
+ token_data.append(
+ {
+ "id": token_id.item(),
+ "text": decoded_token_str,
+ "index": i,
+ "offset": batch["offset_mapping"][i], # (start, end)
+ }
+ )
+ st.session_state.token_data = token_data
+
+ # Reset selections
+ st.session_state.selected_token_index = 0 # default: first token
+ st.session_state.selected_layer_name = None
+ logger.info("Finished loading next prompt and calculating masks.")
+
+
+# --- Main App UI ---
+def run_app(args: argparse.Namespace) -> None:
+ """Sets up and runs the Streamlit application."""
+ st.set_page_config(layout="wide")
+ st.title("LM Component Activation Explorer")
+
+ # Initialize model, data, etc. (cached)
+ st.session_state.app_data = initialize(args.model_path)
+ app_data: AppData = st.session_state.app_data
+ st.caption(f"Model: {args.model_path}")
+
+ # Initialize session state variables if they don't exist
+ if "current_prompt_text" not in st.session_state:
+ st.session_state.current_prompt_text = None
+ if "token_data" not in st.session_state:
+ st.session_state.token_data = None
+ if "current_masks" not in st.session_state:
+ st.session_state.current_masks = None
+ if "selected_token_index" not in st.session_state:
+ st.session_state.selected_token_index = None
+ if "selected_layer_name" not in st.session_state:
+ if app_data.target_layer_names:
+ st.session_state.selected_layer_name = app_data.target_layer_names[0]
+ else:
+ st.session_state.selected_layer_name = None
+ # Initialize the dataloader iterator in session state
+ if "dataloader_iter" not in st.session_state:
+ st.session_state.dataloader_iter = app_data.dataloader_iter_fn()
+
+ if st.session_state.current_prompt_text is None:
+ load_next_prompt()
+
+ # Sidebar container and a single expander for all interactive controls
+ sidebar = st.sidebar
+ controls_expander = sidebar.expander("Controls", expanded=True)
+
+ # ------------------------------------------------------------------
+ # Sidebar – interactive controls
+ # ------------------------------------------------------------------
+ with controls_expander:
+ st.button("Load Next Prompt", on_click=load_next_prompt)
+
+ # Render the raw prompt with faint token borders
+ if st.session_state.token_data and st.session_state.current_prompt_text:
+ # st.subheader("Prompt")
+ render_prompt_with_tokens(
+ raw_text=st.session_state.current_prompt_text,
+ offset_mapping=[t["offset"] for t in st.session_state.token_data],
+ selected_idx=st.session_state.selected_token_index,
+ )
+
+ # Sidebar slider for token selection
+ n_tokens = len(st.session_state.token_data)
+ if n_tokens > 0:
+ with controls_expander:
+ st.header("Token selector")
+ idx = st.slider(
+ "Token index",
+ min_value=0,
+ max_value=n_tokens - 1,
+ step=1,
+ key="selected_token_index",
+ )
+
+ selected_token = st.session_state.token_data[idx]
+ st.write(f"Selected token: {selected_token['text']} (ID: {selected_token['id']})")
+
+ st.divider()
+
+ # --- Token Information Area ---
+ if st.session_state.token_data:
+ idx = st.session_state.selected_token_index
+ # Ensure token_data is loaded before accessing
+ if (
+ st.session_state.token_data
+ and idx is not None
+ and idx < len(st.session_state.token_data)
+ ):
+ # Layer Selection Dropdown
+ # Always default to the first layer if nothing is selected yet
+ if st.session_state.selected_layer_name is None and app_data.target_layer_names:
+ st.session_state.selected_layer_name = app_data.target_layer_names[0]
+
+ with controls_expander:
+ st.header("Layer selector")
+ st.selectbox(
+ "Select Layer to Inspect:",
+ options=app_data.target_layer_names,
+ key="selected_layer_name",
+ )
+
+ # Display Layer-Specific Info if a layer is selected
+ if st.session_state.selected_layer_name:
+ layer_name = st.session_state.selected_layer_name
+ logger.debug(f"Displaying info for token {idx}, layer {layer_name}")
+
+ if st.session_state.current_masks is None:
+ st.warning("Masks not calculated yet. Please load a prompt.")
+ return
+
+ layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[
+ layer_name
+ ]
+ token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :]
+
+ # Find active components (mask > 0)
+ active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0]
+ n_active_layer = len(active_indices_layer)
+
+ st.metric(f"Active Components in {layer_name}", n_active_layer)
+
+ st.subheader("Active Component Indices")
+ if n_active_layer > 0:
+ # Convert to NumPy array and reshape to a column vector (N x 1)
+ active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1)
+ # Pass the NumPy array directly and configure the column header
+ st.dataframe(active_indices_np, height=300, use_container_width=False)
+ else:
+ st.write("No active components for this token in this layer.")
+
+ # Extensibility Placeholder
+ st.subheader("Additional Layer/Token Analysis")
+ st.write(
+ "Future figures and analyses for this specific layer and token will appear here."
+ )
+ else:
+ # Handle case where selected_token_index might be invalid after data reload
+ st.warning("Selected token index is out of bounds. Please select a token again.")
+ st.session_state.selected_token_index = None # Reset selection
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Streamlit app to explore LM component activations."
+ )
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default=DEFAULT_MODEL_PATH,
+ help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}",
+ )
+ args = parser.parse_args()
+
+ run_app(args)
diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py
index 33c7c5a..1258495 100644
--- a/spd/experiments/lm/component_viz.py
+++ b/spd/experiments/lm/component_viz.py
@@ -163,5 +163,5 @@ def main(path: ModelPath) -> None:
if __name__ == "__main__":
- path = "wandb:spd-lm/runs/hmjepm9b"
+ path = "wandb:spd-lm/runs/151bsctx"
main(path)
diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml
index 04d98d6..0b1260f 100644
--- a/spd/experiments/lm/lm_config.yaml
+++ b/spd/experiments/lm/lm_config.yaml
@@ -28,7 +28,7 @@ n_gate_hidden_neurons: null # Not applicable as there are no gates currently
# --- Training ---
batch_size: 4 # Adjust based on GPU memory
-steps: 10_000 # Total training steps
+steps: 1_000 # Total training steps
lr: 1e-3 # Learning rate
lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential)
lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup
@@ -38,7 +38,7 @@ init_from_target_model: false # Not implemented/applicable for this setup
# --- Logging & Saving ---
image_freq: 1000 # Frequency for generating/logging plots
print_freq: 100 # Frequency for printing logs to console
-save_freq: 10_000 # Frequency for saving checkpoints
+save_freq: 1_000 # Frequency for saving checkpoints
image_on_first_step: true # Whether to log plots at step 0
# --- Task Specific ---
diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py
index 73d2d67..9f451a7 100644
--- a/spd/wandb_utils.py
+++ b/spd/wandb_utils.py
@@ -45,6 +45,7 @@ def fetch_wandb_run_dir(run_id: str) -> Path:
"""
# Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set
base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb"))
+ base_cache_dir.mkdir(parents=True, exist_ok=True)
# Set default wandb_run_dir
wandb_run_dir = base_cache_dir / run_id / "files"