Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ This library is intended for the training and analysis of cross-layer sparse cod

A Cross-Layer Transcoder (CLT) is a multi-layer dictionary learning model designed to extract sparse, interpretable features from transformers, using an encoder for each layer and a decoder for each (source layer, destination layer) pair (e.g., 12 encoders and 78 decoders for `gpt2-small`). This implementation focuses on the core functionality needed to train and use CLTs, leveraging `nnsight` for model introspection and `datasets` for data handling.

The library now supports **tied decoders**, which can significantly reduce the number of parameters by sharing decoder weights across layers. Instead of training separate decoders for each (source, destination) pair, tied decoders use either:
- **Per-source tying**: One decoder per source layer, shared across all destination layers
- **Per-target tying**: One decoder per destination layer, shared across all source layers

Training a CLT involves the following steps:
1. Pre-generate activations with `scripts/generate_activations` (though an implementation of `StreamingActivationStore` is on the way).
2. Train a CLT (start with an expansion factor of at least `32`) using this data. Metrics can be logged to WandB. NMSE should get below `0.25`, or ideally even below `0.10`. As mentioned above, I recommend `BatchTopK` training, and suggest keeping `K` low--`200` is a good place to start.
Expand Down Expand Up @@ -85,6 +89,16 @@ Key configuration parameters are mapped to config classes via script arguments:
- `relu`: Standard ReLU activation.
- `batchtopk`: Selects a global top K features across all tokens in a batch, based on pre-activation values. The 'k' can be an absolute number or a fraction. This is often used as a training-time differentiable approximation that can later be converted to `jumprelu`.
- `topk`: Selects top K features per token (row-wise top-k).

**Decoder Tying Options** (`--decoder-tying`):
- `none` (default): Traditional untied decoders - separate decoder for each (source, destination) layer pair
- `per_source`: Share decoder weights per source layer - each source layer has one decoder used for all destinations
- `per_target`: Share decoder weights per destination layer - each destination layer has one decoder that combines features from all source layers

**Additional Tied Decoder Features**:
- `--enable-feature-offset`: Add learnable per-feature bias terms
- `--enable-feature-scale`: Add learnable per-feature scaling
- `--skip-connection`: Enable skip connections from source inputs to decoder outputs
- **TrainingConfig**: `--learning-rate`, `--training-steps`, `--train-batch-size-tokens`, `--activation-source`, `--activation-path` (for `local_manifest`), remote config fields (for `remote`, e.g. `--server-url`, `--dataset-id`), `--normalization-method`, `--sparsity-lambda`, `--preactivation-coef`, `--optimizer`, `--lr-scheduler`, `--log-interval`, `--eval-interval`, `--checkpoint-interval`, `--dead-feature-window`, WandB settings (`--enable-wandb`, `--wandb-project`, etc.).

