From 921c3b8f35ce2aad1ea06dfd8aa60bba28daeb01 Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Mon, 23 Jun 2025 17:38:55 -0700 Subject: [PATCH 1/5] working replication of tied CLTs --- clt/config/clt_config.py | 27 ++ clt/models/clt.py | 176 +++++++- clt/models/decoder.py | 144 +++++- clt/models/encoder.py | 23 + scripts/train_clt.py | 38 ++ ...nd-to-end-training-pythia-tied-decoders.py | 423 ++++++++++++++++++ 6 files changed, 815 insertions(+), 16 deletions(-) create mode 100644 tutorials/1F-end-to-end-training-pythia-tied-decoders.py diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index 3e0d577..3990c18 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -34,6 +34,14 @@ class CLTConfig: tl_input_template: Optional[str] = None # TransformerLens hook point pattern before MLP tl_output_template: Optional[str] = None # TransformerLens hook point pattern after MLP # context_size: Optional[int] = None + + # Tied decoder configuration + decoder_tying: Literal["none", "per_source"] = "none" # Decoder weight sharing strategy + per_target_scale: bool = False # Enable learned scale for each src->tgt path + per_target_bias: bool = False # Enable learned bias for each src->tgt path + enable_feature_offset: bool = False # Enable per-feature bias (feature_offset) + enable_feature_scale: bool = False # Enable per-feature scale (feature_scale) + skip_connection: bool = False # Enable skip connection from input to output def __post_init__(self): """Validate configuration parameters.""" @@ -60,6 +68,12 @@ def __post_init__(self): raise ValueError("topk_k must be specified for TopK activation function.") if self.topk_k is not None and self.topk_k <= 0: raise ValueError("topk_k must be positive if specified.") + + # Validate decoder tying configuration + valid_decoder_tying = ["none", "per_source"] + assert ( + self.decoder_tying in valid_decoder_tying + ), f"Invalid decoder_tying: {self.decoder_tying}. Must be one of {valid_decoder_tying}" @classmethod def from_json(cls: Type[C], json_path: str) -> C: @@ -73,6 +87,19 @@ def from_json(cls: Type[C], json_path: str) -> C: """ with open(json_path, "r") as f: config_dict = json.load(f) + + # Handle backward compatibility for old configs + if "decoder_tying" not in config_dict: + config_dict["decoder_tying"] = "none" # Default to original behavior + if "per_target_scale" not in config_dict: + config_dict["per_target_scale"] = False + if "per_target_bias" not in config_dict: + config_dict["per_target_bias"] = False + if "enable_feature_offset" not in config_dict: + config_dict["enable_feature_offset"] = False + if "enable_feature_scale" not in config_dict: + config_dict["enable_feature_scale"] = False + return cls(**config_dict) def to_json(self, json_path: str) -> None: diff --git a/clt/models/clt.py b/clt/models/clt.py index 5ff9f4a..d288161 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -179,17 +179,122 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: ) return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype) + def _apply_feature_affine(self, activations: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: + """Apply per-feature offset and scale to activations if enabled. + + This function applies feature_offset and feature_scale only to non-zero activations. + This matches the reference implementation which applies post_enc only to selected features. + + Args: + activations: Dictionary mapping layer indices to activation tensors + + Returns: + Modified activations dictionary with affine transformations applied + """ + if not self.config.enable_feature_offset and not self.config.enable_feature_scale: + return activations + + transformed_activations = {} + + for layer_idx, acts in activations.items(): + if acts.numel() == 0: + transformed_activations[layer_idx] = acts + continue + + # Get non-zero positions (selected features) + nonzero_mask = acts != 0 + + if not nonzero_mask.any(): + transformed_activations[layer_idx] = acts + continue + + # Work with a copy to avoid in-place operations + transformed_acts = acts.clone() + + if acts.dim() == 2: # [batch, features] + # Get indices of non-zero elements + batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) + + if self.config.enable_feature_offset and self.encoder_module.feature_offset is not None: + # Apply offset only to selected features + offset_values = self.encoder_module.feature_offset[layer_idx][feature_indices] + transformed_acts[batch_indices, feature_indices] += offset_values + + if self.config.enable_feature_scale and self.encoder_module.feature_scale is not None: + # Apply scale only to selected features + scale_values = self.encoder_module.feature_scale[layer_idx][feature_indices] + transformed_acts[batch_indices, feature_indices] *= scale_values + else: + raise ValueError(f"Unexpected activation dimension: {acts.dim()}") + + transformed_activations[layer_idx] = transformed_acts + + return transformed_activations + def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: return self.decoder_module.decode(a, layer_idx) + + def _apply_skip_connection(self, input_tensor: torch.Tensor, layer_idx: int) -> torch.Tensor: + """Apply skip connection transformation to input. + + Args: + input_tensor: Input tensor at the given layer + layer_idx: Target layer index + + Returns: + Transformed input through skip connection + """ + if self.decoder_module.skip_weights is None: + return torch.zeros_like(input_tensor) + + # Ensure input is 2D for matrix multiplication + original_shape = input_tensor.shape + if input_tensor.dim() == 3: + # Flatten batch and sequence dimensions + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + else: + input_2d = input_tensor + + # Apply skip connection weight + if self.config.decoder_tying == "per_source": + # Use skip weight for this target layer + skip_weight = self.decoder_module.skip_weights[layer_idx] + else: + # For untied, we need to sum contributions from all source layers + # For now, just use the diagonal skip connection (src=tgt) + skip_key = f"{layer_idx}->{layer_idx}" + if skip_key in self.decoder_module.skip_weights: + skip_weight = self.decoder_module.skip_weights[skip_key] + else: + return torch.zeros_like(input_tensor) + + # Apply transformation: input @ W_skip^T + skip_output = input_2d @ skip_weight.T + + # Reshape back to original shape + if input_tensor.dim() == 3: + skip_output = skip_output.view(original_shape) + + return skip_output def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: activations = self.get_feature_activations(inputs) + + # Apply feature affine transformation if enabled + activations = self._apply_feature_affine(activations) reconstructions = {} for layer_idx in range(self.config.num_layers): relevant_activations = {k: v for k, v in activations.items() if k <= layer_idx and v.numel() > 0} if layer_idx in inputs and relevant_activations: - reconstructions[layer_idx] = self.decode(relevant_activations, layer_idx) + reconstruction = self.decode(relevant_activations, layer_idx) + + # Apply skip connection if enabled + if self.config.skip_connection and layer_idx in inputs: + skip_output = self._apply_skip_connection(inputs[layer_idx], layer_idx) + reconstruction = reconstruction + skip_output + + reconstructions[layer_idx] = reconstruction elif layer_idx in inputs: batch_size = 0 input_tensor = inputs[layer_idx] @@ -325,3 +430,72 @@ def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None: if not hasattr(self, "theta_manager") or self.theta_manager is None: raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.") self.theta_manager.log_threshold = new_param + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True): + """Load state dict with backward compatibility for old checkpoints. + + Handles: + 1. Old untied decoder format -> new tied/untied format + 2. Missing theta_bias/theta_scale parameters + 3. Missing per_target_scale/per_target_bias parameters + """ + # Check if this is an old checkpoint by looking for decoder keys + old_format_decoder_keys = [k for k in state_dict.keys() if 'decoders.' in k and '->' in k] + is_old_checkpoint = len(old_format_decoder_keys) > 0 + + if is_old_checkpoint and self.config.decoder_tying == "per_source": + logger.warning( + "Loading old untied decoder checkpoint into tied decoder model. " + "This will use weights from the first target layer for each source layer." + ) + + # Convert old decoder weights to tied format + # For each source layer, use the weights from src->src decoder + new_state_dict = {} + for key, value in state_dict.items(): + if 'decoders.' in key and '->' in key: + # Extract source and target layer indices + # Key format: "decoder_module.decoders.{src}->{tgt}.weight" or ".bias" + parts = key.split('.') + decoder_key_idx = parts.index('decoders') + 1 + src_tgt = parts[decoder_key_idx].split('->') + src_layer = int(src_tgt[0]) + tgt_layer = int(src_tgt[1]) + param_type = parts[-1] # 'weight' or 'bias' + + # Only use diagonal decoders (src->src) for tied architecture + if src_layer == tgt_layer: + new_key = '.'.join(parts[:decoder_key_idx] + [str(src_layer), param_type]) + new_state_dict[new_key] = value + else: + new_state_dict[key] = value + state_dict = new_state_dict + + # Handle missing feature affine parameters + if self.config.enable_feature_offset and self.encoder_module.feature_offset is not None: + for i in range(self.config.num_layers): + key = f"encoder_module.feature_offset.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} to zeros") + # Don't add to state_dict to let it be initialized by the module + + if self.config.enable_feature_scale and self.encoder_module.feature_scale is not None: + for i in range(self.config.num_layers): + key = f"encoder_module.feature_scale.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} to ones") + # Don't add to state_dict to let it be initialized by the module + + # Handle missing per-target parameters + if self.config.per_target_scale and hasattr(self.decoder_module, 'per_target_scale'): + key = "decoder_module.per_target_scale" + if key not in state_dict: + logger.info(f"Initializing missing {key} to ones") + + if self.config.per_target_bias and hasattr(self.decoder_module, 'per_target_bias'): + key = "decoder_module.per_target_bias" + if key not in state_dict: + logger.info(f"Initializing missing {key} to zeros") + + # Call parent's load_state_dict + return super().load_state_dict(state_dict, strict=strict) diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 68fd5c5..13db8b8 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -38,9 +38,11 @@ def __init__( self.world_size = dist_ops.get_world_size(process_group) self.rank = dist_ops.get_rank(process_group) - self.decoders = nn.ModuleDict( - { - f"{src_layer}->{tgt_layer}": RowParallelLinear( + # Initialize decoders based on tying configuration + if config.decoder_tying == "per_source": + # Tied decoders: one decoder per source layer + self.decoders = nn.ModuleList([ + RowParallelLinear( in_features=self.config.num_features, out_features=self.config.d_model, bias=True, @@ -51,10 +53,76 @@ def __init__( device=self.device, dtype=self.dtype, ) - for src_layer in range(self.config.num_layers) - for tgt_layer in range(src_layer, self.config.num_layers) - } - ) + for _ in range(self.config.num_layers) + ]) + + # Initialize decoder weights to zeros for tied transcoders + # This matches the reference implementation + for decoder in self.decoders: + nn.init.zeros_(decoder.weight) + if decoder.bias_param is not None: + nn.init.zeros_(decoder.bias_param) + + # Initialize per-target scale and bias if enabled + if config.per_target_scale: + self.per_target_scale = nn.Parameter( + torch.ones(self.config.num_layers, self.config.num_layers, self.config.d_model, + device=self.device, dtype=self.dtype) + ) + else: + self.per_target_scale = None + + if config.per_target_bias: + self.per_target_bias = nn.Parameter( + torch.zeros(self.config.num_layers, self.config.num_layers, self.config.d_model, + device=self.device, dtype=self.dtype) + ) + else: + self.per_target_bias = None + else: + # Original untied decoders: one decoder per (src, tgt) pair + self.decoders = nn.ModuleDict( + { + f"{src_layer}->{tgt_layer}": RowParallelLinear( + in_features=self.config.num_features, + out_features=self.config.d_model, + bias=True, + process_group=self.process_group, + input_is_parallel=False, + d_model_for_init=self.config.d_model, + num_layers_for_init=self.config.num_layers, + device=self.device, + dtype=self.dtype, + ) + for src_layer in range(self.config.num_layers) + for tgt_layer in range(src_layer, self.config.num_layers) + } + ) + self.per_target_scale = None + self.per_target_bias = None + + # Initialize skip connection weights if enabled + if config.skip_connection: + if config.decoder_tying == "per_source": + # For tied decoders, one skip connection per target layer + self.skip_weights = nn.ParameterList([ + nn.Parameter(torch.zeros(self.config.d_model, self.config.d_model, + device=self.device, dtype=self.dtype)) + for _ in range(self.config.num_layers) + ]) + else: + # For untied decoders, one skip connection per src->tgt pair + self.skip_weights = nn.ParameterDict({ + f"{src_layer}->{tgt_layer}": nn.Parameter( + torch.zeros(self.config.d_model, self.config.d_model, + device=self.device, dtype=self.dtype) + ) + for src_layer in range(self.config.num_layers) + for tgt_layer in range(src_layer, self.config.num_layers) + }) + else: + self.skip_weights = None + self.register_buffer("_cached_decoder_norms", None, persistent=False) def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: @@ -99,8 +167,21 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: ) continue - decoder = self.decoders[f"{src_layer}->{layer_idx}"] - decoded = decoder(activation_tensor) + if self.config.decoder_tying == "per_source": + # Use tied decoder for the source layer + decoder = self.decoders[src_layer] + decoded = decoder(activation_tensor) + + # Apply per-target scale and bias if enabled + if self.per_target_scale is not None: + decoded = decoded * self.per_target_scale[src_layer, layer_idx] + if self.per_target_bias is not None: + decoded = decoded + self.per_target_bias[src_layer, layer_idx] + else: + # Use untied decoder for (src, tgt) pair + decoder = self.decoders[f"{src_layer}->{layer_idx}"] + decoded = decoder(activation_tensor) + reconstruction += decoded return reconstruction @@ -143,10 +224,10 @@ def get_decoder_norms(self) -> torch.Tensor: for src_layer in range(self.config.num_layers): local_norms_sq_accum = torch.zeros(self.config.num_features, device=self.device, dtype=torch.float32) - for tgt_layer in range(src_layer, self.config.num_layers): - decoder_key = f"{src_layer}->{tgt_layer}" - decoder = self.decoders[decoder_key] - assert isinstance(decoder, RowParallelLinear), f"Decoder {decoder_key} is not RowParallelLinear" + if self.config.decoder_tying == "per_source": + # For tied decoders, compute norms once per source layer + decoder = self.decoders[src_layer] + assert isinstance(decoder, RowParallelLinear), f"Decoder {src_layer} is not RowParallelLinear" current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) @@ -161,7 +242,7 @@ def get_decoder_norms(self) -> torch.Tensor: pass elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: logger.warning( - f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for {decoder_key}. This might indicate an issue with RowParallelLinear partitioning." + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for decoder {src_layer}. This might indicate an issue with RowParallelLinear partitioning." ) if actual_local_dim > 0: @@ -171,9 +252,42 @@ def get_decoder_norms(self) -> torch.Tensor: local_norms_sq_accum[global_slice] += valid_norms_sq else: logger.warning( - f"Rank {self.rank}: Shape mismatch in decoder norm calculation for {decoder_key}. " + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for decoder {src_layer}. " f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." ) + else: + # For untied decoders, accumulate norms from all target layers + for tgt_layer in range(src_layer, self.config.num_layers): + decoder_key = f"{src_layer}->{tgt_layer}" + decoder = self.decoders[decoder_key] + assert isinstance(decoder, RowParallelLinear), f"Decoder {decoder_key} is not RowParallelLinear" + + current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) + + full_dim = decoder.full_in_features + features_per_rank = (full_dim + self.world_size - 1) // self.world_size + start_idx = self.rank * features_per_rank + end_idx = min(start_idx + features_per_rank, full_dim) + actual_local_dim = max(0, end_idx - start_idx) + local_dim_padded = decoder.local_in_features + + if local_dim_padded != features_per_rank and self.rank == self.world_size - 1: + pass + elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: + logger.warning( + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for {decoder_key}. This might indicate an issue with RowParallelLinear partitioning." + ) + + if actual_local_dim > 0: + valid_norms_sq = current_norms_sq[:actual_local_dim] + if valid_norms_sq.shape[0] == actual_local_dim: + global_slice = slice(start_idx, end_idx) + local_norms_sq_accum[global_slice] += valid_norms_sq + else: + logger.warning( + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for {decoder_key}. " + f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." + ) if self.process_group is not None and dist_ops.is_dist_initialized_and_available(): dist_ops.all_reduce(local_norms_sq_accum, op=dist_ops.SUM, group=self.process_group) diff --git a/clt/models/encoder.py b/clt/models/encoder.py index b9032e1..73d6e21 100644 --- a/clt/models/encoder.py +++ b/clt/models/encoder.py @@ -47,6 +47,29 @@ def __init__( for _ in range(config.num_layers) ] ) + + # Initialize theta_bias and theta_scale parameters if enabled + # These are per-layer, per-feature parameters + # Note: For tensor parallelism, each rank only holds a shard of features + features_per_rank = config.num_features // self.world_size + + if config.enable_feature_offset: + # Initialize feature_offset for each layer + self.feature_offset = nn.ParameterList([ + nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype)) + for _ in range(config.num_layers) + ]) + else: + self.feature_offset = None + + if config.enable_feature_scale: + # Initialize feature_scale for each layer + self.feature_scale = nn.ParameterList([ + nn.Parameter(torch.ones(features_per_rank, device=self.device, dtype=self.dtype)) + for _ in range(config.num_layers) + ]) + else: + self.feature_scale = None def get_preactivations(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: """Get pre-activation values (full tensor) for features at the specified layer.""" diff --git a/scripts/train_clt.py b/scripts/train_clt.py index bfca521..a79a316 100644 --- a/scripts/train_clt.py +++ b/scripts/train_clt.py @@ -245,6 +245,38 @@ def parse_args(): default=None, help="Optional data type for the CLT model parameters (e.g., 'float16', 'bfloat16').", ) + clt_group.add_argument( + "--decoder-tying", + type=str, + choices=["none", "per_source"], + default="none", + help="Decoder weight sharing strategy: 'none' (default) or 'per_source' (tied per source layer).", + ) + clt_group.add_argument( + "--per-target-scale", + action="store_true", + help="Enable learned scale for each src->tgt path when using tied decoders.", + ) + clt_group.add_argument( + "--per-target-bias", + action="store_true", + help="Enable learned bias for each src->tgt path when using tied decoders.", + ) + clt_group.add_argument( + "--enable-feature-offset", + action="store_true", + help="Enable per-feature bias (theta_bias) applied after encoding.", + ) + clt_group.add_argument( + "--enable-feature-scale", + action="store_true", + help="Enable per-feature scale (theta_scale) applied after encoding.", + ) + clt_group.add_argument( + "--skip-connection", + action="store_true", + help="Enable skip connection from input to output.", + ) # --- Training Hyperparameters (TrainingConfig) --- train_group = parser.add_argument_group("Training Hyperparameters (TrainingConfig)") @@ -608,6 +640,12 @@ def main(): clt_dtype=args.clt_dtype, topk_k=args.topk_k, topk_straight_through=(not args.disable_topk_straight_through), + decoder_tying=args.decoder_tying, + per_target_scale=args.per_target_scale, + per_target_bias=args.per_target_bias, + enable_feature_offset=args.enable_feature_offset, + enable_feature_scale=args.enable_feature_scale, + skip_connection=args.skip_connection, ) logger.info(f"CLT Config: {clt_config}") diff --git a/tutorials/1F-end-to-end-training-pythia-tied-decoders.py b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py new file mode 100644 index 0000000..2b5609f --- /dev/null +++ b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py @@ -0,0 +1,423 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with Tied Decoders and Feature Offset +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) using: +# - **Tied decoder architecture** to reduce memory usage +# - **Feature offset parameters** for per-feature bias +# - **BatchTopK activation** (same as Tutorial 1B) +# +# The tied decoder architecture uses one decoder per source layer (instead of one per source-target pair), +# significantly reducing memory usage from O(L²) to O(L) decoder parameters. +# +# We will: +# 1. Configure the CLT model with tied decoders and feature offset +# 2. Use the same pre-generated activations from Tutorial 1B +# 3. Train the model and compare memory usage +# 4. Demonstrate loading checkpoints with the new architecture + +# %% [markdown] +# ## 1. Imports and Setup + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Ensure tokenizers don't use parallelism +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Add project root to path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction (same as Tutorial 1B) +BASE_MODEL_NAME = "EleutherAI/pythia-70m" + +# %% [markdown] +# ## 2. Configuration with Tied Decoders +# +# Key differences from Tutorial 1B: +# - `decoder_tying="per_source"` - Enables tied decoder architecture +# - `enable_feature_offset=True` - Adds learnable per-feature bias +# - Memory savings: For 6 layers, we go from 21 decoders to just 6 + +# %% +# --- CLT Architecture Configuration with Tied Decoders --- +num_layers = 6 +d_model = 512 +expansion_factor = 32 +clt_num_features = d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", + batchtopk_k=batchtopk_k, + batchtopk_straight_through=True, + # NEW: Tied decoder configuration + decoder_tying="per_source", # Use one decoder per source layer + enable_feature_offset=True, # Enable per-feature bias (feature_offset) + enable_feature_scale=True, # Enable per-feature scale (feature_scale) + per_target_scale=True, # Not using per-target adaptations + per_target_bias=True, + skip_connection=True, # Enable skip connection from input to output +) + +print("CLT Configuration (Tied Decoders with Feature Affine):") +print(f"- decoder_tying: {clt_config.decoder_tying}") +print(f"- enable_feature_offset: {clt_config.enable_feature_offset}") +print(f"- enable_feature_scale: {clt_config.enable_feature_scale}") +print(f"- skip_connection: {clt_config.skip_connection}") +print(f"- Number of features: {clt_config.num_features}") +print(f"- Number of layers: {clt_config.num_layers}") +print(f"- Activation function: {clt_config.activation_fn}") +print(f"- BatchTopK k: {clt_config.batchtopk_k}") + +# Calculate memory savings +untied_decoders = sum(range(1, num_layers + 1)) # 6 + 5 + 4 + 3 + 2 + 1 = 21 +tied_decoders = num_layers # 6 +print(f"\nMemory savings:") +print(f"- Untied decoders: {untied_decoders} decoder matrices") +print(f"- Tied decoders: {tied_decoders} decoder matrices") +print(f"- Reduction: {(1 - tied_decoders/untied_decoders)*100:.1f}%") + +# --- Use existing activations from Tutorial 1B --- +# We'll use the same activation directory as Tutorial 1B since the base model +# and dataset are identical - only the CLT architecture differs +activation_dir = "./tutorial_activations_local_1M_pythia" +dataset_name = "monology/pile-uncopyrighted" + +expected_activation_path = os.path.join( + activation_dir, + BASE_MODEL_NAME, + f"{os.path.basename(dataset_name)}_train", +) + +# Verify activations exist +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if not (os.path.exists(metadata_path) and os.path.exists(manifest_path)): + print(f"\nERROR: Activations not found at {expected_activation_path}") + print("Please run Tutorial 1B first to generate the activations.") + raise FileNotFoundError("Activation dataset not found") +else: + print(f"\nUsing existing activations from: {expected_activation_path}") + +# --- Training Configuration --- +_lr = 1e-4 +_batch_size = 1024 + +# WandB run name includes tied decoder info +wdb_run_name = ( + f"{clt_config.num_features}-width-" + f"tied-decoders-" # Indicate tied decoder architecture + f"feat-offset-" # Indicate feature offset is enabled + f"batchtopk-k{batchtopk_k}-" + f"{_batch_size}-batch-" + f"{_lr:.1e}-lr" +) +print(f"\nGenerated WandB run name: {wdb_run_name}") + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Same as Tutorial 1B for comparison + seed=42, + # Activation source (using existing activations) + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float32", + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + # Normalization + normalization_method="auto", + # Loss function coefficients (same as Tutorial 1B) + sparsity_lambda=0.0, + sparsity_lambda_schedule="linear", + sparsity_c=0.0, + preactivation_coef=0, + aux_loss_factor=1 / 32, + apply_sparsity_penalty_to_batchtopk=False, + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, + max_features_for_diag_hist=1000, + checkpoint_interval=500, + dead_feature_window=200, + # WandB + enable_wandb=True, + wandb_project="clt-hp-sweeps-pythia-70m", + wandb_run_name=wdb_run_name, +) + +print("\nTraining Configuration:") +print(f"- Learning rate: {training_config.learning_rate}") +print(f"- Training steps: {training_config.training_steps}") +print(f"- Batch size (tokens): {training_config.train_batch_size_tokens}") + +# %% [markdown] +# ## 3. Initialize Model and Check Architecture +# +# Let's create the model and verify the tied decoder architecture is set up correctly. + +# %% +print("\nInitializing CLT model with tied decoders...") + +# Create model instance to inspect architecture +model = CrossLayerTranscoder( + config=clt_config, + process_group=None, + device=torch.device(device), +) + +print("\nModel architecture inspection:") +print(f"- Encoder modules: {len(model.encoder_module.encoders)}") +print(f"- Decoder modules: {len(model.decoder_module.decoders)}") + +# Check feature offset parameters +if model.encoder_module.feature_offset is not None: + print(f"- Feature offset parameters per layer: {len(model.encoder_module.feature_offset)}") + print(f"- Feature offset shape (layer 0): {model.encoder_module.feature_offset[0].shape}") +else: + print("- Feature offset: Not enabled") + +# Count total parameters +total_params = sum(p.numel() for p in model.parameters()) +encoder_params = sum(p.numel() for p in model.encoder_module.parameters()) +decoder_params = sum(p.numel() for p in model.decoder_module.parameters()) +print(f"\nParameter counts:") +print(f"- Total parameters: {total_params:,}") +print(f"- Encoder parameters: {encoder_params:,}") +print(f"- Decoder parameters: {decoder_params:,}") + +# Compare with untied architecture (approximate) +untied_decoder_params_approx = decoder_params * (untied_decoders / tied_decoders) +print(f"\nEstimated decoder parameters if untied: {untied_decoder_params_approx:,}") +print(f"Memory savings in decoder: {(1 - decoder_params/untied_decoder_params_approx)*100:.1f}%") + +# Clean up the test model +del model + +# %% [markdown] +# ## 4. Training the CLT with Tied Decoders + +# %% +print("\nInitializing CLTTrainer for training with tied decoders...") + +log_dir = f"clt_training_logs/clt_pythia_tied_decoders_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("\nCreating CLTTrainer instance...") + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("\nBeginning training with tied decoders...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Decoder tying: {clt_config.decoder_tying}") +print(f"Feature offset enabled: {clt_config.enable_feature_offset}") + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"\nTraining finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Tied Decoder Model + +# %% +# Save the final model state and config +final_model_state_path = os.path.join(log_dir, "clt_tied_final_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_tied_final_config.json") + +print(f"\nSaving final model state to: {final_model_state_path}") +print(f"Saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +# Verify the saved config has tied decoder settings +with open(final_model_config_path, "r") as f: + saved_config = json.load(f) + print(f"\nSaved config verification:") + print(f"- decoder_tying: {saved_config['decoder_tying']}") + print(f"- enable_feature_offset: {saved_config['enable_feature_offset']}") + print(f"- activation_fn: {saved_config['activation_fn']} (converted from batchtopk)") + +# Load the model back +print("\nLoading the saved tied decoder model...") +loaded_config = CLTConfig(**saved_config) +loaded_model = CrossLayerTranscoder( + config=loaded_config, + process_group=None, + device=torch.device(device), +) +loaded_model.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_model.eval() + +print("Model loaded successfully.") +print(f"Loaded model decoder count: {len(loaded_model.decoder_module.decoders)}") + +# %% [markdown] +# ## 6. Backward Compatibility Test +# +# Test loading an old untied checkpoint into our tied decoder model. +# This demonstrates the backward compatibility feature. + +# %% +print("\n=== Testing Backward Compatibility ===") + +# Create a simple untied model for testing +untied_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", # Simple activation for testing + decoder_tying="none", # Untied decoders +) + +print("Creating untied model for compatibility test...") +untied_model = CrossLayerTranscoder( + config=untied_config, + process_group=None, + device=torch.device("cpu"), # Use CPU for this test +) + +# Save untied model state +untied_state_dict = untied_model.state_dict() +print(f"Untied model decoder keys (first 5): {list(k for k in untied_state_dict.keys() if 'decoder' in k)[:5]}") + +# Create tied model with same dimensions +tied_test_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", + decoder_tying="per_source", # Tied decoders + enable_feature_offset=True, # This will be initialized to defaults +) + +tied_test_model = CrossLayerTranscoder( + config=tied_test_config, + process_group=None, + device=torch.device("cpu"), +) + +print("\nLoading untied checkpoint into tied model...") +try: + # This should work due to our custom load_state_dict + tied_test_model.load_state_dict(untied_state_dict, strict=False) + print("✓ Successfully loaded untied checkpoint into tied model!") + print(" The tied model uses diagonal decoder weights from the untied model.") +except Exception as e: + print(f"✗ Failed to load: {e}") + +# Clean up test models +del untied_model, tied_test_model + +# %% [markdown] +# ## 7. Performance Comparison Summary + +# %% +print("\n=== Tied Decoder Architecture Summary ===") +print(f"\nConfiguration used:") +print(f"- Model: {BASE_MODEL_NAME}") +print(f"- Layers: {num_layers}") +print(f"- Hidden dimension: {d_model}") +print(f"- Features per layer: {clt_num_features}") +print(f"- Decoder tying: {clt_config.decoder_tying}") +print(f"- Feature offset: {clt_config.enable_feature_offset}") + +print(f"\nMemory efficiency:") +print(f"- Traditional CLT: {untied_decoders} decoder matrices") +print(f"- Tied decoder CLT: {tied_decoders} decoder matrices") +print(f"- Memory reduction: ~{(1 - tied_decoders/untied_decoders)*100:.0f}%") + +print(f"\nKey benefits:") +print(f"1. Significant memory savings for decoder parameters") +print(f"2. Simpler feature interpretability (one decoder per source)") +print(f"3. Feature offset allows per-feature adaptation") +print(f"4. Backward compatible with existing checkpoints") + +print(f"\nTrade-offs:") +print(f"1. Less flexibility in source-target specific adaptations") +print(f"2. May require careful tuning of feature offset parameters") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial demonstrated: +# - Training a CLT with tied decoder architecture +# - Using feature offset parameters for per-feature bias +# - Significant memory savings compared to traditional CLT +# - Backward compatibility with untied checkpoints +# +# You can experiment with: +# - `per_target_scale` and `per_target_bias` for more flexibility +# - `enable_feature_scale` for per-feature scaling +# - Different values of `k` for BatchTopK +# - Comparing reconstruction quality between tied and untied architectures + +# %% +print(f"\n✓ Tied Decoder Tutorial Complete!") +print(f"Model and logs saved to: {log_dir}") From 015bc0dab23efafb9aadd194fe98d79bdeb8ea16 Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Mon, 30 Jun 2025 10:33:54 -0700 Subject: [PATCH 2/5] updated logging --- tests/unit/data/test_data_integrity.py | 2 +- tests/unit/models/test_tied_decoders.py | 285 ++++++++++++++++++++++++ 2 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 tests/unit/models/test_tied_decoders.py diff --git a/tests/unit/data/test_data_integrity.py b/tests/unit/data/test_data_integrity.py index 9a5f7c9..d76893f 100644 --- a/tests/unit/data/test_data_integrity.py +++ b/tests/unit/data/test_data_integrity.py @@ -193,7 +193,7 @@ def test_normalization_application_correctness(self, tmp_path): store = LocalActivationStore( dataset_path=output_dir, train_batch_size_tokens=100, - normalization_method="standard", # Enable normalization + normalization_method="mean_std", # Enable normalization dtype="float32", device="cpu", ) diff --git a/tests/unit/models/test_tied_decoders.py b/tests/unit/models/test_tied_decoders.py new file mode 100644 index 0000000..be3d6c4 --- /dev/null +++ b/tests/unit/models/test_tied_decoders.py @@ -0,0 +1,285 @@ +"""Unit tests for tied decoder functionality in CLT models.""" + +import pytest +import torch +import torch.nn as nn +from typing import Dict + +from clt.config import CLTConfig +from clt.models.clt import CrossLayerTranscoder +from clt.models.decoder import Decoder +from clt.models.encoder import Encoder + + +class TestTiedDecoders: + """Test suite for tied decoder architecture.""" + + @pytest.fixture + def base_config(self): + """Base CLT configuration for testing.""" + return CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + decoder_tying="none", # Default untied + ) + + @pytest.fixture + def tied_config(self): + """CLT configuration with tied decoders.""" + return CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + decoder_tying="per_source", + ) + + def test_decoder_initialization_untied(self, base_config): + """Test that untied decoder creates correct number of decoder modules.""" + decoder = Decoder( + config=base_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Should have decoders for each (src, tgt) pair where src <= tgt + # For 4 layers: 0->0, 0->1, 0->2, 0->3, 1->1, 1->2, 1->3, 2->2, 2->3, 3->3 + # Total: 4 + 3 + 2 + 1 = 10 + expected_decoder_count = sum(range(1, base_config.num_layers + 1)) + assert len(decoder.decoders) == expected_decoder_count + + # Check that all expected keys exist + for src in range(base_config.num_layers): + for tgt in range(src, base_config.num_layers): + assert f"{src}->{tgt}" in decoder.decoders + + def test_decoder_initialization_tied(self, tied_config): + """Test that tied decoder creates one decoder per source layer.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Should have one decoder per source layer + assert len(decoder.decoders) == tied_config.num_layers + + # Check that decoders are indexed by layer + for layer in range(tied_config.num_layers): + assert isinstance(decoder.decoders[layer], nn.Module) + + def test_per_target_parameters(self, tied_config): + """Test per-target scale and bias parameters.""" + # Test with per-target scale + config_with_scale = CLTConfig( + **{**tied_config.__dict__, "per_target_scale": True} + ) + decoder = Decoder( + config=config_with_scale, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + assert decoder.per_target_scale is not None + assert decoder.per_target_scale.shape == ( + config_with_scale.num_layers, + config_with_scale.num_layers, + config_with_scale.d_model, + ) + assert torch.allclose(decoder.per_target_scale, torch.ones_like(decoder.per_target_scale)) + + # Test with per-target bias + config_with_bias = CLTConfig( + **{**tied_config.__dict__, "per_target_bias": True} + ) + decoder = Decoder( + config=config_with_bias, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + assert decoder.per_target_bias is not None + assert decoder.per_target_bias.shape == ( + config_with_bias.num_layers, + config_with_bias.num_layers, + config_with_bias.d_model, + ) + assert torch.allclose(decoder.per_target_bias, torch.zeros_like(decoder.per_target_bias)) + + def test_feature_affine_parameters(self): + """Test feature offset and scale parameters in encoder.""" + config = CLTConfig( + num_features=128, + num_layers=4, + d_model=64, + activation_fn="relu", + enable_feature_offset=True, + enable_feature_scale=True, + ) + + encoder = Encoder( + config=config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Check feature_offset initialization + assert encoder.feature_offset is not None + assert len(encoder.feature_offset) == config.num_layers + for layer_offset in encoder.feature_offset: + assert layer_offset.shape == (config.num_features,) + assert torch.allclose(layer_offset, torch.zeros_like(layer_offset)) + + # Check feature_scale initialization + assert encoder.feature_scale is not None + assert len(encoder.feature_scale) == config.num_layers + for layer_scale in encoder.feature_scale: + assert layer_scale.shape == (config.num_features,) + assert torch.allclose(layer_scale, torch.ones_like(layer_scale)) + + def test_decode_with_tied_decoders(self, tied_config): + """Test decoding with tied decoders.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Create test activations + batch_size = 8 + activations = { + 0: torch.randn(batch_size, tied_config.num_features), + 1: torch.randn(batch_size, tied_config.num_features), + } + + # Test reconstruction at layer 1 + reconstruction = decoder.decode(activations, layer_idx=1) + + assert reconstruction.shape == (batch_size, tied_config.d_model) + # With zero-initialized decoders (matching reference implementation), + # the output will be zeros initially + assert torch.allclose(reconstruction, torch.zeros_like(reconstruction)) + + # Verify that if we set non-zero weights, we get non-zero outputs + for decoder_module in decoder.decoders: + decoder_module.weight.data.fill_(0.1) + reconstruction2 = decoder.decode(activations, layer_idx=1) + assert not torch.allclose(reconstruction2, torch.zeros_like(reconstruction2)) + + def test_decoder_norms_tied(self, tied_config): + """Test decoder norm computation for tied decoders.""" + decoder = Decoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + norms = decoder.get_decoder_norms() + + # Should have shape [num_layers, num_features] + assert norms.shape == (tied_config.num_layers, tied_config.num_features) + + # Norms should be positive + assert torch.all(norms >= 0) + + def test_feature_affine_transformation(self): + """Test feature affine transformation in forward pass.""" + config = CLTConfig( + num_features=128, + num_layers=2, + d_model=64, + activation_fn="relu", + enable_feature_offset=True, + enable_feature_scale=True, + ) + + model = CrossLayerTranscoder( + config=config, + process_group=None, + device=torch.device("cpu"), + ) + + # Create test inputs + batch_size = 4 + seq_len = 16 + inputs = { + 0: torch.randn(batch_size, seq_len, config.d_model), + 1: torch.randn(batch_size, seq_len, config.d_model), + } + + # Get activations + activations = model.get_feature_activations(inputs) + + # Apply affine transformation + transformed = model._apply_feature_affine(activations) + + # The transformation should preserve zeros + for layer_idx in transformed: + zero_mask = activations[layer_idx] == 0 + assert torch.all(transformed[layer_idx][zero_mask] == 0) + + def test_backward_compatibility_config(self): + """Test loading old config without new fields.""" + old_config_dict = { + "num_features": 128, + "num_layers": 4, + "d_model": 64, + "activation_fn": "relu", + # Missing: decoder_tying, per_target_scale, per_target_bias, + # enable_feature_offset, enable_feature_scale + } + + # Should not raise an error + config = CLTConfig(**old_config_dict) + + # Should have default values + assert config.decoder_tying == "none" + assert config.per_target_scale == False + assert config.per_target_bias == False + assert config.enable_feature_offset == False + assert config.enable_feature_scale == False + + def test_checkpoint_compatibility(self, base_config, tied_config): + """Test loading old untied checkpoint into tied model.""" + # Create untied model and save checkpoint + untied_model = CrossLayerTranscoder( + config=base_config, + process_group=None, + device=torch.device("cpu"), + ) + + # Get state dict from untied model + untied_state_dict = untied_model.state_dict() + + # Create tied model + tied_model = CrossLayerTranscoder( + config=tied_config, + process_group=None, + device=torch.device("cpu"), + ) + + # Should be able to load with custom logic + tied_model.load_state_dict(untied_state_dict, strict=False) + + # Tied model should have loaded the diagonal decoder weights + for src_layer in range(tied_config.num_layers): + tied_weight = tied_model.decoder_module.decoders[src_layer].weight + untied_key = f"decoder_module.decoders.{src_layer}->{src_layer}.weight" + if untied_key in untied_state_dict: + untied_weight = untied_state_dict[untied_key] + # Shapes might differ due to RowParallelLinear, so just check they're both tensors + assert isinstance(tied_weight, torch.Tensor) + assert isinstance(untied_weight, torch.Tensor) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file From 78e451d9f361bc04c50ca9b7e4c9c2857547a694 Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Mon, 30 Jun 2025 14:20:56 -0700 Subject: [PATCH 3/5] updated tied CLT implementation --- clt/config/clt_config.py | 35 ++- clt/models/clt.py | 98 +++----- clt/models/decoder.py | 216 +++++++++++++++--- clt/models/encoder.py | 24 +- .../data/manifest_activation_store.py | 34 ++- clt/training/evaluator.py | 17 +- clt/training/losses.py | 32 ++- clt/training/trainer.py | 17 +- clt/training/wandb_logger.py | 28 ++- scripts/convert_batchtopk_to_jumprelu.py | 28 ++- scripts/train_clt.py | 22 +- 11 files changed, 396 insertions(+), 155 deletions(-) diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index 3990c18..be08ab0 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -15,7 +15,7 @@ class CLTConfig: num_layers: int # Number of transformer layers d_model: int # Dimension of model's hidden state model_name: Optional[str] = None # Optional name for the underlying model - normalization_method: Literal["auto", "estimated_mean_std", "none"] = ( + normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = ( "none" # How activations were normalized during training ) activation_fn: Literal["jumprelu", "relu", "batchtopk", "topk"] = "jumprelu" @@ -27,6 +27,8 @@ class CLTConfig: topk_k: Optional[float] = None # Number or fraction of features to keep per token for TopK. # If < 1, treated as fraction. If >= 1, treated as int count. topk_straight_through: bool = True # Whether to use straight-through estimator for TopK. + # Top-K mode selection + topk_mode: Literal["global", "per_layer"] = "global" # How to apply top-k selection clt_dtype: Optional[str] = None # Optional dtype for the CLT model itself (e.g., "float16") expected_input_dtype: Optional[str] = None # Expected dtype of input activations mlp_input_template: Optional[str] = None # Module path template for MLP input activations @@ -36,7 +38,7 @@ class CLTConfig: # context_size: Optional[int] = None # Tied decoder configuration - decoder_tying: Literal["none", "per_source"] = "none" # Decoder weight sharing strategy + decoder_tying: Literal["none", "per_source", "per_target"] = "none" # Decoder weight sharing strategy per_target_scale: bool = False # Enable learned scale for each src->tgt path per_target_bias: bool = False # Enable learned bias for each src->tgt path enable_feature_offset: bool = False # Enable per-feature bias (feature_offset) @@ -48,7 +50,7 @@ def __post_init__(self): assert self.num_features > 0, "Number of features must be positive" assert self.num_layers > 0, "Number of layers must be positive" assert self.d_model > 0, "Model dimension must be positive" - valid_norm_methods = ["auto", "estimated_mean_std", "none"] + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] assert ( self.normalization_method in valid_norm_methods ), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}" @@ -70,7 +72,7 @@ def __post_init__(self): raise ValueError("topk_k must be positive if specified.") # Validate decoder tying configuration - valid_decoder_tying = ["none", "per_source"] + valid_decoder_tying = ["none", "per_source", "per_target"] assert ( self.decoder_tying in valid_decoder_tying ), f"Invalid decoder_tying: {self.decoder_tying}. Must be one of {valid_decoder_tying}" @@ -99,6 +101,21 @@ def from_json(cls: Type[C], json_path: str) -> C: config_dict["enable_feature_offset"] = False if "enable_feature_scale" not in config_dict: config_dict["enable_feature_scale"] = False + + # Handle backwards compatibility for old normalization methods + if "normalization_method" in config_dict: + old_method = config_dict["normalization_method"] + # Map old values to new ones + if old_method in ["auto", "estimated_mean_std"]: + config_dict["normalization_method"] = "mean_std" + elif old_method in ["auto_sqrt_d_model", "estimated_mean_std_sqrt_d_model"]: + config_dict["normalization_method"] = "sqrt_d_model" + + # Handle old sqrt_d_model_normalize flag + if "sqrt_d_model_normalize" in config_dict: + sqrt_normalize = config_dict.pop("sqrt_d_model_normalize") + if sqrt_normalize: + config_dict["normalization_method"] = "sqrt_d_model" return cls(**config_dict) @@ -135,11 +152,11 @@ class TrainingConfig: debug_anomaly: bool = False # Normalization parameters - normalization_method: Literal["auto", "estimated_mean_std", "none"] = "auto" - # 'auto': Use pre-calculated from mapped store, or estimate for streaming store. - # 'estimated_mean_std': Always estimate for streaming store (ignored for mapped). - # 'none': Disable normalization. - normalization_estimation_batches: int = 50 # Batches for normalization estimation + normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = "mean_std" + # 'none': No normalization. + # 'mean_std': Standard (x - mean) / std normalization using pre-calculated stats. + # 'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization. + normalization_estimation_batches: int = 50 # Batches for normalization estimation (if needed) # --- Activation Store Source --- # activation_source: Literal["local_manifest", "remote"] = "local_manifest" diff --git a/clt/models/clt.py b/clt/models/clt.py index d288161..61763e5 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -179,57 +179,6 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: ) return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype) - def _apply_feature_affine(self, activations: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: - """Apply per-feature offset and scale to activations if enabled. - - This function applies feature_offset and feature_scale only to non-zero activations. - This matches the reference implementation which applies post_enc only to selected features. - - Args: - activations: Dictionary mapping layer indices to activation tensors - - Returns: - Modified activations dictionary with affine transformations applied - """ - if not self.config.enable_feature_offset and not self.config.enable_feature_scale: - return activations - - transformed_activations = {} - - for layer_idx, acts in activations.items(): - if acts.numel() == 0: - transformed_activations[layer_idx] = acts - continue - - # Get non-zero positions (selected features) - nonzero_mask = acts != 0 - - if not nonzero_mask.any(): - transformed_activations[layer_idx] = acts - continue - - # Work with a copy to avoid in-place operations - transformed_acts = acts.clone() - - if acts.dim() == 2: # [batch, features] - # Get indices of non-zero elements - batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) - - if self.config.enable_feature_offset and self.encoder_module.feature_offset is not None: - # Apply offset only to selected features - offset_values = self.encoder_module.feature_offset[layer_idx][feature_indices] - transformed_acts[batch_indices, feature_indices] += offset_values - - if self.config.enable_feature_scale and self.encoder_module.feature_scale is not None: - # Apply scale only to selected features - scale_values = self.encoder_module.feature_scale[layer_idx][feature_indices] - transformed_acts[batch_indices, feature_indices] *= scale_values - else: - raise ValueError(f"Unexpected activation dimension: {acts.dim()}") - - transformed_activations[layer_idx] = transformed_acts - - return transformed_activations def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: return self.decoder_module.decode(a, layer_idx) @@ -256,8 +205,8 @@ def _apply_skip_connection(self, input_tensor: torch.Tensor, layer_idx: int) -> input_2d = input_tensor # Apply skip connection weight - if self.config.decoder_tying == "per_source": - # Use skip weight for this target layer + if self.config.decoder_tying in ["per_source", "per_target"]: + # For tied decoders, use skip weight for this target layer skip_weight = self.decoder_module.skip_weights[layer_idx] else: # For untied, we need to sum contributions from all source layers @@ -280,8 +229,7 @@ def _apply_skip_connection(self, input_tensor: torch.Tensor, layer_idx: int) -> def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: activations = self.get_feature_activations(inputs) - # Apply feature affine transformation if enabled - activations = self._apply_feature_affine(activations) + # Note: feature affine transformations are now applied in the decoder reconstructions = {} for layer_idx in range(self.config.num_layers): @@ -321,6 +269,17 @@ def get_feature_activations(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, processed_inputs[layer_idx] = x_orig.to(device=self.device, dtype=self.dtype) if self.config.activation_fn == "batchtopk" or self.config.activation_fn == "topk": + # Check if we should use per-layer mode + if self.config.topk_mode == "per_layer": + # Use per-layer top-k by calling encode on each layer + activations = {} + for layer_idx in sorted(processed_inputs.keys()): + x_input = processed_inputs[layer_idx] + act = self.encode(x_input, layer_idx) + activations[layer_idx] = act + return activations + + # Otherwise use global top-k preactivations_dict, _ = self._encode_all_layers(processed_inputs) if not preactivations_dict: activations = {} @@ -471,26 +430,41 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = Tr new_state_dict[key] = value state_dict = new_state_dict - # Handle missing feature affine parameters - if self.config.enable_feature_offset and self.encoder_module.feature_offset is not None: + # Handle feature affine parameters migration from encoder to decoder module + # (for backward compatibility with old checkpoints) + for i in range(self.config.num_layers): + old_offset_key = f"encoder_module.feature_offset.{i}" + new_offset_key = f"decoder_module.feature_offset.{i}" + if old_offset_key in state_dict and new_offset_key not in state_dict: + logger.info(f"Migrating {old_offset_key} to {new_offset_key}") + state_dict[new_offset_key] = state_dict.pop(old_offset_key) + + old_scale_key = f"encoder_module.feature_scale.{i}" + new_scale_key = f"decoder_module.feature_scale.{i}" + if old_scale_key in state_dict and new_scale_key not in state_dict: + logger.info(f"Migrating {old_scale_key} to {new_scale_key}") + state_dict[new_scale_key] = state_dict.pop(old_scale_key) + + # Handle missing feature affine parameters (now in decoder module) + if self.config.enable_feature_offset and hasattr(self.decoder_module, 'feature_offset') and self.decoder_module.feature_offset is not None: for i in range(self.config.num_layers): - key = f"encoder_module.feature_offset.{i}" + key = f"decoder_module.feature_offset.{i}" if key not in state_dict: logger.info(f"Initializing missing {key} to zeros") # Don't add to state_dict to let it be initialized by the module - if self.config.enable_feature_scale and self.encoder_module.feature_scale is not None: + if self.config.enable_feature_scale and hasattr(self.decoder_module, 'feature_scale') and self.decoder_module.feature_scale is not None: for i in range(self.config.num_layers): - key = f"encoder_module.feature_scale.{i}" + key = f"decoder_module.feature_scale.{i}" if key not in state_dict: - logger.info(f"Initializing missing {key} to ones") + logger.info(f"Initializing missing {key} (first target layer to ones, rest to zeros)") # Don't add to state_dict to let it be initialized by the module # Handle missing per-target parameters if self.config.per_target_scale and hasattr(self.decoder_module, 'per_target_scale'): key = "decoder_module.per_target_scale" if key not in state_dict: - logger.info(f"Initializing missing {key} to ones") + logger.info(f"Initializing missing {key} (diagonal to ones, off-diagonal to zeros)") if self.config.per_target_bias and hasattr(self.decoder_module, 'per_target_bias'): key = "decoder_module.per_target_bias" diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 13db8b8..1a1d0cd 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -55,20 +55,42 @@ def __init__( ) for _ in range(self.config.num_layers) ]) - - # Initialize decoder weights to zeros for tied transcoders - # This matches the reference implementation + elif config.decoder_tying == "per_target": + # Tied decoders: one decoder per target layer (EleutherAI style) + self.decoders = nn.ModuleList([ + RowParallelLinear( + in_features=self.config.num_features, + out_features=self.config.d_model, + bias=True, + process_group=self.process_group, + input_is_parallel=False, + d_model_for_init=self.config.d_model, + num_layers_for_init=self.config.num_layers, + device=self.device, + dtype=self.dtype, + ) + for _ in range(self.config.num_layers) + ]) + + # Initialize decoder weights to zeros for tied decoders (both per_source and per_target) + if config.decoder_tying in ["per_source", "per_target"]: for decoder in self.decoders: nn.init.zeros_(decoder.weight) - if decoder.bias_param is not None: + if hasattr(decoder, 'bias_param') and decoder.bias_param is not None: nn.init.zeros_(decoder.bias_param) + elif hasattr(decoder, 'bias') and decoder.bias is not None: + nn.init.zeros_(decoder.bias) # Initialize per-target scale and bias if enabled if config.per_target_scale: - self.per_target_scale = nn.Parameter( - torch.ones(self.config.num_layers, self.config.num_layers, self.config.d_model, - device=self.device, dtype=self.dtype) - ) + # Initialize scale: diagonal gets ones, off-diagonal gets small values for gradient flow + # Small non-zero values allow gradients to flow even without skip connections + scale_init = torch.full((self.config.num_layers, self.config.num_layers, self.config.d_model), + 0.1, device=self.device, dtype=self.dtype) + # Set diagonal (same src->tgt layer) scales to 1.0 + for i in range(self.config.num_layers): + scale_init[i, i, :] = 1.0 + self.per_target_scale = nn.Parameter(scale_init) else: self.per_target_scale = None @@ -103,7 +125,7 @@ def __init__( # Initialize skip connection weights if enabled if config.skip_connection: - if config.decoder_tying == "per_source": + if config.decoder_tying in ["per_source", "per_target"]: # For tied decoders, one skip connection per target layer self.skip_weights = nn.ParameterList([ nn.Parameter(torch.zeros(self.config.d_model, self.config.d_model, @@ -123,6 +145,38 @@ def __init__( else: self.skip_weights = None + # Initialize feature_offset and feature_scale (indexed by target layer) + # These match EleutherAI's post_enc and post_enc_scale + # Note: Currently only implemented for tied decoders to match EleutherAI + # For per_source tying, these would need to be indexed differently + if config.decoder_tying in ["per_source", "per_target"]: + features_per_rank = config.num_features // self.world_size if self.world_size > 1 else config.num_features + + if config.enable_feature_offset: + # Initialize feature_offset for each target layer + self.feature_offset = nn.ParameterList([ + nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype)) + for _ in range(config.num_layers) + ]) + else: + self.feature_offset = None + + if config.enable_feature_scale: + # Initialize feature_scale for each target layer + # First target layer gets ones, rest get small non-zero values to allow gradient flow + self.feature_scale = nn.ParameterList([ + nn.Parameter( + torch.ones(features_per_rank, device=self.device, dtype=self.dtype) if i == 0 + else torch.full((features_per_rank,), 0.1, device=self.device, dtype=self.dtype) + ) + for i in range(config.num_layers) + ]) + else: + self.feature_scale = None + else: + self.feature_offset = None + self.feature_scale = None + self.register_buffer("_cached_decoder_norms", None, persistent=False) def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: @@ -155,34 +209,99 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: reconstruction = torch.zeros((batch_dim_size, self.config.d_model), device=self.device, dtype=self.dtype) - for src_layer in range(layer_idx + 1): - if src_layer in a: - activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) - - if activation_tensor.numel() == 0: - continue - if activation_tensor.shape[-1] != self.config.num_features: - logger.warning( - f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." - ) - continue + if self.config.decoder_tying == "per_target": + # EleutherAI style: sum activations first, then decode once + summed_activation = torch.zeros((batch_dim_size, self.config.num_features), device=self.device, dtype=self.dtype) + + for src_layer in range(layer_idx + 1): + if src_layer in a: + activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) - if self.config.decoder_tying == "per_source": - # Use tied decoder for the source layer - decoder = self.decoders[src_layer] - decoded = decoder(activation_tensor) + if activation_tensor.numel() == 0: + continue + if activation_tensor.shape[-1] != self.config.num_features: + logger.warning( + f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." + ) + continue + + # Apply feature affine transformations (indexed by target layer) + # Note: EleutherAI applies these to ALL selected features, not just non-zero + if self.feature_offset is not None or self.feature_scale is not None: + activation_tensor = activation_tensor.clone() + + if self.feature_offset is not None: + # Apply offset to all features (not just non-zero) + activation_tensor += self.feature_offset[layer_idx] + + if self.feature_scale is not None: + # Apply scale to all features (not just non-zero) + activation_tensor *= self.feature_scale[layer_idx] - # Apply per-target scale and bias if enabled + # Apply per-target scale and bias if enabled (before summing) + # Note: EleutherAI doesn't have these parameters if self.per_target_scale is not None: - decoded = decoded * self.per_target_scale[src_layer, layer_idx] + activation_tensor = activation_tensor * self.per_target_scale[src_layer, layer_idx] if self.per_target_bias is not None: - decoded = decoded + self.per_target_bias[src_layer, layer_idx] - else: - # Use untied decoder for (src, tgt) pair - decoder = self.decoders[f"{src_layer}->{layer_idx}"] - decoded = decoder(activation_tensor) + activation_tensor = activation_tensor + self.per_target_bias[src_layer, layer_idx] + + summed_activation += activation_tensor + + # Now decode ONCE with the summed activation + decoder = self.decoders[layer_idx] + reconstruction = decoder(summed_activation) + + else: + # Original logic for per_source and untied decoders + for src_layer in range(layer_idx + 1): + if src_layer in a: + activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype) + + if activation_tensor.numel() == 0: + continue + if activation_tensor.shape[-1] != self.config.num_features: + logger.warning( + f"Rank {self.rank}: Activation tensor for layer {src_layer} has incorrect feature dimension {activation_tensor.shape[-1]}, expected {self.config.num_features}. Skipping decode contribution." + ) + continue + + # Apply feature affine transformations for per_source + if self.config.decoder_tying == "per_source": + # Get non-zero positions (selected features) + nonzero_mask = activation_tensor != 0 + + if nonzero_mask.any(): + # Apply transformations only to selected features + activation_tensor = activation_tensor.clone() + batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) + + if self.feature_offset is not None: + # Apply offset indexed by target layer + offset_values = self.feature_offset[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] += offset_values + + if self.feature_scale is not None: + # Apply scale indexed by target layer + scale_values = self.feature_scale[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] *= scale_values + + if self.config.decoder_tying == "per_source": + # Use tied decoder for the source layer + decoder = self.decoders[src_layer] + decoded = decoder(activation_tensor) + + # Apply per-target scale and bias if enabled + if self.per_target_scale is not None: + decoded = decoded * self.per_target_scale[src_layer, layer_idx] + if self.per_target_bias is not None: + decoded = decoded + self.per_target_bias[src_layer, layer_idx] + else: + # Use untied decoder for (src, tgt) pair + decoder = self.decoders[f"{src_layer}->{layer_idx}"] + decoded = decoder(activation_tensor) + + reconstruction += decoded - reconstruction += decoded return reconstruction def get_decoder_norms(self) -> torch.Tensor: @@ -255,6 +374,39 @@ def get_decoder_norms(self) -> torch.Tensor: f"Rank {self.rank}: Shape mismatch in decoder norm calculation for decoder {src_layer}. " f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." ) + elif self.config.decoder_tying == "per_target": + # For per_target tying, each decoder corresponds to a target layer + # We accumulate decoder norms from all target layers >= src_layer + for tgt_layer in range(src_layer, self.config.num_layers): + decoder = self.decoders[tgt_layer] + assert isinstance(decoder, RowParallelLinear), f"Decoder {tgt_layer} is not RowParallelLinear" + + current_norms_sq = torch.norm(decoder.weight, dim=0).pow(2).to(torch.float32) + + full_dim = decoder.full_in_features + features_per_rank = (full_dim + self.world_size - 1) // self.world_size + start_idx = self.rank * features_per_rank + end_idx = min(start_idx + features_per_rank, full_dim) + actual_local_dim = max(0, end_idx - start_idx) + local_dim_padded = decoder.local_in_features + + if local_dim_padded != features_per_rank and self.rank == self.world_size - 1: + pass + elif local_dim_padded != actual_local_dim and local_dim_padded != features_per_rank: + logger.warning( + f"Rank {self.rank}: Padded local dim ({local_dim_padded}) doesn't match calculated actual local dim ({actual_local_dim}) or features_per_rank ({features_per_rank}) for decoder {tgt_layer}. This might indicate an issue with RowParallelLinear partitioning." + ) + + if actual_local_dim > 0: + valid_norms_sq = current_norms_sq[:actual_local_dim] + if valid_norms_sq.shape[0] == actual_local_dim: + global_slice = slice(start_idx, end_idx) + local_norms_sq_accum[global_slice] += valid_norms_sq + else: + logger.warning( + f"Rank {self.rank}: Shape mismatch in decoder norm calculation for decoder {tgt_layer}. " + f"Valid norms shape {valid_norms_sq.shape}, expected size {actual_local_dim}." + ) else: # For untied decoders, accumulate norms from all target layers for tgt_layer in range(src_layer, self.config.num_layers): diff --git a/clt/models/encoder.py b/clt/models/encoder.py index 73d6e21..07b30c2 100644 --- a/clt/models/encoder.py +++ b/clt/models/encoder.py @@ -48,28 +48,8 @@ def __init__( ] ) - # Initialize theta_bias and theta_scale parameters if enabled - # These are per-layer, per-feature parameters - # Note: For tensor parallelism, each rank only holds a shard of features - features_per_rank = config.num_features // self.world_size - - if config.enable_feature_offset: - # Initialize feature_offset for each layer - self.feature_offset = nn.ParameterList([ - nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype)) - for _ in range(config.num_layers) - ]) - else: - self.feature_offset = None - - if config.enable_feature_scale: - # Initialize feature_scale for each layer - self.feature_scale = nn.ParameterList([ - nn.Parameter(torch.ones(features_per_rank, device=self.device, dtype=self.dtype)) - for _ in range(config.num_layers) - ]) - else: - self.feature_scale = None + # Note: feature_offset and feature_scale have been moved to Decoder module + # to match EleutherAI's architecture where they are indexed by target layer def get_preactivations(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: """Get pre-activation values (full tensor) for features at the specified layer.""" diff --git a/clt/training/data/manifest_activation_store.py b/clt/training/data/manifest_activation_store.py index da13cbb..b441362 100644 --- a/clt/training/data/manifest_activation_store.py +++ b/clt/training/data/manifest_activation_store.py @@ -5,6 +5,7 @@ from collections import defaultdict import threading import queue +import math # import json # Unused from abc import ABC, abstractmethod @@ -367,6 +368,7 @@ def __init__( self.epoch = 0 self.prefetch_batches = max(1, prefetch_batches) self.sampling_strategy = sampling_strategy + self.normalization_method = normalization_method self.shard_data = shard_data # Device setup @@ -483,8 +485,12 @@ def __init__( self.apply_normalization = False if normalization_method == "none": self.apply_normalization = False - else: + elif normalization_method == "mean_std": + # mean_std requires normalization stats self.apply_normalization = bool(self.norm_stats_data) + elif normalization_method == "sqrt_d_model": + # sqrt_d_model doesn't need norm stats, just applies scaling + self.apply_normalization = True if self.apply_normalization: self._prep_norm() @@ -556,9 +562,14 @@ def _prep_norm(self): self.mean_tg: Dict[int, torch.Tensor] = {} self.std_tg: Dict[int, torch.Tensor] = {} - if not self.norm_stats_data: - logger.warning("Normalization prep called but no stats data loaded.") - self.apply_normalization = False + # Only need to load stats for mean_std normalization + if self.normalization_method == "mean_std": + if not self.norm_stats_data: + logger.warning("mean_std normalization requested but no stats data loaded.") + self.apply_normalization = False + return + elif self.normalization_method == "sqrt_d_model": + # sqrt_d_model doesn't need stats, just return return missing_layers = set(self.layer_indices) @@ -901,10 +912,17 @@ def _fetch_and_parse_batch(self, idxs: np.ndarray) -> ActivationBatch: log_stats_this_batch["target_mean_in"] = self.mean_in[li].mean().item() log_stats_this_batch["target_std_in"] = self.std_in[li].mean().item() - if li in self.mean_in and li in self.std_in: - inputs_li = (inputs_li - self.mean_in[li]) / self.std_in[li] - if li in self.mean_tg and li in self.std_tg: - targets_li = (targets_li - self.mean_tg[li]) / self.std_tg[li] + if self.normalization_method == "mean_std": + # Standard normalization: (x - mean) / std + if li in self.mean_in and li in self.std_in: + inputs_li = (inputs_li - self.mean_in[li]) / self.std_in[li] + if li in self.mean_tg and li in self.std_tg: + targets_li = (targets_li - self.mean_tg[li]) / self.std_tg[li] + elif self.normalization_method == "sqrt_d_model": + # EleutherAI-style normalization: x * sqrt(d_model) + sqrt_d_model = math.sqrt(self.d_model) + inputs_li = inputs_li * sqrt_d_model + targets_li = targets_li * sqrt_d_model # Convert to final target dtype *after* normalization final_batch_inputs[li] = inputs_li.to(self.dtype) diff --git a/clt/training/evaluator.py b/clt/training/evaluator.py index eb77645..292f9e4 100644 --- a/clt/training/evaluator.py +++ b/clt/training/evaluator.py @@ -35,6 +35,8 @@ def __init__( start_time: Optional[float] = None, mean_tg: Optional[Dict[int, torch.Tensor]] = None, std_tg: Optional[Dict[int, torch.Tensor]] = None, + normalization_method: str = "none", + d_model: Optional[int] = None, ): """Initialize the evaluator. @@ -44,6 +46,8 @@ def __init__( start_time: The initial time.time() from the trainer for elapsed time logging. mean_tg: Optional dictionary of per-layer target means for de-normalising outputs. std_tg: Optional dictionary of per-layer target stds for de-normalising outputs. + normalization_method: The normalization method being used. + d_model: Model dimension for sqrt_d_model normalization. """ self.model = model self.device = device @@ -51,6 +55,8 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + self.normalization_method = normalization_method + self.d_model = d_model self.metrics_history: List[Dict[str, Any]] = [] # For storing metrics over time if needed @staticmethod @@ -256,15 +262,22 @@ def _compute_reconstruction_metrics( recon_act = reconstructions[layer_idx] - # --- De-normalise if stats available --- + # --- De-normalise based on normalization method --- target_act_denorm = target_act recon_act_denorm = recon_act - if layer_idx in self.mean_tg and layer_idx in self.std_tg: + + if self.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + # Standard denormalization: x * std + mean mean = self.mean_tg[layer_idx].to(recon_act.device, recon_act.dtype) std = self.std_tg[layer_idx].to(recon_act.device, recon_act.dtype) # Ensure broadcast shape target_act_denorm = target_act * std + mean recon_act_denorm = recon_act * std + mean + elif self.normalization_method == "sqrt_d_model" and self.d_model is not None: + # sqrt_d_model denormalization: x / sqrt(d_model) + sqrt_d_model = (self.d_model ** 0.5) + target_act_denorm = target_act / sqrt_d_model + recon_act_denorm = recon_act / sqrt_d_model # --- End De-normalisation --- # Ensure shapes match (flatten if necessary) and up-cast to float32 for numerically stable metrics diff --git a/clt/training/losses.py b/clt/training/losses.py index 6771148..f2245b8 100644 --- a/clt/training/losses.py +++ b/clt/training/losses.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from typing import Dict, Tuple, Optional -from clt.config import TrainingConfig +from clt.config import TrainingConfig, CLTConfig from clt.models.clt import CrossLayerTranscoder @@ -15,6 +15,7 @@ def __init__( config: TrainingConfig, mean_tg: Optional[Dict[int, torch.Tensor]] = None, std_tg: Optional[Dict[int, torch.Tensor]] = None, + clt_config: Optional['CLTConfig'] = None, ): """Initialize the loss manager. @@ -22,6 +23,7 @@ def __init__( config: Training configuration mean_tg: Optional dictionary of per-layer target means for de-normalising outputs std_tg: Optional dictionary of per-layer target stds for de-normalising outputs + clt_config: Optional CLT configuration for accessing d_model """ self.config = config self.reconstruction_loss_fn = nn.MSELoss() @@ -29,6 +31,7 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + self.clt_config = clt_config self.aux_loss_factor = config.aux_loss_factor # New: coefficient for auxiliary loss self.apply_sparsity_penalty_to_batchtopk = config.apply_sparsity_penalty_to_batchtopk @@ -69,13 +72,19 @@ def compute_reconstruction_loss( pred_layer = predicted[layer_idx] tgt_layer = target[layer_idx] - # De-normalise if stats available for this layer - if layer_idx in self.mean_tg and layer_idx in self.std_tg: + # De-normalise based on normalization method + if self.config.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + # Standard denormalization: x * std + mean mean = self.mean_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) std = self.std_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) # mean/std were stored with an added batch dim – ensure broadcast shape pred_layer = pred_layer * std + mean tgt_layer = tgt_layer * std + mean + elif self.config.normalization_method == "sqrt_d_model" and self.clt_config is not None: + # sqrt_d_model denormalization: x / sqrt(d_model) + sqrt_d_model = (self.clt_config.d_model ** 0.5) + pred_layer = pred_layer / sqrt_d_model + tgt_layer = tgt_layer / sqrt_d_model layer_loss = self.reconstruction_loss_fn(pred_layer, tgt_layer) total_loss += layer_loss @@ -360,10 +369,25 @@ def compute_total_loss( preactivation_loss = self.compute_preactivation_loss(model, inputs) # Compute residuals for auxiliary loss if needed + # Important: Compute residuals in denormalized (original) space for consistent auxiliary loss scale residuals = {} for layer_idx in predictions: if layer_idx in targets: - residuals[layer_idx] = targets[layer_idx] - predictions[layer_idx] + pred_layer = predictions[layer_idx] + tgt_layer = targets[layer_idx] + + # Denormalize before computing residuals + if self.config.normalization_method == "mean_std" and layer_idx in self.mean_tg and layer_idx in self.std_tg: + mean = self.mean_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) + std = self.std_tg[layer_idx].to(pred_layer.device, pred_layer.dtype) + pred_layer = pred_layer * std + mean + tgt_layer = tgt_layer * std + mean + elif self.config.normalization_method == "sqrt_d_model" and self.clt_config is not None: + sqrt_d_model = (self.clt_config.d_model ** 0.5) + pred_layer = pred_layer / sqrt_d_model + tgt_layer = tgt_layer / sqrt_d_model + + residuals[layer_idx] = tgt_layer - pred_layer # Compute auxiliary loss (only if configured and using BatchTopK) aux_loss = torch.tensor(0.0, device=reconstruction_loss.device) diff --git a/clt/training/trainer.py b/clt/training/trainer.py index 32f7ec2..2b5dc08 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -324,6 +324,7 @@ def lr_lambda(current_step: int): training_config, mean_tg=mean_tg_stats, std_tg=std_tg_stats, + clt_config=clt_config, ) # Initialize Evaluator - Pass norm stats here too @@ -333,6 +334,8 @@ def lr_lambda(current_step: int): start_time=self.start_time, mean_tg=mean_tg_stats, # Pass the same stats std_tg=std_tg_stats, # Pass the same stats + normalization_method=training_config.normalization_method, + d_model=clt_config.d_model, ) # Initialize dead neuron counters (replicated for now, consider sharding later if needed) @@ -472,11 +475,15 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: logger.info(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") # Check if using normalization and notify user - if self.training_config.normalization_method == "estimated_mean_std": - logger.info("\n>>> NORMALIZATION PHASE <<<") - logger.info("Normalization statistics are being estimated from dataset activations.") - logger.info("This may take some time, but happens only once before training begins.") - logger.info(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") + if self.training_config.normalization_method == "mean_std": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("Using mean/std normalization with pre-calculated statistics from norm_stats.json") + elif self.training_config.normalization_method == "sqrt_d_model": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("Using sqrt(d_model) normalization (EleutherAI-style)") + elif self.training_config.normalization_method == "none": + logger.info("\n>>> NORMALIZATION CONFIGURATION <<<") + logger.info("No normalization will be applied to activations") # Make sure we flush stdout to ensure prints appear immediately, # especially important in Jupyter/interactive environments diff --git a/clt/training/wandb_logger.py b/clt/training/wandb_logger.py index 3c351af..521ee58 100644 --- a/clt/training/wandb_logger.py +++ b/clt/training/wandb_logger.py @@ -68,21 +68,41 @@ def __init__( entity_name = os.environ.get("WANDB_ENTITY", training_config.wandb_entity) run_name = training_config.wandb_run_name # Can be None + # Prepare config for both new and resumed runs + wandb_config = self._create_wandb_config(training_config, clt_config) + if self.resume_wandb_id: logger.info(f"Attempting to resume WandB run with ID: {self.resume_wandb_id}") - # When resuming, wandb.init will use the passed id. Do not pass project/entity/name again if resuming. - self.wandb_run = wandb.init(id=self.resume_wandb_id, resume="allow") + # When resuming, we need to update the config after init + self.wandb_run = wandb.init( + id=self.resume_wandb_id, + resume="allow", + project=project_name, + entity=entity_name, + tags=training_config.wandb_tags, + ) + # Update config after resuming + if self.wandb_run: + wandb.config.update(wandb_config, allow_val_change=True) else: self.wandb_run = wandb.init( project=project_name, entity=entity_name, name=run_name, - config=self._create_wandb_config(training_config, clt_config), - reinit=True, # Allow re-initialization in the same process if needed + config=wandb_config, + tags=training_config.wandb_tags, ) if self.wandb_run: + self._run_id = self.wandb_run.id logger.info(f"WandB logging initialized: {self.wandb_run.name} (ID: {self.wandb_run.id})") + # Log config keys to verify they were set (only in debug mode) + if logger.isEnabledFor(logging.DEBUG): + if hasattr(wandb, 'config') and wandb.config: + config_keys = list(wandb.config.keys()) + logger.debug(f"WandB config keys: {config_keys}") + else: + logger.debug("WandB config appears to be empty or not accessible") else: logger.warning("Warning: WandB run initialization failed but no exception was raised.") diff --git a/scripts/convert_batchtopk_to_jumprelu.py b/scripts/convert_batchtopk_to_jumprelu.py index 2822e30..a198b1f 100644 --- a/scripts/convert_batchtopk_to_jumprelu.py +++ b/scripts/convert_batchtopk_to_jumprelu.py @@ -31,7 +31,16 @@ def _remap_checkpoint_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Remaps old state_dict keys to the new format with module prefixes.""" + """Remaps old state_dict keys to the new format with module prefixes. + + Handles both old-style keys (encoders.*, decoders.*) and new tied decoder parameters: + - encoder_module.feature_offset.{layer_idx}: ParameterList for per-feature bias + - encoder_module.feature_scale.{layer_idx}: ParameterList for per-feature scale + - decoder_module.skip_weights.{layer_idx}: ParameterList for tied decoders + - decoder_module.skip_weights.{src}->{tgt}: ParameterDict for untied decoders + - decoder_module.per_target_scale: Tensor for per src->tgt scale (tied decoders) + - decoder_module.per_target_bias: Tensor for per src->tgt bias (tied decoders) + """ new_state_dict = {} for key, value in state_dict.items(): if key.startswith("encoders."): @@ -41,6 +50,23 @@ def _remap_checkpoint_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, tor else: new_key = key new_state_dict[new_key] = value + + # Handle new parameter names that might not have the correct module prefix + # Create a list of keys to avoid modifying dict during iteration + keys_to_check = list(new_state_dict.keys()) + for key in keys_to_check: + # Handle feature_offset/feature_scale that might be saved without module prefix + if key.startswith("feature_offset.") and not key.startswith("encoder_module."): + new_state_dict[f"encoder_module.{key}"] = new_state_dict.pop(key) + elif key.startswith("feature_scale.") and not key.startswith("encoder_module."): + new_state_dict[f"encoder_module.{key}"] = new_state_dict.pop(key) + # Handle skip_weights that might be saved without module prefix + elif key.startswith("skip_weights.") and not key.startswith("decoder_module."): + new_state_dict[f"decoder_module.{key}"] = new_state_dict.pop(key) + # Handle per_target parameters + elif key in ["per_target_scale", "per_target_bias"] and not key.startswith("decoder_module."): + new_state_dict[f"decoder_module.{key}"] = new_state_dict.pop(key) + if not any(k.startswith("encoder_module.") or k.startswith("decoder_module.") for k in new_state_dict.keys()): if any(k.startswith("encoders.") or k.startswith("decoders.") for k in state_dict.keys()): logger.warning( diff --git a/scripts/train_clt.py b/scripts/train_clt.py index a79a316..7f7ab07 100644 --- a/scripts/train_clt.py +++ b/scripts/train_clt.py @@ -228,6 +228,13 @@ def parse_args(): action="store_true", # If flag is present, disable is true. Default behavior is enabled. help="Disable straight-through estimator for BatchTopK. (BatchTopK default is True).", ) + clt_group.add_argument( + "--topk-mode", + type=str, + choices=["global", "per_layer"], + default="global", + help="How to apply top-k selection: 'global' (across all layers) or 'per_layer' (each layer independently).", + ) clt_group.add_argument( "--topk-k", type=float, # As per CLTConfig, topk_k can be a float (fraction) or int (count) @@ -248,9 +255,9 @@ def parse_args(): clt_group.add_argument( "--decoder-tying", type=str, - choices=["none", "per_source"], + choices=["none", "per_source", "per_target"], default="none", - help="Decoder weight sharing strategy: 'none' (default) or 'per_source' (tied per source layer).", + help="Decoder weight sharing strategy: 'none' (default), 'per_source' (tied per source layer), or 'per_target' (tied per target layer, EleutherAI style).", ) clt_group.add_argument( "--per-target-scale", @@ -313,11 +320,13 @@ def parse_args(): train_group.add_argument( "--normalization-method", type=str, - choices=["auto", "none", "estimated_mean_std"], # Added estimated_mean_std from TrainingConfig - default="auto", + choices=["none", "mean_std", "sqrt_d_model"], + default="mean_std", help=( - "Normalization for activation store. 'auto' expects server/local store to provide stats. " - "'estimated_mean_std' forces estimation (if store supports it). 'none' disables." + "Normalization method for activations. " + "'none': No normalization. " + "'mean_std': Standard (x - mean) / std normalization using pre-calculated stats. " + "'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization." ), ) train_group.add_argument( @@ -646,6 +655,7 @@ def main(): enable_feature_offset=args.enable_feature_offset, enable_feature_scale=args.enable_feature_scale, skip_connection=args.skip_connection, + topk_mode=args.topk_mode, ) logger.info(f"CLT Config: {clt_config}") From 82793d7ba62a9341f6e40cf4f020b5cc82769c8f Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Mon, 30 Jun 2025 14:35:37 -0700 Subject: [PATCH 4/5] removed unused params --- clt/config/clt_config.py | 6 ---- clt/models/decoder.py | 62 +++++++++++++--------------------------- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index be08ab0..69a3b8d 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -39,8 +39,6 @@ class CLTConfig: # Tied decoder configuration decoder_tying: Literal["none", "per_source", "per_target"] = "none" # Decoder weight sharing strategy - per_target_scale: bool = False # Enable learned scale for each src->tgt path - per_target_bias: bool = False # Enable learned bias for each src->tgt path enable_feature_offset: bool = False # Enable per-feature bias (feature_offset) enable_feature_scale: bool = False # Enable per-feature scale (feature_scale) skip_connection: bool = False # Enable skip connection from input to output @@ -93,10 +91,6 @@ def from_json(cls: Type[C], json_path: str) -> C: # Handle backward compatibility for old configs if "decoder_tying" not in config_dict: config_dict["decoder_tying"] = "none" # Default to original behavior - if "per_target_scale" not in config_dict: - config_dict["per_target_scale"] = False - if "per_target_bias" not in config_dict: - config_dict["per_target_bias"] = False if "enable_feature_offset" not in config_dict: config_dict["enable_feature_offset"] = False if "enable_feature_scale" not in config_dict: diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 1a1d0cd..3a3970f 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -81,26 +81,8 @@ def __init__( elif hasattr(decoder, 'bias') and decoder.bias is not None: nn.init.zeros_(decoder.bias) - # Initialize per-target scale and bias if enabled - if config.per_target_scale: - # Initialize scale: diagonal gets ones, off-diagonal gets small values for gradient flow - # Small non-zero values allow gradients to flow even without skip connections - scale_init = torch.full((self.config.num_layers, self.config.num_layers, self.config.d_model), - 0.1, device=self.device, dtype=self.dtype) - # Set diagonal (same src->tgt layer) scales to 1.0 - for i in range(self.config.num_layers): - scale_init[i, i, :] = 1.0 - self.per_target_scale = nn.Parameter(scale_init) - else: - self.per_target_scale = None - - if config.per_target_bias: - self.per_target_bias = nn.Parameter( - torch.zeros(self.config.num_layers, self.config.num_layers, self.config.d_model, - device=self.device, dtype=self.dtype) - ) - else: - self.per_target_bias = None + # Note: EleutherAI doesn't have per-target scale/bias parameters + # These have been removed to match their architecture exactly else: # Original untied decoders: one decoder per (src, tgt) pair self.decoders = nn.ModuleDict( @@ -120,8 +102,7 @@ def __init__( for tgt_layer in range(src_layer, self.config.num_layers) } ) - self.per_target_scale = None - self.per_target_bias = None + # Note: EleutherAI doesn't have per-target scale/bias parameters # Initialize skip connection weights if enabled if config.skip_connection: @@ -226,24 +207,25 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: continue # Apply feature affine transformations (indexed by target layer) - # Note: EleutherAI applies these to ALL selected features, not just non-zero + # EleutherAI only applies these to non-zero (selected) features if self.feature_offset is not None or self.feature_scale is not None: - activation_tensor = activation_tensor.clone() + # Get non-zero positions (selected features) + nonzero_mask = activation_tensor != 0 - if self.feature_offset is not None: - # Apply offset to all features (not just non-zero) - activation_tensor += self.feature_offset[layer_idx] + if nonzero_mask.any(): + # Apply transformations only to selected features + activation_tensor = activation_tensor.clone() + batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True) - if self.feature_scale is not None: - # Apply scale to all features (not just non-zero) - activation_tensor *= self.feature_scale[layer_idx] - - # Apply per-target scale and bias if enabled (before summing) - # Note: EleutherAI doesn't have these parameters - if self.per_target_scale is not None: - activation_tensor = activation_tensor * self.per_target_scale[src_layer, layer_idx] - if self.per_target_bias is not None: - activation_tensor = activation_tensor + self.per_target_bias[src_layer, layer_idx] + if self.feature_offset is not None: + # Apply offset only to non-zero features + offset_values = self.feature_offset[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] += offset_values + + if self.feature_scale is not None: + # Apply scale only to non-zero features + scale_values = self.feature_scale[layer_idx][feature_indices] + activation_tensor[batch_indices, feature_indices] *= scale_values summed_activation += activation_tensor @@ -290,11 +272,7 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: decoder = self.decoders[src_layer] decoded = decoder(activation_tensor) - # Apply per-target scale and bias if enabled - if self.per_target_scale is not None: - decoded = decoded * self.per_target_scale[src_layer, layer_idx] - if self.per_target_bias is not None: - decoded = decoded + self.per_target_bias[src_layer, layer_idx] + # Note: EleutherAI doesn't have per-target scale/bias else: # Use untied decoder for (src, tgt) pair decoder = self.decoders[f"{src_layer}->{layer_idx}"] From f02bde6ddfaf5f1a7e67bed2dea50ca289ab13ab Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Tue, 1 Jul 2025 16:43:23 -0700 Subject: [PATCH 5/5] working tied-CLT setup --- README.md | 46 ++ clt/config/clt_config.py | 6 + clt/models/clt.py | 72 +-- clt/models/decoder.py | 58 +- .../data/manifest_activation_store.py | 5 + clt/training/evaluator.py | 18 + tests/unit/models/test_tied_decoders.py | 113 ++-- ...1B-end-to-end-training-pythia-batchtopk.py | 3 +- ...-end-training-pythia-tied-decoders copy.py | 421 +++++++++++++ ...nd-to-end-training-pythia-tied-decoders.py | 16 +- ...end-to-end-training-gpt2-batchtopk-fp16.py | 594 ++++++++++++++++++ 11 files changed, 1216 insertions(+), 136 deletions(-) create mode 100644 tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py create mode 100644 tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py diff --git a/README.md b/README.md index 6509556..5e95b7b 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ This library is intended for the training and analysis of cross-layer sparse cod A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for `gpt2-small`). This implementation focuses on the core functionality needed to train and use CLTs, leveraging `nnsight` for model introspection and `datasets` for data handling. +The library now supports **tied decoders**, which can significantly reduce the number of parameters by sharing decoder weights across layers. Instead of training separate decoders for each (source, destination) pair, tied decoders use either: +- **Per-source tying**: One decoder per source layer, shared across all destination layers +- **Per-target tying**: One decoder per destination layer, shared across all source layers + Training a CLT involves the following steps: 1. Pre-generate activations with `scripts/generate_activations` (though an implementation of `StreamingActivationStore` is on the way). 2. Train a CLT (start with an expansion factor of at least `32`) using this data. Metrics can be logged to WandB. NMSE should get below `0.25`, or ideally even below `0.10`. As mentioned above, I recommend `BatchTopK` training, and suggest keeping `K` low--`200` is a good place to start. @@ -85,6 +89,16 @@ Key configuration parameters are mapped to config classes via script arguments: - `relu`: Standard ReLU activation. - `batchtopk`: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted to `jumprelu`. - `topk`: Selects top K features per token (row-wise top-k). + + **Decoder Tying Options** (`--decoder-tying`): + - `none` (default): Traditional untied decoders - separate decoder for each (source, destination) layer pair + - `per_source`: Share decoder weights per source layer - each source layer has one decoder used for all destinations + - `per_target`: Share decoder weights per destination layer - each destination layer has one decoder that combines features from all source layers + + **Additional Tied Decoder Features**: + - `--enable-feature-offset`: Add learnable per-feature bias terms + - `--enable-feature-scale`: Add learnable per-feature scaling + - `--skip-connection`: Enable skip connections from source inputs to decoder outputs - **TrainingConfig**: `--learning-rate`, `--training-steps`, `--train-batch-size-tokens`, `--activation-source`, `--activation-path` (for `local_manifest`), remote config fields (for `remote`, e.g. `--server-url`, `--dataset-id`), `--normalization-method`, `--sparsity-lambda`, `--preactivation-coef`, `--optimizer`, `--lr-scheduler`, `--log-interval`, `--eval-interval`, `--checkpoint-interval`, `--dead-feature-window`, WandB settings (`--enable-wandb`, `--wandb-project`, etc.). ### Single GPU Training Examples @@ -139,6 +153,38 @@ python scripts/train_clt.py \\ # Add other arguments as needed ``` +**Example: Training with Tied Decoders** + +Tied decoders can significantly reduce the parameter count while maintaining performance. Here's an example using per-source tying: + +```bash +python scripts/train_clt.py \ + --activation-source local_manifest \ + --activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \ + --output-dir ./clt_output_tied \ + --model-name gpt2 \ + --num-features 6144 \ + --decoder-tying per_source \ + --enable-feature-scale \ + --skip-connection \ + --activation-fn batchtopk \ + --batchtopk-k 256 \ + --learning-rate 3e-4 \ + --training-steps 100000 \ + --train-batch-size-tokens 8192 \ + --sparsity-lambda 1e-3 \ + --log-interval 100 \ + --eval-interval 1000 \ + --checkpoint-interval 5000 \ + --enable-wandb --wandb-project clt_tied_training +``` + +This configuration: +- Uses `per_source` tying: 12 decoders instead of 78 for gpt2-small +- Enables feature scaling for better expressiveness +- Includes skip connections to preserve input information +- Uses BatchTopK with k=256 for training (can be converted to JumpReLU later) + ### Multi-GPU Training (Tensor Parallelism) This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (`torch.distributed`). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation. diff --git a/clt/config/clt_config.py b/clt/config/clt_config.py index 69a3b8d..1de0141 100644 --- a/clt/config/clt_config.py +++ b/clt/config/clt_config.py @@ -259,6 +259,12 @@ def __post_init__(self): assert ( 0.0 <= self.sparsity_lambda_delay_frac < 1.0 ), "sparsity_lambda_delay_frac must be between 0.0 (inclusive) and 1.0 (exclusive)" + + # Validate normalization method + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] + assert ( + self.normalization_method in valid_norm_methods + ), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}" @dataclass diff --git a/clt/models/clt.py b/clt/models/clt.py index 61763e5..ff27f01 100644 --- a/clt/models/clt.py +++ b/clt/models/clt.py @@ -180,51 +180,8 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor: return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype) - def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: - return self.decoder_module.decode(a, layer_idx) - - def _apply_skip_connection(self, input_tensor: torch.Tensor, layer_idx: int) -> torch.Tensor: - """Apply skip connection transformation to input. - - Args: - input_tensor: Input tensor at the given layer - layer_idx: Target layer index - - Returns: - Transformed input through skip connection - """ - if self.decoder_module.skip_weights is None: - return torch.zeros_like(input_tensor) - - # Ensure input is 2D for matrix multiplication - original_shape = input_tensor.shape - if input_tensor.dim() == 3: - # Flatten batch and sequence dimensions - input_2d = input_tensor.view(-1, input_tensor.shape[-1]) - else: - input_2d = input_tensor - - # Apply skip connection weight - if self.config.decoder_tying in ["per_source", "per_target"]: - # For tied decoders, use skip weight for this target layer - skip_weight = self.decoder_module.skip_weights[layer_idx] - else: - # For untied, we need to sum contributions from all source layers - # For now, just use the diagonal skip connection (src=tgt) - skip_key = f"{layer_idx}->{layer_idx}" - if skip_key in self.decoder_module.skip_weights: - skip_weight = self.decoder_module.skip_weights[skip_key] - else: - return torch.zeros_like(input_tensor) - - # Apply transformation: input @ W_skip^T - skip_output = input_2d @ skip_weight.T - - # Reshape back to original shape - if input_tensor.dim() == 3: - skip_output = skip_output.view(original_shape) - - return skip_output + def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor: + return self.decoder_module.decode(a, layer_idx, source_inputs) def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: activations = self.get_feature_activations(inputs) @@ -235,12 +192,9 @@ def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]: for layer_idx in range(self.config.num_layers): relevant_activations = {k: v for k, v in activations.items() if k <= layer_idx and v.numel() > 0} if layer_idx in inputs and relevant_activations: - reconstruction = self.decode(relevant_activations, layer_idx) - - # Apply skip connection if enabled - if self.config.skip_connection and layer_idx in inputs: - skip_output = self._apply_skip_connection(inputs[layer_idx], layer_idx) - reconstruction = reconstruction + skip_output + # Pass source inputs for EleutherAI-style skip connections + source_inputs = {k: inputs[k] for k in range(layer_idx + 1) if k in inputs} if self.config.skip_connection else None + reconstruction = self.decode(relevant_activations, layer_idx, source_inputs) reconstructions[layer_idx] = reconstruction elif layer_idx in inputs: @@ -460,16 +414,12 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = Tr logger.info(f"Initializing missing {key} (first target layer to ones, rest to zeros)") # Don't add to state_dict to let it be initialized by the module - # Handle missing per-target parameters - if self.config.per_target_scale and hasattr(self.decoder_module, 'per_target_scale'): - key = "decoder_module.per_target_scale" - if key not in state_dict: - logger.info(f"Initializing missing {key} (diagonal to ones, off-diagonal to zeros)") - - if self.config.per_target_bias and hasattr(self.decoder_module, 'per_target_bias'): - key = "decoder_module.per_target_bias" - if key not in state_dict: - logger.info(f"Initializing missing {key} to zeros") + # Handle missing skip weights + if self.config.skip_connection and hasattr(self.decoder_module, 'skip_weights'): + for i in range(self.config.num_layers): + key = f"decoder_module.skip_weights.{i}" + if key not in state_dict: + logger.info(f"Initializing missing {key} to identity matrix") # Call parent's load_state_dict return super().load_state_dict(state_dict, strict=strict) diff --git a/clt/models/decoder.py b/clt/models/decoder.py index 3a3970f..f6e9060 100644 --- a/clt/models/decoder.py +++ b/clt/models/decoder.py @@ -160,7 +160,7 @@ def __init__( self.register_buffer("_cached_decoder_norms", None, persistent=False) - def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: + def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor: """Decode the feature activations to reconstruct outputs at the specified layer. Input activations `a` are expected to be the *full* tensors. @@ -233,6 +233,26 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: decoder = self.decoders[layer_idx] reconstruction = decoder(summed_activation) + # Apply skip connections from source inputs if enabled + if self.skip_weights is not None and source_inputs is not None: + skip_weight = self.skip_weights[layer_idx] + # Add skip connections from each source layer that contributed + for src_layer in range(layer_idx + 1): + if src_layer in source_inputs: + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + reconstruction += skip_contribution + else: # Original logic for per_source and untied decoders for src_layer in range(layer_idx + 1): @@ -272,12 +292,46 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor: decoder = self.decoders[src_layer] decoded = decoder(activation_tensor) - # Note: EleutherAI doesn't have per-target scale/bias + # Apply skip connection from this source input if enabled + if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs: + skip_weight = self.skip_weights[layer_idx] + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + decoded += skip_contribution else: # Use untied decoder for (src, tgt) pair decoder = self.decoders[f"{src_layer}->{layer_idx}"] decoded = decoder(activation_tensor) + # Apply skip connection from this source input if enabled + if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs: + skip_key = f"{src_layer}->{layer_idx}" + if skip_key in self.skip_weights: + skip_weight = self.skip_weights[skip_key] + source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype) + # Flatten if needed + original_shape = source_input.shape + if source_input.dim() == 3: + source_input_2d = source_input.view(-1, source_input.shape[-1]) + else: + source_input_2d = source_input + # Apply skip: source @ W_skip^T + skip_contribution = source_input_2d @ skip_weight.T + # Reshape back if needed + if source_input.dim() == 3: + skip_contribution = skip_contribution.view(original_shape) + decoded += skip_contribution + reconstruction += decoded return reconstruction diff --git a/clt/training/data/manifest_activation_store.py b/clt/training/data/manifest_activation_store.py index b441362..0a01879 100644 --- a/clt/training/data/manifest_activation_store.py +++ b/clt/training/data/manifest_activation_store.py @@ -491,6 +491,11 @@ def __init__( elif normalization_method == "sqrt_d_model": # sqrt_d_model doesn't need norm stats, just applies scaling self.apply_normalization = True + else: + raise ValueError( + f"Invalid normalization_method: {normalization_method}. " + f"Must be one of ['none', 'mean_std', 'sqrt_d_model']" + ) if self.apply_normalization: self._prep_norm() diff --git a/clt/training/evaluator.py b/clt/training/evaluator.py index 292f9e4..19570bc 100644 --- a/clt/training/evaluator.py +++ b/clt/training/evaluator.py @@ -55,6 +55,14 @@ def __init__( # Store normalisation stats if provided self.mean_tg = mean_tg or {} self.std_tg = std_tg or {} + + # Validate normalization method + valid_norm_methods = ["none", "mean_std", "sqrt_d_model"] + if normalization_method not in valid_norm_methods: + raise ValueError( + f"Invalid normalization_method: {normalization_method}. " + f"Must be one of {valid_norm_methods}" + ) self.normalization_method = normalization_method self.d_model = d_model self.metrics_history: List[Dict[str, Any]] = [] # For storing metrics over time if needed @@ -255,6 +263,10 @@ def _compute_reconstruction_metrics( total_explained_variance = 0.0 total_nmse = 0.0 num_layers = 0 + + # For layerwise metrics + layerwise_nmse = {} + layerwise_explained_variance = {} for layer_idx, target_act in targets.items(): if layer_idx not in reconstructions: @@ -312,6 +324,10 @@ def _compute_reconstruction_metrics( else: # Target variance is zero but MSE is non-zero (implies error, NMSE is effectively infinite) nmse_layer = float("inf") # Or a large number, or handle as NaN depending on preference total_nmse += nmse_layer + + # Store layerwise metrics + layerwise_nmse[f"layer_{layer_idx}"] = nmse_layer + layerwise_explained_variance[f"layer_{layer_idx}"] = explained_variance_layer num_layers += 1 @@ -327,6 +343,8 @@ def _compute_reconstruction_metrics( return { "reconstruction/explained_variance": avg_explained_variance, "reconstruction/normalized_mean_reconstruction_error": avg_normalized_mean_reconstruction_error, + "layerwise/normalized_mse": layerwise_nmse, + "layerwise/explained_variance": layerwise_explained_variance, } def _compute_feature_density(self, activations: Dict[int, torch.Tensor]) -> Dict[str, Any]: diff --git a/tests/unit/models/test_tied_decoders.py b/tests/unit/models/test_tied_decoders.py index be3d6c4..9ef6df2 100644 --- a/tests/unit/models/test_tied_decoders.py +++ b/tests/unit/models/test_tied_decoders.py @@ -72,48 +72,33 @@ def test_decoder_initialization_tied(self, tied_config): for layer in range(tied_config.num_layers): assert isinstance(decoder.decoders[layer], nn.Module) - def test_per_target_parameters(self, tied_config): - """Test per-target scale and bias parameters.""" - # Test with per-target scale - config_with_scale = CLTConfig( - **{**tied_config.__dict__, "per_target_scale": True} + def test_skip_connections(self, tied_config): + """Test skip connection functionality.""" + # Test with skip connections enabled + config_with_skip = CLTConfig( + **{**tied_config.__dict__, "skip_connection": True} ) decoder = Decoder( - config=config_with_scale, + config=config_with_skip, process_group=None, device=torch.device("cpu"), dtype=torch.float32, ) - assert decoder.per_target_scale is not None - assert decoder.per_target_scale.shape == ( - config_with_scale.num_layers, - config_with_scale.num_layers, - config_with_scale.d_model, - ) - assert torch.allclose(decoder.per_target_scale, torch.ones_like(decoder.per_target_scale)) - - # Test with per-target bias - config_with_bias = CLTConfig( - **{**tied_config.__dict__, "per_target_bias": True} - ) - decoder = Decoder( - config=config_with_bias, - process_group=None, - device=torch.device("cpu"), - dtype=torch.float32, - ) + # Skip weights should be initialized + assert decoder.skip_weights is not None + assert len(decoder.skip_weights) == config_with_skip.num_layers - assert decoder.per_target_bias is not None - assert decoder.per_target_bias.shape == ( - config_with_bias.num_layers, - config_with_bias.num_layers, - config_with_bias.d_model, - ) - assert torch.allclose(decoder.per_target_bias, torch.zeros_like(decoder.per_target_bias)) + # Each skip weight should have correct shape + for layer_idx in range(config_with_skip.num_layers): + skip_weight = decoder.skip_weights[layer_idx] + assert skip_weight.shape == (config_with_skip.d_model, config_with_skip.d_model) + # Should be initialized to zeros + expected = torch.zeros(config_with_skip.d_model, config_with_skip.d_model, dtype=torch.float32) + assert torch.allclose(skip_weight, expected) def test_feature_affine_parameters(self): - """Test feature offset and scale parameters in encoder.""" + """Test feature offset and scale parameters in decoder.""" config = CLTConfig( num_features=128, num_layers=4, @@ -121,9 +106,10 @@ def test_feature_affine_parameters(self): activation_fn="relu", enable_feature_offset=True, enable_feature_scale=True, + decoder_tying="per_source", # Feature affine only works with tied decoders ) - encoder = Encoder( + decoder = Decoder( config=config, process_group=None, device=torch.device("cpu"), @@ -131,18 +117,23 @@ def test_feature_affine_parameters(self): ) # Check feature_offset initialization - assert encoder.feature_offset is not None - assert len(encoder.feature_offset) == config.num_layers - for layer_offset in encoder.feature_offset: - assert layer_offset.shape == (config.num_features,) - assert torch.allclose(layer_offset, torch.zeros_like(layer_offset)) + assert decoder.feature_offset is not None + assert len(decoder.feature_offset) == config.num_layers + for layer_idx in range(config.num_layers): + assert decoder.feature_offset[layer_idx].shape == (config.num_features,) + assert torch.allclose(decoder.feature_offset[layer_idx], torch.zeros_like(decoder.feature_offset[layer_idx])) # Check feature_scale initialization - assert encoder.feature_scale is not None - assert len(encoder.feature_scale) == config.num_layers - for layer_scale in encoder.feature_scale: - assert layer_scale.shape == (config.num_features,) - assert torch.allclose(layer_scale, torch.ones_like(layer_scale)) + assert decoder.feature_scale is not None + assert len(decoder.feature_scale) == config.num_layers + for layer_idx in range(config.num_layers): + assert decoder.feature_scale[layer_idx].shape == (config.num_features,) + # First layer should be ones, rest should be 0.1 for tied decoders + if layer_idx == 0: + assert torch.allclose(decoder.feature_scale[layer_idx], torch.ones_like(decoder.feature_scale[layer_idx])) + else: + expected = torch.full_like(decoder.feature_scale[layer_idx], 0.1) + assert torch.allclose(decoder.feature_scale[layer_idx], expected) def test_decode_with_tied_decoders(self, tied_config): """Test decoding with tied decoders.""" @@ -192,7 +183,7 @@ def test_decoder_norms_tied(self, tied_config): assert torch.all(norms >= 0) def test_feature_affine_transformation(self): - """Test feature affine transformation in forward pass.""" + """Test feature affine transformation in decoder.""" config = CLTConfig( num_features=128, num_layers=2, @@ -200,32 +191,32 @@ def test_feature_affine_transformation(self): activation_fn="relu", enable_feature_offset=True, enable_feature_scale=True, + decoder_tying="per_source", ) - model = CrossLayerTranscoder( + decoder = Decoder( config=config, process_group=None, device=torch.device("cpu"), + dtype=torch.float32, ) - # Create test inputs + # Create test activations batch_size = 4 - seq_len = 16 - inputs = { - 0: torch.randn(batch_size, seq_len, config.d_model), - 1: torch.randn(batch_size, seq_len, config.d_model), + test_activations = { + 0: torch.randn(batch_size, config.num_features), + 1: torch.randn(batch_size, config.num_features), } - # Get activations - activations = model.get_feature_activations(inputs) + # Set some specific values for testing + decoder.feature_offset[0].data.fill_(0.5) + decoder.feature_scale[0].data.fill_(2.0) - # Apply affine transformation - transformed = model._apply_feature_affine(activations) + # Decode at layer 1 (should use features from layers 0 and 1) + result = decoder.decode(test_activations, layer_idx=1) - # The transformation should preserve zeros - for layer_idx in transformed: - zero_mask = activations[layer_idx] == 0 - assert torch.all(transformed[layer_idx][zero_mask] == 0) + # Result should have correct shape + assert result.shape == (batch_size, config.d_model) def test_backward_compatibility_config(self): """Test loading old config without new fields.""" @@ -234,8 +225,7 @@ def test_backward_compatibility_config(self): "num_layers": 4, "d_model": 64, "activation_fn": "relu", - # Missing: decoder_tying, per_target_scale, per_target_bias, - # enable_feature_offset, enable_feature_scale + # Missing: decoder_tying, enable_feature_offset, enable_feature_scale, skip_connection } # Should not raise an error @@ -243,10 +233,9 @@ def test_backward_compatibility_config(self): # Should have default values assert config.decoder_tying == "none" - assert config.per_target_scale == False - assert config.per_target_bias == False assert config.enable_feature_offset == False assert config.enable_feature_scale == False + assert config.skip_connection == False def test_checkpoint_compatibility(self, base_config, tied_config): """Test loading old untied checkpoint into tied model.""" diff --git a/tutorials/1B-end-to-end-training-pythia-batchtopk.py b/tutorials/1B-end-to-end-training-pythia-batchtopk.py index 5ef0177..248c03a 100644 --- a/tutorials/1B-end-to-end-training-pythia-batchtopk.py +++ b/tutorials/1B-end-to-end-training-pythia-batchtopk.py @@ -175,7 +175,7 @@ train_batch_size_tokens=_batch_size, sampling_strategy="sequential", # Normalization - normalization_method="auto", # Use pre-calculated stats + normalization_method="mean_std", # Use pre-calculated stats # Loss function coefficients sparsity_lambda=0.0, # Disable standard sparsity penalty sparsity_lambda_schedule="linear", @@ -194,7 +194,6 @@ max_features_for_diag_hist=1000, # optional cap per layer checkpoint_interval=500, dead_feature_window=200, - p # WandB (Optional) enable_wandb=True, wandb_project="clt-hp-sweeps-pythia-70m", diff --git a/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py b/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py new file mode 100644 index 0000000..bd1c9b1 --- /dev/null +++ b/tutorials/1F-end-to-end-training-pythia-tied-decoders copy.py @@ -0,0 +1,421 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with Tied Decoders and Feature Offset +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) using: +# - **Tied decoder architecture** to reduce memory usage +# - **Feature offset parameters** for per-feature bias +# - **BatchTopK activation** (same as Tutorial 1B) +# +# The tied decoder architecture uses one decoder per source layer (instead of one per source-target pair), +# significantly reducing memory usage from O(L²) to O(L) decoder parameters. +# +# We will: +# 1. Configure the CLT model with tied decoders and feature offset +# 2. Use the same pre-generated activations from Tutorial 1B +# 3. Train the model and compare memory usage +# 4. Demonstrate loading checkpoints with the new architecture + +# %% [markdown] +# ## 1. Imports and Setup + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Ensure tokenizers don't use parallelism +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Add project root to path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction (same as Tutorial 1B) +BASE_MODEL_NAME = "EleutherAI/pythia-70m" + +# %% [markdown] +# ## 2. Configuration with Tied Decoders +# +# Key differences from Tutorial 1B: +# - `decoder_tying="per_source"` - Enables tied decoder architecture +# - `enable_feature_offset=True` - Adds learnable per-feature bias +# - Memory savings: For 6 layers, we go from 21 decoders to just 6 + +# %% +# --- CLT Architecture Configuration with Tied Decoders --- +num_layers = 6 +d_model = 512 +expansion_factor = 32 +clt_num_features = d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", + batchtopk_k=batchtopk_k, + batchtopk_straight_through=True, + # NEW: Tied decoder configuration + decoder_tying="per_target", # Use one decoder per source layer + enable_feature_offset=True, # Enable per-feature bias (feature_offset) + enable_feature_scale=False, # Enable per-feature scale (feature_scale) + skip_connection=True, # Enable skip connection from input to output +) + +print("CLT Configuration (Tied Decoders with Feature Affine):") +print(f"- decoder_tying: {clt_config.decoder_tying}") +print(f"- enable_feature_offset: {clt_config.enable_feature_offset}") +print(f"- enable_feature_scale: {clt_config.enable_feature_scale}") +print(f"- skip_connection: {clt_config.skip_connection}") +print(f"- Number of features: {clt_config.num_features}") +print(f"- Number of layers: {clt_config.num_layers}") +print(f"- Activation function: {clt_config.activation_fn}") +print(f"- BatchTopK k: {clt_config.batchtopk_k}") + +# Calculate memory savings +untied_decoders = sum(range(1, num_layers + 1)) # 6 + 5 + 4 + 3 + 2 + 1 = 21 +tied_decoders = num_layers # 6 +print(f"\nMemory savings:") +print(f"- Untied decoders: {untied_decoders} decoder matrices") +print(f"- Tied decoders: {tied_decoders} decoder matrices") +print(f"- Reduction: {(1 - tied_decoders/untied_decoders)*100:.1f}%") + +# --- Use existing activations from Tutorial 1B --- +# We'll use the same activation directory as Tutorial 1B since the base model +# and dataset are identical - only the CLT architecture differs +activation_dir = "./tutorial_activations_local_1M_pythia" +dataset_name = "monology/pile-uncopyrighted" + +expected_activation_path = os.path.join( + activation_dir, + BASE_MODEL_NAME, + f"{os.path.basename(dataset_name)}_train", +) + +# Verify activations exist +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if not (os.path.exists(metadata_path) and os.path.exists(manifest_path)): + print(f"\nERROR: Activations not found at {expected_activation_path}") + print("Please run Tutorial 1B first to generate the activations.") + raise FileNotFoundError("Activation dataset not found") +else: + print(f"\nUsing existing activations from: {expected_activation_path}") + +# --- Training Configuration --- +_lr = 1e-4 +_batch_size = 1024 + +# WandB run name includes tied decoder info +wdb_run_name = ( + f"{clt_config.num_features}-width-" + f"tied-decoders-" # Indicate tied decoder architecture + f"feat-offset-" # Indicate feature offset is enabled + f"batchtopk-k{batchtopk_k}-" + f"{_batch_size}-batch-" + f"{_lr:.1e}-lr" +) +print(f"\nGenerated WandB run name: {wdb_run_name}") + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Same as Tutorial 1B for comparison + seed=42, + # Activation source (using existing activations) + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float32", + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + # Normalization + normalization_method="sqrt_d_model", + # Loss function coefficients (same as Tutorial 1B) + sparsity_lambda=0.0, + sparsity_lambda_schedule="linear", + sparsity_c=0.0, + preactivation_coef=0, + aux_loss_factor=1 / 32, + apply_sparsity_penalty_to_batchtopk=False, + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, + max_features_for_diag_hist=1000, + checkpoint_interval=500, + dead_feature_window=200, + # WandB + enable_wandb=True, + wandb_project="clt-debug-pythia-70m", + wandb_run_name=wdb_run_name, +) + +print("\nTraining Configuration:") +print(f"- Learning rate: {training_config.learning_rate}") +print(f"- Training steps: {training_config.training_steps}") +print(f"- Batch size (tokens): {training_config.train_batch_size_tokens}") + +# %% [markdown] +# ## 3. Initialize Model and Check Architecture +# +# Let's create the model and verify the tied decoder architecture is set up correctly. + +# %% +print("\nInitializing CLT model with tied decoders...") + +# Create model instance to inspect architecture +model = CrossLayerTranscoder( + config=clt_config, + process_group=None, + device=torch.device(device), +) + +print("\nModel architecture inspection:") +print(f"- Encoder modules: {len(model.encoder_module.encoders)}") +print(f"- Decoder modules: {len(model.decoder_module.decoders)}") + +# Check feature offset parameters +if model.decoder_module.feature_offset is not None: + print(f"- Feature offset parameters per layer: {len(model.decoder_module.feature_offset)}") + print(f"- Feature offset shape (layer 0): {model.decoder_module.feature_offset[0].shape}") +else: + print("- Feature offset: Not enabled") + +# Count total parameters +total_params = sum(p.numel() for p in model.parameters()) +encoder_params = sum(p.numel() for p in model.encoder_module.parameters()) +decoder_params = sum(p.numel() for p in model.decoder_module.parameters()) +print(f"\nParameter counts:") +print(f"- Total parameters: {total_params:,}") +print(f"- Encoder parameters: {encoder_params:,}") +print(f"- Decoder parameters: {decoder_params:,}") + +# Compare with untied architecture (approximate) +untied_decoder_params_approx = decoder_params * (untied_decoders / tied_decoders) +print(f"\nEstimated decoder parameters if untied: {untied_decoder_params_approx:,}") +print(f"Memory savings in decoder: {(1 - decoder_params/untied_decoder_params_approx)*100:.1f}%") + +# Clean up the test model +del model + +# %% [markdown] +# ## 4. Training the CLT with Tied Decoders + +# %% +print("\nInitializing CLTTrainer for training with tied decoders...") + +log_dir = f"clt_training_logs/clt_pythia_tied_decoders_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("\nCreating CLTTrainer instance...") + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("\nBeginning training with tied decoders...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Decoder tying: {clt_config.decoder_tying}") +print(f"Feature offset enabled: {clt_config.enable_feature_offset}") + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"\nTraining finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Tied Decoder Model + +# %% +# Save the final model state and config +final_model_state_path = os.path.join(log_dir, "clt_tied_final_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_tied_final_config.json") + +print(f"\nSaving final model state to: {final_model_state_path}") +print(f"Saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +# Verify the saved config has tied decoder settings +with open(final_model_config_path, "r") as f: + saved_config = json.load(f) + print(f"\nSaved config verification:") + print(f"- decoder_tying: {saved_config['decoder_tying']}") + print(f"- enable_feature_offset: {saved_config['enable_feature_offset']}") + print(f"- activation_fn: {saved_config['activation_fn']} (converted from batchtopk)") + +# Load the model back +print("\nLoading the saved tied decoder model...") +loaded_config = CLTConfig(**saved_config) +loaded_model = CrossLayerTranscoder( + config=loaded_config, + process_group=None, + device=torch.device(device), +) +loaded_model.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_model.eval() + +print("Model loaded successfully.") +print(f"Loaded model decoder count: {len(loaded_model.decoder_module.decoders)}") + +# %% [markdown] +# ## 6. Backward Compatibility Test +# +# Test loading an old untied checkpoint into our tied decoder model. +# This demonstrates the backward compatibility feature. + +# %% +print("\n=== Testing Backward Compatibility ===") + +# Create a simple untied model for testing +untied_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", # Simple activation for testing + decoder_tying="none", # Untied decoders +) + +print("Creating untied model for compatibility test...") +untied_model = CrossLayerTranscoder( + config=untied_config, + process_group=None, + device=torch.device("cpu"), # Use CPU for this test +) + +# Save untied model state +untied_state_dict = untied_model.state_dict() +print(f"Untied model decoder keys (first 5): {list(k for k in untied_state_dict.keys() if 'decoder' in k)[:5]}") + +# Create tied model with same dimensions +tied_test_config = CLTConfig( + num_features=clt_config.num_features, + num_layers=clt_config.num_layers, + d_model=clt_config.d_model, + activation_fn="relu", + decoder_tying="per_source", # Tied decoders + enable_feature_offset=True, # This will be initialized to defaults +) + +tied_test_model = CrossLayerTranscoder( + config=tied_test_config, + process_group=None, + device=torch.device("cpu"), +) + +print("\nLoading untied checkpoint into tied model...") +try: + # This should work due to our custom load_state_dict + tied_test_model.load_state_dict(untied_state_dict, strict=False) + print("✓ Successfully loaded untied checkpoint into tied model!") + print(" The tied model uses diagonal decoder weights from the untied model.") +except Exception as e: + print(f"✗ Failed to load: {e}") + +# Clean up test models +del untied_model, tied_test_model + +# %% [markdown] +# ## 7. Performance Comparison Summary + +# %% +print("\n=== Tied Decoder Architecture Summary ===") +print(f"\nConfiguration used:") +print(f"- Model: {BASE_MODEL_NAME}") +print(f"- Layers: {num_layers}") +print(f"- Hidden dimension: {d_model}") +print(f"- Features per layer: {clt_num_features}") +print(f"- Decoder tying: {clt_config.decoder_tying}") +print(f"- Feature offset: {clt_config.enable_feature_offset}") + +print(f"\nMemory efficiency:") +print(f"- Traditional CLT: {untied_decoders} decoder matrices") +print(f"- Tied decoder CLT: {tied_decoders} decoder matrices") +print(f"- Memory reduction: ~{(1 - tied_decoders/untied_decoders)*100:.0f}%") + +print(f"\nKey benefits:") +print(f"1. Significant memory savings for decoder parameters") +print(f"2. Simpler feature interpretability (one decoder per source)") +print(f"3. Feature offset allows per-feature adaptation") +print(f"4. Backward compatible with existing checkpoints") + +print(f"\nTrade-offs:") +print(f"1. Less flexibility in source-target specific adaptations") +print(f"2. May require careful tuning of feature offset parameters") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial demonstrated: +# - Training a CLT with tied decoder architecture +# - Using feature offset parameters for per-feature bias +# - Significant memory savings compared to traditional CLT +# - Backward compatibility with untied checkpoints +# +# You can experiment with: +# - `per_target_scale` and `per_target_bias` for more flexibility +# - `enable_feature_scale` for per-feature scaling +# - Different values of `k` for BatchTopK +# - Comparing reconstruction quality between tied and untied architectures + +# %% +print(f"\n✓ Tied Decoder Tutorial Complete!") +print(f"Model and logs saved to: {log_dir}") diff --git a/tutorials/1F-end-to-end-training-pythia-tied-decoders.py b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py index 2b5609f..d5a6c1a 100644 --- a/tutorials/1F-end-to-end-training-pythia-tied-decoders.py +++ b/tutorials/1F-end-to-end-training-pythia-tied-decoders.py @@ -90,11 +90,9 @@ batchtopk_k=batchtopk_k, batchtopk_straight_through=True, # NEW: Tied decoder configuration - decoder_tying="per_source", # Use one decoder per source layer + decoder_tying="per_target", # Use one decoder per source layer enable_feature_offset=True, # Enable per-feature bias (feature_offset) - enable_feature_scale=True, # Enable per-feature scale (feature_scale) - per_target_scale=True, # Not using per-target adaptations - per_target_bias=True, + enable_feature_scale=False, # Enable per-feature scale (feature_scale) skip_connection=True, # Enable skip connection from input to output ) @@ -167,7 +165,7 @@ train_batch_size_tokens=_batch_size, sampling_strategy="sequential", # Normalization - normalization_method="auto", + normalization_method="none", # Loss function coefficients (same as Tutorial 1B) sparsity_lambda=0.0, sparsity_lambda_schedule="linear", @@ -188,7 +186,7 @@ dead_feature_window=200, # WandB enable_wandb=True, - wandb_project="clt-hp-sweeps-pythia-70m", + wandb_project="clt-debug-pythia-70m", wandb_run_name=wdb_run_name, ) @@ -217,9 +215,9 @@ print(f"- Decoder modules: {len(model.decoder_module.decoders)}") # Check feature offset parameters -if model.encoder_module.feature_offset is not None: - print(f"- Feature offset parameters per layer: {len(model.encoder_module.feature_offset)}") - print(f"- Feature offset shape (layer 0): {model.encoder_module.feature_offset[0].shape}") +if model.decoder_module.feature_offset is not None: + print(f"- Feature offset parameters per layer: {len(model.decoder_module.feature_offset)}") + print(f"- Feature offset shape (layer 0): {model.decoder_module.feature_offset[0].shape}") else: print("- Feature offset: Not enabled") diff --git a/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py b/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py new file mode 100644 index 0000000..f563323 --- /dev/null +++ b/tutorials/1G-end-to-end-training-gpt2-batchtopk-fp16.py @@ -0,0 +1,594 @@ +# %% [markdown] +# # Tutorial: End-to-End CLT Training with GPT-2, BatchTopK, and FP16 +# +# This tutorial demonstrates training a Cross-Layer Transcoder (CLT) +# on **GPT-2** using the **BatchTopK** activation function and **FP16** precision. We will: +# 1. Configure the CLT model for GPT-2, BatchTopK, and FP16 training. +# 2. Generate FP16 activations locally (with manifest) using the ActivationGenerator. +# 3. Configure the trainer to use the locally stored FP16 activations. +# 4. Train the CLT model using BatchTopK activation in mixed precision. +# 5. Save and load the final trained model (which will be JumpReLU if converted). +# 6. Load a model from a distributed checkpoint. +# 7. Perform a post-hoc conversion sweep (θ scaling) on a BatchTopK checkpoint. + +# %% [markdown] +# ## 1. Imports and Setup +# +# First, let's import the necessary components and set up the device. + +# %% +import torch +import os +import time +import sys +import traceback +import json +from torch.distributions.normal import Normal # For post-hoc sweep +from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +from torch.distributed.checkpoint.filesystem import FileSystemReader +from typing import Optional, Dict +import logging # Import logging + +# Configure logging to show INFO level messages for the notebook +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - [%(name)s] %(message)s") + +# Import from torch.distributed.checkpoint and related modules later, only when needed for that specific section +# from torch.distributed.checkpoint import load_state_dict +# from torch.distributed.checkpoint.filesystem import FileSystemReader + +# logging.basicConfig(level=logging.DEBUG) + +# Import components from the clt library +# (Ensure the 'clt' directory is in your Python path or installed) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from clt.config import CLTConfig, TrainingConfig, ActivationConfig + from clt.activation_generation.generator import ActivationGenerator + from clt.training.trainer import CLTTrainer + from clt.models.clt import CrossLayerTranscoder + from clt.training.data import BaseActivationStore +except ImportError as e: + print(f"ImportError: {e}") + print("Please ensure the 'clt' library is installed or the clt directory is in your PYTHONPATH.") + raise + +# Device setup +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print(f"Using device: {device}") + +# Base model for activation extraction +BASE_MODEL_NAME = "gpt2" + +# For post-hoc sweep N(0,1) assumption +std_normal = Normal(0, 1) + +# %% [markdown] +# ## 2. Configuration +# +# We configure the CLT, Activation Generation, and Training for GPT-2 with FP16. +# Key changes: `CLTConfig` matches GPT-2 dims, `ActivationConfig` and `TrainingConfig` use FP16. + +# %% +# --- CLT Architecture Configuration --- +num_layers = 12 # GPT-2 small +d_model = 768 # GPT-2 small +expansion_factor = 32 +clt_num_features = 16384 # d_model * expansion_factor + +batchtopk_k = 200 + +clt_config = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", # Use BatchTopK activation + batchtopk_k=batchtopk_k, # Specify k directly + batchtopk_straight_through=True, # Use STE for gradients + # jumprelu_threshold is not used for batchtopk +) +print("CLT Configuration (BatchTopK for GPT-2):") +print(clt_config) + +# --- Activation Generation Configuration --- +# Generate FP16 activations from GPT-2 +activation_dir = "./tutorial_activations_local_1M_fp16" +dataset_name = "monology/pile-uncopyrighted" +activation_config = ActivationConfig( + # Model Source + model_name=BASE_MODEL_NAME, + mlp_input_module_path_template="transformer.h.{}.ln_2.input", + mlp_output_module_path_template="transformer.h.{}.mlp.output", + model_dtype=None, + # Dataset Source + dataset_path=dataset_name, + dataset_split="train", + dataset_text_column="text", + # Generation Parameters + context_size=128, + inference_batch_size=192, + exclude_special_tokens=True, + prepend_bos=True, + # Dataset Handling + streaming=True, + dataset_trust_remote_code=False, + cache_path=None, + # Generation Output Control + target_total_tokens=1_000_000, # Keep it small for tutorial + # Storage Parameters + activation_dir=activation_dir, + output_format="hdf5", + compression="gzip", + chunk_token_threshold=16_000, + activation_dtype="float16", # Store activations in FP16 + # Normalization + compute_norm_stats=True, + # NNsight args + nnsight_tracer_kwargs={}, + nnsight_invoker_args={}, +) +print("Activation Generation Configuration:") +print(activation_config) + +# --- Training Configuration --- +expected_activation_path = os.path.join( + activation_config.activation_dir, + activation_config.model_name, + f"{os.path.basename(activation_config.dataset_path)}_{activation_config.dataset_split}", +) + +# --- Determine WandB Run Name (using config values) --- +_lr = 1e-4 +_batch_size = 1024 +_k_int = clt_config.batchtopk_k + +wdb_run_name = ( + f"gpt2-{clt_config.num_features}-width-" f"batchtopk-k{_k_int}-" f"{_batch_size}-batch-" f"{_lr:.1e}-lr-fp16" +) +print("\nGenerated WandB run name: " + wdb_run_name) + +training_config = TrainingConfig( + # Training loop parameters + learning_rate=_lr, + training_steps=1000, # Reduced steps for tutorial + seed=42, + # Activation source + activation_source="local_manifest", + activation_path=expected_activation_path, + activation_dtype="float16", # Load activations in FP16 + # Training batch size + train_batch_size_tokens=_batch_size, + sampling_strategy="sequential", + precision="fp16", # Enable mixed-precision training + # Normalization + normalization_method="mean_std", # Use pre-calculated stats + # Loss function coefficients + sparsity_lambda=0.0, # Disable standard sparsity penalty + sparsity_lambda_schedule="linear", + sparsity_c=0.0, # Disable standard sparsity penalty + preactivation_coef=0, # Disable preactivation loss (AuxK handles dead latents) + aux_loss_factor=1 / 32, # Enable AuxK loss with typical factor from paper + apply_sparsity_penalty_to_batchtopk=False, # Ensure standard sparsity penalty is off for BatchTopK + # Optimizer & Scheduler + optimizer="adamw", + lr_scheduler="linear_final20", + optimizer_beta2=0.98, + # Logging & Checkpointing + log_interval=10, + eval_interval=50, + diag_every_n_eval_steps=1, # run diagnostics every eval + max_features_for_diag_hist=1000, # optional cap per layer + checkpoint_interval=500, + dead_feature_window=200, + # WandB (Optional) + enable_wandb=True, + wandb_project="clt-debug-gpt2", + wandb_run_name=wdb_run_name, +) +print("\nTraining Configuration (BatchTopK, FP16):") +print(training_config) + + +# %% [markdown] +# ## 3. Generate Activations (One-Time Step) +# +# Generate the activation dataset for GPT-2 in FP16, including the manifest file. + +# %% +print("Step 1: Generating/Verifying Activations (including manifest)...") + +metadata_path = os.path.join(expected_activation_path, "metadata.json") +manifest_path = os.path.join(expected_activation_path, "index.bin") + +if os.path.exists(metadata_path) and os.path.exists(manifest_path): + print(f"Activations and manifest already found at: {expected_activation_path}") + print("Skipping generation. Delete the directory to regenerate.") +else: + print(f"Activations or manifest not found. Generating them now at: {expected_activation_path}") + try: + generator = ActivationGenerator( + cfg=activation_config, + device=device, + ) + generation_start_time = time.time() + generator.generate_and_save() + generation_end_time = time.time() + print(f"Activation generation complete in {generation_end_time - generation_start_time:.2f}s.") + except Exception as gen_err: + print(f"[ERROR] Activation generation failed: {gen_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 4. Training the CLT with BatchTopK Activation and FP16 +# +# Instantiate the `CLTTrainer` for FP16 training. + +# %% +print("Initializing CLTTrainer for training with BatchTopK and FP16...") + +log_dir = f"clt_training_logs/clt_gpt2_batchtopk_fp16_train_{int(time.time())}" +os.makedirs(log_dir, exist_ok=True) +print(f"Logs and checkpoints will be saved to: {log_dir}") + +try: + print("Creating CLTTrainer instance...") + print(f"- Using device: {device}") + print(f"- CLT config (BatchTopK): {vars(clt_config)}") + print(f"- Activation Source: {training_config.activation_source}") + print(f"- Reading activations from: {training_config.activation_path}") + print(f"- Training precision: {training_config.precision}") + + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=log_dir, + device=device, + distributed=False, + ) + print("CLTTrainer instance created successfully.") +except Exception as e: + print(f"[ERROR] Failed to initialize CLTTrainer: {e}") + traceback.print_exc() + raise + +# Start training +print("Beginning training using BatchTopK activation and FP16...") +print(f"Training for {training_config.training_steps} steps.") +print(f"Normalization method set to: {training_config.normalization_method}") +print( + f"Standard sparsity penalty applied to BatchTopK activations: {training_config.apply_sparsity_penalty_to_batchtopk}" +) + +try: + start_train_time = time.time() + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + end_train_time = time.time() + print(f"Training finished in {end_train_time - start_train_time:.2f} seconds.") +except Exception as train_err: + print(f"[ERROR] Training failed: {train_err}") + traceback.print_exc() + raise + +# %% [markdown] +# ## 5. Saving and Loading the Final Trained Model +# +# The `CLTTrainer` automatically saves the final model and its configuration (cfg.json) +# in the `log_dir/final/` directory. If the training started with BatchTopK, +# the trainer converts the model to JumpReLU before this final save. +# Here, we'll also demonstrate a manual save of the model state and its config as Python dict, +# and then load it back. This manually saved model will be the one returned by trainer.train(), +# so it will also be JumpReLU if conversion occurred. + +# %% +# The trained_clt_model is what trainer.train() returned. +# If clt_config.activation_fn was 'batchtopk', trainer.train() converts it to JumpReLU in-place. +final_model_state_path = os.path.join(log_dir, "clt_final_manual_state.pt") +final_model_config_path = os.path.join(log_dir, "clt_final_manual_config.json") + +print(f"\nManually saving final model state to: {final_model_state_path}") +print(f"Manually saving final model config to: {final_model_config_path}") + +torch.save(trained_clt_model.state_dict(), final_model_state_path) +with open(final_model_config_path, "w") as f: + # The config on trained_clt_model will reflect 'jumprelu' if conversion happened + json.dump(trained_clt_model.config.__dict__, f, indent=4) + +print(f"\nContents of log directory ({log_dir}):") +for item in os.listdir(log_dir): + print(f"- {item}") + +# --- Loading the manually saved model --- +print("\nLoading the manually saved model...") + +# 1. Load the saved configuration +with open(final_model_config_path, "r") as f: + loaded_config_dict_manual = json.load(f) +loaded_clt_config_manual = CLTConfig(**loaded_config_dict_manual) + +print(f"Loaded manual config, activation_fn: {loaded_clt_config_manual.activation_fn}") + +# 2. Instantiate model with this loaded config and load state dict +loaded_clt_model_manual = CrossLayerTranscoder( + config=loaded_clt_config_manual, + process_group=None, # Assuming non-distributed for this load + device=torch.device(device), +) +loaded_clt_model_manual.load_state_dict(torch.load(final_model_state_path, map_location=device)) +loaded_clt_model_manual.eval() # Set to evaluation mode + +print("Manually saved model loaded successfully.") +print(f"Loaded model is on device: {next(loaded_clt_model_manual.parameters()).device}") + + +# %% [markdown] +# ## 6. Loading from Distributed Checkpoint (DC) +# +# The trainer saves checkpoints in a distributed-compatible format (using `torch.distributed.checkpoint`) +# in `log_dir/step_/` and `log_dir/final/`. We can load the `final` one. +# This model will also be in JumpReLU format if the original training was BatchTopK. + +# %% +# Imports moved to top: +# from torch.distributed.checkpoint import load_state_dict as dist_load_state_dict +# from torch.distributed.checkpoint.filesystem import FileSystemReader + +# Path to the 'final' directory created by the trainer +# This contains the sharded checkpoint and the cfg.json (which reflects JumpReLU if converted) +dc_final_checkpoint_dir = os.path.join(log_dir, "final") + +print(f"\nLoading model from distributed checkpoint: {dc_final_checkpoint_dir}") + +# 1. Load the config from cfg.json in that directory +dc_config_path = os.path.join(dc_final_checkpoint_dir, "cfg.json") +if not os.path.exists(dc_config_path): + print(f"ERROR: cfg.json not found in {dc_final_checkpoint_dir}. Cannot load distributed checkpoint correctly.") +else: + with open(dc_config_path, "r") as f: + loaded_config_dict_dc = json.load(f) + loaded_clt_config_dc = CLTConfig(**loaded_config_dict_dc) + print(f"Loaded DC config, activation_fn: {loaded_clt_config_dc.activation_fn}") + + # 2. Instantiate the model with this config + # Determine device (mps not directly supported by some distributed ops, fallback to cpu if necessary for loading) + device_to_load_on = device if device != "mps" else "cpu" + print(f"Instantiating model on device: {device_to_load_on} for DC load") + + model_for_dc_load = CrossLayerTranscoder( + config=loaded_clt_config_dc, + process_group=None, # For non-distributed loading of a dist checkpoint + device=torch.device(device_to_load_on), + ) + model_for_dc_load.eval() + + # 3. Create an empty state dict and load into it + state_dict_to_populate_dc = model_for_dc_load.state_dict() + + try: + dist_load_state_dict( + state_dict=state_dict_to_populate_dc, + storage_reader=FileSystemReader(dc_final_checkpoint_dir), + no_dist=True, # Important for loading a sharded checkpoint into a non-distributed model + ) + # Load the populated state dict into the model + model_for_dc_load.load_state_dict(state_dict_to_populate_dc) + print("Model loaded successfully from distributed checkpoint.") + print(f"Model is on device: {next(model_for_dc_load.parameters()).device}") + except Exception as e_dc: + print(f"ERROR loading distributed checkpoint: {e_dc}") + traceback.print_exc() + +# %% [markdown] +# ## 7. Post-hoc Conversion Sweep (θ scaling) from a BatchTopK Checkpoint +# +# To experiment with different θ scaling factors for BatchTopK-to-JumpReLU conversion, +# we need a model checkpoint that was saved *before* any automatic conversion by the trainer. +# The trainer saves checkpoints periodically (e.g., `clt_checkpoint_500.pt`). +# We'll load one of these, assuming it's still in BatchTopK format. + +# %% + +# Path to a BatchTopK checkpoint (e.g., one saved mid-training) +# Ensure this checkpoint was saved when the model was still BatchTopK. +# The trainer converts to JumpReLU only at the very end of training if the original was BatchTopK. +# So, a checkpoint from step 500 should be BatchTopK. +# Note: This part uses the log_dir defined in Section 4. If you are running this +# section independently, you'll need to set log_dir to a valid path. +batchtopk_checkpoint_path = os.path.join(log_dir, "clt_checkpoint_500.pt") + +if not os.path.exists(batchtopk_checkpoint_path): + print(f"WARNING: BatchTopK checkpoint {batchtopk_checkpoint_path} not found. Skipping sweep.") + print("Ensure your training ran for at least 500 steps and saved a checkpoint.") +else: + print(f"\nLoading BatchTopK model from checkpoint: {batchtopk_checkpoint_path} for sweep...") + + # clt_config_for_batchtopk_load is now defined INSIDE the loop below + + # 2. Load the BatchTopK model state + batchtopk_model_state = torch.load(batchtopk_checkpoint_path, map_location=device) + + # This is the StateDict from the BatchTopK model + # It will be used as the starting point for each conversion in the sweep. + + # std_normal is already defined at the top of the script if using the sweep code from previous turn + from torch.distributions.normal import Normal # Moved to top + + std_normal = Normal(0, 1) + + # Define quick_l0_checks here + def quick_l0_checks( + model: CrossLayerTranscoder, sample_batch_inputs: Dict[int, torch.Tensor], num_tokens_for_l0_check: int = 100 + ) -> tuple[float, float]: + """Return (avg_empirical_l0_layer0, expected_l0) + using an average over random tokens from sample_batch_inputs for empirical L0.""" + model.eval() + avg_empirical_l0_layer0 = float("nan") + std_normal_dist = torch.distributions.normal.Normal(0, 1) + + # Assume sample_batch_inputs[0] is valid if this function is called after store initialization + layer0_inputs_all_tokens = sample_batch_inputs.get(0) # Use .get() for safety, though we assume it exists + + if layer0_inputs_all_tokens is None or layer0_inputs_all_tokens.numel() == 0: + print("Warning: quick_l0_checks received no valid input for layer 0. Empirical L0 will be NaN.") + else: + layer0_inputs_all_tokens = layer0_inputs_all_tokens.to(device=model.device, dtype=model.dtype) + if layer0_inputs_all_tokens.dim() == 3: # B, S, D + num_tokens_in_batch = layer0_inputs_all_tokens.shape[0] * layer0_inputs_all_tokens.shape[1] + layer0_inputs_flat = layer0_inputs_all_tokens.reshape(num_tokens_in_batch, model.config.d_model) + elif layer0_inputs_all_tokens.dim() == 2: # Already [num_tokens, d_model] + num_tokens_in_batch = layer0_inputs_all_tokens.shape[0] + layer0_inputs_flat = layer0_inputs_all_tokens + else: + print( + f"Warning: quick_l0_checks received unexpected input shape {layer0_inputs_all_tokens.shape} for layer 0. Empirical L0 will be NaN." + ) + layer0_inputs_flat = None + + if layer0_inputs_flat is not None and num_tokens_in_batch > 0: + num_to_sample = min(num_tokens_for_l0_check, num_tokens_in_batch) + indices = torch.randperm(num_tokens_in_batch, device=model.device)[:num_to_sample] + selected_tokens_for_l0 = layer0_inputs_flat[indices] + if selected_tokens_for_l0.numel() > 0: + acts_layer0_selected = model.encode(selected_tokens_for_l0, layer_idx=0) + l0_per_token_selected = (acts_layer0_selected > 1e-6).sum(dim=1).float() + avg_empirical_l0_layer0 = l0_per_token_selected.mean().item() + else: + print( + "Warning: No tokens selected for empirical L0 check after sampling. Empirical L0 will be NaN." + ) + # Removed redundant checks for layer0_inputs_flat being None or num_tokens_in_batch == 0, covered by outer if/else + + expected_l0 = float("nan") + if hasattr(model, "log_threshold") and model.log_threshold is not None: + theta = model.log_threshold.exp().cpu() + p_fire = 1.0 - std_normal_dist.cdf(theta.float()) + expected_l0 = p_fire.sum().item() + else: + print("Warning: Model does not have log_threshold. Cannot compute expected_l0.") + return avg_empirical_l0_layer0, expected_l0 + + # Initialize LocalActivationStore for the sweep, assuming training_config is available from earlier cells + print("Initializing LocalActivationStore for theta estimation sweep...") + posthoc_activation_store: Optional[BaseActivationStore] = None + try: + from clt.training.data.local_activation_store import LocalActivationStore # Ensure import + + if training_config.activation_path is None: # This check is still good practice + raise ValueError("training_config.activation_path is None. Cannot initialize activation store for sweep.") + + posthoc_activation_store = LocalActivationStore( + dataset_path=training_config.activation_path, + train_batch_size_tokens=1024, # Can use a reasonable batch size for estimation + device=torch.device(device), + dtype=training_config.activation_dtype, + rank=0, + world=1, + seed=42, + sampling_strategy="sequential", + normalization_method="auto", + ) + print(f"Successfully initialized LocalActivationStore from: {training_config.activation_path}") + except NameError: # Handles case where training_config might not be defined if cells are run out of order + print("Error: 'training_config' not defined. Please ensure previous cells initializing it have been run.") + print("Skipping post-hoc theta scaling sweep.") + except Exception as e_store_init: + print(f"Error initializing LocalActivationStore for post-hoc sweep: {e_store_init}") + print("Skipping post-hoc theta scaling sweep.") + + if posthoc_activation_store: + scale_factors = [1.0] + n_batches_for_theta_estimation = 1 # Number of batches to use for theta estimation + + print("\n=== θ-scaling sweep (from BatchTopK checkpoint) using estimate_theta_posthoc ===") + print(f"Using {n_batches_for_theta_estimation} batches for theta estimation in each iteration.") + + # Import tqdm for the progress bar + from tqdm.auto import tqdm + + for sf in tqdm(scale_factors, desc="Scaling Factor Sweep"): + # Define clt_config_for_batchtopk_load INSIDE the loop + # to ensure a fresh BatchTopK config for each iteration. + clt_config_for_sweep = CLTConfig( + num_features=clt_num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn="batchtopk", # Start with BatchTopK config + batchtopk_k=batchtopk_k, # Specify k directly + batchtopk_straight_through=True, + clt_dtype="float32", # Match model dtype for consistency during load + ) + + tmp_model_for_sweep = CrossLayerTranscoder( + config=clt_config_for_sweep, + process_group=None, + device=torch.device(device), + ) + # Load the original BatchTopK state dict + tmp_model_for_sweep.load_state_dict(batchtopk_model_state) + tmp_model_for_sweep.eval() + + print(f"Estimating theta and converting with scale_factor = {sf:.2f}...") + try: + # Ensure the data iterator is reset or re-created if it's a one-shot iterator + # For this tutorial, assuming posthoc_activation_store can be iterated multiple times + # or we re-initialize it if it's a generator type that gets exhausted. + data_iterator_for_estimation = iter(posthoc_activation_store) + + estimated_thetas = tmp_model_for_sweep.estimate_theta_posthoc( + data_iter=data_iterator_for_estimation, + num_batches=n_batches_for_theta_estimation, + scale_factor=sf, + default_theta_value=1e6, # Default from convert_to_jumprelu_inplace + ) + # estimate_theta_posthoc now calls convert_to_jumprelu_inplace internally + print(f"Estimated theta shape: {estimated_thetas.shape}") + # Now tmp_model_for_sweep is a JumpReLU model + + # Get a sample batch for quick_l0_checks + # We need to be careful if data_iterator_for_estimation was exhausted + # For simplicity, let's try to get one more batch or re-initialize iterator for this check + sample_batch_for_l0_check_inputs: Dict[int, torch.Tensor] = {} + try: + sample_inputs_l0, _ = next(data_iterator_for_estimation) # Try to get next from current iterator + sample_batch_for_l0_check_inputs = sample_inputs_l0 + except StopIteration: + print("Warning: data_iterator_for_estimation exhausted. Re-initializing for L0 check.") + try: + reinitialized_iterator = iter(posthoc_activation_store) + sample_inputs_l0, _ = next(reinitialized_iterator) + sample_batch_for_l0_check_inputs = sample_inputs_l0 + except Exception as e_reinit_fetch: + print(f"Error re-fetching batch for L0 check: {e_reinit_fetch}. L0 check might use zeros.") + except Exception as e_fetch_l0_batch: + print(f"Error fetching batch for L0 check: {e_fetch_l0_batch}. L0 check might use zeros.") + + d_l0, exp_l0 = quick_l0_checks(tmp_model_for_sweep, sample_batch_for_l0_check_inputs) + print( + f"scale {sf:4.2f} | dummy-L0 {d_l0:6.0f} | expected-L0 {exp_l0:7.1f} (num_features={tmp_model_for_sweep.config.num_features}, num_layers={tmp_model_for_sweep.config.num_layers})" + ) + except Exception as e_sweep_iter: + print(f"ERROR during sweep iteration for scale_factor={sf:.2f}: {e_sweep_iter}") + traceback.print_exc() + continue # Continue to next scale factor + else: + print("Skipping post-hoc theta scaling sweep as activation store could not be initialized.") + +# %% [markdown] +# ## 8. Next Steps +# +# This tutorial showed how to train a CLT for GPT-2 using BatchTopK activation and FP16, +# save/load models, and perform a post-hoc analysis. + +# %% +print("\nGPT-2 FP16 BatchTopK Tutorial Complete!") +print(f"Logs and checkpoints are saved in: {log_dir}")