From 47763eec239fef12477e77f13c047e0eb272f9e8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 17 Apr 2025 11:26:42 +0000 Subject: [PATCH 1/4] WIP: Add dashboard --- pyproject.toml | 2 + spd/experiments/lm/app.py | 532 ++++++++++++++++++++++++++++ spd/experiments/lm/component_viz.py | 2 +- spd/experiments/lm/lm_config.yaml | 4 +- spd/experiments/lm/streamlit_app.py | 297 ++++++++++++++++ 5 files changed, 834 insertions(+), 3 deletions(-) create mode 100644 spd/experiments/lm/app.py create mode 100644 spd/experiments/lm/streamlit_app.py diff --git a/pyproject.toml b/pyproject.toml index 49e6f8a..38b4d51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "python-dotenv", "wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248 "sympy", + "streamlit", + "streamlit-antd-components", ] [project.optional-dependencies] diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py new file mode 100644 index 0000000..d6ff13c --- /dev/null +++ b/spd/experiments/lm/app.py @@ -0,0 +1,532 @@ +import argparse +import logging +from collections.abc import Iterator +from typing import Any, cast + +import gradio as gr +import torch +from jaxtyping import Float, Int +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from torch import Tensor +from transformers import AutoTokenizer + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_component_acts, calc_masks +from spd.types import ModelPath + +# --- Configuration & Constants --- + +DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# --- Data Structures --- + +# Structure to hold information mapping character spans to tokens +TokenMapItem = dict[str, Any] # Keys: 'text', 'span': tuple[int, int], 'index': int, 'id': int +TokenMap = list[TokenMapItem] + +# Structure for Gradio state +AppState = dict[str, Any] # Keys: 'model', 'tokenizer', 'config', 'gates', 'components', etc. + + +# --- Core Functions --- + + +@torch.no_grad() +def load_resources(model_path: ModelPath, device: str) -> AppState: + """Loads the model, tokenizer, config, components, and gates.""" + logger.info(f"Loading resources for model: {model_path} on device: {device}") + ss_model, config, _ = SSModel.from_pretrained(model_path) + ss_model.to(device) + ss_model.eval() + + assert isinstance(config.task_config, LMTaskConfig), ( + "Task config must be LMTaskConfig for this app." + ) + + # Derive tokenizer path + tokenizer_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + # Use the base tokenizer from AutoTokenizer for consistency if needed, + # but create_data_loader might load its own. Ensure they are compatible. + # For decoding/mapping, AutoTokenizer is convenient. + hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, legacy=False) + + # Extract components and gates + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() + } + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() + } + target_layer_names = sorted(list(components.keys())) + + logger.info(f"Finished loading resources for {model_path}.") + return { + "model": ss_model, + "tokenizer": hf_tokenizer, # Use HF Tokenizer for decoding/mapping + "config": config, + "gates": gates, + "components": components, + "target_layer_names": target_layer_names, + "device": device, + "tokenizer_path": tokenizer_path, # Store path for dataloader + } + + +def create_eval_dataloader_iter( + app_state: AppState, +) -> Iterator[dict[str, Int[Tensor, "1 seq_len"]]]: + """Creates a new iterator for the evaluation dataloader.""" + config: Config = app_state["config"] + task_config: LMTaskConfig = cast(LMTaskConfig, config.task_config) + tokenizer_path: str = app_state["tokenizer_path"] + logger.info("Creating new evaluation dataloader iterator.") + + eval_data_config = DatasetConfig( + name=task_config.dataset_name, + tokenizer_file_path=None, # Use HF tokenizer path + hf_tokenizer_path=tokenizer_path, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=False, # Tokenize on the fly + streaming=True, # Use streaming as requested + column_name="story", + seed=config.seed, # Use same seed for reproducibility if needed + ) + + dataloader, _ = create_data_loader( + dataset_config=eval_data_config, + batch_size=1, # Always use batch size 1 for this app + buffer_size=task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + # Make the dataloader an explicit iterator + return iter(dataloader) + + +def get_token_mapping( + tokenizer: AutoTokenizer, input_ids: Int[Tensor, "1 seq_len"] +) -> tuple[str, TokenMap]: + """ + Decodes input_ids and creates a mapping from character spans to token info. + Handles potential decoding artifacts like extra spaces. + """ + ids_list = input_ids[0].tolist() + full_text = tokenizer.decode(ids_list, skip_special_tokens=True) + logger.debug(f"Full decoded text length: {len(full_text)}") + + token_map: TokenMap = [] + current_char_index = 0 + + for token_idx, token_id in enumerate(ids_list): + # Decode individual token *without* special tokens or added spaces + # Note: This might differ slightly from full decode for some tokenizers (e.g., SentencePiece) + # We prioritize matching the token's contribution to the full decoded string. + token_text = tokenizer.decode( + [token_id], skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + # Find the *next* occurrence of this token's text in the full string + try: + # Find the start index, searching from the current position + start_char = full_text.index(token_text, current_char_index) + end_char = start_char + len(token_text) + + # Store mapping info + token_map_item: TokenMapItem = { + "text": token_text, + "span": (start_char, end_char), + "index": token_idx, + "id": token_id, + } + token_map.append(token_map_item) + # logger.debug(f"Mapped token {token_idx} (ID: {token_id}, Text: '{token_text}') to span {token_map_item['span']}") + + # Update current character index for the next search + current_char_index = end_char + + except ValueError: + # This can happen if the individual token decode differs significantly + # from its representation in the full decode (e.g., spaces, merges) + logger.warning( + f"Could not find token_text='{token_text}' (ID: {token_id}, Index: {token_idx}) " + f"in remaining full_text='{full_text[current_char_index:]}'. Skipping token mapping." + ) + # Attempt to gracefully handle by skipping or trying alternative decodes if necessary + # For now, we just log and potentially skip. A robust solution might require + # tokenizer-specific logic or offset mapping if available. + + # Verification step (optional but recommended) + if current_char_index != len(full_text) and len(token_map) == len(ids_list): + logger.warning( + f"Final character index {current_char_index} does not match full text length {len(full_text)}. Mapping might be imperfect." + ) + elif len(token_map) != len(ids_list): + logger.warning( + f"Mapped {len(token_map)} tokens, but expected {len(ids_list)}. Mapping is incomplete." + ) + + return full_text, token_map + + +@torch.no_grad() +def calculate_masks_for_batch( + app_state: AppState, input_ids: Int[Tensor, "1 seq_len"] +) -> dict[str, Float[Tensor, "1 seq_len m"]]: + """Performs forward pass and calculates masks for the given input_ids.""" + model: SSModel = app_state["model"] + components: dict[str, LinearComponentWithBias] = app_state["components"] + gates: dict[str, Gate | GateMLP] = app_state["gates"] + device: str = app_state["device"] + + input_ids = input_ids.to(device) + + logger.info("Running forward pass to get activations...") + (_, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( + input_ids, module_names=list(components.keys()) + ) + logger.info("Calculating component activations...") + As = {module_name: v.linear_component.A for module_name, v in components.items()} + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + + logger.info("Calculating masks...") + masks, _ = calc_masks( + gates=gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=True, # No gradients needed for inference + ) + logger.info("Mask calculation complete.") + # Ensure masks are on CPU for Gradio state if needed, although calculations were on device + return {k: v.cpu() for k, v in masks.items()} + + +def update_token_display( + selected_token_info: TokenMapItem | None, + selected_layer_name: str | None, + current_masks: dict[str, Float[Tensor, "1 seq_len m"]] | None, +) -> tuple[str, str, Any, bool]: + """ + Generates display updates for the token info area based on selection. + Returns: (token_summary_md, active_comp_summary_md, active_indices_df_data, area_visible) + """ + if not selected_token_info or not selected_layer_name or not current_masks: + logger.debug("Update condition not met, hiding token info area.") + return "", "", None, False # Hide area if no token/layer/masks + + token_idx = selected_token_info["index"] + token_id = selected_token_info["id"] + token_text = selected_token_info["text"] + + logger.info( + f"Updating display for token {token_idx} ('{token_text}'), layer {selected_layer_name}" + ) + + token_summary_md = f"**Token Info:** '{token_text}' (Position: {token_idx}, ID: {token_id})" + + if selected_layer_name not in current_masks: + logger.warning(f"Selected layer {selected_layer_name} not found in current masks.") + return token_summary_md, "Error: Layer not found in masks.", None, True + + layer_mask_tensor: Float[Tensor, "1 seq_len m"] = current_masks[selected_layer_name] + # Ensure tensor is on CPU before indexing if it wasn't already + token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, token_idx, :].cpu() + + # Find active components (mask > 0) + active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] + n_active_layer = len(active_indices_layer) + + active_comp_summary_md = f"**Active Components in {selected_layer_name}:** {n_active_layer}" + + active_indices_df_data = None + if n_active_layer > 0: + # Convert to NumPy array and reshape for DataFrame (N x 1) + active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) + active_indices_df_data = active_indices_np + logger.debug(f"Found {n_active_layer} active components.") + else: + logger.debug("No active components found.") + + return token_summary_md, active_comp_summary_md, active_indices_df_data, True + + +# --- Gradio Application --- + +# Store the iterator globally (or in a mutable container) for access within handlers +# This is suitable for single-session Gradio apps. +global_iterator_store: dict[str, Iterator[dict[str, Int[Tensor, "1 seq_len"]]] | None] = { + "iterator": None +} + + +def build_gradio_app(args: argparse.Namespace) -> gr.Blocks: + """Builds the Gradio Blocks interface.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + # Load resources once + initial_app_state = load_resources(args.model_path, device) + # Create the initial iterator and store it globally + global_iterator_store["iterator"] = create_eval_dataloader_iter(initial_app_state) + + with gr.Blocks(title="LM Component Activation Explorer") as app: + # --- State Management --- + # Store heavy objects and mutable state here + app_state = gr.State(initial_app_state) + # REMOVE: dataloader_iter = gr.State(initial_dataloader_iter) # CANNOT STORE ITERATOR IN STATE + # State for current prompt data + current_input_ids = gr.State(None) + current_token_map = gr.State(None) # Stores the TokenMap list + current_masks = gr.State(None) + # State for user selections + selected_token_info = gr.State(None) # Stores the selected TokenMapItem + selected_layer_name = gr.State( + initial_app_state["target_layer_names"][0] + if initial_app_state["target_layer_names"] + else None + ) + + # --- UI Layout --- + gr.Markdown(f"# LM Component Activation Explorer\n**Model:** {args.model_path}") + + with gr.Row(): + # Use Textbox for raw display, HighlightedText for interaction + # prompt_display_text = gr.Textbox(label="Prompt Text", lines=5, interactive=False) + prompt_display_highlight = gr.HighlightedText( + label="Prompt (Click Tokens)", + interactive=True, + combine_adjacent=False, # Treat each token span separately + show_legend=False, + ) + next_button = gr.Button("Load Next Prompt") + + layer_dropdown = gr.Dropdown( + label="Select Layer", + choices=initial_app_state["target_layer_names"], + value=initial_app_state["target_layer_names"][0] + if initial_app_state["target_layer_names"] + else None, + interactive=True, + ) + + # Initially hidden area for token details + with gr.Column(visible=False) as token_info_area: + token_summary = gr.Markdown() + active_components_summary = gr.Markdown() + active_indices_table = gr.DataFrame( + headers=["Component Index"], + datatype=["number"], + col_count=(1, "fixed"), + # max_rows=10, # Limit visible rows initially + interactive=False, + label="Active Component Indices", + ) + future_plots_placeholder = gr.Markdown("*Future analyses will appear here.*") + + # --- Event Handlers --- + + def load_next_prompt_data( + app_state_val: AppState, # Removed current_iter_state input + ) -> tuple[ + # REMOVED: Iterator, # Updated iterator state + Int[Tensor, "1 seq_len"], # current_input_ids + TokenMap, # current_token_map + dict[str, Float[Tensor, "1 seq_len m"]], # current_masks + list[tuple[str, str | None]], # HighlightedText value + # Reset selections + None, # selected_token_info + str, # token_summary update + str, # active_comp_summary update + None, # active_indices_table update + gr.update, # Return an update for the Column + ]: + """Gets next batch, calculates masks, prepares display data.""" + logger.info("Attempting to load next prompt...") + # Access the iterator from the global store + iterator = global_iterator_store["iterator"] + if iterator is None: + logger.error("Iterator not initialized.") + raise gr.Error("Iterator not initialized. Please restart the app.") + + try: + batch = next(iterator) + input_ids: Int[Tensor, "1 seq_len"] = batch[ + "input_ids" + ] # Should be shape (1, seq_len) + logger.info(f"Loaded batch with shape: {input_ids.shape}") + except StopIteration: + logger.warning("Dataloader iterator exhausted. Resetting.") + # Recreate iterator and update the global store + iterator = create_eval_dataloader_iter(app_state_val) + global_iterator_store["iterator"] = iterator + try: + batch = next(iterator) + input_ids = batch["input_ids"] + logger.info(f"Loaded first batch after reset, shape: {input_ids.shape}") + except StopIteration: + logger.error("Failed to get data even after resetting dataloader.") + # Handle error state appropriately, maybe raise gr.Error + raise gr.Error("Dataset seems empty or failed to load.") + + # Ensure input_ids are on CPU for mapping + input_ids_cpu = input_ids.cpu() + tokenizer = app_state_val["tokenizer"] + + # 1. Get Token Mapping + full_text, token_map = get_token_mapping(tokenizer, input_ids_cpu) + # Format for HighlightedText: list of (text_substring, label/tooltip) + # We use the token index as the label for now. + highlight_data = [(item["text"], f"Token {item['index']}") for item in token_map] + + # 2. Calculate Masks (can run on GPU if available) + masks = calculate_masks_for_batch( + app_state_val, input_ids + ) # input_ids passed might be on CPU or GPU based on device + + # 3. Return all updates (excluding the iterator) + return ( + # REMOVED: current_iter_state, + input_ids_cpu, # Store CPU version in state + token_map, + masks, + highlight_data, + None, # Reset selected token + "", # Reset token_summary + "", # Reset active_components_summary + None, # Reset active_indices_table + gr.update(visible=False), # Return update object + ) + + next_button.click( + fn=load_next_prompt_data, + inputs=[app_state], # Removed dataloader_iter + outputs=[ + # REMOVED: dataloader_iter, + current_input_ids, + current_token_map, + current_masks, + prompt_display_highlight, + selected_token_info, # Reset + token_summary, + active_components_summary, + active_indices_table, + token_info_area, # Reset + ], + queue=True, # Use queue for potentially long-running model calls + ) + + def handle_token_select( + evt: gr.SelectData, + current_token_map_val: TokenMap | None, + current_masks_val: dict[str, Float[Tensor, "1 seq_len m"]] | None, + selected_layer_name_val: str | None, + ) -> tuple[TokenMapItem | None, str, str, Any, gr.update]: + """Handles token selection, finds corresponding token info, updates display.""" + logger.debug(f"Token selected event: Index={evt.index}, Value='{evt.value}'") + selected_info = None + if current_token_map_val: + # Find the token corresponding to the selected character span + char_index = evt.index[0] # Start index of the selection + for item in current_token_map_val: + start, end = item["span"] + # Check if the click falls within this token's span + if start <= char_index < end: + selected_info = item + logger.info( + f"Mapped selection at char {char_index} to token index {selected_info['index']}" + ) + break + if not selected_info: + logger.warning(f"Could not map character index {char_index} to any token span.") + + # Update the display based on the found token and current layer + token_sum, active_sum, active_table, visible = update_token_display( + selected_info, selected_layer_name_val, current_masks_val + ) + return selected_info, token_sum, active_sum, active_table, gr.update(visible=visible) + + prompt_display_highlight.select( + fn=handle_token_select, + inputs=[current_token_map, current_masks, selected_layer_name], + outputs=[ + selected_token_info, + token_summary, + active_components_summary, + active_indices_table, + token_info_area, + ], + queue=False, # Selection should be fast + ) + + def handle_layer_change( + layer_name: str, + selected_token_info_val: TokenMapItem | None, + current_masks_val: dict[str, Float[Tensor, "1 seq_len m"]] | None, + ) -> tuple[str, str, Any, gr.update]: + """Handles layer change, updates display if a token is selected.""" + logger.info(f"Layer changed to: {layer_name}") + # Update display based on the new layer and existing token selection + token_sum, active_sum, active_table, visible = update_token_display( + selected_token_info_val, layer_name, current_masks_val + ) + return token_sum, active_sum, active_table, gr.update(visible=visible) + + layer_dropdown.change( + fn=handle_layer_change, + inputs=[layer_dropdown, selected_token_info, current_masks], + outputs=[ + token_summary, + active_components_summary, + active_indices_table, + token_info_area, + ], + queue=False, # Layer change update should be fast + ) + + # --- Initial Load --- + # Use the same function, but ensure the global iterator is used + app.load( + fn=load_next_prompt_data, + inputs=[app_state], + outputs=[ + # REMOVED: dataloader_iter, + current_input_ids, + current_token_map, + current_masks, + prompt_display_highlight, + selected_token_info, # Reset + token_summary, + active_components_summary, + active_indices_table, + token_info_area, # Reset - Keep this here + ], + queue=True, + ) + + return app + + +# --- Main Execution --- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Gradio app to explore LM component activations.") + parser.add_argument( + "--model_path", + type=str, + default=DEFAULT_MODEL_PATH, + help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", + ) + parser.add_argument( + "--share", + action="store_true", + help="Create a publicly shareable link.", + ) + args = parser.parse_args() + + gradio_app = build_gradio_app(args) + # The iterator is initialized inside build_gradio_app now + gradio_app.launch(share=args.share) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index 33c7c5a..1258495 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -163,5 +163,5 @@ def main(path: ModelPath) -> None: if __name__ == "__main__": - path = "wandb:spd-lm/runs/hmjepm9b" + path = "wandb:spd-lm/runs/151bsctx" main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 04d98d6..0b1260f 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -28,7 +28,7 @@ n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- batch_size: 4 # Adjust based on GPU memory -steps: 10_000 # Total training steps +steps: 1_000 # Total training steps lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup @@ -38,7 +38,7 @@ init_from_target_model: false # Not implemented/applicable for this setup # --- Logging & Saving --- image_freq: 1000 # Frequency for generating/logging plots print_freq: 100 # Frequency for printing logs to console -save_freq: 10_000 # Frequency for saving checkpoints +save_freq: 1_000 # Frequency for saving checkpoints image_on_first_step: true # Whether to log plots at step 0 # --- Task Specific --- diff --git a/spd/experiments/lm/streamlit_app.py b/spd/experiments/lm/streamlit_app.py new file mode 100644 index 0000000..32f20d7 --- /dev/null +++ b/spd/experiments/lm/streamlit_app.py @@ -0,0 +1,297 @@ +""" +To run this app, run the following command: + +```bash + streamlit run spd/experiments/lm/streamlit_app.py -- --model_path "wandb:spd-lm/runs/151bsctx" +``` +""" + +import argparse +from collections.abc import Iterator +from typing import Any + +import streamlit as st +import streamlit_antd_components as sac +import torch +from jaxtyping import Float, Int +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from torch import Tensor +from transformers import AutoTokenizer + +from spd.configs import LMTaskConfig +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_component_acts, calc_masks +from spd.types import ModelPath + +DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" + + +# --- Initialization and Data Loading --- +@st.cache_resource(show_spinner="Loading model and data...") +def initialize(model_path: ModelPath) -> dict[str, Any]: + """ + Loads the model, tokenizer, config, and evaluation dataloader. + Cached by Streamlit based on the model_path. + """ + device = "cpu" # Use CPU for the Streamlit app + logger.info(f"Initializing app with model: {model_path} on device: {device}") + ss_model, config, _ = SSModel.from_pretrained(model_path) + ss_model.to(device) + ss_model.eval() + + assert isinstance(config.task_config, LMTaskConfig), ( + "Task config must be LMTaskConfig for this app." + ) + + # Derive tokenizer path (adjust if stored differently) + tokenizer_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, legacy=False) + + # Create eval dataloader config + eval_data_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=tokenizer_path, + split=config.task_config.eval_data_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, # Non-streaming might be simpler for iterator reset + column_name="story", + ) + + # Create the dataloader iterator + def create_dataloader_iter() -> Iterator[dict[str, Int[Tensor, "1 seq_len"]]]: + logger.info("Creating new dataloader iterator.") + dataloader, _ = create_data_loader( + dataset_config=eval_data_config, + batch_size=1, # Always use batch size 1 for this app + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, # Use same seed for reproducibility + ddp_rank=0, + ddp_world_size=1, + ) + return iter(dataloader) + + # Extract components and gates + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() + } + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() + } + target_layer_names = sorted(list(components.keys())) + + logger.info(f"Initialization complete for {model_path}.") + return { + "model": ss_model, + "tokenizer": tokenizer, + "config": config, + "dataloader_iter_fn": create_dataloader_iter, # Store the function to create iter + "gates": gates, + "components": components, + "target_layer_names": target_layer_names, + "device": device, + } + + +def load_next_prompt() -> None: + """Loads the next prompt, calculates masks, and prepares token data.""" + logger.info("Loading next prompt.") + app_data = st.session_state.app_data + dataloader_iter = st.session_state.dataloader_iter # Get current iterator + + try: + batch = next(dataloader_iter) + input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data["device"]) + except StopIteration: + logger.warning("Dataloader iterator exhausted. Throwing error.") + st.error("Failed to get data even after resetting dataloader.") + return + + st.session_state.current_input_ids = input_ids + + # Decode tokenized IDs to get the transformed text + st.session_state.transformed_prompt_text = app_data["tokenizer"].decode( + input_ids[0], skip_special_tokens=True + ) + + # Calculate activations and masks + with torch.no_grad(): + (_, _), pre_weight_acts = app_data["model"].forward_with_pre_forward_cache_hooks( + input_ids, module_names=list(app_data["components"].keys()) + ) + As = { + module_name: v.linear_component.A for module_name, v in app_data["components"].items() + } + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + masks, _ = calc_masks( + gates=app_data["gates"], + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=True, # No gradients needed + ) + st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]] + + # Prepare token data for display + token_data = [] + tokenizer = app_data["tokenizer"] + for i, token_id in enumerate(input_ids[0]): + # Decode individual token - might differ slightly from full decode for spaces etc. + decoded_token_str = tokenizer.decode([token_id]) + token_data.append({"id": token_id.item(), "text": decoded_token_str, "index": i}) + st.session_state.token_data = token_data + + # Reset selections + st.session_state.selected_token_index = None + st.session_state.selected_layer_name = None + logger.info("Finished loading next prompt and calculating masks.") + + +def set_selected_token(index: int) -> None: + """Callback function to set the selected token index.""" + # Check if the index is valid before setting + if 0 <= index < len(st.session_state.token_data): + logger.debug(f"Token {index} selected.") + st.session_state.selected_token_index = index + else: + logger.debug(f"Invalid index ({index}) received or no token clicked.") + + +# --- Main App UI --- +def run_app(args: argparse.Namespace) -> None: + """Sets up and runs the Streamlit application.""" + st.set_page_config(layout="wide") + st.title("LM Component Activation Explorer") + + # Initialize model, data, etc. (cached) + st.session_state.app_data = initialize(args.model_path) + app_data = st.session_state.app_data + st.caption(f"Model: {args.model_path}") + + # Initialize session state variables if they don't exist + if "transformed_prompt_text" not in st.session_state: + st.session_state.transformed_prompt_text = None + if "token_data" not in st.session_state: + st.session_state.token_data = None + if "current_masks" not in st.session_state: + st.session_state.current_masks = None + if "selected_token_index" not in st.session_state: + st.session_state.selected_token_index = None + if "selected_layer_name" not in st.session_state: + if app_data["target_layer_names"]: + st.session_state.selected_layer_name = app_data["target_layer_names"][0] + else: + st.session_state.selected_layer_name = None + # Initialize the dataloader iterator in session state + if "dataloader_iter" not in st.session_state: + st.session_state.dataloader_iter = app_data["dataloader_iter_fn"]() + + # --- Prompt Area --- + st.button("Load Initial / Next Prompt", on_click=load_next_prompt) + + # Display Transformed (Decoded) Prompt using Clickable Tokens + if st.session_state.token_data: + st.subheader("Prompt (Encoded->Decoded, Click Tokens Below)") + + # Use sac.buttons to create clickable text segments + clicked_token_index = sac.buttons( + items=[ + sac.ButtonsItem(label=token_info["text"]) + for i, token_info in enumerate(st.session_state.token_data) + ], + index=st.session_state.selected_token_index, + format_func=None, + align="left", + variant="text", + size="xs", + gap=1, + use_container_width=False, + return_index=True, + key="token_buttons", + radius=1, + ) + + # Update selected token based on click + if clicked_token_index != st.session_state.selected_token_index: + set_selected_token(clicked_token_index) + st.rerun() + + st.divider() + + # --- Token Information Area --- + if st.session_state.selected_token_index is not None: + idx = st.session_state.selected_token_index + # Ensure token_data is loaded before accessing + if st.session_state.token_data and idx < len(st.session_state.token_data): + token_info = st.session_state.token_data[idx] + token_text = token_info["text"] + + st.header(f"Token Info: '{token_text}' (Position: {idx}, ID: {token_info['id']})") + + # Layer Selection Dropdown + st.selectbox( + "Select Layer to Inspect:", + options=app_data["target_layer_names"], + key="selected_layer_name", # Binds selection to session state + ) + + # Display Layer-Specific Info if a layer is selected + if st.session_state.selected_layer_name: + layer_name = st.session_state.selected_layer_name + logger.debug(f"Displaying info for token {idx}, layer {layer_name}") + + if st.session_state.current_masks is None: + st.warning("Masks not calculated yet. Please load a prompt.") + return + + layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[ + layer_name + ] + token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :] + + # Find active components (mask > 0) + active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] + n_active_layer = len(active_indices_layer) + + st.metric(f"Active Components in {layer_name}", n_active_layer) + + st.subheader("Active Component Indices") + if n_active_layer > 0: + # Convert to NumPy array and reshape to a column vector (N x 1) + active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) + # Pass the NumPy array directly and configure the column header + st.dataframe( + active_indices_np, + height=300, + use_container_width=False, + column_config={0: "Component Index"}, # Rename the first column (index 0) + ) + else: + st.write("No active components for this token in this layer.") + + # Extensibility Placeholder + st.subheader("Additional Layer/Token Analysis") + st.write( + "Future figures and analyses for this specific layer and token will appear here." + ) + else: + # Handle case where selected_token_index might be invalid after data reload + st.warning("Selected token index is out of bounds. Please select a token again.") + st.session_state.selected_token_index = None # Reset selection + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Streamlit app to explore LM component activations." + ) + parser.add_argument( + "--model_path", + type=str, + default=DEFAULT_MODEL_PATH, + help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", + ) + args = parser.parse_args() + + run_app(args) From 52ff9a42c2155ff41120b3f8fe3bcf0967802421 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 21 Apr 2025 02:14:45 +0000 Subject: [PATCH 2/4] Create base_cache_dir if it doesn't exist --- spd/wandb_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py index 73d2d67..9f451a7 100644 --- a/spd/wandb_utils.py +++ b/spd/wandb_utils.py @@ -45,6 +45,7 @@ def fetch_wandb_run_dir(run_id: str) -> Path: """ # Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb")) + base_cache_dir.mkdir(parents=True, exist_ok=True) # Set default wandb_run_dir wandb_run_dir = base_cache_dir / run_id / "files" From 38d5527ae9d70f317bc878a9acb6d08b2c3f5604 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 22 Apr 2025 05:27:29 +0000 Subject: [PATCH 3/4] Functional dashboard --- .vscode/launch.json | 12 + spd/experiments/lm/app.py | 799 ++++++++++++---------------- spd/experiments/lm/streamlit_app.py | 297 ----------- 3 files changed, 348 insertions(+), 760 deletions(-) delete mode 100644 spd/experiments/lm/streamlit_app.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 89875d2..a753e33 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -48,6 +48,18 @@ "env": { "PYDEVD_DISABLE_FILE_VALIDATION": "1" } + }, + { + "name": "lm streamlit", + "type": "debugpy", + "request": "launch", + "module": "streamlit", + "args": [ + "run", + "${workspaceFolder}/spd/experiments/lm/app.py", + "--server.port", + "2000" + ] } ] } \ No newline at end of file diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py index d6ff13c..d880198 100644 --- a/spd/experiments/lm/app.py +++ b/spd/experiments/lm/app.py @@ -1,532 +1,405 @@ +""" +To run this app, run the following command: + +```bash + streamlit run spd/experiments/lm/app.py -- --model_path "wandb:spd-lm/runs/151bsctx" +``` +""" + import argparse -import logging -from collections.abc import Iterator -from typing import Any, cast +import html +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any -import gradio as gr +import streamlit as st import torch +from datasets import load_dataset from jaxtyping import Float, Int -from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from simple_stories_train.dataloaders import DatasetConfig from torch import Tensor from transformers import AutoTokenizer from spd.configs import Config, LMTaskConfig from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger from spd.models.components import Gate, GateMLP from spd.run_spd import calc_component_acts, calc_masks from spd.types import ModelPath -# --- Configuration & Constants --- - DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# --- Data Structures --- - -# Structure to hold information mapping character spans to tokens -TokenMapItem = dict[str, Any] # Keys: 'text', 'span': tuple[int, int], 'index': int, 'id': int -TokenMap = list[TokenMapItem] - -# Structure for Gradio state -AppState = dict[str, Any] # Keys: 'model', 'tokenizer', 'config', 'gates', 'components', etc. - - -# --- Core Functions --- -@torch.no_grad() -def load_resources(model_path: ModelPath, device: str) -> AppState: - """Loads the model, tokenizer, config, components, and gates.""" - logger.info(f"Loading resources for model: {model_path} on device: {device}") +# ----------------------------------------------------------- +# Dataclass holding everything the app needs +# ----------------------------------------------------------- +@dataclass(frozen=True) +class AppData: + model: SSModel + tokenizer: AutoTokenizer + config: Config + dataloader_iter_fn: Callable[[], Iterator[dict[str, Any]]] + gates: dict[str, Gate | GateMLP] + components: dict[str, LinearComponentWithBias] + target_layer_names: list[str] + device: str + + +# --- Initialization and Data Loading --- +@st.cache_resource(show_spinner="Loading model and data...") +def initialize(model_path: ModelPath) -> AppData: + """ + Loads the model, tokenizer, config, and evaluation dataloader. + Cached by Streamlit based on the model_path. + """ + device = "cpu" # Use CPU for the Streamlit app + logger.info(f"Initializing app with model: {model_path} on device: {device}") ss_model, config, _ = SSModel.from_pretrained(model_path) ss_model.to(device) ss_model.eval() - assert isinstance(config.task_config, LMTaskConfig), ( - "Task config must be LMTaskConfig for this app." + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Task config must be LMTaskConfig for this app." + + # Derive tokenizer path (adjust if stored differently) + tokenizer_path = f"chandan-sreedhara/SimpleStories-{task_config.model_size}" + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + add_bos_token=False, + unk_token="[UNK]", + eos_token="[EOS]", + bos_token=None, ) - # Derive tokenizer path - tokenizer_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" - # Use the base tokenizer from AutoTokenizer for consistency if needed, - # but create_data_loader might load its own. Ensure they are compatible. - # For decoding/mapping, AutoTokenizer is convenient. - hf_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, legacy=False) - - # Extract components and gates - gates: dict[str, Gate | GateMLP] = { - k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() - } - components: dict[str, LinearComponentWithBias] = { - k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() - } - target_layer_names = sorted(list(components.keys())) - - logger.info(f"Finished loading resources for {model_path}.") - return { - "model": ss_model, - "tokenizer": hf_tokenizer, # Use HF Tokenizer for decoding/mapping - "config": config, - "gates": gates, - "components": components, - "target_layer_names": target_layer_names, - "device": device, - "tokenizer_path": tokenizer_path, # Store path for dataloader - } - - -def create_eval_dataloader_iter( - app_state: AppState, -) -> Iterator[dict[str, Int[Tensor, "1 seq_len"]]]: - """Creates a new iterator for the evaluation dataloader.""" - config: Config = app_state["config"] - task_config: LMTaskConfig = cast(LMTaskConfig, config.task_config) - tokenizer_path: str = app_state["tokenizer_path"] - logger.info("Creating new evaluation dataloader iterator.") - + # Create eval dataloader config eval_data_config = DatasetConfig( name=task_config.dataset_name, - tokenizer_file_path=None, # Use HF tokenizer path + tokenizer_file_path=None, hf_tokenizer_path=tokenizer_path, split=task_config.eval_data_split, n_ctx=task_config.max_seq_len, - is_tokenized=False, # Tokenize on the fly - streaming=True, # Use streaming as requested + is_tokenized=False, + streaming=False, # Non-streaming might be simpler for iterator reset column_name="story", - seed=config.seed, # Use same seed for reproducibility if needed ) - dataloader, _ = create_data_loader( - dataset_config=eval_data_config, - batch_size=1, # Always use batch size 1 for this app - buffer_size=task_config.buffer_size, - global_seed=config.seed, - ddp_rank=0, - ddp_world_size=1, - ) - # Make the dataloader an explicit iterator - return iter(dataloader) + # Create the dataloader iterator + def create_dataloader_iter() -> Iterator[dict[str, Any]]: + """ + Returns a *new* iterator each time it is called. + Each element is a dict with: + - "text": the raw document text + - "input_ids": Int[Tensor, "1 seq_len"] + - "offset_mapping": list[tuple[int, int]] + """ + logger.info("Creating new dataloader iterator.") + + # Stream the HF dataset split + dataset = load_dataset( + eval_data_config.name, + streaming=eval_data_config.streaming, + split=eval_data_config.split, + trust_remote_code=False, + ) + dataset = dataset.with_format("torch") -def get_token_mapping( - tokenizer: AutoTokenizer, input_ids: Int[Tensor, "1 seq_len"] -) -> tuple[str, TokenMap]: - """ - Decodes input_ids and creates a mapping from character spans to token info. - Handles potential decoding artifacts like extra spaces. - """ - ids_list = input_ids[0].tolist() - full_text = tokenizer.decode(ids_list, skip_special_tokens=True) - logger.debug(f"Full decoded text length: {len(full_text)}") - - token_map: TokenMap = [] - current_char_index = 0 - - for token_idx, token_id in enumerate(ids_list): - # Decode individual token *without* special tokens or added spaces - # Note: This might differ slightly from full decode for some tokenizers (e.g., SentencePiece) - # We prioritize matching the token's contribution to the full decoded string. - token_text = tokenizer.decode( - [token_id], skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + text_column = eval_data_config.column_name - # Find the *next* occurrence of this token's text in the full string - try: - # Find the start index, searching from the current position - start_char = full_text.index(token_text, current_char_index) - end_char = start_char + len(token_text) - - # Store mapping info - token_map_item: TokenMapItem = { - "text": token_text, - "span": (start_char, end_char), - "index": token_idx, - "id": token_id, - } - token_map.append(token_map_item) - # logger.debug(f"Mapped token {token_idx} (ID: {token_id}, Text: '{token_text}') to span {token_map_item['span']}") - - # Update current character index for the next search - current_char_index = end_char - - except ValueError: - # This can happen if the individual token decode differs significantly - # from its representation in the full decode (e.g., spaces, merges) - logger.warning( - f"Could not find token_text='{token_text}' (ID: {token_id}, Index: {token_idx}) " - f"in remaining full_text='{full_text[current_char_index:]}'. Skipping token mapping." + def tokenize_and_prepare(example: dict[str, Any]) -> dict[str, Any]: + original_text: str = example[text_column] + + tokenized = tokenizer( + original_text, + return_tensors="pt", + return_offsets_mapping=True, + truncation=True, + max_length=task_config.max_seq_len, + padding=False, ) - # Attempt to gracefully handle by skipping or trying alternative decodes if necessary - # For now, we just log and potentially skip. A robust solution might require - # tokenizer-specific logic or offset mapping if available. - - # Verification step (optional but recommended) - if current_char_index != len(full_text) and len(token_map) == len(ids_list): - logger.warning( - f"Final character index {current_char_index} does not match full text length {len(full_text)}. Mapping might be imperfect." - ) - elif len(token_map) != len(ids_list): - logger.warning( - f"Mapped {len(token_map)} tokens, but expected {len(ids_list)}. Mapping is incomplete." - ) - return full_text, token_map + input_ids: Int[Tensor, "1 seq_len"] = tokenized["input_ids"] + if input_ids.dim() == 1: # Ensure 2‑D [1, seq_len] + input_ids = input_ids.unsqueeze(0) + # HF returns offset_mapping as a list per sequence; batch size is 1 + offset_mapping: list[tuple[int, int]] = tokenized["offset_mapping"][0].tolist() -@torch.no_grad() -def calculate_masks_for_batch( - app_state: AppState, input_ids: Int[Tensor, "1 seq_len"] -) -> dict[str, Float[Tensor, "1 seq_len m"]]: - """Performs forward pass and calculates masks for the given input_ids.""" - model: SSModel = app_state["model"] - components: dict[str, LinearComponentWithBias] = app_state["components"] - gates: dict[str, Gate | GateMLP] = app_state["gates"] - device: str = app_state["device"] + return { + "text": original_text, + "input_ids": input_ids, + "offset_mapping": offset_mapping, + } - input_ids = input_ids.to(device) + # Map over the streaming dataset and return an iterator + return map(tokenize_and_prepare, iter(dataset)) - logger.info("Running forward pass to get activations...") - (_, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( - input_ids, module_names=list(components.keys()) - ) - logger.info("Calculating component activations...") - As = {module_name: v.linear_component.A for module_name, v in components.items()} - target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + # Extract components and gates + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() + } # type: ignore[reportAssignmentType] + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() + } # type: ignore[reportAssignmentType] + target_layer_names = sorted(list(components.keys())) - logger.info("Calculating masks...") - masks, _ = calc_masks( + logger.info(f"Initialization complete for {model_path}.") + return AppData( + model=ss_model, + tokenizer=tokenizer, + config=config, + dataloader_iter_fn=create_dataloader_iter, gates=gates, - target_component_acts=target_component_acts, - attributions=None, - detach_inputs=True, # No gradients needed for inference + components=components, + target_layer_names=target_layer_names, + device=device, ) - logger.info("Mask calculation complete.") - # Ensure masks are on CPU for Gradio state if needed, although calculations were on device - return {k: v.cpu() for k, v in masks.items()} -def update_token_display( - selected_token_info: TokenMapItem | None, - selected_layer_name: str | None, - current_masks: dict[str, Float[Tensor, "1 seq_len m"]] | None, -) -> tuple[str, str, Any, bool]: +# ----------------------------------------------------------- +# Utility: render the prompt with faint token outlines +# ----------------------------------------------------------- +def render_prompt_with_tokens( + *, + raw_text: str, + offset_mapping: list[tuple[int, int]], + selected_idx: int | None, +) -> None: """ - Generates display updates for the token info area based on selection. - Returns: (token_summary_md, active_comp_summary_md, active_indices_df_data, area_visible) + Renders `raw_text` inside Streamlit, wrapping each token span with a thin + border. The currently‑selected token receives a thicker red border. + All other tokens get a thin mid‑grey border (no background fill). """ - if not selected_token_info or not selected_layer_name or not current_masks: - logger.debug("Update condition not met, hiding token info area.") - return "", "", None, False # Hide area if no token/layer/masks + html_chunks: list[str] = [] + cursor = 0 + + def esc(s: str) -> str: + return html.escape(s) - token_idx = selected_token_info["index"] - token_id = selected_token_info["id"] - token_text = selected_token_info["text"] + for idx, (start, end) in enumerate(offset_mapping): + if cursor < start: + html_chunks.append(esc(raw_text[cursor:start])) - logger.info( - f"Updating display for token {token_idx} ('{token_text}'), layer {selected_layer_name}" + token_substr = esc(raw_text[start:end]) + if token_substr: + is_selected = idx == selected_idx + border_style = ( + "2px solid rgb(200,0,0)" if is_selected else "0.5px solid #aaa" # all other tokens + ) + html_chunks.append( + "' + f"{token_substr}" + ) + cursor = end + + if cursor < len(raw_text): + html_chunks.append(esc(raw_text[cursor:])) + + st.markdown( + f'
{"".join(html_chunks)}
', + unsafe_allow_html=True, ) - token_summary_md = f"**Token Info:** '{token_text}' (Position: {token_idx}, ID: {token_id})" - - if selected_layer_name not in current_masks: - logger.warning(f"Selected layer {selected_layer_name} not found in current masks.") - return token_summary_md, "Error: Layer not found in masks.", None, True - - layer_mask_tensor: Float[Tensor, "1 seq_len m"] = current_masks[selected_layer_name] - # Ensure tensor is on CPU before indexing if it wasn't already - token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, token_idx, :].cpu() - - # Find active components (mask > 0) - active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] - n_active_layer = len(active_indices_layer) - - active_comp_summary_md = f"**Active Components in {selected_layer_name}:** {n_active_layer}" - - active_indices_df_data = None - if n_active_layer > 0: - # Convert to NumPy array and reshape for DataFrame (N x 1) - active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) - active_indices_df_data = active_indices_np - logger.debug(f"Found {n_active_layer} active components.") - else: - logger.debug("No active components found.") - - return token_summary_md, active_comp_summary_md, active_indices_df_data, True - - -# --- Gradio Application --- - -# Store the iterator globally (or in a mutable container) for access within handlers -# This is suitable for single-session Gradio apps. -global_iterator_store: dict[str, Iterator[dict[str, Int[Tensor, "1 seq_len"]]] | None] = { - "iterator": None -} - - -def build_gradio_app(args: argparse.Namespace) -> gr.Blocks: - """Builds the Gradio Blocks interface.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" - # Load resources once - initial_app_state = load_resources(args.model_path, device) - # Create the initial iterator and store it globally - global_iterator_store["iterator"] = create_eval_dataloader_iter(initial_app_state) - - with gr.Blocks(title="LM Component Activation Explorer") as app: - # --- State Management --- - # Store heavy objects and mutable state here - app_state = gr.State(initial_app_state) - # REMOVE: dataloader_iter = gr.State(initial_dataloader_iter) # CANNOT STORE ITERATOR IN STATE - # State for current prompt data - current_input_ids = gr.State(None) - current_token_map = gr.State(None) # Stores the TokenMap list - current_masks = gr.State(None) - # State for user selections - selected_token_info = gr.State(None) # Stores the selected TokenMapItem - selected_layer_name = gr.State( - initial_app_state["target_layer_names"][0] - if initial_app_state["target_layer_names"] - else None - ) - # --- UI Layout --- - gr.Markdown(f"# LM Component Activation Explorer\n**Model:** {args.model_path}") - - with gr.Row(): - # Use Textbox for raw display, HighlightedText for interaction - # prompt_display_text = gr.Textbox(label="Prompt Text", lines=5, interactive=False) - prompt_display_highlight = gr.HighlightedText( - label="Prompt (Click Tokens)", - interactive=True, - combine_adjacent=False, # Treat each token span separately - show_legend=False, - ) - next_button = gr.Button("Load Next Prompt") - - layer_dropdown = gr.Dropdown( - label="Select Layer", - choices=initial_app_state["target_layer_names"], - value=initial_app_state["target_layer_names"][0] - if initial_app_state["target_layer_names"] - else None, - interactive=True, - ) +def load_next_prompt() -> None: + """Loads the next prompt, calculates masks, and prepares token data.""" + logger.info("Loading next prompt.") + app_data: AppData = st.session_state.app_data + dataloader_iter = st.session_state.dataloader_iter # Get current iterator - # Initially hidden area for token details - with gr.Column(visible=False) as token_info_area: - token_summary = gr.Markdown() - active_components_summary = gr.Markdown() - active_indices_table = gr.DataFrame( - headers=["Component Index"], - datatype=["number"], - col_count=(1, "fixed"), - # max_rows=10, # Limit visible rows initially - interactive=False, - label="Active Component Indices", - ) - future_plots_placeholder = gr.Markdown("*Future analyses will appear here.*") - - # --- Event Handlers --- - - def load_next_prompt_data( - app_state_val: AppState, # Removed current_iter_state input - ) -> tuple[ - # REMOVED: Iterator, # Updated iterator state - Int[Tensor, "1 seq_len"], # current_input_ids - TokenMap, # current_token_map - dict[str, Float[Tensor, "1 seq_len m"]], # current_masks - list[tuple[str, str | None]], # HighlightedText value - # Reset selections - None, # selected_token_info - str, # token_summary update - str, # active_comp_summary update - None, # active_indices_table update - gr.update, # Return an update for the Column - ]: - """Gets next batch, calculates masks, prepares display data.""" - logger.info("Attempting to load next prompt...") - # Access the iterator from the global store - iterator = global_iterator_store["iterator"] - if iterator is None: - logger.error("Iterator not initialized.") - raise gr.Error("Iterator not initialized. Please restart the app.") - - try: - batch = next(iterator) - input_ids: Int[Tensor, "1 seq_len"] = batch[ - "input_ids" - ] # Should be shape (1, seq_len) - logger.info(f"Loaded batch with shape: {input_ids.shape}") - except StopIteration: - logger.warning("Dataloader iterator exhausted. Resetting.") - # Recreate iterator and update the global store - iterator = create_eval_dataloader_iter(app_state_val) - global_iterator_store["iterator"] = iterator - try: - batch = next(iterator) - input_ids = batch["input_ids"] - logger.info(f"Loaded first batch after reset, shape: {input_ids.shape}") - except StopIteration: - logger.error("Failed to get data even after resetting dataloader.") - # Handle error state appropriately, maybe raise gr.Error - raise gr.Error("Dataset seems empty or failed to load.") - - # Ensure input_ids are on CPU for mapping - input_ids_cpu = input_ids.cpu() - tokenizer = app_state_val["tokenizer"] - - # 1. Get Token Mapping - full_text, token_map = get_token_mapping(tokenizer, input_ids_cpu) - # Format for HighlightedText: list of (text_substring, label/tooltip) - # We use the token index as the label for now. - highlight_data = [(item["text"], f"Token {item['index']}") for item in token_map] - - # 2. Calculate Masks (can run on GPU if available) - masks = calculate_masks_for_batch( - app_state_val, input_ids - ) # input_ids passed might be on CPU or GPU based on device - - # 3. Return all updates (excluding the iterator) - return ( - # REMOVED: current_iter_state, - input_ids_cpu, # Store CPU version in state - token_map, - masks, - highlight_data, - None, # Reset selected token - "", # Reset token_summary - "", # Reset active_components_summary - None, # Reset active_indices_table - gr.update(visible=False), # Return update object - ) + try: + batch = next(dataloader_iter) + input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data.device) + except StopIteration: + logger.warning("Dataloader iterator exhausted. Throwing error.") + st.error("Failed to get data even after resetting dataloader.") + return - next_button.click( - fn=load_next_prompt_data, - inputs=[app_state], # Removed dataloader_iter - outputs=[ - # REMOVED: dataloader_iter, - current_input_ids, - current_token_map, - current_masks, - prompt_display_highlight, - selected_token_info, # Reset - token_summary, - active_components_summary, - active_indices_table, - token_info_area, # Reset - ], - queue=True, # Use queue for potentially long-running model calls - ) + st.session_state.current_input_ids = input_ids - def handle_token_select( - evt: gr.SelectData, - current_token_map_val: TokenMap | None, - current_masks_val: dict[str, Float[Tensor, "1 seq_len m"]] | None, - selected_layer_name_val: str | None, - ) -> tuple[TokenMapItem | None, str, str, Any, gr.update]: - """Handles token selection, finds corresponding token info, updates display.""" - logger.debug(f"Token selected event: Index={evt.index}, Value='{evt.value}'") - selected_info = None - if current_token_map_val: - # Find the token corresponding to the selected character span - char_index = evt.index[0] # Start index of the selection - for item in current_token_map_val: - start, end = item["span"] - # Check if the click falls within this token's span - if start <= char_index < end: - selected_info = item - logger.info( - f"Mapped selection at char {char_index} to token index {selected_info['index']}" - ) - break - if not selected_info: - logger.warning(f"Could not map character index {char_index} to any token span.") - - # Update the display based on the found token and current layer - token_sum, active_sum, active_table, visible = update_token_display( - selected_info, selected_layer_name_val, current_masks_val - ) - return selected_info, token_sum, active_sum, active_table, gr.update(visible=visible) - - prompt_display_highlight.select( - fn=handle_token_select, - inputs=[current_token_map, current_masks, selected_layer_name], - outputs=[ - selected_token_info, - token_summary, - active_components_summary, - active_indices_table, - token_info_area, - ], - queue=False, # Selection should be fast - ) + # Store the original raw prompt text + st.session_state.current_prompt_text = batch["text"] - def handle_layer_change( - layer_name: str, - selected_token_info_val: TokenMapItem | None, - current_masks_val: dict[str, Float[Tensor, "1 seq_len m"]] | None, - ) -> tuple[str, str, Any, gr.update]: - """Handles layer change, updates display if a token is selected.""" - logger.info(f"Layer changed to: {layer_name}") - # Update display based on the new layer and existing token selection - token_sum, active_sum, active_table, visible = update_token_display( - selected_token_info_val, layer_name, current_masks_val - ) - return token_sum, active_sum, active_table, gr.update(visible=visible) - - layer_dropdown.change( - fn=handle_layer_change, - inputs=[layer_dropdown, selected_token_info, current_masks], - outputs=[ - token_summary, - active_components_summary, - active_indices_table, - token_info_area, - ], - queue=False, # Layer change update should be fast + # Calculate activations and masks + with torch.no_grad(): + (_, _), pre_weight_acts = app_data.model.forward_with_pre_forward_cache_hooks( + input_ids, module_names=list(app_data.components.keys()) ) - - # --- Initial Load --- - # Use the same function, but ensure the global iterator is used - app.load( - fn=load_next_prompt_data, - inputs=[app_state], - outputs=[ - # REMOVED: dataloader_iter, - current_input_ids, - current_token_map, - current_masks, - prompt_display_highlight, - selected_token_info, # Reset - token_summary, - active_components_summary, - active_indices_table, - token_info_area, # Reset - Keep this here - ], - queue=True, + As = {module_name: v.linear_component.A for module_name, v in app_data.components.items()} + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore[reportArgumentType] + masks, _ = calc_masks( + gates=app_data.gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=True, # No gradients needed + ) + st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]] + + # Prepare token data for display + token_data = [] + tokenizer = app_data.tokenizer + for i, token_id in enumerate(input_ids[0]): + # Decode individual token - might differ slightly from full decode for spaces etc. + decoded_token_str = tokenizer.decode([token_id]) # type: ignore[reportAttributeAccessIssue] + token_data.append( + { + "id": token_id.item(), + "text": decoded_token_str, + "index": i, + "offset": batch["offset_mapping"][i], # (start, end) + } + ) + st.session_state.token_data = token_data + + # Reset selections + st.session_state.selected_token_index = 0 # default: first token + st.session_state.selected_layer_name = None + logger.info("Finished loading next prompt and calculating masks.") + + +# --- Main App UI --- +def run_app(args: argparse.Namespace) -> None: + """Sets up and runs the Streamlit application.""" + st.set_page_config(layout="wide") + st.title("LM Component Activation Explorer") + + # Initialize model, data, etc. (cached) + st.session_state.app_data = initialize(args.model_path) + app_data: AppData = st.session_state.app_data + st.caption(f"Model: {args.model_path}") + + # Initialize session state variables if they don't exist + if "current_prompt_text" not in st.session_state: + st.session_state.current_prompt_text = None + if "token_data" not in st.session_state: + st.session_state.token_data = None + if "current_masks" not in st.session_state: + st.session_state.current_masks = None + if "selected_token_index" not in st.session_state: + st.session_state.selected_token_index = None + if "selected_layer_name" not in st.session_state: + if app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + else: + st.session_state.selected_layer_name = None + # Initialize the dataloader iterator in session state + if "dataloader_iter" not in st.session_state: + st.session_state.dataloader_iter = app_data.dataloader_iter_fn() + + if st.session_state.current_prompt_text is None: + load_next_prompt() + + # Sidebar container and a single expander for all interactive controls + sidebar = st.sidebar + controls_expander = sidebar.expander("Controls", expanded=True) + + # ------------------------------------------------------------------ + # Sidebar – interactive controls + # ------------------------------------------------------------------ + with controls_expander: + st.button("Load Next Prompt", on_click=load_next_prompt) + + # Render the raw prompt with faint token borders + if st.session_state.token_data and st.session_state.current_prompt_text: + # st.subheader("Prompt") + render_prompt_with_tokens( + raw_text=st.session_state.current_prompt_text, + offset_mapping=[t["offset"] for t in st.session_state.token_data], + selected_idx=st.session_state.selected_token_index, ) - return app - + # Sidebar slider for token selection + n_tokens = len(st.session_state.token_data) + if n_tokens > 0: + with controls_expander: + st.header("Token selector") + idx = st.slider( + "Token index", + min_value=0, + max_value=n_tokens - 1, + step=1, + key="selected_token_index", + ) + + selected_token = st.session_state.token_data[idx] + st.write(f"Selected token: {selected_token['text']} (ID: {selected_token['id']})") + + st.divider() + + # --- Token Information Area --- + if st.session_state.token_data: + idx = st.session_state.selected_token_index + # Ensure token_data is loaded before accessing + if ( + st.session_state.token_data + and idx is not None + and idx < len(st.session_state.token_data) + ): + # Layer Selection Dropdown + # Always default to the first layer if nothing is selected yet + if st.session_state.selected_layer_name is None and app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + + with controls_expander: + st.header("Layer selector") + st.selectbox( + "Select Layer to Inspect:", + options=app_data.target_layer_names, + key="selected_layer_name", + ) + + # Display Layer-Specific Info if a layer is selected + if st.session_state.selected_layer_name: + layer_name = st.session_state.selected_layer_name + logger.debug(f"Displaying info for token {idx}, layer {layer_name}") + + if st.session_state.current_masks is None: + st.warning("Masks not calculated yet. Please load a prompt.") + return + + layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[ + layer_name + ] + token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :] + + # Find active components (mask > 0) + active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] + n_active_layer = len(active_indices_layer) + + st.metric(f"Active Components in {layer_name}", n_active_layer) + + st.subheader("Active Component Indices") + if n_active_layer > 0: + # Convert to NumPy array and reshape to a column vector (N x 1) + active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) + # Pass the NumPy array directly and configure the column header + st.dataframe(active_indices_np, height=300, use_container_width=False) + else: + st.write("No active components for this token in this layer.") + + # Extensibility Placeholder + st.subheader("Additional Layer/Token Analysis") + st.write( + "Future figures and analyses for this specific layer and token will appear here." + ) + else: + # Handle case where selected_token_index might be invalid after data reload + st.warning("Selected token index is out of bounds. Please select a token again.") + st.session_state.selected_token_index = None # Reset selection -# --- Main Execution --- if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Gradio app to explore LM component activations.") + parser = argparse.ArgumentParser( + description="Streamlit app to explore LM component activations." + ) parser.add_argument( "--model_path", type=str, default=DEFAULT_MODEL_PATH, help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", ) - parser.add_argument( - "--share", - action="store_true", - help="Create a publicly shareable link.", - ) args = parser.parse_args() - gradio_app = build_gradio_app(args) - # The iterator is initialized inside build_gradio_app now - gradio_app.launch(share=args.share) + run_app(args) diff --git a/spd/experiments/lm/streamlit_app.py b/spd/experiments/lm/streamlit_app.py deleted file mode 100644 index 32f20d7..0000000 --- a/spd/experiments/lm/streamlit_app.py +++ /dev/null @@ -1,297 +0,0 @@ -""" -To run this app, run the following command: - -```bash - streamlit run spd/experiments/lm/streamlit_app.py -- --model_path "wandb:spd-lm/runs/151bsctx" -``` -""" - -import argparse -from collections.abc import Iterator -from typing import Any - -import streamlit as st -import streamlit_antd_components as sac -import torch -from jaxtyping import Float, Int -from simple_stories_train.dataloaders import DatasetConfig, create_data_loader -from torch import Tensor -from transformers import AutoTokenizer - -from spd.configs import LMTaskConfig -from spd.experiments.lm.models import LinearComponentWithBias, SSModel -from spd.log import logger -from spd.models.components import Gate, GateMLP -from spd.run_spd import calc_component_acts, calc_masks -from spd.types import ModelPath - -DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" - - -# --- Initialization and Data Loading --- -@st.cache_resource(show_spinner="Loading model and data...") -def initialize(model_path: ModelPath) -> dict[str, Any]: - """ - Loads the model, tokenizer, config, and evaluation dataloader. - Cached by Streamlit based on the model_path. - """ - device = "cpu" # Use CPU for the Streamlit app - logger.info(f"Initializing app with model: {model_path} on device: {device}") - ss_model, config, _ = SSModel.from_pretrained(model_path) - ss_model.to(device) - ss_model.eval() - - assert isinstance(config.task_config, LMTaskConfig), ( - "Task config must be LMTaskConfig for this app." - ) - - # Derive tokenizer path (adjust if stored differently) - tokenizer_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, legacy=False) - - # Create eval dataloader config - eval_data_config = DatasetConfig( - name=config.task_config.dataset_name, - tokenizer_file_path=None, - hf_tokenizer_path=tokenizer_path, - split=config.task_config.eval_data_split, - n_ctx=config.task_config.max_seq_len, - is_tokenized=False, - streaming=False, # Non-streaming might be simpler for iterator reset - column_name="story", - ) - - # Create the dataloader iterator - def create_dataloader_iter() -> Iterator[dict[str, Int[Tensor, "1 seq_len"]]]: - logger.info("Creating new dataloader iterator.") - dataloader, _ = create_data_loader( - dataset_config=eval_data_config, - batch_size=1, # Always use batch size 1 for this app - buffer_size=config.task_config.buffer_size, - global_seed=config.seed, # Use same seed for reproducibility - ddp_rank=0, - ddp_world_size=1, - ) - return iter(dataloader) - - # Extract components and gates - gates: dict[str, Gate | GateMLP] = { - k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() - } - components: dict[str, LinearComponentWithBias] = { - k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() - } - target_layer_names = sorted(list(components.keys())) - - logger.info(f"Initialization complete for {model_path}.") - return { - "model": ss_model, - "tokenizer": tokenizer, - "config": config, - "dataloader_iter_fn": create_dataloader_iter, # Store the function to create iter - "gates": gates, - "components": components, - "target_layer_names": target_layer_names, - "device": device, - } - - -def load_next_prompt() -> None: - """Loads the next prompt, calculates masks, and prepares token data.""" - logger.info("Loading next prompt.") - app_data = st.session_state.app_data - dataloader_iter = st.session_state.dataloader_iter # Get current iterator - - try: - batch = next(dataloader_iter) - input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data["device"]) - except StopIteration: - logger.warning("Dataloader iterator exhausted. Throwing error.") - st.error("Failed to get data even after resetting dataloader.") - return - - st.session_state.current_input_ids = input_ids - - # Decode tokenized IDs to get the transformed text - st.session_state.transformed_prompt_text = app_data["tokenizer"].decode( - input_ids[0], skip_special_tokens=True - ) - - # Calculate activations and masks - with torch.no_grad(): - (_, _), pre_weight_acts = app_data["model"].forward_with_pre_forward_cache_hooks( - input_ids, module_names=list(app_data["components"].keys()) - ) - As = { - module_name: v.linear_component.A for module_name, v in app_data["components"].items() - } - target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) - masks, _ = calc_masks( - gates=app_data["gates"], - target_component_acts=target_component_acts, - attributions=None, - detach_inputs=True, # No gradients needed - ) - st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]] - - # Prepare token data for display - token_data = [] - tokenizer = app_data["tokenizer"] - for i, token_id in enumerate(input_ids[0]): - # Decode individual token - might differ slightly from full decode for spaces etc. - decoded_token_str = tokenizer.decode([token_id]) - token_data.append({"id": token_id.item(), "text": decoded_token_str, "index": i}) - st.session_state.token_data = token_data - - # Reset selections - st.session_state.selected_token_index = None - st.session_state.selected_layer_name = None - logger.info("Finished loading next prompt and calculating masks.") - - -def set_selected_token(index: int) -> None: - """Callback function to set the selected token index.""" - # Check if the index is valid before setting - if 0 <= index < len(st.session_state.token_data): - logger.debug(f"Token {index} selected.") - st.session_state.selected_token_index = index - else: - logger.debug(f"Invalid index ({index}) received or no token clicked.") - - -# --- Main App UI --- -def run_app(args: argparse.Namespace) -> None: - """Sets up and runs the Streamlit application.""" - st.set_page_config(layout="wide") - st.title("LM Component Activation Explorer") - - # Initialize model, data, etc. (cached) - st.session_state.app_data = initialize(args.model_path) - app_data = st.session_state.app_data - st.caption(f"Model: {args.model_path}") - - # Initialize session state variables if they don't exist - if "transformed_prompt_text" not in st.session_state: - st.session_state.transformed_prompt_text = None - if "token_data" not in st.session_state: - st.session_state.token_data = None - if "current_masks" not in st.session_state: - st.session_state.current_masks = None - if "selected_token_index" not in st.session_state: - st.session_state.selected_token_index = None - if "selected_layer_name" not in st.session_state: - if app_data["target_layer_names"]: - st.session_state.selected_layer_name = app_data["target_layer_names"][0] - else: - st.session_state.selected_layer_name = None - # Initialize the dataloader iterator in session state - if "dataloader_iter" not in st.session_state: - st.session_state.dataloader_iter = app_data["dataloader_iter_fn"]() - - # --- Prompt Area --- - st.button("Load Initial / Next Prompt", on_click=load_next_prompt) - - # Display Transformed (Decoded) Prompt using Clickable Tokens - if st.session_state.token_data: - st.subheader("Prompt (Encoded->Decoded, Click Tokens Below)") - - # Use sac.buttons to create clickable text segments - clicked_token_index = sac.buttons( - items=[ - sac.ButtonsItem(label=token_info["text"]) - for i, token_info in enumerate(st.session_state.token_data) - ], - index=st.session_state.selected_token_index, - format_func=None, - align="left", - variant="text", - size="xs", - gap=1, - use_container_width=False, - return_index=True, - key="token_buttons", - radius=1, - ) - - # Update selected token based on click - if clicked_token_index != st.session_state.selected_token_index: - set_selected_token(clicked_token_index) - st.rerun() - - st.divider() - - # --- Token Information Area --- - if st.session_state.selected_token_index is not None: - idx = st.session_state.selected_token_index - # Ensure token_data is loaded before accessing - if st.session_state.token_data and idx < len(st.session_state.token_data): - token_info = st.session_state.token_data[idx] - token_text = token_info["text"] - - st.header(f"Token Info: '{token_text}' (Position: {idx}, ID: {token_info['id']})") - - # Layer Selection Dropdown - st.selectbox( - "Select Layer to Inspect:", - options=app_data["target_layer_names"], - key="selected_layer_name", # Binds selection to session state - ) - - # Display Layer-Specific Info if a layer is selected - if st.session_state.selected_layer_name: - layer_name = st.session_state.selected_layer_name - logger.debug(f"Displaying info for token {idx}, layer {layer_name}") - - if st.session_state.current_masks is None: - st.warning("Masks not calculated yet. Please load a prompt.") - return - - layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[ - layer_name - ] - token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :] - - # Find active components (mask > 0) - active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] - n_active_layer = len(active_indices_layer) - - st.metric(f"Active Components in {layer_name}", n_active_layer) - - st.subheader("Active Component Indices") - if n_active_layer > 0: - # Convert to NumPy array and reshape to a column vector (N x 1) - active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) - # Pass the NumPy array directly and configure the column header - st.dataframe( - active_indices_np, - height=300, - use_container_width=False, - column_config={0: "Component Index"}, # Rename the first column (index 0) - ) - else: - st.write("No active components for this token in this layer.") - - # Extensibility Placeholder - st.subheader("Additional Layer/Token Analysis") - st.write( - "Future figures and analyses for this specific layer and token will appear here." - ) - else: - # Handle case where selected_token_index might be invalid after data reload - st.warning("Selected token index is out of bounds. Please select a token again.") - st.session_state.selected_token_index = None # Reset selection - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Streamlit app to explore LM component activations." - ) - parser.add_argument( - "--model_path", - type=str, - default=DEFAULT_MODEL_PATH, - help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", - ) - args = parser.parse_args() - - run_app(args) From 9217638ca590a5a45fcc36f5dc09e9c908801d78 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 22 Apr 2025 05:52:26 +0000 Subject: [PATCH 4/4] Add simple-stories-train and datasets to pyproject.toml --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 38b4d51..4031cfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,8 @@ dependencies = [ "sympy", "streamlit", "streamlit-antd-components", + "datasets", + "simple-stories-train" ] [project.optional-dependencies]