### Single GPU Training Examples
Expand Down Expand Up @@ -139,6 +153,38 @@ python scripts/train_clt.py \\
# Add other arguments as needed
```

**Example: Training with Tied Decoders**

Tied decoders can significantly reduce the parameter count while maintaining performance. Here's an example using per-source tying:

```bash
python scripts/train_clt.py \
--activation-source local_manifest \
--activation-path ./tutorial_activations/gpt2/pile-uncopyrighted_train \
--output-dir ./clt_output_tied \
--model-name gpt2 \
--num-features 6144 \
--decoder-tying per_source \
--enable-feature-scale \
--skip-connection \
--activation-fn batchtopk \
--batchtopk-k 256 \
--learning-rate 3e-4 \
--training-steps 100000 \
--train-batch-size-tokens 8192 \
--sparsity-lambda 1e-3 \
--log-interval 100 \
--eval-interval 1000 \
--checkpoint-interval 5000 \
--enable-wandb --wandb-project clt_tied_training
```

This configuration:
- Uses `per_source` tying: 12 decoders instead of 78 for gpt2-small
- Enables feature scaling for better expressiveness
- Includes skip connections to preserve input information
- Uses BatchTopK with k=256 for training (can be converted to JumpReLU later)

### Multi-GPU Training (Tensor Parallelism)

This library supports feature-wise tensor parallelism using PyTorch Distributed Data Parallel (`torch.distributed`). This shards the model's parameters (encoders, decoders) across multiple GPUs, reducing memory usage per GPU and potentially speeding up computation.
Expand Down
58 changes: 51 additions & 7 deletions clt/config/clt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CLTConfig:
num_layers: int # Number of transformer layers
d_model: int # Dimension of model's hidden state
model_name: Optional[str] = None # Optional name for the underlying model
normalization_method: Literal["auto", "estimated_mean_std", "none"] = (
normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = (
"none" # How activations were normalized during training
)
activation_fn: Literal["jumprelu", "relu", "batchtopk", "topk"] = "jumprelu"
Expand All @@ -27,20 +27,28 @@ class CLTConfig:
topk_k: Optional[float] = None # Number or fraction of features to keep per token for TopK.
# If < 1, treated as fraction. If >= 1, treated as int count.
topk_straight_through: bool = True # Whether to use straight-through estimator for TopK.
# Top-K mode selection
topk_mode: Literal["global", "per_layer"] = "global" # How to apply top-k selection
clt_dtype: Optional[str] = None # Optional dtype for the CLT model itself (e.g., "float16")
expected_input_dtype: Optional[str] = None # Expected dtype of input activations
mlp_input_template: Optional[str] = None # Module path template for MLP input activations
mlp_output_template: Optional[str] = None # Module path template for MLP output activations
tl_input_template: Optional[str] = None # TransformerLens hook point pattern before MLP
tl_output_template: Optional[str] = None # TransformerLens hook point pattern after MLP
# context_size: Optional[int] = None

# Tied decoder configuration
decoder_tying: Literal["none", "per_source", "per_target"] = "none" # Decoder weight sharing strategy
enable_feature_offset: bool = False # Enable per-feature bias (feature_offset)
enable_feature_scale: bool = False # Enable per-feature scale (feature_scale)
skip_connection: bool = False # Enable skip connection from input to output

def __post_init__(self):
"""Validate configuration parameters."""
assert self.num_features > 0, "Number of features must be positive"
assert self.num_layers > 0, "Number of layers must be positive"
assert self.d_model > 0, "Model dimension must be positive"
valid_norm_methods = ["auto", "estimated_mean_std", "none"]
valid_norm_methods = ["none", "mean_std", "sqrt_d_model"]
assert (
self.normalization_method in valid_norm_methods
), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}"
Expand All @@ -60,6 +68,12 @@ def __post_init__(self):
raise ValueError("topk_k must be specified for TopK activation function.")
if self.topk_k is not None and self.topk_k <= 0:
raise ValueError("topk_k must be positive if specified.")

# Validate decoder tying configuration
valid_decoder_tying = ["none", "per_source", "per_target"]
assert (
self.decoder_tying in valid_decoder_tying
), f"Invalid decoder_tying: {self.decoder_tying}. Must be one of {valid_decoder_tying}"

@classmethod
def from_json(cls: Type[C], json_path: str) -> C:
Expand All @@ -73,6 +87,30 @@ def from_json(cls: Type[C], json_path: str) -> C:
"""
with open(json_path, "r") as f:
config_dict = json.load(f)

# Handle backward compatibility for old configs
if "decoder_tying" not in config_dict:
config_dict["decoder_tying"] = "none" # Default to original behavior
if "enable_feature_offset" not in config_dict:
config_dict["enable_feature_offset"] = False
if "enable_feature_scale" not in config_dict:
config_dict["enable_feature_scale"] = False

# Handle backwards compatibility for old normalization methods
if "normalization_method" in config_dict:
old_method = config_dict["normalization_method"]
# Map old values to new ones
if old_method in ["auto", "estimated_mean_std"]:
config_dict["normalization_method"] = "mean_std"
elif old_method in ["auto_sqrt_d_model", "estimated_mean_std_sqrt_d_model"]:
config_dict["normalization_method"] = "sqrt_d_model"

# Handle old sqrt_d_model_normalize flag
if "sqrt_d_model_normalize" in config_dict:
sqrt_normalize = config_dict.pop("sqrt_d_model_normalize")
if sqrt_normalize:
config_dict["normalization_method"] = "sqrt_d_model"

return cls(**config_dict)

def to_json(self, json_path: str) -> None:
Expand Down Expand Up @@ -108,11 +146,11 @@ class TrainingConfig:
debug_anomaly: bool = False

# Normalization parameters
normalization_method: Literal["auto", "estimated_mean_std", "none"] = "auto"
# 'auto': Use pre-calculated from mapped store, or estimate for streaming store.
# 'estimated_mean_std': Always estimate for streaming store (ignored for mapped).
# 'none': Disable normalization.
normalization_estimation_batches: int = 50 # Batches for normalization estimation
normalization_method: Literal["none", "mean_std", "sqrt_d_model"] = "mean_std"
# 'none': No normalization.
# 'mean_std': Standard (x - mean) / std normalization using pre-calculated stats.
# 'sqrt_d_model': EleutherAI-style x * sqrt(d_model) normalization.
normalization_estimation_batches: int = 50 # Batches for normalization estimation (if needed)

# --- Activation Store Source --- #
activation_source: Literal["local_manifest", "remote"] = "local_manifest"
Expand Down Expand Up @@ -221,6 +259,12 @@ def __post_init__(self):
assert (
0.0 <= self.sparsity_lambda_delay_frac < 1.0
), "sparsity_lambda_delay_frac must be between 0.0 (inclusive) and 1.0 (exclusive)"

# Validate normalization method
valid_norm_methods = ["none", "mean_std", "sqrt_d_model"]
assert (
self.normalization_method in valid_norm_methods
), f"Invalid normalization_method: {self.normalization_method}. Must be one of {valid_norm_methods}"


@dataclass
Expand Down
104 changes: 101 additions & 3 deletions clt/models/clt.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,24 @@ def encode(self, x: torch.Tensor, layer_idx: int) -> torch.Tensor:
)
return torch.zeros((expected_batch_dim, self.config.num_features), device=self.device, dtype=self.dtype)

def decode(self, a: Dict[int, torch.Tensor], layer_idx: int) -> torch.Tensor:
return self.decoder_module.decode(a, layer_idx)

def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor:
return self.decoder_module.decode(a, layer_idx, source_inputs)

def forward(self, inputs: Dict[int, torch.Tensor]) -> Dict[int, torch.Tensor]:
activations = self.get_feature_activations(inputs)

# Note: feature affine transformations are now applied in the decoder

reconstructions = {}
for layer_idx in range(self.config.num_layers):
relevant_activations = {k: v for k, v in activations.items() if k <= layer_idx and v.numel() > 0}
if layer_idx in inputs and relevant_activations:
reconstructions[layer_idx] = self.decode(relevant_activations, layer_idx)
# Pass source inputs for EleutherAI-style skip connections
source_inputs = {k: inputs[k] for k in range(layer_idx + 1) if k in inputs} if self.config.skip_connection else None
reconstruction = self.decode(relevant_activations, layer_idx, source_inputs)

reconstructions[layer_idx] = reconstruction
elif layer_idx in inputs:
batch_size = 0
input_tensor = inputs[layer_idx]
Expand All @@ -216,6 +223,17 @@ def get_feature_activations(self, inputs: Dict[int, torch.Tensor]) -> Dict[int,
processed_inputs[layer_idx] = x_orig.to(device=self.device, dtype=self.dtype)

if self.config.activation_fn == "batchtopk" or self.config.activation_fn == "topk":
# Check if we should use per-layer mode
if self.config.topk_mode == "per_layer":
# Use per-layer top-k by calling encode on each layer
activations = {}
for layer_idx in sorted(processed_inputs.keys()):
x_input = processed_inputs[layer_idx]
act = self.encode(x_input, layer_idx)
activations[layer_idx] = act
return activations

# Otherwise use global top-k
preactivations_dict, _ = self._encode_all_layers(processed_inputs)
if not preactivations_dict:
activations = {}
Expand Down Expand Up @@ -325,3 +343,83 @@ def log_threshold(self, new_param: Optional[torch.nn.Parameter]) -> None:
if not hasattr(self, "theta_manager") or self.theta_manager is None:
raise AttributeError("ThetaManager is not initialised; cannot set log_threshold.")
self.theta_manager.log_threshold = new_param

def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
"""Load state dict with backward compatibility for old checkpoints.

Handles:
1. Old untied decoder format -> new tied/untied format
2. Missing theta_bias/theta_scale parameters
3. Missing per_target_scale/per_target_bias parameters
"""
# Check if this is an old checkpoint by looking for decoder keys
old_format_decoder_keys = [k for k in state_dict.keys() if 'decoders.' in k and '->' in k]
is_old_checkpoint = len(old_format_decoder_keys) > 0

if is_old_checkpoint and self.config.decoder_tying == "per_source":
logger.warning(
"Loading old untied decoder checkpoint into tied decoder model. "
"This will use weights from the first target layer for each source layer."
)

# Convert old decoder weights to tied format
# For each source layer, use the weights from src->src decoder
new_state_dict = {}
for key, value in state_dict.items():
if 'decoders.' in key and '->' in key:
# Extract source and target layer indices
# Key format: "decoder_module.decoders.{src}->{tgt}.weight" or ".bias"
parts = key.split('.')
decoder_key_idx = parts.index('decoders') + 1
src_tgt = parts[decoder_key_idx].split('->')
src_layer = int(src_tgt[0])
tgt_layer = int(src_tgt[1])
param_type = parts[-1] # 'weight' or 'bias'

# Only use diagonal decoders (src->src) for tied architecture
if src_layer == tgt_layer:
new_key = '.'.join(parts[:decoder_key_idx] + [str(src_layer), param_type])
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
state_dict = new_state_dict

# Handle feature affine parameters migration from encoder to decoder module
# (for backward compatibility with old checkpoints)
for i in range(self.config.num_layers):
old_offset_key = f"encoder_module.feature_offset.{i}"
new_offset_key = f"decoder_module.feature_offset.{i}"
if old_offset_key in state_dict and new_offset_key not in state_dict:
logger.info(f"Migrating {old_offset_key} to {new_offset_key}")
state_dict[new_offset_key] = state_dict.pop(old_offset_key)

old_scale_key = f"encoder_module.feature_scale.{i}"
new_scale_key = f"decoder_module.feature_scale.{i}"
if old_scale_key in state_dict and new_scale_key not in state_dict:
logger.info(f"Migrating {old_scale_key} to {new_scale_key}")
state_dict[new_scale_key] = state_dict.pop(old_scale_key)

# Handle missing feature affine parameters (now in decoder module)
if self.config.enable_feature_offset and hasattr(self.decoder_module, 'feature_offset') and self.decoder_module.feature_offset is not None:
for i in range(self.config.num_layers):
key = f"decoder_module.feature_offset.{i}"
if key not in state_dict:
logger.info(f"Initializing missing {key} to zeros")
# Don't add to state_dict to let it be initialized by the module

if self.config.enable_feature_scale and hasattr(self.decoder_module, 'feature_scale') and self.decoder_module.feature_scale is not None:
for i in range(self.config.num_layers):
key = f"decoder_module.feature_scale.{i}"
if key not in state_dict:
logger.info(f"Initializing missing {key} (first target layer to ones, rest to zeros)")
# Don't add to state_dict to let it be initialized by the module

# Handle missing skip weights
if self.config.skip_connection and hasattr(self.decoder_module, 'skip_weights'):
for i in range(self.config.num_layers):
key = f"decoder_module.skip_weights.{i}"
if key not in state_dict:
logger.info(f"Initializing missing {key} to identity matrix")

# Call parent's load_state_dict
return super().load_state_dict(state_dict, strict=strict)
Loading
Loading