From b53df861c7ac8e95f1560297f82108148c5a69bd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 12 Feb 2025 12:45:28 +0000 Subject: [PATCH 01/73] Rename some topk_mask vars to mask --- pyproject.toml | 1 + spd/experiments/resid_mlp/models.py | 10 ++--- spd/experiments/resid_mlp/plotting.py | 38 +++++++++++-------- .../resid_mlp/resid_mlp_decomposition.py | 2 +- spd/experiments/resid_mlp/spd_interp.py | 4 +- spd/experiments/tms/models.py | 12 +++--- spd/hooks.py | 2 +- spd/models/components.py | 8 ++-- spd/plotting.py | 8 ++-- spd/run_spd.py | 38 ++++++++++--------- spd/utils.py | 20 +++++----- 11 files changed, 77 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd92222..72b99ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ known-third-party = ["wandb"] [tool.pyright] include = ["spd", "tests"] +exclude = ["**/wandb/**"] strictListInference = true strictDictionaryInference = true diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 5b93556..e3e54b5 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -93,18 +93,18 @@ def __init__( def forward( self, x: Float[Tensor, "batch ... d_model"], - topk_mask: Float[Tensor, "batch ... C"] | None = None, + mask: Float[Tensor, "batch ... C"] | None = None, ) -> tuple[Float[Tensor, "batch ... d_model"],]: """Run a forward pass and cache pre and post activations for each parameter. Note that we don't need to cache pre activations for the biases. We also don't care about the output bias which is always zero. """ - mid_pre_act_fn = self.mlp_in(x, topk_mask=topk_mask) + mid_pre_act_fn = self.mlp_in(x, mask=mask) if self.bias1 is not None: mid_pre_act_fn = mid_pre_act_fn + self.bias1 mid = self.act_fn(mid_pre_act_fn) - out = self.mlp_out(mid, topk_mask=topk_mask) + out = self.mlp_out(mid, mask=mask) if self.bias2 is not None: out = out + self.bias2 return out @@ -327,7 +327,7 @@ def __init__( def forward( self, x: Float[Tensor, "batch n_instances n_features"], - topk_mask: Float[Tensor, "batch n_instances C"] | None = None, + mask: Float[Tensor, "batch n_instances C"] | None = None, ) -> Float[Tensor, "batch n_instances d_embed"]: """ Returns: @@ -339,7 +339,7 @@ def forward( "batch n_instances n_features, n_instances n_features d_embed -> batch n_instances d_embed", ) for layer in self.layers: - residual = residual + layer(residual, topk_mask) + residual = residual + layer(residual, mask) out = einops.einsum( residual, self.W_U, diff --git a/spd/experiments/resid_mlp/plotting.py b/spd/experiments/resid_mlp/plotting.py index ae4103b..3f254cc 100644 --- a/spd/experiments/resid_mlp/plotting.py +++ b/spd/experiments/resid_mlp/plotting.py @@ -956,9 +956,9 @@ def collect_per_feature_losses( sample_topk_mask = calc_topk_mask(attribution_scores, topk=1, batch_topk=False) # Get the batch topk model output - spd_out_batch_topk = spd_model(batch, topk_mask=batch_topk_mask) + spd_out_batch_topk = spd_model(batch, mask=batch_topk_mask) # Get the sample topk model output - spd_out_sample_topk = spd_model(batch, topk_mask=sample_topk_mask) + spd_out_sample_topk = spd_model(batch, mask=sample_topk_mask) # Get rid of the n_instances dimension for simplicity batch: Float[Tensor, "batch n_features"] = batch.squeeze(1) @@ -1058,7 +1058,7 @@ def collect_average_components_per_feature( assert batch.shape[1] == 1 # Get which components were active for each feature - topk_mask_raw: Float[Tensor, "batch n_instances C"] = model_fn(batch).topk_mask + topk_mask_raw: Float[Tensor, "batch n_instances C"] = model_fn(batch).mask batch: Float[Tensor, "batch n_features"] = batch.squeeze(1) topk_mask: Float[Tensor, "batch C"] = topk_mask_raw.squeeze(1) @@ -1284,7 +1284,9 @@ def plot_resid_vs_mlp_out( # Get the SPD resid contribution by running with no subnetworks. This should be equivalent # to W_E W_U and but doesn't require access to the ResidMLP SPD model. topk_mask = torch.zeros_like(batch) - spd_WEU = topk_model_fn(batch, topk_mask).spd_topk_model_output[batch_idx, instance_idx, :] + spd_WEU = topk_model_fn(batch, topk_mask).spd_model_masked_output[ + batch_idx, instance_idx, : + ] spd_WEU = spd_WEU.detach().cpu() if tied_weights: assert torch.allclose(spd_WEU, W_EU), "Tied weights but W_EU != SPD resid contribution" @@ -1302,7 +1304,9 @@ def plot_resid_vs_mlp_out( else: topk_mask = torch.zeros_like(batch) topk_mask[:, :, subnet_indices] = 1 - topk_out = topk_model_fn(batch, topk_mask).spd_topk_model_output[batch_idx, instance_idx, :] + topk_out = topk_model_fn(batch, topk_mask).spd_model_masked_output[ + batch_idx, instance_idx, : + ] topk_mlp_out = topk_out.detach().cpu() - spd_WEU topk_mlp_out_mse = F.mse_loss(topk_mlp_out, mlp_out).item() corr = np.corrcoef(topk_mlp_out[mask], W_EU[mask])[0, 1] @@ -1315,7 +1319,9 @@ def plot_resid_vs_mlp_out( ) # Full forward pass topk_mask = torch.ones_like(batch) - full_out = topk_model_fn(batch, topk_mask).spd_topk_model_output[batch_idx, instance_idx, :] + full_out = topk_model_fn(batch, topk_mask).spd_model_masked_output[ + batch_idx, instance_idx, : + ] full_mlp_out = full_out.detach().cpu() - spd_WEU full_mlp_out_mse = F.mse_loss(full_mlp_out, mlp_out).item() corr = np.corrcoef(full_mlp_out[mask], W_EU[mask])[0, 1] @@ -1500,10 +1506,10 @@ def get_scrubbed_losses( topk = config.topk batch_topk = config.batch_topk - out_spd = spd_model_fn(batch, topk, batch_topk).spd_topk_model_output - out_random = top1_model_fn(batch, random_topk_mask).spd_topk_model_output - out_scrubbed = top1_model_fn(batch, scrubbed_topk_mask).spd_topk_model_output - out_antiscrubbed = top1_model_fn(batch, antiscrubbed_topk_mask).spd_topk_model_output + out_spd = spd_model_fn(batch, topk, batch_topk).spd_model_masked_output + out_random = top1_model_fn(batch, random_topk_mask).spd_model_masked_output + out_scrubbed = top1_model_fn(batch, scrubbed_topk_mask).spd_model_masked_output + out_antiscrubbed = top1_model_fn(batch, antiscrubbed_topk_mask).spd_model_masked_output out_target = target_model(batch) # Monosemantic baseline out_monosemantic = batch.clone() @@ -1663,13 +1669,15 @@ def plot_feature_response_with_subnets( ) zeros_topk_mask = torch.zeros(batch_size, n_instances, C, device=device) ones_topk_mask = torch.ones(batch_size, n_instances, C, device=device) - out_WE_WU_only = topk_model_fn(batch, zeros_topk_mask).spd_topk_model_output[:, instance_idx, :] + out_WE_WU_only = topk_model_fn(batch, zeros_topk_mask).spd_model_masked_output[ + :, instance_idx, : + ] out_red = topk_model_fn(batch, topk_mask_red) out_blue = topk_model_fn(batch, topk_mask_blue) - out_spd = topk_model_fn(batch, ones_topk_mask).spd_topk_model_output[:, instance_idx, :] - mlp_out_blue_spd = out_blue.spd_topk_model_output[:, instance_idx, :] - out_WE_WU_only - mlp_out_red_spd = out_red.spd_topk_model_output[:, instance_idx, :] - out_WE_WU_only + out_spd = topk_model_fn(batch, ones_topk_mask).spd_model_masked_output[:, instance_idx, :] + mlp_out_blue_spd = out_blue.spd_model_masked_output[:, instance_idx, :] - out_WE_WU_only + mlp_out_red_spd = out_red.spd_model_masked_output[:, instance_idx, :] - out_WE_WU_only mlp_out_target = out_blue.target_model_output[:, instance_idx, :] - out_WE_WU_only mlp_out_spd = out_spd - out_WE_WU_only @@ -1767,7 +1775,7 @@ def get_feature_subnet_map( batch = torch.zeros(batch_size, n_instances, n_features, device=device) batch[torch.arange(n_features), instance_idx, torch.arange(n_features)] = 1 top1_out = top1_model_fn(batch, None) - top1_mask = top1_out.topk_mask[:, instance_idx, :] + top1_mask = top1_out.mask[:, instance_idx, :] subnet_indices = { int(feature_idx.item()): int(subnet_idx.item()) for feature_idx, subnet_idx in top1_mask.nonzero() diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 27bce0e..4f2c7cc 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -317,7 +317,7 @@ def spd_model_fn( batch_topk=batch_topk, topk=topk, distil_from_target=config.distil_from_target, - ).spd_topk_model_output + ).spd_model_masked_output def target_model_fn(batch: Float[Tensor, "batch n_instances"]): return target_model(batch) diff --git a/spd/experiments/resid_mlp/spd_interp.py b/spd/experiments/resid_mlp/spd_interp.py index 38707dd..b86053a 100644 --- a/spd/experiments/resid_mlp/spd_interp.py +++ b/spd/experiments/resid_mlp/spd_interp.py @@ -117,7 +117,7 @@ def top1_model_fn( batch_topk=False, topk=1, distil_from_target=config.distil_from_target, - topk_mask=topk_mask, + mask=topk_mask, ) @@ -197,7 +197,7 @@ def top1_model_fn( # Get the loss of the spd model w.r.t the target model fn_without_batch_topk = lambda batch: spd_model_fn( batch, topk=1, batch_topk=False -).spd_topk_model_output # type: ignore +).spd_model_masked_output # type: ignore losses_spd_wrt_target = analyze_per_feature_performance( model_fn=fn_without_batch_topk, target_model_fn=target_model_fn, diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 334c923..defe011 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -45,18 +45,18 @@ def _tms_forward( linear1: Linear | LinearComponent, linear2: TransposedLinear | TransposedLinearComponent, b_final: Float[Tensor, "n_instances n_features"], - topk_mask: Float[Tensor, "batch n_instances C"] | None = None, + mask: Float[Tensor, "batch n_instances C"] | None = None, hidden_layers: nn.ModuleList | None = None, ) -> Float[Tensor, "batch n_instances n_features"]: """Forward pass used for TMSModel and TMSSPDModel. Note that topk_mask is only used for TMSSPDModel. """ - hidden = linear1(x, topk_mask=topk_mask) + hidden = linear1(x, mask=mask) if hidden_layers is not None: for layer in hidden_layers: - hidden = layer(hidden, topk_mask=topk_mask) - out_pre_relu = linear2(hidden, topk_mask=topk_mask) + b_final + hidden = layer(hidden, mask=mask) + out_pre_relu = linear2(hidden, mask=mask) + b_final out = F.relu(out_pre_relu) return out @@ -223,7 +223,7 @@ def __init__(self, config: TMSSPDModelConfig): def forward( self, x: Float[Tensor, "batch n_instances n_features"], - topk_mask: Float[Tensor, "batch n_instances C"] | None = None, + mask: Float[Tensor, "batch n_instances C"] | None = None, ) -> Float[Tensor, "batch n_instances n_features"]: return _tms_forward( x=x, @@ -231,7 +231,7 @@ def forward( linear2=self.linear2, b_final=self.b_final, hidden_layers=self.hidden_layers, - topk_mask=topk_mask, + mask=mask, ) @staticmethod diff --git a/spd/hooks.py b/spd/hooks.py index 552b87e..1ebe328 100644 --- a/spd/hooks.py +++ b/spd/hooks.py @@ -1,5 +1,5 @@ """ -Allow for running hooks on a model. Currently only forward hooks supported. +Allow for running hooks on a model. Much of this code is copied from https://github.com/TransformerLensOrg/TransformerLens """ diff --git a/spd/models/components.py b/spd/models/components.py index ce11dbe..9d1bad2 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -83,13 +83,13 @@ def weight(self) -> Float[Tensor, "... d_in d_out"]: def forward( self, x: Float[Tensor, "batch ... d_in"], - topk_mask: Float[Tensor, "batch ... C"] | None = None, + mask: Float[Tensor, "batch ... C"] | None = None, ) -> Float[Tensor, "batch ... d_out"]: """Forward pass through A and B matrices which make up the component for this layer. Args: x: Input tensor - topk_mask: Boolean tensor indicating which subnetworks to keep + mask: Tensor which masks parameter components. May be boolean or float. Returns: output: The summed output across all subnetworks """ @@ -97,11 +97,11 @@ def forward( # First multiply by A to get to intermediate dimension m inner_acts = einops.einsum(x, self.A, "batch ... d_in, ... C d_in m -> batch ... C m") - if topk_mask is not None: + if mask is not None: # We could apply the mask after component_acts, but we do it here so our matrices become # sparser and more efficient to compute with. inner_acts = einops.einsum( - inner_acts, topk_mask, "batch ... C m, batch ... C -> batch ... C m" + inner_acts, mask, "batch ... C m, batch ... C -> batch ... C m" ) # Then multiply by B to get to output dimension diff --git a/spd/plotting.py b/spd/plotting.py index 0289f3b..4c1e704 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -224,12 +224,12 @@ def collect_sparse_dataset_mse_losses( distil_from_target=distil_from_target, ) # Combine the batch and n_instances dimension for batch, labels, target_model_output, - # spd_outputs.spd_topk_model_output + # spd_outputs.spd_model_masked_output ein_str = "batch n_instances n_features -> (batch n_instances) n_features" batch = einops.rearrange(batch, ein_str) labels = einops.rearrange(labels, ein_str) target_model_output = einops.rearrange(target_model_output, ein_str) - spd_topk_model_output = einops.rearrange(spd_outputs.spd_topk_model_output, ein_str) + spd_model_masked_output = einops.rearrange(spd_outputs.spd_model_masked_output, ein_str) if gen_type == "at_least_zero_active": # Remove all entries where there are no active features @@ -237,10 +237,10 @@ def collect_sparse_dataset_mse_losses( batch = batch[mask] labels = labels[mask] target_model_output = target_model_output[mask] - spd_topk_model_output = spd_topk_model_output[mask] + spd_model_masked_output = spd_model_masked_output[mask] topk_recon_loss_labels = calc_recon_mse( - spd_topk_model_output, labels, has_instance_dim=False + spd_model_masked_output, labels, has_instance_dim=False ) recon_loss = calc_recon_mse(target_model_output, labels, has_instance_dim=False) baseline_batch = calc_recon_mse(batch, labels, has_instance_dim=False) diff --git a/spd/run_spd.py b/spd/run_spd.py index 971d66c..d6b6995 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -310,11 +310,11 @@ def optimize( ) ( - out_topk, + out_masked, schatten_loss, - topk_recon_loss, - topk_mask, - layer_acts_topk, + masked_recon_loss, + mask, + layer_acts_masked, ) = None, None, None, None, None if config.topk is not None: # We always assume the final subnetwork is the one we want to distil @@ -330,31 +330,33 @@ def optimize( ), "exact_topk only works if n_instances = 1" # Get the exact number of active features over the batch exact_topk = ((batch != 0).sum() / batch.shape[0]).item() - topk_mask = calc_topk_mask(topk_attrs, exact_topk, batch_topk=True) + mask = calc_topk_mask(topk_attrs, exact_topk, batch_topk=True) else: - topk_mask = calc_topk_mask(topk_attrs, config.topk, batch_topk=config.batch_topk) + mask = calc_topk_mask(topk_attrs, config.topk, batch_topk=config.batch_topk) if config.distil_from_target: # Add back the final subnetwork index to the topk mask and set it to True last_subnet_mask = torch.ones( - (*topk_mask.shape[:-1], 1), dtype=torch.bool, device=device + (*mask.shape[:-1], 1), dtype=mask.dtype, device=device ) - topk_mask = torch.cat((topk_mask, last_subnet_mask), dim=-1) + mask = torch.cat((mask, last_subnet_mask), dim=-1) # Do a forward pass with only the topk subnetworks - out_topk, topk_spd_cache = model.run_with_cache( - batch, names_filter=spd_cache_filter, topk_mask=topk_mask + out_masked, spd_cache_masked = model.run_with_cache( + batch, names_filter=spd_cache_filter, mask=mask ) - layer_acts_topk = {k: v for k, v in topk_spd_cache.items() if k.endswith("hook_post")} + layer_acts_masked = { + k: v for k, v in spd_cache_masked.items() if k.endswith("hook_post") + } if config.topk_recon_coeff is not None: - assert out_topk is not None - topk_recon_loss = calc_recon_mse(out_topk, target_out, has_instance_dim) + assert out_masked is not None + masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) act_recon_loss = None if config.act_recon_coeff is not None: act_recon_layer_acts = ( - layer_acts_topk - if layer_acts_topk is not None + layer_acts_masked + if layer_acts_masked is not None else {k: v for k, v in spd_cache.items() if k.endswith("hook_post")} ) target_post_weight_acts = post_weight_acts @@ -374,7 +376,7 @@ def optimize( if config.schatten_coeff is not None: # Use the sparsity loss as the mask in the lp case, and topk_mask otherwise - mask = topk_mask if topk_mask is not None else lp_sparsity_loss_per_k + mask = mask if mask is not None else lp_sparsity_loss_per_k assert mask is not None schatten_pnorm = config.schatten_pnorm if config.schatten_pnorm is not None else 1.0 schatten_loss = calc_schatten_loss( @@ -395,7 +397,7 @@ def optimize( "param_match_loss": (param_match_loss, config.param_match_coeff), "out_recon_loss": (out_recon_loss, config.out_recon_coeff), "lp_sparsity_loss": (lp_sparsity_loss, config.lp_sparsity_coeff), - "topk_recon_loss": (topk_recon_loss, config.topk_recon_coeff), + "masked_recon_loss": (masked_recon_loss, config.topk_recon_coeff), "act_recon_loss": (act_recon_loss, config.act_recon_coeff), "schatten_loss": (schatten_loss, config.schatten_coeff), } @@ -442,7 +444,7 @@ def optimize( out_dir=out_dir, device=device, config=config, - topk_mask=topk_mask, + topk_mask=mask, batch=batch, ) if config.wandb_project: diff --git a/spd/utils.py b/spd/utils.py index bb5b92b..c6f8eb2 100644 --- a/spd/utils.py +++ b/spd/utils.py @@ -148,7 +148,7 @@ class SPDOutputs(NamedTuple): spd_model_output: ( Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] ) - spd_topk_model_output: ( + spd_model_masked_output: ( Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] ) layer_acts: dict[str, Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"]] @@ -156,7 +156,7 @@ class SPDOutputs(NamedTuple): str, Float[Tensor, "batch C d_out"] | Float[Tensor, "batch n_instances C d_out"] ] attribution_scores: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] - topk_mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] + mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] def calc_topk_mask( @@ -200,7 +200,7 @@ def run_spd_forward_pass( batch_topk: bool, topk: float, distil_from_target: bool, - topk_mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] | None = None, + mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] | None = None, ) -> SPDOutputs: # Forward pass on target model target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) @@ -223,28 +223,28 @@ def run_spd_forward_pass( component_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")}, ) - if topk_mask is None: + if mask is None: # We always assume the final subnetwork is the one we want to distil topk_attrs = attribution_scores[..., :-1] if distil_from_target else attribution_scores - topk_mask = calc_topk_mask(topk_attrs, topk, batch_topk=batch_topk) + mask = calc_topk_mask(topk_attrs, topk, batch_topk=batch_topk) if distil_from_target: # Add back the final subnetwork index to the topk mask and set it to True last_subnet_mask = torch.ones( - (*topk_mask.shape[:-1], 1), dtype=torch.bool, device=attribution_scores.device + (*mask.shape[:-1], 1), dtype=torch.bool, device=attribution_scores.device ) - topk_mask = torch.cat((topk_mask, last_subnet_mask), dim=-1) + mask = torch.cat((mask, last_subnet_mask), dim=-1) - topk_spd_out = spd_model(input_array, topk_mask=topk_mask) + spd_model_masked_output = spd_model(input_array, mask=mask) attribution_scores = attribution_scores.cpu().detach() return SPDOutputs( target_model_output=target_out, spd_model_output=out, - spd_topk_model_output=topk_spd_out, + spd_model_masked_output=spd_model_masked_output, layer_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_post")}, component_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")}, attribution_scores=attribution_scores, - topk_mask=topk_mask, + mask=mask, ) From 0f4e7f8941580e4c89e92e386881edbc4ce759f5 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 11:38:15 +0000 Subject: [PATCH 02/73] Implement gating (untested) --- .vscode/launch.json | 20 +- README.md | 4 +- spd/attributions.py | 187 +- spd/configs.py | 76 +- spd/experiments/resid_mlp/model_interp.py | 84 - spd/experiments/resid_mlp/models.py | 30 +- spd/experiments/resid_mlp/plotting.py | 1509 +---------------- ...topk_config.yaml => resid_mlp_config.yaml} | 38 +- .../resid_mlp/resid_mlp_decomposition.py | 373 +--- .../resid_mlp/resid_mlp_sweep_config.yaml | 4 +- spd/experiments/resid_mlp/spd_interp.py | 296 ---- spd/experiments/tms/models.py | 29 +- spd/experiments/tms/spd_interp.py | 379 ----- .../{tms_topk_config.yaml => tms_config.yaml} | 28 +- spd/experiments/tms/tms_decomposition.py | 310 +--- spd/experiments/tms/tms_lp_config.yaml | 29 - spd/experiments/tms/tms_sweep_config.yaml | 4 +- spd/hooks.py | 2 +- spd/models/base.py | 27 - spd/models/components.py | 86 +- spd/plotting.py | 268 +-- spd/run_spd.py | 251 +-- spd/utils.py | 113 +- tests/test_attributions.py | 45 - tests/test_components.py | 46 - tests/test_resid_mlp.py | 28 +- tests/test_spd_losses.py | 18 +- tests/test_spd_model.py | 106 -- tests/test_tms.py | 87 +- tests/test_utils.py | 57 +- 30 files changed, 277 insertions(+), 4257 deletions(-) rename spd/experiments/resid_mlp/{resid_mlp_topk_config.yaml => resid_mlp_config.yaml} (81%) delete mode 100644 spd/experiments/resid_mlp/spd_interp.py delete mode 100644 spd/experiments/tms/spd_interp.py rename spd/experiments/tms/{tms_topk_config.yaml => tms_config.yaml} (65%) delete mode 100644 spd/experiments/tms/tms_lp_config.yaml delete mode 100644 tests/test_attributions.py delete mode 100644 tests/test_components.py delete mode 100644 tests/test_spd_model.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 60712d6..5ce7aef 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,11 +14,11 @@ }, }, { - "name": "tms_lp", + "name": "tms", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/spd/experiments/tms/tms_decomposition.py", - "args": "${workspaceFolder}/spd/experiments/tms/tms_lp_config.yaml", + "args": "${workspaceFolder}/spd/experiments/tms/tms_config.yaml", "console": "integratedTerminal", "justMyCode": true, "env": { @@ -26,23 +26,11 @@ } }, { - "name": "tms_topk", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/spd/experiments/tms/tms_decomposition.py", - "args": "${workspaceFolder}/spd/experiments/tms/tms_topk_config.yaml", - "console": "integratedTerminal", - "justMyCode": true, - "env": { - "PYDEVD_DISABLE_FILE_VALIDATION": "1" - } - }, - { - "name": "resid_mlp_topk", + "name": "resid_mlp", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/spd/experiments/resid_mlp/resid_mlp_decomposition.py", - "args": "${workspaceFolder}/spd/experiments/resid_mlp/resid_mlp_topk_config.yaml", + "args": "${workspaceFolder}/spd/experiments/resid_mlp/resid_mlp_config.yaml", "console": "integratedTerminal", "justMyCode": true, "env": { diff --git a/README.md b/README.md index 4fca768..7f4d233 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,9 @@ APD can be run by executing any of the `*_decomposition.py` scripts defined in t subdirectories. A config file is required for each experiment, which can be found in the same directory. For example: ```bash -python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_topk_config.yaml +python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_config.yaml ``` -will run SPD on TMS with the config file `tms_topk_config.yaml` (which is the main config file used +will run SPD on TMS with the config file `tms_config.yaml` (which is the main config file used for the TMS experiments in the paper). Wandb sweep files are also provided in the experiment subdirectories, and can be run with e.g.: diff --git a/spd/attributions.py b/spd/attributions.py index a9d24c1..7b43b7c 100644 --- a/spd/attributions.py +++ b/spd/attributions.py @@ -1,13 +1,10 @@ -"""Calculating and collecting attributions""" - -from typing import Literal +"""Calculations for how important each component is to the output.""" import einops import torch from jaxtyping import Float from torch import Tensor -from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel from spd.module_utils import collect_nested_module_attrs @@ -21,81 +18,72 @@ def calc_grad_attributions( post_weight_acts: dict[ str, Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"] ], - component_weights: dict[ - str, Float[Tensor, "C d_in d_out"] | Float[Tensor, "n_instances C d_in d_out"] - ], - C: int, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: + component_acts: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + Bs: dict[str, Float[Tensor, "m d_out"] | Float[Tensor, "n_instances m d_out"]], +) -> dict[str, Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]]: """Calculate the sum of the (squared) attributions from each output dimension. An attribution is the product of the gradient of the target model output w.r.t. the post acts - and the inner acts (i.e. the output of each subnetwork before being summed). - - Note that we don't use the component_acts collected from the SPD model, because this includes the - computational graph of the full model. We only want the subnetwork parameters of the current - layer to be in the computational graph. To do this, we multiply a detached version of the - pre_weight_acts by the subnet parameters. + and the component_acts. I.e. + sum_i[((pre_weight_acts @ A) @ (B @ d(out_i)/d(post_weight_acts))) ** 2] Note: This code may be run in between the training forward pass, and the loss.backward() and opt.step() calls; it must not mess with the training. The reason the current implementation is fine to run anywhere is that we just use autograd rather than backward which does not - populate the .grad attributes. Unrelatedly, we use retain_graph=True in a bunch of cases - where we want to later use the `out` variable in e.g. the loss function. + populate the .grad attributes. + + Unrelatedly, we use retain_graph=True in a bunch of cases where we want to later use the `out` + variable in e.g. the loss function. Args: target_out: The output of the target model. - pre_weight_acts: The activations at the output of each subnetwork before being summed. - post_weight_acts: The activations at the output of each layer after being summed. - component_weights: The component weight matrix at each layer. - C: The number of components. + pre_weight_acts: The activations of the target model before the weight matrix at each layer. + post_weight_acts: The activations at the target model after the weight matrix at each layer. + component_acts: The activations after multiplying by A at each layer. + Bs: The B matrix at each layer. + Returns: - The sum of the (squared) attributions from each output dimension. + A dictionary of the sum of the (squared) attributions from each output dimension for each + layer. """ # Ensure that all keys are the same after removing the hook suffixes - post_weight_act_names = [C.removesuffix(".hook_post") for C in post_weight_acts] - pre_weight_act_names = [C.removesuffix(".hook_pre") for C in pre_weight_acts] - component_weight_names = list(component_weights.keys()) - assert set(post_weight_act_names) == set(pre_weight_act_names) == set(component_weight_names) - - attr_shape = target_out.shape[:-1] + (C,) # (batch, C) or (batch, n_instances, C) - attribution_scores: Float[Tensor, "batch ... C"] = torch.zeros( - attr_shape, device=target_out.device, dtype=target_out.dtype - ) - - component_acts = {} - for param_name in pre_weight_act_names: - component_acts[param_name] = einops.einsum( - pre_weight_acts[param_name + ".hook_pre"].detach().clone(), - component_weights[param_name], - "... d_in, ... C d_in d_out -> ... C d_out", - ) - out_dim = target_out.shape[-1] - for feature_idx in range(out_dim): - feature_attributions: Float[Tensor, "batch ... C"] = torch.zeros( - attr_shape, device=target_out.device, dtype=target_out.dtype - ) - grad_post_weight_acts: tuple[Float[Tensor, "batch ... d_out"], ...] = torch.autograd.grad( + post_weight_act_names = [comp.removesuffix(".hook_post") for comp in post_weight_acts] + pre_weight_act_names = [comp.removesuffix(".hook_pre") for comp in pre_weight_acts] + component_act_names = [comp.removesuffix(".hook_component_acts") for comp in component_acts] + assert set(post_weight_act_names) == set(pre_weight_act_names) == set(component_act_names) + + m = next(iter(Bs.values())).shape[-2] + + attr_shape = target_out.shape[:-1] + (m,) # (batch, m) or (batch, n_instances, m) + attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]] = { + param_name: torch.zeros(attr_shape, device=target_out.device, dtype=target_out.dtype) + for param_name in post_weight_act_names + } + + for feature_idx in range(target_out.shape[-1]): # Iterate over the output dimensions + grad_post_weight_acts: tuple[ + Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"], ... + ] = torch.autograd.grad( target_out[..., feature_idx].sum(), list(post_weight_acts.values()), retain_graph=True ) for i, param_name in enumerate(post_weight_act_names): - feature_attributions += einops.einsum( - grad_post_weight_acts[i], - component_acts[param_name], - "... d_out ,... C d_out -> ... C", + # (B @ d(out)/d(post_weight_acts)) + grad_B = einops.einsum( + Bs[param_name], grad_post_weight_acts[i], "... m d_out, ... d_out -> ... m" ) + attributions[param_name] += ( + component_acts[param_name + ".hook_component_acts"] * grad_B + ) ** 2 - attribution_scores += feature_attributions**2 - - return attribution_scores + return attributions def collect_subnetwork_attributions( spd_model: SPDModel, - config: Config, target_model: HookedRootModule, device: str, n_instances: int | None = None, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: +) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: """ Collect subnetwork attributions. @@ -122,96 +110,11 @@ def collect_subnetwork_attributions( test_batch, names_filter=target_cache_filter ) - attribution_scores = calculate_attributions( - model=spd_model, - config=config, - batch=test_batch, + attribution_scores = calc_grad_attributions( target_out=target_out, pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, + component_acts={k: v for k, v in target_cache.items() if k.endswith("hook_component_acts")}, + Bs=collect_nested_module_attrs(spd_model, attr_name="B", include_attr_name=False), ) return attribution_scores - - -@torch.inference_mode() -def calc_ablation_attributions( - spd_model: SPDModel, - batch: Float[Tensor, "batch n_features"] | Float[Tensor, "batch n_instances n_features"], - out: Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] | None, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: - """Calculate the attributions by ablating each subnetwork one at a time.""" - assert out is not None, "out tensor is missing." - attr_shape = out.shape[:-1] + (spd_model.C,) # (batch, C) or (batch, n_instances, C) - has_instance_dim = len(out.shape) == 3 - attributions = torch.zeros(attr_shape, device=out.device, dtype=out.dtype) - for subnet_idx in range(spd_model.C): - stored_vals = spd_model.set_subnet_to_zero(subnet_idx, has_instance_dim) - ablation_out, _, _ = spd_model(batch) - out_recon = ((out - ablation_out) ** 2).mean(dim=-1) - attributions[..., subnet_idx] = out_recon - spd_model.restore_subnet(subnet_idx, stored_vals, has_instance_dim) - return attributions - - -def calc_activation_attributions( - component_acts: dict[ - str, Float[Tensor, "batch C d_out"] | Float[Tensor, "batch n_instances C d_out"] - ] - | None, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: - """Calculate the attributions by taking the L2 norm of the activations in each subnetwork. - - Args: - component_acts: The activations at the output of each subnetwork before being summed. - Returns: - The attributions for each subnetwork. - """ - assert component_acts is not None, "Component_acts are missing" - first_param = component_acts[next(iter(component_acts.keys()))] - assert len(first_param.shape) in (3, 4) - - attribution_scores: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] = ( - torch.zeros(first_param.shape[:-1], device=first_param.device, dtype=first_param.dtype) - ) - for param_matrix in component_acts.values(): - attribution_scores += param_matrix.pow(2).sum(dim=-1) - return attribution_scores - - -def calculate_attributions( - model: SPDModel, - config: Config, - batch: Float[Tensor, "batch n_features"] | Float[Tensor, "batch n_instances n_features"], - target_out: Float[Tensor, "batch n_features"] | Float[Tensor, "batch n_instances n_features"], - pre_weight_acts: dict[ - str, Float[Tensor, "batch d_in"] | Float[Tensor, "batch n_instances d_in"] - ], - post_weight_acts: dict[ - str, Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"] - ], - component_acts: dict[str, Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]] - | None = None, - out: Float[Tensor, "batch n_features"] - | Float[Tensor, "batch n_instances n_features"] - | None = None, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: - attributions = None - attribution_type: Literal["ablation", "gradient", "activation"] = config.attribution_type - if attribution_type == "ablation": - attributions = calc_ablation_attributions(spd_model=model, batch=batch, out=out) - elif attribution_type == "gradient": - component_weights = collect_nested_module_attrs( - model, attr_name="component_weights", include_attr_name=False - ) - attributions = calc_grad_attributions( - target_out=target_out, - pre_weight_acts=pre_weight_acts, - post_weight_acts=post_weight_acts, - component_weights=component_weights, - C=model.C, - ) - elif attribution_type == "activation": - attributions = calc_activation_attributions(component_acts=component_acts) - else: - raise ValueError(f"Invalid attribution type: {attribution_type}") - return attributions diff --git a/spd/configs.py b/spd/configs.py index 795dac3..09b3602 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -45,57 +45,35 @@ class Config(BaseModel): wandb_run_name: str | None = None wandb_run_name_prefix: str = "" seed: int = 0 - topk: PositiveFloat | None = None - batch_topk: bool = True - exact_topk: bool = False batch_size: PositiveInt steps: PositiveInt print_freq: PositiveInt image_freq: PositiveInt | None = None image_on_first_step: bool = True - slow_images: bool = False save_freq: PositiveInt | None = None lr: PositiveFloat out_recon_coeff: NonNegativeFloat | None = None act_recon_coeff: NonNegativeFloat | None = None param_match_coeff: NonNegativeFloat | None = 1.0 - topk_recon_coeff: NonNegativeFloat | None = None - schatten_coeff: NonNegativeFloat | None = None - schatten_pnorm: NonNegativeFloat | None = None - lp_sparsity_coeff: NonNegativeFloat | None = None + masked_recon_coeff: NonNegativeFloat | None = None + lp_sparsity_coeff: NonNegativeFloat + pnorm: PositiveFloat post_relu_act_recon: bool = False - distil_from_target: bool = False - pnorm: PositiveFloat | None = None - C: PositiveInt - m: PositiveInt | None = None + m: PositiveInt lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = "constant" lr_exponential_halflife: PositiveFloat | None = None lr_warmup_pct: Probability = 0.0 sparsity_loss_type: Literal["jacobian"] = "jacobian" unit_norm_matrices: bool = False - attribution_type: Literal["gradient", "ablation", "activation"] = "gradient" + attribution_type: Literal["gradient"] = "gradient" task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name") - DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = [ - "topk_param_attrib_coeff", - "orthog_coeff", - "hardcode_topk_mask_step", - "pnorm_end", - "topk_l2_coeff", - "spd_type", - "sparsity_warmup_pct", - ] - RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {"topk_act_recon_coeff": "act_recon_coeff"} + DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = [] + RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {} @model_validator(mode="before") def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, Any]: """Remove deprecated config keys and change names of any keys that have been renamed.""" - # Move k from task_config to Config and rename it to C - if "task_config" in config_dict and "k" in config_dict["task_config"]: - logger.warning("task_config.k is deprecated, please use C in the main Config instead") - config_dict["C"] = config_dict["task_config"]["k"] - del config_dict["task_config"]["k"] - for key in list(config_dict.keys()): val = config_dict[key] if key in cls.DEPRECATED_CONFIG_KEYS: @@ -109,34 +87,9 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, @model_validator(mode="after") def validate_model(self) -> Self: - # Check valid combinations of topk and batch_size - if self.topk is not None: - if self.batch_topk: - if not (self.batch_size * self.topk).is_integer(): - logger.warning( - f"batch_size * topk={self.batch_size * self.topk} is not an integer, will " - f"round down from {self.batch_size * self.topk} to " - f"{int(self.batch_size * self.topk)} when calculating topk_mask" - ) - else: - if not self.topk.is_integer(): - raise ValueError("topk must be an integer when not using batch_topk") - - # Warn if neither topk_recon_coeff nor lp_sparsity_coeff is set - if not self.topk_recon_coeff and not self.lp_sparsity_coeff: - logger.warning("Neither topk_recon_coeff nor lp_sparsity_coeff is set") - - # If topk_recon_coeff is set, topk must be set - if self.topk_recon_coeff is not None: - assert self.topk is not None, "topk must be set if topk_recon_coeff is set" - - # If lp_sparsity_coeff is set, pnorm must be set - if self.lp_sparsity_coeff is not None: - assert self.pnorm is not None, "pnorm must be set if lp_sparsity_coeff is set" - - # Check that topk_recon_coeff is None if topk is None - if self.topk is None: - assert self.topk_recon_coeff is None, "topk_recon_coeff is not None but topk is" + # Warn if neither masked_recon_coeff nor lp_sparsity_coeff is set + if not self.masked_recon_coeff and not self.lp_sparsity_coeff: + logger.warning("Neither masked_recon_coeff nor lp_sparsity_coeff is set") # Give a warning if both out_recon_coeff and param_match_coeff are > 0 if ( @@ -151,8 +104,8 @@ def validate_model(self) -> Self: # If any of the coeffs are 0, raise a warning msg = "is 0, you may wish to instead set it to null to avoid calculating the loss" - if self.topk_recon_coeff == 0: - logger.warning(f"topk_recon_coeff {msg}") + if self.masked_recon_coeff == 0: + logger.warning(f"masked_recon_coeff {msg}") if self.lp_sparsity_coeff == 0: logger.warning(f"lp_sparsity_coeff {msg}") if self.param_match_coeff == 0: @@ -164,9 +117,4 @@ def validate_model(self) -> Self: self.lr_exponential_halflife is not None ), "lr_exponential_halflife must be set if lr_schedule is exponential" - if self.schatten_coeff is not None: - assert ( - self.schatten_pnorm is not None - ), "schatten_pnorm must be set if schatten_coeff is set" - return self diff --git a/spd/experiments/resid_mlp/model_interp.py b/spd/experiments/resid_mlp/model_interp.py index 0f4163e..b813f85 100644 --- a/spd/experiments/resid_mlp/model_interp.py +++ b/spd/experiments/resid_mlp/model_interp.py @@ -1,6 +1,5 @@ # %% Imports -import einops import matplotlib.pyplot as plt import torch @@ -8,18 +7,13 @@ ResidualMLPModel, ) from spd.experiments.resid_mlp.plotting import ( - calculate_virtual_weights, - plot_2d_snr, plot_all_relu_curves, plot_individual_feature_response, - plot_resid_vs_mlp_out, plot_single_feature_response, plot_single_relu_curve, - relu_contribution_plot, ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset from spd.experiments.resid_mlp.train_resid_mlp import ResidMLPTrainConfig -from spd.plotting import plot_matrix from spd.settings import REPO_ROOT from spd.types import ModelPath from spd.utils import set_seed @@ -140,81 +134,3 @@ out_dir / f"resid_mlp_feature_response_multi_{n_layers}layers.png", bbox_inches="tight", dpi=300 ) print(f"Saved figure to {out_dir / f'resid_mlp_feature_response_multi_{n_layers}layers.png'}") - -# %% - - -instance_idx = 0 -nrows = 10 -fig, axs = plt.subplots(nrows=nrows, ncols=1, constrained_layout=True, figsize=(10, 3 + 4 * nrows)) -fig.suptitle(f"Model {path}") -for i in range(nrows): - ax = axs[i] # type: ignore - plot_resid_vs_mlp_out( - target_model=model, device=device, ax=ax, instance_idx=instance_idx, feature_idx=i - ) -plt.show() - - -# %% Show connection strength between ReLUs and features -virtual_weights = calculate_virtual_weights(target_model=model, device=device) -fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 5), constrained_layout=True) # type: ignore - -relu_contribution_plot( - ax1=ax1, - ax2=ax2, - all_diag_relu_conns=virtual_weights["diag_relu_conns"], - model=model, - device=device, - instance_idx=0, -) -plt.show() - -# %% Calculate S/N ratio for 1 and 2 active features. -fig = plot_2d_snr(model, device) -plt.show() - -# %% Plot virtual weights - -fig = plt.figure(constrained_layout=True, figsize=(20, 20)) -gs = fig.add_gridspec(ncols=2, nrows=3) -ax1 = fig.add_subplot(gs[0, 0]) -ax2 = fig.add_subplot(gs[0, 1]) -ax3 = fig.add_subplot(gs[1:, :]) -virtual_weights = calculate_virtual_weights(target_model=model, device=device) -instance_idx = 0 -in_conns = virtual_weights["in_conns"][instance_idx].cpu().detach() -out_conns = virtual_weights["out_conns"][instance_idx].cpu().detach() -W_E_W_U = einops.einsum( - virtual_weights["W_E"][instance_idx], - virtual_weights["W_U"][instance_idx], - "n_features1 d_embed, d_embed n_features2 -> n_features1 n_features2", -) -plot_matrix( - ax1, - in_conns.T, - "Virtual input weights $(W_E W_{in})^T$", - "Features", - "Neurons", - colorbar_format="%.2f", -) -plot_matrix( - ax2, - out_conns, - "Virtual output weights $W_{out} W_U$", - "Features", - "Neurons", - colorbar_format="%.2f", -) -ax2.xaxis.set_label_position("top") -plot_matrix( - ax3, - W_E_W_U, - "Virtual weights $W_E W_U$", - "Features", - "Features", - colorbar_format="%.2f", -) -plt.show() - -# %% diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index e3e54b5..cf02787 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -53,7 +53,6 @@ def __init__( n_instances=n_instances, init_type=init_type, init_scale=init_scale, - C=spd_kwargs["C"], m=spd_kwargs["m"], ) self.mlp_out = LinearComponent( @@ -62,7 +61,6 @@ def __init__( n_instances=n_instances, init_type=init_type, init_scale=init_scale, - C=spd_kwargs["C"], m=spd_kwargs["m"], ) else: @@ -93,18 +91,20 @@ def __init__( def forward( self, x: Float[Tensor, "batch ... d_model"], - mask: Float[Tensor, "batch ... C"] | None = None, + mlp_in_mask: Float[Tensor, "batch ... d_mlp"] | None = None, + mlp_out_mask: Float[Tensor, "batch ... d_model"] | None = None, ) -> tuple[Float[Tensor, "batch ... d_model"],]: """Run a forward pass and cache pre and post activations for each parameter. Note that we don't need to cache pre activations for the biases. We also don't care about the output bias which is always zero. """ - mid_pre_act_fn = self.mlp_in(x, mask=mask) + mid_pre_act_fn = self.mlp_in(x, mask=mlp_in_mask) if self.bias1 is not None: mid_pre_act_fn = mid_pre_act_fn + self.bias1 mid = self.act_fn(mid_pre_act_fn) - out = self.mlp_out(mid, mask=mask) + + out = self.mlp_out(mid, mask=mlp_out_mask) if self.bias2 is not None: out = out + self.bias2 return out @@ -280,8 +280,7 @@ class ResidualMLPSPDConfig(BaseModel): in_bias: bool out_bias: bool init_scale: float - C: PositiveInt - m: PositiveInt | None = None + m: PositiveInt init_type: Literal["kaiming_uniform", "xavier_normal"] = "xavier_normal" @@ -294,7 +293,6 @@ def __init__( self.config = config self.n_features = config.n_features # Required for backward compatibility self.n_instances = config.n_instances # Required for backward compatibility - self.C = config.C # Required for backward compatibility assert config.act_fn_name in ["gelu", "relu"] self.act_fn = F.gelu if config.act_fn_name == "gelu" else F.relu @@ -304,7 +302,7 @@ def __init__( init_param_(self.W_E, init_type=config.init_type) init_param_(self.W_U, init_type=config.init_type) - self.m = min(config.d_embed, config.d_mlp) if config.m is None else config.m + self.m = config.m self.layers = nn.ModuleList( [ @@ -317,7 +315,7 @@ def __init__( in_bias=config.in_bias, out_bias=config.out_bias, act_fn=self.act_fn, - spd_kwargs={"C": config.C, "m": self.m}, + spd_kwargs={"m": self.m}, ) for _ in range(config.n_layers) ] @@ -327,7 +325,7 @@ def __init__( def forward( self, x: Float[Tensor, "batch n_instances n_features"], - mask: Float[Tensor, "batch n_instances C"] | None = None, + masks: dict[str, Float[Tensor, "batch n_instances m"]] | None = None, ) -> Float[Tensor, "batch n_instances d_embed"]: """ Returns: @@ -338,8 +336,12 @@ def forward( self.W_E, "batch n_instances n_features, n_instances n_features d_embed -> batch n_instances d_embed", ) - for layer in self.layers: - residual = residual + layer(residual, mask) + for i, layer in enumerate(self.layers): + mlp_in_mask = masks[f"layers.{i}.mlp_in"] if masks is not None else None + mlp_out_mask = masks[f"layers.{i}.mlp_out"] if masks is not None else None + residual = residual + layer( + residual, mlp_in_mask=mlp_in_mask, mlp_out_mask=mlp_out_mask + ) out = einops.einsum( residual, self.W_U, @@ -423,7 +425,7 @@ def from_pretrained( assert isinstance(config.task_config, ResidualMLPTaskConfig) resid_mlp_spd_config = ResidualMLPSPDConfig( - **resid_mlp_train_config_dict["resid_mlp_config"], C=config.C, m=config.m + **resid_mlp_train_config_dict["resid_mlp_config"], m=config.m ) model = cls(config=resid_mlp_spd_config) params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") diff --git a/spd/experiments/resid_mlp/plotting.py b/spd/experiments/resid_mlp/plotting.py index 3f254cc..f70dca7 100644 --- a/spd/experiments/resid_mlp/plotting.py +++ b/spd/experiments/resid_mlp/plotting.py @@ -1,29 +1,12 @@ from collections.abc import Callable -from dataclasses import dataclass from typing import Literal -import einops import matplotlib.pyplot as plt -import numpy as np import torch import torch.nn.functional as F -from jaxtyping import Float -from matplotlib.colors import Normalize -from mpl_toolkits.axes_grid1 import make_axes_locatable -from pydantic import PositiveFloat from torch import Tensor -from tqdm import tqdm -from spd.experiments.resid_mlp.models import ( - ResidualMLPConfig, - ResidualMLPModel, - ResidualMLPSPDConfig, - ResidualMLPSPDModel, -) -from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset -from spd.plotting import plot_matrix -from spd.run_spd import Config -from spd.utils import SPDOutputs, calc_topk_mask, calculate_attributions +from spd.experiments.resid_mlp.models import ResidualMLPConfig, ResidualMLPSPDConfig def plot_individual_feature_response( @@ -291,1493 +274,3 @@ def plot_all_relu_curves( ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) return fig - - -def _calculate_snr( - model: ResidualMLPModel, device: str, input_values: tuple[float, float] -) -> Tensor: - n_features = model.config.n_features - n_instances = model.config.n_instances - batch_size = n_features**2 - batch = torch.zeros(batch_size, n_instances, n_features, device=device) - instance_idx = 0 - snr = torch.zeros(n_features, n_features) - for f1 in range(n_features): - for f2 in range(n_features): - idx = f1 * n_features + f2 - batch[idx, instance_idx, f1] = input_values[0] - batch[idx, instance_idx, f2] = input_values[1] - out = model(batch) - out: Float[Tensor, "batch n_features"] = out[:, instance_idx, :] - for f1 in range(n_features): - for f2 in range(n_features): - idx = f1 * n_features + f2 - signal = min(out[idx, f1].abs().item(), out[idx, f2].abs().item()) - noise = out[idx, :].std().item() - snr[f1, f2] = signal / noise - return snr - - -def plot_2d_snr(model: ResidualMLPModel, device: str): - fig, (ax1, ax2, ax3) = plt.subplots( - 3, 1, height_ratios=[1, 10, 10], constrained_layout=True, figsize=(4, 8) - ) # type: ignore - # Calculate SNR for (1, 1) and implicitly (1,) too. - snr = _calculate_snr(model, device, input_values=(1, 1)).cpu().detach() - # Plot diagonal in top subplot - diagonal = torch.diag(snr) - im1 = ax1.imshow(diagonal.unsqueeze(0), aspect="auto", vmin=1, vmax=snr.max()) - ax1.set_yticks([]) - ax1.set_title("SNR for single active features") - divider = make_axes_locatable(ax1) - cax1 = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im1, cax=cax1) - - # Plot main SNR matrix without diagonal - snr_no_diag = snr.clone() - snr_no_diag.fill_diagonal_(torch.nan) - im2 = ax2.imshow(snr_no_diag, aspect="auto", vmin=1, vmax=snr.max()) - ax2.set_title("SNR for pairs of active features set to (1, 1)") - ax2.set_xlabel("Feature 2 (set to 1)") - ax2.set_ylabel("Feature 1 (set to 1)") - divider = make_axes_locatable(ax2) - cax2 = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im2, cax=cax2) - - # Calculate SNR for (1, -1) - snr = _calculate_snr(model, device, input_values=(1, -1)).cpu().detach() - # Plot second SNR matrix without diagonal - snr_no_diag = snr.clone() - snr_no_diag.fill_diagonal_(torch.nan) - im3 = ax3.imshow(snr_no_diag, aspect="auto", vmin=1, vmax=snr.max()) - ax3.set_title("SNR for pairs of active features set to (1, -1)") - ax3.set_xlabel("Feature 2 (set to -1)") - ax3.set_ylabel("Feature 1 (set to 1)") - divider = make_axes_locatable(ax3) - cax3 = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im3, cax=cax3) - - return fig - - -def calculate_virtual_weights(target_model: ResidualMLPModel, device: str) -> dict[str, Tensor]: - """Currently ignoring interactions between layers. Just flattening (n_layers, d_mlp)""" - n_instances = target_model.config.n_instances - n_features = target_model.config.n_features - d_embed = target_model.config.d_embed - d_mlp = target_model.config.d_mlp - has_bias1 = target_model.layers[0].bias1 is not None - has_bias2 = target_model.layers[0].bias2 is not None - n_layers = target_model.config.n_layers - # Get weights - W_E: Float[Tensor, "n_instances n_features d_embed"] = target_model.W_E - W_U: Float[Tensor, "n_instances d_embed n_features"] = target_model.W_U - W_in: Float[Tensor, "n_instances d_embed d_mlp_eff"] = torch.cat( - [target_model.layers[i].mlp_in.weight.data for i in range(n_layers)], dim=-1 - ) - W_out: Float[Tensor, "n_instances d_mlp_eff d_embed"] = torch.cat( - [target_model.layers[i].mlp_out.weight.data for i in range(n_layers)], - dim=-2, - ) - b_in: Float[Tensor, "n_instances d_mlp_eff"] | None = ( - torch.cat([target_model.layers[i].bias1.data for i in range(n_layers)], dim=-1) - if has_bias1 - else None - ) - b_out: Float[Tensor, "n_instances d_embed"] | None = ( - torch.stack([target_model.layers[i].bias2.data for i in range(n_layers)]).sum(dim=0) - if has_bias2 - else None - ) - assert W_E.shape == (n_instances, n_features, d_embed) - assert W_U.shape == (n_instances, d_embed, n_features) - assert W_in.shape == (n_instances, d_embed, n_layers * d_mlp) - assert W_out.shape == (n_instances, n_layers * d_mlp, d_embed) - assert b_in.shape == (n_instances, n_layers * d_mlp) if b_in is not None else True - assert b_out.shape == (n_instances, d_embed) if b_out is not None else True - # Calculate connection strengths / virtual weights - in_conns: Float[Tensor, "n_instances n_features d_mlp"] = einops.einsum( - W_E, - W_in, - "n_instances n_features d_embed, n_instances d_embed d_mlp -> n_instances n_features d_mlp", - ) - out_conns: Float[Tensor, "n_instances d_mlp n_features"] = einops.einsum( - W_out, - W_E, - "n_instances d_mlp d_embed, n_instances n_features d_embed -> n_instances d_mlp n_features", - ) - diag_relu_conns: Float[Tensor, "n_instances n_features d_mlp"] = einops.einsum( - in_conns, - out_conns, - "n_instances n_features d_mlp, n_instances d_mlp n_features -> n_instances n_features d_mlp", - ) - assert in_conns.shape == (n_instances, n_features, n_layers * d_mlp) - assert out_conns.shape == (n_instances, n_layers * d_mlp, n_features) - assert diag_relu_conns.shape == (n_instances, n_features, n_layers * d_mlp) - virtual_weights = { - "W_E": W_E, - "W_U": W_U, - "W_in": W_in, - "W_out": W_out, - "in_conns": in_conns, - "out_conns": out_conns, - "diag_relu_conns": diag_relu_conns, - } - if b_in is not None: - virtual_weights["b_in"] = b_in - if b_out is not None: - virtual_weights["b_out"] = b_out - return virtual_weights - - -def relu_contribution_plot( - ax1: plt.Axes, - ax2: plt.Axes, - all_diag_relu_conns: Float[Tensor, "n_instances n_features d_mlp"], - model: ResidualMLPModel | ResidualMLPSPDModel, - device: str, - instance_idx: int = 0, -): - diag_relu_conns: Float[Tensor, "n_features d_mlp"] = ( - all_diag_relu_conns[instance_idx].cpu().detach() - ) - d_mlp = model.config.d_mlp - n_layers = model.config.n_layers - n_features = model.config.n_features - - ax1.axvline(-0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for i in range(model.config.n_features): - ax1.scatter([i] * d_mlp * n_layers, diag_relu_conns[i, :], alpha=0.3, marker=".", c="k") - ax1.axvline(i + 0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for j in range(d_mlp * n_layers): - if diag_relu_conns[i, j].item() > 0.1: - cmap_label = plt.get_cmap("hsv") - ax1.text( - i, diag_relu_conns[i, j].item(), str(j), color=cmap_label(j / d_mlp / n_layers) - ) - ax1.axhline(0, color="k", linestyle="--", alpha=0.3) - ax1.set_xlim(-0.5, model.config.n_features - 0.5) - ax2.axvline(-0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for i in range(d_mlp * n_layers): - ax2.scatter([i] * n_features, diag_relu_conns[:, i], alpha=0.3, marker=".", c="k") - ax2.axvline(i + 0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for j in range(n_features): - if diag_relu_conns[j, i].item() > 0.2: - cmap_label = plt.get_cmap("hsv") - ax2.text(i, diag_relu_conns[j, i].item(), str(j), color=cmap_label(j / n_features)) - ax2.axhline(0, color="k", linestyle="--", alpha=0.3) - ax1.set_xlabel("Features") - ax2.set_xlabel("ReLUs (consecutively enumerated throughout layers)") - ax2.set_xlim(-0.5, d_mlp * n_layers - 0.5) - - -def feature_contribution_plot( - ax: plt.Axes, - all_diag_relu_conns: Float[Tensor, "n_features d_mlp"], - model: ResidualMLPModel | ResidualMLPSPDModel, - n_features: int, - pre_labelled_neurons: dict[int, list[int]] | None = None, - legend: bool = True, -) -> dict[int, list[int]]: - diag_relu_conns: Float[Tensor, "n_features d_mlp"] = all_diag_relu_conns.cpu().detach() - d_mlp = model.config.d_mlp - n_layers = model.config.n_layers - - # Define colors for different layers - assert n_layers in [1, 2] - layer_colors = ["grey"] if n_layers == 1 else ["blue", "red"] - distinct_colors = [ - "#E41A1C", # red - "#377EB8", # blue - "#4DAF4A", # green - "#984EA3", # purple - "#FF7F00", # orange - "#A65628", # brown - "#F781BF", # pink - "#1B9E77", # teal - "#D95F02", # dark orange - "#7570B3", # slate blue - "#66A61E", # lime green - ] - - # Add legend if there are two layers - if n_layers == 2 and legend: - # Create dummy scatter plots for legend - ax.scatter([], [], c="blue", alpha=0.3, marker=".", label="First MLP") - ax.scatter([], [], c="red", alpha=0.3, marker=".", label="Second MLP") - ax.legend(loc="upper right") - - labelled_neurons: dict[int, list[int]] = {i: [] for i in range(n_features)} - - ax.axvline(-0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for i in range(n_features): - # Split points by layer and plot separately - for layer in range(n_layers): - layer_indices = slice(layer * d_mlp, (layer + 1) * d_mlp) - ax.scatter( - [i] * d_mlp, - diag_relu_conns[i, layer_indices], - alpha=0.3, - marker=".", - c=layer_colors[layer], - ) - ax.axvline(i + 0.5, color="k", linestyle="--", alpha=0.3, lw=0.5) - for j in range(d_mlp * n_layers): - # Label the neuron if it's in the pre-labelled set or if no pre-labelled set is provided - # and the neuron has a connection strength greater than 0.1 - if (pre_labelled_neurons is not None and j in pre_labelled_neurons[i]) or ( - pre_labelled_neurons is None and diag_relu_conns[i, j].item() > 0.1 - ): - color_idx = j % len(distinct_colors) - # Make the neuron label alternate between left and right (-0.1, 0.1) - # Add 0.05 or -0.05 to the x coordinate to shift the label left or right - ax.text( - i, - diag_relu_conns[i, j].item(), - str(j), - color=distinct_colors[color_idx], - ha="left" if (len(labelled_neurons[i]) + 1) % 2 == 0 else "right", - ) - labelled_neurons[i].append(j) - ax.axhline(0, color="k", linestyle="--", alpha=0.3) - ax.set_xlim(-0.5, n_features - 0.5) - ax.set_xlabel("Features") - - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - return labelled_neurons - - -def spd_calculate_virtual_weights(model: ResidualMLPSPDModel, device: str) -> dict[str, Tensor]: - """Currently ignoring interactions between layers. Just flattening (n_layers, d_mlp)""" - old_device = next(model.parameters()).device - model.to(device) - n_instances = model.config.n_instances - n_features = model.config.n_features - d_embed = model.config.d_embed - d_mlp = model.config.d_mlp - C = model.config.C - has_bias1 = model.layers[0].bias1 is not None - has_bias2 = model.layers[0].bias2 is not None - n_layers = model.config.n_layers - # Get weights - W_E: Float[Tensor, "n_instances n_features d_embed"] = model.W_E - W_U: Float[Tensor, "n_instances d_embed n_features"] = model.W_U - W_in: Float[Tensor, "n_instances C d_embed d_mlp_eff"] = torch.cat( - [model.layers[i].mlp_in.component_weights for i in range(n_layers)], dim=-1 - ) - W_out: Float[Tensor, "n_instances C d_mlp_eff d_embed"] = torch.cat( - [model.layers[i].mlp_out.component_weights for i in range(n_layers)], - dim=-2, - ) - b_in: Float[Tensor, "n_instances C d_mlp_eff"] | None = ( - torch.cat([model.layers[i].bias1 for i in range(n_layers)], dim=-1) if has_bias1 else None - ) - b_out: Float[Tensor, "n_instances C d_embed"] | None = ( - torch.stack([model.layers[i].bias2 for i in range(n_layers)]).sum(dim=0) - if has_bias2 - else None - ) - assert W_E.shape == (n_instances, n_features, d_embed) - assert W_U.shape == (n_instances, d_embed, n_features) - assert W_in.shape == (n_instances, C, d_embed, n_layers * d_mlp) - assert W_out.shape == (n_instances, C, n_layers * d_mlp, d_embed) - assert b_in.shape == (n_instances, C, n_layers * d_mlp) if b_in is not None else True - assert b_out.shape == (n_instances, C, d_embed) if b_out is not None else True - # Calculate connection strengths / virtual weights - in_conns: Float[Tensor, "n_instances C n_features d_mlp"] = einops.einsum( - W_E, - W_in, - "n_instances n_features d_embed, n_instances C d_embed d_mlp -> n_instances C n_features d_mlp", - ) - out_conns: Float[Tensor, "n_instances C d_mlp n_features"] = einops.einsum( - W_out, - W_E, - "n_instances C d_mlp d_embed, n_instances n_features d_embed -> n_instances C d_mlp n_features", - ) - diag_relu_conns: Float[Tensor, "n_instances C n_features d_mlp"] = einops.einsum( - in_conns, - out_conns, - "n_instances C n_features d_mlp, n_instances C d_mlp n_features -> n_instances C n_features d_mlp", - ) - assert in_conns.shape == (n_instances, C, n_features, n_layers * d_mlp) - assert out_conns.shape == (n_instances, C, n_layers * d_mlp, n_features) - assert diag_relu_conns.shape == (n_instances, C, n_features, n_layers * d_mlp) - virtual_weights = { - "W_E": W_E, - "W_U": W_U, - "W_in": W_in, - "W_out": W_out, - "in_conns": in_conns, - "out_conns": out_conns, - "diag_relu_conns": diag_relu_conns, - } - if b_in is not None: - virtual_weights["b_in"] = b_in - if b_out is not None: - virtual_weights["b_out"] = b_out - - model.to(old_device) - return virtual_weights - - -def spd_calculate_diag_relu_conns( - model: ResidualMLPSPDModel, - device: str, - k_select: int | Literal["sum_before", "sum_nocrossterms", "sum_onlycrossterms"] = 0, -) -> Float[Tensor, "n_instances n_features d_mlp"]: - virtual_weights = spd_calculate_virtual_weights(model, device) - if isinstance(k_select, int): - return virtual_weights["diag_relu_conns"][:, k_select] - elif k_select == "sum_nocrossterms": - return virtual_weights["diag_relu_conns"].sum(dim=1) - else: - in_conns: Float[Tensor, "n_instances C n_features d_mlp"] = virtual_weights["in_conns"] - out_conns: Float[Tensor, "n_instances C d_mlp n_features"] = virtual_weights["out_conns"] - if k_select == "sum_onlycrossterms": - nocross_diag_relu_conns: Float[Tensor, "n_instances n_features d_mlp"] = ( - virtual_weights["diag_relu_conns"].sum(dim=1) - ) - all_diag_relu_conns: Float[Tensor, "n_instances k1 k2 n_features d_mlp"] = ( - einops.einsum( - in_conns, - out_conns, - "n_instances k1 n_features d_mlp, n_instance k2 d_mlp n_features -> n_instances k1 k2 n_features d_mlp", - ) - ) - return all_diag_relu_conns.sum(dim=(-3, -4)) - nocross_diag_relu_conns - elif k_select == "sum_before": - sum_diag_relu_conns: Float[Tensor, "n_instances n_features d_mlp"] = einops.einsum( - in_conns.sum(dim=1), - out_conns.sum(dim=1), - "n_instances n_features d_mlp, n_instance d_mlp n_features -> n_instances n_features d_mlp", - ) - return sum_diag_relu_conns - else: - raise ValueError(f"Invalid k_select: {k_select}") - - -def plot_spd_relu_contribution( - spd_model: ResidualMLPSPDModel, - target_model: ResidualMLPModel, - device: str = "cuda", - k_plot_limit: int | None = None, -): - offset = 4 - nrows = (k_plot_limit or spd_model.config.C) + offset - fig1, axes1 = plt.subplots(nrows, 1, figsize=(20, 3 + 2 * nrows), constrained_layout=True) - axes1 = np.atleast_1d(axes1) # type: ignore - fig2, axes2 = plt.subplots(nrows, 1, figsize=(10, 3 + 2 * nrows), constrained_layout=True) - axes2 = np.atleast_1d(axes2) # type: ignore - - virtual_weights = calculate_virtual_weights(target_model, device) - relu_conns = virtual_weights["diag_relu_conns"] - relu_contribution_plot(axes1[0], axes2[0], relu_conns, target_model, device) - axes1[0].set_ylabel("Target model", fontsize=8) - axes2[0].set_ylabel("Target model", fontsize=8) - axes1[0].set_xlabel("") - axes2[0].set_xlabel("") - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_before") - relu_contribution_plot(axes1[1], axes2[1], relu_conns, spd_model, device) - axes1[1].set_ylabel("SPD model full sum of all subnets", fontsize=8) - axes2[1].set_ylabel("SPD model full sum of all subnets", fontsize=8) - axes1[1].set_xlabel("") - axes2[1].set_xlabel("") - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_nocrossterms") - relu_contribution_plot(axes1[2], axes2[2], relu_conns, spd_model, device) - axes1[2].set_ylabel("SPD model sum without cross terms", fontsize=8) - axes2[2].set_ylabel("SPD model sum without cross terms", fontsize=8) - axes1[2].set_xlabel("") - axes2[2].set_xlabel("") - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_onlycrossterms") - relu_contribution_plot(axes1[3], axes2[3], relu_conns, spd_model, device) - axes1[3].set_ylabel("SPD model sum only cross terms", fontsize=8) - axes2[3].set_ylabel("SPD model sum only cross terms", fontsize=8) - axes1[3].set_xlabel("") - axes2[3].set_xlabel("") - for c in range(k_plot_limit or spd_model.config.C): - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select=c) - relu_contribution_plot(axes1[c + offset], axes2[c + offset], relu_conns, spd_model, device) - axes1[c + offset].set_ylabel(f"k={c}") - axes2[c + offset].set_ylabel(f"k={c}") - if (k_plot_limit or spd_model.config.C) - 1 > c: - axes1[c + offset].set_xlabel("") - axes2[c + offset].set_xlabel("") - return fig1, fig2 - - -def plot_spd_feature_contributions( - spd_model: ResidualMLPSPDModel, - target_model: ResidualMLPModel, - device: str = "cuda", - k_plot_limit: int | None = None, -) -> plt.Figure: - instance_idx = 0 - offset = 4 - nrows = (k_plot_limit or spd_model.config.C) + offset - fig1, axes1 = plt.subplots(nrows, 1, figsize=(20, 3 + 2 * nrows), constrained_layout=True) - axes1 = np.atleast_1d(axes1) # type: ignore - - n_features = spd_model.config.n_features - - virtual_weights = calculate_virtual_weights(target_model, device) - relu_conns = virtual_weights["diag_relu_conns"] - feature_contribution_plot( - ax=axes1[0], - all_diag_relu_conns=relu_conns[instance_idx], - model=target_model, - n_features=n_features, - ) - axes1[0].set_ylabel("Target model", fontsize=8) - axes1[0].set_xlabel("") - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_before") - feature_contribution_plot( - ax=axes1[1], - all_diag_relu_conns=relu_conns[instance_idx], - model=spd_model, - n_features=n_features, - ) - axes1[1].set_ylabel("SPD model full sum of all subnets", fontsize=8) - axes1[1].set_xlabel("") - - # We now use max component instead of sum_nocrossterms now - # relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_nocrossterms") - # relu_contribution_plot(axes1[2], axes2[2], relu_conns, spd_model, device) - # axes1[2].set_ylabel("SPD model sum without cross terms", fontsize=8) - # axes2[2].set_ylabel("SPD model sum without cross terms", fontsize=8) - # axes1[2].set_xlabel("") - # axes2[2].set_xlabel("") - - diag_relu_conns: Float[Tensor, "C n_features d_mlp"] = spd_calculate_virtual_weights( - spd_model, device - )["diag_relu_conns"][instance_idx] - max_component_indices = diag_relu_conns.max(dim=-1).values.argmax(dim=0) - # For each feature, use the C values based on the max_component_indices - max_component_contributions: Float[Tensor, "n_features d_mlp"] = diag_relu_conns[ - max_component_indices, torch.arange(spd_model.config.n_features) - ] - feature_contribution_plot( - ax=axes1[2], - all_diag_relu_conns=max_component_contributions, - model=spd_model, - n_features=n_features, - ) - axes1[2].set_ylabel("SPD model max component", fontsize=8) - axes1[2].set_xlabel("") - # Label the x axis with the subnets that have the largest neuron for each feature - # Set xticks for every index - axes1[2].set_xticks(range(n_features)) - axes1[2].set_xticklabels(max_component_indices.tolist()) - axes1[2].tick_params(axis="x", labelsize=6) - - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select="sum_onlycrossterms") - # relu_contribution_plot(axes1[3], relu_conns, spd_model, device) - feature_contribution_plot( - ax=axes1[3], - all_diag_relu_conns=relu_conns[instance_idx], - model=spd_model, - n_features=n_features, - ) - axes1[3].set_ylabel("SPD model sum only cross terms", fontsize=8) - axes1[3].set_xlabel("") - - # Use the same y-axis max for all plots - y_min = min([axes1[i].get_ylim()[0] for i in range(4)]) - y_max = max([axes1[i].get_ylim()[1] for i in range(4)]) - axes1[0].set_ylim(y_min, y_max) - axes1[1].set_ylim(y_min, y_max) - axes1[2].set_ylim(y_min, y_max) - axes1[3].set_ylim(y_min, y_max) - - for c in range(k_plot_limit or spd_model.config.C): - relu_conns = spd_calculate_diag_relu_conns(spd_model, device, k_select=c) - # relu_contribution_plot(axes1[k + offset], relu_conns, spd_model, device) - feature_contribution_plot( - ax=axes1[c + offset], - all_diag_relu_conns=relu_conns[instance_idx], - model=spd_model, - n_features=n_features, - ) - axes1[c + offset].set_ylabel(f"k={c}") - if (k_plot_limit or spd_model.config.C) - 1 > c: - axes1[c + offset].set_xlabel("") - return fig1 - - -def plot_spd_feature_contributions_truncated( - spd_model: ResidualMLPSPDModel, - target_model: ResidualMLPModel, - device: str = "cuda", - n_features: int | None = 10, - include_crossterms: bool = False, -): - assert spd_model.config.n_instances == 1, "Only one instance supported for now" - - n_features = n_features or spd_model.config.n_features - n_rows = 3 if include_crossterms else 2 - fig1, axes1 = plt.subplots(n_rows, 1, figsize=(10, 7), constrained_layout=True) - axes1 = np.atleast_1d(axes1) # type: ignore - - # First plot: Target model - virtual_weights = calculate_virtual_weights(target_model, device) - relu_conns: Float[Tensor, "n_features d_mlp"] = virtual_weights["diag_relu_conns"][ - 0, :n_features, : - ] - labelled_neurons = feature_contribution_plot( - axes1[0], - relu_conns, - model=target_model, - n_features=n_features, - legend=True, - ) - axes1[0].set_ylabel("Neuron contribution") - axes1[0].set_xlabel(f"Input feature index (first {n_features} shown)") - axes1[0].set_title("Target model") - axes1[0].set_xticks(range(n_features)) # Ensure all xticks have labels - - # Second plot: SPD model (without cross terms max k) - # Previously we would just use no_crossterms which sums over all k - # spd_relu_conns: Float[Tensor, "n_features d_mlp"] = spd_calculate_diag_relu_conns( - # spd_model, device, k_select="sum_nocrossterms" - # )[0, :n_features, :] - # Instead, we want find the C which has the largest neuron for each feature index - diag_relu_conns: Float[Tensor, "C n_features d_mlp"] = spd_calculate_virtual_weights( - spd_model, device - )["diag_relu_conns"][0, :, :n_features, :] - max_component_indices = diag_relu_conns.max(dim=-1).values.argmax(dim=0) - # For each feature, use the C values based on the max_component_indices - max_component_contributions: Float[Tensor, "n_features d_mlp"] = diag_relu_conns[ - max_component_indices, torch.arange(n_features) - ] - feature_contribution_plot( - axes1[1], - max_component_contributions, - model=spd_model, - n_features=n_features, - pre_labelled_neurons=labelled_neurons, - legend=False, - ) - axes1[1].set_ylabel("Neuron contribution") - axes1[1].set_xlabel("Parameter component index") - axes1[1].set_title("Individual APD parameter components") - axes1[1].set_xticks(range(n_features)) - - # Set the same y-axis limits for both plots - y_min = min(axes1[0].get_ylim()[0], axes1[1].get_ylim()[0]) - y_max = max(axes1[0].get_ylim()[1], axes1[1].get_ylim()[1]) - axes1[0].set_ylim(y_min, y_max) - axes1[1].set_ylim(y_min, y_max) - - # Label the x axis with the subnets that have the largest neuron for each feature - axes1[1].set_xticklabels(max_component_indices.tolist()) # Labels are the subnet indices - - if include_crossterms: - # Third plot: SPD model (with cross terms) - spd_relu_conns: Float[Tensor, "n_features d_mlp"] = spd_calculate_diag_relu_conns( - spd_model, device, k_select="sum_onlycrossterms" - )[0, :n_features, :] - feature_contribution_plot( - axes1[2], spd_relu_conns, model=spd_model, n_features=n_features, legend=False - ) - axes1[2].set_ylabel("Neuron contribution") - axes1[2].set_xlabel("Input feature index") - axes1[2].set_title("Input feature cross terms") - axes1[2].set_xticks(range(n_features)) - - return fig1 - - -def collect_per_feature_losses( - target_model: ResidualMLPModel, - spd_model: ResidualMLPSPDModel, - config: Config, - dataset: ResidualMLPDataset, - device: str, - batch_size: int, - n_samples: int, -) -> tuple[ - Float[Tensor, " n_features"], Float[Tensor, " n_features"], Float[Tensor, " n_features"] -]: - """Collect the MSE losses for the target model, SPD batch topk SPD sample topk. - - Returns: - loss_target: Float[Tensor, " n_features"] - loss_spd_batch_topk: Float[Tensor, " n_features"] - loss_spd_sample_topk: Float[Tensor, " n_features"] - """ - print_every_counter = 10000 - n_samples_acc = 0 - feature_counts = torch.zeros(spd_model.config.n_features, device=device) - loss_target = torch.zeros(spd_model.config.n_features, device=device) - loss_spd_batch_topk = torch.zeros(spd_model.config.n_features, device=device) - loss_spd_sample_topk = torch.zeros(spd_model.config.n_features, device=device) - - while n_samples_acc < n_samples: - # Generate batch with no zero samples - batch, labels = dataset.generate_batch(batch_size) - batch = batch.to(device) - labels = labels.to(device) - - # n_instances should be 1 - assert batch.shape[1] == 1 - - # Forward pass on target model - target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) - target_out, target_cache = target_model.run_with_cache( - batch, names_filter=target_cache_filter - ) - - # Do a forward pass with all subnetworks - spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) - spd_out, spd_cache = spd_model.run_with_cache(batch, names_filter=spd_cache_filter) - - attribution_scores = calculate_attributions( - model=spd_model, - config=config, - batch=batch, - out=spd_out, - target_out=target_out, - pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, - post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, - component_acts={ - k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts") - }, - ) - - with torch.inference_mode(): - assert config.batch_topk is True - assert config.topk is not None - # Get the topk mask for the batch topk model - batch_topk_mask = calc_topk_mask( - attribution_scores, topk=config.topk, batch_topk=config.batch_topk - ) - # Get the topk mask for the sample topk model - sample_topk_mask = calc_topk_mask(attribution_scores, topk=1, batch_topk=False) - - # Get the batch topk model output - spd_out_batch_topk = spd_model(batch, mask=batch_topk_mask) - # Get the sample topk model output - spd_out_sample_topk = spd_model(batch, mask=sample_topk_mask) - - # Get rid of the n_instances dimension for simplicity - batch: Float[Tensor, "batch n_features"] = batch.squeeze(1) - batch_topk_mask: Float[Tensor, "batch C"] = batch_topk_mask.squeeze(1) - target_out: Float[Tensor, "batch n_features"] = target_out.squeeze(1) - spd_out_batch_topk: Float[Tensor, "batch n_features"] = spd_out_batch_topk.squeeze(1) - spd_out_sample_topk: Float[Tensor, "batch n_features"] = spd_out_sample_topk.squeeze(1) - labels: Float[Tensor, "batch n_features"] = labels.squeeze(1) - - # Get the indices of samples where there is exactly one feature active - active_features_batch: Float[Tensor, "batch n_features"] = (batch != 0).float() - exactly_one_active_features_batch: Float[Tensor, "batch n_features"] = ( - active_features_batch.sum(dim=-1) == 1 - ) - - # Filter to only include the samples where there is exactly one feature active - batch = batch[exactly_one_active_features_batch] - labels = labels[exactly_one_active_features_batch] - target_out = target_out[exactly_one_active_features_batch] - spd_out_batch_topk = spd_out_batch_topk[exactly_one_active_features_batch] - spd_out_sample_topk = spd_out_sample_topk[exactly_one_active_features_batch] - filtered_active_features_batch: Float[Tensor, "sub_batch n_features"] = ( - active_features_batch[exactly_one_active_features_batch] - ) - - # Get the Squared error loss for each sample - loss_target_batch_raw: Float[Tensor, "sub_batch 1"] = ((target_out - labels) ** 2).sum( - dim=-1, keepdim=True - ) - loss_spd_batch_topk_batch_raw: Float[Tensor, "sub_batch 1"] = ( - (spd_out_batch_topk - labels) ** 2 - ).sum(dim=-1, keepdim=True) - loss_spd_sample_topk_batch_raw: Float[Tensor, "sub_batch 1"] = ( - (spd_out_sample_topk - labels) ** 2 - ).sum(dim=-1, keepdim=True) - - # Element-wise multiply the loss by the active features - loss_target_batch: Float[Tensor, "sub_batch n_features"] = ( - loss_target_batch_raw * filtered_active_features_batch - ) - loss_spd_batch_topk_batch: Float[Tensor, "sub_batch n_features"] = ( - loss_spd_batch_topk_batch_raw * filtered_active_features_batch - ) - loss_spd_sample_topk_batch: Float[Tensor, "sub_batch n_features"] = ( - loss_spd_sample_topk_batch_raw * filtered_active_features_batch - ) - - # Count the number of times each feature was active - feature_counts += filtered_active_features_batch.sum(dim=0) - - # Add to the losses - loss_target += loss_target_batch.sum(dim=0) - loss_spd_batch_topk += loss_spd_batch_topk_batch.sum(dim=0) - loss_spd_sample_topk += loss_spd_sample_topk_batch.sum(dim=0) - n_samples_acc += batch.shape[0] - - if n_samples_acc > print_every_counter: - print(f"n_samples_acc: {n_samples_acc}, n_samples: {n_samples}") - print_every_counter += 10_000 - - # Normalize the losses by the number of samples each feature was active for and n_features - loss_target /= feature_counts * spd_model.config.n_features - loss_spd_batch_topk /= feature_counts * spd_model.config.n_features - loss_spd_sample_topk /= feature_counts * spd_model.config.n_features - - print(f"n_samples_acc: {n_samples_acc}, n_samples: {n_samples}") - return ( - loss_target.detach().cpu(), - loss_spd_batch_topk.detach().cpu(), - loss_spd_sample_topk.detach().cpu(), - ) - - -def collect_average_components_per_feature( - model_fn: Callable[ - [Float[Tensor, "batch n_instances n_features"]], - SPDOutputs, - ], - dataset: ResidualMLPDataset, - device: str, - n_features: int, - batch_size: int, - n_samples: int, - exactly_one_active: bool = True, -) -> Float[Tensor, " n_features"]: - """Collect the average number of components active per feature when that feature is active.""" - # Initialize counters - feature_active_count = torch.zeros(n_features, device=device) - component_active_count = torch.zeros(n_features, device=device) - - for _ in tqdm(range(n_samples // batch_size)): - # Generate batch with no zero samples - batch = dataset._generate_multi_feature_batch(batch_size) - batch = batch.to(device) - - # n_instances should be 1 - assert batch.shape[1] == 1 - - # Get which components were active for each feature - topk_mask_raw: Float[Tensor, "batch n_instances C"] = model_fn(batch).mask - - batch: Float[Tensor, "batch n_features"] = batch.squeeze(1) - topk_mask: Float[Tensor, "batch C"] = topk_mask_raw.squeeze(1) - - active_features_batch: Float[Tensor, "batch n_features"] = (batch != 0).float() - - if exactly_one_active: - # Get the indices of samples where there is exactly one feature active - exactly_one_active_features_batch: Float[Tensor, "batch n_features"] = ( - active_features_batch.sum(dim=-1) == 1 - ) - # Filter the batch to only include the samples where there is exactly one feature active - # batch = batch[exactly_one_active_features_batch] - active_features_batch = active_features_batch[exactly_one_active_features_batch] - topk_mask = topk_mask[exactly_one_active_features_batch] - - # Get the number of components active for each sample - n_components_active: Float[Tensor, "batch 1"] = topk_mask.sum(dim=-1, keepdim=True) - - # Multiply the number of components active by the number of features active - n_components_active_times_features_active_batch: Float[Tensor, "batch n_features"] = ( - n_components_active * active_features_batch - ) - - # Sum over the batch - n_components_active_times_features_active = ( - n_components_active_times_features_active_batch.sum(dim=0) - ) - - # Get the number of features that were active - n_features_active = active_features_batch.sum(dim=0) - - # Add to the counters - feature_active_count += n_features_active - component_active_count += n_components_active_times_features_active - - # Calculate average components per feature - avg_components = component_active_count / feature_active_count - return avg_components - - -def analyze_per_feature_performance( - model_fn: Callable[[Float[Tensor, "batch n_instances"]], Float[Tensor, "batch n_instances"]], - model_config: ResidualMLPConfig | ResidualMLPSPDConfig, - device: str, - batch_size: int = 128, - target_model_fn: Callable[ - [Float[Tensor, "batch n_instances"]], Float[Tensor, "batch n_instances"] - ] - | None = None, -) -> Float[Tensor, " n_features"]: - """For each feature, run a bunch where only that feature varies, then measure loss""" - n_features = model_config.n_features - n_instances = model_config.n_instances - losses = torch.zeros(model_config.n_features) - assert n_instances == 1, "Only one instance supported for now" - label_fn = F.relu if model_config.act_fn_name == "relu" else F.gelu - for i in range(model_config.n_features): - batch_i = torch.zeros((batch_size, n_instances, n_features), device=device) - batch_i[:, 0, i] = torch.linspace(-1, 1, batch_size) - model_output = model_fn(batch_i) - if target_model_fn is not None: - # Get the labels from the target model if it's provided - labels_i = target_model_fn(batch_i) - else: - labels_i = torch.zeros((batch_size, n_instances, n_features), device=device) - labels_i[:, 0, i] = batch_i[:, 0, i] + label_fn(batch_i[:, 0, i]) - loss = F.mse_loss(model_output, labels_i) - losses[i] = loss.item() - losses = losses.detach().cpu() - return losses - - -def plot_per_feature_performance( - losses: Float[Tensor, " n_features"], - sorted_indices: Float[Tensor, " n_features"] | None = None, - ax: plt.Axes | None = None, - label: str | None = None, - color: str = "C0", - zorder: int = 0, - show_xticks: bool = False, -): - sorted_indices = sorted_indices if sorted_indices is not None else losses.argsort() - # Plot the losses as bar chart with x labels corresponding to feature index - if ax is None: - fig, ax = plt.subplots(figsize=(15, 5)) - features = torch.arange(losses.shape[0]) - ax.bar(features, losses[sorted_indices], label=label, zorder=zorder, color=color) - if show_xticks: - ax.set_xticks(features, features[sorted_indices].numpy(), fontsize=6, rotation=90) - else: - ax.set_xticks([]) - ax.set_xticklabels([]) - ax.set_xlabel("Input feature index (sorted by target model MSE)") - ax.set_ylabel("MSE w.r.t true labels") - - -def plot_virtual_weights_target_spd( - target_model: ResidualMLPModel, model: ResidualMLPSPDModel, device: str -): - target_virtual_weights = calculate_virtual_weights(target_model, device) - spd_virtual_weights = spd_calculate_virtual_weights(model=model, device=device) - instance_idx = 0 - fig = plt.figure(constrained_layout=True, figsize=(10, 2 * model.config.C + 8)) - gs = fig.add_gridspec(ncols=2, nrows=model.config.C + 1 + 2) - ax_ID = fig.add_subplot(gs[:2, :]) - W_E_W_U = einops.einsum( - target_virtual_weights["W_E"][instance_idx], - target_virtual_weights["W_U"][instance_idx], - "n_features1 d_embed, d_embed n_features2 -> n_features1 n_features2", - ) - plot_matrix( - ax_ID, - W_E_W_U, - "Virtual weights $W_E W_U$", - "Features", - "Features", - colorbar_format="%.2f", - ) - norm = Normalize(vmin=-1, vmax=1) - ax1 = fig.add_subplot(gs[2, 0]) - ax2 = fig.add_subplot(gs[2, 1]) - in_conns = target_virtual_weights["in_conns"][instance_idx].cpu().detach() - out_conns = target_virtual_weights["out_conns"][instance_idx].cpu().detach() - plot_matrix( - ax1, - in_conns.T, - "Virtual input weights $(W_E W_{in})^T$", - "Features", - "(Target Model) Neurons", - colorbar_format="%.2f", - norm=norm, - ) - plot_matrix( - ax2, - out_conns, - "Virtual output weights $W_{out} W_U$", - "Features", - "Neurons", - colorbar_format="%.2f", - norm=norm, - ) - for c in range(model.config.C): - ax1 = fig.add_subplot(gs[3 + c, 0]) - ax2 = fig.add_subplot(gs[3 + c, 1]) - plot_matrix( - ax1, - spd_virtual_weights["in_conns"][instance_idx, c].T, - "$(W_E W_{in})^T$", - "Features", - f"c={c} Neurons", - colorbar_format="%.2f", - norm=norm, - ) - plot_matrix( - ax2, - spd_virtual_weights["out_conns"][instance_idx, c], - "$W_{out} W_U$", - "Features", - "Neurons", - colorbar_format="%.2f", - norm=norm, - ) - return fig - - -def plot_resid_vs_mlp_out( - target_model: ResidualMLPModel, - device: str, - ax: plt.Axes, - topk_model_fn: Callable[ - [ - Float[Tensor, "batch n_instances n_features"], - Float[Tensor, "batch n_instances C"] | None, - ], - SPDOutputs, - ] - | None = None, - subnet_indices: Float[Tensor, " C"] | None = None, - instance_idx: int = 0, - feature_idx: int = 0, -): - tied_weights = True - if not torch.allclose(target_model.W_U.data, target_model.W_E.data.transpose(-2, -1)): - print("Warning: W_E and W_U are not tied") - tied_weights = False - batch_size = 1 - batch_idx = 0 - n_instances = target_model.config.n_instances - n_features = target_model.config.n_features - batch = torch.zeros(batch_size, n_instances, n_features, device=device) - batch[:, instance_idx, feature_idx] = 1 - # Target model full output - out = target_model(batch)[batch_idx, instance_idx, :].cpu().detach() - # Target model residual stream contribution - W_E = target_model.W_E[instance_idx].cpu().detach() - W_U = target_model.W_U[instance_idx].cpu().detach() - W_EU = einops.einsum(W_E, W_U, "f1 d_mlp, d_mlp f2 -> f1 f2")[feature_idx, :] - # Compute MLP-out - mlp_out = out - W_EU - # Mask for noise & correlation - mask = torch.ones_like(out).bool() - mask[feature_idx] = False - noise_out = F.mse_loss(out[mask], torch.zeros_like(out[mask])).item() - corr = np.corrcoef(mlp_out[mask], W_EU[mask])[0, 1] - ax.axhline(0, color="grey", linestyle="-", lw=0.5) - ax.plot([], [], c="white", label=f"Full target model noise level ~ {noise_out:.2e}") - ax.plot( - mlp_out, - color="C0", - label=f"Target MLP output.\n" - f"Corr w/ resid (excluding feature {feature_idx}): {corr:.2f}", - lw=2, - ) - noise_W_EU = F.mse_loss(W_EU[mask], torch.zeros_like(W_EU[mask])).item() - ax.plot( - W_EU, - color="C1", - label=f"Target resid contribution (W_E W_U)\n" f"Noise level ~ {noise_W_EU:.2e}", - ) - # If topk_model_fn is provided, use it to get the SPD model output - if topk_model_fn is not None: - # Get the SPD resid contribution by running with no subnetworks. This should be equivalent - # to W_E W_U and but doesn't require access to the ResidMLP SPD model. - topk_mask = torch.zeros_like(batch) - spd_WEU = topk_model_fn(batch, topk_mask).spd_model_masked_output[ - batch_idx, instance_idx, : - ] - spd_WEU = spd_WEU.detach().cpu() - if tied_weights: - assert torch.allclose(spd_WEU, W_EU), "Tied weights but W_EU != SPD resid contribution" - else: - ax.plot( - spd_WEU, - color="C4", - label="SPD resid contribution (no subnets).\n" - "Note that embeddings are untied and numbers in legend are not applicable", - ls=":", - ) - # Get SPD forward pass, either from subnet_indices or attribution-based topk_mask - if subnet_indices is None: - topk_mask = None - else: - topk_mask = torch.zeros_like(batch) - topk_mask[:, :, subnet_indices] = 1 - topk_out = topk_model_fn(batch, topk_mask).spd_model_masked_output[ - batch_idx, instance_idx, : - ] - topk_mlp_out = topk_out.detach().cpu() - spd_WEU - topk_mlp_out_mse = F.mse_loss(topk_mlp_out, mlp_out).item() - corr = np.corrcoef(topk_mlp_out[mask], W_EU[mask])[0, 1] - ax.plot( - topk_mlp_out, - color="C2", - label=f"SPD MLP output (topk) MSE: {topk_mlp_out_mse:.1e}.\n" - f"Corr w/ resid (excluding feature {feature_idx}): {corr:.2f}", - ls="--", - ) - # Full forward pass - topk_mask = torch.ones_like(batch) - full_out = topk_model_fn(batch, topk_mask).spd_model_masked_output[ - batch_idx, instance_idx, : - ] - full_mlp_out = full_out.detach().cpu() - spd_WEU - full_mlp_out_mse = F.mse_loss(full_mlp_out, mlp_out).item() - corr = np.corrcoef(full_mlp_out[mask], W_EU[mask])[0, 1] - ax.plot( - full_mlp_out, - color="C3", - label=f"SPD MLP output (full) MSE: {full_mlp_out_mse:.1e}.\n" - f"Corr w/ resid (excluding feature {feature_idx}): {corr:.2f}", - ls=":", - ) - # Can we scale W_EU by a scalar to make it match the model output in mask? - # def difference(alpha): - # return F.mse_loss( float(alpha) * W_EU[mask], out[mask]) - # from scipy.optimize import minimize - # res = minimize(difference, x0=0.1, method="Nelder-Mead") - # ax.plot(W_EU * float(res.x[0]), color="C2", label="Scaled W_E W_U") - ax.legend() - ax.set_title(f"Instance {instance_idx}, feature {feature_idx}") - - -def plot_per_feature_performance_fig( - loss_target: Float[Tensor, " n_features"], - loss_spd_batch_topk: Float[Tensor, " n_features"], - loss_spd_sample_topk: Float[Tensor, " n_features"], - config: Config, - color_map: dict[str, str], -) -> plt.Figure: - fig, axs = plt.subplots(2, 1, figsize=(15, 10)) - axs = np.array(axs) - - indices = loss_target.argsort() - topk = int(config.topk) if config.topk is not None and config.topk == 1 else config.topk - plot_per_feature_performance( - losses=loss_spd_batch_topk, - sorted_indices=indices, - ax=axs[1], - label=f"APD (per-batch top-k={topk})", - color=color_map["apd_topk"], - ) - - plot_per_feature_performance( - losses=loss_target, - sorted_indices=indices, - ax=axs[1], - label="Target model", - color=color_map["target"], - ) - axs[1].legend(loc="upper left") - - plot_per_feature_performance( - losses=loss_spd_sample_topk, - sorted_indices=indices, - ax=axs[0], - label="APD (per-sample top-k=1)", - color=color_map["apd_topk"], - ) - plot_per_feature_performance( - losses=loss_target, - sorted_indices=indices, - ax=axs[0], - label="Target model", - color=color_map["target"], - ) - - axs[0].legend(loc="upper left", fontsize=12) - axs[1].legend(loc="upper left", fontsize=12) - - # Use the max y-axis limit for both subplots - max_ylim = max(axs[0].get_ylim()[1], axs[1].get_ylim()[1]) - axs[0].set_ylim(0, max_ylim) - axs[1].set_ylim(0, max_ylim) - - # Remove the top and right spines - for ax in axs: - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - # # Increase the fontsize of the xlabel and ylabel - for ax in axs: - ax.xaxis.label.set_fontsize(12) - ax.yaxis.label.set_fontsize(12) - - return fig - - -def plot_avg_components_scatter( - losses_spd_wrt_target: Float[Tensor, " n_features"], - avg_components: Float[Tensor, " n_features"], -) -> plt.Figure: - fig, ax = plt.subplots(figsize=(15, 5)) - ax.scatter( - losses_spd_wrt_target.abs().detach().cpu(), - avg_components.detach().cpu(), - ) - ax.set_xlabel("MSE between APD (per-sample top-k=1) and target model outputs") - ax.set_ylabel("Average number of active components") - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - # Increase the fontsize of the xlabel and ylabel - ax.xaxis.label.set_fontsize(12) - ax.yaxis.label.set_fontsize(12) - return fig - - -@dataclass -class ScrubbedLosses: - loss_scrubbed: Float[Tensor, " n_features"] - loss_antiscrubbed: Float[Tensor, " n_features"] - loss_random: Float[Tensor, " n_features"] - loss_spd: Float[Tensor, " n_features"] - loss_zero: Float[Tensor, " n_features"] - loss_monosemantic: Float[Tensor, " n_features"] - n_samples: int - - -def get_scrubbed_losses( - dataset: ResidualMLPDataset, - model: ResidualMLPSPDModel, - target_model: ResidualMLPModel, - device: str, - config: Config, - top1_model_fn: Callable[ - [ - Float[Tensor, "batch n_instances n_features"], - Float[Tensor, "batch n_instances C"] | None, - ], - SPDOutputs, - ], - spd_model_fn: Callable[ - [ - Float[Tensor, "batch n_instances n_features"], # batch - PositiveFloat | None, # topk - bool, # batch_topk - ], - SPDOutputs, - ], - n_batches: int, -) -> ScrubbedLosses: - assert model.config.n_instances == 1, "Can only handle n_instances = 1 for now" - - # Dictionary feature_idx -> subnet_idx - subnet_indices = get_feature_subnet_map(top1_model_fn, device, model.config, instance_idx=0) - - batch_size = config.batch_size - - n_samples = 0 - # Initialize tensors to store all losses - all_loss_scrubbed = [] - all_loss_antiscrubbed = [] - all_loss_random = [] - all_loss_spd = [] - all_loss_zero = [] - all_loss_monosemantic = [] - for _ in tqdm(range(n_batches)): - # In the future this will be merged into generate_batch - batch = dataset._generate_multi_feature_batch(batch_size) - if isinstance(dataset, ResidualMLPDataset) and dataset.label_fn is not None: - labels = dataset.label_fn(batch) - else: - labels = batch.clone().detach() - - # Count the number of samples in which there is at least one active feature - n_samples += int((batch != 0).any(dim=-1).sum().item()) - - batch = batch.to(device) - active_features = torch.where(batch != 0) - # Randomly assign 0 or 1 to topk mask - random_topk_mask = torch.randint( - 0, 2, (batch_size, model.config.n_instances, model.config.C) - ) - scrubbed_topk_mask = torch.randint( - 0, 2, (batch_size, model.config.n_instances, model.config.C) - ) - antiscrubbed_topk_mask = torch.randint( - 0, 2, (batch_size, model.config.n_instances, model.config.C) - ) - for b, i, f in zip(*active_features, strict=False): - s = subnet_indices[f.item()] - scrubbed_topk_mask[b, i, s] = 1 - antiscrubbed_topk_mask[b, i, s] = 0 - topk = config.topk - batch_topk = config.batch_topk - - out_spd = spd_model_fn(batch, topk, batch_topk).spd_model_masked_output - out_random = top1_model_fn(batch, random_topk_mask).spd_model_masked_output - out_scrubbed = top1_model_fn(batch, scrubbed_topk_mask).spd_model_masked_output - out_antiscrubbed = top1_model_fn(batch, antiscrubbed_topk_mask).spd_model_masked_output - out_target = target_model(batch) - # Monosemantic baseline - out_monosemantic = batch.clone() - d_mlp = target_model.config.d_mlp * target_model.config.n_layers # type: ignore - out_monosemantic[..., :d_mlp] = labels[..., :d_mlp] - - # Calc MSE losses - all_loss_scrubbed.append( - ((out_scrubbed - out_target) ** 2).mean(dim=-1).flatten().detach().cpu() - ) - all_loss_antiscrubbed.append( - ((out_antiscrubbed - out_target) ** 2).mean(dim=-1).flatten().detach().cpu() - ) - all_loss_random.append( - ((out_random - out_target) ** 2).mean(dim=-1).flatten().detach().cpu() - ) - all_loss_spd.append(((out_spd - out_target) ** 2).mean(dim=-1).flatten().detach().cpu()) - all_loss_zero.append( - ((torch.zeros_like(out_target) - out_target) ** 2).mean(dim=-1).flatten().detach().cpu() - ) - all_loss_monosemantic.append( - ((out_monosemantic - out_target) ** 2).mean(dim=-1).flatten().detach().cpu() - ) - - # Concatenate all batches - loss_scrubbed = torch.cat(all_loss_scrubbed) - loss_antiscrubbed = torch.cat(all_loss_antiscrubbed) - loss_random = torch.cat(all_loss_random) - loss_spd = torch.cat(all_loss_spd) - loss_zero = torch.cat(all_loss_zero) - loss_monosemantic = torch.cat(all_loss_monosemantic) - - print(f"Loss SPD: {loss_spd.mean().item():.6f}") - print(f"Loss scrubbed: {loss_scrubbed.mean().item():.6f}") - print(f"Loss antiscrubbed: {loss_antiscrubbed.mean().item():.6f}") - print(f"Loss monosemantic: {loss_monosemantic.mean().item():.6f}") - print(f"Loss random: {loss_random.mean().item():.6f}") - print(f"Loss zero: {loss_zero.mean().item():.6f}") - return ScrubbedLosses( - loss_scrubbed=loss_scrubbed, - loss_antiscrubbed=loss_antiscrubbed, - loss_random=loss_random, - loss_spd=loss_spd, - loss_zero=loss_zero, - loss_monosemantic=loss_monosemantic, - n_samples=n_samples, - ) - - -def plot_scrub_losses( - losses: ScrubbedLosses, config: Config, color_map: dict[str, str], n_batches: int -) -> plt.Figure: - fig, ax = plt.subplots(figsize=(15, 5)) - log_bins: list[float] = np.geomspace(1e-7, losses.loss_zero.max().item(), 50).tolist() # type: ignore - ax.hist( - losses.loss_spd, - bins=log_bins, - label="APD (top-k)", - histtype="step", - lw=2, - color=color_map["apd_topk"], - ) - ax.axvline(losses.loss_spd.mean().item(), color=color_map["apd_topk"], linestyle="--") - ax.hist( - losses.loss_scrubbed, - bins=log_bins, - label="APD (scrubbed)", - histtype="step", - lw=2, - color=color_map["apd_scrubbed"], - ) - ax.axvline(losses.loss_scrubbed.mean().item(), color=color_map["apd_scrubbed"], linestyle="--") - ax.hist( - losses.loss_antiscrubbed, - bins=log_bins, - label="APD (anti-scrubbed)", - histtype="step", - lw=2, - color=color_map["apd_antiscrubbed"], - ) - ax.axvline( - losses.loss_antiscrubbed.mean().item(), color=color_map["apd_antiscrubbed"], linestyle="--" - ) - # ax.hist(loss_random, bins=log_bins, label="APD (random)", histtype="step") - # ax.hist(loss_zero, bins=log_bins, label="APD (zero)", histtype="step") - ax.axvline( - losses.loss_monosemantic.mean().item(), - color=color_map["baseline_monosemantic"], - linestyle="-", - label="Monosemantic neuron solution", - ) - ax.legend() - ax.set_ylabel("Count") - ax.set_xlabel("MSE loss with target model output") - ax.set_xscale("log") - - # Remove spines - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - # Increase fontsize - ax.xaxis.label.set_fontsize(12) - ax.yaxis.label.set_fontsize(12) - return fig - - -def plot_feature_response_with_subnets( - topk_model_fn: Callable[ - [Float[Tensor, "batch n_instances n_features"], Float[Tensor, "batch n_instances C"]], - SPDOutputs, - ], - device: str, - model_config: ResidualMLPSPDConfig, - feature_idx: int = 0, - subnet_idx: int = 0, - instance_idx: int = 0, - ax: plt.Axes | None = None, - batch_size: int | None = None, - plot_type: Literal["line", "errorbar"] = "errorbar", - color_map: dict[str, str] | None = None, -) -> dict[str, plt.Figure]: - n_instances = model_config.n_instances - n_features = model_config.n_features - batch_size = batch_size or n_features - C = model_config.C - - if color_map is None: - color_map = { - "apd_topk": "C1", - "apd_scrubbed": "C4", - "apd_antiscrubbed": "C2", - "target": "C0", - } - - if ax is None: - _, ax = plt.subplots(constrained_layout=True, figsize=(10, 5)) - fig = ax.figure - - batch = torch.zeros(batch_size, n_instances, n_features, device=device) - batch[:, instance_idx, feature_idx] = 1 - topk_mask_blue = torch.zeros(batch_size, n_instances, C, device=device) - topk_mask_red = torch.zeros(batch_size, n_instances, C, device=device) - topk_mask_blue[:, :, subnet_idx] = 1 - for s in range(batch_size): - # Randomly ablate half the features - half_n_features = n_features // 2 - choice = torch.randperm(n_features - 1)[:half_n_features] - # Exclude feature_idx from choice - choice[choice >= subnet_idx] += 1 - topk_mask_blue[s, :, choice] = 1 - topk_mask_red[s, :, choice] = 1 - assert torch.allclose( - topk_mask_blue[:, :, subnet_idx], torch.ones_like(topk_mask_blue[:, :, subnet_idx]) - ) - assert torch.allclose( - topk_mask_red[:, :, subnet_idx], torch.zeros_like(topk_mask_red[:, :, subnet_idx]) - ) - zeros_topk_mask = torch.zeros(batch_size, n_instances, C, device=device) - ones_topk_mask = torch.ones(batch_size, n_instances, C, device=device) - out_WE_WU_only = topk_model_fn(batch, zeros_topk_mask).spd_model_masked_output[ - :, instance_idx, : - ] - - out_red = topk_model_fn(batch, topk_mask_red) - out_blue = topk_model_fn(batch, topk_mask_blue) - out_spd = topk_model_fn(batch, ones_topk_mask).spd_model_masked_output[:, instance_idx, :] - mlp_out_blue_spd = out_blue.spd_model_masked_output[:, instance_idx, :] - out_WE_WU_only - mlp_out_red_spd = out_red.spd_model_masked_output[:, instance_idx, :] - out_WE_WU_only - mlp_out_target = out_blue.target_model_output[:, instance_idx, :] - out_WE_WU_only - mlp_out_spd = out_spd - out_WE_WU_only - - x = torch.arange(n_features) - - if plot_type == "errorbar": - # Calculate means and stds across batch dimension - blue_mean = mlp_out_blue_spd.mean(dim=0).detach().cpu() - blue_std = mlp_out_blue_spd.std(dim=0).detach().cpu() - red_mean = mlp_out_red_spd.mean(dim=0).detach().cpu() - red_std = mlp_out_red_spd.std(dim=0).detach().cpu() - mlp_out_spd_mean = mlp_out_spd.mean(dim=0).detach().cpu() - - # Plot errorbars - ax.errorbar( - x, - blue_mean, - yerr=blue_std, - color=color_map["apd_scrubbed"], - label="APD (scrubbed)", - fmt="o", - markersize=2, - ) - ax.errorbar( - x, - red_mean, - yerr=red_std, - color=color_map["apd_antiscrubbed"], - label="APD (anti-scrubbed)", - fmt="o", - markersize=2, - ) - - # Plot target model output - yt = mlp_out_target[0, :].detach().cpu() - ax.scatter( - x, yt, color=color_map["target"], label="Target model\n≈APD (top-k)", marker="x", s=10 - ) - # Plot non-scrubbed SPD model output - # ax.scatter( - # x, - # mlp_out_spd_mean.detach().cpu(), - # color=color_map["apd_topk"], - # label="APD (top-k)", - # marker=".", - # s=10, - # ) - # Remove all axes lines - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xticks([]) - elif plot_type == "line": - cmap1 = plt.get_cmap("Purples") - cmap2 = plt.get_cmap("Oranges") - for s in range(batch_size): - yb = mlp_out_blue_spd[s, :].detach().cpu() - yr = mlp_out_red_spd[s, :].detach().cpu() - if plot_type == "line": - ax.plot(x, yb, color=cmap1(s / batch_size), lw=0.3) - ax.plot(x, yr, color=cmap2(s / batch_size), lw=0.3) - ax.plot([], [], color=cmap1(0), label="APD (scrubbed)") - ax.plot([], [], color=cmap2(0), label="APD (anti-scrubbed)") - yt = mlp_out_target[0, :].detach().cpu() - ax.plot(x, yt, color="red", lw=0.5, label="Target model") - else: - raise ValueError(f"Invalid plot type: {plot_type}") - - ax.set_ylabel("MLP output (forward pass minus W_E W_U contribution)") - ax.set_xlabel("Output index") - - # I only need 0, feature_idx, and 100 as x ticks - ax.set_xticks([0, feature_idx, 100]) - ax.set_xticklabels(["0", str(feature_idx), "100"]) - # ax.set_title(f"APD model when ablating parameter components. One-hot $x_{{{feature_idx}}}=1$") - ax.legend() - assert isinstance(fig, plt.Figure) - return {"feature_response_with_subnets": fig} - - -def get_feature_subnet_map( - top1_model_fn: Callable[ - [ - Float[Tensor, "batch n_instances n_features"], - Float[Tensor, "batch n_instances C"] | None, - ], - SPDOutputs, - ], - device: str, - model_config: ResidualMLPConfig | ResidualMLPSPDConfig, - instance_idx: int = 0, -) -> dict[int, int]: - n_instances = model_config.n_instances - n_features = model_config.n_features - batch_size = n_features - batch = torch.zeros(batch_size, n_instances, n_features, device=device) - batch[torch.arange(n_features), instance_idx, torch.arange(n_features)] = 1 - top1_out = top1_model_fn(batch, None) - top1_mask = top1_out.mask[:, instance_idx, :] - subnet_indices = { - int(feature_idx.item()): int(subnet_idx.item()) - for feature_idx, subnet_idx in top1_mask.nonzero() - } - return subnet_indices diff --git a/spd/experiments/resid_mlp/resid_mlp_topk_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml similarity index 81% rename from spd/experiments/resid_mlp/resid_mlp_topk_config.yaml rename to spd/experiments/resid_mlp/resid_mlp_config.yaml index 669d88f..2772de0 100644 --- a/spd/experiments/resid_mlp/resid_mlp_topk_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -4,26 +4,21 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: true seed: 0 -# topk: 1 -topk: 1.28 -m: null -C: 130 -pnorm: null -batch_topk: true +m: 50 param_match_coeff: 1.0 -topk_recon_coeff: 1.0 +masked_recon_coeff: 1.0 act_recon_coeff: 1.0 post_relu_act_recon: true -schatten_pnorm: 0.9 -schatten_coeff: 1e1 -lr: 1e-3 +pnorm: 0.9 +lp_sparsity_coeff: 1.0 batch_size: 256 steps: 10_000 -print_freq: 500 image_freq: 5_000 +print_freq: 500 save_freq: 10_000 -lr_warmup_pct: 0.01 +lr: 1e-3 lr_schedule: cosine +lr_warmup_pct: 0.01 image_on_first_step: false task_config: task_name: residual_mlp @@ -39,24 +34,21 @@ task_config: # wandb_run_name_prefix: "" # unit_norm_matrices: false # seed: 0 -# topk: 1.28 # bs=256 -# m: null -# C: 200 -# pnorm: null -# batch_topk: true +# m: 25 # param_match_coeff: 1.0 -# topk_recon_coeff: 2.0 +# masked_recon_coeff: 2.0 # act_recon_coeff: 1.0 -# schatten_pnorm: 0.9 -# schatten_coeff: 7 -# lr: 1e-3 +# post_relu_act_recon: true +# pnorm: 0.9 +# lp_sparsity_coeff: 1.0 # batch_size: 256 # steps: 10_000 -# print_freq: 500 # image_freq: 10_000 +# print_freq: 500 # save_freq: 10_000 -# lr_warmup_pct: 0.01 +# lr: 1e-3 # lr_schedule: cosine +# lr_warmup_pct: 0.01 # image_on_first_step: false # task_config: # task_name: residual_mlp diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 4f2c7cc..e90f65c 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,36 +13,21 @@ import wandb import yaml from jaxtyping import Float -from matplotlib.colors import CenteredNorm -from pydantic import PositiveFloat from torch import Tensor from tqdm import tqdm -from spd.attributions import collect_subnetwork_attributions from spd.configs import Config, ResidualMLPTaskConfig from spd.experiments.resid_mlp.models import ( ResidualMLPModel, ResidualMLPSPDConfig, ResidualMLPSPDModel, ) -from spd.experiments.resid_mlp.plotting import ( - analyze_per_feature_performance, - plot_individual_feature_response, - plot_per_feature_performance, - plot_spd_relu_contribution, - plot_virtual_weights_target_spd, - spd_calculate_diag_relu_conns, -) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset from spd.log import logger -from spd.module_utils import collect_nested_module_attrs -from spd.plotting import plot_subnetwork_attributions_statistics, plot_subnetwork_correlations from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( - COLOR_PALETTE, DatasetGeneratedDataLoader, load_config, - run_spd_forward_pass, set_seed, ) from spd.wandb_utils import init_wandb @@ -66,35 +51,11 @@ def get_run_name( else: run_suffix = get_common_run_name_suffix(config) run_suffix += f"scale{init_scale}_ft{n_features}_lay{n_layers}_resid{d_resid}_mlp{d_mlp}" - if m is not None: - run_suffix += f"_m{m}" return config.wandb_run_name_prefix + run_suffix -def calc_n_active_features_per_subnet( - model: ResidualMLPSPDModel, cutoff: float, device: str -) -> tuple[Float[Tensor, "n_instances C"], Float[Tensor, "n_instances n_features"]]: - """Calculate the number of active features per subnet (and per instance if n_instances > 1).""" - n_active_features_per_subnet: Float[Tensor, "n_instances C"] = torch.zeros( - (model.config.n_instances, model.C), device=device - ) - active_feature_counts_per_subnet: Float[Tensor, "n_instances n_features"] = torch.zeros( - (model.config.n_instances, model.config.n_features), device=device - ) - for c in range(model.C): - relu_conns: Float[Tensor, "n_instances n_features d_mlp"] = spd_calculate_diag_relu_conns( - model, device, k_select=c - ) - # Count the number of features for which each subnet fires beyond the cutoff - above_cutoff = relu_conns.max(dim=-1).values > cutoff - n_active_features_per_subnet[:, c] = above_cutoff.sum(dim=-1) - active_feature_counts_per_subnet[:] += above_cutoff - - return n_active_features_per_subnet, active_feature_counts_per_subnet - - def plot_subnetwork_attributions( - attribution_scores: Float[Tensor, "batch n_instances C"], + attribution_scores: Float[Tensor, "batch n_instances m"], out_dir: Path | None, step: int | None, ) -> plt.Figure: @@ -139,130 +100,6 @@ def plot_subnetwork_attributions( return fig -def plot_multiple_component_weights( - model: ResidualMLPSPDModel, - out_dir: Path | None, - step: int | None = None, -) -> plt.Figure: - """Plot each component weight matrix.""" - all_params = collect_nested_module_attrs(model, "component_weights") - # Each param (of which there are n_layers): [k, n_features, n_features] - n_params = len(all_params) - param_names = list(all_params.keys()) - n_instances = model.config.n_instances - C = model.C - - # Find global min and max for normalization - all_values = [] - for param_name in param_names: - param_values = all_params[param_name].detach().cpu().numpy() - all_values.append(param_values) - all_values_concat = np.concatenate([v.flatten() for v in all_values]) - vmax = np.abs(all_values_concat).max() - norm = CenteredNorm(vcenter=0, halfrange=vmax) - - fig, axs = plt.subplots( - n_instances * n_params, - C, - figsize=(2 * C, n_instances * n_params), - constrained_layout=False, - ) - axs = np.array(axs) - - for instance_idx in range(n_instances): - for param_idx in range(n_params): - param_name = param_names[param_idx] - for subnet_idx in range(C): - col_idx = subnet_idx - row_idx = instance_idx * n_params + param_idx - - ax = axs[row_idx, col_idx] # type: ignore - param = all_params[param_name][instance_idx, subnet_idx].detach().cpu().numpy() - # If it's a bias with a single dimension, unsqueeze it - if param.ndim == 1: - param = param[:, None] - - # Set aspect ratio based on parameter dimensions - height, width = param.shape - aspect = width / height - - im = ax.matshow(param, cmap="RdBu", norm=norm, aspect=aspect) - ax.set_xticks([]) - ax.set_yticks([]) - - if col_idx == 0: - ax.set_ylabel( - f"Inst.{instance_idx}.{param_name}", - rotation=0, - ha="right", - va="center", - ) - - if row_idx == ((n_instances * n_params) - 1): - ax.set_xlabel(f"Subnet {subnet_idx}", rotation=0, ha="center", va="top") - - # Add colorbar - fig.colorbar(im, ax=axs.ravel().tolist(), location="right") # type: ignore - - title_text = "Subnet Parameters" - if step is not None: - title_text += f" (Step {step})" - fig.suptitle(title_text) - if out_dir: - fig.savefig(out_dir / f"component_weights_s{step}.png", dpi=200) - return fig - - -def plot_subnet_categories( - model: ResidualMLPSPDModel, device: str, cutoff: float = 4e-2 -) -> plt.Figure: - n_active_features_per_subnet, active_feature_counts_per_subnet = ( - calc_n_active_features_per_subnet(model, cutoff=cutoff, device=device) - ) - n_dead_subnets = (n_active_features_per_subnet == 0).sum(dim=-1).detach().tolist() - n_monosemantic_subnets = (n_active_features_per_subnet == 1).sum(dim=-1).detach().tolist() - n_duosemantic_subnets = (n_active_features_per_subnet == 2).sum(dim=-1).detach().tolist() - n_polysemantic_subnets = (n_active_features_per_subnet > 2).sum(dim=-1).detach().tolist() - n_unique_features_represented = ( - (active_feature_counts_per_subnet > 0).sum(dim=-1).detach().tolist() - ) - - n_instances = len(n_dead_subnets) - fig, ax = plt.subplots( - nrows=1, ncols=n_instances, figsize=(5 * n_instances, 5), constrained_layout=True - ) - axs = np.array([ax]) if n_instances == 1 else np.array(ax) - categories = ["Dead", "Mono", "Duo", "Poly", "Represented"] - - for i in range(n_instances): - counts = [ - n_dead_subnets[i], - n_monosemantic_subnets[i], - n_duosemantic_subnets[i], - n_polysemantic_subnets[i], - n_unique_features_represented[i], - ] - bars = axs[i].bar(categories, counts) - axs[i].set_xlabel("Category") - axs[i].set_ylabel("Count") - axs[i].set_title(f"Subnet Categories (Instance {i})") - - # Add numbers in the middle of the bars - for bar, count in zip(bars, counts, strict=False): - height = bar.get_height() - axs[i].text( - bar.get_x() + bar.get_width() / 2.0, - height / 2.0, - f"{count}", - ha="center", - va="center", - color="white", - fontsize=10, - ) - - return fig - - def resid_mlp_plot_results_fn( model: ResidualMLPSPDModel, target_model: ResidualMLPModel, @@ -270,217 +107,12 @@ def resid_mlp_plot_results_fn( out_dir: Path | None, device: str, config: Config, - topk_mask: Float[Tensor, "batch_size C"] | None, - dataloader: DatasetGeneratedDataLoader[ - tuple[Float[Tensor, "batch n_features"], Float[Tensor, "batch d_embed"]] - ] - | None = None, + masks: dict[str, Float[Tensor, "batch_size m"]] | None, **_, ) -> dict[str, plt.Figure]: assert isinstance(config.task_config, ResidualMLPTaskConfig) fig_dict = {} - fig_dict["subnet_categories"] = plot_subnet_categories(model, device) - - ############################################################################################ - # Feature contributions - ############################################################################################ - fig1, fig2 = plot_spd_relu_contribution(model, target_model, device) - fig1.suptitle("How much does each ReLU contribute to each feature?") - fig2.suptitle("How much does each feature route through each ReLU?") - fig_dict["feature_contributions"] = fig1 - fig_dict["relu_contributions"] = fig2 - - fig1, fig2 = plot_spd_relu_contribution(model, target_model, device, k_plot_limit=3) - fig1.suptitle("How much does each ReLU contribute to each feature?") - fig2.suptitle("How much does each feature route through each ReLU?") - fig_dict["cropped_feature_contributions"] = fig1 - fig_dict["cropped_relu_contributions"] = fig2 - - ############################################################################################ - # Individual feature responses + per-feature performance - ############################################################################################ - def spd_model_fn( - batch: Float[Tensor, "batch n_instances n_features"], - topk: PositiveFloat | None = config.topk, - batch_topk: bool = config.batch_topk, - ) -> Float[Tensor, "batch n_instances n_features"]: - assert topk is not None - if config.exact_topk: - assert model.n_instances == 1, "exact_topk only works if n_instances = 1" - topk = ((batch != 0).sum() / batch.shape[0]).item() - return run_spd_forward_pass( - spd_model=model, - config=config, - target_model=target_model, - input_array=batch, - batch_topk=batch_topk, - topk=topk, - distil_from_target=config.distil_from_target, - ).spd_model_masked_output - - def target_model_fn(batch: Float[Tensor, "batch n_instances"]): - return target_model(batch) - - fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), constrained_layout=True) - axes = np.atleast_2d(axes) # type: ignore - plot_individual_feature_response( - model_fn=target_model_fn, - device=device, - model_config=model.config, - ax=axes[0, 0], - ) - plot_individual_feature_response( - model_fn=target_model_fn, - device=device, - model_config=model.config, - sweep=True, - ax=axes[1, 0], - ) - plot_individual_feature_response( - model_fn=spd_model_fn, - device=device, - model_config=model.config, - ax=axes[0, 1], - ) - plot_individual_feature_response( - model_fn=spd_model_fn, - device=device, - model_config=model.config, - sweep=True, - ax=axes[1, 1], - ) - axes[0, 0].set_ylabel(axes[0, 0].get_title()) - axes[1, 0].set_ylabel(axes[1, 0].get_title()) - axes[0, 1].set_ylabel("") - axes[1, 1].set_ylabel("") - axes[0, 0].set_title("Target model") - axes[0, 1].set_title("SPD model") - axes[1, 0].set_title("") - axes[1, 1].set_title("") - axes[0, 0].set_xlabel("") - axes[0, 1].set_xlabel("") - fig_dict["individual_feature_responses"] = fig - - # Plot per-feature performance when setting topk=1 and batch_topk=False - fig, ax1 = plt.subplots(figsize=(15, 5)) - - losses_target = analyze_per_feature_performance( - model_fn=target_model_fn, - model_config=target_model.config, - device=device, - batch_size=config.batch_size, - ) - indices = losses_target.argsort() - fn_without_batch_topk = lambda batch: spd_model_fn(batch, topk=1, batch_topk=False) # type: ignore - losses_spd = analyze_per_feature_performance( - model_fn=fn_without_batch_topk, - model_config=model.config, - device=device, - batch_size=config.batch_size, - ) - - plot_per_feature_performance( - losses=losses_spd, - sorted_indices=indices, - ax=ax1, - label="SPD", - color=COLOR_PALETTE[1], - ) - plot_per_feature_performance( - losses=losses_target, - sorted_indices=indices, - ax=ax1, - label="Target", - color=COLOR_PALETTE[0], - ) - ax1.legend() - - fig_dict["loss_by_feature_topk_1"] = fig - - # Plot per-feature performance when using batch_topk - fig, ax2 = plt.subplots(figsize=(15, 5)) - - target_losses_batch_topk = analyze_per_feature_performance( - model_fn=target_model_fn, - model_config=target_model.config, - device=device, - batch_size=config.batch_size, - ) - - spd_losses_batch_topk = analyze_per_feature_performance( - model_fn=spd_model_fn, - model_config=model.config, - device=device, - batch_size=config.batch_size, - ) - - plot_per_feature_performance( - losses=spd_losses_batch_topk, - sorted_indices=indices, - ax=ax2, - label="SPD", - color=COLOR_PALETTE[1], - ) - plot_per_feature_performance( - losses=target_losses_batch_topk, - sorted_indices=indices, - ax=ax2, - label="Target", - color=COLOR_PALETTE[0], - ) - ax2.legend() - # Use the same y-axis limits as the topk=1 plot - ax2.set_ylim(ax1.get_ylim()) - fig_dict["loss_by_feature_batch_topk"] = fig - - ############################################################################################ - # Virtual weights - ############################################################################################ - - fig = plot_virtual_weights_target_spd(target_model, model, device) - fig_dict["virtual_weights"] = fig - - ############################################################################################ - # Subnetwork attributions - ############################################################################################ - attribution_scores = collect_subnetwork_attributions( - spd_model=model, - config=config, - target_model=target_model, - device=device, - n_instances=model.n_instances, - ) - fig_dict["subnetwork_attributions"] = plot_subnetwork_attributions( - attribution_scores, out_dir, step - ) - - if config.topk is not None: - if dataloader is not None and config.C > 1: - fig_dict_correlations = plot_subnetwork_correlations( - dataloader=dataloader, - target_model=target_model, - spd_model=model, - config=config, - device=device, - ) - fig_dict.update(fig_dict_correlations) - - assert topk_mask is not None - fig_dict_attributions = plot_subnetwork_attributions_statistics(topk_mask=topk_mask) - fig_dict.update(fig_dict_attributions) - - ############################################################################################ - # Subnetwork parameters - ############################################################################################ - - # This can be too big to plot - n_matrix_params = target_model.config.d_mlp * target_model.config.d_embed - if n_matrix_params < 1000: - fig_dict["component_weights"] = plot_multiple_component_weights( - model=model, out_dir=out_dir, step=step - ) - # Save plots to files if out_dir: for k, v in fig_dict.items(): @@ -574,7 +206,6 @@ def main( in_bias=target_model.config.in_bias, out_bias=target_model.config.out_bias, init_scale=config.task_config.init_scale, - C=config.C, m=config.m, ) model = ResidualMLPSPDModel(config=model_config).to(device) diff --git a/spd/experiments/resid_mlp/resid_mlp_sweep_config.yaml b/spd/experiments/resid_mlp/resid_mlp_sweep_config.yaml index 74943d9..19f8c2f 100644 --- a/spd/experiments/resid_mlp/resid_mlp_sweep_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_sweep_config.yaml @@ -8,11 +8,11 @@ parameters: values: [0] lr: values: [1e-2] - topk_recon_coeff: + masked_recon_coeff: values: [1e-1, 1e-2] command: - ${env} - ${interpreter} - ${program} -- spd/experiments/resid_mlp/resid_mlp_topk_config.yaml \ No newline at end of file +- spd/experiments/resid_mlp/resid_mlp_config.yaml \ No newline at end of file diff --git a/spd/experiments/resid_mlp/spd_interp.py b/spd/experiments/resid_mlp/spd_interp.py deleted file mode 100644 index b86053a..0000000 --- a/spd/experiments/resid_mlp/spd_interp.py +++ /dev/null @@ -1,296 +0,0 @@ -# %% -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -from jaxtyping import Float -from pydantic import PositiveFloat -from torch import Tensor - -from spd.configs import ResidualMLPTaskConfig -from spd.experiments.resid_mlp.models import ResidualMLPModel, ResidualMLPSPDModel -from spd.experiments.resid_mlp.plotting import ( - analyze_per_feature_performance, - collect_average_components_per_feature, - collect_per_feature_losses, - get_feature_subnet_map, - get_scrubbed_losses, - plot_avg_components_scatter, - plot_feature_response_with_subnets, - plot_per_feature_performance_fig, - plot_scrub_losses, - plot_spd_feature_contributions_truncated, -) -from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset -from spd.experiments.resid_mlp.resid_mlp_decomposition import plot_subnet_categories -from spd.settings import REPO_ROOT -from spd.utils import ( - COLOR_PALETTE, - SPDOutputs, - run_spd_forward_pass, - set_seed, -) - -color_map = { - "target": COLOR_PALETTE[0], - "apd_topk": COLOR_PALETTE[1], - "apd_scrubbed": COLOR_PALETTE[4], - "apd_antiscrubbed": COLOR_PALETTE[2], # alt: 3 - "baseline_monosemantic": "grey", -} - -out_dir = REPO_ROOT / "spd/experiments/resid_mlp/out/figures/" -out_dir.mkdir(parents=True, exist_ok=True) - -# %% Loading -device = "cuda" if torch.cuda.is_available() else "cpu" -print(f"Using device: {device}") -set_seed(0) # You can change this seed if needed - -use_data_from_files = True -wandb_path = "wandb:spd-resid-mlp/runs/8qz1si1l" # 1 layer 40k steps (R6) topk=1.28 -# wandb_path = "wandb:spd-resid-mlp/runs/9a639c6w" # 1 layer topk=1 -# wandb_path = "wandb:spd-resid-mlp/runs/cb0ej7hj" # 2 layer 2LR4 topk=1.28 -# wandb_path = "wandb:spd-resid-mlp/runs/wbeghftm" # 2 layer topk=1 -# wandb_path = "wandb:spd-resid-mlp/runs/c1q3bs6f" # 2 layer m=1 topk=1.28 (not in paper) - -wandb_id = wandb_path.split("/")[-1] - -# Load the pretrained SPD model -model, config, label_coeffs = ResidualMLPSPDModel.from_pretrained(wandb_path) -assert isinstance(config.task_config, ResidualMLPTaskConfig) - -# Path must be local -target_model, target_model_train_config_dict, target_label_coeffs = ( - ResidualMLPModel.from_pretrained(config.task_config.pretrained_model_path) -) -# Print some basic information about the model -print(f"Number of features: {model.config.n_features}") -print(f"Feature probability: {config.task_config.feature_probability}") -print(f"Embedding dimension: {model.config.d_embed}") -print(f"MLP dimension: {model.config.d_mlp}") -print(f"Number of layers: {model.config.n_layers}") -print(f"Number of subnetworks (C): {model.config.C}") -model = model.to(device) -label_coeffs = label_coeffs.to(device) -target_model = target_model.to(device) -target_label_coeffs = target_label_coeffs.to(device) -assert torch.allclose(target_label_coeffs, label_coeffs) - -n_layers = target_model.config.n_layers - - -# Functions used for various plots -def spd_model_fn( - batch: Float[Tensor, "batch n_instances n_features"], - topk: PositiveFloat | None = config.topk, - batch_topk: bool = config.batch_topk, -) -> SPDOutputs: - assert topk is not None - return run_spd_forward_pass( - spd_model=model, - config=config, - target_model=target_model, - input_array=batch, - batch_topk=batch_topk, - topk=topk, - distil_from_target=config.distil_from_target, - ) - - -def target_model_fn(batch: Float[Tensor, "batch n_instances"]): - return target_model(batch) - - -def top1_model_fn( - batch: Float[Tensor, "batch n_instances n_features"], - topk_mask: Float[Tensor, "batch n_instances C"] | None, -) -> SPDOutputs: - """Top1 if topk_mask is None, else just use provided topk_mask""" - topk_mask = topk_mask.to(device) if topk_mask is not None else None - assert config.topk is not None - return run_spd_forward_pass( - spd_model=model, - config=config, - target_model=target_model, - input_array=batch, - batch_topk=False, - topk=1, - distil_from_target=config.distil_from_target, - mask=topk_mask, - ) - - -dataset = ResidualMLPDataset( - n_instances=model.config.n_instances, - n_features=model.config.n_features, - feature_probability=config.task_config.feature_probability, - device=device, - calc_labels=True, - label_type=target_model_train_config_dict["label_type"], - act_fn_name=target_model.config.act_fn_name, - label_coeffs=target_label_coeffs, - data_generation_type="at_least_zero_active", # We will change this in the for loop -) - -# %% Plot how many subnets are monosemantic, etc. -fig = plot_subnet_categories(model, device, cutoff=4e-2) -# Save the figure -fig.savefig(out_dir / f"resid_mlp_subnet_categories_{n_layers}layers_{wandb_id}.png") -print(f"Saved figure to {out_dir / f'resid_mlp_subnet_categories_{n_layers}layers_{wandb_id}.png'}") - - -# %% -per_feature_losses_path = Path(out_dir) / f"resid_mlp_losses_{n_layers}layers_{wandb_id}.pt" -if not use_data_from_files or not per_feature_losses_path.exists(): - loss_target, loss_spd_batch_topk, loss_spd_sample_topk = collect_per_feature_losses( - target_model=target_model, - spd_model=model, - config=config, - dataset=dataset, - device=device, - batch_size=config.batch_size, - n_samples=100_000, - ) - # Save the losses to a file - torch.save( - (loss_target, loss_spd_batch_topk, loss_spd_sample_topk), - per_feature_losses_path, - ) - -# Load the losses from a file -loss_target, loss_spd_batch_topk, loss_spd_sample_topk = torch.load( - per_feature_losses_path, weights_only=True, map_location="cpu" -) - -fig = plot_per_feature_performance_fig( - loss_target=loss_target, - loss_spd_batch_topk=loss_spd_batch_topk, - loss_spd_sample_topk=loss_spd_sample_topk, - config=config, - color_map=color_map, -) -fig.show() -fig.savefig(out_dir / f"resid_mlp_per_feature_performance_{n_layers}layers_{wandb_id}.png") -print( - f"Saved figure to {out_dir / f'resid_mlp_per_feature_performance_{n_layers}layers_{wandb_id}.png'}" -) - -# %% -# Scatter plot of avg active components vs loss difference -avg_components_path = Path(out_dir) / f"avg_components_{n_layers}layers_{wandb_id}.pt" -if not use_data_from_files or not avg_components_path.exists(): - avg_components = collect_average_components_per_feature( - model_fn=spd_model_fn, - dataset=dataset, - device=device, - n_features=model.config.n_features, - batch_size=config.batch_size, - n_samples=500_000, - ) - # Save the avg_components to a file - torch.save(avg_components.cpu(), avg_components_path) - -# Load the avg_components from a file -avg_components = torch.load(avg_components_path, map_location=device, weights_only=True) - -# Get the loss of the spd model w.r.t the target model -fn_without_batch_topk = lambda batch: spd_model_fn( - batch, topk=1, batch_topk=False -).spd_model_masked_output # type: ignore -losses_spd_wrt_target = analyze_per_feature_performance( - model_fn=fn_without_batch_topk, - target_model_fn=target_model_fn, - model_config=model.config, - device=device, - batch_size=config.batch_size, -) - -fig = plot_avg_components_scatter( - losses_spd_wrt_target=losses_spd_wrt_target, avg_components=avg_components -) -fig.show() -# Save the figure -fig.savefig(out_dir / f"resid_mlp_avg_components_scatter_{n_layers}layers_{wandb_id}.png") -print( - f"Saved figure to {out_dir / f'resid_mlp_avg_components_scatter_{n_layers}layers_{wandb_id}.png'}" -) - -# %% -# Plot the main truncated feature contributions figure for the paper -fig = plot_spd_feature_contributions_truncated( - spd_model=model, - target_model=target_model, - device=device, - n_features=10, - include_crossterms=False, -) -fig.savefig(out_dir / f"resid_mlp_weights_{n_layers}layers_{wandb_id}.png") -print(f"Saved figure to {out_dir / f'resid_mlp_weights_{n_layers}layers_{wandb_id}.png'}") - -# Full figure for updating wandb report -# fig = plot_spd_feature_contributions( -# spd_model=model, -# target_model=target_model, -# device=device, -# ) -# fig.savefig(out_dir / f"resid_mlp_weights_full_{n_layers}layers_{wandb_id}.png") -# plt.close(fig) -# print(f"Saved figure to {out_dir / f'resid_mlp_weights_full_{n_layers}layers_{wandb_id}.png'}") -# import wandb - -# # Restart the run and log the figure -# run = wandb.init(project="spd-resid-mlp", id=wandb_id, resume="must") -# run.log({"neuron_contributions": wandb.Image(fig)}) -# run.finish() - -# %% -# Plot causal scrubbing-esque test -n_batches = 100 -losses = get_scrubbed_losses( - top1_model_fn=top1_model_fn, - spd_model_fn=spd_model_fn, - target_model=target_model, - dataset=dataset, - model=model, - device=device, - config=config, - n_batches=n_batches, -) - -fig = plot_scrub_losses(losses, config, color_map, n_batches) -fig.savefig( - out_dir / f"resid_mlp_scrub_hist_{n_layers}layers_{wandb_id}.png", bbox_inches="tight", dpi=300 -) -print(f"Saved figure to {out_dir / f'resid_mlp_scrub_hist_{n_layers}layers_{wandb_id}.png'}") - -# %% Linearity test: Enable one subnet after the other -# candlestick plot - -# # Dictionary feature_idx -> subnet_idx -subnet_indices = get_feature_subnet_map(top1_model_fn, device, model.config, instance_idx=0) - -n_features = model.config.n_features -feature_idx = 42 -subtract_inputs = True # TODO TRUE subnet - - -fig = plot_feature_response_with_subnets( - topk_model_fn=top1_model_fn, - device=device, - model_config=model.config, - feature_idx=feature_idx, - subnet_idx=subnet_indices[feature_idx], - batch_size=1000, - plot_type="errorbar", - color_map=color_map, -)["feature_response_with_subnets"] -fig.savefig( # type: ignore - out_dir / f"feature_response_with_subnets_{feature_idx}_{n_layers}layers_{wandb_id}.png", - bbox_inches="tight", - dpi=300, -) -print( - f"Saved figure to {out_dir / f'feature_response_with_subnets_{feature_idx}_{n_layers}layers_{wandb_id}.png'}" -) -plt.show() diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index defe011..4482b22 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -45,18 +45,21 @@ def _tms_forward( linear1: Linear | LinearComponent, linear2: TransposedLinear | TransposedLinearComponent, b_final: Float[Tensor, "n_instances n_features"], - mask: Float[Tensor, "batch n_instances C"] | None = None, + masks: dict[str, Float[Tensor, "batch n_instances m"]] | None = None, hidden_layers: nn.ModuleList | None = None, ) -> Float[Tensor, "batch n_instances n_features"]: """Forward pass used for TMSModel and TMSSPDModel. - Note that topk_mask is only used for TMSSPDModel. + Note that masks have no effect for TMSModel. """ - hidden = linear1(x, mask=mask) + linear1_mask = masks["linear1"] if masks is not None else None + hidden = linear1(x, mask=linear1_mask) if hidden_layers is not None: - for layer in hidden_layers: - hidden = layer(hidden, mask=mask) - out_pre_relu = linear2(hidden, mask=mask) + b_final + for i, layer in enumerate(hidden_layers): + hidden_mask = masks[f"hidden_layers.{i}"] if masks is not None else None + hidden = layer(hidden, mask=hidden_mask) + linear2_mask = masks["linear2"] if masks is not None else None + out_pre_relu = linear2(hidden, mask=linear2_mask) + b_final out = F.relu(out_pre_relu) return out @@ -167,10 +170,9 @@ class TMSSPDModelConfig(BaseModel): n_features: PositiveInt n_hidden: PositiveInt n_hidden_layers: NonNegativeInt - C: PositiveInt | None = None bias_val: float device: str - m: PositiveInt | None = None + m: PositiveInt class TMSSPDModel(SPDModel): @@ -179,10 +181,8 @@ def __init__(self, config: TMSSPDModelConfig): self.config = config self.n_instances = config.n_instances # Required for backwards compatibility self.n_features = config.n_features # Required for backwards compatibility - self.C = config.C if config.C is not None else config.n_features self.bias_val = config.bias_val - - self.m = min(config.n_features, config.n_hidden) + 1 if config.m is None else config.m + self.m = config.m self.linear1 = LinearComponent( d_in=config.n_features, @@ -190,7 +190,6 @@ def __init__(self, config: TMSSPDModelConfig): n_instances=config.n_instances, init_type="xavier_normal", init_scale=1.0, - C=self.C, m=self.m, ) self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) @@ -211,7 +210,6 @@ def __init__(self, config: TMSSPDModelConfig): n_instances=config.n_instances, init_type="xavier_normal", init_scale=1.0, - C=self.C, m=self.m, ) for _ in range(config.n_hidden_layers) @@ -223,7 +221,7 @@ def __init__(self, config: TMSSPDModelConfig): def forward( self, x: Float[Tensor, "batch n_instances n_features"], - mask: Float[Tensor, "batch n_instances C"] | None = None, + masks: dict[str, Float[Tensor, "batch n_instances m"]] | None = None, ) -> Float[Tensor, "batch n_instances n_features"]: return _tms_forward( x=x, @@ -231,7 +229,7 @@ def forward( linear2=self.linear2, b_final=self.b_final, hidden_layers=self.hidden_layers, - mask=mask, + masks=masks, ) @staticmethod @@ -286,7 +284,6 @@ def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: assert isinstance(spd_config.task_config, TMSTaskConfig) tms_spd_config = TMSSPDModelConfig( **tms_train_config_dict["tms_model_config"], - C=spd_config.C, m=spd_config.m, bias_val=spd_config.task_config.bias_val, ) diff --git a/spd/experiments/tms/spd_interp.py b/spd/experiments/tms/spd_interp.py deleted file mode 100644 index 945d508..0000000 --- a/spd/experiments/tms/spd_interp.py +++ /dev/null @@ -1,379 +0,0 @@ -# %% - -import matplotlib.collections as mc -import matplotlib.pyplot as plt -import numpy as np -import numpy.typing as npt -import torch -from jaxtyping import Float -from torch import Tensor - -from spd.configs import TMSTaskConfig -from spd.experiments.tms.models import TMSModel, TMSSPDModel -from spd.plotting import collect_sparse_dataset_mse_losses, plot_sparse_feature_mse_line_plot -from spd.settings import REPO_ROOT -from spd.utils import COLOR_PALETTE, DataGenerationType, SparseFeatureDataset - - -def plot_vectors( - subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], - axs: npt.NDArray[np.object_], -) -> None: - """2D polygon plot of each subnetwork. - - Adapted from - https://colab.research.google.com/github/anthropics/toy-models-of-superposition/blob/main/toy_models.ipynb. - """ - n_instances, n_subnets, n_features, n_hidden = subnets.shape - - # Use different colors for each subnetwork if there's only one instance - color_vals = np.linspace(0, 1, n_features) if n_instances == 1 else np.zeros(n_features) - colors = plt.cm.viridis(color_vals) # type: ignore - - for subnet_idx in range(n_subnets): - for instance_idx, ax in enumerate(axs[:, subnet_idx]): - arr = subnets[instance_idx, subnet_idx].cpu().detach().numpy() - - # Plot each feature with its unique color - for j in range(n_features): - ax.scatter(arr[j, 0], arr[j, 1], color=colors[j]) - ax.add_collection( - mc.LineCollection([[(0, 0), (arr[j, 0], arr[j, 1])]], colors=[colors[j]]) - ) - - ax.set_aspect("equal") - z = 1.3 - ax.set_facecolor("#f6f6f6") - ax.set_xlim((-z, z)) - ax.set_ylim((-z, z)) - ax.tick_params(left=True, right=False, labelleft=False, labelbottom=False, bottom=True) - for spine in ["top", "right"]: - ax.spines[spine].set_visible(False) - for spine in ["bottom", "left"]: - ax.spines[spine].set_position("center") - - if instance_idx == 0: # Only add labels to the first row - if subnet_idx == 0: - label = "Target model" - elif subnet_idx == 1: - label = "Sum of components" - else: - label = f"Component {subnet_idx - 2}" - ax.set_title(label, pad=10, fontsize="large") - - -def plot_networks( - subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], - axs: npt.NDArray[np.object_], -) -> None: - """Plot neural network diagrams for each W matrix in the subnet variable. - - Args: - subnets: Tensor of shape [n_instances, n_subnets, n_features, n_hidden]. - axs: Matplotlib axes to plot on. - """ - - n_instances, n_subnets, n_features, n_hidden = subnets.shape - - # Take the absolute value of the weights - subnets_abs = subnets.abs() - - # Find the maximum weight across each instance - max_weights = subnets_abs.amax(dim=(1, 2, 3)) - - axs = np.atleast_2d(np.array(axs)) - - # axs[0, 0].set_xlabel("Outputs (before ReLU and biases)") - # Add the above but in text because the x-axis is killed - axs[0, 0].text( - 0.05, - 0.05, - "Outputs (before bias & ReLU)", - ha="left", - va="center", - transform=axs[0, 0].transAxes, - ) - # Also add "input label" - axs[0, 0].text( - 0.05, - 0.95, - "Inputs", - ha="left", - va="center", - transform=axs[0, 0].transAxes, - ) - - # Grayscale colormap. darker for larger weight - cmap = plt.get_cmap("gray_r") - - for subnet_idx in range(n_subnets): - for instance_idx, ax in enumerate(axs[:, subnet_idx]): - arr = subnets_abs[instance_idx, subnet_idx].cpu().detach().numpy() - - # Define node positions (top to bottom) - y_input, y_hidden, y_output = 0, -1, -2 - x_input = np.linspace(0.05, 0.95, n_features) - x_hidden = np.linspace(0.25, 0.75, n_hidden) - x_output = np.linspace(0.05, 0.95, n_features) - - # Add transparent grey box around hidden layer - box_width = 0.8 - box_height = 0.4 - box = plt.Rectangle( - (0.5 - box_width / 2, y_hidden - box_height / 2), - box_width, - box_height, - fill=True, - facecolor="#e4e4e4", - edgecolor="none", - alpha=0.33, - transform=ax.transData, - ) - ax.add_patch(box) - - # Plot nodes - ax.scatter( - x_input, [y_input] * n_features, s=200, color="grey", edgecolors="k", zorder=3 - ) - ax.scatter( - x_hidden, [y_hidden] * n_hidden, s=200, color="grey", edgecolors="k", zorder=3 - ) - ax.scatter( - x_output, [y_output] * n_features, s=200, color="grey", edgecolors="k", zorder=3 - ) - - # Plot edges from input to hidden layer - for idx_input in range(n_features): - for idx_hidden in range(n_hidden): - weight = arr[idx_input, idx_hidden] - norm_weight = weight / max_weights[instance_idx] - color = cmap(norm_weight) - ax.plot( - [x_input[idx_input], x_hidden[idx_hidden]], - [y_input, y_hidden], - color=color, - linewidth=1, - ) - - # Plot edges from hidden to output layer - arr_T = arr.T # Transpose of W for W^T - for idx_hidden in range(n_hidden): - for idx_output in range(n_features): - weight = arr_T[idx_hidden, idx_output] - norm_weight = weight / max_weights[instance_idx] - color = cmap(norm_weight) - ax.plot( - [x_hidden[idx_hidden], x_output[idx_output]], - [y_hidden, y_output], - color=color, - linewidth=1, - ) - - # Remove axes for clarity - # ax.axis("off") - ax.set_xlim(-0.1, 1.1) - ax.set_ylim(y_output - 0.5, y_input + 0.5) - # Remove x and y ticks and bounding boxes - ax.set_xticks([]) - ax.set_yticks([]) - for spine in ["top", "right", "bottom", "left"]: - ax.spines[spine].set_visible(False) - - -def plot_combined( - subnets: Float[Tensor, "n_instances n_subnets n_features n_hidden"], - target_weights: Float[Tensor, "n_instances n_features n_hidden"], - n_instances: int | None = None, -) -> plt.Figure: - """Create a combined figure with both vector and network diagrams side by side.""" - if n_instances is not None: - subnets = subnets[:n_instances] - target_weights = target_weights[:n_instances] - n_instances, n_subnets, n_features, n_hidden = subnets.shape - - # We wish to add two panels to the left: The target model weights and the sum of the subnets - # Add an extra dimension to the target weights so we can concatenate them - target_subnet = target_weights[:, None, :, :] - summed_subnet = subnets.sum(dim=1, keepdim=True) - subnets = torch.cat([target_subnet, summed_subnet, subnets], dim=1) - n_subnets += 2 - - # Create figure with two rows - fig, axs = plt.subplots( - nrows=n_instances * 2, - ncols=n_subnets, - figsize=(3 * n_subnets, 6 * n_instances), - ) - - plt.subplots_adjust(hspace=0) - - axs = np.atleast_2d(np.array(axs)) - - # Split axes into left (vectors) and right (networks) sides - axs_vectors = axs[:n_instances, :] - axs_networks = axs[n_instances:, :] - - # Call existing plotting logic with the split axes - plot_vectors(subnets=subnets, axs=axs_vectors) - plot_networks(subnets=subnets, axs=axs_networks) - - return fig - - -# %% -device = "cuda" if torch.cuda.is_available() else "cpu" -# path = "wandb:spd-tms/runs/bft0pgi8" # Old 5-2 run with attributions from spd model # paper run -# instance_idx = 0 -# path = "wandb:spd-tms/runs/sv9padmo" # 10-5 -# path = "wandb:spd-tms/runs/vt0i4a22" # 20-5 -# path = "wandb:spd-tms/runs/tyo4serm" # 40-10 with topk=2, topk_recon_coeff=1e1, schatten_coeff=15# old paper run -# path = "wandb:spd-tms/runs/9zzp2s68" # 40-10 with topk=2, topk_recon_coeff=1e1, schatten_coeff=20 -path = "wandb:spd-tms/runs/08no00iq" # 40-10 with topk=1, topk_recon_coeff=1e1, schatten_coeff=20# new paper run -instance_idx = 2 -# path = "wandb:spd-tms/runs/014t4f9n" # 40-10 with topk=1, topk_recon_coeff=1e1, schatten_coeff=1e1 - -run_id = path.split("/")[-1] - -# Plot showing polygons for each subnet -model, config = TMSSPDModel.from_pretrained(path) -subnets = model.linear1.component_weights.detach().cpu() - -assert isinstance(config.task_config, TMSTaskConfig) -target_model, target_model_train_config_dict = TMSModel.from_pretrained( - config.task_config.pretrained_model_path -) - -out_dir = REPO_ROOT / "spd/experiments/tms/out/figures/" -out_dir.mkdir(parents=True, exist_ok=True) - - -# %% -# Max cosine similarity between subnets and target model -def plot_max_cosine_sim(max_cosine_sim: Float[Tensor, " n_features"]) -> plt.Figure: - fig, ax = plt.subplots() - # Make a bar plot of the max cosine similarity for each feature - ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().numpy()) - # Add a grey horizontal line at 1 - ax.axhline(1, color="grey", linestyle="--") - ax.set_xlabel("Input feature index") - ax.set_ylabel("Max cosine similarity") - # Remove top and right spines - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - return fig - - -cosine_sims = torch.einsum( - "C f h, f h -> C f", - subnets[instance_idx] / torch.norm(subnets[instance_idx], dim=-1, keepdim=True), - target_model.linear1.weight[instance_idx] - / torch.norm(target_model.linear1.weight[instance_idx], dim=-1, keepdim=True), -) -max_cosine_sim = cosine_sims.max(dim=0).values -print(f"Max cosine similarity:\n{max_cosine_sim}") -print(f"Mean max cosine similarity: {max_cosine_sim.mean()}") -print(f"std max cosine similarity: {max_cosine_sim.std()}") - - -# Get the subnet weights at the max cosine similarity -subnet_weights_at_max_cosine_sim: Float[Tensor, "n_features n_hidden"] = subnets[ - instance_idx, cosine_sims.max(dim=0).indices, torch.arange(target_model.config.n_features) -] -# Get the norm of the target model weights -target_model_weights_norm = torch.norm( - target_model.linear1.weight[instance_idx], dim=-1, keepdim=True -) -# Get the norm of subnet_weights_at_max_cosine_sim -subnet_weights_at_max_cosine_sim_norm = torch.norm( - subnet_weights_at_max_cosine_sim, dim=-1, keepdim=True -) -# Divide the subnet weights by the target model weights ratio -l2_ratio = subnet_weights_at_max_cosine_sim_norm / target_model_weights_norm -print(f"Mean L2 ratio: {l2_ratio.mean()}") -print(f"std L2 ratio: {l2_ratio.std()}") - -# Mean bias -print(f"Mean bias: {target_model.b_final[instance_idx].mean()}") - - -# fig = plot_max_cosine_sim(max_cosine_sim) -# # Save figure -# fig.savefig(out_dir / f"tms_max_cosine_sim_{run_id}.png", bbox_inches="tight", dpi=400) -# print(f"Saved figure to {out_dir / f'tms_max_cosine_sim_{run_id}.png'}") -# %% -# Only plot if the hidden dimension is 2 -if target_model.config.n_hidden == 2: - # We only look at the first instance - fig = plot_combined(subnets, target_model.linear1.weight.detach().cpu(), n_instances=1) - fig.savefig(out_dir / f"tms_combined_diagram_{run_id}.png", bbox_inches="tight", dpi=400) - print(f"Saved figure to {out_dir / f'tms_combined_diagram_{run_id}.png'}") - -# %% -# This doesn't work for TMS. -# # Get the entries for the main loss table in the paper -dataset = SparseFeatureDataset( - n_instances=target_model.config.n_instances, - n_features=target_model.config.n_features, - feature_probability=config.task_config.feature_probability, - device=device, - data_generation_type="at_least_zero_active", # This will be changed in collect_sparse_dataset_mse_losses - value_range=(0.0, 1.0), -) -gen_types: list[DataGenerationType] = [ - "at_least_zero_active", - "exactly_one_active", - "exactly_two_active", - "exactly_three_active", - "exactly_four_active", -] -assert config.topk is not None -results = collect_sparse_dataset_mse_losses( - dataset=dataset, - target_model=target_model, - spd_model=model, - config=config, - batch_size=10000, - device=device, - topk=config.topk, - batch_topk=config.batch_topk, - distil_from_target=config.distil_from_target, - gen_types=gen_types, -) - -# %% -# Option to plot a single instance -inst = None -if inst is not None: - # We only plot the {inst}th instance - plot_data = { - gen_type: {k: float(v[inst].detach().cpu()) for k, v in results[gen_type].items()} - for gen_type in gen_types - } -else: - # Take the mean over all instances - plot_data = { - gen_type: {k: float(v.mean(dim=0).detach().cpu()) for k, v in results[gen_type].items()} - for gen_type in gen_types - } - -# %% -# Create line plot of results -color_map = { - "target": COLOR_PALETTE[0], - "apd_topk": COLOR_PALETTE[1], - "baseline_monosemantic": "grey", -} -label_map = [ - ("target", "Target model", color_map["target"]), - ("spd", "APD model", color_map["apd_topk"]), - ("baseline_monosemantic", "Monosemantic baseline", color_map["baseline_monosemantic"]), -] - -fig = plot_sparse_feature_mse_line_plot(plot_data, label_map=label_map, log_scale=False) -fig.show() -# fig.savefig(out_dir / f"tms_mse_{run_id}_inst{inst}.png", dpi=400) -# print(f"Saved figure to {out_dir / f'tms_mse_{run_id}_inst{inst}.png'}") -fig.savefig(out_dir / f"tms_mse_{run_id}.png", dpi=400) -print(f"Saved figure to {out_dir / f'tms_mse_{run_id}.png'}") - -# %% diff --git a/spd/experiments/tms/tms_topk_config.yaml b/spd/experiments/tms/tms_config.yaml similarity index 65% rename from spd/experiments/tms/tms_topk_config.yaml rename to spd/experiments/tms/tms_config.yaml index d8c11eb..2eba18e 100644 --- a/spd/experiments/tms/tms_topk_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -4,15 +4,10 @@ # wandb_run_name_prefix: "" # unit_norm_matrices: false # seed: 0 -# C: 5 -# topk: 0.211 -# batch_topk: true # param_match_coeff: 1.0 -# topk_recon_coeff: 1 -# attribution_type: gradient -# pnorm: null -# schatten_pnorm: 1.0 -# schatten_coeff: 7e-1 +# masked_recon_coeff: 1 +# pnorm: 0.9 +# lp_sparsity_coeff: 1.0 # batch_size: 2048 # steps: 20_000 # image_freq: 5_000 @@ -35,16 +30,11 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: false seed: 0 -topk: 2.0 -# topk: 0.8 # synced inputs -C: 40 -batch_topk: true +m: 10 param_match_coeff: 1.0 -topk_recon_coeff: 10.0 -attribution_type: gradient -pnorm: null -schatten_pnorm: 0.9 -schatten_coeff: 15.0 +masked_recon_coeff: 10.0 +pnorm: 0.9 +lp_sparsity_coeff: 1.0 batch_size: 2048 steps: 20_000 image_freq: 5_000 @@ -58,7 +48,5 @@ task_config: bias_val: 0.0 train_bias: false feature_probability: 0.05 - # feature_probability: 0.02 # synced inputs data_generation_type: "at_least_zero_active" - pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" - # pretrained_model_path: "wandb:spd-train-tms/runs/rkflpubi" # synced inputs \ No newline at end of file + pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" \ No newline at end of file diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index c2a7878..6ba2ecf 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -10,15 +10,12 @@ import fire import matplotlib.pyplot as plt -import numpy as np import torch import wandb import yaml from jaxtyping import Float from torch import Tensor -from tqdm import tqdm -from spd.attributions import collect_subnetwork_attributions from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig from spd.log import logger @@ -41,283 +38,6 @@ def get_run_name(config: Config, tms_model_config: TMSModelConfig) -> str: return config.wandb_run_name_prefix + run_suffix -def plot_A_matrix(x: torch.Tensor, pos_only: bool = False) -> plt.Figure: - n_instances = x.shape[0] - - fig, axs = plt.subplots( - 1, n_instances, figsize=(2.5 * n_instances, 2), squeeze=False, sharey=True - ) - - cmap = "Blues" if pos_only else "RdBu" - ims = [] - for i in range(n_instances): - ax = axs[0, i] - instance_data = x[i, :, :].detach().cpu().float().numpy() - max_abs_val = np.abs(instance_data).max() - vmin = 0 if pos_only else -max_abs_val - vmax = max_abs_val - im = ax.matshow(instance_data, vmin=vmin, vmax=vmax, cmap=cmap) - ims.append(im) - ax.xaxis.set_ticks_position("bottom") - if i == 0: - ax.set_ylabel("k", rotation=0, labelpad=10, va="center") - else: - ax.set_yticks([]) # Remove y-axis ticks for all but the first plot - ax.xaxis.set_label_position("top") - ax.set_xlabel("n_features") - - plt.subplots_adjust(wspace=0.1, bottom=0.15, top=0.9) - fig.subplots_adjust(bottom=0.2) - - return fig - - -def plot_subnetwork_attributions_multiple_instances( - attribution_scores: Float[Tensor, "batch n_instances C"], - out_dir: Path, - step: int | None, -) -> plt.Figure: - """Plot subnetwork attributions for multiple instances in a row.""" - n_instances = attribution_scores.shape[1] - - # Create a wide figure with subplots in a row - fig, axes = plt.subplots(1, n_instances, figsize=(5 * n_instances, 5), constrained_layout=True) - - axes = np.array([axes]) if isinstance(axes, plt.Axes) else axes - - images = [] - for idx, ax in enumerate(axes): - instance_scores = attribution_scores[:, idx, :] - im = ax.matshow(instance_scores.detach().cpu().numpy(), aspect="auto", cmap="Reds") - images.append(im) - - # Annotate each cell with the numeric value - for i in range(instance_scores.shape[0]): - for j in range(instance_scores.shape[1]): - ax.text( - j, - i, - f"{instance_scores[i, j]:.2f}", - ha="center", - va="center", - color="black", - fontsize=3, - ) - - ax.set_xlabel("Subnetwork Index") - if idx == 0: # Only set ylabel for leftmost plot - ax.set_ylabel("Batch Index") - ax.set_title(f"Instance {idx}") - - # Add a single colorbar that references all plots - norm = plt.Normalize(vmin=attribution_scores.min().item(), vmax=attribution_scores.max().item()) - for im in images: - im.set_norm(norm) - fig.colorbar(images[0], ax=axes) - - fig.suptitle(f"Subnetwork Attributions (Step {step})") - filename = ( - f"subnetwork_attributions_s{step}.png" - if step is not None - else "subnetwork_attributions.png" - ) - fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") - plt.close(fig) - tqdm.write(f"Saved subnetwork attributions to {out_dir / filename}") - return fig - - -def plot_subnetwork_attributions_statistics_multiple_instances( - topk_mask: Float[Tensor, "batch_size n_instances C"], out_dir: Path, step: int | None -) -> plt.Figure: - """Plot a row of vertical bar charts showing active subnetworks for each instance.""" - n_instances = topk_mask.shape[1] - fig, axes = plt.subplots(1, n_instances, figsize=(5 * n_instances, 5), constrained_layout=True) - - axes = np.array([axes]) if isinstance(axes, plt.Axes) else axes - - for instance_idx in range(n_instances): - ax = axes[instance_idx] - instance_mask = topk_mask[:, instance_idx] - - values = instance_mask.sum(dim=1).cpu().detach().numpy() - bins = list(range(int(values.min().item()), int(values.max().item()) + 2)) - counts, _ = np.histogram(values, bins=bins) - - bars = ax.bar(bins[:-1], counts, align="center", width=0.8) - ax.set_xticks(bins[:-1]) - ax.set_xticklabels([str(b) for b in bins[:-1]]) - ax.set_title(f"Instance {instance_idx}") - - if instance_idx == 0: # Only set y-label for leftmost plot - ax.set_ylabel("Count") - ax.set_xlabel("Number of active subnetworks") - - # Add value annotations on top of each bar - for bar in bars: - height = bar.get_height() - ax.annotate( - f"{height}", - xy=(bar.get_x() + bar.get_width() / 2, height), - xytext=(0, 3), - textcoords="offset points", - ha="center", - va="bottom", - ) - - fig.suptitle(f"Active subnetworks per instance (batch_size={topk_mask.shape[0]})") - filename = ( - f"subnetwork_attributions_statistics_s{step}.png" - if step is not None - else "subnetwork_attributions_statistics.png" - ) - fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") - plt.close(fig) - tqdm.write(f"Saved subnetwork attributions statistics to {out_dir / filename}") - return fig - - -def plot_component_weights(model: TMSSPDModel, step: int, out_dir: Path, **_) -> plt.Figure: - """Plot the component weight matrices.""" - component_weights = model.linear1.component_weights - - # component_weights: [n_instances, k, n_features, n_hidden] - n_instances, C, dim1, dim2 = component_weights.shape - - fig, axs = plt.subplots( - C, - n_instances, - figsize=(2 * n_instances, 2 * C), - constrained_layout=True, - ) - - for i in range(n_instances): - instance_max = np.abs(component_weights[i].detach().cpu().numpy()).max() - for j in range(C): - ax = axs[j, i] # type: ignore - param = component_weights[i, j].detach().cpu().numpy() - ax.matshow(param, cmap="RdBu", vmin=-instance_max, vmax=instance_max) - ax.set_xticks([]) - - if i == 0: - ax.set_ylabel(f"k={j}", rotation=0, ha="right", va="center") - if j == C - 1: - ax.set_xlabel(f"Inst {i}", rotation=45, ha="right") - - fig.suptitle(f"Component Weights (Step {step})") - fig.savefig(out_dir / f"component_weights_{step}.png", dpi=300, bbox_inches="tight") - plt.close(fig) - tqdm.write(f"Saved component weights to {out_dir / f'component_weights_{step}.png'}") - return fig - - -def plot_batch_frequencies( - frequencies: Float[Tensor, "n_instances C"], - xlabel: str, - ax: plt.Axes, - batch_size: int, - title: str | None = None, -) -> None: - """Plot frequency of C activations for each instance on a given axis. - - Args: - frequencies: Tensor counting frequencies for each instance - xlabel: Label for x-axis - ax: Matplotlib axis to plot on - batch_size: Size of the batch - title: Optional title for the subplot - """ - n_instances = frequencies.shape[0] - C = frequencies.shape[1] - - for instance_idx in range(n_instances): - bars = ax.bar( - np.arange(C) + instance_idx * (C + 1), # Add spacing between instances - frequencies[instance_idx].detach().cpu().numpy(), - align="center", - width=0.8, - label=f"Instance {instance_idx}", - ) - - # Add value annotations on top of each bar - for bar in bars: - height = bar.get_height() - ax.annotate( - f"{int(height)}", - xy=(bar.get_x() + bar.get_width() / 2, height), - xytext=(0, 3), - textcoords="offset points", - ha="center", - va="bottom", - ) - - ax.set_xlabel(xlabel) - ax.set_ylabel(f"Activation Count (batch_size={batch_size})") - if title: - ax.set_title(title) - - # Set x-ticks for each instance group - all_ticks = [] - all_labels = [] - for i in range(n_instances): - ticks = np.arange(C) + i * (C + 1) - all_ticks.extend(ticks) - all_labels.extend([str(j) for j in range(C)]) - ax.set_xticks(all_ticks) - ax.set_xticklabels(all_labels) - - -def plot_batch_statistics( - batch: Float[Tensor, "batch n_instances n_features"], - topk_mask: Float[Tensor, "batch n_instances C"], - out_dir: Path, - step: int | None, -) -> dict[str, plt.Figure]: - # Count the number of active features over the batch - active_input_feats = (batch != 0).sum(dim=0) - topk_activations = topk_mask.sum(dim=0) - - # Create figure with two vertically stacked subplots - fig = plt.figure(figsize=(15, 10)) - gs = fig.add_gridspec(2, 1, height_ratios=[1, 1], hspace=0.3) - - # Plot input features - ax1 = fig.add_subplot(gs[0]) - plot_batch_frequencies( - frequencies=active_input_feats, - xlabel="Input feature index", - ax=ax1, - batch_size=batch.shape[0], - title="Input feature frequencies across batch", - ) - - # Plot subnetwork frequencies - ax2 = fig.add_subplot(gs[1]) - plot_batch_frequencies( - frequencies=topk_activations, - xlabel="Component index", - ax=ax2, - batch_size=batch.shape[0], - title="Component frequencies across batch", - ) - - # Ensure that each ax has the same y-axis maximum - y_lims = [ax.get_ylim() for ax in [ax1, ax2]] - y_max = max(y_lims[0][1], y_lims[1][1]) - for ax in [ax1, ax2]: - ax.set_ylim(0, y_max) - - # fig.suptitle(f"Batch Statistics (Step {step})") - - # Save the combined figure - filename = f"batch_statistics_s{step}.png" if step is not None else "batch_statistics.png" - fig.savefig(out_dir / filename, dpi=300, bbox_inches="tight") - plt.close(fig) - tqdm.write(f"Saved batch statistics to {out_dir / filename}") - - return {"batch_statistics": fig} - - def make_plots( model: TMSSPDModel, target_model: TMSModel, @@ -325,38 +45,11 @@ def make_plots( out_dir: Path, device: str, config: Config, - topk_mask: Float[Tensor, "batch n_instances C"] | None, + masks: dict[str, Float[Tensor, "batch n_instances m"]] | None, batch: Float[Tensor, "batch n_instances n_features"], **_, ) -> dict[str, plt.Figure]: plots = {} - if model.hidden_layers is not None: - logger.warning("Only plotting the W matrix params and not the hidden layers.") - plots["component_weights"] = plot_component_weights(model, step, out_dir) - - if config.topk is not None: - assert topk_mask is not None - assert isinstance(config.task_config, TMSTaskConfig) - n_instances = model.config.n_instances if hasattr(model, "config") else model.n_instances - attribution_scores = collect_subnetwork_attributions( - spd_model=model, - config=config, - target_model=target_model, - device=device, - n_instances=n_instances, - ) - plots["subnetwork_attributions"] = plot_subnetwork_attributions_multiple_instances( - attribution_scores=attribution_scores, out_dir=out_dir, step=step - ) - plots["subnetwork_attributions_statistics"] = ( - plot_subnetwork_attributions_statistics_multiple_instances( - topk_mask=topk_mask, out_dir=out_dir, step=step - ) - ) - - batch_stat_plots = plot_batch_statistics(batch, topk_mask, out_dir, step) - plots.update(batch_stat_plots) - return plots @@ -420,7 +113,6 @@ def main( tms_spd_model_config = TMSSPDModelConfig( **target_model.config.model_dump(mode="json"), - C=config.C, m=config.m, bias_val=task_config.bias_val, ) diff --git a/spd/experiments/tms/tms_lp_config.yaml b/spd/experiments/tms/tms_lp_config.yaml deleted file mode 100644 index 99c73da..0000000 --- a/spd/experiments/tms/tms_lp_config.yaml +++ /dev/null @@ -1,29 +0,0 @@ -wandb_project: spd-tms -wandb_run_name: null -wandb_run_name_prefix: "" -unit_norm_matrices: true -seed: 0 -topk: null -m: 3 -C: 5 -param_match_coeff: 1.0 -lp_sparsity_coeff: 7.0 -pnorm: 0.9 -schatten_pnorm: 1.0 -schatten_coeff: 1.0 -batch_size: 2048 -steps: 20_000 -image_freq: 5000 -print_freq: 500 -save_freq: 20_000 -lr: 0.3 -lr_schedule: constant -lr_warmup_pct: 0.1 -task_config: - task_name: tms - bias_val: 0.0 - train_bias: false - feature_probability: 0.05 - data_generation_type: "at_least_zero_active" - # File obtained by running spd/experiments/tms/train_tms.py - pretrained_model_path: spd/experiments/tms/out/tms_n-features5_n-hidden2_n-instances12_seed0.pth/model.pth \ No newline at end of file diff --git a/spd/experiments/tms/tms_sweep_config.yaml b/spd/experiments/tms/tms_sweep_config.yaml index 0964d8e..b483a6c 100644 --- a/spd/experiments/tms/tms_sweep_config.yaml +++ b/spd/experiments/tms/tms_sweep_config.yaml @@ -4,12 +4,10 @@ metric: name: final_closeness goal: minimize parameters: - # topk: - # # values: [0.211, 0.239, 0.25, 0.261, 0.289] seed: values: [0, 1, 2, 3, 4] command: - ${env} - ${interpreter} - ${program} -- spd/experiments/tms/tms_topk_config.yaml +- spd/experiments/tms/tms_config.yaml diff --git a/spd/hooks.py b/spd/hooks.py index 1ebe328..5b820e2 100644 --- a/spd/hooks.py +++ b/spd/hooks.py @@ -454,7 +454,7 @@ def run_with_cache( reset_hooks_end: bool = True, clear_contexts: bool = False, **model_kwargs: Any, - ): + ) -> tuple[Tensor, dict[str, Tensor]]: """ Runs the model and returns the model output and a Cache object. diff --git a/spd/models/base.py b/spd/models/base.py index b58917a..9a1ffdf 100644 --- a/spd/models/base.py +++ b/spd/models/base.py @@ -1,5 +1,3 @@ -from torch import Tensor - from spd.hooks import HookedRootModule from spd.models.components import TransposedLinearComponent from spd.module_utils import ( @@ -10,31 +8,6 @@ class SPDModel(HookedRootModule): - def set_subnet_to_zero(self, subnet_idx: int, has_instance_dim: bool) -> dict[str, Tensor]: - stored_vals = {} - for attr_name in ["A", "B"]: - params = collect_nested_module_attrs(self, attr_name) - for param_name, param in params.items(): - if self.parent_is_transposed_linear(param_name): - continue - if has_instance_dim: - stored_vals[param_name] = param.data[:, subnet_idx, :, :].detach().clone() - param.data[:, subnet_idx, :, :] = 0.0 - else: - stored_vals[param_name] = param.data[subnet_idx, :, :].detach().clone() - param.data[subnet_idx, :, :] = 0.0 - return stored_vals - - def restore_subnet( - self, subnet_idx: int, stored_vals: dict[str, Tensor], has_instance_dim: bool - ) -> None: - for name, val in stored_vals.items(): - param = get_nested_module_attr(self, name) - if has_instance_dim: - param.data[:, subnet_idx, :, :] = val - else: - param.data[subnet_idx, :, :] = val - def set_As_to_unit_norm(self) -> None: """Set all A matrices to unit norm for stability. diff --git a/spd/models/components.py b/spd/models/components.py index 9d1bad2..a5b1c86 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -9,6 +9,27 @@ from spd.module_utils import init_param_ +def hard_sigmoid(x: Tensor) -> Tensor: + return torch.nn.functional.relu(torch.clamp(x, max=1)) + + +class Gate(nn.Module): + """A gate that maps a single input to a single output.""" + + def __init__(self, m: int, n_instances: int | None = None): + super().__init__() + self.n_instances = n_instances + shape = (n_instances, m) if n_instances is not None else (m,) + self.weight = nn.Parameter(torch.empty(shape)) + init_param_(self.weight, scale=1.0, init_type="kaiming_uniform") + self.bias = nn.Parameter(torch.zeros(shape)) + + def forward( + self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: + return hard_sigmoid(x * self.weight + self.bias) + + class Linear(nn.Module): """A linear transformation with an optional n_instances dimension.""" @@ -47,45 +68,36 @@ def __init__( self, d_in: int, d_out: int, - C: int, + m: int, n_instances: int | None = None, init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", init_scale: float = 1.0, - m: int | None = None, ): super().__init__() self.n_instances = n_instances - self.C = C - self.m = min(d_in, d_out) if m is None else m + self.m = m # Initialize A and B matrices - shape_A = (n_instances, C, d_in, self.m) if n_instances is not None else (C, d_in, self.m) - shape_B = (n_instances, C, self.m, d_out) if n_instances is not None else (C, self.m, d_out) + shape_A = (n_instances, d_in, self.m) if n_instances is not None else (d_in, self.m) + shape_B = (n_instances, self.m, d_out) if n_instances is not None else (self.m, d_out) self.A = nn.Parameter(torch.empty(shape_A)) self.B = nn.Parameter(torch.empty(shape_B)) self.hook_pre = HookPoint() # (batch d_in) or (batch n_instances d_in) - self.hook_component_acts = HookPoint() # (batch C d_out) or (batch n_instances C d_out) + self.hook_component_acts = HookPoint() # (batch m) or (batch n_instances m) self.hook_post = HookPoint() # (batch d_out) or (batch n_instances d_out) init_param_(self.A, scale=init_scale, init_type=init_type) init_param_(self.B, scale=init_scale, init_type=init_type) - @property - def component_weights(self) -> Float[Tensor, "... C d_in d_out"]: - """A @ B before summing over the subnetwork dimension.""" - return einops.einsum(self.A, self.B, "... C d_in m, ... C m d_out -> ... C d_in d_out") - @property def weight(self) -> Float[Tensor, "... d_in d_out"]: - """A @ B after summing over the subnetwork dimension.""" - return einops.einsum(self.A, self.B, "... C d_in m, ... C m d_out -> ... d_in d_out") + """A @ B""" + return einops.einsum(self.A, self.B, "... d_in m, ... m d_out -> ... d_in d_out") def forward( - self, - x: Float[Tensor, "batch ... d_in"], - mask: Float[Tensor, "batch ... C"] | None = None, + self, x: Float[Tensor, "batch ... d_in"], mask: Float[Tensor, "batch ... m"] | None = None ) -> Float[Tensor, "batch ... d_out"]: - """Forward pass through A and B matrices which make up the component for this layer. + """Forward pass through A and B matrices. Args: x: Input tensor @@ -96,23 +108,14 @@ def forward( x = self.hook_pre(x) # First multiply by A to get to intermediate dimension m - inner_acts = einops.einsum(x, self.A, "batch ... d_in, ... C d_in m -> batch ... C m") + component_acts = einops.einsum(x, self.A, "batch ... d_in, ... d_in m -> batch ... m") if mask is not None: - # We could apply the mask after component_acts, but we do it here so our matrices become - # sparser and more efficient to compute with. - inner_acts = einops.einsum( - inner_acts, mask, "batch ... C m, batch ... C -> batch ... C m" - ) + component_acts *= mask + component_acts = self.hook_component_acts(component_acts) # Then multiply by B to get to output dimension - component_acts = einops.einsum( - inner_acts, self.B, "batch ... C m, ... C m d_out -> batch ... C d_out" - ) - - self.hook_component_acts(component_acts) + out = einops.einsum(component_acts, self.B, "batch ... m, ... m d_out -> batch ... d_out") - # Sum over subnetwork dimension - out = einops.einsum(component_acts, "batch ... C d_out -> batch ... d_out") out = self.hook_post(out) return out @@ -147,31 +150,26 @@ def __init__(self, original_A: nn.Parameter, original_B: nn.Parameter): # Copy the relevant parts from LinearComponent.__init__. Don't copy operations that will # call TransposedLinear.A or TransposedLinear.B. nn.Module.__init__(self) - self.n_instances, self.C, _, self.m = original_A.shape + self.n_instances, _, self.m = original_A.shape self.hook_pre = HookPoint() # (batch ... d_out) - self.hook_component_acts = HookPoint() # (batch ... C d_in) + self.hook_component_acts = HookPoint() # (batch ... m) self.hook_post = HookPoint() # (batch ... d_in) self.register_buffer("original_A", original_A, persistent=False) self.register_buffer("original_B", original_B, persistent=False) @property - def A(self) -> Float[Tensor, "... C d_out m"]: + def A(self) -> Float[Tensor, "... d_out m"]: # New A is the transpose of the original B - return einops.rearrange(self.original_B, "... C m d_out -> ... C d_out m") + return einops.rearrange(self.original_B, "... m d_out -> ... d_out m") @property - def B(self) -> Float[Tensor, "... C d_in m"]: + def B(self) -> Float[Tensor, "... d_in m"]: # New B is the transpose of the original A - return einops.rearrange(self.original_A, "... C d_in m -> ... C m d_in") - - @property - def component_weights(self) -> Float[Tensor, "... C d_out d_in"]: - """A @ B before summing over the subnetwork dimension.""" - return einops.einsum(self.A, self.B, "... C d_out m, ... C m d_in -> ... C d_out d_in") + return einops.rearrange(self.original_A, "... d_in m -> ... m d_in") @property def weight(self) -> Float[Tensor, "... d_out d_in"]: - """A @ B after summing over the subnetwork dimension.""" - return einops.einsum(self.A, self.B, "... C d_out m, ... C m d_in -> ... d_out d_in") + """A @ B""" + return einops.einsum(self.A, self.B, "... d_out m, ... m d_in -> ... d_out d_in") diff --git a/spd/plotting.py b/spd/plotting.py index 4c1e704..891db8d 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -1,5 +1,3 @@ -from typing import Any - import einops import matplotlib.pyplot as plt import matplotlib.ticker as tkr @@ -9,33 +7,18 @@ from matplotlib.colors import CenteredNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from torch import Tensor -from torch.utils.data import DataLoader - -from spd.attributions import calculate_attributions -from spd.configs import Config -from spd.experiments.resid_mlp.models import ResidualMLPModel, ResidualMLPSPDModel -from spd.experiments.tms.models import TMSModel, TMSSPDModel -from spd.hooks import HookedRootModule -from spd.models.base import SPDModel -from spd.utils import ( - DataGenerationType, - SparseFeatureDataset, - calc_recon_mse, - calc_topk_mask, - run_spd_forward_pass, -) def plot_subnetwork_attributions_statistics( - topk_mask: Float[Tensor, "batch_size n_instances C"], + mask: Float[Tensor, "batch_size n_instances m"], ) -> dict[str, plt.Figure]: """Plot vertical bar charts of the number of active subnetworks over the batch for each instance.""" - batch_size = topk_mask.shape[0] - if topk_mask.ndim == 2: + batch_size = mask.shape[0] + if mask.ndim == 2: n_instances = 1 - topk_mask = einops.repeat(topk_mask, "batch C -> batch n_instances C", n_instances=1) + mask = einops.repeat(mask, "batch m -> batch n_instances m", n_instances=1) else: - n_instances = topk_mask.shape[1] + n_instances = mask.shape[1] fig, axs = plt.subplots( ncols=n_instances, nrows=1, figsize=(5 * n_instances, 5), constrained_layout=True @@ -43,7 +26,7 @@ def plot_subnetwork_attributions_statistics( axs = np.array([axs]) if n_instances == 1 else np.array(axs) for i, ax in enumerate(axs): - values = topk_mask[:, i].sum(dim=1).cpu().detach().numpy() + values = mask[:, i].sum(dim=1).cpu().detach().numpy() bins = list(range(int(values.min().item()), int(values.max().item()) + 2)) counts, _ = np.histogram(values, bins=bins) bars = ax.bar(bins[:-1], counts, align="center", width=0.8) @@ -73,245 +56,6 @@ def plot_subnetwork_attributions_statistics( return {"subnetwork_attributions_statistics": fig} -def plot_subnetwork_correlations( - dataloader: DataLoader[ - tuple[Float[Tensor, "batch n_inputs"] | Float[Tensor, "batch n_instances? n_inputs"], Any] - ], - target_model: HookedRootModule, - spd_model: SPDModel, - config: Config, - device: str, - n_forward_passes: int = 100, -) -> dict[str, plt.Figure]: - topk_masks = [] - for batch, _ in dataloader: - batch = batch.to(device=device) - assert config.topk is not None - - # Forward pass on target model - target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) - target_out, target_cache = target_model.run_with_cache( - batch, names_filter=target_cache_filter - ) - - # Do a forward pass with all subnetworks - spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) - out, spd_cache = spd_model.run_with_cache(batch, names_filter=spd_cache_filter) - attribution_scores = calculate_attributions( - model=spd_model, - config=config, - batch=batch, - out=out, - target_out=target_out, - pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, - post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, - component_acts={ - k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts") - }, - ) - - # We always assume the final subnetwork is the one we want to distil - topk_attrs = ( - attribution_scores[..., :-1] if config.distil_from_target else attribution_scores - ) - if config.exact_topk: - assert spd_model.n_instances == 1, "exact_topk only works if n_instances = 1" - topk = (batch != 0).sum() / batch.shape[0] - topk_mask = calc_topk_mask(topk_attrs, topk, batch_topk=config.batch_topk) - else: - topk_mask = calc_topk_mask(topk_attrs, config.topk, batch_topk=config.batch_topk) - - topk_masks.append(topk_mask) - if len(topk_masks) > n_forward_passes: - break - topk_masks = torch.cat(topk_masks).float() - - if hasattr(spd_model, "n_instances"): - n_instances = spd_model.n_instances - else: - n_instances = 1 - topk_masks = einops.repeat(topk_masks, "batch C -> batch n_instances C", n_instances=1) - - fig, axs = plt.subplots( - ncols=n_instances, nrows=1, figsize=(5 * n_instances, 5), constrained_layout=True - ) - - axs = np.array([axs]) if n_instances == 1 else np.array(axs) - im, ax = None, None - for i, ax in enumerate(axs): - # Calculate correlation matrix - corr_matrix = torch.corrcoef(topk_masks[:, i].T).cpu() - - im = ax.matshow(corr_matrix) - ax.xaxis.set_ticks_position("bottom") - if corr_matrix.shape[0] * corr_matrix.shape[1] < 200: - for l in range(corr_matrix.shape[0]): - for j in range(corr_matrix.shape[1]): - ax.text( - j, - l, - f"{corr_matrix[l, j]:.2f}", - ha="center", - va="center", - color="#EE7777", - fontsize=8, - ) - if (im is not None) and (ax is not None): - divider = make_axes_locatable(plt.gca()) - cax = divider.append_axes("right", size="5%", pad=0.1) - plt.colorbar(im, cax=cax) - ax.set_title("Subnetwork Correlation Matrix") - ax.set_xlabel("Subnetwork") - ax.set_ylabel("Subnetwork") - return {"subnetwork_correlation_matrix": fig} - - -def collect_sparse_dataset_mse_losses( - dataset: SparseFeatureDataset, - target_model: ResidualMLPModel | TMSModel, - spd_model: TMSSPDModel | ResidualMLPSPDModel, - config: Config, - batch_size: int, - device: str, - topk: float, - batch_topk: bool, - distil_from_target: bool, - gen_types: list[DataGenerationType], -) -> dict[str, dict[str, Float[Tensor, ""] | Float[Tensor, " n_instances"]]]: - """Collect the MSE losses for specific number of active features, as well as for - 'at_least_zero_active'. - - We calculate two baselines: - - baseline_monosemantic: a baseline loss where the first d_mlp feature indices get mapped to the - true labels and the final (n_features - d_mlp) features are either 0 (TMS) or the raw inputs - (ResidualMLP). - - Returns: - A dictionary keyed by generation type and then by model type (target, spd, - baseline_monosemantic), with values being MSE losses. - """ - target_model.to(device) - spd_model.to(device) - # Get the entries for the main loss table in the paper - results = {gen_type: {} for gen_type in gen_types} - word_to_num = {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5} - - for gen_type in gen_types: - dataset.data_generation_type = gen_type - batch, labels = dataset.generate_batch(batch_size) - - batch = batch.to(device) - labels = labels.to(device) - - target_model_output = target_model(batch) - - if gen_type == "at_least_zero_active": - run_batch_topk = batch_topk - run_topk = topk - else: - run_batch_topk = False - assert gen_type.startswith("exactly_") - n_active = word_to_num[gen_type.split("_")[1]] - run_topk = n_active - - spd_outputs = run_spd_forward_pass( - spd_model=spd_model, - config=config, - target_model=target_model, - input_array=batch, - batch_topk=run_batch_topk, - topk=run_topk, - distil_from_target=distil_from_target, - ) - # Combine the batch and n_instances dimension for batch, labels, target_model_output, - # spd_outputs.spd_model_masked_output - ein_str = "batch n_instances n_features -> (batch n_instances) n_features" - batch = einops.rearrange(batch, ein_str) - labels = einops.rearrange(labels, ein_str) - target_model_output = einops.rearrange(target_model_output, ein_str) - spd_model_masked_output = einops.rearrange(spd_outputs.spd_model_masked_output, ein_str) - - if gen_type == "at_least_zero_active": - # Remove all entries where there are no active features - mask = (batch != 0).any(dim=-1) - batch = batch[mask] - labels = labels[mask] - target_model_output = target_model_output[mask] - spd_model_masked_output = spd_model_masked_output[mask] - - topk_recon_loss_labels = calc_recon_mse( - spd_model_masked_output, labels, has_instance_dim=False - ) - recon_loss = calc_recon_mse(target_model_output, labels, has_instance_dim=False) - baseline_batch = calc_recon_mse(batch, labels, has_instance_dim=False) - - # Monosemantic baseline - monosemantic_out = batch.clone() - # Assumes TMS or ResidualMLP - if isinstance(target_model, ResidualMLPModel): - d_mlp = target_model.config.d_mlp * target_model.config.n_layers # type: ignore - monosemantic_out[..., :d_mlp] = labels[..., :d_mlp] - elif isinstance(target_model, TMSModel): - d_mlp = target_model.config.n_hidden # type: ignore - # The first d_mlp features are the true labels (i.e. the batch) and the rest are 0 - monosemantic_out[..., d_mlp:] = 0 - baseline_monosemantic = calc_recon_mse(monosemantic_out, labels, has_instance_dim=False) - - results[gen_type]["target"] = recon_loss - results[gen_type]["spd"] = topk_recon_loss_labels - results[gen_type]["baseline_batch"] = baseline_batch - results[gen_type]["baseline_monosemantic"] = baseline_monosemantic - return results - - -def plot_sparse_feature_mse_line_plot( - results: dict[str, dict[str, float]], - label_map: list[tuple[str, str, str]], - log_scale: bool = False, -) -> plt.Figure: - xtick_label_map = { - "at_least_zero_active": "Training distribution", - "exactly_one_active": "Exactly 1 active", - "exactly_two_active": "Exactly 2 active", - "exactly_three_active": "Exactly 3 active", - "exactly_four_active": "Exactly 4 active", - "exactly_five_active": "Exactly 5 active", - } - # Create grouped bar plots for each generation type - fig, ax = plt.subplots(figsize=(12, 6)) - - n_groups = len(results) # number of generation types - n_models = len(label_map) # number of models to compare - width = 0.8 / n_models # width of bars - - # Create bars for each model type - for i, (model_type, label, color) in enumerate(label_map): - x_positions = np.arange(n_groups) + i * width - (n_models - 1) * width / 2 - heights = [results[gen_type][model_type] for gen_type in results] - ax.bar(x_positions, heights, width, label=label, color=color) - - # Customize the plot - ax.set_ylabel("MSE w.r.t true labels") - ax.set_xticks(np.arange(n_groups)) - xtick_labels = [xtick_label_map[gen_type] for gen_type in results] - ax.set_xticklabels(xtick_labels) - ax.legend() - ax.grid(True, alpha=0.3, axis="y") - - if log_scale: - ax.set_yscale("log") - - # Remove top and right spines - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - - # Ensure that 0 is the bottom of the y-axis - ax.set_ylim(bottom=0) - - plt.tight_layout() - return fig - - def plot_matrix( ax: plt.Axes, matrix: torch.Tensor, diff --git a/spd/run_spd.py b/spd/run_spd.py index d6b6995..87bf039 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -3,7 +3,6 @@ from collections.abc import Callable from pathlib import Path -import einops import matplotlib.pyplot as plt import torch import wandb @@ -12,32 +11,25 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from spd.attributions import calculate_attributions +from spd.attributions import calc_grad_attributions from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel +from spd.models.components import Gate from spd.module_utils import collect_nested_module_attrs, get_nested_module_attr -from spd.utils import calc_recon_mse, calc_topk_mask, get_lr_schedule_fn, get_lr_with_warmup +from spd.utils import calc_recon_mse, get_lr_schedule_fn, get_lr_with_warmup def get_common_run_name_suffix(config: Config) -> str: """Generate a run suffix based on Config that is common to all experiments.""" run_suffix = "" - if config.pnorm is not None: - run_suffix += f"p{config.pnorm:.2e}_" - if config.lp_sparsity_coeff is not None: - run_suffix += f"lpsp{config.lp_sparsity_coeff:.2e}_" - if config.topk is not None: - run_suffix += f"topk{config.topk:.2e}_" - if config.topk_recon_coeff is not None: - run_suffix += f"topkrecon{config.topk_recon_coeff:.2e}_" - if config.schatten_pnorm is not None: - run_suffix += f"schatp{config.schatten_pnorm:.2e}_" - if config.schatten_coeff is not None: - run_suffix += f"schatten{config.schatten_coeff:.2e}_" + if config.masked_recon_coeff is not None: + run_suffix += f"maskedrecon{config.masked_recon_coeff:.2e}_" if config.act_recon_coeff is not None: run_suffix += f"actrecon_{config.act_recon_coeff:.2e}_" - run_suffix += f"C{config.C}_" + run_suffix += f"p{config.pnorm:.2e}_" + run_suffix += f"lpsp{config.lp_sparsity_coeff:.2e}_" + run_suffix += f"m{config.m}_" run_suffix += f"sd{config.seed}_" run_suffix += f"attr-{config.attribution_type[:3]}_" run_suffix += f"lr{config.lr:.2e}_" @@ -45,56 +37,6 @@ def get_common_run_name_suffix(config: Config) -> str: return run_suffix -def calc_schatten_loss( - As: dict[str, Float[Tensor, "C d_layer_in m"] | Float[Tensor, "n_instances C d_layer_in m"]], - Bs: dict[str, Float[Tensor, "C m d_layer_out"] | Float[Tensor, "n_instances C m d_layer_out"]], - mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"], - p: float, - n_params: int, - device: str, -) -> Float[Tensor, ""] | Float[Tensor, " n_instances"]: - """Calculate the Schatten p-norms of the topk subnetworks and sum them. - - Args: - As: Dictionary of A matrices for each layer - Bs: Dictionary of B matrices for each layer - mask: The mask to use for the Schatten p-norm penalty. May be a binary mask (if topk) or - a float mask (if lp sparsity). - p: The Schatten p-norm to use (from config.schatten_pnorm) - n_params: The number of parameters in the model - device: The device to use for calculations - Returns: - The Schatten p-norm penalty for the topk subnetworks - """ - assert As.keys() == Bs.keys(), "As and Bs must have the same keys" - n_instances = mask.shape[1] if mask.ndim == 3 else None - accumulate_shape = (n_instances,) if n_instances is not None else () - - schatten_penalty = torch.zeros(accumulate_shape, device=device) - batch_size = mask.shape[0] - - for name in As: - A = As[name] # [C, d_in, m] or [n_instances, C, d_in, m] - B = Bs[name] # [C, m, d_out] or [n_instances, C, m, d_out] - # mask: [batch, C] or [batch, n_instances, C] - - # Compute S_A = A^T A and S_B = B B^T - S_A = einops.einsum(A, A, "... C d_in m, ... C d_in m -> ... C m") - S_B = einops.einsum(B, B, "... C m d_out, ... C m d_out -> ... C m") - - S_AB = S_A * S_B - - # Apply topk mask - S_AB_topk = einops.einsum(S_AB, mask, "... C m, batch ... C -> batch ... C m") - - # Sum the Schatten p-norm - schatten_penalty = schatten_penalty + ((S_AB_topk + 1e-16) ** (0.5 * p)).sum( - dim=(0, -2, -1) - ) - - return schatten_penalty / n_params / batch_size - - def _calc_param_mse( params1: dict[str, Float[Tensor, "d_in d_out"] | Float[Tensor, "n_instances d_in d_out"]], params2: dict[str, Float[Tensor, "d_in d_out"] | Float[Tensor, "n_instances d_in d_out"]], @@ -149,28 +91,27 @@ def calc_param_match_loss( def calc_lp_sparsity_loss( - out: Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"], - attributions: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"], + masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], step_pnorm: float, -) -> Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]: +) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: """Calculate the Lp sparsity loss on the attributions. Args: - out: The output of the model. - attributions: The attributions to use for the sparsity loss. + masks: Dictionary of masks for each layer to use for the sparsity loss. step_pnorm: The pnorm to use for the sparsity loss. Returns: The Lp sparsity loss. Will have an n_instances dimension if the model has an n_instances - dimension. Note that we keep the batch and C dimensions as we need them if calculating - the schatten loss. + dimension. """ - # Average the attributions over the output dimensions - d_model_out = out.shape[-1] - attributions = attributions / d_model_out + # Initialize with zeros matching the shape of first mask + total_loss = torch.zeros_like(next(iter(masks.values()))) + + for layer_mask in masks.values(): + # step_pnorm * 0.5 is because we have the squares of sparsity_inner terms above + layer_loss = (layer_mask.abs() + 1e-16) ** (step_pnorm * 0.5) + total_loss = total_loss + layer_loss - # step_pnorm * 0.5 is because we have the squares of sparsity_inner terms above - lp_sparsity_loss_per_k = (attributions.abs() + 1e-16) ** (step_pnorm * 0.5) - return lp_sparsity_loss_per_k + return total_loss def calc_act_recon( @@ -207,6 +148,29 @@ def calc_act_recon( return (loss / total_act_dim).mean(dim=0) +def calc_masks( + gates: dict[str, Gate], + component_acts: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], +) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: + """Calculate the mask for the SPD model. + + TODO: Use attributions in our gate calculation too. + + Args: + gates: The gates to use for the mask. + component_acts: The activations after each subnetwork in the SPD model. + attributions: The attributions to use for the mask. + + Returns: + Dictionary of masks for each layer. + """ + masks = {} + for layer_name in gates: + masks[layer_name] = gates[layer_name](component_acts[layer_name + ".hook_component_acts"]) + return masks + + def optimize( model: SPDModel, config: Config, @@ -221,15 +185,25 @@ def optimize( target_model.to(device=device) has_instance_dim = hasattr(model, "n_instances") + n_instances = model.n_instances if has_instance_dim else None - # Note that we expect weight decay to be problematic for spd - opt = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.0) + gates = { + param_name: Gate(n_instances=n_instances, m=config.m).to(device) + for param_name in param_names + } + all_params = list(model.parameters()) + [ + p for gate in gates.values() for p in gate.parameters() + ] + + # Note that we expect weight decay to be problematic for spd models + opt = torch.optim.AdamW(all_params, lr=config.lr, weight_decay=0.0) lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) n_params = 0 for param_name in param_names: - n_params += get_nested_module_attr(target_model, param_name + ".weight").numel() + weight = get_nested_module_attr(target_model, param_name + ".weight") + n_params += weight.numel() if has_instance_dim: # All subnetwork param have an n_instances dimension @@ -277,94 +251,51 @@ def optimize( # Calculate losses out_recon_loss = calc_recon_mse(out, target_out, has_instance_dim) - param_match_loss = None - if config.param_match_coeff is not None: - param_match_loss = calc_param_match_loss( - param_names=param_names, - target_model=target_model, - spd_model=model, - n_params=n_params, - device=device, - ) + param_match_loss = calc_param_match_loss( + param_names=param_names, + target_model=target_model, + spd_model=model, + n_params=n_params, + device=device, + ) post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_post")} pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} - attributions = calculate_attributions( - model=model, - config=config, - batch=batch, - out=out, + component_acts = {k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")} + attributions = calc_grad_attributions( target_out=target_out, pre_weight_acts=pre_weight_acts, post_weight_acts=post_weight_acts, component_acts={ k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts") }, + Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), ) - lp_sparsity_loss_per_k = None - if config.lp_sparsity_coeff is not None: - assert config.pnorm is not None, "pnorm must be set if lp_sparsity_coeff is set" - lp_sparsity_loss_per_k = calc_lp_sparsity_loss( - out=out, attributions=attributions, step_pnorm=config.pnorm - ) + masks = calc_masks(gates=gates, component_acts=component_acts, attributions=attributions) - ( - out_masked, - schatten_loss, - masked_recon_loss, - mask, - layer_acts_masked, - ) = None, None, None, None, None - if config.topk is not None: - # We always assume the final subnetwork is the one we want to distil - topk_attrs: Float[Tensor, "batch ... C"] = ( - attributions[..., :-1] if config.distil_from_target else attributions - ) - if config.exact_topk: - # Currently only valid for batch_topk and n_instances = 1. Would need to change the - # topk argument in calc_topk_mask to allow for tensors if relaxing these constraints - assert config.batch_topk, "exact_topk only works if batch_topk is True" - assert ( - hasattr(model, "n_instances") and model.n_instances == 1 - ), "exact_topk only works if n_instances = 1" - # Get the exact number of active features over the batch - exact_topk = ((batch != 0).sum() / batch.shape[0]).item() - mask = calc_topk_mask(topk_attrs, exact_topk, batch_topk=True) - else: - mask = calc_topk_mask(topk_attrs, config.topk, batch_topk=config.batch_topk) - if config.distil_from_target: - # Add back the final subnetwork index to the topk mask and set it to True - last_subnet_mask = torch.ones( - (*mask.shape[:-1], 1), dtype=mask.dtype, device=device - ) - mask = torch.cat((mask, last_subnet_mask), dim=-1) + normed_masks = {k: v / out.shape[-1] for k, v in masks.items()} + lp_sparsity_loss_per_m = calc_lp_sparsity_loss(masks=normed_masks, step_pnorm=config.pnorm) + # Sum over the m dimension (-1) and mean over the batch dimension (0) + lp_sparsity_loss = lp_sparsity_loss_per_m.sum(dim=-1).mean(dim=0) - # Do a forward pass with only the topk subnetworks - out_masked, spd_cache_masked = model.run_with_cache( - batch, names_filter=spd_cache_filter, mask=mask - ) - layer_acts_masked = { - k: v for k, v in spd_cache_masked.items() if k.endswith("hook_post") - } + # Masked forward pass + out_masked, spd_cache_masked = model.run_with_cache( + batch, names_filter=spd_cache_filter, masks=masks + ) + masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) - if config.topk_recon_coeff is not None: - assert out_masked is not None - masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) + layer_acts_masked = {k: v for k, v in spd_cache_masked.items() if k.endswith("hook_post")} act_recon_loss = None if config.act_recon_coeff is not None: - act_recon_layer_acts = ( - layer_acts_masked - if layer_acts_masked is not None - else {k: v for k, v in spd_cache.items() if k.endswith("hook_post")} - ) + act_recon_layer_acts = layer_acts_masked target_post_weight_acts = post_weight_acts if config.post_relu_act_recon: relu = torch.nn.functional.relu # Only do post-relu act recon for mlp_in layers and ignore the other layers act_recon_layer_acts = { - k: relu(v) for k, v in act_recon_layer_acts.items() if "mlp_in" in k + k: relu(v) for k, v in layer_acts_masked.items() if "mlp_in" in k } target_post_weight_acts = { k: relu(v) for k, v in target_post_weight_acts.items() if "mlp_in" in k @@ -374,32 +305,12 @@ def optimize( layer_acts=act_recon_layer_acts, ) - if config.schatten_coeff is not None: - # Use the sparsity loss as the mask in the lp case, and topk_mask otherwise - mask = mask if mask is not None else lp_sparsity_loss_per_k - assert mask is not None - schatten_pnorm = config.schatten_pnorm if config.schatten_pnorm is not None else 1.0 - schatten_loss = calc_schatten_loss( - As=collect_nested_module_attrs(model, attr_name="A", include_attr_name=False), - Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), - mask=mask, - p=schatten_pnorm, - n_params=n_params, - device=device, - ) - - lp_sparsity_loss = None - if lp_sparsity_loss_per_k is not None: - # Sum over the C dimension (-1) and mean over the batch dimension (0) - lp_sparsity_loss = lp_sparsity_loss_per_k.sum(dim=-1).mean(dim=0) - loss_terms = { "param_match_loss": (param_match_loss, config.param_match_coeff), "out_recon_loss": (out_recon_loss, config.out_recon_coeff), "lp_sparsity_loss": (lp_sparsity_loss, config.lp_sparsity_coeff), - "masked_recon_loss": (masked_recon_loss, config.topk_recon_coeff), + "masked_recon_loss": (masked_recon_loss, config.masked_recon_coeff), "act_recon_loss": (act_recon_loss, config.act_recon_coeff), - "schatten_loss": (schatten_loss, config.schatten_coeff), } # Add up the loss terms loss = torch.tensor(0.0, device=device) @@ -444,7 +355,7 @@ def optimize( out_dir=out_dir, device=device, config=config, - topk_mask=mask, + masks=masks, batch=batch, ) if config.wandb_project: diff --git a/spd/utils.py b/spd/utils.py index c6f8eb2..4a48d33 100644 --- a/spd/utils.py +++ b/spd/utils.py @@ -1,7 +1,7 @@ import random from collections.abc import Callable, Iterator from pathlib import Path -from typing import Any, Generic, Literal, NamedTuple, TypeVar +from typing import Any, Generic, Literal, TypeVar import einops import numpy as np @@ -13,11 +13,7 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from spd.attributions import calculate_attributions -from spd.configs import Config -from spd.hooks import HookedRootModule from spd.log import logger -from spd.models.base import SPDModel T = TypeVar("T", bound=BaseModel) Q = TypeVar("Q") @@ -141,113 +137,6 @@ def __iter__(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: # type: igno yield batch[0], label[0] -class SPDOutputs(NamedTuple): - target_model_output: ( - Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] - ) - spd_model_output: ( - Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] - ) - spd_model_masked_output: ( - Float[Tensor, "batch d_model_out"] | Float[Tensor, "batch n_instances d_model_out"] - ) - layer_acts: dict[str, Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"]] - component_acts: dict[ - str, Float[Tensor, "batch C d_out"] | Float[Tensor, "batch n_instances C d_out"] - ] - attribution_scores: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] - mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] - - -def calc_topk_mask( - attribution_scores: Float[Tensor, "batch ... C"], - topk: float, - batch_topk: bool, -) -> Float[Tensor, "batch ... C"]: - """Calculate the top-k mask. - - Args: - attribution_scores: The attribution scores to calculate the top-k mask for. - topk: The number of top-k elements to select. If `batch_topk` is True, this is multiplied - by the batch size to get the number of top-k elements over the whole batch. - batch_topk: If True, the top-k mask is calculated over the concatenated batch and k - dimensions. - - Returns: - The top-k mask. - """ - batch_size = attribution_scores.shape[0] - topk = int(topk * batch_size) if batch_topk else int(topk) - - if batch_topk: - attribution_scores = einops.rearrange(attribution_scores, "b ... C -> ... (b C)") - - topk_indices = attribution_scores.topk(topk, dim=-1).indices - topk_mask = torch.zeros_like(attribution_scores, dtype=torch.bool) - topk_mask.scatter_(dim=-1, index=topk_indices, value=True) - - if batch_topk: - topk_mask = einops.rearrange(topk_mask, "... (b C) -> b ... C", b=batch_size) - - return topk_mask - - -def run_spd_forward_pass( - spd_model: SPDModel, - config: Config, - target_model: HookedRootModule, - input_array: Float[Tensor, "batch n_inputs"], - batch_topk: bool, - topk: float, - distil_from_target: bool, - mask: Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"] | None = None, -) -> SPDOutputs: - # Forward pass on target model - target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) - target_out, target_cache = target_model.run_with_cache( - input_array, names_filter=target_cache_filter - ) - - # Do a forward pass with all subnetworks - spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) - out, spd_cache = spd_model.run_with_cache(input_array, names_filter=spd_cache_filter) - - attribution_scores = calculate_attributions( - model=spd_model, - config=config, - batch=input_array, - out=out, - target_out=target_out, - pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, - post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, - component_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")}, - ) - - if mask is None: - # We always assume the final subnetwork is the one we want to distil - topk_attrs = attribution_scores[..., :-1] if distil_from_target else attribution_scores - - mask = calc_topk_mask(topk_attrs, topk, batch_topk=batch_topk) - if distil_from_target: - # Add back the final subnetwork index to the topk mask and set it to True - last_subnet_mask = torch.ones( - (*mask.shape[:-1], 1), dtype=torch.bool, device=attribution_scores.device - ) - mask = torch.cat((mask, last_subnet_mask), dim=-1) - - spd_model_masked_output = spd_model(input_array, mask=mask) - attribution_scores = attribution_scores.cpu().detach() - return SPDOutputs( - target_model_output=target_out, - spd_model_output=out, - spd_model_masked_output=spd_model_masked_output, - layer_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_post")}, - component_acts={k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")}, - attribution_scores=attribution_scores, - mask=mask, - ) - - DataGenerationType = Literal[ "exactly_one_active", "exactly_two_active", diff --git a/tests/test_attributions.py b/tests/test_attributions.py deleted file mode 100644 index 737db2f..0000000 --- a/tests/test_attributions.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Tests for attributions.py methods.""" - -import torch - -from spd.attributions import calc_activation_attributions - - -def test_calc_activation_attributions_obvious(): - component_acts = {"layer1": torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])} - expected = torch.tensor([[1.0, 1.0]]) - - result = calc_activation_attributions(component_acts) - torch.testing.assert_close(result, expected) - - -def test_calc_activation_attributions_different_d_out(): - component_acts = { - "layer1": torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]), - "layer2": torch.tensor([[[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]]), - } - expected = torch.tensor( - [[1.0**2 + 2**2 + 5**2 + 6**2 + 7**2, 3**2 + 4**2 + 8**2 + 9**2 + 10**2]] - ) - - result = calc_activation_attributions(component_acts) - torch.testing.assert_close(result, expected) - - -def test_calc_activation_attributions_with_n_instances(): - # Batch=1, n_instances=2, C=2, d_out=2 - component_acts = { - "layer1": torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]), - "layer2": torch.tensor([[[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]]]), - } - expected = torch.tensor( - [ - [ - [1.0**2 + 2**2 + 9**2 + 10**2, 3**2 + 4**2 + 11**2 + 12**2], - [5**2 + 6**2 + 13**2 + 14**2, 7**2 + 8**2 + 15**2 + 16**2], - ] - ] - ) - - result = calc_activation_attributions(component_acts) - torch.testing.assert_close(result, expected) diff --git a/tests/test_components.py b/tests/test_components.py deleted file mode 100644 index 7b46da4..0000000 --- a/tests/test_components.py +++ /dev/null @@ -1,46 +0,0 @@ -import einops -import torch -from jaxtyping import Float -from torch import Tensor - -from spd.models.components import LinearComponent - - -def reference_forward_with_mask( - x: Float[Tensor, "batch ... d_in"], - A: Float[Tensor, "... C d_in m"], - B: Float[Tensor, "... C m d_out"], - topk_mask: Float[Tensor, "batch ... C"], -) -> Float[Tensor, "batch ... d_out"]: - """Reference implementation that applies the mask after the full computation, rather than - after the first multiplication by A (which is done for efficiency in the code). - """ - # Apply A and B matrices - inner = einops.einsum(x, A, "batch ... d_in, ... C d_in m -> batch ... C m") - comp_acts = einops.einsum(inner, B, "batch ... C m, ... C m d_out -> batch ... C d_out") - - # Apply mask and sum - out = einops.einsum(comp_acts, topk_mask, "batch ... C d_out, batch ... C -> batch ... d_out") - - return out - - -def test_linear_component_mask_values(): - """Test that masking works correctly with different mask values.""" - batch_size, d_in, d_out, C, m = 2, 8, 8, 4, 4 - - component = LinearComponent(d_in=d_in, d_out=d_out, C=C, m=m) - x = torch.randn(batch_size, d_in) - - # Test with various mask patterns - test_masks = [ - torch.ones(batch_size, C), # All ones - torch.zeros(batch_size, C), # All zeros - torch.eye(C)[None].expand(batch_size, -1, -1)[:, :C], # Identity-like - torch.rand(batch_size, C), # Random values - ] - - for topk_mask in test_masks: - actual_output = component(x, topk_mask) - expected_output = reference_forward_with_mask(x, component.A, component.B, topk_mask) - torch.testing.assert_close(actual_output, expected_output) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index e50a43f..292aaa5 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -46,17 +46,17 @@ def test_resid_mlp_decomposition_happy_path() -> None: device = "cpu" config = Config( seed=0, - C=3, - topk=1, - batch_topk=True, + m=2, param_match_coeff=1.0, - topk_recon_coeff=1, - schatten_pnorm=1, - schatten_coeff=1, + masked_recon_coeff=1, + act_recon_coeff=1, + post_relu_act_recon=True, + lp_sparsity_coeff=1.0, + pnorm=0.9, attribution_type="gradient", lr=1e-3, batch_size=32, - steps=10, # Run only a few steps for the test + steps=50, # Run only a few steps for the test print_freq=2, image_freq=5, save_freq=None, @@ -70,7 +70,7 @@ def test_resid_mlp_decomposition_happy_path() -> None: target_model = ResidualMLPModel(config=resid_mlp_config).to(device) # Create the SPD model - spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), C=config.C) + spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), m=config.m) model = ResidualMLPSPDModel(config=spd_config).to(device) # Use the pretrained model's embedding matrices and don't train them further @@ -103,12 +103,6 @@ def test_resid_mlp_decomposition_happy_path() -> None: ) dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) - # Set up param_map - param_map = {} - for i in range(resid_mlp_config.n_layers): - param_map[f"layers.{i}.mlp_in"] = f"layers.{i}.mlp_in" - param_map[f"layers.{i}.mlp_out"] = f"layers.{i}.mlp_out" - # Calculate initial loss with torch.inference_mode(): batch, _ = next(iter(dataloader)) @@ -140,7 +134,7 @@ def test_resid_mlp_decomposition_happy_path() -> None: print(f"Final loss: {final_loss}, initial loss: {initial_loss}") # Assert that the final loss is lower than the initial loss assert ( - final_loss < initial_loss + final_loss < initial_loss + 1e-3 ), f"Expected final loss to be lower than initial loss, but got {final_loss} >= {initial_loss}" # Show that W_E is still the same as the target model's W_E @@ -150,6 +144,7 @@ def test_resid_mlp_decomposition_happy_path() -> None: def test_resid_mlp_equivalent_to_raw_model() -> None: device = "cpu" set_seed(0) + m = 4 resid_mlp_config = ResidualMLPConfig( n_instances=2, n_features=3, @@ -161,12 +156,11 @@ def test_resid_mlp_equivalent_to_raw_model() -> None: in_bias=True, out_bias=True, ) - C = 2 target_model = ResidualMLPModel(config=resid_mlp_config).to(device) # Create the SPD model with k=1 - resid_mlp_spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), C=C) + resid_mlp_spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), m=m) spd_model = ResidualMLPSPDModel(config=resid_mlp_spd_config).to(device) # Init all params to random values diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 96f0219..a5962d6 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -76,40 +76,40 @@ def test_calc_param_match_loss_multiple_instances(self): class TestCalcActReconLoss: - def test_calc_topk_act_recon_simple(self): + def test_calc_act_recon_simple(self): # Batch size 2, d_out 2 target_post_weight_acts = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} - layer_acts_topk = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} + layer_acts = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} expected = torch.tensor(0.0) - result = calc_act_recon(target_post_weight_acts, layer_acts_topk) + result = calc_act_recon(target_post_weight_acts, layer_acts) torch.testing.assert_close(result, expected) - def test_calc_topk_act_recon_different_d_out(self): + def test_calc_act_recon_different_d_out(self): # Batch size 2, d_out 2/3 target_post_weight_acts = { "layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), "layer2": torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]), } - layer_acts_topk = { + layer_acts = { "layer1": torch.tensor([[1.5, 2.5], [4.0, 5.0]]), "layer2": torch.tensor([[5.5, 6.5, 7.5], [9.0, 10.0, 11.0]]), } expected = torch.tensor((0.25 + 1) / 2) # ((0.5^2 * 5) / 5 + (1^2 * 5) / 5) / 2 - result = calc_act_recon(target_post_weight_acts, layer_acts_topk) + result = calc_act_recon(target_post_weight_acts, layer_acts) torch.testing.assert_close(result, expected) - def test_calc_topk_act_recon_with_n_instances(self): + def test_calc_act_recon_with_n_instances(self): target_post_weight_acts = { "layer1": torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), "layer2": torch.tensor([[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]]), } - layer_acts_topk = { + layer_acts = { "layer1": torch.tensor([[[1.5, 2.5], [3.5, 4.5]], [[5.5, 6.5], [7.5, 8.5]]]), "layer2": torch.tensor([[[9.5, 10.5], [11.5, 12.5]], [[13.5, 14.5], [15.5, 16.5]]]), } expected = torch.tensor([0.25, 0.25]) # (0.5^2 * 8) / 8 for each instance - result = calc_act_recon(target_post_weight_acts, layer_acts_topk) + result = calc_act_recon(target_post_weight_acts, layer_acts) torch.testing.assert_close(result, expected) diff --git a/tests/test_spd_model.py b/tests/test_spd_model.py deleted file mode 100644 index 63d2b99..0000000 --- a/tests/test_spd_model.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch - -from spd.experiments.resid_mlp.models import ResidualMLPSPDConfig, ResidualMLPSPDModel -from spd.experiments.tms.models import TMSSPDModel, TMSSPDModelConfig - - -def test_tms_set_and_restore_subnet(): - subnet_idx = 2 - config = TMSSPDModelConfig( - n_instances=2, - n_features=4, - n_hidden=3, - C=5, - n_hidden_layers=1, - bias_val=0.0, - device="cpu", - ) - model = TMSSPDModel(config) - assert model.linear1.component_weights.shape == (2, 5, 4, 3) # (n_instances, C, d_in, d_out) - - # Get the original values of the weight_matrix of subnet_idx - original_vals = model.linear1.component_weights[:, subnet_idx, :, :].detach().clone() - - # Now set the 3rd subnet to zero - stored_vals = model.set_subnet_to_zero(subnet_idx=subnet_idx, has_instance_dim=True) - - # Check that model.linear1.component_weights is zero for all instances - assert model.linear1.component_weights[:, subnet_idx, :, :].allclose( - torch.zeros_like(model.linear1.component_weights[:, subnet_idx, :, :]) - ) - assert subnet_idx != 0 - # Check that it's not zero in another component - assert not model.linear1.component_weights[:, 0, :, :].allclose( - torch.zeros_like(model.linear1.component_weights[:, 0, :, :]) - ) - - # Now restore the subnet - model.restore_subnet(subnet_idx=subnet_idx, stored_vals=stored_vals, has_instance_dim=True) - assert model.linear1.component_weights[:, subnet_idx, :, :].allclose(original_vals) - - -def test_resid_mlp_set_and_restore_subnet(): - subnet_idx = 2 - config = ResidualMLPSPDConfig( - n_instances=2, - n_features=4, - d_embed=6, - d_mlp=8, - n_layers=1, - act_fn_name="gelu", - apply_output_act_fn=False, - in_bias=False, - out_bias=False, - init_scale=1.0, - C=5, - init_type="xavier_normal", - ) - model = ResidualMLPSPDModel(config) - - # Check shapes of first layer's component weights - assert model.layers[0].mlp_in.component_weights.shape == (2, 5, 6, 8) # n_inst, C, d_in, d_out - - # Get the original values of the weight_matrix of subnet_idx for both mlp_in and mlp_out - original_vals_in = ( - model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :].detach().clone() - ) - original_vals_out = ( - model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :].detach().clone() - ) - - # Set the subnet to zero - stored_vals = model.set_subnet_to_zero(subnet_idx=subnet_idx, has_instance_dim=True) - - # Check that component_weights are zero for all instances in both mlp_in and mlp_out - assert ( - model.layers[0] - .mlp_in.component_weights[:, subnet_idx, :, :] - .allclose(torch.zeros_like(model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :])) - ) - assert ( - model.layers[0] - .mlp_out.component_weights[:, subnet_idx, :, :] - .allclose(torch.zeros_like(model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :])) - ) - - assert subnet_idx != 0 - # Check that it's not zero in another component - assert ( - not model.layers[0] - .mlp_in.component_weights[:, 0, :, :] - .allclose(torch.zeros_like(model.layers[0].mlp_in.component_weights[:, 0, :, :])) - ) - assert ( - not model.layers[0] - .mlp_out.component_weights[:, 0, :, :] - .allclose(torch.zeros_like(model.layers[0].mlp_out.component_weights[:, 0, :, :])) - ) - - # Restore the subnet - model.restore_subnet(subnet_idx=subnet_idx, stored_vals=stored_vals, has_instance_dim=True) - - # Verify restoration was successful - assert model.layers[0].mlp_in.component_weights[:, subnet_idx, :, :].allclose(original_vals_in) - assert ( - model.layers[0].mlp_out.component_weights[:, subnet_idx, :, :].allclose(original_vals_out) - ) diff --git a/tests/test_tms.py b/tests/test_tms.py index a59dded..55fc056 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -1,6 +1,5 @@ from pathlib import Path -import pytest import torch from jaxtyping import Float from torch import Tensor @@ -38,9 +37,7 @@ def tms_spd_happy_path(config: Config, n_hidden_layers: int = 0): target_model = TMSModel(config=tms_model_config) tms_spd_model_config = TMSSPDModelConfig( - **tms_model_config.model_dump(mode="json"), - C=config.C, - bias_val=config.task_config.bias_val, + **tms_model_config.model_dump(mode="json"), m=config.m, bias_val=config.task_config.bias_val ) model = TMSSPDModel(config=tms_spd_model_config) # Randomly initialize the bias for the pretrained model @@ -85,66 +82,9 @@ def tms_spd_happy_path(config: Config, n_hidden_layers: int = 0): ), "Model A matrix should have changed after optimization" -def test_tms_batch_topk_no_schatten(): +def test_tms_happy_path(): config = Config( - C=5, - topk=2, - batch_topk=True, - batch_size=4, - steps=4, - print_freq=2, - save_freq=None, - lr=1e-3, - topk_recon_coeff=1, - schatten_pnorm=None, - schatten_coeff=None, - task_config=TMS_TASK_CONFIG, - ) - tms_spd_happy_path(config) - - -@pytest.mark.parametrize("n_hidden_layers", [0, 2]) -def test_tms_batch_topk_and_schatten(n_hidden_layers: int): - config = Config( - C=5, - topk=2, - batch_topk=True, - batch_size=4, - steps=4, - print_freq=2, - save_freq=None, - lr=1e-3, - topk_recon_coeff=1, - schatten_pnorm=0.9, - schatten_coeff=1e-1, - task_config=TMS_TASK_CONFIG, - ) - tms_spd_happy_path(config, n_hidden_layers) - - -def test_tms_topk_and_l2(): - config = Config( - C=5, - topk=2, - batch_topk=False, - batch_size=4, - steps=4, - print_freq=2, - save_freq=None, - lr=1e-3, - topk_recon_coeff=1, - schatten_pnorm=0.9, - schatten_coeff=1e-1, - task_config=TMS_TASK_CONFIG, - ) - tms_spd_happy_path(config) - - -def test_tms_lp(): - config = Config( - C=5, - topk=None, - batch_topk=False, + m=10, batch_size=4, steps=4, print_freq=2, @@ -157,25 +97,6 @@ def test_tms_lp(): tms_spd_happy_path(config) -@pytest.mark.parametrize("n_hidden_layers", [0, 2]) -def test_tms_topk_and_lp(n_hidden_layers: int): - config = Config( - C=5, - topk=2, - batch_topk=False, - batch_size=4, - steps=4, - print_freq=2, - save_freq=None, - lr=1e-3, - pnorm=0.9, - topk_recon_coeff=1, - lp_sparsity_coeff=1, - task_config=TMS_TASK_CONFIG, - ) - tms_spd_happy_path(config, n_hidden_layers) - - def test_train_tms_happy_path(): device = "cpu" set_seed(0) @@ -298,14 +219,12 @@ def test_tms_equivalent_to_raw_model() -> None: n_hidden_layers=1, device=device, ) - C = 2 target_model = TMSModel(config=tms_config).to(device) # Create the SPD model tms_spd_config = TMSSPDModelConfig( **tms_config.model_dump(), - C=C, m=3, # Small m for testing bias_val=0.0, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ed2c812..d48ed88 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,62 +5,7 @@ from jaxtyping import Float from torch import Tensor -from spd.utils import SparseFeatureDataset, calc_topk_mask, compute_feature_importances - - -def test_calc_topk_mask_without_batch_topk(): - attribution_scores = torch.tensor([[1.0, 5.0, 2.0, 1.0, 2.0], [3.0, 3.0, 5.0, 4.0, 4.0]]) - topk = 3 - expected_mask = torch.tensor( - [[False, True, True, False, True], [False, False, True, True, True]] - ) - - result = calc_topk_mask(attribution_scores, topk, batch_topk=False) - torch.testing.assert_close(result, expected_mask) - - -def test_calc_topk_mask_with_batch_topk(): - attribution_scores = torch.tensor([[1.0, 5.0, 2.0, 1.0, 2.0], [3.0, 3.0, 5.0, 4.0, 4.0]]) - topk = 3 # mutliplied by batch size to get 6 - expected_mask = torch.tensor( - [[False, True, False, False, False], [True, True, True, True, True]] - ) - - result = calc_topk_mask(attribution_scores, topk, batch_topk=True) - torch.testing.assert_close(result, expected_mask) - - -def test_calc_topk_mask_without_batch_topk_n_instances(): - """attributions have shape [batch, n_instances, n_features]. We take the topk - over the n_features dim for each instance in each batch.""" - attribution_scores = torch.tensor( - [[[1.0, 5.0, 3.0, 4.0], [2.0, 4.0, 6.0, 1.0]], [[2.0, 1.0, 5.0, 9.5], [3.0, 4.0, 1.0, 5.0]]] - ) - topk = 2 - expected_mask = torch.tensor( - [ - [[False, True, False, True], [False, True, True, False]], - [[False, False, True, True], [False, True, False, True]], - ] - ) - - result = calc_topk_mask(attribution_scores, topk, batch_topk=False) - torch.testing.assert_close(result, expected_mask) - - -def test_calc_topk_mask_with_batch_topk_n_instances(): - """attributions have shape [batch, n_instances, n_features]. We take the topk - over the concatenated batch and n_features dim.""" - attribution_scores = torch.tensor( - [[[1.0, 5.0, 3.0], [2.0, 4.0, 6.0]], [[2.0, 1.0, 5.0], [3.0, 4.0, 1.0]]] - ) - topk = 2 # multiplied by batch size to get 4 - expected_mask = torch.tensor( - [[[False, True, True], [False, True, True]], [[True, False, True], [True, True, False]]] - ) - - result = calc_topk_mask(attribution_scores, topk, batch_topk=True) - torch.testing.assert_close(result, expected_mask) +from spd.utils import SparseFeatureDataset, compute_feature_importances def test_dataset_at_least_zero_active(): From c784489c687190cd384b7eaa10cb6e46b40eca5d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 13:41:08 +0000 Subject: [PATCH 03/73] Fix grad attributions and calc_recon_mse --- spd/attributions.py | 67 ++++++------------------------- spd/run_spd.py | 97 +++++++++++++++++++++++---------------------- 2 files changed, 62 insertions(+), 102 deletions(-) diff --git a/spd/attributions.py b/spd/attributions.py index 7b43b7c..42f2663 100644 --- a/spd/attributions.py +++ b/spd/attributions.py @@ -5,10 +5,6 @@ from jaxtyping import Float from torch import Tensor -from spd.hooks import HookedRootModule -from spd.models.base import SPDModel -from spd.module_utils import collect_nested_module_attrs - def calc_grad_attributions( target_out: Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"], @@ -18,14 +14,16 @@ def calc_grad_attributions( post_weight_acts: dict[ str, Float[Tensor, "batch d_out"] | Float[Tensor, "batch n_instances d_out"] ], - component_acts: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], Bs: dict[str, Float[Tensor, "m d_out"] | Float[Tensor, "n_instances m d_out"]], + target_component_acts: dict[ + str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ], ) -> dict[str, Float[Tensor, "batch C"] | Float[Tensor, "batch n_instances C"]]: """Calculate the sum of the (squared) attributions from each output dimension. An attribution is the product of the gradient of the target model output w.r.t. the post acts and the component_acts. I.e. - sum_i[((pre_weight_acts @ A) @ (B @ d(out_i)/d(post_weight_acts))) ** 2] + sum_i[((pre_weight_acts @ A) * (B @ d(out_i)/d(post_weight_acts))) ** 2] Note: This code may be run in between the training forward pass, and the loss.backward() and opt.step() calls; it must not mess with the training. The reason the current implementation is @@ -39,8 +37,8 @@ def calc_grad_attributions( target_out: The output of the target model. pre_weight_acts: The activations of the target model before the weight matrix at each layer. post_weight_acts: The activations at the target model after the weight matrix at each layer. - component_acts: The activations after multiplying by A at each layer. Bs: The B matrix at each layer. + target_component_acts: The component acts at each layer. (I.e. (pre_weight_acts @ A)) Returns: A dictionary of the sum of the (squared) attributions from each output dimension for each @@ -49,11 +47,14 @@ def calc_grad_attributions( # Ensure that all keys are the same after removing the hook suffixes post_weight_act_names = [comp.removesuffix(".hook_post") for comp in post_weight_acts] pre_weight_act_names = [comp.removesuffix(".hook_pre") for comp in pre_weight_acts] - component_act_names = [comp.removesuffix(".hook_component_acts") for comp in component_acts] - assert set(post_weight_act_names) == set(pre_weight_act_names) == set(component_act_names) + assert ( + set(post_weight_act_names) + == set(pre_weight_act_names) + == set(Bs.keys()) + == set(target_component_acts.keys()) + ) m = next(iter(Bs.values())).shape[-2] - attr_shape = target_out.shape[:-1] + (m,) # (batch, m) or (batch, n_instances, m) attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]] = { param_name: torch.zeros(attr_shape, device=target_out.device, dtype=target_out.dtype) @@ -71,50 +72,6 @@ def calc_grad_attributions( grad_B = einops.einsum( Bs[param_name], grad_post_weight_acts[i], "... m d_out, ... d_out -> ... m" ) - attributions[param_name] += ( - component_acts[param_name + ".hook_component_acts"] * grad_B - ) ** 2 + attributions[param_name] += (target_component_acts[param_name] * grad_B) ** 2 return attributions - - -def collect_subnetwork_attributions( - spd_model: SPDModel, - target_model: HookedRootModule, - device: str, - n_instances: int | None = None, -) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: - """ - Collect subnetwork attributions. - - This function creates a test batch using an identity matrix, passes it through the model, - and collects the attributions. - - Args: - spd_model: The model to collect attributions on. - config: The main SPD config. - target_model: The target model to collect attributions on. - device: The device to run computations on. - n_instances: The number of instances in the batch. - - Returns: - The attribution scores. - """ - test_batch = torch.eye(spd_model.n_features, device=device) - if n_instances is not None: - test_batch = einops.repeat( - test_batch, "batch n_features -> batch n_instances n_features", n_instances=n_instances - ) - target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) - target_out, target_cache = target_model.run_with_cache( - test_batch, names_filter=target_cache_filter - ) - - attribution_scores = calc_grad_attributions( - target_out=target_out, - pre_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_pre")}, - post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, - component_acts={k: v for k, v in target_cache.items() if k.endswith("hook_component_acts")}, - Bs=collect_nested_module_attrs(spd_model, attr_name="B", include_attr_name=False), - ) - return attribution_scores diff --git a/spd/run_spd.py b/spd/run_spd.py index 87bf039..06f29cc 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -3,6 +3,7 @@ from collections.abc import Callable from pathlib import Path +import einops import matplotlib.pyplot as plt import torch import wandb @@ -114,38 +115,26 @@ def calc_lp_sparsity_loss( return total_loss -def calc_act_recon( - target_post_weight_acts: dict[ - str, Float[Tensor, "batch n_instances d_out"] | Float[Tensor, "batch d_out"] - ], - layer_acts: dict[str, Float[Tensor, "batch n_instances d_out"] | Float[Tensor, "batch d_out"]], +def calc_act_recon_mse( + acts1: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + acts2: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], ) -> Float[Tensor, ""] | Float[Tensor, " n_instances"]: - """MSE between all target model activations and the output of each subnetwork in the SPD model. - - Args: - target_post_weight_acts: The activations after each layer in the target model. - layer_acts: The activations after each subnetwork in the SPD model. - + """MSE between each entry in acts1 and acts2. Returns: The activation reconstruction loss. Will have an n_instances dimension if the model has an n_instances dimension, otherwise a scalar. """ - assert ( - target_post_weight_acts.keys() == layer_acts.keys() - ), f"Layer keys must match: {target_post_weight_acts.keys()} != {layer_acts.keys()}" + assert acts1.keys() == acts2.keys(), f"Key mismatch: {acts1.keys()} != {acts2.keys()}" - device = next(iter(layer_acts.values())).device + device = next(iter(acts1.values())).device + m = next(iter(acts1.values())).shape[-1] - total_act_dim = 0 # Accumulate the d_out over all layers for normalization loss = torch.zeros(1, device=device) - for layer_name in target_post_weight_acts: - total_act_dim += target_post_weight_acts[layer_name].shape[-1] - - error = ((target_post_weight_acts[layer_name] - layer_acts[layer_name]) ** 2).sum(dim=-1) - loss = loss + error + for layer_name in acts1: + loss = loss + ((acts1[layer_name] - acts2[layer_name]) ** 2).sum(dim=-1) # Normalize by the total number of output dimensions and mean over the batch dim - return (loss / total_act_dim).mean(dim=0) + return (loss / (m * len(acts1))).mean(dim=0) def calc_masks( @@ -167,10 +156,31 @@ def calc_masks( """ masks = {} for layer_name in gates: - masks[layer_name] = gates[layer_name](component_acts[layer_name + ".hook_component_acts"]) + masks[layer_name] = gates[layer_name](component_acts[layer_name]) return masks +def calc_component_acts( + pre_weight_acts: dict[ + str, Float[Tensor, "batch n_instances d_in"] | Float[Tensor, "batch d_in"] + ], + As: dict[str, Float[Tensor, "d_in m"] | Float[Tensor, "n_instances d_in m"]], +) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: + """Calculate the component acts for each layer. I.e. (pre_weight_acts @ A). + + Args: + pre_weight_acts: The activations before each layer in the target model. + As: The A matrix at each layer. + """ + component_acts = {} + for param_name in pre_weight_acts: + raw_name = param_name.removesuffix(".hook_pre") + component_acts[raw_name] = einops.einsum( + pre_weight_acts[param_name], As[raw_name], "... d_in, ... d_in m -> ... m" + ) + return component_acts + + def optimize( model: SPDModel, config: Config, @@ -245,8 +255,7 @@ def optimize( ) # Do a forward pass with all subnetworks - spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) - out, spd_cache = model.run_with_cache(batch, names_filter=spd_cache_filter) + out = model(batch) # Calculate losses out_recon_loss = calc_recon_mse(out, target_out, has_instance_dim) @@ -261,18 +270,22 @@ def optimize( post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_post")} pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} - component_acts = {k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts")} + + target_component_acts = calc_component_acts( + pre_weight_acts=pre_weight_acts, + As=collect_nested_module_attrs(model, attr_name="A", include_attr_name=False), + ) attributions = calc_grad_attributions( target_out=target_out, pre_weight_acts=pre_weight_acts, post_weight_acts=post_weight_acts, - component_acts={ - k: v for k, v in spd_cache.items() if k.endswith("hook_component_acts") - }, + target_component_acts=target_component_acts, Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), ) - masks = calc_masks(gates=gates, component_acts=component_acts, attributions=attributions) + masks = calc_masks( + gates=gates, component_acts=target_component_acts, attributions=attributions + ) normed_masks = {k: v / out.shape[-1] for k, v in masks.items()} lp_sparsity_loss_per_m = calc_lp_sparsity_loss(masks=normed_masks, step_pnorm=config.pnorm) @@ -280,30 +293,20 @@ def optimize( lp_sparsity_loss = lp_sparsity_loss_per_m.sum(dim=-1).mean(dim=0) # Masked forward pass + spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) out_masked, spd_cache_masked = model.run_with_cache( batch, names_filter=spd_cache_filter, masks=masks ) masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) - layer_acts_masked = {k: v for k, v in spd_cache_masked.items() if k.endswith("hook_post")} - act_recon_loss = None if config.act_recon_coeff is not None: - act_recon_layer_acts = layer_acts_masked - target_post_weight_acts = post_weight_acts - if config.post_relu_act_recon: - relu = torch.nn.functional.relu - # Only do post-relu act recon for mlp_in layers and ignore the other layers - act_recon_layer_acts = { - k: relu(v) for k, v in layer_acts_masked.items() if "mlp_in" in k - } - target_post_weight_acts = { - k: relu(v) for k, v in target_post_weight_acts.items() if "mlp_in" in k - } - act_recon_loss = calc_act_recon( - target_post_weight_acts=target_post_weight_acts, - layer_acts=act_recon_layer_acts, - ) + masked_spd_component_acts = { + k.removesuffix(".hook_component_acts"): v + for k, v in spd_cache_masked.items() + if k.endswith("hook_component_acts") + } + act_recon_loss = calc_act_recon_mse(masked_spd_component_acts, target_component_acts) loss_terms = { "param_match_loss": (param_match_loss, config.param_match_coeff), From e3c3eb0ef67896e984ccee0608523f63c02e2ee4 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 13:43:54 +0000 Subject: [PATCH 04/73] Init gate with bias=1 and weights normal dist mean=0 std=0.2 --- spd/models/components.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/models/components.py b/spd/models/components.py index a5b1c86..ad67c71 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -21,8 +21,8 @@ def __init__(self, m: int, n_instances: int | None = None): self.n_instances = n_instances shape = (n_instances, m) if n_instances is not None else (m,) self.weight = nn.Parameter(torch.empty(shape)) - init_param_(self.weight, scale=1.0, init_type="kaiming_uniform") - self.bias = nn.Parameter(torch.zeros(shape)) + torch.nn.init.normal_(self.weight, mean=0.0, std=0.2) + self.bias = nn.Parameter(torch.ones(shape)) def forward( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] From 15b310c6ae93bb4528a87feafa7d0e049410319d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 14:03:34 +0000 Subject: [PATCH 05/73] Fix lp sparsity loss --- spd/attributions.py | 4 ++++ spd/run_spd.py | 37 +++++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/spd/attributions.py b/spd/attributions.py index 42f2663..840902b 100644 --- a/spd/attributions.py +++ b/spd/attributions.py @@ -74,4 +74,8 @@ def calc_grad_attributions( ) attributions[param_name] += (target_component_acts[param_name] * grad_B) ** 2 + # Take the square root of each attribution and divide by the number of output dimensions + for param_name in attributions: + attributions[param_name] = attributions[param_name].sqrt() / target_out.shape[-1] + return attributions diff --git a/spd/run_spd.py b/spd/run_spd.py index 06f29cc..db240cb 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -92,27 +92,31 @@ def calc_param_match_loss( def calc_lp_sparsity_loss( - masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], - step_pnorm: float, + target_component_acts: dict[ + str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ], + pnorm: float, ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: """Calculate the Lp sparsity loss on the attributions. Args: - masks: Dictionary of masks for each layer to use for the sparsity loss. - step_pnorm: The pnorm to use for the sparsity loss. + target_component_acts: Dictionary of pre_weight_acts @ A for each layer to use for the + sparsity loss. + pnorm: The pnorm to use for the sparsity loss. Returns: The Lp sparsity loss. Will have an n_instances dimension if the model has an n_instances dimension. """ # Initialize with zeros matching the shape of first mask - total_loss = torch.zeros_like(next(iter(masks.values()))) + total_loss = torch.zeros_like(next(iter(target_component_acts.values()))) - for layer_mask in masks.values(): - # step_pnorm * 0.5 is because we have the squares of sparsity_inner terms above - layer_loss = (layer_mask.abs() + 1e-16) ** (step_pnorm * 0.5) + for layer_target_component_acts in target_component_acts.values(): + layer_loss = layer_target_component_acts.relu() ** pnorm total_loss = total_loss + layer_loss - return total_loss + m = next(iter(target_component_acts.values())).shape[-1] + # Sum over the batch and m dimensions and normalize by the n_layers * m + return total_loss.sum(dim=(0, -1)) / (len(target_component_acts) * m) def calc_act_recon_mse( @@ -139,7 +143,9 @@ def calc_act_recon_mse( def calc_masks( gates: dict[str, Gate], - component_acts: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + target_component_acts: dict[ + str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ], attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], ) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: """Calculate the mask for the SPD model. @@ -156,7 +162,7 @@ def calc_masks( """ masks = {} for layer_name in gates: - masks[layer_name] = gates[layer_name](component_acts[layer_name]) + masks[layer_name] = gates[layer_name](target_component_acts[layer_name]) return masks @@ -284,13 +290,12 @@ def optimize( ) masks = calc_masks( - gates=gates, component_acts=target_component_acts, attributions=attributions + gates=gates, target_component_acts=target_component_acts, attributions=attributions ) - normed_masks = {k: v / out.shape[-1] for k, v in masks.items()} - lp_sparsity_loss_per_m = calc_lp_sparsity_loss(masks=normed_masks, step_pnorm=config.pnorm) - # Sum over the m dimension (-1) and mean over the batch dimension (0) - lp_sparsity_loss = lp_sparsity_loss_per_m.sum(dim=-1).mean(dim=0) + lp_sparsity_loss = calc_lp_sparsity_loss( + target_component_acts=target_component_acts, pnorm=config.pnorm + ) # Masked forward pass spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) From 3aff69bf0734053483537d7fb327b5b450f7803e Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 14:53:55 +0000 Subject: [PATCH 06/73] Add random mask loss --- spd/configs.py | 2 + .../resid_mlp/resid_mlp_config.yaml | 8 +- spd/experiments/tms/tms_config.yaml | 6 +- spd/run_spd.py | 83 +++++++++++++++---- tests/test_resid_mlp.py | 2 + tests/test_spd_losses.py | 42 +--------- tests/test_tms.py | 2 + 7 files changed, 85 insertions(+), 60 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 09b3602..a119001 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -56,10 +56,12 @@ class Config(BaseModel): act_recon_coeff: NonNegativeFloat | None = None param_match_coeff: NonNegativeFloat | None = 1.0 masked_recon_coeff: NonNegativeFloat | None = None + random_mask_recon_coeff: NonNegativeFloat | None = None lp_sparsity_coeff: NonNegativeFloat pnorm: PositiveFloat post_relu_act_recon: bool = False m: PositiveInt + n_random_masks: PositiveInt lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = "constant" lr_exponential_halflife: PositiveFloat | None = None lr_warmup_pct: Probability = 0.0 diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index 2772de0..cbc33fa 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -1,5 +1,5 @@ # ########## 1 layer ########## -wandb_project: spd-resid-mlp +# wandb_project: spd-resid-mlp wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: true @@ -9,12 +9,14 @@ param_match_coeff: 1.0 masked_recon_coeff: 1.0 act_recon_coeff: 1.0 post_relu_act_recon: true +random_mask_recon_coeff: 1.0 +n_random_masks: 2 pnorm: 0.9 lp_sparsity_coeff: 1.0 batch_size: 256 steps: 10_000 image_freq: 5_000 -print_freq: 500 +print_freq: 100 save_freq: 10_000 lr: 1e-3 lr_schedule: cosine @@ -39,6 +41,8 @@ task_config: # masked_recon_coeff: 2.0 # act_recon_coeff: 1.0 # post_relu_act_recon: true +# random_mask_recon_coeff: 1.0 +# n_random_masks: 2 # pnorm: 0.9 # lp_sparsity_coeff: 1.0 # batch_size: 256 diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 2eba18e..36ff20e 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -8,6 +8,8 @@ # masked_recon_coeff: 1 # pnorm: 0.9 # lp_sparsity_coeff: 1.0 +# random_mask_recon_coeff: 1.0 +# n_random_masks: 2 # batch_size: 2048 # steps: 20_000 # image_freq: 5_000 @@ -35,10 +37,12 @@ param_match_coeff: 1.0 masked_recon_coeff: 10.0 pnorm: 0.9 lp_sparsity_coeff: 1.0 +random_mask_recon_coeff: 1.0 +n_random_masks: 2 batch_size: 2048 steps: 20_000 image_freq: 5_000 -print_freq: 1_000 +print_freq: 200 save_freq: 20_000 lr: 1e-3 lr_schedule: cosine diff --git a/spd/run_spd.py b/spd/run_spd.py index db240cb..871df12 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -166,6 +166,47 @@ def calc_masks( return masks +def calc_random_masks( + masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + n_random_masks: int, +) -> list[dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]]: + """Calculate n_random_masks random masks with the formula `mask + (1 - mask) * rand_unif(0,1)`. + + Args: + masks: The masks to use for the random masks. + n_random_masks: The number of random masks to calculate. + + Return: + A list of n_random_masks dictionaries, each containing the random masks for each layer. + """ + random_masks = [] + for _ in range(n_random_masks): + random_masks.append( + { + layer_name: mask + (1 - mask) * torch.rand_like(mask) + for layer_name, mask in masks.items() + } + ) + return random_masks + + +def calc_random_masks_mse_loss( + model: SPDModel, + batch: Float[Tensor, "batch n_instances d_in"], + random_masks: list[dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]], + out_masked: Float[Tensor, "batch n_instances d_out"], +) -> Float[Tensor, ""] | Float[Tensor, " n_instances"]: + """Calculate the MSE over all random masks.""" + loss = torch.zeros(1, device=out_masked.device) + for i in range(len(random_masks)): + out_masked_random_mask = model(batch, masks=random_masks[i]) + loss = loss + ((out_masked - out_masked_random_mask) ** 2).sum(dim=-1) + + n_layers = len(random_masks[0]) + # Normalize by the total number of output dimensions and mean over the batch dim + return (loss / (len(random_masks) * n_layers * out_masked.shape[-1])).mean(dim=0) + + def calc_component_acts( pre_weight_acts: dict[ str, Float[Tensor, "batch n_instances d_in"] | Float[Tensor, "batch d_in"] @@ -255,25 +296,15 @@ def optimize( batch = batch.to(device=device) total_samples += batch.shape[0] + # Forward pass with target model target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) target_out, target_cache = target_model.run_with_cache( batch, names_filter=target_cache_filter ) - # Do a forward pass with all subnetworks + # Forward pass with all subnetworks out = model(batch) - # Calculate losses - out_recon_loss = calc_recon_mse(out, target_out, has_instance_dim) - - param_match_loss = calc_param_match_loss( - param_names=param_names, - target_model=target_model, - spd_model=model, - n_params=n_params, - device=device, - ) - post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_post")} pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} @@ -293,15 +324,34 @@ def optimize( gates=gates, target_component_acts=target_component_acts, attributions=attributions ) - lp_sparsity_loss = calc_lp_sparsity_loss( - target_component_acts=target_component_acts, pnorm=config.pnorm - ) - # Masked forward pass spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) out_masked, spd_cache_masked = model.run_with_cache( batch, names_filter=spd_cache_filter, masks=masks ) + + random_masks_loss = None + if config.random_mask_recon_coeff is not None: + random_masks = calc_random_masks(masks=masks, n_random_masks=config.n_random_masks) + random_masks_loss = calc_random_masks_mse_loss( + model=model, batch=batch, random_masks=random_masks, out_masked=out_masked + ) + + # Calculate losses + out_recon_loss = calc_recon_mse(out, target_out, has_instance_dim) + + param_match_loss = calc_param_match_loss( + param_names=param_names, + target_model=target_model, + spd_model=model, + n_params=n_params, + device=device, + ) + + lp_sparsity_loss = calc_lp_sparsity_loss( + target_component_acts=target_component_acts, pnorm=config.pnorm + ) + masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) act_recon_loss = None @@ -319,6 +369,7 @@ def optimize( "lp_sparsity_loss": (lp_sparsity_loss, config.lp_sparsity_coeff), "masked_recon_loss": (masked_recon_loss, config.masked_recon_coeff), "act_recon_loss": (act_recon_loss, config.act_recon_coeff), + "random_masks_loss": (random_masks_loss, config.random_mask_recon_coeff), } # Add up the loss terms loss = torch.tensor(0.0, device=device) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 292aaa5..ada410e 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -47,6 +47,8 @@ def test_resid_mlp_decomposition_happy_path() -> None: config = Config( seed=0, m=2, + random_mask_recon_coeff=1, + n_random_masks=2, param_match_coeff=1.0, masked_recon_coeff=1, act_recon_coeff=1, diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index a5962d6..c65d5a8 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -1,6 +1,6 @@ import torch -from spd.run_spd import _calc_param_mse, calc_act_recon +from spd.run_spd import _calc_param_mse class TestCalcParamMatchLoss: @@ -73,43 +73,3 @@ def test_calc_param_match_loss_multiple_instances(self): # Sum together and divide by n_params: [4, 16] / 12 = [1/3, 4/3] expected = torch.tensor([1.0 / 3.0, 4.0 / 3.0]) assert torch.allclose(result, expected), f"Expected {expected}, but got {result}" - - -class TestCalcActReconLoss: - def test_calc_act_recon_simple(self): - # Batch size 2, d_out 2 - target_post_weight_acts = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} - layer_acts = {"layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]])} - expected = torch.tensor(0.0) - - result = calc_act_recon(target_post_weight_acts, layer_acts) - torch.testing.assert_close(result, expected) - - def test_calc_act_recon_different_d_out(self): - # Batch size 2, d_out 2/3 - target_post_weight_acts = { - "layer1": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), - "layer2": torch.tensor([[5.0, 6.0, 7.0], [8.0, 9.0, 10.0]]), - } - layer_acts = { - "layer1": torch.tensor([[1.5, 2.5], [4.0, 5.0]]), - "layer2": torch.tensor([[5.5, 6.5, 7.5], [9.0, 10.0, 11.0]]), - } - expected = torch.tensor((0.25 + 1) / 2) # ((0.5^2 * 5) / 5 + (1^2 * 5) / 5) / 2 - - result = calc_act_recon(target_post_weight_acts, layer_acts) - torch.testing.assert_close(result, expected) - - def test_calc_act_recon_with_n_instances(self): - target_post_weight_acts = { - "layer1": torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]), - "layer2": torch.tensor([[[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]]), - } - layer_acts = { - "layer1": torch.tensor([[[1.5, 2.5], [3.5, 4.5]], [[5.5, 6.5], [7.5, 8.5]]]), - "layer2": torch.tensor([[[9.5, 10.5], [11.5, 12.5]], [[13.5, 14.5], [15.5, 16.5]]]), - } - expected = torch.tensor([0.25, 0.25]) # (0.5^2 * 8) / 8 for each instance - - result = calc_act_recon(target_post_weight_acts, layer_acts) - torch.testing.assert_close(result, expected) diff --git a/tests/test_tms.py b/tests/test_tms.py index 55fc056..3eb593a 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -85,6 +85,8 @@ def tms_spd_happy_path(config: Config, n_hidden_layers: int = 0): def test_tms_happy_path(): config = Config( m=10, + random_mask_recon_coeff=1, + n_random_masks=2, batch_size=4, steps=4, print_freq=2, From 13b809715210f8abfc770cc3cc604f103a76bccf Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 15:11:19 +0000 Subject: [PATCH 07/73] Use relud masks for lp sparsity loss --- spd/models/components.py | 5 +++++ spd/run_spd.py | 42 +++++++++++++++++++--------------------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/spd/models/components.py b/spd/models/components.py index ad67c71..e68463c 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -29,6 +29,11 @@ def forward( ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: return hard_sigmoid(x * self.weight + self.bias) + def forward_relu( + self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: + return (x * self.weight + self.bias).relu() + class Linear(nn.Module): """A linear transformation with an optional n_instances dimension.""" diff --git a/spd/run_spd.py b/spd/run_spd.py index 871df12..5db0492 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -92,31 +92,27 @@ def calc_param_match_loss( def calc_lp_sparsity_loss( - target_component_acts: dict[ - str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] - ], + relud_masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], pnorm: float, ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: """Calculate the Lp sparsity loss on the attributions. Args: - target_component_acts: Dictionary of pre_weight_acts @ A for each layer to use for the - sparsity loss. + relud_masks: Dictionary of relu masks for each layer. pnorm: The pnorm to use for the sparsity loss. Returns: The Lp sparsity loss. Will have an n_instances dimension if the model has an n_instances dimension. """ # Initialize with zeros matching the shape of first mask - total_loss = torch.zeros_like(next(iter(target_component_acts.values()))) + total_loss = torch.zeros_like(next(iter(relud_masks.values()))) - for layer_target_component_acts in target_component_acts.values(): - layer_loss = layer_target_component_acts.relu() ** pnorm - total_loss = total_loss + layer_loss + for layer_relud_mask in relud_masks.values(): + total_loss = total_loss + layer_relud_mask**pnorm - m = next(iter(target_component_acts.values())).shape[-1] + m = next(iter(relud_masks.values())).shape[-1] # Sum over the batch and m dimensions and normalize by the n_layers * m - return total_loss.sum(dim=(0, -1)) / (len(target_component_acts) * m) + return total_loss.sum(dim=(0, -1)) / (len(relud_masks) * m) def calc_act_recon_mse( @@ -147,7 +143,10 @@ def calc_masks( str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] ], attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], -) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: +) -> tuple[ + dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], +]: """Calculate the mask for the SPD model. TODO: Use attributions in our gate calculation too. @@ -161,9 +160,11 @@ def calc_masks( Dictionary of masks for each layer. """ masks = {} + relud_masks = {} for layer_name in gates: - masks[layer_name] = gates[layer_name](target_component_acts[layer_name]) - return masks + masks[layer_name] = gates[layer_name].forward(target_component_acts[layer_name]) + relud_masks[layer_name] = gates[layer_name].forward_relu(target_component_acts[layer_name]) + return masks, relud_masks def calc_random_masks( @@ -200,11 +201,10 @@ def calc_random_masks_mse_loss( loss = torch.zeros(1, device=out_masked.device) for i in range(len(random_masks)): out_masked_random_mask = model(batch, masks=random_masks[i]) - loss = loss + ((out_masked - out_masked_random_mask) ** 2).sum(dim=-1) + loss = loss + ((out_masked - out_masked_random_mask) ** 2).mean(dim=-1) - n_layers = len(random_masks[0]) - # Normalize by the total number of output dimensions and mean over the batch dim - return (loss / (len(random_masks) * n_layers * out_masked.shape[-1])).mean(dim=0) + # Normalize by the number of random masks and mean over the batch dim + return (loss / len(random_masks)).mean(dim=0) def calc_component_acts( @@ -320,7 +320,7 @@ def optimize( Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), ) - masks = calc_masks( + masks, relud_masks = calc_masks( gates=gates, target_component_acts=target_component_acts, attributions=attributions ) @@ -348,9 +348,7 @@ def optimize( device=device, ) - lp_sparsity_loss = calc_lp_sparsity_loss( - target_component_acts=target_component_acts, pnorm=config.pnorm - ) + lp_sparsity_loss = calc_lp_sparsity_loss(relud_masks=relud_masks, pnorm=config.pnorm) masked_recon_loss = calc_recon_mse(out_masked, target_out, has_instance_dim) From 0923c0f65b39aaea5e77964012bf4a2bc2a5a1ae Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 13 Feb 2025 15:30:13 +0000 Subject: [PATCH 08/73] Use masked_target_component_acts in calc_act_recon_mse --- spd/run_spd.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 5db0492..06c9649 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -228,6 +228,24 @@ def calc_component_acts( return component_acts +def calc_masked_target_component_acts( + pre_weight_acts: dict[ + str, Float[Tensor, "batch n_instances d_in"] | Float[Tensor, "batch d_in"] + ], + As: dict[str, Float[Tensor, "d_in m"] | Float[Tensor, "n_instances d_in m"]], + masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], +) -> dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]: + """Calculate the masked target component acts for each layer.""" + masked_target_component_acts = {} + for param_name in pre_weight_acts: + raw_name = param_name.removesuffix(".hook_pre") + masked_As = einops.einsum(As[raw_name], masks[raw_name], "... d_in m, ... m -> ... d_in m") + masked_target_component_acts[raw_name] = einops.einsum( + pre_weight_acts[param_name], masked_As, "... d_in, ... d_in m -> ... m" + ) + return masked_target_component_acts + + def optimize( model: SPDModel, config: Config, @@ -307,17 +325,16 @@ def optimize( post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_post")} pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} + As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + Bs = collect_nested_module_attrs(model, attr_name="B", include_attr_name=False) - target_component_acts = calc_component_acts( - pre_weight_acts=pre_weight_acts, - As=collect_nested_module_attrs(model, attr_name="A", include_attr_name=False), - ) + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) attributions = calc_grad_attributions( target_out=target_out, pre_weight_acts=pre_weight_acts, post_weight_acts=post_weight_acts, target_component_acts=target_component_acts, - Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), + Bs=Bs, ) masks, relud_masks = calc_masks( @@ -359,7 +376,12 @@ def optimize( for k, v in spd_cache_masked.items() if k.endswith("hook_component_acts") } - act_recon_loss = calc_act_recon_mse(masked_spd_component_acts, target_component_acts) + masked_target_component_acts = calc_masked_target_component_acts( + pre_weight_acts=pre_weight_acts, As=As, masks=masks + ) + act_recon_loss = calc_act_recon_mse( + masked_spd_component_acts, masked_target_component_acts + ) loss_terms = { "param_match_loss": (param_match_loss, config.param_match_coeff), From 3aceb8a8500da63900feddbe833911f6b46ea289 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 10:31:21 +0000 Subject: [PATCH 09/73] Comment out grad attribution calculation so people don't use now --- spd/experiments/tms/tms_config.yaml | 8 ++++---- spd/run_spd.py | 21 ++++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 36ff20e..427a9a1 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -32,17 +32,17 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: false seed: 0 -m: 10 +m: 40 param_match_coeff: 1.0 masked_recon_coeff: 10.0 pnorm: 0.9 -lp_sparsity_coeff: 1.0 +lp_sparsity_coeff: 1e-4 random_mask_recon_coeff: 1.0 -n_random_masks: 2 +n_random_masks: 1 batch_size: 2048 steps: 20_000 image_freq: 5_000 -print_freq: 200 +print_freq: 1000 save_freq: 20_000 lr: 1e-3 lr_schedule: cosine diff --git a/spd/run_spd.py b/spd/run_spd.py index 06c9649..63ddd86 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -12,7 +12,6 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from spd.attributions import calc_grad_attributions from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel @@ -142,7 +141,8 @@ def calc_masks( target_component_acts: dict[ str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] ], - attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], + attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]] + | None = None, ) -> tuple[ dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], @@ -323,19 +323,18 @@ def optimize( # Forward pass with all subnetworks out = model(batch) - post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_post")} pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) - Bs = collect_nested_module_attrs(model, attr_name="B", include_attr_name=False) target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) - attributions = calc_grad_attributions( - target_out=target_out, - pre_weight_acts=pre_weight_acts, - post_weight_acts=post_weight_acts, - target_component_acts=target_component_acts, - Bs=Bs, - ) + # attributions = calc_grad_attributions( + # target_out=target_out, + # pre_weight_acts=pre_weight_acts, + # post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, + # target_component_acts=target_component_acts, + # Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), + # ) + attributions = None masks, relud_masks = calc_masks( gates=gates, target_component_acts=target_component_acts, attributions=attributions From 61247dc2eb0e79024fd75ecc95d00ccec7e1c9bd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 12:12:42 +0000 Subject: [PATCH 10/73] Store gates in model class --- spd/experiments/resid_mlp/models.py | 20 +++++++++++-------- .../resid_mlp/resid_mlp_config.yaml | 2 +- spd/experiments/tms/models.py | 13 +++++++++++- spd/experiments/tms/tms_config.yaml | 2 +- spd/run_spd.py | 15 +++++--------- 5 files changed, 31 insertions(+), 21 deletions(-) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index cf02787..3ae37e4 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -17,7 +17,7 @@ from spd.hooks import HookedRootModule from spd.log import logger from spd.models.base import SPDModel -from spd.models.components import Linear, LinearComponent +from spd.models.components import Gate, Linear, LinearComponent from spd.module_utils import init_param_ from spd.run_spd import Config from spd.types import WANDB_PATH_PREFIX, ModelPath @@ -293,6 +293,7 @@ def __init__( self.config = config self.n_features = config.n_features # Required for backward compatibility self.n_instances = config.n_instances # Required for backward compatibility + self.m = config.m assert config.act_fn_name in ["gelu", "relu"] self.act_fn = F.gelu if config.act_fn_name == "gelu" else F.relu @@ -302,10 +303,10 @@ def __init__( init_param_(self.W_E, init_type=config.init_type) init_param_(self.W_U, init_type=config.init_type) - self.m = config.m - - self.layers = nn.ModuleList( - [ + self.layers = nn.ModuleList() + self.gates = nn.ModuleDict() + for i in range(config.n_layers): + self.layers.append( MLP( n_instances=config.n_instances, d_model=config.d_embed, @@ -317,9 +318,12 @@ def __init__( act_fn=self.act_fn, spd_kwargs={"m": self.m}, ) - for _ in range(config.n_layers) - ] - ) + ) + # For now, we just define all the gates in this class rather than in the MLP class + # to make it easier to collect all the gates + self.gates[f"layers-{i}-mlp_in"] = Gate(m=self.m, n_instances=config.n_instances) + self.gates[f"layers-{i}-mlp_out"] = Gate(m=self.m, n_instances=config.n_instances) + self.setup() def forward( diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index cbc33fa..2d80d26 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -1,5 +1,5 @@ # ########## 1 layer ########## -# wandb_project: spd-resid-mlp +wandb_project: spd-resid-mlp wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: true diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 4482b22..5f6ec82 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -14,6 +14,7 @@ from spd.hooks import HookedRootModule from spd.models.base import SPDModel from spd.models.components import ( + Gate, Linear, LinearComponent, TransposedLinear, @@ -193,7 +194,6 @@ def __init__(self, config: TMSSPDModelConfig): m=self.m, ) self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) - bias_data = ( torch.zeros((config.n_instances, config.n_features), device=config.device) + config.bias_val @@ -216,6 +216,17 @@ def __init__(self, config: TMSSPDModelConfig): ] ) + self.gates = nn.ModuleDict( + { + "linear1": Gate(m=self.m, n_instances=config.n_instances), + "linear2": Gate(m=self.m, n_instances=config.n_instances), + **{ + f"hidden_layers-{i}": Gate(m=self.m, n_instances=config.n_instances) + for i in range(config.n_hidden_layers) + }, + } + ) + self.setup() def forward( diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 427a9a1..9db33e2 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -36,7 +36,7 @@ m: 40 param_match_coeff: 1.0 masked_recon_coeff: 10.0 pnorm: 0.9 -lp_sparsity_coeff: 1e-4 +lp_sparsity_coeff: 5e-4 random_mask_recon_coeff: 1.0 n_random_masks: 1 batch_size: 2048 diff --git a/spd/run_spd.py b/spd/run_spd.py index 63ddd86..887e319 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -24,7 +24,7 @@ def get_common_run_name_suffix(config: Config) -> str: """Generate a run suffix based on Config that is common to all experiments.""" run_suffix = "" if config.masked_recon_coeff is not None: - run_suffix += f"maskedrecon{config.masked_recon_coeff:.2e}_" + run_suffix += f"maskrecon{config.masked_recon_coeff:.2e}_" if config.act_recon_coeff is not None: run_suffix += f"actrecon_{config.act_recon_coeff:.2e}_" run_suffix += f"p{config.pnorm:.2e}_" @@ -260,18 +260,12 @@ def optimize( target_model.to(device=device) has_instance_dim = hasattr(model, "n_instances") - n_instances = model.n_instances if has_instance_dim else None - gates = { - param_name: Gate(n_instances=n_instances, m=config.m).to(device) - for param_name in param_names - } - all_params = list(model.parameters()) + [ - p for gate in gates.values() for p in gate.parameters() - ] + # We used "-" instead of "." as module names can't have "." in them + gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items()} # Note that we expect weight decay to be problematic for spd models - opt = torch.optim.AdamW(all_params, lr=config.lr, weight_decay=0.0) + opt = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.0) lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) @@ -434,6 +428,7 @@ def optimize( device=device, config=config, masks=masks, + gates=gates, batch=batch, ) if config.wandb_project: From 64c3a239ade47fa36a9a882f14ee68361baa2d03 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 12:17:04 +0000 Subject: [PATCH 11/73] Remove buggy tms deprecated params replacement --- spd/experiments/tms/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 5f6ec82..4242e01 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -300,6 +300,5 @@ def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: ) model = cls(config=tms_spd_config) params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") - params = replace_deprecated_param_names(params, {"A": "linear1.A", "B": "linear1.B"}) model.load_state_dict(params) return model, spd_config From ed32237a7511070ebcc864121f26073be52a5780 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 15:03:14 +0000 Subject: [PATCH 12/73] Tie the gates for TMS --- spd/experiments/tms/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 4242e01..1ab48f4 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -216,10 +216,12 @@ def __init__(self, config: TMSSPDModelConfig): ] ) + # Same gate for linear1 and linear2 since the weights are tied + gate = Gate(m=self.m, n_instances=config.n_instances) self.gates = nn.ModuleDict( { - "linear1": Gate(m=self.m, n_instances=config.n_instances), - "linear2": Gate(m=self.m, n_instances=config.n_instances), + "linear1": gate, + "linear2": gate, **{ f"hidden_layers-{i}": Gate(m=self.m, n_instances=config.n_instances) for i in range(config.n_hidden_layers) From 60cc056557fcee0be2f0b61e4b497f5fd985de52 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 15:05:40 +0000 Subject: [PATCH 13/73] Plot masks --- spd/experiments/tms/plotting.py | 89 ++++++++++++++++++++++++ spd/experiments/tms/tms_decomposition.py | 8 ++- 2 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 spd/experiments/tms/plotting.py diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py new file mode 100644 index 0000000..8041701 --- /dev/null +++ b/spd/experiments/tms/plotting.py @@ -0,0 +1,89 @@ +import einops +import matplotlib.pyplot as plt +import numpy as np +import torch + +from spd.experiments.tms.models import TMSModel, TMSSPDModel +from spd.models.components import Gate +from spd.module_utils import collect_nested_module_attrs +from spd.run_spd import calc_component_acts, calc_masks + + +def plot_mask_vals( + model: TMSSPDModel, + target_model: TMSModel, + gates: dict[str, Gate], + device: str, + input_magnitude: float, +) -> plt.Figure: + """Plot the values of the mask for a batch of inputs with single active features.""" + # First, create a batch of inputs with single active features + n_features = model.n_features + n_instances = model.n_instances + batch = torch.eye(n_features, device=device) * input_magnitude + batch = einops.repeat( + batch, "batch n_features -> batch n_instances n_features", n_instances=n_instances + ) + + # Forward pass with target model + target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) + target_cache = target_model.run_with_cache(batch, names_filter=target_cache_filter)[1] + pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} + As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + + relud_masks = calc_masks( + gates=gates, target_component_acts=target_component_acts, attributions=None + )[1] + + # Create figure with better layout and sizing + fig, axs = plt.subplots( + len(relud_masks), + n_instances, + figsize=(5 * n_instances, 5 * len(relud_masks)), + constrained_layout=True, + ) + axs = np.array(axs) + + images = [] + for i in range(n_instances): + axs[0, i].set_title(f"Instance {i}") + for j, (mask_name, mask) in enumerate(relud_masks.items()): + # mask has shape (batch, n_instances, m) + mask_data = mask[:, i, :].detach().cpu().numpy() + im = axs[j, i].matshow(mask_data, aspect="auto", cmap="Reds") + images.append(im) + + axs[j, i].set_xlabel("Mask index") + if i == 0: # Only set ylabel for leftmost plots + axs[j, i].set_ylabel("Input feature index") + axs[j, i].set_title(mask_name) + + # Add unified colorbar + norm = plt.Normalize( + vmin=min(mask.min().item() for mask in relud_masks.values()), + vmax=max(mask.max().item() for mask in relud_masks.values()), + ) + for im in images: + im.set_norm(norm) + fig.colorbar(images[0], ax=axs.ravel().tolist()) + + # Add a title which shows the input magnitude + fig.suptitle(f"Input magnitude: {input_magnitude}") + + return fig + + +# pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" +# run_id = "wandb:spd-tms/runs/7qvf63x8" + + +# target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) +# spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) + +# # We used "-" instead of "." as module names can't have "." in them +# gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} + +# fig = plot_mask_vals(spd_model, target_model, gates, device="cpu", input_magnitude=0.5) +# fig.savefig("tms_mask_vals.png") diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 6ba2ecf..53e1229 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -18,7 +18,9 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig +from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger +from spd.models.components import Gate from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import DatasetGeneratedDataLoader, SparseFeatureDataset, load_config, set_seed from spd.wandb_utils import init_wandb @@ -45,11 +47,15 @@ def make_plots( out_dir: Path, device: str, config: Config, - masks: dict[str, Float[Tensor, "batch n_instances m"]] | None, + gates: dict[str, Gate], + masks: dict[str, Float[Tensor, "batch n_instances m"]], batch: Float[Tensor, "batch n_instances n_features"], **_, ) -> dict[str, plt.Figure]: plots = {} + plots["masks"] = plot_mask_vals( + model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 + ) return plots From bc9505c43cf1067e0fd04a27f0cde707616d5f48 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 15:25:03 +0000 Subject: [PATCH 14/73] Fix resid_mlp test (sensitive to float precision) --- spd/experiments/resid_mlp/models.py | 2 -- tests/test_resid_mlp.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 3ae37e4..6426174 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -319,8 +319,6 @@ def __init__( spd_kwargs={"m": self.m}, ) ) - # For now, we just define all the gates in this class rather than in the MLP class - # to make it easier to collect all the gates self.gates[f"layers-{i}-mlp_in"] = Gate(m=self.m, n_instances=config.n_instances) self.gates[f"layers-{i}-mlp_out"] = Gate(m=self.m, n_instances=config.n_instances) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index ada410e..ee6125d 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -161,7 +161,7 @@ def test_resid_mlp_equivalent_to_raw_model() -> None: target_model = ResidualMLPModel(config=resid_mlp_config).to(device) - # Create the SPD model with k=1 + # Create the SPD model resid_mlp_spd_config = ResidualMLPSPDConfig(**resid_mlp_config.model_dump(), m=m) spd_model = ResidualMLPSPDModel(config=resid_mlp_spd_config).to(device) @@ -196,11 +196,11 @@ def test_resid_mlp_equivalent_to_raw_model() -> None: input_data, names_filter=target_cache_filter ) # Forward pass with all subnetworks - spd_cache_filter = lambda k: k.endswith((".hook_post", ".hook_component_acts")) + spd_cache_filter = lambda k: k.endswith(".hook_post") out, spd_cache = spd_model.run_with_cache(input_data, names_filter=spd_cache_filter) # Assert outputs are the same - assert torch.allclose(target_out, out, atol=1e-6), "Outputs do not match" + assert torch.allclose(target_out, out, atol=1e-4), "Outputs do not match" # Assert that all post-acts are the same target_post_weight_acts = {k: v for k, v in target_cache.items() if k.endswith(".hook_post")} From 01a03bce8eac8694e86b67db5a97b1f122f6d722 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 16:21:29 +0000 Subject: [PATCH 15/73] Add init_from_target for tms --- spd/configs.py | 1 + spd/experiments/tms/tms_decomposition.py | 19 +++++++++++ tests/test_tms.py | 41 ++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/spd/configs.py b/spd/configs.py index a119001..029a028 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -62,6 +62,7 @@ class Config(BaseModel): post_relu_act_recon: bool = False m: PositiveInt n_random_masks: PositiveInt + init_from_target_model: bool = False lr_schedule: Literal["linear", "constant", "cosine", "exponential"] = "constant" lr_exponential_halflife: PositiveFloat | None = None lr_warmup_pct: Probability = 0.0 diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 53e1229..1d85abb 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Any +import einops import fire import matplotlib.pyplot as plt import torch @@ -75,6 +76,21 @@ def save_target_model_info( wandb.save(str(out_dir / "tms_train_config.yaml"), base_path=out_dir, policy="now") +def init_spd_model_from_target_model(model: TMSSPDModel, target_model: TMSModel, m: int) -> None: + assert target_model.config.n_hidden_layers == 0, "Hidden layers not supported for now" + assert m == target_model.config.n_features, "m must be equal to n_features" + # We set the A to the identity and B to the target weight matrix + model.linear1.A.data[:] = einops.repeat( + torch.eye(m), + "d_in m -> n_instances d_in m", + n_instances=target_model.config.n_instances, + ) + # The B matrix is just the target model's linear layer + model.linear1.B.data[:] = target_model.linear1.weight.data.clone() + model.b_final.data[:] = target_model.b_final.data.clone() + logger.info("Initialized SPD model from target model") + + def main( config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None ) -> None: @@ -124,6 +140,9 @@ def main( ) model = TMSSPDModel(config=tms_spd_model_config) + if config.init_from_target_model: + init_spd_model_from_target_model(model=model, target_model=target_model, m=config.m) + # Manually set the bias for the SPD model from the bias in the pretrained model model.b_final.data[:] = target_model.b_final.data.clone() diff --git a/tests/test_tms.py b/tests/test_tms.py index 3eb593a..b2ac386 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -6,6 +6,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig +from spd.experiments.tms.tms_decomposition import init_spd_model_from_target_model from spd.experiments.tms.train_tms import TMSTrainConfig, get_model_and_dataloader, train from spd.module_utils import get_nested_module_attr from spd.run_spd import optimize @@ -273,3 +274,43 @@ def test_tms_equivalent_to_raw_model() -> None: assert torch.allclose( target_post_weight_acts[key_name], spd_post_weight_acts[key_name], atol=1e-6 ), f"post-acts do not match at layer {key_name}" + + +def test_init_spd_model_from_target() -> None: + """Test that initializing an SPD model from a target model results in identical outputs.""" + device = "cpu" + set_seed(0) + + # Create target model with no hidden layers (as per current limitation) + tms_config = TMSModelConfig( + n_instances=2, + n_features=3, + n_hidden=2, + n_hidden_layers=0, + device=device, + ) + target_model = TMSModel(config=tms_config).to(device) + + # Create the SPD model with m equal to n_features + tms_spd_config = TMSSPDModelConfig( + **tms_config.model_dump(), + m=tms_config.n_features, # Must match n_features for initialization + bias_val=0.0, + ) + spd_model = TMSSPDModel(config=tms_spd_config).to(device) + + init_spd_model_from_target_model(spd_model, target_model, m=tms_config.n_features) + + # Create a random input + batch_size = 4 + input_data: Float[Tensor, "batch n_instances n_features"] = torch.rand( + batch_size, tms_config.n_instances, tms_config.n_features, device=device + ) + + with torch.inference_mode(): + # Forward pass on both models + target_out = target_model(input_data) + spd_out = spd_model(input_data) + + # Assert outputs are the same + assert torch.allclose(target_out, spd_out), "Outputs after initialization do not match" From 6d6d99fc9b6119b6383925a81a5a01ae88bd3106 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 16:31:07 +0000 Subject: [PATCH 16/73] Support init_from_target for resid_mlp --- .../resid_mlp/resid_mlp_decomposition.py | 51 +++++++++++++++ tests/test_resid_mlp.py | 65 +++++++++++++++++++ tests/test_tms.py | 10 ++- 3 files changed, 123 insertions(+), 3 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index e90f65c..9a55958 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +import einops import fire import matplotlib.pyplot as plt import numpy as np @@ -143,6 +144,53 @@ def save_target_model_info( wandb.save(str(out_dir / "label_coeffs.json"), base_path=out_dir, policy="now") +def init_spd_model_from_target_model( + model: ResidualMLPSPDModel, target_model: ResidualMLPModel, m: int +) -> None: + """Initialize SPD model from target model. + + For mlp_in: A = target weights, B = identity + For mlp_out: A = identity, B = target weights + + Args: + model: The SPD model to initialize + target_model: The target model to initialize from + m: The number of components (must equal d_mlp for initialization) + """ + # For ResidualMLP, we need to initialize each layer's mlp_in and mlp_out components + for i in range(target_model.config.n_layers): + # For mlp_in, m must equal d_mlp + assert m == target_model.config.d_mlp, "m must be equal to d_mlp" + + # For mlp_in: A = target weights, B = identity + model.layers[i].mlp_in.A.data[:] = target_model.layers[i].mlp_in.weight.data.clone() + model.layers[i].mlp_in.B.data[:] = einops.repeat( + torch.eye(m), + "m d_out -> n_instances m d_out", + n_instances=target_model.config.n_instances, + ) + + # For mlp_out: A = identity, B = target weights + model.layers[i].mlp_out.A.data[:] = einops.repeat( + torch.eye(m), + "d_in m -> n_instances d_in m", + n_instances=target_model.config.n_instances, + ) + model.layers[i].mlp_out.B.data[:] = target_model.layers[i].mlp_out.weight.data.clone() + + # Copy biases if they exist + if target_model.config.in_bias: + model.layers[i].bias1.data[:] = target_model.layers[i].bias1.data.clone() + if target_model.config.out_bias: + model.layers[i].bias2.data[:] = target_model.layers[i].bias2.data.clone() + + # Copy embedding matrices + model.W_E.data[:] = target_model.W_E.data.clone() + model.W_U.data[:] = target_model.W_U.data.clone() + + logger.info("Initialized SPD model from target model") + + def main( config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None ) -> None: @@ -226,6 +274,9 @@ def main( model.layers[i].bias2.data[:, :] = target_model.layers[i].bias2.data.detach().clone() model.layers[i].bias2.requires_grad = False + if config.init_from_target_model: + init_spd_model_from_target_model(model=model, target_model=target_model, m=config.m) + param_names = [] for i in range(target_model.config.n_layers): param_names.append(f"layers.{i}.mlp_in") diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index ee6125d..56641f2 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,6 +13,7 @@ ResidualMLPTaskConfig, ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset +from spd.experiments.resid_mlp.resid_mlp_decomposition import init_spd_model_from_target_model from spd.module_utils import get_nested_module_attr from spd.run_spd import optimize from spd.utils import DatasetGeneratedDataLoader, set_seed @@ -209,3 +210,67 @@ def test_resid_mlp_equivalent_to_raw_model() -> None: assert torch.allclose( target_post_weight_acts[key_name], spd_post_weight_acts[key_name], atol=1e-6 ), f"post-acts do not match at layer {key_name}" + + +def test_init_resid_mlp_spd_model_from_target() -> None: + """Test that initializing an SPD model from a target model results in identical outputs.""" + device = "cpu" + set_seed(0) + + # Create target model + resid_mlp_config = ResidualMLPConfig( + n_instances=2, + n_features=3, + d_embed=4, + d_mlp=5, # This will be our m value + n_layers=2, + act_fn_name="relu", + apply_output_act_fn=False, + in_bias=True, + out_bias=True, + init_scale=1.0, + ) + target_model = ResidualMLPModel(config=resid_mlp_config).to(device) + + # Create the SPD model with m equal to d_mlp + resid_mlp_spd_config = ResidualMLPSPDConfig( + **resid_mlp_config.model_dump(), + m=resid_mlp_config.d_mlp, # Must match d_mlp for initialization + init_type="xavier_normal", + ) + spd_model = ResidualMLPSPDModel(config=resid_mlp_spd_config).to(device) + + init_spd_model_from_target_model(spd_model, target_model, m=resid_mlp_config.d_mlp) + + # Also copy the biases + for i in range(resid_mlp_config.n_layers): + spd_model.layers[i].bias1.data[:, :] = target_model.layers[i].bias1.data + spd_model.layers[i].bias2.data[:, :] = target_model.layers[i].bias2.data + + # Create a random input + batch_size = 4 + input_data: Float[Tensor, "batch n_instances n_features"] = torch.rand( + batch_size, resid_mlp_config.n_instances, resid_mlp_config.n_features, device=device + ) + + with torch.inference_mode(): + # Forward pass on both models + target_out = target_model(input_data) + spd_out = spd_model(input_data) + + # Assert outputs are the same + assert torch.allclose(target_out, spd_out), "Outputs after initialization do not match" + + # Also verify that the component matrices were initialized correctly + for i in range(resid_mlp_config.n_layers): + # Check mlp_in weights + spd_weight = spd_model.layers[i].mlp_in.weight + target_weight = target_model.layers[i].mlp_in.weight + assert torch.allclose(spd_weight, target_weight), f"mlp_in weights don't match at layer {i}" + + # Check mlp_out weights + spd_weight = spd_model.layers[i].mlp_out.weight + target_weight = target_model.layers[i].mlp_out.weight + assert torch.allclose( + spd_weight, target_weight + ), f"mlp_out weights don't match at layer {i}" diff --git a/tests/test_tms.py b/tests/test_tms.py index b2ac386..c0e0cad 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -276,7 +276,7 @@ def test_tms_equivalent_to_raw_model() -> None: ), f"post-acts do not match at layer {key_name}" -def test_init_spd_model_from_target() -> None: +def test_init_tms_spd_model_from_target() -> None: """Test that initializing an SPD model from a target model results in identical outputs.""" device = "cpu" set_seed(0) @@ -300,6 +300,8 @@ def test_init_spd_model_from_target() -> None: spd_model = TMSSPDModel(config=tms_spd_config).to(device) init_spd_model_from_target_model(spd_model, target_model, m=tms_config.n_features) + # Also copy the bias + spd_model.b_final.data[:, :] = target_model.b_final.data # Create a random input batch_size = 4 @@ -308,9 +310,11 @@ def test_init_spd_model_from_target() -> None: ) with torch.inference_mode(): - # Forward pass on both models target_out = target_model(input_data) spd_out = spd_model(input_data) - # Assert outputs are the same + assert torch.allclose( + spd_model.linear1.weight, target_model.linear1.weight + ), "Weights do not match" + assert torch.allclose(target_out, spd_out), "Outputs after initialization do not match" From c303c14d70f75845b1852c94526b19cc4a0f6336 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 14 Feb 2025 16:50:35 +0000 Subject: [PATCH 17/73] Normalise lp sparsity by batch size --- spd/experiments/tms/tms_config.yaml | 5 +++-- spd/run_spd.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 9db33e2..197d5a1 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -34,9 +34,9 @@ unit_norm_matrices: false seed: 0 m: 40 param_match_coeff: 1.0 -masked_recon_coeff: 10.0 +masked_recon_coeff: 1.0 pnorm: 0.9 -lp_sparsity_coeff: 5e-4 +lp_sparsity_coeff: 7e-1 random_mask_recon_coeff: 1.0 n_random_masks: 1 batch_size: 2048 @@ -47,6 +47,7 @@ save_freq: 20_000 lr: 1e-3 lr_schedule: cosine lr_warmup_pct: 0.05 +init_from_target_model: true task_config: task_name: tms bias_val: 0.0 diff --git a/spd/run_spd.py b/spd/run_spd.py index 887e319..6694c1c 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -109,9 +109,8 @@ def calc_lp_sparsity_loss( for layer_relud_mask in relud_masks.values(): total_loss = total_loss + layer_relud_mask**pnorm - m = next(iter(relud_masks.values())).shape[-1] - # Sum over the batch and m dimensions and normalize by the n_layers * m - return total_loss.sum(dim=(0, -1)) / (len(relud_masks) * m) + # Mean over the batch and m dimension and divide by the number of parameter layers + return total_loss.mean(dim=(0, -1)) / len(relud_masks) def calc_act_recon_mse( From 41bd85b95bb378c17302e98ee99977f034ea12c0 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Sat, 15 Feb 2025 18:29:27 +0000 Subject: [PATCH 18/73] Don't copy biases in init_spd_model_from_target_model --- spd/experiments/resid_mlp/resid_mlp_decomposition.py | 10 ---------- spd/experiments/tms/tms_decomposition.py | 1 - 2 files changed, 11 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 9a55958..dee7d2d 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -178,16 +178,6 @@ def init_spd_model_from_target_model( ) model.layers[i].mlp_out.B.data[:] = target_model.layers[i].mlp_out.weight.data.clone() - # Copy biases if they exist - if target_model.config.in_bias: - model.layers[i].bias1.data[:] = target_model.layers[i].bias1.data.clone() - if target_model.config.out_bias: - model.layers[i].bias2.data[:] = target_model.layers[i].bias2.data.clone() - - # Copy embedding matrices - model.W_E.data[:] = target_model.W_E.data.clone() - model.W_U.data[:] = target_model.W_U.data.clone() - logger.info("Initialized SPD model from target model") diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 1d85abb..382d76a 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -87,7 +87,6 @@ def init_spd_model_from_target_model(model: TMSSPDModel, target_model: TMSModel, ) # The B matrix is just the target model's linear layer model.linear1.B.data[:] = target_model.linear1.weight.data.clone() - model.b_final.data[:] = target_model.b_final.data.clone() logger.info("Initialized SPD model from target model") From befac1d85803440d843e496546835ed395cc1be9 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Sun, 16 Feb 2025 07:36:20 +0000 Subject: [PATCH 19/73] Fix resid_mlp init_from_target test --- tests/test_resid_mlp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 56641f2..5b2d999 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -242,6 +242,10 @@ def test_init_resid_mlp_spd_model_from_target() -> None: init_spd_model_from_target_model(spd_model, target_model, m=resid_mlp_config.d_mlp) + # Copy the embeddings + spd_model.W_E.data[:, :, :] = target_model.W_E.data + spd_model.W_U.data[:, :, :] = target_model.W_U.data + # Also copy the biases for i in range(resid_mlp_config.n_layers): spd_model.layers[i].bias1.data[:, :] = target_model.layers[i].bias1.data From e7e60a715581531a1c3029614a01128877527d2d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 20 Feb 2025 22:10:09 +0000 Subject: [PATCH 20/73] Add randrecon to run label --- spd/run_spd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spd/run_spd.py b/spd/run_spd.py index 6694c1c..6113792 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -27,6 +27,8 @@ def get_common_run_name_suffix(config: Config) -> str: run_suffix += f"maskrecon{config.masked_recon_coeff:.2e}_" if config.act_recon_coeff is not None: run_suffix += f"actrecon_{config.act_recon_coeff:.2e}_" + if config.random_mask_recon_coeff is not None: + run_suffix += f"randrecon{config.random_mask_recon_coeff:.2e}_" run_suffix += f"p{config.pnorm:.2e}_" run_suffix += f"lpsp{config.lp_sparsity_coeff:.2e}_" run_suffix += f"m{config.m}_" From 3845ca3c1cde394612042a38029d1d815f4da1a1 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Feb 2025 20:01:18 +0000 Subject: [PATCH 21/73] Permute to identity for plotting mask_vals --- .../resid_mlp/resid_mlp_config.yaml | 13 ++-- .../resid_mlp/resid_mlp_decomposition.py | 12 +++- spd/experiments/tms/plotting.py | 62 +++++++++++++++---- spd/experiments/tms/tms_config.yaml | 8 +-- 4 files changed, 69 insertions(+), 26 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index 2d80d26..d68536d 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -4,15 +4,15 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: true seed: 0 -m: 50 +m: 200 param_match_coeff: 1.0 masked_recon_coeff: 1.0 -act_recon_coeff: 1.0 +act_recon_coeff: 1 post_relu_act_recon: true random_mask_recon_coeff: 1.0 -n_random_masks: 2 +n_random_masks: 1 pnorm: 0.9 -lp_sparsity_coeff: 1.0 +lp_sparsity_coeff: 1e-2 batch_size: 256 steps: 10_000 image_freq: 5_000 @@ -21,7 +21,8 @@ save_freq: 10_000 lr: 1e-3 lr_schedule: cosine lr_warmup_pct: 0.01 -image_on_first_step: false +image_on_first_step: true +init_from_target_model: false task_config: task_name: residual_mlp init_scale: 2.0 @@ -36,7 +37,7 @@ task_config: # wandb_run_name_prefix: "" # unit_norm_matrices: false # seed: 0 -# m: 25 +# m: 100 # param_match_coeff: 1.0 # masked_recon_coeff: 2.0 # act_recon_coeff: 1.0 diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index dee7d2d..41a589f 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -2,7 +2,6 @@ import json from datetime import datetime -from functools import partial from pathlib import Path from typing import Any @@ -24,7 +23,9 @@ ResidualMLPSPDModel, ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset +from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger +from spd.models.components import Gate from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, @@ -108,12 +109,17 @@ def resid_mlp_plot_results_fn( out_dir: Path | None, device: str, config: Config, + gates: dict[str, Gate], masks: dict[str, Float[Tensor, "batch_size m"]] | None, **_, ) -> dict[str, plt.Figure]: assert isinstance(config.task_config, ResidualMLPTaskConfig) fig_dict = {} + fig_dict["masks"] = plot_mask_vals( + model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 + ) + # Save plots to files if out_dir: for k, v in fig_dict.items(): @@ -160,6 +166,7 @@ def init_spd_model_from_target_model( # For ResidualMLP, we need to initialize each layer's mlp_in and mlp_out components for i in range(target_model.config.n_layers): # For mlp_in, m must equal d_mlp + # TODO: This is broken, we shouldn't need m=d_mlp for this function. assert m == target_model.config.d_mlp, "m must be equal to d_mlp" # For mlp_in: A = target weights, B = identity @@ -289,7 +296,6 @@ def main( dataloader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) - plot_results_fn = partial(resid_mlp_plot_results_fn, dataloader=dataloader) optimize( model=model, config=config, @@ -298,7 +304,7 @@ def main( target_model=target_model, param_names=param_names, out_dir=out_dir, - plot_results_fn=plot_results_fn, + plot_results_fn=resid_mlp_plot_results_fn, ) if config.wandb_project: diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index 8041701..833f6c0 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -2,16 +2,44 @@ import matplotlib.pyplot as plt import numpy as np import torch +from jaxtyping import Float +from torch import Tensor from spd.experiments.tms.models import TMSModel, TMSSPDModel +from spd.hooks import HookedRootModule +from spd.models.base import SPDModel from spd.models.components import Gate from spd.module_utils import collect_nested_module_attrs from spd.run_spd import calc_component_acts, calc_masks +def permute_to_identity( + mask: Float[Tensor, "batch n_instances m"], +) -> Float[Tensor, "batch n_instances m"]: + batch, n_instances, m = mask.shape + new_mask = mask.clone() + effective_rows: int = min(batch, m) + for inst in range(n_instances): + mat: Tensor = mask[:, inst, :] + perm: list[int] = [0] * m + used: set[int] = set() + for i in range(effective_rows): + sorted_indices: list[int] = torch.argsort(mat[i, :], descending=True).tolist() + chosen: int = next( + (col for col in sorted_indices if col not in used), sorted_indices[0] + ) + perm[i] = chosen + used.add(chosen) + remaining: list[int] = sorted(list(set(range(m)) - used)) + for idx, col in enumerate(remaining): + perm[effective_rows + idx] = col + new_mask[:, inst, :] = mat[:, perm] + return new_mask + + def plot_mask_vals( - model: TMSSPDModel, - target_model: TMSModel, + model: SPDModel, + target_model: HookedRootModule, gates: dict[str, Gate], device: str, input_magnitude: float, @@ -37,12 +65,16 @@ def plot_mask_vals( gates=gates, target_component_acts=target_component_acts, attributions=None )[1] + # Permute columns so that in each instance the maximum per row ends up on the diagonal. + relud_masks = {k: permute_to_identity(mask=v) for k, v in relud_masks.items()} + # Create figure with better layout and sizing fig, axs = plt.subplots( len(relud_masks), n_instances, figsize=(5 * n_instances, 5 * len(relud_masks)), constrained_layout=True, + squeeze=False, ) axs = np.array(axs) @@ -75,15 +107,19 @@ def plot_mask_vals( return fig -# pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" +pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" # run_id = "wandb:spd-tms/runs/7qvf63x8" - - -# target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) -# spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) - -# # We used "-" instead of "." as module names can't have "." in them -# gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} - -# fig = plot_mask_vals(spd_model, target_model, gates, device="cpu", input_magnitude=0.5) -# fig.savefig("tms_mask_vals.png") +# run_id = "wandb:spd-tms/runs/fj68gebo" + +# run_id = "wandb:spd-tms/runs/eafxol4e" +# run_id = "wandb:spd-tms/runs/hr4jv78k" +run_id = "wandb:spd-tms/runs/fj68gebo" +target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) +spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) + +# We used "-" instead of "." as module names can't have "." in them +gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} +input_magnitude = 0.75 +fig = plot_mask_vals(spd_model, target_model, gates, device="cpu", input_magnitude=input_magnitude) # type: ignore +fig.savefig(f"tms_mask_vals_{input_magnitude}.png") +print(f"Saved figure to tms_mask_vals_{input_magnitude}.png") diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 197d5a1..3524f2c 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -36,15 +36,15 @@ m: 40 param_match_coeff: 1.0 masked_recon_coeff: 1.0 pnorm: 0.9 -lp_sparsity_coeff: 7e-1 -random_mask_recon_coeff: 1.0 +lp_sparsity_coeff: 1e-1 +random_mask_recon_coeff: 1 n_random_masks: 1 batch_size: 2048 -steps: 20_000 +steps: 30_000 image_freq: 5_000 print_freq: 1000 save_freq: 20_000 -lr: 1e-3 +lr: 1e-4 lr_schedule: cosine lr_warmup_pct: 0.05 init_from_target_model: true From 3bb654ca28063d3dc99719d4a20bc80e95fbef10 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Feb 2025 12:17:53 +0000 Subject: [PATCH 22/73] Remove post_relu_act_recon config arg --- spd/configs.py | 1 - spd/experiments/resid_mlp/models.py | 7 ------- spd/experiments/resid_mlp/resid_mlp_config.yaml | 4 +--- tests/test_resid_mlp.py | 1 - 4 files changed, 1 insertion(+), 12 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 029a028..39c41fb 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -59,7 +59,6 @@ class Config(BaseModel): random_mask_recon_coeff: NonNegativeFloat | None = None lp_sparsity_coeff: NonNegativeFloat pnorm: PositiveFloat - post_relu_act_recon: bool = False m: PositiveInt n_random_masks: PositiveInt init_from_target_model: bool = False diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 6426174..b63cdd5 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -410,13 +410,6 @@ def from_pretrained( with open(paths.final_config) as f: final_config_dict = yaml.safe_load(f) - # Old configs didn't have post_relu_act_recon - if ( - "post_relu_act_recon" not in final_config_dict - and final_config_dict["act_recon_coeff"] is not None - ): - final_config_dict["post_relu_act_recon"] = True - config = Config(**final_config_dict) with open(paths.resid_mlp_train_config) as f: diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index d68536d..0d61f45 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -2,13 +2,12 @@ wandb_project: spd-resid-mlp wandb_run_name: null wandb_run_name_prefix: "" -unit_norm_matrices: true +unit_norm_matrices: false seed: 0 m: 200 param_match_coeff: 1.0 masked_recon_coeff: 1.0 act_recon_coeff: 1 -post_relu_act_recon: true random_mask_recon_coeff: 1.0 n_random_masks: 1 pnorm: 0.9 @@ -41,7 +40,6 @@ task_config: # param_match_coeff: 1.0 # masked_recon_coeff: 2.0 # act_recon_coeff: 1.0 -# post_relu_act_recon: true # random_mask_recon_coeff: 1.0 # n_random_masks: 2 # pnorm: 0.9 diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 5b2d999..0708e1d 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -53,7 +53,6 @@ def test_resid_mlp_decomposition_happy_path() -> None: param_match_coeff=1.0, masked_recon_coeff=1, act_recon_coeff=1, - post_relu_act_recon=True, lp_sparsity_coeff=1.0, pnorm=0.9, attribution_type="gradient", From ebee91129b5d5d3e46518539324e4a003919f08b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Feb 2025 12:28:19 +0000 Subject: [PATCH 23/73] Remove code from global scope in plotting --- spd/experiments/tms/models.py | 2 ++ spd/experiments/tms/plotting.py | 35 ++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 1ab48f4..3a008c0 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -289,6 +289,8 @@ def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: with open(paths.final_config) as f: final_config_dict = yaml.safe_load(f) + final_config_dict.pop("post_act_recon_coeff", None) + spd_config = Config(**final_config_dict) with open(paths.tms_train_config) as f: diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index 833f6c0..0719cc1 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -107,19 +107,22 @@ def plot_mask_vals( return fig -pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" -# run_id = "wandb:spd-tms/runs/7qvf63x8" -# run_id = "wandb:spd-tms/runs/fj68gebo" - -# run_id = "wandb:spd-tms/runs/eafxol4e" -# run_id = "wandb:spd-tms/runs/hr4jv78k" -run_id = "wandb:spd-tms/runs/fj68gebo" -target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) -spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) - -# We used "-" instead of "." as module names can't have "." in them -gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} -input_magnitude = 0.75 -fig = plot_mask_vals(spd_model, target_model, gates, device="cpu", input_magnitude=input_magnitude) # type: ignore -fig.savefig(f"tms_mask_vals_{input_magnitude}.png") -print(f"Saved figure to tms_mask_vals_{input_magnitude}.png") +if __name__ == "__main__": + pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" + run_id = "wandb:spd-tms/runs/fj68gebo" + target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) + spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) + + # We used "-" instead of "." as module names can't have "." in them + gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} + + input_magnitude = 0.75 + fig = plot_mask_vals( + spd_model, + target_model, + gates, # type: ignore + device="cpu", + input_magnitude=input_magnitude, + ) + fig.savefig(f"tms_mask_vals_{input_magnitude}.png") + print(f"Saved figure to tms_mask_vals_{input_magnitude}.png") From 0b3f61d4938262b186343d1af81d4b4825ce231d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 27 Feb 2025 15:25:07 +0000 Subject: [PATCH 24/73] Handle deprecated 'post_relu_act_recon' arg. --- spd/experiments/resid_mlp/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index b63cdd5..54a4ca3 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -410,6 +410,7 @@ def from_pretrained( with open(paths.final_config) as f: final_config_dict = yaml.safe_load(f) + final_config_dict.pop("post_relu_act_recon", None) config = Config(**final_config_dict) with open(paths.resid_mlp_train_config) as f: From 931b6f3fb19e501510d401648af0ae458ae40ebe Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 3 Mar 2025 13:33:38 +0000 Subject: [PATCH 25/73] Use mps if available --- spd/experiments/resid_mlp/resid_mlp_decomposition.py | 3 ++- spd/experiments/tms/tms_decomposition.py | 10 ++++++++-- spd/utils.py | 9 +++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 41a589f..2ff6c23 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -29,6 +29,7 @@ from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, + get_device, load_config, set_seed, ) @@ -199,7 +200,7 @@ def main( set_seed(config.seed) logger.info(config) - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() print(f"Using device: {device}") assert isinstance(config.task_config, ResidualMLPTaskConfig) diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 382d76a..97f65c7 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -23,7 +23,13 @@ from spd.log import logger from spd.models.components import Gate from spd.run_spd import get_common_run_name_suffix, optimize -from spd.utils import DatasetGeneratedDataLoader, SparseFeatureDataset, load_config, set_seed +from spd.utils import ( + DatasetGeneratedDataLoader, + SparseFeatureDataset, + get_device, + load_config, + set_seed, +) from spd.wandb_utils import init_wandb wandb.require("core") @@ -93,7 +99,7 @@ def init_spd_model_from_target_model(model: TMSSPDModel, target_model: TMSModel, def main( config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None ) -> None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = get_device() config = load_config(config_path_or_obj, config_model=Config) diff --git a/spd/utils.py b/spd/utils.py index 4a48d33..436f627 100644 --- a/spd/utils.py +++ b/spd/utils.py @@ -33,6 +33,15 @@ ] +def get_device() -> str: + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + def set_seed(seed: int | None) -> None: """Set the random seed for random, PyTorch and NumPy""" if seed is not None: From 19d7181268053ad95015e50212bf0fa9066e9b57 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 3 Mar 2025 13:55:49 +0000 Subject: [PATCH 26/73] Avoid mps as it breaks tms --- spd/utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/spd/utils.py b/spd/utils.py index 436f627..a140930 100644 --- a/spd/utils.py +++ b/spd/utils.py @@ -34,12 +34,8 @@ def get_device() -> str: - if torch.cuda.is_available(): - return "cuda" - elif torch.backends.mps.is_available(): - return "mps" - else: - return "cpu" + # NOTE: MPS returns NaNs on TMS when run. Avoiding for now. + return "cuda" if torch.cuda.is_available() else "cpu" def set_seed(seed: int | None) -> None: From 8560f1bcde33e0c971b0008791edb3c04f620d20 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 3 Mar 2025 15:30:25 +0000 Subject: [PATCH 27/73] Untie gates in TMS --- spd/experiments/tms/models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 3a008c0..e88bf73 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -216,12 +216,10 @@ def __init__(self, config: TMSSPDModelConfig): ] ) - # Same gate for linear1 and linear2 since the weights are tied - gate = Gate(m=self.m, n_instances=config.n_instances) self.gates = nn.ModuleDict( { - "linear1": gate, - "linear2": gate, + "linear1": Gate(m=self.m, n_instances=config.n_instances), + "linear2": Gate(m=self.m, n_instances=config.n_instances), **{ f"hidden_layers-{i}": Gate(m=self.m, n_instances=config.n_instances) for i in range(config.n_hidden_layers) From 79391e9b518332f5c5504d675b93eed8ab3e2922 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 4 Mar 2025 16:45:07 +0000 Subject: [PATCH 28/73] Allow for detached inputs to gates and use target_out in random_mask_recon --- spd/run_spd.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 6113792..6b3bf27 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -25,6 +25,7 @@ def get_common_run_name_suffix(config: Config) -> str: run_suffix = "" if config.masked_recon_coeff is not None: run_suffix += f"maskrecon{config.masked_recon_coeff:.2e}_" + run_suffix += f"nrandmasks{config.n_random_masks}_" if config.act_recon_coeff is not None: run_suffix += f"actrecon_{config.act_recon_coeff:.2e}_" if config.random_mask_recon_coeff is not None: @@ -144,6 +145,7 @@ def calc_masks( ], attributions: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]] | None = None, + detach_inputs: bool = False, ) -> tuple[ dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], @@ -156,15 +158,18 @@ def calc_masks( gates: The gates to use for the mask. component_acts: The activations after each subnetwork in the SPD model. attributions: The attributions to use for the mask. - + detach_inputs: Whether to detach the inputs to the gates. Returns: Dictionary of masks for each layer. """ masks = {} relud_masks = {} for layer_name in gates: - masks[layer_name] = gates[layer_name].forward(target_component_acts[layer_name]) - relud_masks[layer_name] = gates[layer_name].forward_relu(target_component_acts[layer_name]) + gate_input = target_component_acts[layer_name] + if detach_inputs: + gate_input = gate_input.detach() + masks[layer_name] = gates[layer_name].forward(gate_input) + relud_masks[layer_name] = gates[layer_name].forward_relu(gate_input) return masks, relud_masks @@ -332,7 +337,10 @@ def optimize( attributions = None masks, relud_masks = calc_masks( - gates=gates, target_component_acts=target_component_acts, attributions=attributions + gates=gates, + target_component_acts=target_component_acts, + attributions=attributions, + detach_inputs=False, ) # Masked forward pass @@ -345,7 +353,7 @@ def optimize( if config.random_mask_recon_coeff is not None: random_masks = calc_random_masks(masks=masks, n_random_masks=config.n_random_masks) random_masks_loss = calc_random_masks_mse_loss( - model=model, batch=batch, random_masks=random_masks, out_masked=out_masked + model=model, batch=batch, random_masks=random_masks, out_masked=target_out ) # Calculate losses From cd23609251b5a3115522fae568bc1353ac0a9e9d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 5 Mar 2025 16:06:12 +0000 Subject: [PATCH 29/73] Add GateMLP --- spd/configs.py | 1 + spd/experiments/resid_mlp/models.py | 18 +++-- .../resid_mlp/resid_mlp_decomposition.py | 9 ++- spd/experiments/tms/models.py | 15 ++++- spd/experiments/tms/tms_decomposition.py | 1 + spd/models/components.py | 65 ++++++++++++++++++- 6 files changed, 98 insertions(+), 11 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 39c41fb..20d8249 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -68,6 +68,7 @@ class Config(BaseModel): sparsity_loss_type: Literal["jacobian"] = "jacobian" unit_norm_matrices: bool = False attribution_type: Literal["gradient"] = "gradient" + n_gate_hidden_neurons: PositiveInt | None = None task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name") DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = [] diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 54a4ca3..128cf32 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -17,7 +17,7 @@ from spd.hooks import HookedRootModule from spd.log import logger from spd.models.base import SPDModel -from spd.models.components import Gate, Linear, LinearComponent +from spd.models.components import Gate, GateMLP, Linear, LinearComponent from spd.module_utils import init_param_ from spd.run_spd import Config from spd.types import WANDB_PATH_PREFIX, ModelPath @@ -281,6 +281,7 @@ class ResidualMLPSPDConfig(BaseModel): out_bias: bool init_scale: float m: PositiveInt + n_gate_hidden_neurons: PositiveInt | None = None init_type: Literal["kaiming_uniform", "xavier_normal"] = "xavier_normal" @@ -304,6 +305,13 @@ def __init__( init_param_(self.W_U, init_type=config.init_type) self.layers = nn.ModuleList() + + # Use GateMLP if n_gate_hidden_neurons is provided, otherwise use Gate + gate_class = GateMLP if config.n_gate_hidden_neurons is not None else Gate + gate_kwargs = {"m": self.m, "n_instances": config.n_instances} + if config.n_gate_hidden_neurons is not None: + gate_kwargs["n_gate_hidden_neurons"] = config.n_gate_hidden_neurons + self.gates = nn.ModuleDict() for i in range(config.n_layers): self.layers.append( @@ -319,8 +327,8 @@ def __init__( spd_kwargs={"m": self.m}, ) ) - self.gates[f"layers-{i}-mlp_in"] = Gate(m=self.m, n_instances=config.n_instances) - self.gates[f"layers-{i}-mlp_out"] = Gate(m=self.m, n_instances=config.n_instances) + self.gates[f"layers-{i}-mlp_in"] = gate_class(**gate_kwargs) + self.gates[f"layers-{i}-mlp_out"] = gate_class(**gate_kwargs) self.setup() @@ -421,7 +429,9 @@ def from_pretrained( assert isinstance(config.task_config, ResidualMLPTaskConfig) resid_mlp_spd_config = ResidualMLPSPDConfig( - **resid_mlp_train_config_dict["resid_mlp_config"], m=config.m + **resid_mlp_train_config_dict["resid_mlp_config"], + m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) model = cls(config=resid_mlp_spd_config) params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 2ff6c23..043487f 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -168,7 +168,9 @@ def init_spd_model_from_target_model( for i in range(target_model.config.n_layers): # For mlp_in, m must equal d_mlp # TODO: This is broken, we shouldn't need m=d_mlp for this function. - assert m == target_model.config.d_mlp, "m must be equal to d_mlp" + assert ( + m == target_model.config.d_mlp or m == target_model.config.d_embed + ), "m must be equal to d_mlp or d_embed" # For mlp_in: A = target weights, B = identity model.layers[i].mlp_in.A.data[:] = target_model.layers[i].mlp_in.weight.data.clone() @@ -253,10 +255,10 @@ def main( out_bias=target_model.config.out_bias, init_scale=config.task_config.init_scale, m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) - model = ResidualMLPSPDModel(config=model_config).to(device) + model = ResidualMLPSPDModel(config=model_config) - model = model.to(device) # Use the target_model's embedding matrix and don't train it further model.W_E.data[:, :] = target_model.W_E.data.detach().clone() model.W_E.requires_grad = False @@ -275,6 +277,7 @@ def main( if config.init_from_target_model: init_spd_model_from_target_model(model=model, target_model=target_model, m=config.m) + model.to(device) param_names = [] for i in range(target_model.config.n_layers): param_names.append(f"layers.{i}.mlp_in") diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index e88bf73..d3b6d2c 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -15,6 +15,7 @@ from spd.models.base import SPDModel from spd.models.components import ( Gate, + GateMLP, Linear, LinearComponent, TransposedLinear, @@ -174,6 +175,7 @@ class TMSSPDModelConfig(BaseModel): bias_val: float device: str m: PositiveInt + n_gate_hidden_neurons: PositiveInt | None = None class TMSSPDModel(SPDModel): @@ -216,12 +218,18 @@ def __init__(self, config: TMSSPDModelConfig): ] ) + # Use GateMLP if n_gate_hidden_neurons is provided, otherwise use Gate + gate_class = GateMLP if config.n_gate_hidden_neurons else Gate + gate_kwargs = {"m": self.m, "n_instances": config.n_instances} + if config.n_gate_hidden_neurons: + gate_kwargs["n_gate_hidden_neurons"] = config.n_gate_hidden_neurons + self.gates = nn.ModuleDict( { - "linear1": Gate(m=self.m, n_instances=config.n_instances), - "linear2": Gate(m=self.m, n_instances=config.n_instances), + "linear1": gate_class(**gate_kwargs), + "linear2": gate_class(**gate_kwargs), **{ - f"hidden_layers-{i}": Gate(m=self.m, n_instances=config.n_instances) + f"hidden_layers-{i}": gate_class(**gate_kwargs) for i in range(config.n_hidden_layers) }, } @@ -299,6 +307,7 @@ def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: **tms_train_config_dict["tms_model_config"], m=spd_config.m, bias_val=spd_config.task_config.bias_val, + n_gate_hidden_neurons=spd_config.n_gate_hidden_neurons, ) model = cls(config=tms_spd_config) params = torch.load(paths.checkpoint, weights_only=True, map_location="cpu") diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 97f65c7..867133d 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -142,6 +142,7 @@ def main( **target_model.config.model_dump(mode="json"), m=config.m, bias_val=task_config.bias_val, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) model = TMSSPDModel(config=tms_spd_model_config) diff --git a/spd/models/components.py b/spd/models/components.py index e68463c..31349bb 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -4,13 +4,14 @@ import torch from jaxtyping import Float from torch import Tensor, nn +from torch.nn import functional as F from spd.hooks import HookPoint from spd.module_utils import init_param_ def hard_sigmoid(x: Tensor) -> Tensor: - return torch.nn.functional.relu(torch.clamp(x, max=1)) + return F.relu(torch.clamp(x, max=1)) class Gate(nn.Module): @@ -35,6 +36,68 @@ def forward_relu( return (x * self.weight + self.bias).relu() +class GateMLP(nn.Module): + """A gate with a hidden layer that maps a single input to a single output.""" + + def __init__(self, m: int, n_gate_hidden_neurons: int, n_instances: int | None = None): + super().__init__() + self.n_instances = n_instances + self.n_gate_hidden_neurons = n_gate_hidden_neurons + + # Define weight shapes based on instances + shape = ( + (n_instances, m, n_gate_hidden_neurons) + if n_instances is not None + else (m, n_gate_hidden_neurons) + ) + in_bias_shape = ( + (n_instances, m, n_gate_hidden_neurons) + if n_instances is not None + else (m, n_gate_hidden_neurons) + ) + out_bias_shape = (n_instances, m) if n_instances is not None else (m,) + + self.mlp_in = nn.Parameter(torch.empty(shape)) + self.in_bias = nn.Parameter(torch.zeros(in_bias_shape)) + self.mlp_out = nn.Parameter(torch.empty(shape)) + self.out_bias = nn.Parameter(torch.zeros(out_bias_shape)) + + torch.nn.init.normal_(self.mlp_in, mean=0.0, std=0.2) + torch.nn.init.normal_(self.mlp_out, mean=0.0, std=0.2) + + def _compute_pre_activation( + self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: + """Compute the output before applying the final activation function.""" + # First layer with gelu activation + hidden = einops.einsum( + x, + self.mlp_in, + "batch ... m, ... m n_gate_hidden_neurons -> batch ... m n_gate_hidden_neurons", + ) + hidden = hidden + self.in_bias + hidden = F.gelu(hidden) + + # Second layer + out = einops.einsum( + hidden, + self.mlp_out, + "batch ... m n_gate_hidden_neurons, ... m n_gate_hidden_neurons -> batch ... m", + ) + out = out + self.out_bias + return out + + def forward( + self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: + return hard_sigmoid(self._compute_pre_activation(x)) + + def forward_relu( + self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] + ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: + return F.relu(self._compute_pre_activation(x)) + + class Linear(nn.Module): """A linear transformation with an optional n_instances dimension.""" From 96939c20baf9e16ee00460807541c58722b0d59b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 11:05:46 +0000 Subject: [PATCH 30/73] Remove bias_val and train_bias config args --- spd/configs.py | 2 -- spd/experiments/tms/models.py | 7 +------ spd/experiments/tms/tms_config.yaml | 15 ++++++--------- spd/experiments/tms/tms_decomposition.py | 5 +---- tests/test_tms.py | 12 ++---------- 5 files changed, 10 insertions(+), 31 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 20d8249..a931445 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -20,8 +20,6 @@ class TMSTaskConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) task_name: Literal["tms"] = "tms" feature_probability: Probability - train_bias: bool - bias_val: float data_generation_type: Literal["exactly_one_active", "at_least_zero_active"] = ( "at_least_zero_active" ) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index d3b6d2c..60cb711 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -172,7 +172,6 @@ class TMSSPDModelConfig(BaseModel): n_features: PositiveInt n_hidden: PositiveInt n_hidden_layers: NonNegativeInt - bias_val: float device: str m: PositiveInt n_gate_hidden_neurons: PositiveInt | None = None @@ -184,7 +183,6 @@ def __init__(self, config: TMSSPDModelConfig): self.config = config self.n_instances = config.n_instances # Required for backwards compatibility self.n_features = config.n_features # Required for backwards compatibility - self.bias_val = config.bias_val self.m = config.m self.linear1 = LinearComponent( @@ -196,11 +194,9 @@ def __init__(self, config: TMSSPDModelConfig): m=self.m, ) self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) - bias_data = ( + self.b_final = nn.Parameter( torch.zeros((config.n_instances, config.n_features), device=config.device) - + config.bias_val ) - self.b_final = nn.Parameter(bias_data) self.hidden_layers = None if config.n_hidden_layers > 0: @@ -306,7 +302,6 @@ def from_pretrained(cls, path: ModelPath) -> tuple["TMSSPDModel", Config]: tms_spd_config = TMSSPDModelConfig( **tms_train_config_dict["tms_model_config"], m=spd_config.m, - bias_val=spd_config.task_config.bias_val, n_gate_hidden_neurons=spd_config.n_gate_hidden_neurons, ) model = cls(config=tms_spd_config) diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 3524f2c..e52d2f9 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -20,8 +20,6 @@ # lr_warmup_pct: 0.05 # task_config: # task_name: tms -# bias_val: 0.0 -# train_bias: false # feature_probability: 0.05 # data_generation_type: "at_least_zero_active" # pretrained_model_path: "wandb:spd-train-tms/runs/cv3g3z9d" # Local or wandb path @@ -32,26 +30,25 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: false seed: 0 -m: 40 +m: 100 param_match_coeff: 1.0 masked_recon_coeff: 1.0 pnorm: 0.9 -lp_sparsity_coeff: 1e-1 +lp_sparsity_coeff: 3e-2 random_mask_recon_coeff: 1 n_random_masks: 1 +n_gate_hidden_neurons: 8 batch_size: 2048 steps: 30_000 image_freq: 5_000 print_freq: 1000 save_freq: 20_000 -lr: 1e-4 -lr_schedule: cosine +lr: 1e-3 +lr_schedule: constant lr_warmup_pct: 0.05 -init_from_target_model: true +init_from_target_model: false task_config: task_name: tms - bias_val: 0.0 - train_bias: false feature_probability: 0.05 data_generation_type: "at_least_zero_active" pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" \ No newline at end of file diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 867133d..86b1e24 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -141,7 +141,6 @@ def main( tms_spd_model_config = TMSSPDModelConfig( **target_model.config.model_dump(mode="json"), m=config.m, - bias_val=task_config.bias_val, n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) model = TMSSPDModel(config=tms_spd_model_config) @@ -151,9 +150,7 @@ def main( # Manually set the bias for the SPD model from the bias in the pretrained model model.b_final.data[:] = target_model.b_final.data.clone() - - if not task_config.train_bias: - model.b_final.requires_grad = False + model.b_final.requires_grad = False param_names = ["linear1", "linear2"] if model.hidden_layers is not None: diff --git a/tests/test_tms.py b/tests/test_tms.py index c0e0cad..8c1eec2 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -16,8 +16,6 @@ TMS_TASK_CONFIG = TMSTaskConfig( task_name="tms", feature_probability=0.5, - train_bias=False, - bias_val=0.0, pretrained_model_path=Path(""), # We'll create this later ) @@ -37,17 +35,13 @@ def tms_spd_happy_path(config: Config, n_hidden_layers: int = 0): ) target_model = TMSModel(config=tms_model_config) - tms_spd_model_config = TMSSPDModelConfig( - **tms_model_config.model_dump(mode="json"), m=config.m, bias_val=config.task_config.bias_val - ) + tms_spd_model_config = TMSSPDModelConfig(**tms_model_config.model_dump(mode="json"), m=config.m) model = TMSSPDModel(config=tms_spd_model_config) # Randomly initialize the bias for the pretrained model target_model.b_final.data = torch.randn_like(target_model.b_final.data) # Manually set the bias for the SPD model from the bias in the pretrained model model.b_final.data[:] = target_model.b_final.data.clone() - - if not config.task_config.train_bias: - model.b_final.requires_grad = False + model.b_final.requires_grad = False dataset = SparseFeatureDataset( n_instances=target_model.config.n_instances, @@ -229,7 +223,6 @@ def test_tms_equivalent_to_raw_model() -> None: tms_spd_config = TMSSPDModelConfig( **tms_config.model_dump(), m=3, # Small m for testing - bias_val=0.0, ) spd_model = TMSSPDModel(config=tms_spd_config).to(device) @@ -295,7 +288,6 @@ def test_init_tms_spd_model_from_target() -> None: tms_spd_config = TMSSPDModelConfig( **tms_config.model_dump(), m=tms_config.n_features, # Must match n_features for initialization - bias_val=0.0, ) spd_model = TMSSPDModel(config=tms_spd_config).to(device) From 58eb6060566448df5d72e927c67e8912554e7951 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 11:30:03 +0000 Subject: [PATCH 31/73] Make calc_masked_target_component_acts einsums clearer --- spd/run_spd.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 6b3bf27..4ab28b5 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -245,9 +245,13 @@ def calc_masked_target_component_acts( masked_target_component_acts = {} for param_name in pre_weight_acts: raw_name = param_name.removesuffix(".hook_pre") - masked_As = einops.einsum(As[raw_name], masks[raw_name], "... d_in m, ... m -> ... d_in m") + masked_As = einops.einsum( + As[raw_name], masks[raw_name], "... d_in m, batch ... m -> batch ... d_in m" + ) masked_target_component_acts[raw_name] = einops.einsum( - pre_weight_acts[param_name], masked_As, "... d_in, ... d_in m -> ... m" + pre_weight_acts[param_name], + masked_As, + "batch ... d_in, batch ... d_in m -> batch ... m", ) return masked_target_component_acts From f5367436c0d58f369292f11656ee3bba5110b0e5 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 11:39:19 +0000 Subject: [PATCH 32/73] Change bias init to 1 in GateMLP --- spd/models/components.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/models/components.py b/spd/models/components.py index 31349bb..55504ac 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -58,9 +58,9 @@ def __init__(self, m: int, n_gate_hidden_neurons: int, n_instances: int | None = out_bias_shape = (n_instances, m) if n_instances is not None else (m,) self.mlp_in = nn.Parameter(torch.empty(shape)) - self.in_bias = nn.Parameter(torch.zeros(in_bias_shape)) + self.in_bias = nn.Parameter(torch.ones(in_bias_shape)) self.mlp_out = nn.Parameter(torch.empty(shape)) - self.out_bias = nn.Parameter(torch.zeros(out_bias_shape)) + self.out_bias = nn.Parameter(torch.ones(out_bias_shape)) torch.nn.init.normal_(self.mlp_in, mean=0.0, std=0.2) torch.nn.init.normal_(self.mlp_out, mean=0.0, std=0.2) From b6a35cc7ed5f29b3c0486418480e9c904bb595c3 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 12:19:19 +0000 Subject: [PATCH 33/73] Plot unpermuted As --- .../resid_mlp/resid_mlp_decomposition.py | 10 +---- spd/experiments/tms/tms_decomposition.py | 2 + spd/plotting.py | 45 +++++++++++++++++++ spd/run_spd.py | 4 ++ 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 043487f..a32295b 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -14,7 +14,6 @@ import yaml from jaxtyping import Float from torch import Tensor -from tqdm import tqdm from spd.configs import Config, ResidualMLPTaskConfig from spd.experiments.resid_mlp.models import ( @@ -26,6 +25,7 @@ from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger from spd.models.components import Gate +from spd.plotting import plot_As from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, @@ -120,13 +120,7 @@ def resid_mlp_plot_results_fn( fig_dict["masks"] = plot_mask_vals( model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 ) - - # Save plots to files - if out_dir: - for k, v in fig_dict.items(): - out_file = out_dir / f"{k}_s{step}.png" - v.savefig(out_file, dpi=100) - tqdm.write(f"Saved plot to {out_file}") + fig_dict["As"] = plot_As(model=model, device=device) return fig_dict diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 86b1e24..13d63af 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -22,6 +22,7 @@ from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger from spd.models.components import Gate +from spd.plotting import plot_As from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, @@ -63,6 +64,7 @@ def make_plots( plots["masks"] = plot_mask_vals( model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 ) + plots["As"] = plot_As(model=model, device=device) return plots diff --git a/spd/plotting.py b/spd/plotting.py index 891db8d..14f39e2 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -8,6 +8,9 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable from torch import Tensor +from spd.models.base import SPDModel +from spd.module_utils import collect_nested_module_attrs + def plot_subnetwork_attributions_statistics( mask: Float[Tensor, "batch_size n_instances m"], @@ -88,3 +91,45 @@ def plot_matrix( n_functions = matrix.shape[0] ax.set_yticks(range(n_functions)) ax.set_yticklabels([f"{L:.0f}" for L in range(1, n_functions + 1)]) + + +def plot_As(model: SPDModel, device: str) -> plt.Figure: + """Plot the A matrices for each instance.""" + # Collect all A matrices + As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + n_instances = model.n_instances + + # Create figure for plotting + fig, axs = plt.subplots( + len(As), + n_instances, + figsize=(5 * n_instances, 5 * len(As)), + constrained_layout=True, + squeeze=False, + ) + axs = np.array(axs) + + images = [] + + # Plot each A matrix for each instance + for i in range(n_instances): + axs[0, i].set_title(f"Instance {i}") + for j, (A_name, A) in enumerate(As.items()): + # A has shape (n_instances, d_in, m) + A_data = A[i].detach().cpu().numpy() + im = axs[j, i].matshow(A_data, aspect="auto", cmap="coolwarm") + if i == 0: + axs[j, i].set_ylabel("d_in index") + axs[j, i].set_xlabel("Component index") + axs[j, i].set_title(A_name) + images.append(im) + + # Add unified colorbar + norm = plt.Normalize( + vmin=min(A.min().item() for A in As.values()), + vmax=max(A.max().item() for A in As.values()), + ) + for im in images: + im.set_norm(norm) + fig.colorbar(images[0], ax=axs.ravel().tolist()) + return fig diff --git a/spd/run_spd.py b/spd/run_spd.py index 4ab28b5..649bc3d 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -449,6 +449,10 @@ def optimize( {k: wandb.Image(v) for k, v in fig_dict.items()}, step=step, ) + if out_dir is not None: + for k, v in fig_dict.items(): + v.savefig(out_dir / f"{k}_{step}.png") + tqdm.write(f"Saved plot to {out_dir / f'{k}_{step}.png'}") # Save model if ( From 10cad293fa33853afa081ae6eb636cfc22002cfa Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 12:29:21 +0000 Subject: [PATCH 34/73] Set in_bias in GateMLP to zeros --- spd/models/components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/models/components.py b/spd/models/components.py index 55504ac..bb89977 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -58,7 +58,7 @@ def __init__(self, m: int, n_gate_hidden_neurons: int, n_instances: int | None = out_bias_shape = (n_instances, m) if n_instances is not None else (m,) self.mlp_in = nn.Parameter(torch.empty(shape)) - self.in_bias = nn.Parameter(torch.ones(in_bias_shape)) + self.in_bias = nn.Parameter(torch.zeros(in_bias_shape)) self.mlp_out = nn.Parameter(torch.empty(shape)) self.out_bias = nn.Parameter(torch.ones(out_bias_shape)) From 6aa82a81ab6f97d27309a8eb7c9f405e788e7636 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 13:35:02 +0000 Subject: [PATCH 35/73] plot_mask_vals in the root plotting.py instead of in tms experiment --- .../resid_mlp/resid_mlp_decomposition.py | 3 +- spd/experiments/tms/plotting.py | 145 +++--------------- spd/experiments/tms/tms_decomposition.py | 3 +- spd/plotting.py | 97 ++++++++++++ 4 files changed, 120 insertions(+), 128 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index a32295b..103c508 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -22,10 +22,9 @@ ResidualMLPSPDModel, ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset -from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger from spd.models.components import Gate -from spd.plotting import plot_As +from spd.plotting import plot_As, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index 0719cc1..8df15c9 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -1,128 +1,25 @@ -import einops -import matplotlib.pyplot as plt -import numpy as np -import torch -from jaxtyping import Float -from torch import Tensor - -from spd.experiments.tms.models import TMSModel, TMSSPDModel -from spd.hooks import HookedRootModule -from spd.models.base import SPDModel -from spd.models.components import Gate -from spd.module_utils import collect_nested_module_attrs -from spd.run_spd import calc_component_acts, calc_masks - - -def permute_to_identity( - mask: Float[Tensor, "batch n_instances m"], -) -> Float[Tensor, "batch n_instances m"]: - batch, n_instances, m = mask.shape - new_mask = mask.clone() - effective_rows: int = min(batch, m) - for inst in range(n_instances): - mat: Tensor = mask[:, inst, :] - perm: list[int] = [0] * m - used: set[int] = set() - for i in range(effective_rows): - sorted_indices: list[int] = torch.argsort(mat[i, :], descending=True).tolist() - chosen: int = next( - (col for col in sorted_indices if col not in used), sorted_indices[0] - ) - perm[i] = chosen - used.add(chosen) - remaining: list[int] = sorted(list(set(range(m)) - used)) - for idx, col in enumerate(remaining): - perm[effective_rows + idx] = col - new_mask[:, inst, :] = mat[:, perm] - return new_mask - - -def plot_mask_vals( - model: SPDModel, - target_model: HookedRootModule, - gates: dict[str, Gate], - device: str, - input_magnitude: float, -) -> plt.Figure: - """Plot the values of the mask for a batch of inputs with single active features.""" - # First, create a batch of inputs with single active features - n_features = model.n_features - n_instances = model.n_instances - batch = torch.eye(n_features, device=device) * input_magnitude - batch = einops.repeat( - batch, "batch n_features -> batch n_instances n_features", n_instances=n_instances - ) - - # Forward pass with target model - target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) - target_cache = target_model.run_with_cache(batch, names_filter=target_cache_filter)[1] - pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} - As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) - - target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) - - relud_masks = calc_masks( - gates=gates, target_component_acts=target_component_acts, attributions=None - )[1] - - # Permute columns so that in each instance the maximum per row ends up on the diagonal. - relud_masks = {k: permute_to_identity(mask=v) for k, v in relud_masks.items()} - - # Create figure with better layout and sizing - fig, axs = plt.subplots( - len(relud_masks), - n_instances, - figsize=(5 * n_instances, 5 * len(relud_masks)), - constrained_layout=True, - squeeze=False, - ) - axs = np.array(axs) - - images = [] - for i in range(n_instances): - axs[0, i].set_title(f"Instance {i}") - for j, (mask_name, mask) in enumerate(relud_masks.items()): - # mask has shape (batch, n_instances, m) - mask_data = mask[:, i, :].detach().cpu().numpy() - im = axs[j, i].matshow(mask_data, aspect="auto", cmap="Reds") - images.append(im) - - axs[j, i].set_xlabel("Mask index") - if i == 0: # Only set ylabel for leftmost plots - axs[j, i].set_ylabel("Input feature index") - axs[j, i].set_title(mask_name) - - # Add unified colorbar - norm = plt.Normalize( - vmin=min(mask.min().item() for mask in relud_masks.values()), - vmax=max(mask.max().item() for mask in relud_masks.values()), - ) - for im in images: - im.set_norm(norm) - fig.colorbar(images[0], ax=axs.ravel().tolist()) - - # Add a title which shows the input magnitude - fig.suptitle(f"Input magnitude: {input_magnitude}") - - return fig - +from spd.experiments.tms.models import TMSSPDModel if __name__ == "__main__": - pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" - run_id = "wandb:spd-tms/runs/fj68gebo" - target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) + # run_id = "wandb:spd-tms/runs/u359w3kq" + # run_id = "wandb:spd-tms/runs/hrwrgei2" + run_id = "wandb:spd-tms/runs/3p8qgg6b" + # pretrained_model_path = "wandb:spd-train-tms/runs/tmzweoqk" + # run_id = "wandb:spd-tms/runs/fj68gebo" + # target_model, target_model_train_config_dict = TMSModel.from_pretrained(pretrained_model_path) spd_model, spd_model_train_config_dict = TMSSPDModel.from_pretrained(run_id) - # We used "-" instead of "." as module names can't have "." in them - gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} - - input_magnitude = 0.75 - fig = plot_mask_vals( - spd_model, - target_model, - gates, # type: ignore - device="cpu", - input_magnitude=input_magnitude, - ) - fig.savefig(f"tms_mask_vals_{input_magnitude}.png") - print(f"Saved figure to tms_mask_vals_{input_magnitude}.png") + pass + # # We used "-" instead of "." as module names can't have "." in them + # gates = {k.removeprefix("gates.").replace("-", "."): v for k, v in spd_model.gates.items()} + + # input_magnitude = 0.75 + # fig = plot_mask_vals( + # spd_model, + # target_model, + # gates, # type: ignore + # device="cpu", + # input_magnitude=input_magnitude, + # ) + # fig.savefig(f"tms_mask_vals_{input_magnitude}.png") + # print(f"Saved figure to tms_mask_vals_{input_magnitude}.png") diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 13d63af..7662df4 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -19,10 +19,9 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig -from spd.experiments.tms.plotting import plot_mask_vals from spd.log import logger from spd.models.components import Gate -from spd.plotting import plot_As +from spd.plotting import plot_As, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, diff --git a/spd/plotting.py b/spd/plotting.py index 14f39e2..1fa6dda 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -8,8 +8,105 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable from torch import Tensor +from spd.hooks import HookedRootModule from spd.models.base import SPDModel +from spd.models.components import Gate from spd.module_utils import collect_nested_module_attrs +from spd.run_spd import calc_component_acts, calc_masks + + +def permute_to_identity( + mask: Float[Tensor, "batch n_instances m"], +) -> Float[Tensor, "batch n_instances m"]: + batch, n_instances, m = mask.shape + new_mask = mask.clone() + effective_rows: int = min(batch, m) + for inst in range(n_instances): + mat: Tensor = mask[:, inst, :] + perm: list[int] = [0] * m + used: set[int] = set() + for i in range(effective_rows): + sorted_indices: list[int] = torch.argsort(mat[i, :], descending=True).tolist() + chosen: int = next( + (col for col in sorted_indices if col not in used), sorted_indices[0] + ) + perm[i] = chosen + used.add(chosen) + remaining: list[int] = sorted(list(set(range(m)) - used)) + for idx, col in enumerate(remaining): + perm[effective_rows + idx] = col + new_mask[:, inst, :] = mat[:, perm] + return new_mask + + +def plot_mask_vals( + model: SPDModel, + target_model: HookedRootModule, + gates: dict[str, Gate], + device: str, + input_magnitude: float, +) -> plt.Figure: + """Plot the values of the mask for a batch of inputs with single active features.""" + # First, create a batch of inputs with single active features + n_features = model.n_features + n_instances = model.n_instances + batch = torch.eye(n_features, device=device) * input_magnitude + batch = einops.repeat( + batch, "batch n_features -> batch n_instances n_features", n_instances=n_instances + ) + + # Forward pass with target model + target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) + target_cache = target_model.run_with_cache(batch, names_filter=target_cache_filter)[1] + pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} + As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + + relud_masks = calc_masks( + gates=gates, target_component_acts=target_component_acts, attributions=None + )[1] + + # Permute columns so that in each instance the maximum per row ends up on the diagonal. + relud_masks = {k: permute_to_identity(mask=v) for k, v in relud_masks.items()} + + # Create figure with better layout and sizing + fig, axs = plt.subplots( + len(relud_masks), + n_instances, + figsize=(5 * n_instances, 5 * len(relud_masks)), + constrained_layout=True, + squeeze=False, + ) + axs = np.array(axs) + + images = [] + for i in range(n_instances): + axs[0, i].set_title(f"Instance {i}") + for j, (mask_name, mask) in enumerate(relud_masks.items()): + # mask has shape (batch, n_instances, m) + mask_data = mask[:, i, :].detach().cpu().numpy() + im = axs[j, i].matshow(mask_data, aspect="auto", cmap="Reds") + images.append(im) + + axs[j, i].set_xlabel("Mask index") + if i == 0: # Only set ylabel for leftmost plots + axs[j, i].set_ylabel("Input feature index") + axs[j, i].set_title(mask_name) + + # Add unified colorbar + norm = plt.Normalize( + vmin=min(mask.min().item() for mask in relud_masks.values()), + vmax=max(mask.max().item() for mask in relud_masks.values()), + ) + for im in images: + im.set_norm(norm) + fig.colorbar(images[0], ax=axs.ravel().tolist()) + + # Add a title which shows the input magnitude + fig.suptitle(f"Input magnitude: {input_magnitude}") + + return fig def plot_subnetwork_attributions_statistics( From 99da31b5d7bbcd3099dfda976c373e9b10aeef2a Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 14:54:26 +0000 Subject: [PATCH 36/73] Plot permuted AB matrices --- .../resid_mlp/resid_mlp_decomposition.py | 8 +- spd/experiments/tms/tms_decomposition.py | 8 +- spd/plotting.py | 91 ++++++++++++++----- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 103c508..8b5aa82 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -24,7 +24,7 @@ from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset from spd.log import logger from spd.models.components import Gate -from spd.plotting import plot_As, plot_mask_vals +from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, @@ -116,10 +116,12 @@ def resid_mlp_plot_results_fn( assert isinstance(config.task_config, ResidualMLPTaskConfig) fig_dict = {} - fig_dict["masks"] = plot_mask_vals( + fig_dict["masks"], all_perm_indices = plot_mask_vals( model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 ) - fig_dict["As"] = plot_As(model=model, device=device) + fig_dict["AB_matrices"] = plot_AB_matrices( + model=model, device=device, all_perm_indices=all_perm_indices + ) return fig_dict diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 7662df4..70ff530 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -21,7 +21,7 @@ from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig from spd.log import logger from spd.models.components import Gate -from spd.plotting import plot_As, plot_mask_vals +from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( DatasetGeneratedDataLoader, @@ -60,10 +60,12 @@ def make_plots( **_, ) -> dict[str, plt.Figure]: plots = {} - plots["masks"] = plot_mask_vals( + plots["masks"], all_perm_indices = plot_mask_vals( model=model, target_model=target_model, gates=gates, device=device, input_magnitude=0.75 ) - plots["As"] = plot_As(model=model, device=device) + plots["AB_matrices"] = plot_AB_matrices( + model=model, device=device, all_perm_indices=all_perm_indices + ) return plots diff --git a/spd/plotting.py b/spd/plotting.py index 1fa6dda..1eb1adf 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -17,10 +17,14 @@ def permute_to_identity( mask: Float[Tensor, "batch n_instances m"], -) -> Float[Tensor, "batch n_instances m"]: +) -> tuple[Float[Tensor, "batch n_instances m"], Float[Tensor, "n_instances m"]]: + """Returns (permuted_mask, permutation_indices)""" batch, n_instances, m = mask.shape new_mask = mask.clone() - effective_rows: int = min(batch, m) + effective_rows = min(batch, m) + # Store permutation indices for each instance + perm_indices = torch.zeros((n_instances, m), dtype=torch.long, device=mask.device) + for inst in range(n_instances): mat: Tensor = mask[:, inst, :] perm: list[int] = [0] * m @@ -36,7 +40,9 @@ def permute_to_identity( for idx, col in enumerate(remaining): perm[effective_rows + idx] = col new_mask[:, inst, :] = mat[:, perm] - return new_mask + perm_indices[inst] = torch.tensor(perm, device=mask.device) + + return new_mask, perm_indices def plot_mask_vals( @@ -45,7 +51,7 @@ def plot_mask_vals( gates: dict[str, Gate], device: str, input_magnitude: float, -) -> plt.Figure: +) -> tuple[plt.Figure, dict[str, Float[Tensor, "n_instances m"]]]: """Plot the values of the mask for a batch of inputs with single active features.""" # First, create a batch of inputs with single active features n_features = model.n_features @@ -63,12 +69,14 @@ def plot_mask_vals( target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) - relud_masks = calc_masks( + relud_masks_raw = calc_masks( gates=gates, target_component_acts=target_component_acts, attributions=None )[1] - # Permute columns so that in each instance the maximum per row ends up on the diagonal. - relud_masks = {k: permute_to_identity(mask=v) for k, v in relud_masks.items()} + relud_masks = {} + all_perm_indices = {} + for k, v in relud_masks_raw.items(): + relud_masks[k], all_perm_indices[k] = permute_to_identity(mask=v) # Create figure with better layout and sizing fig, axs = plt.subplots( @@ -106,7 +114,7 @@ def plot_mask_vals( # Add a title which shows the input magnitude fig.suptitle(f"Input magnitude: {input_magnitude}") - return fig + return fig, all_perm_indices def plot_subnetwork_attributions_statistics( @@ -190,17 +198,31 @@ def plot_matrix( ax.set_yticklabels([f"{L:.0f}" for L in range(1, n_functions + 1)]) -def plot_As(model: SPDModel, device: str) -> plt.Figure: - """Plot the A matrices for each instance.""" - # Collect all A matrices +def plot_AB_matrices( + model: SPDModel, + device: str, + all_perm_indices: dict[str, Float[Tensor, "n_instances m"]] | None = None, +) -> plt.Figure: + """Plot A and B matrices for each instance, grouped by layer.""" + # Collect all A and B matrices As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + Bs = collect_nested_module_attrs(model, attr_name="B", include_attr_name=False) n_instances = model.n_instances - # Create figure for plotting + # Verify that A and B matrices have matching names + A_names = set(As.keys()) + B_names = set(Bs.keys()) + assert ( + A_names == B_names + ), f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}" + + n_layers = len(As) + + # Create figure for plotting - 2 rows per layer (A and B) fig, axs = plt.subplots( - len(As), + 2 * n_layers, n_instances, - figsize=(5 * n_instances, 5 * len(As)), + figsize=(5 * n_instances, 5 * 2 * n_layers), constrained_layout=True, squeeze=False, ) @@ -208,23 +230,42 @@ def plot_As(model: SPDModel, device: str) -> plt.Figure: images = [] - # Plot each A matrix for each instance + # Plot each layer's A and B matrices for each instance for i in range(n_instances): - axs[0, i].set_title(f"Instance {i}") - for j, (A_name, A) in enumerate(As.items()): - # A has shape (n_instances, d_in, m) - A_data = A[i].detach().cpu().numpy() - im = axs[j, i].matshow(A_data, aspect="auto", cmap="coolwarm") + if i == 0: + axs[0, i].set_title(f"Instance {i}") + + # Plot A and B matrices for each layer + for j, name in enumerate(sorted(As.keys())): + # Plot A matrix + A_data = As[name][i] + if all_perm_indices is not None: + A_data = A_data[:, all_perm_indices[name][i]] + A_data = A_data.detach().cpu().numpy() + im = axs[2 * j, i].matshow(A_data, aspect="auto", cmap="coolwarm") + if i == 0: + axs[2 * j, i].set_ylabel("d_in index") + axs[2 * j, i].set_xlabel("Component index") + axs[2 * j, i].set_title(f"{name} (A matrix)") + images.append(im) + + # Plot B matrix + B_data = Bs[name][i] + if all_perm_indices is not None: + B_data = B_data[all_perm_indices[name][i], :] + B_data = B_data.detach().cpu().numpy() + im = axs[2 * j + 1, i].matshow(B_data, aspect="auto", cmap="coolwarm") if i == 0: - axs[j, i].set_ylabel("d_in index") - axs[j, i].set_xlabel("Component index") - axs[j, i].set_title(A_name) + axs[2 * j + 1, i].set_ylabel("Component index") + axs[2 * j + 1, i].set_xlabel("d_out index") + axs[2 * j + 1, i].set_title(f"{name} (B matrix)") images.append(im) # Add unified colorbar + all_matrices = list(As.values()) + list(Bs.values()) norm = plt.Normalize( - vmin=min(A.min().item() for A in As.values()), - vmax=max(A.max().item() for A in As.values()), + vmin=min(M.min().item() for M in all_matrices), + vmax=max(M.max().item() for M in all_matrices), ) for im in images: im.set_norm(norm) From aa453f73c0369c2b265ecde7f518d725b4237c81 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 15:22:04 +0000 Subject: [PATCH 37/73] Take mean over batch only for lp_sparsity_coeff --- spd/run_spd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 649bc3d..802e582 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -112,8 +112,8 @@ def calc_lp_sparsity_loss( for layer_relud_mask in relud_masks.values(): total_loss = total_loss + layer_relud_mask**pnorm - # Mean over the batch and m dimension and divide by the number of parameter layers - return total_loss.mean(dim=(0, -1)) / len(relud_masks) + # Mean over the batch dimension only + return total_loss.mean(dim=0) def calc_act_recon_mse( From f6bc57dde26d491dbbd45dd9aa439d238555519b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 15:51:50 +0000 Subject: [PATCH 38/73] Fix for normalizing by batch only; sum over m dim --- spd/run_spd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 802e582..b2bf42a 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -113,7 +113,7 @@ def calc_lp_sparsity_loss( total_loss = total_loss + layer_relud_mask**pnorm # Mean over the batch dimension only - return total_loss.mean(dim=0) + return total_loss.mean(dim=0).sum(dim=-1) def calc_act_recon_mse( From 5f216b3e4b60fa467abe8a34c52b2e9871792aa6 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 15:54:14 +0000 Subject: [PATCH 39/73] Fix docs for lp sparsity loss --- spd/run_spd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index b2bf42a..601ce4f 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -112,8 +112,8 @@ def calc_lp_sparsity_loss( for layer_relud_mask in relud_masks.values(): total_loss = total_loss + layer_relud_mask**pnorm - # Mean over the batch dimension only - return total_loss.mean(dim=0).sum(dim=-1) + # Sum over the m dimension and mean over the batch dimension + return total_loss.sum(dim=-1).mean(dim=0) def calc_act_recon_mse( From d1b82fa51fee2242e8f4ccd8e0c5782e4b590934 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 6 Mar 2025 16:11:57 +0000 Subject: [PATCH 40/73] Fix return type of lp_sparsity_loss --- spd/run_spd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 601ce4f..eabc811 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -96,7 +96,7 @@ def calc_param_match_loss( def calc_lp_sparsity_loss( relud_masks: dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]], pnorm: float, -) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: +) -> Float[Tensor, ""] | Float[Tensor, " n_instances"]: """Calculate the Lp sparsity loss on the attributions. Args: From e93c5c9769d6d6a68bee91f406996da3a66aad99 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 12:15:36 +0000 Subject: [PATCH 41/73] Use Kaiming normal everywhere --- spd/experiments/resid_mlp/models.py | 34 ++++++++++------------------- spd/experiments/tms/models.py | 13 +++++------ spd/models/components.py | 34 ++++++++++++++++------------- spd/module_utils.py | 29 +++++++++++++++++------- 4 files changed, 56 insertions(+), 54 deletions(-) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 128cf32..0ef9b13 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -35,8 +35,7 @@ def __init__( act_fn: Callable[[Tensor], Tensor], in_bias: bool, out_bias: bool, - init_scale: float, - init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", + init_scale: float = 1.0, n_instances: int | None = None, spd_kwargs: dict[str, Any] | None = None, ): @@ -51,7 +50,6 @@ def __init__( d_in=d_model, d_out=d_mlp, n_instances=n_instances, - init_type=init_type, init_scale=init_scale, m=spd_kwargs["m"], ) @@ -59,34 +57,27 @@ def __init__( d_in=d_mlp, d_out=d_model, n_instances=n_instances, - init_type=init_type, init_scale=init_scale, m=spd_kwargs["m"], ) else: self.mlp_in = Linear( - d_in=d_model, - d_out=d_mlp, - n_instances=n_instances, - init_type=init_type, - init_scale=init_scale, + d_in=d_model, d_out=d_mlp, n_instances=n_instances, init_scale=init_scale ) self.mlp_out = Linear( - d_in=d_mlp, - d_out=d_model, - n_instances=n_instances, - init_type=init_type, - init_scale=init_scale, + d_in=d_mlp, d_out=d_model, n_instances=n_instances, init_scale=init_scale ) self.bias1 = None self.bias2 = None if in_bias: shape = (n_instances, d_mlp) if n_instances is not None else d_mlp - self.bias1 = nn.Parameter(torch.zeros(shape)) + self.bias1 = nn.Parameter(torch.empty(shape)) + init_param_(self.bias1, fan_val=d_mlp, nonlinearity="relu") if out_bias: shape = (n_instances, d_model) if n_instances is not None else d_model - self.bias2 = nn.Parameter(torch.zeros(shape)) + self.bias2 = nn.Parameter(torch.empty(shape)) + init_param_(self.bias2, fan_val=d_model, nonlinearity="linear") def forward( self, @@ -132,7 +123,6 @@ class ResidualMLPConfig(BaseModel): apply_output_act_fn: bool in_bias: bool out_bias: bool - init_scale: float = 1.0 class ResidualMLPModel(HookedRootModule): @@ -140,9 +130,9 @@ def __init__(self, config: ResidualMLPConfig): super().__init__() self.config = config self.W_E = nn.Parameter(torch.empty(config.n_instances, config.n_features, config.d_embed)) - init_param_(self.W_E, scale=config.init_scale) + init_param_(self.W_E, fan_val=config.n_features, nonlinearity="linear") self.W_U = nn.Parameter(torch.empty(config.n_instances, config.d_embed, config.n_features)) - init_param_(self.W_U, scale=config.init_scale) + init_param_(self.W_U, fan_val=config.d_embed, nonlinearity="linear") assert config.act_fn_name in ["gelu", "relu"] self.act_fn = F.gelu if config.act_fn_name == "gelu" else F.relu @@ -155,7 +145,6 @@ def __init__(self, config: ResidualMLPConfig): act_fn=self.act_fn, in_bias=config.in_bias, out_bias=config.out_bias, - init_scale=config.init_scale, ) for _ in range(config.n_layers) ] @@ -301,8 +290,8 @@ def __init__( self.W_E = nn.Parameter(torch.empty(config.n_instances, config.n_features, config.d_embed)) self.W_U = nn.Parameter(torch.empty(config.n_instances, config.d_embed, config.n_features)) - init_param_(self.W_E, init_type=config.init_type) - init_param_(self.W_U, init_type=config.init_type) + init_param_(self.W_E, fan_val=config.n_features, nonlinearity="linear") + init_param_(self.W_U, fan_val=config.d_embed, nonlinearity="linear") self.layers = nn.ModuleList() @@ -319,7 +308,6 @@ def __init__( n_instances=config.n_instances, d_model=config.d_embed, d_mlp=config.d_mlp, - init_type=config.init_type, init_scale=config.init_scale, in_bias=config.in_bias, out_bias=config.out_bias, diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 60cb711..9d99fb0 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -21,6 +21,7 @@ TransposedLinear, TransposedLinearComponent, ) +from spd.module_utils import init_param_ from spd.types import WANDB_PATH_PREFIX, ModelPath from spd.utils import replace_deprecated_param_names from spd.wandb_utils import download_wandb_file, fetch_latest_wandb_checkpoint, fetch_wandb_run_dir @@ -75,12 +76,12 @@ def __init__(self, config: TMSModelConfig): d_in=config.n_features, d_out=config.n_hidden, n_instances=config.n_instances, - init_type="xavier_normal", ) # Use tied weights for the second linear layer self.linear2 = TransposedLinear(self.linear1.weight) - self.b_final = nn.Parameter(torch.zeros((config.n_instances, config.n_features))) + self.b_final = nn.Parameter(torch.empty((config.n_instances, config.n_features))) + init_param_(self.b_final, fan_val=config.n_features, nonlinearity="relu") self.hidden_layers = None if config.n_hidden_layers > 0: @@ -90,7 +91,6 @@ def __init__(self, config: TMSModelConfig): d_in=config.n_hidden, d_out=config.n_hidden, n_instances=config.n_instances, - init_type="xavier_normal", ) self.hidden_layers.append(layer) self.setup() @@ -189,14 +189,13 @@ def __init__(self, config: TMSSPDModelConfig): d_in=config.n_features, d_out=config.n_hidden, n_instances=config.n_instances, - init_type="xavier_normal", - init_scale=1.0, m=self.m, ) self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) self.b_final = nn.Parameter( - torch.zeros((config.n_instances, config.n_features), device=config.device) + torch.empty((config.n_instances, config.n_features), device=config.device) ) + init_param_(self.b_final, fan_val=config.n_features, nonlinearity="relu") self.hidden_layers = None if config.n_hidden_layers > 0: @@ -206,8 +205,6 @@ def __init__(self, config: TMSSPDModelConfig): d_in=config.n_hidden, d_out=config.n_hidden, n_instances=config.n_instances, - init_type="xavier_normal", - init_scale=1.0, m=self.m, ) for _ in range(config.n_hidden_layers) diff --git a/spd/models/components.py b/spd/models/components.py index bb89977..f162a1f 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -1,4 +1,4 @@ -from typing import Any, Literal +from typing import Any import einops import torch @@ -22,8 +22,10 @@ def __init__(self, m: int, n_instances: int | None = None): self.n_instances = n_instances shape = (n_instances, m) if n_instances is not None else (m,) self.weight = nn.Parameter(torch.empty(shape)) - torch.nn.init.normal_(self.weight, mean=0.0, std=0.2) - self.bias = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.empty(shape)) + fan_val = 1 # Since each weight gets applied independently + init_param_(self.weight, fan_val=fan_val, nonlinearity="linear") + init_param_(self.bias, fan_val=fan_val, nonlinearity="linear") def forward( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] @@ -33,7 +35,7 @@ def forward( def forward_relu( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] ) -> Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]: - return (x * self.weight + self.bias).relu() + return F.relu(x * self.weight + self.bias) class GateMLP(nn.Module): @@ -58,12 +60,14 @@ def __init__(self, m: int, n_gate_hidden_neurons: int, n_instances: int | None = out_bias_shape = (n_instances, m) if n_instances is not None else (m,) self.mlp_in = nn.Parameter(torch.empty(shape)) - self.in_bias = nn.Parameter(torch.zeros(in_bias_shape)) + self.in_bias = nn.Parameter(torch.empty(in_bias_shape)) self.mlp_out = nn.Parameter(torch.empty(shape)) - self.out_bias = nn.Parameter(torch.ones(out_bias_shape)) + self.out_bias = nn.Parameter(torch.empty(out_bias_shape)) - torch.nn.init.normal_(self.mlp_in, mean=0.0, std=0.2) - torch.nn.init.normal_(self.mlp_out, mean=0.0, std=0.2) + init_param_(self.mlp_in, fan_val=1, nonlinearity="relu") + init_param_(self.in_bias, fan_val=1, nonlinearity="relu") + init_param_(self.mlp_out, fan_val=n_gate_hidden_neurons, nonlinearity="linear") + init_param_(self.out_bias, fan_val=n_gate_hidden_neurons, nonlinearity="linear") def _compute_pre_activation( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] @@ -106,13 +110,14 @@ def __init__( d_in: int, d_out: int, n_instances: int | None = None, - init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", init_scale: float = 1.0, ): super().__init__() shape = (n_instances, d_in, d_out) if n_instances is not None else (d_in, d_out) self.weight = nn.Parameter(torch.empty(shape)) - init_param_(self.weight, scale=init_scale, init_type=init_type) + # Note: init assumes no relu/gelu after this layer (which won't be the case for mlp_in, but + # sqrt(2) ~= 1 so we're ignoring this for now.) + init_param_(self.weight, fan_val=d_in, nonlinearity="linear", scale=init_scale) self.hook_pre = HookPoint() # (batch ... d_in) self.hook_post = HookPoint() # (batch ... d_out) @@ -138,7 +143,6 @@ def __init__( d_out: int, m: int, n_instances: int | None = None, - init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", init_scale: float = 1.0, ): super().__init__() @@ -146,16 +150,16 @@ def __init__( self.m = m # Initialize A and B matrices - shape_A = (n_instances, d_in, self.m) if n_instances is not None else (d_in, self.m) - shape_B = (n_instances, self.m, d_out) if n_instances is not None else (self.m, d_out) + shape_A = (n_instances, d_in, m) if n_instances is not None else (d_in, m) + shape_B = (n_instances, m, d_out) if n_instances is not None else (m, d_out) self.A = nn.Parameter(torch.empty(shape_A)) self.B = nn.Parameter(torch.empty(shape_B)) self.hook_pre = HookPoint() # (batch d_in) or (batch n_instances d_in) self.hook_component_acts = HookPoint() # (batch m) or (batch n_instances m) self.hook_post = HookPoint() # (batch d_out) or (batch n_instances d_out) - init_param_(self.A, scale=init_scale, init_type=init_type) - init_param_(self.B, scale=init_scale, init_type=init_type) + init_param_(self.A, fan_val=d_in, nonlinearity="linear", scale=init_scale) + init_param_(self.B, fan_val=m, nonlinearity="linear", scale=init_scale) @property def weight(self) -> Float[Tensor, "... d_in d_out"]: diff --git a/spd/module_utils.py b/spd/module_utils.py index a565412..7f5cc5b 100644 --- a/spd/module_utils.py +++ b/spd/module_utils.py @@ -1,11 +1,13 @@ +import math from functools import reduce -from typing import Any, Literal +from typing import Any import einops import torch import torch.nn as nn from jaxtyping import Float from torch import Tensor +from torch.nn.init import calculate_gain def get_nested_module_attr(module: nn.Module, access_string: str) -> Any: @@ -88,12 +90,23 @@ def remove_grad_parallel_to_subnetwork_vecs( def init_param_( param: torch.Tensor, + fan_val: float, + mean: float = 0.0, + nonlinearity: str = "linear", scale: float = 1.0, - init_type: Literal["kaiming_uniform", "xavier_normal"] = "kaiming_uniform", + generator: torch.Generator | None = None, ) -> None: - if init_type == "kaiming_uniform": - torch.nn.init.kaiming_uniform_(param) - with torch.no_grad(): - param.mul_(scale) - elif init_type == "xavier_normal": - torch.nn.init.xavier_normal_(param, gain=scale) + """Fill in param with values sampled from a Kaiming normal distribution. + + Args: + param: The parameter to initialize + fan_val: The squared denominator of the std used for the kaiming normal distribution + mean: The mean of the normal distribution + nonlinearity: The nonlinearity of the activation function + scale: Scale the standard deviation by this amount + generator: The generator to sample from + """ + gain = calculate_gain(nonlinearity) + std = gain / math.sqrt(fan_val) + with torch.no_grad(): + param.normal_(mean, std * scale, generator=generator) From 52e6d9178ba740172db0ce536580409a4c7aa4d9 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 13:21:31 +0000 Subject: [PATCH 42/73] Fix MLP bias init --- spd/experiments/resid_mlp/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 0ef9b13..2799cd3 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -73,11 +73,11 @@ def __init__( if in_bias: shape = (n_instances, d_mlp) if n_instances is not None else d_mlp self.bias1 = nn.Parameter(torch.empty(shape)) - init_param_(self.bias1, fan_val=d_mlp, nonlinearity="relu") + init_param_(self.bias1, fan_val=d_model, nonlinearity="relu") if out_bias: shape = (n_instances, d_model) if n_instances is not None else d_model self.bias2 = nn.Parameter(torch.empty(shape)) - init_param_(self.bias2, fan_val=d_model, nonlinearity="linear") + init_param_(self.bias2, fan_val=d_mlp, nonlinearity="linear") def forward( self, From 244883faea365cf78208f94ebe4c95763a2bb3bc Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 13:21:54 +0000 Subject: [PATCH 43/73] Always init TMS biases to 0 --- spd/experiments/tms/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 9d99fb0..4c39e61 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -21,7 +21,6 @@ TransposedLinear, TransposedLinearComponent, ) -from spd.module_utils import init_param_ from spd.types import WANDB_PATH_PREFIX, ModelPath from spd.utils import replace_deprecated_param_names from spd.wandb_utils import download_wandb_file, fetch_latest_wandb_checkpoint, fetch_wandb_run_dir @@ -80,8 +79,8 @@ def __init__(self, config: TMSModelConfig): # Use tied weights for the second linear layer self.linear2 = TransposedLinear(self.linear1.weight) - self.b_final = nn.Parameter(torch.empty((config.n_instances, config.n_features))) - init_param_(self.b_final, fan_val=config.n_features, nonlinearity="relu") + # TMS seems to require zero bias initialization to work + self.b_final = nn.Parameter(torch.zeros((config.n_instances, config.n_features))) self.hidden_layers = None if config.n_hidden_layers > 0: @@ -193,9 +192,8 @@ def __init__(self, config: TMSSPDModelConfig): ) self.linear2 = TransposedLinearComponent(self.linear1.A, self.linear1.B) self.b_final = nn.Parameter( - torch.empty((config.n_instances, config.n_features), device=config.device) + torch.zeros((config.n_instances, config.n_features), device=config.device) ) - init_param_(self.b_final, fan_val=config.n_features, nonlinearity="relu") self.hidden_layers = None if config.n_hidden_layers > 0: From bddc0edca085b9f47136b55f301fe5206fb1ff46 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 13:41:54 +0000 Subject: [PATCH 44/73] Remove init_scale everywhere --- spd/configs.py | 1 - spd/experiments/resid_mlp/models.py | 15 ++++--------- .../resid_mlp/resid_mlp_config.yaml | 22 ++++++++++--------- .../resid_mlp/resid_mlp_decomposition.py | 7 +----- spd/experiments/resid_mlp/train_resid_mlp.py | 4 ++-- spd/experiments/tms/tms_config.yaml | 5 +++-- spd/models/components.py | 8 +++---- spd/module_utils.py | 4 +--- tests/test_resid_mlp.py | 2 -- 9 files changed, 26 insertions(+), 42 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index a931445..50fada3 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -30,7 +30,6 @@ class ResidualMLPTaskConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) task_name: Literal["residual_mlp"] = "residual_mlp" feature_probability: Probability - init_scale: float = 1.0 data_generation_type: Literal[ "exactly_one_active", "exactly_two_active", "at_least_zero_active" ] = "at_least_zero_active" diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 2799cd3..469b795 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -35,7 +35,6 @@ def __init__( act_fn: Callable[[Tensor], Tensor], in_bias: bool, out_bias: bool, - init_scale: float = 1.0, n_instances: int | None = None, spd_kwargs: dict[str, Any] | None = None, ): @@ -50,23 +49,17 @@ def __init__( d_in=d_model, d_out=d_mlp, n_instances=n_instances, - init_scale=init_scale, m=spd_kwargs["m"], ) self.mlp_out = LinearComponent( d_in=d_mlp, d_out=d_model, n_instances=n_instances, - init_scale=init_scale, m=spd_kwargs["m"], ) else: - self.mlp_in = Linear( - d_in=d_model, d_out=d_mlp, n_instances=n_instances, init_scale=init_scale - ) - self.mlp_out = Linear( - d_in=d_mlp, d_out=d_model, n_instances=n_instances, init_scale=init_scale - ) + self.mlp_in = Linear(d_in=d_model, d_out=d_mlp, n_instances=n_instances) + self.mlp_out = Linear(d_in=d_mlp, d_out=d_model, n_instances=n_instances) self.bias1 = None self.bias2 = None @@ -232,6 +225,8 @@ def from_pretrained( with open(paths.resid_mlp_train_config) as f: resid_mlp_train_config_dict = yaml.safe_load(f) + resid_mlp_train_config_dict.pop("init_scale", None) # Deprecated + with open(paths.label_coeffs) as f: label_coeffs = torch.tensor(json.load(f)) @@ -268,7 +263,6 @@ class ResidualMLPSPDConfig(BaseModel): apply_output_act_fn: bool in_bias: bool out_bias: bool - init_scale: float m: PositiveInt n_gate_hidden_neurons: PositiveInt | None = None init_type: Literal["kaiming_uniform", "xavier_normal"] = "xavier_normal" @@ -308,7 +302,6 @@ def __init__( n_instances=config.n_instances, d_model=config.d_embed, d_mlp=config.d_mlp, - init_scale=config.init_scale, in_bias=config.in_bias, out_bias=config.out_bias, act_fn=self.act_fn, diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index 0d61f45..ba09f3c 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -4,16 +4,18 @@ wandb_run_name: null wandb_run_name_prefix: "" unit_norm_matrices: false seed: 0 -m: 200 +m: 100 param_match_coeff: 1.0 masked_recon_coeff: 1.0 act_recon_coeff: 1 random_mask_recon_coeff: 1.0 n_random_masks: 1 +n_gate_hidden_neurons: null +# n_gate_hidden_neurons: 8 pnorm: 0.9 -lp_sparsity_coeff: 1e-2 +lp_sparsity_coeff: 3e-2 batch_size: 256 -steps: 10_000 +steps: 20_000 image_freq: 5_000 print_freq: 100 save_freq: 10_000 @@ -24,7 +26,6 @@ image_on_first_step: true init_from_target_model: false task_config: task_name: residual_mlp - init_scale: 2.0 feature_probability: 0.01 data_generation_type: "at_least_zero_active" pretrained_model_path: wandb:spd-train-resid-mlp/runs/zas5yjdl # 1 layer @@ -36,26 +37,27 @@ task_config: # wandb_run_name_prefix: "" # unit_norm_matrices: false # seed: 0 -# m: 100 +# m: 200 # param_match_coeff: 1.0 # masked_recon_coeff: 2.0 # act_recon_coeff: 1.0 # random_mask_recon_coeff: 1.0 -# n_random_masks: 2 +# n_random_masks: 1 +# n_gate_hidden_neurons: 8 # pnorm: 0.9 -# lp_sparsity_coeff: 1.0 +# lp_sparsity_coeff: 3e-3 # batch_size: 256 # steps: 10_000 -# image_freq: 10_000 +# image_freq: 5_000 # print_freq: 500 # save_freq: 10_000 # lr: 1e-3 # lr_schedule: cosine # lr_warmup_pct: 0.01 -# image_on_first_step: false +# image_on_first_step: true +# init_from_target_model: false # task_config: # task_name: residual_mlp -# init_scale: 2.0 # feature_probability: 0.01 # data_generation_type: "at_least_zero_active" # pretrained_model_path: wandb:spd-train-resid-mlp/runs/sv23xrhj # 2 layer diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 8b5aa82..7d1b72b 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -43,8 +43,6 @@ def get_run_name( n_layers: int, d_resid: int, d_mlp: int, - m: int | None, - init_scale: float, ) -> str: """Generate a run name based on the config.""" run_suffix = "" @@ -52,7 +50,7 @@ def get_run_name( run_suffix = config.wandb_run_name else: run_suffix = get_common_run_name_suffix(config) - run_suffix += f"scale{init_scale}_ft{n_features}_lay{n_layers}_resid{d_resid}_mlp{d_mlp}" + run_suffix += f"ft{n_features}_lay{n_layers}_resid{d_resid}_mlp{d_mlp}" return config.wandb_run_name_prefix + run_suffix @@ -213,8 +211,6 @@ def main( n_layers=target_model.config.n_layers, d_resid=target_model.config.d_embed, d_mlp=target_model.config.d_mlp, - m=config.m, - init_scale=config.task_config.init_scale, ) if config.wandb_project: assert wandb.run, "wandb.run must be initialized before training" @@ -248,7 +244,6 @@ def main( apply_output_act_fn=target_model.config.apply_output_act_fn, in_bias=target_model.config.in_bias, out_bias=target_model.config.out_bias, - init_scale=config.task_config.init_scale, m=config.m, n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) diff --git a/spd/experiments/resid_mlp/train_resid_mlp.py b/spd/experiments/resid_mlp/train_resid_mlp.py index bba9ce7..f63a91f 100644 --- a/spd/experiments/resid_mlp/train_resid_mlp.py +++ b/spd/experiments/resid_mlp/train_resid_mlp.py @@ -295,8 +295,8 @@ def run_train(config: ResidMLPTrainConfig, device: str) -> Float[Tensor, " n_ins importance_val=1, data_generation_type="at_least_zero_active", batch_size=2048, - steps=10000, - print_freq=500, + steps=1000, + print_freq=100, lr=3e-3, lr_schedule="cosine", fixed_random_embedding=True, diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index e52d2f9..5396465 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -34,10 +34,11 @@ m: 100 param_match_coeff: 1.0 masked_recon_coeff: 1.0 pnorm: 0.9 -lp_sparsity_coeff: 3e-2 +lp_sparsity_coeff: 1.5e-4 random_mask_recon_coeff: 1 n_random_masks: 1 -n_gate_hidden_neurons: 8 +# n_gate_hidden_neurons: 8 +n_gate_hidden_neurons: null batch_size: 2048 steps: 30_000 image_freq: 5_000 diff --git a/spd/models/components.py b/spd/models/components.py index f162a1f..171baf1 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -110,14 +110,13 @@ def __init__( d_in: int, d_out: int, n_instances: int | None = None, - init_scale: float = 1.0, ): super().__init__() shape = (n_instances, d_in, d_out) if n_instances is not None else (d_in, d_out) self.weight = nn.Parameter(torch.empty(shape)) # Note: init assumes no relu/gelu after this layer (which won't be the case for mlp_in, but # sqrt(2) ~= 1 so we're ignoring this for now.) - init_param_(self.weight, fan_val=d_in, nonlinearity="linear", scale=init_scale) + init_param_(self.weight, fan_val=d_in, nonlinearity="linear") self.hook_pre = HookPoint() # (batch ... d_in) self.hook_post = HookPoint() # (batch ... d_out) @@ -143,7 +142,6 @@ def __init__( d_out: int, m: int, n_instances: int | None = None, - init_scale: float = 1.0, ): super().__init__() self.n_instances = n_instances @@ -158,8 +156,8 @@ def __init__( self.hook_component_acts = HookPoint() # (batch m) or (batch n_instances m) self.hook_post = HookPoint() # (batch d_out) or (batch n_instances d_out) - init_param_(self.A, fan_val=d_in, nonlinearity="linear", scale=init_scale) - init_param_(self.B, fan_val=m, nonlinearity="linear", scale=init_scale) + init_param_(self.A, fan_val=d_in, nonlinearity="linear") + init_param_(self.B, fan_val=m, nonlinearity="linear") @property def weight(self) -> Float[Tensor, "... d_in d_out"]: diff --git a/spd/module_utils.py b/spd/module_utils.py index 7f5cc5b..2fe0666 100644 --- a/spd/module_utils.py +++ b/spd/module_utils.py @@ -93,7 +93,6 @@ def init_param_( fan_val: float, mean: float = 0.0, nonlinearity: str = "linear", - scale: float = 1.0, generator: torch.Generator | None = None, ) -> None: """Fill in param with values sampled from a Kaiming normal distribution. @@ -103,10 +102,9 @@ def init_param_( fan_val: The squared denominator of the std used for the kaiming normal distribution mean: The mean of the normal distribution nonlinearity: The nonlinearity of the activation function - scale: Scale the standard deviation by this amount generator: The generator to sample from """ gain = calculate_gain(nonlinearity) std = gain / math.sqrt(fan_val) with torch.no_grad(): - param.normal_(mean, std * scale, generator=generator) + param.normal_(mean, std, generator=generator) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 0708e1d..0b56407 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -22,7 +22,6 @@ RESID_MLP_TASK_CONFIG = ResidualMLPTaskConfig( task_name="residual_mlp", feature_probability=0.333, - init_scale=1.0, data_generation_type="at_least_zero_active", pretrained_model_path=Path(), # We'll create this later ) @@ -227,7 +226,6 @@ def test_init_resid_mlp_spd_model_from_target() -> None: apply_output_act_fn=False, in_bias=True, out_bias=True, - init_scale=1.0, ) target_model = ResidualMLPModel(config=resid_mlp_config).to(device) From a1d40c46cc28a0742dab2ba3afb4c305ed6a7a64 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 14:12:14 +0000 Subject: [PATCH 45/73] Fix init_scale deprecation --- spd/experiments/resid_mlp/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 469b795..0ea6ed7 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -225,7 +225,7 @@ def from_pretrained( with open(paths.resid_mlp_train_config) as f: resid_mlp_train_config_dict = yaml.safe_load(f) - resid_mlp_train_config_dict.pop("init_scale", None) # Deprecated + resid_mlp_train_config_dict["resid_mlp_config"].pop("init_scale", None) # Deprecated with open(paths.label_coeffs) as f: label_coeffs = torch.tensor(json.load(f)) From c71ace6bf2cd8e09a16066777d3d34638f56796e Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 15:55:18 +0000 Subject: [PATCH 46/73] Init A and B based on norm of target weights --- spd/run_spd.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/spd/run_spd.py b/spd/run_spd.py index eabc811..dbe9618 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -256,6 +256,31 @@ def calc_masked_target_component_acts( return masked_target_component_acts +def init_As_and_Bs_(model: SPDModel, target_model: HookedRootModule) -> None: + """Initialize the A and B matrices using a scale factor from the target weights.""" + As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) + Bs = collect_nested_module_attrs(model, attr_name="B", include_attr_name=False) + for param_name in As: + A = As[param_name] # (..., d_in, m) + B = Bs[param_name] # (..., m, d_out) + target_weight = get_nested_module_attr( + target_model, param_name + ".weight" + ) # (..., d_in, d_out) + + # Make A and B have unit norm in the d_in and d_out dimensions + A.data[:] = torch.randn_like(A.data) + B.data[:] = torch.randn_like(B.data) + A.data[:] = A.data / A.data.norm(dim=-2, keepdim=True) + B.data[:] = B.data / B.data.norm(dim=-1, keepdim=True) + + m_norms = einops.einsum( + A, B, target_weight, "... d_in m, ... m d_out, ... d_in d_out -> ... m" + ) + # Scale B by m_norms. We leave A as is since this may get scaled with the unit_norm_matrices + # config options. + B.data[:] = B.data * m_norms.unsqueeze(-1) + + def optimize( model: SPDModel, config: Config, @@ -269,6 +294,8 @@ def optimize( model.to(device=device) target_model.to(device=device) + init_As_and_Bs_(model=model, target_model=target_model) + has_instance_dim = hasattr(model, "n_instances") # We used "-" instead of "." as module names can't have "." in them From 3898599343c50f699d82577917a664f2b998402d Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 7 Mar 2025 16:15:12 +0000 Subject: [PATCH 47/73] Set Gate biases to 0 --- spd/experiments/tms/tms_config.yaml | 2 +- spd/models/components.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index 5396465..f575ee3 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -34,7 +34,7 @@ m: 100 param_match_coeff: 1.0 masked_recon_coeff: 1.0 pnorm: 0.9 -lp_sparsity_coeff: 1.5e-4 +lp_sparsity_coeff: 1e-4 random_mask_recon_coeff: 1 n_random_masks: 1 # n_gate_hidden_neurons: 8 diff --git a/spd/models/components.py b/spd/models/components.py index 171baf1..6c76cea 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -22,10 +22,11 @@ def __init__(self, m: int, n_instances: int | None = None): self.n_instances = n_instances shape = (n_instances, m) if n_instances is not None else (m,) self.weight = nn.Parameter(torch.empty(shape)) - self.bias = nn.Parameter(torch.empty(shape)) + # self.bias = nn.Parameter(torch.empty(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) fan_val = 1 # Since each weight gets applied independently init_param_(self.weight, fan_val=fan_val, nonlinearity="linear") - init_param_(self.bias, fan_val=fan_val, nonlinearity="linear") + # init_param_(self.bias, fan_val=fan_val, nonlinearity="linear") def forward( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] From 5afdc92f3c4968918f72fa253bad83bc73deec60 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 17 Mar 2025 16:00:03 +0000 Subject: [PATCH 48/73] Load env vars when running sweeps too --- spd/wandb_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py index 242ee8b..73d2d67 100644 --- a/spd/wandb_utils.py +++ b/spd/wandb_utils.py @@ -101,12 +101,13 @@ def init_wandb( Returns: Config updated with sweep hyperparameters (if any). """ + load_dotenv(override=True) + if sweep_config_path is not None: with open(sweep_config_path) as f: sweep_data = yaml.safe_load(f) - wandb.init(config=sweep_data, save_code=True, name=name) + wandb.init(config=sweep_data, entity=os.getenv("WANDB_ENTITY"), save_code=True, name=name) else: - load_dotenv(override=True) wandb.init(project=project, entity=os.getenv("WANDB_ENTITY"), save_code=True, name=name) # Update the config with the hyperparameters for this sweep (if any) From e80f874cfa953b9fab96de271c864a99754e2346 Mon Sep 17 00:00:00 2001 From: Dan <150014290+danbraunai-apollo@users.noreply.github.com> Date: Mon, 24 Mar 2025 15:27:03 +0000 Subject: [PATCH 49/73] Add layerwise recon (#263) * Add layerwise recon * Add layerwise_random_recon_loss * Protect the eyes of mathematicians --- spd/configs.py | 8 +- .../resid_mlp/resid_mlp_config.yaml | 4 +- spd/experiments/tms/tms_config.yaml | 5 +- spd/run_spd.py | 87 ++++++++++++++++++- 4 files changed, 98 insertions(+), 6 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 50fada3..1622900 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -54,6 +54,8 @@ class Config(BaseModel): param_match_coeff: NonNegativeFloat | None = 1.0 masked_recon_coeff: NonNegativeFloat | None = None random_mask_recon_coeff: NonNegativeFloat | None = None + layerwise_recon_coeff: NonNegativeFloat | None = None + layerwise_random_recon_coeff: NonNegativeFloat | None = None lp_sparsity_coeff: NonNegativeFloat pnorm: PositiveFloat m: PositiveInt @@ -113,8 +115,8 @@ def validate_model(self) -> Self: # Check that lr_exponential_halflife is not None if lr_schedule is "exponential" if self.lr_schedule == "exponential": - assert ( - self.lr_exponential_halflife is not None - ), "lr_exponential_halflife must be set if lr_schedule is exponential" + assert self.lr_exponential_halflife is not None, ( + "lr_exponential_halflife must be set if lr_schedule is exponential" + ) return self diff --git a/spd/experiments/resid_mlp/resid_mlp_config.yaml b/spd/experiments/resid_mlp/resid_mlp_config.yaml index ba09f3c..4c43b7b 100644 --- a/spd/experiments/resid_mlp/resid_mlp_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp_config.yaml @@ -7,11 +7,13 @@ seed: 0 m: 100 param_match_coeff: 1.0 masked_recon_coeff: 1.0 -act_recon_coeff: 1 +# act_recon_coeff: 1 random_mask_recon_coeff: 1.0 n_random_masks: 1 n_gate_hidden_neurons: null # n_gate_hidden_neurons: 8 +layerwise_recon_coeff: 1.0 +layerwise_random_recon_coeff: 1.0 pnorm: 0.9 lp_sparsity_coeff: 3e-2 batch_size: 256 diff --git a/spd/experiments/tms/tms_config.yaml b/spd/experiments/tms/tms_config.yaml index f575ee3..c741eb2 100644 --- a/spd/experiments/tms/tms_config.yaml +++ b/spd/experiments/tms/tms_config.yaml @@ -39,6 +39,8 @@ random_mask_recon_coeff: 1 n_random_masks: 1 # n_gate_hidden_neurons: 8 n_gate_hidden_neurons: null +layerwise_recon_coeff: 1.0 +layerwise_random_recon_coeff: 1.0 batch_size: 2048 steps: 30_000 image_freq: 5_000 @@ -52,4 +54,5 @@ task_config: task_name: tms feature_probability: 0.05 data_generation_type: "at_least_zero_active" - pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" \ No newline at end of file + # pretrained_model_path: "wandb:spd-train-tms/runs/tmzweoqk" + pretrained_model_path: "wandb:spd-train-tms/runs/me2x5oeo" # 1 hidden layer fixed to identity \ No newline at end of file diff --git a/spd/run_spd.py b/spd/run_spd.py index dbe9618..00a2a7e 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -1,11 +1,13 @@ """Run SPD on a model.""" from collections.abc import Callable +from functools import partial from pathlib import Path import einops import matplotlib.pyplot as plt import torch +import torch.nn as nn import wandb from jaxtyping import Float from torch import Tensor @@ -15,7 +17,7 @@ from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel -from spd.models.components import Gate +from spd.models.components import Gate, Linear, LinearComponent from spd.module_utils import collect_nested_module_attrs, get_nested_module_attr from spd.utils import calc_recon_mse, get_lr_schedule_fn, get_lr_with_warmup @@ -256,6 +258,56 @@ def calc_masked_target_component_acts( return masked_target_component_acts +def calc_layerwise_recon_loss( + param_names: list[str], + target_model: HookedRootModule, + spd_model: SPDModel, + batch: Float[Tensor, "batch n_instances d_in"] | Float[Tensor, "batch d_in"], + device: str, + masks: list[dict[str, Float[Tensor, "batch n_instances m"] | Float[Tensor, "batch m"]]], + target_out: Float[Tensor, "batch n_instances d_out"] | Float[Tensor, "batch d_out"], + has_instance_dim: bool, +) -> Float[Tensor, ""]: + """Calculate the layerwise activation reconstruction loss using regular PyTorch hooks. + + Note that we support multiple masks for the case of calculating this loss over a list of random + masks. + """ + total_loss = torch.tensor(0.0, device=device) + + for mask in masks: + for param_name in param_names: + target_module = get_nested_module_attr(target_model, param_name) + assert isinstance(target_module, Linear) + + component_module = get_nested_module_attr(spd_model, param_name) + assert isinstance(component_module, LinearComponent) + + def hook( + module: nn.Module, + input: tuple[ + Float[Tensor, "batch n_instances d_in"] | Float[Tensor, "batch d_in"], ... + ], + output: Float[Tensor, "batch n_instances d_out"] | Float[Tensor, "batch d_out"], + param_name: str, + mask: dict[str, Float[Tensor, "batch n_instances m"] | Float[Tensor, "batch m"]], + component_module: LinearComponent, + ) -> Float[Tensor, "batch n_instances d_out"] | Float[Tensor, "batch d_out"]: + linear_output = component_module(input[0], mask=mask[param_name]) + return linear_output + + handle = target_module.register_forward_hook( + partial(hook, param_name=param_name, mask=mask, component_module=component_module) + ) + modified_output = target_model(batch) + handle.remove() + + mse_loss = calc_recon_mse(modified_output, target_out, has_instance_dim) + total_loss = total_loss + mse_loss + + return total_loss / (len(param_names) * len(masks)) + + def init_As_and_Bs_(model: SPDModel, target_model: HookedRootModule) -> None: """Initialize the A and B matrices using a scale factor from the target weights.""" As = collect_nested_module_attrs(model, attr_name="A", include_attr_name=False) @@ -380,6 +432,7 @@ def optimize( batch, names_filter=spd_cache_filter, masks=masks ) + random_masks = None random_masks_loss = None if config.random_mask_recon_coeff is not None: random_masks = calc_random_masks(masks=masks, n_random_masks=config.n_random_masks) @@ -416,6 +469,33 @@ def optimize( masked_spd_component_acts, masked_target_component_acts ) + layerwise_recon_loss = None + if config.layerwise_recon_coeff is not None: + layerwise_recon_loss = calc_layerwise_recon_loss( + param_names=param_names, + target_model=target_model, + spd_model=model, + batch=batch, + device=device, + masks=[masks], + target_out=target_out, + has_instance_dim=has_instance_dim, + ) + + layerwise_random_recon_loss = None + if config.layerwise_random_recon_coeff is not None: + assert random_masks is not None + layerwise_random_recon_loss = calc_layerwise_recon_loss( + param_names=param_names, + target_model=target_model, + spd_model=model, + batch=batch, + device=device, + masks=random_masks, + target_out=target_out, + has_instance_dim=has_instance_dim, + ) + loss_terms = { "param_match_loss": (param_match_loss, config.param_match_coeff), "out_recon_loss": (out_recon_loss, config.out_recon_coeff), @@ -423,6 +503,11 @@ def optimize( "masked_recon_loss": (masked_recon_loss, config.masked_recon_coeff), "act_recon_loss": (act_recon_loss, config.act_recon_coeff), "random_masks_loss": (random_masks_loss, config.random_mask_recon_coeff), + "layerwise_recon_loss": (layerwise_recon_loss, config.layerwise_recon_coeff), + "layerwise_random_recon_loss": ( + layerwise_random_recon_loss, + config.layerwise_random_recon_coeff, + ), } # Add up the loss terms loss = torch.tensor(0.0, device=device) From 16992e5864e90f59919c13f7f0db2aa949969d30 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Mar 2025 15:27:27 +0000 Subject: [PATCH 50/73] Remove transformer-lens dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 72b99ed..635ffaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ dependencies = [ "pytest", "ipykernel", "transformers", - "transformer-lens", "matplotlib==3.9.1", # Avoid frequent pyright errors with new matplotlib versions "numpy", "python-dotenv", From 7f6a94b944bfb5f76d8f699fcb7e49da1119df24 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Mar 2025 15:39:06 +0000 Subject: [PATCH 51/73] Use new random masks for layerwise_random_masks --- spd/run_spd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 00a2a7e..1b2e13c 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -432,7 +432,6 @@ def optimize( batch, names_filter=spd_cache_filter, masks=masks ) - random_masks = None random_masks_loss = None if config.random_mask_recon_coeff is not None: random_masks = calc_random_masks(masks=masks, n_random_masks=config.n_random_masks) @@ -484,14 +483,16 @@ def optimize( layerwise_random_recon_loss = None if config.layerwise_random_recon_coeff is not None: - assert random_masks is not None + layerwise_random_masks = calc_random_masks( + masks=masks, n_random_masks=config.n_random_masks + ) layerwise_random_recon_loss = calc_layerwise_recon_loss( param_names=param_names, target_model=target_model, spd_model=model, batch=batch, device=device, - masks=random_masks, + masks=layerwise_random_masks, target_out=target_out, has_instance_dim=has_instance_dim, ) From 5c632f9124a3249f1aa908c84ee7987f35e03733 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Mar 2025 15:43:33 +0000 Subject: [PATCH 52/73] Add jaxtyping to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 635ffaf..1f75d61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "pytest", "ipykernel", "transformers", + "jaxtyping", "matplotlib==3.9.1", # Avoid frequent pyright errors with new matplotlib versions "numpy", "python-dotenv", From 5981df652ffaa005ecae3387a1754bbb4382a934 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Mar 2025 15:58:58 +0000 Subject: [PATCH 53/73] Add einops dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 1f75d61..49e6f8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "ipykernel", "transformers", "jaxtyping", + "einops", "matplotlib==3.9.1", # Avoid frequent pyright errors with new matplotlib versions "numpy", "python-dotenv", From fcff30494caa947119422c3edc1fd2e8fb9867b7 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 24 Mar 2025 16:46:00 +0000 Subject: [PATCH 54/73] Use calc_recon_mse in calc_random_masks_mse_loss for consistency --- spd/run_spd.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 1b2e13c..7c4646a 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -204,15 +204,15 @@ def calc_random_masks_mse_loss( batch: Float[Tensor, "batch n_instances d_in"], random_masks: list[dict[str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"]]], out_masked: Float[Tensor, "batch n_instances d_out"], + has_instance_dim: bool, ) -> Float[Tensor, ""] | Float[Tensor, " n_instances"]: """Calculate the MSE over all random masks.""" - loss = torch.zeros(1, device=out_masked.device) + loss = torch.tensor(0.0, device=out_masked.device) for i in range(len(random_masks)): out_masked_random_mask = model(batch, masks=random_masks[i]) - loss = loss + ((out_masked - out_masked_random_mask) ** 2).mean(dim=-1) + loss = loss + calc_recon_mse(out_masked, out_masked_random_mask, has_instance_dim) - # Normalize by the number of random masks and mean over the batch dim - return (loss / len(random_masks)).mean(dim=0) + return loss / len(random_masks) def calc_component_acts( @@ -436,7 +436,11 @@ def optimize( if config.random_mask_recon_coeff is not None: random_masks = calc_random_masks(masks=masks, n_random_masks=config.n_random_masks) random_masks_loss = calc_random_masks_mse_loss( - model=model, batch=batch, random_masks=random_masks, out_masked=target_out + model=model, + batch=batch, + random_masks=random_masks, + out_masked=target_out, + has_instance_dim=has_instance_dim, ) # Calculate losses From 7ac2a42c088b76719b6018f5682dffbdaefaa6c8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 25 Mar 2025 11:28:00 +0000 Subject: [PATCH 55/73] Set bias to zero in GateMLP mlp_out --- spd/models/components.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/spd/models/components.py b/spd/models/components.py index 6c76cea..bbe3dba 100644 --- a/spd/models/components.py +++ b/spd/models/components.py @@ -22,11 +22,9 @@ def __init__(self, m: int, n_instances: int | None = None): self.n_instances = n_instances shape = (n_instances, m) if n_instances is not None else (m,) self.weight = nn.Parameter(torch.empty(shape)) - # self.bias = nn.Parameter(torch.empty(shape)) self.bias = nn.Parameter(torch.zeros(shape)) fan_val = 1 # Since each weight gets applied independently init_param_(self.weight, fan_val=fan_val, nonlinearity="linear") - # init_param_(self.bias, fan_val=fan_val, nonlinearity="linear") def forward( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] @@ -63,12 +61,11 @@ def __init__(self, m: int, n_gate_hidden_neurons: int, n_instances: int | None = self.mlp_in = nn.Parameter(torch.empty(shape)) self.in_bias = nn.Parameter(torch.empty(in_bias_shape)) self.mlp_out = nn.Parameter(torch.empty(shape)) - self.out_bias = nn.Parameter(torch.empty(out_bias_shape)) + self.out_bias = nn.Parameter(torch.zeros(out_bias_shape)) init_param_(self.mlp_in, fan_val=1, nonlinearity="relu") init_param_(self.in_bias, fan_val=1, nonlinearity="relu") init_param_(self.mlp_out, fan_val=n_gate_hidden_neurons, nonlinearity="linear") - init_param_(self.out_bias, fan_val=n_gate_hidden_neurons, nonlinearity="linear") def _compute_pre_activation( self, x: Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] From 037caf1b786cc834b1c2fd1f111532301b75c10b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 1 Apr 2025 05:50:24 +0000 Subject: [PATCH 56/73] WIP: Swap components with Llama nn.Linear modules --- spd/experiments/lm/lm_decomposition.py | 54 +++++++++++++++++ spd/experiments/lm/models.py | 83 ++++++++++++++++++++++++++ spd/module_utils.py | 15 +++++ tests/test_module_utils.py | 28 +++++++++ 4 files changed, 180 insertions(+) create mode 100644 spd/experiments/lm/lm_decomposition.py create mode 100644 spd/experiments/lm/models.py create mode 100644 tests/test_module_utils.py diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py new file mode 100644 index 0000000..bba9bed --- /dev/null +++ b/spd/experiments/lm/lm_decomposition.py @@ -0,0 +1,54 @@ +# %% +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from transformers import AutoTokenizer + +from spd.experiments.lm.models import SSModel, create_gate_proj_components + +# %% +# Select the model size you want to use +model_size = "1.25M" # Options: "35M", "30M", "11M", "5M", "1.25M" + +# Load model configuration +model_config = MODEL_CONFIGS[model_size] + +# Load appropriate model +model_path = f"chandan-sreedhara/SimpleStories-{model_size}" +model = Llama.from_pretrained(model_path, model_config) +# model.to("cuda") +model.eval() +# %% + +ss_model = SSModel(model) + + +# Create components with rank=10 (adjust as needed) +gate_proj_components = create_gate_proj_components(model, rank=17) + +# %% +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) + +# Define your prompt +prompt = "The curious cat looked at the" + +# IMPORTANT: Use tokenizer without special tokens +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) +# input_ids = inputs.input_ids.to("cuda") +input_ids = inputs.input_ids +# Targets should be the inputs shifted by one (we will later ignore the last input token) +targets = input_ids[:, 1:] +input_ids = input_ids[:, :-1] + +# IMPORTANT: Set correct EOS token ID (not the default from tokenizer) +eos_token_id = 1 + +# %% + +# logits, _ = ss_model.forward(input_ids, components=gate_proj_components) +logits, _ = ss_model.model(input_ids, targets=targets) +print("inputs_shape", input_ids.shape) +print("logits", logits) +print("logits shape", logits.shape) + +# %% diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py new file mode 100644 index 0000000..c3d723b --- /dev/null +++ b/spd/experiments/lm/models.py @@ -0,0 +1,83 @@ +""" +Defines a SSModel class that is a wrapper around a llama model from SimpleStories +""" + +from typing import Any + +import torch.nn as nn +from simple_stories_train.models.llama import Llama +from torch import Tensor + +from spd.models.components import LinearComponent +from spd.module_utils import get_nested_module_attr, set_nested_module_attr + + +class LinearComponentWithBias(nn.Module): + """A LinearComponent with a bias parameter.""" + + def __init__(self, linear_component: LinearComponent, bias: Tensor | None): + super().__init__() + self.linear_component = linear_component + self.bias = bias + + def forward(self, x: Tensor) -> Tensor: + out = self.linear_component(x) + if self.bias is not None: + out += self.bias + return out + + +def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponentWithBias: + """Replace a nn.Linear module with a LinearComponentWithBias module.""" + d_in, d_out = linear_module.weight.shape + + linear_component = LinearComponent(d_in=d_in, d_out=d_out, m=m, n_instances=None) + + # # Initialize with A = W (original weights) and B = I (identity) + # # This provides a starting point where the component exactly equals the original + # linear_component.A.data[:] = linear_module.weight.t() # (d_in, m) + # linear_component.B.data[:] = torch.eye(m) + + bias = linear_module.bias.clone() if linear_module.bias is not None else None # type: ignore + + return LinearComponentWithBias(linear_component, bias) + + +# Create LinearComponentWithBias objects for gate_proj in each layer +def create_gate_proj_components(model: Llama, rank: int) -> dict[str, LinearComponentWithBias]: + components = {} + for i in range(len(model.transformer.h)): + gate_proj = model.transformer.h[i].mlp.gate_proj + module_name = f"model.transformer.h.{i}.mlp.gate_proj" + components[module_name] = nn_linear_to_components(gate_proj, m=rank) + return components + + +class SSModel(nn.Module): + """Wrapper around a llama model from SimpleStories for running SPD.""" + + def __init__(self, llama_model: Llama): + super().__init__() + self.model = llama_model + + def forward( + self, + *args: Any, + components: dict[str, LinearComponentWithBias] | None = None, + **kwargs: Any, + ) -> Any: + if components is None: + return self.model(*args, **kwargs) + + old_components = {} + for module_name, component in components.items(): + old_component = get_nested_module_attr(self, module_name) + assert old_component is not None + old_components[module_name] = old_component + set_nested_module_attr(self, module_name, component) + + out = self.model(*args, **kwargs) + + for module_name, component in old_components.items(): + set_nested_module_attr(self, module_name, component) + return out diff --git a/spd/module_utils.py b/spd/module_utils.py index 2fe0666..394c7ef 100644 --- a/spd/module_utils.py +++ b/spd/module_utils.py @@ -28,6 +28,21 @@ def get_nested_module_attr(module: nn.Module, access_string: str) -> Any: return mod +def set_nested_module_attr(module: nn.Module, access_string: str, value: Any) -> None: + """Set a specific attribute by its full, path-like name. + + Args: + module: The module to set the attribute on. + access_string: The full name of the nested attribute to set, with each object separated by periods (e.g. "linear1.A"). + """ + names = access_string.split(".") + try: + mod = reduce(getattr, names[:-1], module) + except AttributeError as err: + raise AttributeError(f"{module} does not have nested attribute {access_string}") from err + setattr(mod, names[-1], value) + + def collect_nested_module_attrs( module: nn.Module, attr_name: str, diff --git a/tests/test_module_utils.py b/tests/test_module_utils.py new file mode 100644 index 0000000..cc2711c --- /dev/null +++ b/tests/test_module_utils.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + +from spd.module_utils import get_nested_module_attr, set_nested_module_attr + + +def test_get_nested_module_attr(): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + module = TestModule() + assert get_nested_module_attr(module, "linear1.weight.data").shape == (10, 10) + assert get_nested_module_attr(module, "linear2.weight.data").shape == (10, 10) + + +def test_set_nested_module_attr(): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + module = TestModule() + set_nested_module_attr(module, "linear1.weight.data", torch.randn(10, 5)) + assert module.linear1.weight.data.shape == (10, 5) From 1a1dcaf0f67e6c717f6fe330fc705dfa7986ecce Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 3 Apr 2025 04:13:32 +0000 Subject: [PATCH 57/73] Fix nn.Linear shape and handle masked components --- spd/experiments/lm/lm_decomposition.py | 24 ++++++++++++-- spd/experiments/lm/models.py | 43 +++++++++++++++++++------- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index bba9bed..be84df4 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -1,4 +1,5 @@ # %% +import torch from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS from transformers import AutoTokenizer @@ -21,9 +22,9 @@ ss_model = SSModel(model) - +m = 17 # Create components with rank=10 (adjust as needed) -gate_proj_components = create_gate_proj_components(model, rank=17) +gate_proj_components = create_gate_proj_components(model, rank=m) # %% # Load tokenizer @@ -46,9 +47,26 @@ # %% # logits, _ = ss_model.forward(input_ids, components=gate_proj_components) -logits, _ = ss_model.model(input_ids, targets=targets) +logits, _ = ss_model.forward(input_ids) print("inputs_shape", input_ids.shape) print("logits", logits) print("logits shape", logits.shape) +logits, _ = ss_model.forward_with_components(input_ids, components=gate_proj_components) + +print("Component logits shape", logits.shape) +print("Component logits", logits) + +# Create some dummy masks +masks = { + f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], m) + for i in range(len(model.transformer.h)) +} + +logits, _ = ss_model.forward_with_components( + input_ids, components=gate_proj_components, masks=masks +) + +print("Masked component logits shape", logits.shape) +print("Masked component logits", logits) # %% diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index c3d723b..57426ba 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -5,6 +5,7 @@ from typing import Any import torch.nn as nn +from jaxtyping import Float from simple_stories_train.models.llama import Llama from torch import Tensor @@ -19,9 +20,10 @@ def __init__(self, linear_component: LinearComponent, bias: Tensor | None): super().__init__() self.linear_component = linear_component self.bias = bias + self.mask: Float[Tensor, "batch pos m"] | None = None # Gets set on sparse forward passes - def forward(self, x: Tensor) -> Tensor: - out = self.linear_component(x) + def forward(self, x: Float[Tensor, "batch d_in"]) -> Float[Tensor, "batch d_out"]: + out = self.linear_component(x, mask=self.mask) if self.bias is not None: out += self.bias return out @@ -29,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor: def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponentWithBias: """Replace a nn.Linear module with a LinearComponentWithBias module.""" - d_in, d_out = linear_module.weight.shape + d_out, d_in = linear_module.weight.shape linear_component = LinearComponent(d_in=d_in, d_out=d_out, m=m, n_instances=None) @@ -63,21 +65,38 @@ def __init__(self, llama_model: Llama): def forward( self, *args: Any, - components: dict[str, LinearComponentWithBias] | None = None, **kwargs: Any, ) -> Any: - if components is None: - return self.model(*args, **kwargs) + """Regular forward pass of the (target) model.""" + return self.model(*args, **kwargs) - old_components = {} + def forward_with_components( + self, + *args: Any, + components: dict[str, LinearComponentWithBias], + masks: dict[str, Float[Tensor, "batch pos m"]] | None = None, + **kwargs: Any, + ) -> Any: + """Forward pass with temporary component replacement.""" + old_modules = {} for module_name, component in components.items(): - old_component = get_nested_module_attr(self, module_name) - assert old_component is not None - old_components[module_name] = old_component + old_module = get_nested_module_attr(self, module_name) + assert old_module is not None + old_modules[module_name] = old_module + + if masks is not None: + assert module_name in masks, f"Mask for {module_name} not found" + component.mask = masks[module_name] set_nested_module_attr(self, module_name, component) out = self.model(*args, **kwargs) - for module_name, component in old_components.items(): - set_nested_module_attr(self, module_name, component) + # Restore the original modules + for module_name, old_module in old_modules.items(): + set_nested_module_attr(self, module_name, old_module) + + # Remove the masks attribute from the components + for component in components.values(): + component.mask = None + return out From 993da44a7efeaaa5dc61ef833b312fbde4a26fa0 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 3 Apr 2025 06:36:36 +0000 Subject: [PATCH 58/73] WIP: Add lm_decomposition script --- .vscode/launch.json | 12 + spd/configs.py | 16 +- spd/experiments/lm/lm_config.yaml | 56 +++ spd/experiments/lm/lm_decomposition.py | 465 ++++++++++++++++++++++--- spd/experiments/lm/models.py | 28 +- spd/experiments/lm/play.py | 75 ++++ 6 files changed, 591 insertions(+), 61 deletions(-) create mode 100644 spd/experiments/lm/lm_config.yaml create mode 100644 spd/experiments/lm/play.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 5ce7aef..89875d2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,5 +37,17 @@ "PYDEVD_DISABLE_FILE_VALIDATION": "1" } }, + { + "name": "lm", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py", + "args": "${workspaceFolder}/spd/experiments/lm/lm_config.yaml", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + } ] } \ No newline at end of file diff --git a/spd/configs.py b/spd/configs.py index 1622900..773223c 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -36,6 +36,18 @@ class ResidualMLPTaskConfig(BaseModel): pretrained_model_path: ModelPath # e.g. wandb:spd-resid-mlp/runs/j9kmavzi +class LMTaskConfig(BaseModel): + model_config = ConfigDict(extra="forbid", frozen=True) + task_name: Literal["lm"] = "lm" + model_size: str # e.g. "1.25M" + max_seq_len: PositiveInt = 512 + buffer_size: PositiveInt = 1000 + dataset_name: str = "lennart-finke/SimpleStories" + dataset_split: str = "train" + # List of fnmatch patterns for nn.Linear modules to decompose + target_module_patterns: list[str] = ["transformer.h.*.mlp.*_proj"] + + class Config(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) wandb_project: str | None = None @@ -68,7 +80,9 @@ class Config(BaseModel): unit_norm_matrices: bool = False attribution_type: Literal["gradient"] = "gradient" n_gate_hidden_neurons: PositiveInt | None = None - task_config: TMSTaskConfig | ResidualMLPTaskConfig = Field(..., discriminator="task_name") + task_config: TMSTaskConfig | ResidualMLPTaskConfig | LMTaskConfig = Field( + ..., discriminator="task_name" + ) DEPRECATED_CONFIG_KEYS: ClassVar[list[str]] = [] RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {} diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml new file mode 100644 index 0000000..fe02f4b --- /dev/null +++ b/spd/experiments/lm/lm_config.yaml @@ -0,0 +1,56 @@ +# --- WandB --- +# wandb_project: spd-lm # Project name for Weights & Biases +wandb_project: null # Project name for Weights & Biases +wandb_run_name: null # Set specific run name (optional, otherwise generated) +wandb_run_name_prefix: "" # Prefix for generated run name + +# --- General --- +seed: 0 +unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here) +m: 100 # Rank of the decomposition / number of components per layer + +# --- Loss Coefficients --- +# Set coeffs to null if the loss shouldn't be computed +param_match_coeff: null # Not applicable for component-only optimization +out_recon_coeff: 1.0 # Reconstruction loss based on output logits (MSE) +lp_sparsity_coeff: 1e-2 # Coefficient for Lp sparsity loss (applied to component params A & B) +pnorm: 1.0 # p-value for the Lp sparsity norm (1.0 for L1) + +# Placeholder losses (set coeffs to null as they require mask calculation implementation) +masked_recon_coeff: null # Reconstruction loss using masks +act_recon_coeff: null # Reconstruction loss on intermediate component activations +random_mask_recon_coeff: null # Reconstruction loss averaged over random masks +layerwise_recon_coeff: null # Layer-wise reconstruction loss +layerwise_random_recon_coeff: null # Layer-wise reconstruction loss with random masks + +n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used +n_gate_hidden_neurons: null # Not applicable as there are no gates currently + +# --- Training --- +batch_size: 2 # Adjust based on GPU memory +steps: 10_000 # Total training steps +lr: 1e-4 # Learning rate +lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) +lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup +lr_exponential_halflife: null # Required if lr_schedule is exponential +init_from_target_model: false # Not implemented/applicable for this setup + +# --- Logging & Saving --- +image_freq: 1000 # Frequency for generating/logging plots +print_freq: 100 # Frequency for printing logs to console +save_freq: 2000 # Frequency for saving checkpoints +image_on_first_step: true # Whether to log plots at step 0 + +# --- Task Specific --- +task_config: + task_name: lm # Specifies the LM decomposition task + model_size: "1.25M" # SimpleStories model size (e.g., "1.25M", "5M", "11M", "30M", "35M") + max_seq_len: 512 # Maximum sequence length for truncation/padding + buffer_size: 1000 # Buffer size for streaming dataset shuffling + dataset_name: "lennart-finke/SimpleStories" # HuggingFace dataset name + dataset_split: "train" # Dataset split to use + # List of fnmatch patterns for nn.Linear modules to decompose + target_module_patterns: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] + # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"] + # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] + # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] \ No newline at end of file diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index be84df4..01b3612 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -1,72 +1,433 @@ -# %% +"""Language Model decomposition script.""" + +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +import fire +import matplotlib.pyplot as plt import torch +import torch.optim as optim +import wandb +import yaml +from jaxtyping import Float +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS -from transformers import AutoTokenizer +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.models import ( + LinearComponentWithBias, + SSModel, + create_target_components, +) +from spd.log import logger +from spd.run_spd import get_common_run_name_suffix +from spd.utils import ( + get_device, + get_lr_schedule_fn, + get_lr_with_warmup, + load_config, + set_seed, +) +from spd.wandb_utils import init_wandb -from spd.experiments.lm.models import SSModel, create_gate_proj_components +# Define wandb_available at the module level +wadb_available = False +try: + import wandb -# %% -# Select the model size you want to use -model_size = "1.25M" # Options: "35M", "30M", "11M", "5M", "1.25M" + wandb.require("core") + wandb_available = True +except ImportError: + logger.warning("wandb not installed, skipping wandb related code.") -# Load model configuration -model_config = MODEL_CONFIGS[model_size] -# Load appropriate model -model_path = f"chandan-sreedhara/SimpleStories-{model_size}" -model = Llama.from_pretrained(model_path, model_config) -# model.to("cuda") -model.eval() -# %% +def get_run_name( + config: Config, + model_size: str, + max_seq_len: int, +) -> str: + """Generate a run name based on the config.""" + run_suffix = "" + if config.wandb_run_name: + run_suffix = config.wandb_run_name + else: + run_suffix = get_common_run_name_suffix(config) + run_suffix += f"_lm{model_size}_seq{max_seq_len}" + return config.wandb_run_name_prefix + run_suffix -ss_model = SSModel(model) -m = 17 -# Create components with rank=10 (adjust as needed) -gate_proj_components = create_gate_proj_components(model, rank=m) +def lm_plot_results_fn( + model: SSModel, + components: dict[str, LinearComponentWithBias], + step: int | None, + out_dir: Path | None, + device: str, + config: Config, + **_, +) -> dict[str, plt.Figure]: + """Plotting function for LM decomposition. Placeholder for now.""" + # TODO: Implement actual plotting (e.g., component matrix values?) + logger.info(f"Plotting results at step {step}...") + fig_dict: dict[str, plt.Figure] = {} + # Example: Potentially plot A/B matrix norms or sparsity patterns? + # fig_dict["component_norms"] = plot_component_norms(components, out_dir, step) + return fig_dict -# %% -# Load tokenizer -tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) -# Define your prompt -prompt = "The curious cat looked at the" +def calc_recon_mse_lm( + out1: Float[Tensor, "batch seq vocab"], + out2: Float[Tensor, "batch seq vocab"], +) -> Float[Tensor, ""]: + """Calculate the Mean Squared Error reconstruction loss for LM logits.""" + assert out1.shape == out2.shape + # Mean over batch and sequence length, sum over vocab + return ((out1 - out2) ** 2).sum(dim=-1).mean() -# IMPORTANT: Use tokenizer without special tokens -inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) -# input_ids = inputs.input_ids.to("cuda") -input_ids = inputs.input_ids -# Targets should be the inputs shifted by one (we will later ignore the last input token) -targets = input_ids[:, 1:] -input_ids = input_ids[:, :-1] -# IMPORTANT: Set correct EOS token ID (not the default from tokenizer) -eos_token_id = 1 +def optimize_lm( + model: SSModel, + components: dict[str, LinearComponentWithBias], + config: Config, + device: str, + dataloader: DataLoader[tuple[Float[Tensor, "batch pos"], Float[Tensor, "batch pos"]]], + out_dir: Path, + plot_results_fn: Callable[..., dict[str, plt.Figure]], +) -> None: + """Run the optimization loop for LM decomposition.""" + # --- Optimizer --- # + component_params = [] + param_names_to_optimize = [] + for name, component in components.items(): + component_params.extend(list(component.parameters())) + param_names_to_optimize.extend( + [f"{name}.{p_name}" for p_name, _ in component.named_parameters()] + ) + logger.debug(f"Adding parameters from component: {name}") -# %% + if not component_params: + logger.error("No parameters found in components to optimize. Exiting.") + return -# logits, _ = ss_model.forward(input_ids, components=gate_proj_components) -logits, _ = ss_model.forward(input_ids) -print("inputs_shape", input_ids.shape) -print("logits", logits) -print("logits shape", logits.shape) + optimizer = optim.AdamW(component_params, lr=config.lr, weight_decay=0.0) + logger.info(f"Optimizer created for params: {param_names_to_optimize}") + logger.info(f"Optimizer details: {optimizer}") -logits, _ = ss_model.forward_with_components(input_ids, components=gate_proj_components) + # --- Scheduler --- # + # Get the base LR schedule function (e.g., constant, linear, cosine) + lr_schedule_fn = get_lr_schedule_fn( + config.lr_schedule, + config.lr_exponential_halflife, + ) + logger.info(f"Base LR scheduler created: {config.lr_schedule}") -print("Component logits shape", logits.shape) -print("Component logits", logits) + # --- Training Loop --- # + pbar = tqdm(range(config.steps), desc="Optimizing Components") + log_data = {} + # Make dataloader an iterator + # TODO: Handle dataloader exhaustion if it's finite (e.g., for validation) + data_iter = iter(dataloader) -# Create some dummy masks -masks = { - f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], m) - for i in range(len(model.transformer.h)) -} + for step in pbar: + # --- LR Scheduling Step --- # + step_lr = get_lr_with_warmup( + step=step, + steps=config.steps, + lr=config.lr, + lr_schedule_fn=lr_schedule_fn, + lr_warmup_pct=config.lr_warmup_pct, + ) + # Manually update optimizer's learning rate + for group in optimizer.param_groups: + group["lr"] = step_lr + log_data["lr"] = step_lr + + # --- Zero Gradients --- # + optimizer.zero_grad() + + # --- Get Batch --- # + try: + batch = next(data_iter) + except StopIteration: + logger.warning("Dataloader exhausted, resetting iterator.") + data_iter = iter(dataloader) + batch = next(data_iter) + + input_ids = batch["input_ids"].to(device) + + # --- Calculate Losses --- # + total_loss = torch.tensor(0.0, device=device) + loss_terms = {} + + # 1. Reconstruction Loss (comparing logits) + if config.out_recon_coeff is not None and config.out_recon_coeff > 0: + # Get target logits (no gradients needed for target model) + with torch.no_grad(): + target_logits, _ = model.forward(input_ids) + # Detach target logits to ensure no grads flow back + target_logits = target_logits.detach() + + # Get component logits + component_logits, _ = model.forward_with_components(input_ids, components=components) + + # Ensure shapes match (Batch, SeqLen-1, VocabSize) + assert component_logits.shape == target_logits.shape, ( + f"Shape mismatch: {component_logits.shape} vs {target_logits.shape}" + ) + + recon_loss = calc_recon_mse_lm(component_logits, target_logits) + total_loss += config.out_recon_coeff * recon_loss + loss_terms["loss/reconstruction"] = recon_loss.item() + + # 2. Sparsity Loss (Lp norm on component parameters) + # Note: Using p=config.pnorm. The original optimize used relud_masks from gates. + lp_sparsity_loss_val = None + if config.lp_sparsity_coeff > 0: + lp_norm = torch.tensor(0.0, device=device) + for component in components.values(): + # Apply Lp loss to A and B matrices + lp_norm += torch.norm(component.linear_component.A, p=config.pnorm) + lp_norm += torch.norm(component.linear_component.B, p=config.pnorm) + + lp_sparsity_loss_val = lp_norm + total_loss += config.lp_sparsity_coeff * lp_sparsity_loss_val + loss_terms[f"loss/sparsity_l{config.pnorm}_params"] = lp_sparsity_loss_val.item() + + # --- Placeholder Losses (Mimicking run_spd.optimize structure) --- + # These require a mechanism for calculating masks specific to the LM setup. + masks = None # Placeholder: Masks are needed for the following losses + masked_recon_loss_val = None + if config.masked_recon_coeff is not None and config.masked_recon_coeff > 0: + logger.warning("masked_recon_loss requires mask calculation implementation.") + # TODO: Calculate masked_recon_loss_val using masks + # e.g., component_logits_masked = model.forward_with_components(..., masks=masks) + # masked_recon_loss_val = calc_recon_mse_lm(component_logits_masked, target_logits) + loss_terms["loss/masked_reconstruction"] = None # Or 0.0 if calculated + + act_recon_loss_val = None + if config.act_recon_coeff is not None and config.act_recon_coeff > 0: + logger.warning("act_recon_loss requires mask and target activation calculation.") + # TODO: Implement act_recon_loss_val + loss_terms["loss/activation_reconstruction"] = None + + random_masks_loss_val = None + if config.random_mask_recon_coeff is not None and config.random_mask_recon_coeff > 0: + logger.warning("random_masks_loss requires mask calculation implementation.") + # TODO: Implement random_masks_loss_val + loss_terms["loss/random_mask_reconstruction"] = None + + layerwise_recon_loss_val = None + if config.layerwise_recon_coeff is not None and config.layerwise_recon_coeff > 0: + logger.warning("layerwise_recon_loss requires mask calculation and layerwise hooks.") + # TODO: Implement layerwise_recon_loss_val + loss_terms["loss/layerwise_reconstruction"] = None + + layerwise_random_recon_loss_val = None + if ( + config.layerwise_random_recon_coeff is not None + and config.layerwise_random_recon_coeff > 0 + ): + logger.warning( + "layerwise_random_recon_loss requires mask calculation and layerwise hooks." + ) + # TODO: Implement layerwise_random_recon_loss_val + loss_terms["loss/layerwise_random_reconstruction"] = None + + # Add placeholder losses to total_loss if they were calculated (currently they are not) + # Example if masked_recon_loss_val was calculated: + # if masked_recon_loss_val is not None: + # total_loss += config.masked_recon_coeff * masked_recon_loss_val + # Repeat for other placeholder losses... + + # --- Backward Pass & Optimize --- # + if total_loss.requires_grad: + total_loss.backward() + # Optional: Gradient Clipping + # grad_norm_clip_val = 1.0 + # grad_norm = torch.nn.utils.clip_grad_norm_(component_params, max_norm=grad_norm_clip_val) + # log_data["grad_norm/clipped"] = grad_norm.item() + + optimizer.step() + elif total_loss == 0.0: + logger.warning(f"Step {step}: Total loss is zero, skipping backward/optimize.") + else: + logger.warning(f"Step {step}: No loss requires grad, skipping backward/optimize.") + + log_data["loss/total"] = total_loss.item() + + # --- Logging --- # + if step % config.print_freq == 0 or step == config.steps - 1: + log_data.update(loss_terms) # Add individual loss terms for logging + pbar.set_postfix(log_data) + if config.wandb_project and wadb_available: + wandb.log(log_data, step=step) + # Reset loss_terms part of log_data for next interval, keep LR + log_data = {"lr": step_lr} + + # --- Plotting --- # + if config.image_freq is not None and ( + ( + step % config.image_freq == 0 and step > 0 + ) # Avoid plotting at step 0 unless requested + or (config.image_on_first_step and step == 0) + or (step == config.steps - 1) # Always plot at the end + ): + logger.info(f"Step {step}: Generating plots...") + # Ensure model is in eval mode for plotting if necessary, though shouldn't matter here + # model.eval() + with torch.no_grad(): + figures = plot_results_fn( + model=model, # Pass the SSModel wrapper + components=components, + step=step, + out_dir=out_dir, + device=device, + config=config, + # Add any other necessary args for plotting like tokenizer, sample text? + ) + if config.wandb_project and wadb_available and figures: + wandb.log({f"plots/{k}": wandb.Image(v) for k, v in figures.items()}, step=step) + # model.train() # Set back to train mode if needed + + # --- Saving Checkpoints --- # + if (config.save_freq is not None and step % config.save_freq == 0 and step > 0) or ( + step == config.steps - 1 + ): + checkpoint_dir = out_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + checkpoint_path = checkpoint_dir / f"components_step_{step}.pt" + # Save only component state dicts + component_state_dicts = {n: c.state_dict() for n, c in components.items()} + save_payload = { + "components": component_state_dicts, + "optimizer": optimizer.state_dict(), + # "scheduler": scheduler.state_dict(), + "step": step, + "config": config.model_dump(mode="json"), + } + torch.save(save_payload, checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + if config.wandb_project and wadb_available: + wandb.save(str(checkpoint_path), base_path=str(out_dir), policy="now") + + logger.info("Finished training loop.") + + +def main( + config_path_or_obj: Path | str | Config, sweep_config_path: Path | str | None = None +) -> None: + config = load_config(config_path_or_obj, config_model=Config) + + if config.wandb_project and wadb_available: + config = init_wandb(config, config.wandb_project, sweep_config_path) + + set_seed(config.seed) + logger.info(config) + + device = get_device() + logger.info(f"Using device: {device}") + assert isinstance(config.task_config, LMTaskConfig), ( + "Task config must be LMTaskConfig for LM decomposition." + ) + + # --- Load Model --- # + logger.info(f"Loading model: {config.task_config.model_size}") + model_config_dict = MODEL_CONFIGS[config.task_config.model_size] + model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + model = Llama.from_pretrained(model_path, model_config_dict) + ss_model = SSModel(model) + ss_model.to(device) + logger.info("Model loaded.") + + # --- Setup Run Name and Output Dir --- # + run_name = get_run_name( + config, + model_size=config.task_config.model_size, + max_seq_len=config.task_config.max_seq_len, + ) + if config.wandb_project: + assert wandb.run, "wandb.run must be initialized before training" + wandb.run.name = run_name + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] + out_dir = Path(__file__).parent / "out" / f"{run_name}_{timestamp}" + out_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Output directory: {out_dir}") + + # --- Save Config --- # + with open(out_dir / "final_config.yaml", "w") as f: + yaml.dump(config.model_dump(mode="json"), f, indent=2) + if config.wandb_project: + wandb.save(str(out_dir / "final_config.yaml"), base_path=out_dir, policy="now") + + # --- Load Data --- # + logger.info("Loading dataset...") + dataset_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=model_path, + split=config.task_config.dataset_split, + n_ctx=config.task_config.max_seq_len, # Use n_ctx as per DatasetConfig + is_tokenized=False, # Assume dataset is tokenized + streaming=True, # Use streaming as per default + column_name="story", + # Assuming default tokenizer path is okay + ) + # Note: SimpleStories dataloader might require specific DDP setup if used. + # Assuming single-process for now (ddp_rank=0, ddp_world_size=1) + dataloader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + logger.info("Dataset and tokenizer loaded.") + + # --- Freeze Target Model --- # + logger.info("Freezing target model parameters...") + for param in ss_model.model.parameters(): + param.requires_grad = False + logger.info("Target model frozen.") + + # --- Initialize Components --- # + logger.info( + f"Initializing components for modules matching: {config.task_config.target_module_patterns}" + ) + components = create_target_components( + ss_model.model, + rank=config.m, + target_module_patterns=config.task_config.target_module_patterns, + ) + logger.info(f"Created {len(components)} components: {list(components.keys())}") + + # Move components to device (their parameters are registered within the LinearComponent) + for name, component in components.items(): + logger.debug(f"Moving component {name} to {device}") + component.to(device) + logger.info("Components initialized and moved to device.") + + # --- Run Optimization --- # + logger.info("Starting optimization...") + optimize_lm( + model=ss_model, + components=components, + config=config, + device=device, + dataloader=dataloader, + out_dir=out_dir, + plot_results_fn=lm_plot_results_fn, + ) + + logger.info("Optimization finished.") + + if config.wandb_project and wadb_available: + wandb.finish() -logits, _ = ss_model.forward_with_components( - input_ids, components=gate_proj_components, masks=masks -) -print("Masked component logits shape", logits.shape) -print("Masked component logits", logits) -# %% +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 57426ba..769d97a 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -2,6 +2,7 @@ Defines a SSModel class that is a wrapper around a llama model from SimpleStories """ +import fnmatch from typing import Any import torch.nn as nn @@ -20,9 +21,11 @@ def __init__(self, linear_component: LinearComponent, bias: Tensor | None): super().__init__() self.linear_component = linear_component self.bias = bias - self.mask: Float[Tensor, "batch pos m"] | None = None # Gets set on sparse forward passes + self.mask: Float[Tensor, "... m"] | None = None # Gets set on sparse forward passes - def forward(self, x: Float[Tensor, "batch d_in"]) -> Float[Tensor, "batch d_out"]: + def forward(self, x: Float[Tensor, "... d_in"]) -> Float[Tensor, "... d_out"]: + # Note: We assume bias is added *after* the component multiplication + # Also assume input is (batch, seq_len, d_in) out = self.linear_component(x, mask=self.mask) if self.bias is not None: out += self.bias @@ -45,13 +48,22 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent return LinearComponentWithBias(linear_component, bias) -# Create LinearComponentWithBias objects for gate_proj in each layer -def create_gate_proj_components(model: Llama, rank: int) -> dict[str, LinearComponentWithBias]: +def create_target_components( + model: Llama, rank: int, target_module_patterns: list[str] +) -> dict[str, LinearComponentWithBias]: + """Create LinearComponentWithBias objects for nn.Linear modules matching the patterns.""" components = {} - for i in range(len(model.transformer.h)): - gate_proj = model.transformer.h[i].mlp.gate_proj - module_name = f"model.transformer.h.{i}.mlp.gate_proj" - components[module_name] = nn_linear_to_components(gate_proj, m=rank) + for name, module in model.named_modules(): + for pattern in target_module_patterns: + if fnmatch.fnmatch(name, pattern): + # If a module name matches a pattern, assert it's a Linear layer + assert isinstance(module, nn.Linear), ( + f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " + f"Found type: {type(module)}" + ) + components[name] = nn_linear_to_components(module, m=rank) + # Module matched and processed, move to the next module + break return components diff --git a/spd/experiments/lm/play.py b/spd/experiments/lm/play.py new file mode 100644 index 0000000..f83fabc --- /dev/null +++ b/spd/experiments/lm/play.py @@ -0,0 +1,75 @@ +# %% +import torch +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from transformers import AutoTokenizer + +from spd.experiments.lm.models import SSModel, create_target_components + +# %% +# Select the model size you want to use +model_size = "1.25M" # Options: "35M", "30M", "11M", "5M", "1.25M" + +# Load model configuration +model_config = MODEL_CONFIGS[model_size] + +# Load appropriate model +model_path = f"chandan-sreedhara/SimpleStories-{model_size}" +model = Llama.from_pretrained(model_path, model_config) +# model.to("cuda") +model.eval() +# %% + +ss_model = SSModel(model) + +m = 17 +# Create components with rank=10 (adjust as needed) +gate_proj_components = create_target_components( + model, rank=m, target_module_patterns=["model.transformer.h.*.mlp.gate_proj"] +) + +# %% +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) + +# Define your prompt +prompt = "The curious cat looked at the" + +# IMPORTANT: Use tokenizer without special tokens +inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) +# input_ids = inputs.input_ids.to("cuda") +input_ids = inputs.input_ids +# Targets should be the inputs shifted by one (we will later ignore the last input token) +targets = input_ids[:, 1:] +input_ids = input_ids[:, :-1] + +# IMPORTANT: Set correct EOS token ID (not the default from tokenizer) +eos_token_id = 1 + +# %% + +# logits, _ = ss_model.forward(input_ids, components=gate_proj_components) +logits, _ = ss_model.forward(input_ids) +print("inputs_shape", input_ids.shape) +print("logits", logits) +print("logits shape", logits.shape) + +logits, _ = ss_model.forward_with_components(input_ids, components=gate_proj_components) + +print("Component logits shape", logits.shape) +print("Component logits", logits) + +# Create some dummy masks +masks = { + f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], m) + for i in range(len(model.transformer.h)) +} + +logits, _ = ss_model.forward_with_components( + input_ids, components=gate_proj_components, masks=masks +) + +print("Masked component logits shape", logits.shape) +print("Masked component logits", logits) +######################################################### +# %% From fccc1896ca4e9037081455d329bcf03995c9cc57 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 3 Apr 2025 23:55:46 +0000 Subject: [PATCH 59/73] Fix module paths --- spd/experiments/lm/lm_decomposition.py | 39 ++++++++------------------ spd/experiments/lm/models.py | 16 ++++------- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 01b3612..4d2b9e9 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -35,15 +35,7 @@ ) from spd.wandb_utils import init_wandb -# Define wandb_available at the module level -wadb_available = False -try: - import wandb - - wandb.require("core") - wandb_available = True -except ImportError: - logger.warning("wandb not installed, skipping wandb related code.") +wandb.require("core") def get_run_name( @@ -263,7 +255,7 @@ def optimize_lm( if step % config.print_freq == 0 or step == config.steps - 1: log_data.update(loss_terms) # Add individual loss terms for logging pbar.set_postfix(log_data) - if config.wandb_project and wadb_available: + if config.wandb_project: wandb.log(log_data, step=step) # Reset loss_terms part of log_data for next interval, keep LR log_data = {"lr": step_lr} @@ -289,7 +281,7 @@ def optimize_lm( config=config, # Add any other necessary args for plotting like tokenizer, sample text? ) - if config.wandb_project and wadb_available and figures: + if config.wandb_project and figures: wandb.log({f"plots/{k}": wandb.Image(v) for k, v in figures.items()}, step=step) # model.train() # Set back to train mode if needed @@ -311,7 +303,7 @@ def optimize_lm( } torch.save(save_payload, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}") - if config.wandb_project and wadb_available: + if config.wandb_project: wandb.save(str(checkpoint_path), base_path=str(out_dir), policy="now") logger.info("Finished training loop.") @@ -322,7 +314,7 @@ def main( ) -> None: config = load_config(config_path_or_obj, config_model=Config) - if config.wandb_project and wadb_available: + if config.wandb_project: config = init_wandb(config, config.wandb_project, sweep_config_path) set_seed(config.seed) @@ -370,14 +362,12 @@ def main( tokenizer_file_path=None, hf_tokenizer_path=model_path, split=config.task_config.dataset_split, - n_ctx=config.task_config.max_seq_len, # Use n_ctx as per DatasetConfig - is_tokenized=False, # Assume dataset is tokenized - streaming=True, # Use streaming as per default + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, column_name="story", - # Assuming default tokenizer path is okay ) - # Note: SimpleStories dataloader might require specific DDP setup if used. - # Assuming single-process for now (ddp_rank=0, ddp_world_size=1) + dataloader, tokenizer = create_data_loader( dataset_config=dataset_config, batch_size=config.batch_size, @@ -388,7 +378,6 @@ def main( ) logger.info("Dataset and tokenizer loaded.") - # --- Freeze Target Model --- # logger.info("Freezing target model parameters...") for param in ss_model.model.parameters(): param.requires_grad = False @@ -402,16 +391,10 @@ def main( ss_model.model, rank=config.m, target_module_patterns=config.task_config.target_module_patterns, + device=device, ) logger.info(f"Created {len(components)} components: {list(components.keys())}") - # Move components to device (their parameters are registered within the LinearComponent) - for name, component in components.items(): - logger.debug(f"Moving component {name} to {device}") - component.to(device) - logger.info("Components initialized and moved to device.") - - # --- Run Optimization --- # logger.info("Starting optimization...") optimize_lm( model=ss_model, @@ -425,7 +408,7 @@ def main( logger.info("Optimization finished.") - if config.wandb_project and wadb_available: + if config.wandb_project: wandb.finish() diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 769d97a..7830996 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -49,7 +49,7 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent def create_target_components( - model: Llama, rank: int, target_module_patterns: list[str] + model: Llama, rank: int, target_module_patterns: list[str], device: str ) -> dict[str, LinearComponentWithBias]: """Create LinearComponentWithBias objects for nn.Linear modules matching the patterns.""" components = {} @@ -61,7 +61,7 @@ def create_target_components( f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " f"Found type: {type(module)}" ) - components[name] = nn_linear_to_components(module, m=rank) + components[name] = nn_linear_to_components(module, m=rank).to(device) # Module matched and processed, move to the next module break return components @@ -74,11 +74,7 @@ def __init__(self, llama_model: Llama): super().__init__() self.model = llama_model - def forward( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> Any: """Regular forward pass of the (target) model.""" return self.model(*args, **kwargs) @@ -92,20 +88,20 @@ def forward_with_components( """Forward pass with temporary component replacement.""" old_modules = {} for module_name, component in components.items(): - old_module = get_nested_module_attr(self, module_name) + old_module = get_nested_module_attr(self.model, module_name) assert old_module is not None old_modules[module_name] = old_module if masks is not None: assert module_name in masks, f"Mask for {module_name} not found" component.mask = masks[module_name] - set_nested_module_attr(self, module_name, component) + set_nested_module_attr(self.model, module_name, component) out = self.model(*args, **kwargs) # Restore the original modules for module_name, old_module in old_modules.items(): - set_nested_module_attr(self, module_name, old_module) + set_nested_module_attr(self.model, module_name, old_module) # Remove the masks attribute from the components for component in components.values(): From 3fcf59388b03b7353fb6b4ef83956fffa2611634 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 4 Apr 2025 06:40:56 +0000 Subject: [PATCH 60/73] WIP: Add param_match_loss --- spd/experiments/lm/lm_config.yaml | 8 +- spd/experiments/lm/lm_decomposition.py | 220 ++++++++++++++++--------- spd/experiments/lm/models.py | 61 ++++--- 3 files changed, 179 insertions(+), 110 deletions(-) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index fe02f4b..2cd9f93 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -7,13 +7,13 @@ wandb_run_name_prefix: "" # Prefix for generated run name # --- General --- seed: 0 unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here) -m: 100 # Rank of the decomposition / number of components per layer +m: 3 # Rank of the decomposition / number of components per layer # --- Loss Coefficients --- # Set coeffs to null if the loss shouldn't be computed -param_match_coeff: null # Not applicable for component-only optimization -out_recon_coeff: 1.0 # Reconstruction loss based on output logits (MSE) -lp_sparsity_coeff: 1e-2 # Coefficient for Lp sparsity loss (applied to component params A & B) +param_match_coeff: 1.0 +out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE) +lp_sparsity_coeff: 0.0 # Coefficient for Lp sparsity loss (applied to component params A & B) pnorm: 1.0 # p-value for the Lp sparsity norm (1.0 for L1) # Placeholder losses (set coeffs to null as they require mask calculation implementation) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 4d2b9e9..09d982f 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -4,6 +4,7 @@ from datetime import datetime from pathlib import Path +import einops import fire import matplotlib.pyplot as plt import torch @@ -16,16 +17,15 @@ from simple_stories_train.models.model_configs import MODEL_CONFIGS from torch import Tensor from torch.utils.data import DataLoader -from tqdm.auto import tqdm +from tqdm import tqdm from spd.configs import Config, LMTaskConfig from spd.experiments.lm.models import ( LinearComponentWithBias, SSModel, - create_target_components, ) from spd.log import logger -from spd.run_spd import get_common_run_name_suffix +from spd.run_spd import _calc_param_mse, get_common_run_name_suffix from spd.utils import ( get_device, get_lr_schedule_fn, @@ -81,20 +81,46 @@ def calc_recon_mse_lm( return ((out1 - out2) ** 2).sum(dim=-1).mean() +def calc_param_match_loss( + components: dict[str, LinearComponentWithBias], + target_model: Llama, + n_params: int, + device: str, +) -> Float[Tensor, ""]: + """Calculate the MSE loss between component parameters (A@B + bias) and target parameters.""" + target_params: dict[str, Float[Tensor, "d_in d_out"]] = {} + component_params: dict[str, Float[Tensor, "d_in d_out"]] = {} + + for comp_name, component in components.items(): + component_params[comp_name] = einops.einsum( + component.linear_component.A, + component.linear_component.B, + "d_in m, m d_out -> d_in d_out", + ) + target_params[comp_name] = target_model.get_parameter(comp_name + ".weight").T + assert component_params[comp_name].shape == target_params[comp_name].shape + + param_mse = _calc_param_mse( + params1=component_params, + params2=target_params, + n_params=n_params, + device=device, + ) + return param_mse + + def optimize_lm( model: SSModel, - components: dict[str, LinearComponentWithBias], config: Config, device: str, dataloader: DataLoader[tuple[Float[Tensor, "batch pos"], Float[Tensor, "batch pos"]]], - out_dir: Path, plot_results_fn: Callable[..., dict[str, plt.Figure]], + out_dir: Path | None, ) -> None: """Run the optimization loop for LM decomposition.""" - # --- Optimizer --- # component_params = [] param_names_to_optimize = [] - for name, component in components.items(): + for name, component in model.components.items(): component_params.extend(list(component.parameters())) param_names_to_optimize.extend( [f"{name}.{p_name}" for p_name, _ in component.named_parameters()] @@ -107,24 +133,20 @@ def optimize_lm( optimizer = optim.AdamW(component_params, lr=config.lr, weight_decay=0.0) logger.info(f"Optimizer created for params: {param_names_to_optimize}") - logger.info(f"Optimizer details: {optimizer}") - # --- Scheduler --- # - # Get the base LR schedule function (e.g., constant, linear, cosine) - lr_schedule_fn = get_lr_schedule_fn( - config.lr_schedule, - config.lr_exponential_halflife, - ) + lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) logger.info(f"Base LR scheduler created: {config.lr_schedule}") - # --- Training Loop --- # - pbar = tqdm(range(config.steps), desc="Optimizing Components") + n_params = 0 + for module_name in model.components: + weight = model.model.get_parameter(module_name + ".weight") + n_params += weight.numel() + log_data = {} - # Make dataloader an iterator - # TODO: Handle dataloader exhaustion if it's finite (e.g., for validation) data_iter = iter(dataloader) - for step in pbar: + # Use tqdm directly in the loop, iterate one extra step for final logging/plotting/saving + for step in tqdm(range(config.steps + 1), ncols=0): # --- LR Scheduling Step --- # step_lr = get_lr_with_warmup( step=step, @@ -143,28 +165,67 @@ def optimize_lm( # --- Get Batch --- # try: - batch = next(data_iter) + batch = next(data_iter)["input_ids"].to(device) except StopIteration: logger.warning("Dataloader exhausted, resetting iterator.") data_iter = iter(dataloader) - batch = next(data_iter) - - input_ids = batch["input_ids"].to(device) + batch = next(data_iter)["input_ids"].to(device) + + # # Forward pass with target model + # target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) and any( + # k.startswith("model." + module_name) for module_name in model.components + # ) + # target_out, target_cache = model.run_with_cache(batch, names_filter=target_cache_filter) + # I want to do a forward pass on model.model, but applying pre_forward_hooks to each of the + # keys in models.components. The pre_forward_hooks should simply cache the activations that + # go into each of the components. + + # # Forward pass with all subnetworks + # out = model(batch) + + # pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} + As = {module_name: v.linear_component.A for module_name, v in model.components.items()} + + # target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + # attributions = calc_grad_attributions( + # target_out=target_out, + # pre_weight_acts=pre_weight_acts, + # post_weight_acts={k: v for k, v in target_cache.items() if k.endswith("hook_post")}, + # target_component_acts=target_component_acts, + # Bs=collect_nested_module_attrs(model, attr_name="B", include_attr_name=False), + # ) + attributions = None + + # masks, relud_masks = calc_masks( + # gates=dates, + # target_component_acts=target_component_acts, + # attributions=attributions, + # detach_inputs=False, + # ) # --- Calculate Losses --- # total_loss = torch.tensor(0.0, device=device) loss_terms = {} + param_match_loss_val = calc_param_match_loss( + components=model.components, + target_model=model.model, + n_params=n_params, + device=device, + ) + total_loss += config.param_match_coeff * param_match_loss_val + loss_terms["loss/parameter_matching"] = param_match_loss_val.item() + # 1. Reconstruction Loss (comparing logits) if config.out_recon_coeff is not None and config.out_recon_coeff > 0: # Get target logits (no gradients needed for target model) with torch.no_grad(): - target_logits, _ = model.forward(input_ids) + target_logits, _ = model.forward(batch) # Detach target logits to ensure no grads flow back target_logits = target_logits.detach() # Get component logits - component_logits, _ = model.forward_with_components(input_ids, components=components) + component_logits, _ = model.forward_with_components(batch, components=model.components) # Ensure shapes match (Batch, SeqLen-1, VocabSize) assert component_logits.shape == target_logits.shape, ( @@ -180,7 +241,7 @@ def optimize_lm( lp_sparsity_loss_val = None if config.lp_sparsity_coeff > 0: lp_norm = torch.tensor(0.0, device=device) - for component in components.values(): + for component in model.components.values(): # Apply Lp loss to A and B matrices lp_norm += torch.norm(component.linear_component.A, p=config.pnorm) lp_norm += torch.norm(component.linear_component.B, p=config.pnorm) @@ -235,69 +296,61 @@ def optimize_lm( # total_loss += config.masked_recon_coeff * masked_recon_loss_val # Repeat for other placeholder losses... - # --- Backward Pass & Optimize --- # - if total_loss.requires_grad: - total_loss.backward() - # Optional: Gradient Clipping - # grad_norm_clip_val = 1.0 - # grad_norm = torch.nn.utils.clip_grad_norm_(component_params, max_norm=grad_norm_clip_val) - # log_data["grad_norm/clipped"] = grad_norm.item() - - optimizer.step() - elif total_loss == 0.0: - logger.warning(f"Step {step}: Total loss is zero, skipping backward/optimize.") - else: - logger.warning(f"Step {step}: No loss requires grad, skipping backward/optimize.") - log_data["loss/total"] = total_loss.item() + log_data.update(loss_terms) # --- Logging --- # - if step % config.print_freq == 0 or step == config.steps - 1: - log_data.update(loss_terms) # Add individual loss terms for logging - pbar.set_postfix(log_data) + if step % config.print_freq == 0: + tqdm.write(f"--- Step {step} ---") + tqdm.write(f"LR: {step_lr:.6f}") + tqdm.write(f"Total Loss: {log_data['loss/total']:.7f}") + for name, value in loss_terms.items(): + if value is not None: + tqdm.write(f"{name}: {value:.7f}") + if config.wandb_project: wandb.log(log_data, step=step) - # Reset loss_terms part of log_data for next interval, keep LR - log_data = {"lr": step_lr} # --- Plotting --- # - if config.image_freq is not None and ( - ( - step % config.image_freq == 0 and step > 0 - ) # Avoid plotting at step 0 unless requested - or (config.image_on_first_step and step == 0) - or (step == config.steps - 1) # Always plot at the end + if ( + config.image_freq is not None + and step % config.image_freq == 0 + and (step > 0 or config.image_on_first_step) ): logger.info(f"Step {step}: Generating plots...") - # Ensure model is in eval mode for plotting if necessary, though shouldn't matter here - # model.eval() with torch.no_grad(): - figures = plot_results_fn( + fig_dict = plot_results_fn( model=model, # Pass the SSModel wrapper - components=components, + components=model.components, step=step, out_dir=out_dir, device=device, config=config, # Add any other necessary args for plotting like tokenizer, sample text? ) - if config.wandb_project and figures: - wandb.log({f"plots/{k}": wandb.Image(v) for k, v in figures.items()}, step=step) - # model.train() # Set back to train mode if needed - - # --- Saving Checkpoints --- # - if (config.save_freq is not None and step % config.save_freq == 0 and step > 0) or ( - step == config.steps - 1 - ): + if config.wandb_project: + wandb.log( + {k: wandb.Image(v) for k, v in fig_dict.items()}, + step=step, + ) + if out_dir is not None: + for k, v in fig_dict.items(): + v.savefig(out_dir / f"{k}_{step}.png") + tqdm.write(f"Saved plot to {out_dir / f'{k}_{step}.png'}") + + # --- Saving Checkpoint --- # + if ( + (config.save_freq is not None and step % config.save_freq == 0 and step > 0) + or step == config.steps + ) and out_dir is not None: checkpoint_dir = out_dir / "checkpoints" checkpoint_dir.mkdir(exist_ok=True) checkpoint_path = checkpoint_dir / f"components_step_{step}.pt" # Save only component state dicts - component_state_dicts = {n: c.state_dict() for n, c in components.items()} + component_state_dicts = {n: c.state_dict() for n, c in model.components.items()} save_payload = { "components": component_state_dicts, "optimizer": optimizer.state_dict(), - # "scheduler": scheduler.state_dict(), "step": step, "config": config.model_dump(mode="json"), } @@ -306,6 +359,23 @@ def optimize_lm( if config.wandb_project: wandb.save(str(checkpoint_path), base_path=str(out_dir), policy="now") + # --- Backward Pass & Optimize --- # + # Skip gradient step if we are at the last step (last step just for plotting and logging) + if step != config.steps: + total_loss.backward(retain_graph=True) + + if step % config.print_freq == 0 and config.wandb_project: + # Calculate gradient norm + grad_norm: float = 0.0 + for param in model.parameters(): + if param.grad is not None: + grad_norm += param.grad.data.norm() # type: ignore + wandb.log({"grad_norm": grad_norm}, step=step) + + if config.unit_norm_matrices: + model.fix_normalized_adam_gradients() + + optimizer.step() logger.info("Finished training loop.") @@ -331,7 +401,12 @@ def main( model_config_dict = MODEL_CONFIGS[config.task_config.model_size] model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" model = Llama.from_pretrained(model_path, model_config_dict) - ss_model = SSModel(model) + + ss_model = SSModel( + llama_model=model, + target_module_patterns=config.task_config.target_module_patterns, + m=config.m, + ) ss_model.to(device) logger.info("Model loaded.") @@ -383,22 +458,9 @@ def main( param.requires_grad = False logger.info("Target model frozen.") - # --- Initialize Components --- # - logger.info( - f"Initializing components for modules matching: {config.task_config.target_module_patterns}" - ) - components = create_target_components( - ss_model.model, - rank=config.m, - target_module_patterns=config.task_config.target_module_patterns, - device=device, - ) - logger.info(f"Created {len(components)} components: {list(components.keys())}") - logger.info("Starting optimization...") optimize_lm( model=ss_model, - components=components, config=config, device=device, dataloader=dataloader, diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 7830996..271ee4b 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -11,7 +11,6 @@ from torch import Tensor from spd.models.components import LinearComponent -from spd.module_utils import get_nested_module_attr, set_nested_module_attr class LinearComponentWithBias(nn.Module): @@ -48,31 +47,40 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent return LinearComponentWithBias(linear_component, bias) -def create_target_components( - model: Llama, rank: int, target_module_patterns: list[str], device: str -) -> dict[str, LinearComponentWithBias]: - """Create LinearComponentWithBias objects for nn.Linear modules matching the patterns.""" - components = {} - for name, module in model.named_modules(): - for pattern in target_module_patterns: - if fnmatch.fnmatch(name, pattern): - # If a module name matches a pattern, assert it's a Linear layer - assert isinstance(module, nn.Linear), ( - f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " - f"Found type: {type(module)}" - ) - components[name] = nn_linear_to_components(module, m=rank).to(device) - # Module matched and processed, move to the next module - break - return components - - +# class SSModel(HookedRootModule): class SSModel(nn.Module): """Wrapper around a llama model from SimpleStories for running SPD.""" - def __init__(self, llama_model: Llama): + def __init__(self, llama_model: Llama, target_module_patterns: list[str], m: int): super().__init__() self.model = llama_model + self.components = self.create_target_components( + target_module_patterns=target_module_patterns, m=m + ) + # self.setup() + + def create_target_components( + self, target_module_patterns: list[str], m: int + ) -> dict[str, LinearComponentWithBias]: + """Create target components for the model.""" + components = {} + for name, module in self.model.named_modules(): + for pattern in target_module_patterns: + if fnmatch.fnmatch(name, pattern): + assert isinstance(module, nn.Linear), ( + f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " + f"Found type: {type(module)}" + ) + components[name] = nn_linear_to_components(module, m=m) + break + return components + + def to(self, *args: Any, **kwargs: Any) -> "SSModel": + """Move the model and components to a device.""" + self.model.to(*args, **kwargs) + for component in self.components.values(): + component.to(*args, **kwargs) + return self def forward(self, *args: Any, **kwargs: Any) -> Any: """Regular forward pass of the (target) model.""" @@ -81,30 +89,29 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: def forward_with_components( self, *args: Any, - components: dict[str, LinearComponentWithBias], masks: dict[str, Float[Tensor, "batch pos m"]] | None = None, **kwargs: Any, ) -> Any: """Forward pass with temporary component replacement.""" old_modules = {} - for module_name, component in components.items(): - old_module = get_nested_module_attr(self.model, module_name) + for module_name, component in self.components.items(): + old_module = self.model.get_submodule(module_name) assert old_module is not None old_modules[module_name] = old_module if masks is not None: assert module_name in masks, f"Mask for {module_name} not found" component.mask = masks[module_name] - set_nested_module_attr(self.model, module_name, component) + self.model.set_submodule(module_name, component) out = self.model(*args, **kwargs) # Restore the original modules for module_name, old_module in old_modules.items(): - set_nested_module_attr(self.model, module_name, old_module) + self.model.set_submodule(module_name, old_module) # Remove the masks attribute from the components - for component in components.values(): + for component in self.components.values(): component.mask = None return out From aa7cacf4764b987a9eb19a342204e64f3bf59fdc Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 8 Apr 2025 05:26:51 +0000 Subject: [PATCH 61/73] Add layerwise recon losses --- spd/experiments/lm/lm_config.yaml | 2 +- spd/experiments/lm/lm_decomposition.py | 184 +++++++++++++++++-------- spd/experiments/lm/models.py | 75 +++++++++- spd/run_spd.py | 4 +- 4 files changed, 202 insertions(+), 63 deletions(-) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 2cd9f93..926808f 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -21,7 +21,7 @@ masked_recon_coeff: null # Reconstruction loss using masks act_recon_coeff: null # Reconstruction loss on intermediate component activations random_mask_recon_coeff: null # Reconstruction loss averaged over random masks layerwise_recon_coeff: null # Layer-wise reconstruction loss -layerwise_random_recon_coeff: null # Layer-wise reconstruction loss with random masks +layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used n_gate_hidden_neurons: null # Not applicable as there are no gates currently diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 09d982f..4f671cd 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -15,7 +15,7 @@ from simple_stories_train.dataloaders import DatasetConfig, create_data_loader from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS -from torch import Tensor +from torch import Tensor, nn from torch.utils.data import DataLoader from tqdm import tqdm @@ -25,7 +25,8 @@ SSModel, ) from spd.log import logger -from spd.run_spd import _calc_param_mse, get_common_run_name_suffix +from spd.models.components import Gate, GateMLP +from spd.run_spd import _calc_param_mse, calc_masks, calc_random_masks, get_common_run_name_suffix from spd.utils import ( get_device, get_lr_schedule_fn, @@ -71,6 +72,24 @@ def lm_plot_results_fn( return fig_dict +def calc_component_acts( + pre_weight_acts: dict[str, Float[Tensor, "... d_in"]], + As: dict[str, Float[nn.Parameter, "d_in m"]], +) -> dict[str, Float[Tensor, "batch m"]]: + """Calculate the component acts for each layer. I.e. (pre_weight_acts @ A). + + Args: + pre_weight_acts: The activations before each layer in the target model. + As: The A matrix at each layer. + """ + component_acts = {} + for param_name in pre_weight_acts: + component_acts[param_name] = einops.einsum( + pre_weight_acts[param_name], As[param_name], "... d_in, ... d_in m -> ... m" + ) + return component_acts + + def calc_recon_mse_lm( out1: Float[Tensor, "batch seq vocab"], out2: Float[Tensor, "batch seq vocab"], @@ -109,6 +128,26 @@ def calc_param_match_loss( return param_mse +def calc_layerwise_recon_loss( + model: SSModel, + batch: Float[Tensor, "batch pos"], + device: str, + masks: list[dict[str, Float[Tensor, "batch pos m"]]], + target_out: Float[Tensor, "batch pos vocab"], +) -> Float[Tensor, ""]: + """Calculate the recon loss when augmenting the model one (masked) component at a time.""" + n_modified_components = len(masks[0]) + total_loss = torch.tensor(0.0, device=device) + for mask_info in masks: + for module_name in mask_info: + modified_out, _ = model.forward_with_component( + batch, module_name=module_name, mask=mask_info[module_name] + ) + loss = calc_recon_mse_lm(modified_out, target_out) + total_loss += loss + return total_loss / (n_modified_components * len(masks)) + + def optimize_lm( model: SSModel, config: Config, @@ -118,6 +157,12 @@ def optimize_lm( out_dir: Path | None, ) -> None: """Run the optimization loop for LM decomposition.""" + + # We used "-" instead of "." as module names can't have "." in them + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items() + } # type: ignore + component_params = [] param_names_to_optimize = [] for name, component in model.components.items(): @@ -184,9 +229,12 @@ def optimize_lm( # out = model(batch) # pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} + (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( + batch, module_names=list(model.components.keys()) + ) As = {module_name: v.linear_component.A for module_name, v in model.components.items()} - # target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # attributions = calc_grad_attributions( # target_out=target_out, # pre_weight_acts=pre_weight_acts, @@ -196,17 +244,18 @@ def optimize_lm( # ) attributions = None - # masks, relud_masks = calc_masks( - # gates=dates, - # target_component_acts=target_component_acts, - # attributions=attributions, - # detach_inputs=False, - # ) + masks, relud_masks = calc_masks( + gates=gates, + target_component_acts=target_component_acts, + attributions=attributions, + detach_inputs=False, + ) # --- Calculate Losses --- # total_loss = torch.tensor(0.0, device=device) loss_terms = {} + ####### param match loss ####### param_match_loss_val = calc_param_match_loss( components=model.components, target_model=model.model, @@ -216,8 +265,35 @@ def optimize_lm( total_loss += config.param_match_coeff * param_match_loss_val loss_terms["loss/parameter_matching"] = param_match_loss_val.item() - # 1. Reconstruction Loss (comparing logits) - if config.out_recon_coeff is not None and config.out_recon_coeff > 0: + ####### layerwise recon loss ####### + if config.layerwise_recon_coeff is not None: + layerwise_recon_loss = calc_layerwise_recon_loss( + model=model, + batch=batch, + device=device, + masks=[masks], + target_out=target_out, + ) + total_loss += config.layerwise_recon_coeff * layerwise_recon_loss + loss_terms["loss/layerwise_reconstruction"] = layerwise_recon_loss.item() + + ####### layerwise random recon loss ####### + if config.layerwise_random_recon_coeff is not None: + layerwise_random_masks = calc_random_masks( + masks=masks, n_random_masks=config.n_random_masks + ) + layerwise_random_recon_loss = calc_layerwise_recon_loss( + model=model, + batch=batch, + device=device, + masks=layerwise_random_masks, + target_out=target_out, + ) + total_loss += config.layerwise_random_recon_coeff * layerwise_random_recon_loss + loss_terms["loss/layerwise_random_reconstruction"] = layerwise_random_recon_loss.item() + + ####### out recon loss ####### + if config.out_recon_coeff is not None: # Get target logits (no gradients needed for target model) with torch.no_grad(): target_logits, _ = model.forward(batch) @@ -225,9 +301,8 @@ def optimize_lm( target_logits = target_logits.detach() # Get component logits - component_logits, _ = model.forward_with_components(batch, components=model.components) + component_logits, _ = model.forward_with_components(batch, masks=masks) - # Ensure shapes match (Batch, SeqLen-1, VocabSize) assert component_logits.shape == target_logits.shape, ( f"Shape mismatch: {component_logits.shape} vs {target_logits.shape}" ) @@ -250,51 +325,43 @@ def optimize_lm( total_loss += config.lp_sparsity_coeff * lp_sparsity_loss_val loss_terms[f"loss/sparsity_l{config.pnorm}_params"] = lp_sparsity_loss_val.item() - # --- Placeholder Losses (Mimicking run_spd.optimize structure) --- - # These require a mechanism for calculating masks specific to the LM setup. - masks = None # Placeholder: Masks are needed for the following losses - masked_recon_loss_val = None - if config.masked_recon_coeff is not None and config.masked_recon_coeff > 0: - logger.warning("masked_recon_loss requires mask calculation implementation.") - # TODO: Calculate masked_recon_loss_val using masks - # e.g., component_logits_masked = model.forward_with_components(..., masks=masks) - # masked_recon_loss_val = calc_recon_mse_lm(component_logits_masked, target_logits) - loss_terms["loss/masked_reconstruction"] = None # Or 0.0 if calculated - - act_recon_loss_val = None - if config.act_recon_coeff is not None and config.act_recon_coeff > 0: - logger.warning("act_recon_loss requires mask and target activation calculation.") - # TODO: Implement act_recon_loss_val - loss_terms["loss/activation_reconstruction"] = None - - random_masks_loss_val = None - if config.random_mask_recon_coeff is not None and config.random_mask_recon_coeff > 0: - logger.warning("random_masks_loss requires mask calculation implementation.") - # TODO: Implement random_masks_loss_val - loss_terms["loss/random_mask_reconstruction"] = None - - layerwise_recon_loss_val = None - if config.layerwise_recon_coeff is not None and config.layerwise_recon_coeff > 0: - logger.warning("layerwise_recon_loss requires mask calculation and layerwise hooks.") - # TODO: Implement layerwise_recon_loss_val - loss_terms["loss/layerwise_reconstruction"] = None - - layerwise_random_recon_loss_val = None - if ( - config.layerwise_random_recon_coeff is not None - and config.layerwise_random_recon_coeff > 0 - ): - logger.warning( - "layerwise_random_recon_loss requires mask calculation and layerwise hooks." - ) - # TODO: Implement layerwise_random_recon_loss_val - loss_terms["loss/layerwise_random_reconstruction"] = None - - # Add placeholder losses to total_loss if they were calculated (currently they are not) - # Example if masked_recon_loss_val was calculated: - # if masked_recon_loss_val is not None: - # total_loss += config.masked_recon_coeff * masked_recon_loss_val - # Repeat for other placeholder losses... + # # --- Placeholder Losses (Mimicking run_spd.optimize structure) --- + # masked_recon_loss_val = None + # if config.masked_recon_coeff is not None and config.masked_recon_coeff > 0: + # logger.warning("masked_recon_loss requires mask calculation implementation.") + # # TODO: Calculate masked_recon_loss_val using masks + # # e.g., component_logits_masked = model.forward_with_components(..., masks=masks) + # # masked_recon_loss_val = calc_recon_mse_lm(component_logits_masked, target_logits) + # loss_terms["loss/masked_reconstruction"] = None # Or 0.0 if calculated + + # act_recon_loss_val = None + # if config.act_recon_coeff is not None and config.act_recon_coeff > 0: + # logger.warning("act_recon_loss requires mask and target activation calculation.") + # # TODO: Implement act_recon_loss_val + # loss_terms["loss/activation_reconstruction"] = None + + # random_masks_loss_val = None + # if config.random_mask_recon_coeff is not None and config.random_mask_recon_coeff > 0: + # logger.warning("random_masks_loss requires mask calculation implementation.") + # # TODO: Implement random_masks_loss_val + # loss_terms["loss/random_mask_reconstruction"] = None + + # layerwise_recon_loss_val = None + # if config.layerwise_recon_coeff is not None and config.layerwise_recon_coeff > 0: + # logger.warning("layerwise_recon_loss requires mask calculation and layerwise hooks.") + # # TODO: Implement layerwise_recon_loss_val + # loss_terms["loss/layerwise_reconstruction"] = None + + # layerwise_random_recon_loss_val = None + # if ( + # config.layerwise_random_recon_coeff is not None + # and config.layerwise_random_recon_coeff > 0 + # ): + # logger.warning( + # "layerwise_random_recon_loss requires mask calculation and layerwise hooks." + # ) + # # TODO: Implement layerwise_random_recon_loss_val + # loss_terms["loss/layerwise_random_reconstruction"] = None log_data["loss/total"] = total_loss.item() log_data.update(loss_terms) @@ -406,6 +473,7 @@ def main( llama_model=model, target_module_patterns=config.task_config.target_module_patterns, m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) ss_model.to(device) logger.info("Model loaded.") diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 271ee4b..9caa41b 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -3,14 +3,16 @@ """ import fnmatch +from functools import partial from typing import Any +import torch import torch.nn as nn from jaxtyping import Float from simple_stories_train.models.llama import Llama from torch import Tensor -from spd.models.components import LinearComponent +from spd.models.components import Gate, GateMLP, LinearComponent class LinearComponentWithBias(nn.Module): @@ -51,12 +53,29 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent class SSModel(nn.Module): """Wrapper around a llama model from SimpleStories for running SPD.""" - def __init__(self, llama_model: Llama, target_module_patterns: list[str], m: int): + def __init__( + self, + llama_model: Llama, + target_module_patterns: list[str], + m: int, + n_gate_hidden_neurons: int | None, + ): super().__init__() self.model = llama_model self.components = self.create_target_components( target_module_patterns=target_module_patterns, m=m ) + + # Use GateMLP if n_gate_hidden_neurons is provided, otherwise use Gate + gate_class = GateMLP if n_gate_hidden_neurons is not None else Gate + gate_kwargs = {"m": m} + if n_gate_hidden_neurons is not None: + gate_kwargs["n_gate_hidden_neurons"] = n_gate_hidden_neurons + + self.gates = nn.ModuleDict() + for name in self.components: + self.gates[name.replace(".", "-")] = gate_class(**gate_kwargs) + # self.setup() def create_target_components( @@ -80,12 +99,35 @@ def to(self, *args: Any, **kwargs: Any) -> "SSModel": self.model.to(*args, **kwargs) for component in self.components.values(): component.to(*args, **kwargs) + for gate in self.gates.values(): + gate.to(*args, **kwargs) return self def forward(self, *args: Any, **kwargs: Any) -> Any: """Regular forward pass of the (target) model.""" return self.model(*args, **kwargs) + def forward_with_component( + self, + *args: Any, + module_name: str, + mask: Float[Tensor, "batch pos m"] | None = None, + **kwargs: Any, + ) -> Any: + """Forward pass with a single component replacement.""" + old_module = self.model.get_submodule(module_name) + assert old_module is not None + + component = self.components[module_name] + self.model.set_submodule(module_name, component) + if mask is not None: + component.mask = mask + + out = self.model(*args, **kwargs) + + self.model.set_submodule(module_name, old_module) + return out + def forward_with_components( self, *args: Any, @@ -115,3 +157,32 @@ def forward_with_components( component.mask = None return out + + def forward_with_pre_forward_cache_hooks( + self, *args: Any, module_names: list[str], **kwargs: Any + ) -> tuple[Any, dict[str, Tensor]]: + """Forward pass with caching at in the input to the modules given by `module_names`. + + Args: + module_names: List of module names to cache the inputs to. + """ + cache = {} + + def cache_hook(module: nn.Module, input: tuple[Tensor, ...], param_name: str) -> Tensor: + cache[param_name] = input[0] + return input[0] + + handles: list[torch.utils.hooks.RemovableHandle] = [] + for module_name in module_names: + module = self.model.get_submodule(module_name) + assert module is not None + handles.append( + module.register_forward_pre_hook(partial(cache_hook, param_name=module_name)) + ) + + out = self.forward(*args, **kwargs) + + for handle in handles: + handle.remove() + + return out, cache diff --git a/spd/run_spd.py b/spd/run_spd.py index 7c4646a..ceb3194 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -17,7 +17,7 @@ from spd.configs import Config from spd.hooks import HookedRootModule from spd.models.base import SPDModel -from spd.models.components import Gate, Linear, LinearComponent +from spd.models.components import Gate, GateMLP, Linear, LinearComponent from spd.module_utils import collect_nested_module_attrs, get_nested_module_attr from spd.utils import calc_recon_mse, get_lr_schedule_fn, get_lr_with_warmup @@ -141,7 +141,7 @@ def calc_act_recon_mse( def calc_masks( - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], target_component_acts: dict[ str, Float[Tensor, "batch m"] | Float[Tensor, "batch n_instances m"] ], From 82b505a494fcf0eeef35972b52f18ebaa47d1bf4 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 8 Apr 2025 05:34:28 +0000 Subject: [PATCH 62/73] Add lp sparsity loss --- spd/experiments/lm/lm_config.yaml | 4 +- spd/experiments/lm/lm_decomposition.py | 67 +++++++++++++------------- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 926808f..2daa3f5 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -13,15 +13,15 @@ m: 3 # Rank of the decomposition / number of components per layer # Set coeffs to null if the loss shouldn't be computed param_match_coeff: 1.0 out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE) -lp_sparsity_coeff: 0.0 # Coefficient for Lp sparsity loss (applied to component params A & B) +lp_sparsity_coeff: 1e-3 # Coefficient for Lp sparsity loss (applied to component params A & B) pnorm: 1.0 # p-value for the Lp sparsity norm (1.0 for L1) +layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks # Placeholder losses (set coeffs to null as they require mask calculation implementation) masked_recon_coeff: null # Reconstruction loss using masks act_recon_coeff: null # Reconstruction loss on intermediate component activations random_mask_recon_coeff: null # Reconstruction loss averaged over random masks layerwise_recon_coeff: null # Layer-wise reconstruction loss -layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used n_gate_hidden_neurons: null # Not applicable as there are no gates currently diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 4f671cd..7a5e639 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -91,8 +91,8 @@ def calc_component_acts( def calc_recon_mse_lm( - out1: Float[Tensor, "batch seq vocab"], - out2: Float[Tensor, "batch seq vocab"], + out1: Float[Tensor, "batch pos vocab"], + out2: Float[Tensor, "batch pos vocab"], ) -> Float[Tensor, ""]: """Calculate the Mean Squared Error reconstruction loss for LM logits.""" assert out1.shape == out2.shape @@ -100,7 +100,7 @@ def calc_recon_mse_lm( return ((out1 - out2) ** 2).sum(dim=-1).mean() -def calc_param_match_loss( +def calc_param_match_loss_lm( components: dict[str, LinearComponentWithBias], target_model: Llama, n_params: int, @@ -128,7 +128,7 @@ def calc_param_match_loss( return param_mse -def calc_layerwise_recon_loss( +def calc_layerwise_recon_loss_lm( model: SSModel, batch: Float[Tensor, "batch pos"], device: str, @@ -148,6 +148,27 @@ def calc_layerwise_recon_loss( return total_loss / (n_modified_components * len(masks)) +def calc_lp_sparsity_loss_lm( + relud_masks: dict[str, Float[Tensor, "batch pos m"]], pnorm: float +) -> Float[Tensor, ""]: + """Calculate the Lp sparsity loss on the attributions. + + Args: + relud_masks: Dictionary of relu masks for each layer. + pnorm: The pnorm to use for the sparsity loss. + Returns: + The Lp sparsity loss. + """ + # Initialize with zeros matching the shape of first mask + total_loss = torch.zeros_like(next(iter(relud_masks.values()))) + + for layer_relud_mask in relud_masks.values(): + total_loss = total_loss + layer_relud_mask**pnorm + + # Sum over the m dimension and mean over the batch and pos dimensions + return total_loss.sum(dim=-1).mean(dim=[0, 1]) + + def optimize_lm( model: SSModel, config: Config, @@ -216,19 +237,6 @@ def optimize_lm( data_iter = iter(dataloader) batch = next(data_iter)["input_ids"].to(device) - # # Forward pass with target model - # target_cache_filter = lambda k: k.endswith((".hook_pre", ".hook_post")) and any( - # k.startswith("model." + module_name) for module_name in model.components - # ) - # target_out, target_cache = model.run_with_cache(batch, names_filter=target_cache_filter) - # I want to do a forward pass on model.model, but applying pre_forward_hooks to each of the - # keys in models.components. The pre_forward_hooks should simply cache the activations that - # go into each of the components. - - # # Forward pass with all subnetworks - # out = model(batch) - - # pre_weight_acts = {k: v for k, v in target_cache.items() if k.endswith("hook_pre")} (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( batch, module_names=list(model.components.keys()) ) @@ -256,7 +264,7 @@ def optimize_lm( loss_terms = {} ####### param match loss ####### - param_match_loss_val = calc_param_match_loss( + param_match_loss_val = calc_param_match_loss_lm( components=model.components, target_model=model.model, n_params=n_params, @@ -267,7 +275,7 @@ def optimize_lm( ####### layerwise recon loss ####### if config.layerwise_recon_coeff is not None: - layerwise_recon_loss = calc_layerwise_recon_loss( + layerwise_recon_loss = calc_layerwise_recon_loss_lm( model=model, batch=batch, device=device, @@ -282,7 +290,7 @@ def optimize_lm( layerwise_random_masks = calc_random_masks( masks=masks, n_random_masks=config.n_random_masks ) - layerwise_random_recon_loss = calc_layerwise_recon_loss( + layerwise_random_recon_loss = calc_layerwise_recon_loss_lm( model=model, batch=batch, device=device, @@ -292,6 +300,11 @@ def optimize_lm( total_loss += config.layerwise_random_recon_coeff * layerwise_random_recon_loss loss_terms["loss/layerwise_random_reconstruction"] = layerwise_random_recon_loss.item() + ####### lp sparsity loss ####### + lp_sparsity_loss = calc_lp_sparsity_loss_lm(relud_masks=relud_masks, pnorm=config.pnorm) + total_loss += config.lp_sparsity_coeff * lp_sparsity_loss + loss_terms["loss/lp_sparsity_loss"] = lp_sparsity_loss.item() + ####### out recon loss ####### if config.out_recon_coeff is not None: # Get target logits (no gradients needed for target model) @@ -311,20 +324,6 @@ def optimize_lm( total_loss += config.out_recon_coeff * recon_loss loss_terms["loss/reconstruction"] = recon_loss.item() - # 2. Sparsity Loss (Lp norm on component parameters) - # Note: Using p=config.pnorm. The original optimize used relud_masks from gates. - lp_sparsity_loss_val = None - if config.lp_sparsity_coeff > 0: - lp_norm = torch.tensor(0.0, device=device) - for component in model.components.values(): - # Apply Lp loss to A and B matrices - lp_norm += torch.norm(component.linear_component.A, p=config.pnorm) - lp_norm += torch.norm(component.linear_component.B, p=config.pnorm) - - lp_sparsity_loss_val = lp_norm - total_loss += config.lp_sparsity_coeff * lp_sparsity_loss_val - loss_terms[f"loss/sparsity_l{config.pnorm}_params"] = lp_sparsity_loss_val.item() - # # --- Placeholder Losses (Mimicking run_spd.optimize structure) --- # masked_recon_loss_val = None # if config.masked_recon_coeff is not None and config.masked_recon_coeff > 0: From 96ae9548c9732350cadccda74fa2bb70698262b7 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 8 Apr 2025 05:57:45 +0000 Subject: [PATCH 63/73] Minor comment and config clean --- spd/experiments/lm/lm_config.yaml | 8 +++--- spd/experiments/lm/lm_decomposition.py | 40 +------------------------- 2 files changed, 5 insertions(+), 43 deletions(-) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 2daa3f5..0d319c4 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -7,7 +7,7 @@ wandb_run_name_prefix: "" # Prefix for generated run name # --- General --- seed: 0 unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here) -m: 3 # Rank of the decomposition / number of components per layer +m: 100 # Rank of the decomposition / number of components per layer # --- Loss Coefficients --- # Set coeffs to null if the loss shouldn't be computed @@ -29,7 +29,7 @@ n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- batch_size: 2 # Adjust based on GPU memory steps: 10_000 # Total training steps -lr: 1e-4 # Learning rate +lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup lr_exponential_halflife: null # Required if lr_schedule is exponential @@ -38,7 +38,7 @@ init_from_target_model: false # Not implemented/applicable for this setup # --- Logging & Saving --- image_freq: 1000 # Frequency for generating/logging plots print_freq: 100 # Frequency for printing logs to console -save_freq: 2000 # Frequency for saving checkpoints +save_freq: 10_000 # Frequency for saving checkpoints image_on_first_step: true # Whether to log plots at step 0 # --- Task Specific --- @@ -50,7 +50,7 @@ task_config: dataset_name: "lennart-finke/SimpleStories" # HuggingFace dataset name dataset_split: "train" # Dataset split to use # List of fnmatch patterns for nn.Linear modules to decompose - target_module_patterns: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] + target_module_patterns: ["transformer.h.2.mlp.gate_proj"] # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"] # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] \ No newline at end of file diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 7a5e639..7839990 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -136,7 +136,6 @@ def calc_layerwise_recon_loss_lm( target_out: Float[Tensor, "batch pos vocab"], ) -> Float[Tensor, ""]: """Calculate the recon loss when augmenting the model one (masked) component at a time.""" - n_modified_components = len(masks[0]) total_loss = torch.tensor(0.0, device=device) for mask_info in masks: for module_name in mask_info: @@ -145,6 +144,7 @@ def calc_layerwise_recon_loss_lm( ) loss = calc_recon_mse_lm(modified_out, target_out) total_loss += loss + n_modified_components = len(masks[0]) return total_loss / (n_modified_components * len(masks)) @@ -324,44 +324,6 @@ def optimize_lm( total_loss += config.out_recon_coeff * recon_loss loss_terms["loss/reconstruction"] = recon_loss.item() - # # --- Placeholder Losses (Mimicking run_spd.optimize structure) --- - # masked_recon_loss_val = None - # if config.masked_recon_coeff is not None and config.masked_recon_coeff > 0: - # logger.warning("masked_recon_loss requires mask calculation implementation.") - # # TODO: Calculate masked_recon_loss_val using masks - # # e.g., component_logits_masked = model.forward_with_components(..., masks=masks) - # # masked_recon_loss_val = calc_recon_mse_lm(component_logits_masked, target_logits) - # loss_terms["loss/masked_reconstruction"] = None # Or 0.0 if calculated - - # act_recon_loss_val = None - # if config.act_recon_coeff is not None and config.act_recon_coeff > 0: - # logger.warning("act_recon_loss requires mask and target activation calculation.") - # # TODO: Implement act_recon_loss_val - # loss_terms["loss/activation_reconstruction"] = None - - # random_masks_loss_val = None - # if config.random_mask_recon_coeff is not None and config.random_mask_recon_coeff > 0: - # logger.warning("random_masks_loss requires mask calculation implementation.") - # # TODO: Implement random_masks_loss_val - # loss_terms["loss/random_mask_reconstruction"] = None - - # layerwise_recon_loss_val = None - # if config.layerwise_recon_coeff is not None and config.layerwise_recon_coeff > 0: - # logger.warning("layerwise_recon_loss requires mask calculation and layerwise hooks.") - # # TODO: Implement layerwise_recon_loss_val - # loss_terms["loss/layerwise_reconstruction"] = None - - # layerwise_random_recon_loss_val = None - # if ( - # config.layerwise_random_recon_coeff is not None - # and config.layerwise_random_recon_coeff > 0 - # ): - # logger.warning( - # "layerwise_random_recon_loss requires mask calculation and layerwise hooks." - # ) - # # TODO: Implement layerwise_random_recon_loss_val - # loss_terms["loss/layerwise_random_reconstruction"] = None - log_data["loss/total"] = total_loss.item() log_data.update(loss_terms) From cb12ed1e7a52581ae86a90fd19cbdc8b053f8981 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 10 Apr 2025 01:52:31 +0000 Subject: [PATCH 64/73] Make components a submodule of SSModel and update model loading --- spd/experiments/lm/component_viz.py | 64 ++++++++++++++++++++++++++ spd/experiments/lm/lm_config.yaml | 28 ++++++++--- spd/experiments/lm/lm_decomposition.py | 38 +++++++++------ spd/experiments/lm/models.py | 32 ++++++------- 4 files changed, 123 insertions(+), 39 deletions(-) create mode 100644 spd/experiments/lm/component_viz.py diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py new file mode 100644 index 0000000..46da671 --- /dev/null +++ b/spd/experiments/lm/component_viz.py @@ -0,0 +1,64 @@ +""" +Vizualises the components of the model. +""" + +from pathlib import Path + +import torch +import wandb +from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS +from wandb.apis.public import Run + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.models import SSModel +from spd.types import WANDB_PATH_PREFIX, ModelPath +from spd.wandb_utils import ( + download_wandb_file, + fetch_latest_wandb_checkpoint, + fetch_wandb_run_dir, +) + + +def load_model( + path: ModelPath, +) -> tuple[SSModel, Config]: + if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): + wandb_path = path.removeprefix(WANDB_PATH_PREFIX) + api = wandb.Api() + run: Run = api.run(wandb_path) + checkpoint_file_obj = fetch_latest_wandb_checkpoint(run, prefix="model") + + run_dir = fetch_wandb_run_dir(run.id) + checkpoint_path = download_wandb_file(run, run_dir, checkpoint_file_obj.name) + + else: + checkpoint_path = Path(path) # local path + + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + config = Config(**checkpoint_dict["config"]) + + assert isinstance(config.task_config, LMTaskConfig) + model_config_dict = MODEL_CONFIGS[config.task_config.model_size] + model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + llama_model = Llama.from_pretrained(model_path, model_config_dict) + + ss_model = SSModel( + llama_model=llama_model, + target_module_patterns=config.task_config.target_module_patterns, + m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, + ) + ss_model.load_state_dict(checkpoint_dict["model"]) + return ss_model, config + + +def main(path: ModelPath) -> None: + ss_model, config = load_model(path) + print(ss_model) + print(config) + + +if __name__ == "__main__": + path = "wandb:spd-lm/runs/60ycavou" + main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 0d319c4..6973703 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -1,20 +1,20 @@ # --- WandB --- -# wandb_project: spd-lm # Project name for Weights & Biases -wandb_project: null # Project name for Weights & Biases +wandb_project: spd-lm +# wandb_project: null # Project name for Weights & Biases wandb_run_name: null # Set specific run name (optional, otherwise generated) wandb_run_name_prefix: "" # Prefix for generated run name # --- General --- seed: 0 unit_norm_matrices: false # Whether to enforce unit norm on A matrices (not typically used here) -m: 100 # Rank of the decomposition / number of components per layer +m: 10000 # Rank of the decomposition / number of components per layer # --- Loss Coefficients --- # Set coeffs to null if the loss shouldn't be computed param_match_coeff: 1.0 out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE) lp_sparsity_coeff: 1e-3 # Coefficient for Lp sparsity loss (applied to component params A & B) -pnorm: 1.0 # p-value for the Lp sparsity norm (1.0 for L1) +pnorm: 2.0 # p-value for the Lp sparsity norm layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks # Placeholder losses (set coeffs to null as they require mask calculation implementation) @@ -27,8 +27,8 @@ n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- -batch_size: 2 # Adjust based on GPU memory -steps: 10_000 # Total training steps +batch_size: 8 # Adjust based on GPU memory +steps: 500 # Total training steps lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup @@ -53,4 +53,18 @@ task_config: target_module_patterns: ["transformer.h.2.mlp.gate_proj"] # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"] # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] - # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] \ No newline at end of file + # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] + +# Config details for the target model taken from https://github.com/danbraunai/simple_stories_train/blob/main/simple_stories_train/models/model_configs.py#L54 + # "1.25M": LlamaConfig( + # block_size=512, + # vocab_size=4096, + # n_layer=4, + # n_head=4, + # n_embd=128, + # n_intermediate=128 * 4 * 2 // 3 = 341, + # rotary_dim=128 // 4 = 32, + # n_ctx=512, + # n_key_value_heads=2, + # flash_attention=True, + # ), \ No newline at end of file diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 7839990..bdce2c6 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -132,15 +132,20 @@ def calc_layerwise_recon_loss_lm( model: SSModel, batch: Float[Tensor, "batch pos"], device: str, + components: dict[str, LinearComponentWithBias], masks: list[dict[str, Float[Tensor, "batch pos m"]]], target_out: Float[Tensor, "batch pos vocab"], ) -> Float[Tensor, ""]: """Calculate the recon loss when augmenting the model one (masked) component at a time.""" total_loss = torch.tensor(0.0, device=device) for mask_info in masks: - for module_name in mask_info: + for component_name, component in components.items(): + module_name = component_name.replace("-", ".") modified_out, _ = model.forward_with_component( - batch, module_name=module_name, mask=mask_info[module_name] + batch, + module_name=module_name, + component=component, + mask=mask_info.get(component_name, None), ) loss = calc_recon_mse_lm(modified_out, target_out) total_loss += loss @@ -183,10 +188,13 @@ def optimize_lm( gates: dict[str, Gate | GateMLP] = { k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items() } # type: ignore + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items() + } # type: ignore component_params = [] param_names_to_optimize = [] - for name, component in model.components.items(): + for name, component in components.items(): component_params.extend(list(component.parameters())) param_names_to_optimize.extend( [f"{name}.{p_name}" for p_name, _ in component.named_parameters()] @@ -204,7 +212,7 @@ def optimize_lm( logger.info(f"Base LR scheduler created: {config.lr_schedule}") n_params = 0 - for module_name in model.components: + for module_name in components: weight = model.model.get_parameter(module_name + ".weight") n_params += weight.numel() @@ -238,9 +246,9 @@ def optimize_lm( batch = next(data_iter)["input_ids"].to(device) (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( - batch, module_names=list(model.components.keys()) + batch, module_names=list(components.keys()) ) - As = {module_name: v.linear_component.A for module_name, v in model.components.items()} + As = {module_name: v.linear_component.A for module_name, v in components.items()} target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # attributions = calc_grad_attributions( @@ -265,7 +273,7 @@ def optimize_lm( ####### param match loss ####### param_match_loss_val = calc_param_match_loss_lm( - components=model.components, + components=components, target_model=model.model, n_params=n_params, device=device, @@ -279,6 +287,7 @@ def optimize_lm( model=model, batch=batch, device=device, + components=components, masks=[masks], target_out=target_out, ) @@ -294,6 +303,7 @@ def optimize_lm( model=model, batch=batch, device=device, + components=components, masks=layerwise_random_masks, target_out=target_out, ) @@ -314,7 +324,9 @@ def optimize_lm( target_logits = target_logits.detach() # Get component logits - component_logits, _ = model.forward_with_components(batch, masks=masks) + component_logits, _ = model.forward_with_components( + batch, components=components, masks=masks + ) assert component_logits.shape == target_logits.shape, ( f"Shape mismatch: {component_logits.shape} vs {target_logits.shape}" @@ -349,7 +361,7 @@ def optimize_lm( with torch.no_grad(): fig_dict = plot_results_fn( model=model, # Pass the SSModel wrapper - components=model.components, + components=components, step=step, out_dir=out_dir, device=device, @@ -371,13 +383,9 @@ def optimize_lm( (config.save_freq is not None and step % config.save_freq == 0 and step > 0) or step == config.steps ) and out_dir is not None: - checkpoint_dir = out_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) - checkpoint_path = checkpoint_dir / f"components_step_{step}.pt" - # Save only component state dicts - component_state_dicts = {n: c.state_dict() for n, c in model.components.items()} + checkpoint_path = out_dir / f"model_{step}.pth" save_payload = { - "components": component_state_dicts, + "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step, "config": config.model_dump(mode="json"), diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 9caa41b..64b847d 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -49,7 +49,6 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent return LinearComponentWithBias(linear_component, bias) -# class SSModel(HookedRootModule): class SSModel(nn.Module): """Wrapper around a llama model from SimpleStories for running SPD.""" @@ -72,17 +71,11 @@ def __init__( if n_gate_hidden_neurons is not None: gate_kwargs["n_gate_hidden_neurons"] = n_gate_hidden_neurons - self.gates = nn.ModuleDict() - for name in self.components: - self.gates[name.replace(".", "-")] = gate_class(**gate_kwargs) + self.gates = nn.ModuleDict({name: gate_class(**gate_kwargs) for name in self.components}) - # self.setup() - - def create_target_components( - self, target_module_patterns: list[str], m: int - ) -> dict[str, LinearComponentWithBias]: + def create_target_components(self, target_module_patterns: list[str], m: int) -> nn.ModuleDict: """Create target components for the model.""" - components = {} + components: dict[str, LinearComponentWithBias] = {} for name, module in self.model.named_modules(): for pattern in target_module_patterns: if fnmatch.fnmatch(name, pattern): @@ -90,9 +83,10 @@ def create_target_components( f"Module '{name}' matched pattern '{pattern}' but is not nn.Linear. " f"Found type: {type(module)}" ) - components[name] = nn_linear_to_components(module, m=m) + # Replace "." with "-" in the name to avoid issues with module dict keys + components[name.replace(".", "-")] = nn_linear_to_components(module, m=m) break - return components + return nn.ModuleDict(components) def to(self, *args: Any, **kwargs: Any) -> "SSModel": """Move the model and components to a device.""" @@ -111,14 +105,15 @@ def forward_with_component( self, *args: Any, module_name: str, + component: LinearComponentWithBias, mask: Float[Tensor, "batch pos m"] | None = None, **kwargs: Any, ) -> Any: """Forward pass with a single component replacement.""" + # Note that module_name uses "." separators but self.components use "-" separators old_module = self.model.get_submodule(module_name) assert old_module is not None - component = self.components[module_name] self.model.set_submodule(module_name, component) if mask is not None: component.mask = mask @@ -131,19 +126,22 @@ def forward_with_component( def forward_with_components( self, *args: Any, + components: dict[str, LinearComponentWithBias], masks: dict[str, Float[Tensor, "batch pos m"]] | None = None, **kwargs: Any, ) -> Any: """Forward pass with temporary component replacement.""" + # Note that components and masks uses "-" separators old_modules = {} - for module_name, component in self.components.items(): + for component_name, component in components.items(): + module_name = component_name.replace("-", ".") + # component: LinearComponentWithBias = self.components[module_name.replace(".", "-")] old_module = self.model.get_submodule(module_name) assert old_module is not None old_modules[module_name] = old_module if masks is not None: - assert module_name in masks, f"Mask for {module_name} not found" - component.mask = masks[module_name] + component.mask = masks.get(component_name, None) self.model.set_submodule(module_name, component) out = self.model(*args, **kwargs) @@ -153,7 +151,7 @@ def forward_with_components( self.model.set_submodule(module_name, old_module) # Remove the masks attribute from the components - for component in self.components.values(): + for component in components.values(): component.mask = None return out From d3a7c76cd22271155d82695aaf078571c2311543 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 10 Apr 2025 04:43:31 +0000 Subject: [PATCH 65/73] Add SSModel.from_pretrained() --- spd/experiments/lm/component_viz.py | 51 ++--------------------------- spd/experiments/lm/models.py | 42 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index 46da671..d0a8571 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -2,59 +2,12 @@ Vizualises the components of the model. """ -from pathlib import Path - -import torch -import wandb -from simple_stories_train.models.llama import Llama -from simple_stories_train.models.model_configs import MODEL_CONFIGS -from wandb.apis.public import Run - -from spd.configs import Config, LMTaskConfig from spd.experiments.lm.models import SSModel -from spd.types import WANDB_PATH_PREFIX, ModelPath -from spd.wandb_utils import ( - download_wandb_file, - fetch_latest_wandb_checkpoint, - fetch_wandb_run_dir, -) - - -def load_model( - path: ModelPath, -) -> tuple[SSModel, Config]: - if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): - wandb_path = path.removeprefix(WANDB_PATH_PREFIX) - api = wandb.Api() - run: Run = api.run(wandb_path) - checkpoint_file_obj = fetch_latest_wandb_checkpoint(run, prefix="model") - - run_dir = fetch_wandb_run_dir(run.id) - checkpoint_path = download_wandb_file(run, run_dir, checkpoint_file_obj.name) - - else: - checkpoint_path = Path(path) # local path - - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - config = Config(**checkpoint_dict["config"]) - - assert isinstance(config.task_config, LMTaskConfig) - model_config_dict = MODEL_CONFIGS[config.task_config.model_size] - model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" - llama_model = Llama.from_pretrained(model_path, model_config_dict) - - ss_model = SSModel( - llama_model=llama_model, - target_module_patterns=config.task_config.target_module_patterns, - m=config.m, - n_gate_hidden_neurons=config.n_gate_hidden_neurons, - ) - ss_model.load_state_dict(checkpoint_dict["model"]) - return ss_model, config +from spd.types import ModelPath def main(path: ModelPath) -> None: - ss_model, config = load_model(path) + ss_model, config = SSModel.from_pretrained(path) print(ss_model) print(config) diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 64b847d..8148b4e 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -4,15 +4,26 @@ import fnmatch from functools import partial +from pathlib import Path from typing import Any import torch import torch.nn as nn +import wandb from jaxtyping import Float from simple_stories_train.models.llama import Llama +from simple_stories_train.models.model_configs import MODEL_CONFIGS from torch import Tensor +from wandb.apis.public import Run +from spd.configs import Config, LMTaskConfig from spd.models.components import Gate, GateMLP, LinearComponent +from spd.types import WANDB_PATH_PREFIX, ModelPath +from spd.wandb_utils import ( + download_wandb_file, + fetch_latest_wandb_checkpoint, + fetch_wandb_run_dir, +) class LinearComponentWithBias(nn.Module): @@ -184,3 +195,34 @@ def cache_hook(module: nn.Module, input: tuple[Tensor, ...], param_name: str) -> handle.remove() return out, cache + + @classmethod + def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config]: + if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): + wandb_path = path.removeprefix(WANDB_PATH_PREFIX) + api = wandb.Api() + run: Run = api.run(wandb_path) + checkpoint_file_obj = fetch_latest_wandb_checkpoint(run, prefix="model") + + run_dir = fetch_wandb_run_dir(run.id) + checkpoint_path = download_wandb_file(run, run_dir, checkpoint_file_obj.name) + + else: + checkpoint_path = Path(path) # local path + + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + config = Config(**checkpoint_dict["config"]) + + assert isinstance(config.task_config, LMTaskConfig) + model_config_dict = MODEL_CONFIGS[config.task_config.model_size] + model_path = f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}" + llama_model = Llama.from_pretrained(model_path, model_config_dict) + + ss_model = SSModel( + llama_model=llama_model, + target_module_patterns=config.task_config.target_module_patterns, + m=config.m, + n_gate_hidden_neurons=config.n_gate_hidden_neurons, + ) + ss_model.load_state_dict(checkpoint_dict["model"]) + return ss_model, config From 1425354292bc2269b811bd3a71b8a66bbeef3222 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 10 Apr 2025 08:37:34 +0000 Subject: [PATCH 66/73] WIP: Fix download with weights_only=True --- spd/experiments/lm/component_viz.py | 125 ++++++++++++++++++++++++- spd/experiments/lm/lm_config.yaml | 2 +- spd/experiments/lm/lm_decomposition.py | 22 ++--- spd/experiments/lm/models.py | 58 ++++++++++-- 4 files changed, 178 insertions(+), 29 deletions(-) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index d0a8571..3e1ef71 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -2,16 +2,133 @@ Vizualises the components of the model. """ -from spd.experiments.lm.models import SSModel +from pathlib import Path + +import torch +from jaxtyping import Float +from matplotlib import pyplot as plt +from simple_stories_train.dataloaders import DatasetConfig, create_data_loader +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.experiments.lm.lm_decomposition import calc_component_acts +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_masks from spd.types import ModelPath +def component_activation_statistics( + model: SSModel, + dataloader: DataLoader[Float[Tensor, "batch pos"]], + n_steps: int, + device: str, + out_dir: Path, +) -> None: + """Get the number and strength of the masks over the full dataset.""" + # We used "-" instead of "." as module names can't have "." in them + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in model.gates.items() + } # type: ignore + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in model.components.items() + } # type: ignore + + n_tokens = {module_name.replace("-", "."): 0 for module_name in components} + total_n_active_components = {module_name.replace("-", "."): 0 for module_name in components} + component_activation_counts = {module_name.replace("-", "."): 0 for module_name in components} + data_iter = iter(dataloader) + for _ in tqdm(range(n_steps), ncols=0): + # --- Get Batch --- # + try: + batch = next(data_iter)["input_ids"].to(device) + except StopIteration: + logger.warning("Dataloader exhausted, resetting iterator.") + data_iter = iter(dataloader) + batch = next(data_iter)["input_ids"].to(device) + + _, pre_weight_acts = model.forward_with_pre_forward_cache_hooks( + batch, module_names=list(components.keys()) + ) + As = {module_name: v.linear_component.A for module_name, v in components.items()} + + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + + masks, relud_masks = calc_masks( + gates=gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=False, + ) + for module_name, mask in masks.items(): + assert mask.ndim == 3 # (batch_size, pos, m) + n_tokens[module_name] += mask.shape[0] * mask.shape[1] + # Count the number of components that are active at all + active_components = mask > 0 + total_n_active_components[module_name] += active_components.sum() + component_activation_counts[module_name] += active_components.sum(dim=(0, 1)) + + # Show the mean number of components + mean_n_active_components_per_token: dict[str, float] = { + module_name: (total_n_active_components[module_name] / n_tokens[module_name]).item() + for module_name in components + } + mean_component_activation_counts: dict[str, Float[Tensor, " m"]] = { + module_name: component_activation_counts[module_name] / n_tokens[module_name] + for module_name in components + } + + print(mean_n_active_components_per_token) + print(mean_component_activation_counts) + for module_name, counts in mean_component_activation_counts.items(): + name = module_name.replace(".", "-") + plt.hist(counts.detach().cpu().numpy(), bins=100) + plt.savefig(out_dir / f"{name}_mean_component_activation_counts.png") + print("Saved plot to", out_dir / f"{name}_mean_component_activation_counts.png") + print("...") + + def main(path: ModelPath) -> None: - ss_model, config = SSModel.from_pretrained(path) - print(ss_model) + device = "cuda" if torch.cuda.is_available() else "cpu" + ss_model, config, checkpoint_dict = SSModel.from_pretrained(path) + ss_model.to(device) + + out_dir = Path(checkpoint_dict["out_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + dataset_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}", + split=config.task_config.dataset_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + column_name="story", + ) + + dataloader, tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + # print(ss_model) print(config) + component_activation_statistics( + model=ss_model, + dataloader=dataloader, + n_steps=100, + device=device, + out_dir=out_dir, + ) + if __name__ == "__main__": - path = "wandb:spd-lm/runs/60ycavou" + path = "wandb:spd-lm/runs/ttpa8pl5" main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 6973703..bfa7898 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -28,7 +28,7 @@ n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- batch_size: 8 # Adjust based on GPU memory -steps: 500 # Total training steps +steps: 1000 # Total training steps lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index bdce2c6..25420a8 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -20,10 +20,7 @@ from tqdm import tqdm from spd.configs import Config, LMTaskConfig -from spd.experiments.lm.models import ( - LinearComponentWithBias, - SSModel, -) +from spd.experiments.lm.models import LinearComponentWithBias, SSModel from spd.log import logger from spd.models.components import Gate, GateMLP from spd.run_spd import _calc_param_mse, calc_masks, calc_random_masks, get_common_run_name_suffix @@ -383,17 +380,14 @@ def optimize_lm( (config.save_freq is not None and step % config.save_freq == 0 and step > 0) or step == config.steps ) and out_dir is not None: - checkpoint_path = out_dir / f"model_{step}.pth" - save_payload = { - "model": model.state_dict(), - "optimizer": optimizer.state_dict(), - "step": step, - "config": config.model_dump(mode="json"), - } - torch.save(save_payload, checkpoint_path) - logger.info(f"Saved checkpoint to {checkpoint_path}") + torch.save(model.state_dict(), out_dir / f"model_{step}.pth") + torch.save(optimizer.state_dict(), out_dir / f"optimizer_{step}.pth") + logger.info(f"Saved model, optimizer, and out_dir to {out_dir}") if config.wandb_project: - wandb.save(str(checkpoint_path), base_path=str(out_dir), policy="now") + wandb.save(str(out_dir / f"model_{step}.pth"), base_path=str(out_dir), policy="now") + wandb.save( + str(out_dir / f"optimizer_{step}.pth"), base_path=str(out_dir), policy="now" + ) # --- Backward Pass & Optimize --- # # Skip gradient step if we are at the last step (last step just for plotting and logging) diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index 8148b4e..e7f76e5 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -10,7 +10,9 @@ import torch import torch.nn as nn import wandb +import yaml from jaxtyping import Float +from pydantic import BaseModel from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS from torch import Tensor @@ -60,6 +62,14 @@ def nn_linear_to_components(linear_module: nn.Linear, m: int) -> LinearComponent return LinearComponentWithBias(linear_component, bias) +class SSModelPaths(BaseModel): + """Paths to output files from a SSModel training run.""" + + model: Path + optimizer: Path + config: Path + + class SSModel(nn.Module): """Wrapper around a llama model from SimpleStories for running SPD.""" @@ -196,22 +206,50 @@ def cache_hook(module: nn.Module, input: tuple[Tensor, ...], param_name: str) -> return out, cache + @staticmethod + def _download_wandb_files(wandb_project_run_id: str) -> SSModelPaths: + """Download the relevant files from a wandb run.""" + api = wandb.Api() + run: Run = api.run(wandb_project_run_id) + + checkpoint = fetch_latest_wandb_checkpoint(run, prefix="model") + + run_dir = fetch_wandb_run_dir(run.id) + + final_config_path = download_wandb_file(run, run_dir, "final_config.yaml") + checkpoint_path = download_wandb_file(run, run_dir, checkpoint.name) + + # Get the step number from the path + step = int(Path(checkpoint_path).stem.split("_")[-1]) + + return SSModelPaths( + model=checkpoint_path, + optimizer=download_wandb_file(run, run_dir, f"optimizer_{step}.pth"), + config=final_config_path, + ) + @classmethod - def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config]: + def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config, Path]: if isinstance(path, str) and path.startswith(WANDB_PATH_PREFIX): wandb_path = path.removeprefix(WANDB_PATH_PREFIX) api = wandb.Api() run: Run = api.run(wandb_path) - checkpoint_file_obj = fetch_latest_wandb_checkpoint(run, prefix="model") - - run_dir = fetch_wandb_run_dir(run.id) - checkpoint_path = download_wandb_file(run, run_dir, checkpoint_file_obj.name) + paths = cls._download_wandb_files(run.id) + out_dir = fetch_wandb_run_dir(run.id) else: - checkpoint_path = Path(path) # local path + # Get the step number from the path + step = int(Path(path).stem.split("_")[-1]) + paths = SSModelPaths( + model=Path(path), + optimizer=Path(path).parent / f"optimizer_{step}.pth", + config=Path(path).parent / "final_config.yaml", + ) + out_dir = Path(path).parent - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - config = Config(**checkpoint_dict["config"]) + model_weights = torch.load(paths.model, map_location="cpu", weights_only=True) + with open(paths.config) as f: + config = Config(**yaml.safe_load(f)) assert isinstance(config.task_config, LMTaskConfig) model_config_dict = MODEL_CONFIGS[config.task_config.model_size] @@ -224,5 +262,5 @@ def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config]: m=config.m, n_gate_hidden_neurons=config.n_gate_hidden_neurons, ) - ss_model.load_state_dict(checkpoint_dict["model"]) - return ss_model, config + ss_model.load_state_dict(model_weights) + return ss_model, config, out_dir From 7a23520941923678f360a41ff3bc86bc2cb22374 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 14 Apr 2025 08:38:26 +0000 Subject: [PATCH 67/73] Calc mask l0 for lms --- spd/experiments/lm/lm_decomposition.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index bdce2c6..9d22eb9 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -26,7 +26,13 @@ ) from spd.log import logger from spd.models.components import Gate, GateMLP -from spd.run_spd import _calc_param_mse, calc_masks, calc_random_masks, get_common_run_name_suffix +from spd.run_spd import ( + _calc_param_mse, + calc_mask_l_zero, + calc_masks, + calc_random_masks, + get_common_run_name_suffix, +) from spd.utils import ( get_device, get_lr_schedule_fn, @@ -349,6 +355,9 @@ def optimize_lm( tqdm.write(f"{name}: {value:.7f}") if config.wandb_project: + mask_l_zero = calc_mask_l_zero(masks=masks) + for layer_name, layer_mask_l_zero in mask_l_zero.items(): + log_data[f"mask_l0/{layer_name}"] = layer_mask_l_zero wandb.log(log_data, step=step) # --- Plotting --- # From 0103c0c9a61ed437e11a875f737bfe8e36ebb5af Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 14 Apr 2025 08:49:30 +0000 Subject: [PATCH 68/73] Fix missing GateMLP type references --- spd/experiments/resid_mlp/resid_mlp_decomposition.py | 10 +++++----- spd/experiments/tms/tms_decomposition.py | 4 ++-- spd/plotting.py | 12 ++++++------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 7d1b72b..c8e09c7 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -23,7 +23,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset from spd.log import logger -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( @@ -107,7 +107,7 @@ def resid_mlp_plot_results_fn( out_dir: Path | None, device: str, config: Config, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], masks: dict[str, Float[Tensor, "batch_size m"]] | None, **_, ) -> dict[str, plt.Figure]: @@ -161,9 +161,9 @@ def init_spd_model_from_target_model( for i in range(target_model.config.n_layers): # For mlp_in, m must equal d_mlp # TODO: This is broken, we shouldn't need m=d_mlp for this function. - assert ( - m == target_model.config.d_mlp or m == target_model.config.d_embed - ), "m must be equal to d_mlp or d_embed" + assert m == target_model.config.d_mlp or m == target_model.config.d_embed, ( + "m must be equal to d_mlp or d_embed" + ) # For mlp_in: A = target weights, B = identity model.layers[i].mlp_in.A.data[:] = target_model.layers[i].mlp_in.weight.data.clone() diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 70ff530..2d70087 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -20,7 +20,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSModelConfig, TMSSPDModel, TMSSPDModelConfig from spd.log import logger -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.plotting import plot_AB_matrices, plot_mask_vals from spd.run_spd import get_common_run_name_suffix, optimize from spd.utils import ( @@ -54,7 +54,7 @@ def make_plots( out_dir: Path, device: str, config: Config, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], masks: dict[str, Float[Tensor, "batch n_instances m"]], batch: Float[Tensor, "batch n_instances n_features"], **_, diff --git a/spd/plotting.py b/spd/plotting.py index 1eb1adf..91d438a 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -10,7 +10,7 @@ from spd.hooks import HookedRootModule from spd.models.base import SPDModel -from spd.models.components import Gate +from spd.models.components import Gate, GateMLP from spd.module_utils import collect_nested_module_attrs from spd.run_spd import calc_component_acts, calc_masks @@ -48,7 +48,7 @@ def permute_to_identity( def plot_mask_vals( model: SPDModel, target_model: HookedRootModule, - gates: dict[str, Gate], + gates: dict[str, Gate | GateMLP], device: str, input_magnitude: float, ) -> tuple[plt.Figure, dict[str, Float[Tensor, "n_instances m"]]]: @@ -146,7 +146,7 @@ def plot_subnetwork_attributions_statistics( ax.set_ylabel("Count") ax.set_xlabel("Number of active subnetworks") - ax.set_title(f"Instance {i+1}") + ax.set_title(f"Instance {i + 1}") # Add value annotations on top of each bar for bar in bars: @@ -212,9 +212,9 @@ def plot_AB_matrices( # Verify that A and B matrices have matching names A_names = set(As.keys()) B_names = set(Bs.keys()) - assert ( - A_names == B_names - ), f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}" + assert A_names == B_names, ( + f"A and B matrices must have matching names. Found A: {A_names}, B: {B_names}" + ) n_layers = len(As) From 60fa3cc816cc9e68f9f694fcecb9745aa339593b Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 17 Apr 2025 00:50:05 +0000 Subject: [PATCH 69/73] Update component_viz for new model format --- spd/experiments/lm/component_viz.py | 31 ++++++++++++++--------------- spd/experiments/lm/models.py | 3 ++- spd/experiments/lm/play.py | 22 ++++++++++++-------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index 3e1ef71..d776357 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -12,6 +12,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from spd.configs import LMTaskConfig from spd.experiments.lm.lm_decomposition import calc_component_acts from spd.experiments.lm.models import LinearComponentWithBias, SSModel from spd.log import logger @@ -38,16 +39,14 @@ def component_activation_statistics( n_tokens = {module_name.replace("-", "."): 0 for module_name in components} total_n_active_components = {module_name.replace("-", "."): 0 for module_name in components} - component_activation_counts = {module_name.replace("-", "."): 0 for module_name in components} + component_activation_counts = { + module_name.replace("-", "."): torch.zeros(model.m, device=device) + for module_name in components + } data_iter = iter(dataloader) for _ in tqdm(range(n_steps), ncols=0): # --- Get Batch --- # - try: - batch = next(data_iter)["input_ids"].to(device) - except StopIteration: - logger.warning("Dataloader exhausted, resetting iterator.") - data_iter = iter(dataloader) - batch = next(data_iter)["input_ids"].to(device) + batch = next(data_iter)["input_ids"].to(device) _, pre_weight_acts = model.forward_with_pre_forward_cache_hooks( batch, module_names=list(components.keys()) @@ -67,12 +66,12 @@ def component_activation_statistics( n_tokens[module_name] += mask.shape[0] * mask.shape[1] # Count the number of components that are active at all active_components = mask > 0 - total_n_active_components[module_name] += active_components.sum() + total_n_active_components[module_name] += int(active_components.sum().item()) component_activation_counts[module_name] += active_components.sum(dim=(0, 1)) # Show the mean number of components mean_n_active_components_per_token: dict[str, float] = { - module_name: (total_n_active_components[module_name] / n_tokens[module_name]).item() + module_name: (total_n_active_components[module_name] / n_tokens[module_name]) for module_name in components } mean_component_activation_counts: dict[str, Float[Tensor, " m"]] = { @@ -80,24 +79,24 @@ def component_activation_statistics( for module_name in components } - print(mean_n_active_components_per_token) - print(mean_component_activation_counts) + logger.info(f"n_components: {model.m}") + logger.info(f"mean_n_active_components_per_token: {mean_n_active_components_per_token}") + logger.info(f"mean_component_activation_counts: {mean_component_activation_counts}") for module_name, counts in mean_component_activation_counts.items(): name = module_name.replace(".", "-") plt.hist(counts.detach().cpu().numpy(), bins=100) plt.savefig(out_dir / f"{name}_mean_component_activation_counts.png") print("Saved plot to", out_dir / f"{name}_mean_component_activation_counts.png") - print("...") def main(path: ModelPath) -> None: device = "cuda" if torch.cuda.is_available() else "cpu" - ss_model, config, checkpoint_dict = SSModel.from_pretrained(path) + ss_model, config, checkpoint_path = SSModel.from_pretrained(path) ss_model.to(device) - out_dir = Path(checkpoint_dict["out_dir"]) - out_dir.mkdir(parents=True, exist_ok=True) + out_dir = checkpoint_path + assert isinstance(config.task_config, LMTaskConfig) dataset_config = DatasetConfig( name=config.task_config.dataset_name, tokenizer_file_path=None, @@ -130,5 +129,5 @@ def main(path: ModelPath) -> None: if __name__ == "__main__": - path = "wandb:spd-lm/runs/ttpa8pl5" + path = "wandb:spd-lm/runs/fuff71ef" main(path) diff --git a/spd/experiments/lm/models.py b/spd/experiments/lm/models.py index e7f76e5..eafc9b4 100644 --- a/spd/experiments/lm/models.py +++ b/spd/experiments/lm/models.py @@ -82,6 +82,7 @@ def __init__( ): super().__init__() self.model = llama_model + self.m = m self.components = self.create_target_components( target_module_patterns=target_module_patterns, m=m ) @@ -234,7 +235,7 @@ def from_pretrained(cls, path: ModelPath) -> tuple["SSModel", Config, Path]: wandb_path = path.removeprefix(WANDB_PATH_PREFIX) api = wandb.Api() run: Run = api.run(wandb_path) - paths = cls._download_wandb_files(run.id) + paths = cls._download_wandb_files(wandb_path) out_dir = fetch_wandb_run_dir(run.id) else: diff --git a/spd/experiments/lm/play.py b/spd/experiments/lm/play.py index f83fabc..1890dee 100644 --- a/spd/experiments/lm/play.py +++ b/spd/experiments/lm/play.py @@ -4,7 +4,7 @@ from simple_stories_train.models.model_configs import MODEL_CONFIGS from transformers import AutoTokenizer -from spd.experiments.lm.models import SSModel, create_target_components +from spd.experiments.lm.models import LinearComponentWithBias, SSModel # %% # Select the model size you want to use @@ -20,14 +20,20 @@ model.eval() # %% -ss_model = SSModel(model) - -m = 17 -# Create components with rank=10 (adjust as needed) -gate_proj_components = create_target_components( - model, rank=m, target_module_patterns=["model.transformer.h.*.mlp.gate_proj"] +ss_model = SSModel( + llama_model=model, + target_module_patterns=["model.transformer.h.*.mlp.gate_proj"], + m=17, + n_gate_hidden_neurons=None, ) +# # Create components with rank=10 (adjust as needed) +# gate_proj_components = create_target_components( +# model, rank=m, target_module_patterns=["model.transformer.h.*.mlp.gate_proj"] +# ) +gate_proj_components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() +} # type: ignore # %% # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) @@ -61,7 +67,7 @@ # Create some dummy masks masks = { - f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], m) + f"model.transformer.h.{i}.mlp.gate_proj": torch.randn(1, input_ids.shape[-1], ss_model.m) for i in range(len(model.transformer.h)) } From 04bcbe177562eccf3059337b77a4d36654445ba0 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 17 Apr 2025 02:37:53 +0000 Subject: [PATCH 70/73] Plot mean components during apd run --- spd/configs.py | 4 +- spd/experiments/lm/component_viz.py | 82 ++++++++++++++++++-------- spd/experiments/lm/lm_config.yaml | 12 ++-- spd/experiments/lm/lm_decomposition.py | 81 ++++++++++++++++--------- 4 files changed, 121 insertions(+), 58 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index 773223c..d039054 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -43,7 +43,9 @@ class LMTaskConfig(BaseModel): max_seq_len: PositiveInt = 512 buffer_size: PositiveInt = 1000 dataset_name: str = "lennart-finke/SimpleStories" - dataset_split: str = "train" + train_data_split: str = "train" + eval_data_split: str = "test" + n_eval_steps: PositiveInt = 100 # List of fnmatch patterns for nn.Linear modules to decompose target_module_patterns: list[str] = ["transformer.h.*.mlp.*_proj"] diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index d776357..33c7c5a 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -2,7 +2,7 @@ Vizualises the components of the model. """ -from pathlib import Path +import math import torch from jaxtyping import Float @@ -10,14 +10,12 @@ from simple_stories_train.dataloaders import DatasetConfig, create_data_loader from torch import Tensor from torch.utils.data import DataLoader -from tqdm import tqdm from spd.configs import LMTaskConfig -from spd.experiments.lm.lm_decomposition import calc_component_acts from spd.experiments.lm.models import LinearComponentWithBias, SSModel from spd.log import logger from spd.models.components import Gate, GateMLP -from spd.run_spd import calc_masks +from spd.run_spd import calc_component_acts, calc_masks from spd.types import ModelPath @@ -26,8 +24,7 @@ def component_activation_statistics( dataloader: DataLoader[Float[Tensor, "batch pos"]], n_steps: int, device: str, - out_dir: Path, -) -> None: +) -> tuple[dict[str, float], dict[str, Float[Tensor, " m"]]]: """Get the number and strength of the masks over the full dataset.""" # We used "-" instead of "." as module names can't have "." in them gates: dict[str, Gate | GateMLP] = { @@ -44,7 +41,7 @@ def component_activation_statistics( for module_name in components } data_iter = iter(dataloader) - for _ in tqdm(range(n_steps), ncols=0): + for _ in range(n_steps): # --- Get Batch --- # batch = next(data_iter)["input_ids"].to(device) @@ -53,7 +50,7 @@ def component_activation_statistics( ) As = {module_name: v.linear_component.A for module_name, v in components.items()} - target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore masks, relud_masks = calc_masks( gates=gates, @@ -79,14 +76,40 @@ def component_activation_statistics( for module_name in components } - logger.info(f"n_components: {model.m}") - logger.info(f"mean_n_active_components_per_token: {mean_n_active_components_per_token}") - logger.info(f"mean_component_activation_counts: {mean_component_activation_counts}") - for module_name, counts in mean_component_activation_counts.items(): - name = module_name.replace(".", "-") - plt.hist(counts.detach().cpu().numpy(), bins=100) - plt.savefig(out_dir / f"{name}_mean_component_activation_counts.png") - print("Saved plot to", out_dir / f"{name}_mean_component_activation_counts.png") + return mean_n_active_components_per_token, mean_component_activation_counts + + +def plot_mean_component_activation_counts( + mean_component_activation_counts: dict[str, Float[Tensor, " m"]], +) -> plt.Figure: + """Plots the mean activation counts for each component module in a grid.""" + n_modules = len(mean_component_activation_counts) + max_cols = 6 + n_cols = min(n_modules, max_cols) + # Calculate the number of rows needed, rounding up + n_rows = math.ceil(n_modules / n_cols) + + # Create a figure with the calculated number of rows and columns + fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows), squeeze=False) + # Ensure axs is always a 2D array for consistent indexing, even if n_modules is 1 + axs = axs.flatten() # Flatten the axes array for easy iteration + + # Iterate through modules and plot each histogram on its corresponding axis + for i, (module_name, counts) in enumerate(mean_component_activation_counts.items()): + ax = axs[i] + ax.hist(counts.detach().cpu().numpy(), bins=100) + ax.set_title(module_name) # Add module name as title to each subplot + ax.set_xlabel("Mean Activation Count") + ax.set_ylabel("Frequency") + + # Hide any unused subplots if the grid isn't perfectly filled + for i in range(n_modules, n_rows * n_cols): + axs[i].axis("off") + + # Adjust layout to prevent overlapping titles/labels + fig.tight_layout() + + return fig def main(path: ModelPath) -> None: @@ -101,7 +124,7 @@ def main(path: ModelPath) -> None: name=config.task_config.dataset_name, tokenizer_file_path=None, hf_tokenizer_path=f"chandan-sreedhara/SimpleStories-{config.task_config.model_size}", - split=config.task_config.dataset_split, + split=config.task_config.train_data_split, n_ctx=config.task_config.max_seq_len, is_tokenized=False, streaming=False, @@ -119,15 +142,26 @@ def main(path: ModelPath) -> None: # print(ss_model) print(config) - component_activation_statistics( - model=ss_model, - dataloader=dataloader, - n_steps=100, - device=device, - out_dir=out_dir, + mean_n_active_components_per_token, mean_component_activation_counts = ( + component_activation_statistics( + model=ss_model, + dataloader=dataloader, + n_steps=100, + device=device, + ) + ) + logger.info(f"n_components: {ss_model.m}") + logger.info(f"mean_n_active_components_per_token: {mean_n_active_components_per_token}") + logger.info(f"mean_component_activation_counts: {mean_component_activation_counts}") + fig = plot_mean_component_activation_counts( + mean_component_activation_counts=mean_component_activation_counts, ) + # Save the entire figure once + save_path = out_dir / "modules_mean_component_activation_counts.png" + fig.savefig(save_path) + logger.info(f"Saved combined plot to {str(save_path)}") if __name__ == "__main__": - path = "wandb:spd-lm/runs/fuff71ef" + path = "wandb:spd-lm/runs/hmjepm9b" main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index bfa7898..04d98d6 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -13,7 +13,7 @@ m: 10000 # Rank of the decomposition / number of components per layer # Set coeffs to null if the loss shouldn't be computed param_match_coeff: 1.0 out_recon_coeff: 0.0 # Reconstruction loss based on output logits (MSE) -lp_sparsity_coeff: 1e-3 # Coefficient for Lp sparsity loss (applied to component params A & B) +lp_sparsity_coeff: 1e-1 # Coefficient for Lp sparsity loss (applied to component params A & B) pnorm: 2.0 # p-value for the Lp sparsity norm layerwise_random_recon_coeff: 1 # Layer-wise reconstruction loss with random masks @@ -27,8 +27,8 @@ n_random_masks: 1 # Number of random masks if random_mask_recon_coeff is used n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- -batch_size: 8 # Adjust based on GPU memory -steps: 1000 # Total training steps +batch_size: 4 # Adjust based on GPU memory +steps: 10_000 # Total training steps lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup @@ -48,9 +48,11 @@ task_config: max_seq_len: 512 # Maximum sequence length for truncation/padding buffer_size: 1000 # Buffer size for streaming dataset shuffling dataset_name: "lennart-finke/SimpleStories" # HuggingFace dataset name - dataset_split: "train" # Dataset split to use + train_data_split: "train" # Dataset split to use + eval_data_split: "test" # Dataset split to use + n_eval_steps: 100 # Number of evaluation steps # List of fnmatch patterns for nn.Linear modules to decompose - target_module_patterns: ["transformer.h.2.mlp.gate_proj"] + target_module_patterns: ["transformer.h.0.mlp.gate_proj"] # Example: Decompose only gate_proj: ["transformer.h.*.mlp.gate_proj"] # Example: Decompose gate_proj and up_proj: ["transformer.h.*.mlp.gate_proj", "transformer.h.*.mlp.up_proj"] # Example: Decompose all MLP layers: ["transformer.h.*.mlp.*_proj"] diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 299dd63..dde635e 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -15,16 +15,21 @@ from simple_stories_train.dataloaders import DatasetConfig, create_data_loader from simple_stories_train.models.llama import Llama from simple_stories_train.models.model_configs import MODEL_CONFIGS -from torch import Tensor, nn +from torch import Tensor from torch.utils.data import DataLoader from tqdm import tqdm from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.component_viz import ( + component_activation_statistics, + plot_mean_component_activation_counts, +) from spd.experiments.lm.models import LinearComponentWithBias, SSModel from spd.log import logger from spd.models.components import Gate, GateMLP from spd.run_spd import ( _calc_param_mse, + calc_component_acts, calc_mask_l_zero, calc_masks, calc_random_masks, @@ -75,24 +80,6 @@ def lm_plot_results_fn( return fig_dict -def calc_component_acts( - pre_weight_acts: dict[str, Float[Tensor, "... d_in"]], - As: dict[str, Float[nn.Parameter, "d_in m"]], -) -> dict[str, Float[Tensor, "batch m"]]: - """Calculate the component acts for each layer. I.e. (pre_weight_acts @ A). - - Args: - pre_weight_acts: The activations before each layer in the target model. - As: The A matrix at each layer. - """ - component_acts = {} - for param_name in pre_weight_acts: - component_acts[param_name] = einops.einsum( - pre_weight_acts[param_name], As[param_name], "... d_in, ... d_in m -> ... m" - ) - return component_acts - - def calc_recon_mse_lm( out1: Float[Tensor, "batch pos vocab"], out2: Float[Tensor, "batch pos vocab"], @@ -181,7 +168,9 @@ def optimize_lm( model: SSModel, config: Config, device: str, - dataloader: DataLoader[tuple[Float[Tensor, "batch pos"], Float[Tensor, "batch pos"]]], + train_loader: DataLoader[Float[Tensor, "batch pos"]], + eval_loader: DataLoader[Float[Tensor, "batch pos"]], + n_eval_steps: int, plot_results_fn: Callable[..., dict[str, plt.Figure]], out_dir: Path | None, ) -> None: @@ -220,7 +209,7 @@ def optimize_lm( n_params += weight.numel() log_data = {} - data_iter = iter(dataloader) + data_iter = iter(train_loader) # Use tqdm directly in the loop, iterate one extra step for final logging/plotting/saving for step in tqdm(range(config.steps + 1), ncols=0): @@ -245,7 +234,7 @@ def optimize_lm( batch = next(data_iter)["input_ids"].to(device) except StopIteration: logger.warning("Dataloader exhausted, resetting iterator.") - data_iter = iter(dataloader) + data_iter = iter(train_loader) batch = next(data_iter)["input_ids"].to(device) (target_out, _), pre_weight_acts = model.forward_with_pre_forward_cache_hooks( @@ -253,7 +242,7 @@ def optimize_lm( ) As = {module_name: v.linear_component.A for module_name, v in components.items()} - target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore # attributions = calc_grad_attributions( # target_out=target_out, # pre_weight_acts=pre_weight_acts, @@ -351,10 +340,16 @@ def optimize_lm( if value is not None: tqdm.write(f"{name}: {value:.7f}") + mean_n_active_components_per_token = component_activation_statistics( + model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device + )[0] + tqdm.write(f"Mean n active components per token: {mean_n_active_components_per_token}") + if config.wandb_project: mask_l_zero = calc_mask_l_zero(masks=masks) for layer_name, layer_mask_l_zero in mask_l_zero.items(): log_data[f"mask_l0/{layer_name}"] = layer_mask_l_zero + log_data["mean_n_active_components_per_token"] = mean_n_active_components_per_token wandb.log(log_data, step=step) # --- Plotting --- # @@ -374,6 +369,14 @@ def optimize_lm( config=config, # Add any other necessary args for plotting like tokenizer, sample text? ) + mean_component_activation_counts = component_activation_statistics( + model=model, dataloader=eval_loader, n_steps=n_eval_steps, device=device + )[1] + fig_dict["mean_component_activation_counts"] = ( + plot_mean_component_activation_counts( + mean_component_activation_counts=mean_component_activation_counts, + ) + ) if config.wandb_project: wandb.log( {k: wandb.Image(v) for k, v in fig_dict.items()}, @@ -472,25 +475,45 @@ def main( # --- Load Data --- # logger.info("Loading dataset...") - dataset_config = DatasetConfig( + train_data_config = DatasetConfig( name=config.task_config.dataset_name, tokenizer_file_path=None, hf_tokenizer_path=model_path, - split=config.task_config.dataset_split, + split=config.task_config.train_data_split, n_ctx=config.task_config.max_seq_len, is_tokenized=False, streaming=False, column_name="story", ) - dataloader, tokenizer = create_data_loader( - dataset_config=dataset_config, + train_loader, tokenizer = create_data_loader( + dataset_config=train_data_config, + batch_size=config.batch_size, + buffer_size=config.task_config.buffer_size, + global_seed=config.seed, + ddp_rank=0, + ddp_world_size=1, + ) + + eval_data_config = DatasetConfig( + name=config.task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=model_path, + split=config.task_config.eval_data_split, + n_ctx=config.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + column_name="story", + ) + eval_loader, _ = create_data_loader( + dataset_config=eval_data_config, batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed, ddp_rank=0, ddp_world_size=1, ) + logger.info("Dataset and tokenizer loaded.") logger.info("Freezing target model parameters...") @@ -503,7 +526,9 @@ def main( model=ss_model, config=config, device=device, - dataloader=dataloader, + train_loader=train_loader, + eval_loader=eval_loader, + n_eval_steps=config.task_config.n_eval_steps, out_dir=out_dir, plot_results_fn=lm_plot_results_fn, ) From c2bdda1c9c42c27890509c95e151403b16fb67e9 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Thu, 17 Apr 2025 04:23:47 +0000 Subject: [PATCH 71/73] Re-organise wandb logging --- spd/experiments/lm/lm_decomposition.py | 6 ++++-- spd/experiments/lm/play.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index dde635e..41f5ce6 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -348,8 +348,10 @@ def optimize_lm( if config.wandb_project: mask_l_zero = calc_mask_l_zero(masks=masks) for layer_name, layer_mask_l_zero in mask_l_zero.items(): - log_data[f"mask_l0/{layer_name}"] = layer_mask_l_zero - log_data["mean_n_active_components_per_token"] = mean_n_active_components_per_token + log_data[f"{layer_name}/mask_l0"] = layer_mask_l_zero + log_data[f"{layer_name}/mean_n_active_components_per_token"] = ( + mean_n_active_components_per_token[layer_name] + ) wandb.log(log_data, step=step) # --- Plotting --- # diff --git a/spd/experiments/lm/play.py b/spd/experiments/lm/play.py index 1890dee..87164f7 100644 --- a/spd/experiments/lm/play.py +++ b/spd/experiments/lm/play.py @@ -52,6 +52,19 @@ # IMPORTANT: Set correct EOS token ID (not the default from tokenizer) eos_token_id = 1 +# %% + +# # Generate text +# with torch.no_grad(): +# output_ids = model.generate( +# idx=input_ids, max_new_tokens=20, temperature=0.7, top_k=40, eos_token_id=eos_token_id +# ) + +# # Decode output +# output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) +# print(f"Generated text:\n{output_text}") + + # %% # logits, _ = ss_model.forward(input_ids, components=gate_proj_components) From 072085e8db49189b7a92436b9db6ea555593402e Mon Sep 17 00:00:00 2001 From: Dan <150014290+danbraunai-apollo@users.noreply.github.com> Date: Tue, 22 Apr 2025 15:56:01 +1000 Subject: [PATCH 72/73] Add streamlit dashboard for lm (#2) * WIP: Add dashboard * Create base_cache_dir if it doesn't exist * Functional dashboard * Add simple-stories-train and datasets to pyproject.toml --- .vscode/launch.json | 12 + pyproject.toml | 4 + spd/experiments/lm/app.py | 405 ++++++++++++++++++++++++++++ spd/experiments/lm/component_viz.py | 2 +- spd/experiments/lm/lm_config.yaml | 4 +- spd/wandb_utils.py | 1 + 6 files changed, 425 insertions(+), 3 deletions(-) create mode 100644 spd/experiments/lm/app.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 89875d2..a753e33 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -48,6 +48,18 @@ "env": { "PYDEVD_DISABLE_FILE_VALIDATION": "1" } + }, + { + "name": "lm streamlit", + "type": "debugpy", + "request": "launch", + "module": "streamlit", + "args": [ + "run", + "${workspaceFolder}/spd/experiments/lm/app.py", + "--server.port", + "2000" + ] } ] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 49e6f8a..4031cfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,10 @@ dependencies = [ "python-dotenv", "wandb<=0.17.7", # due to https://github.com/wandb/wandb/issues/8248 "sympy", + "streamlit", + "streamlit-antd-components", + "datasets", + "simple-stories-train" ] [project.optional-dependencies] diff --git a/spd/experiments/lm/app.py b/spd/experiments/lm/app.py new file mode 100644 index 0000000..d880198 --- /dev/null +++ b/spd/experiments/lm/app.py @@ -0,0 +1,405 @@ +""" +To run this app, run the following command: + +```bash + streamlit run spd/experiments/lm/app.py -- --model_path "wandb:spd-lm/runs/151bsctx" +``` +""" + +import argparse +import html +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from typing import Any + +import streamlit as st +import torch +from datasets import load_dataset +from jaxtyping import Float, Int +from simple_stories_train.dataloaders import DatasetConfig +from torch import Tensor +from transformers import AutoTokenizer + +from spd.configs import Config, LMTaskConfig +from spd.experiments.lm.models import LinearComponentWithBias, SSModel +from spd.log import logger +from spd.models.components import Gate, GateMLP +from spd.run_spd import calc_component_acts, calc_masks +from spd.types import ModelPath + +DEFAULT_MODEL_PATH: ModelPath = "wandb:spd-lm/runs/151bsctx" + + +# ----------------------------------------------------------- +# Dataclass holding everything the app needs +# ----------------------------------------------------------- +@dataclass(frozen=True) +class AppData: + model: SSModel + tokenizer: AutoTokenizer + config: Config + dataloader_iter_fn: Callable[[], Iterator[dict[str, Any]]] + gates: dict[str, Gate | GateMLP] + components: dict[str, LinearComponentWithBias] + target_layer_names: list[str] + device: str + + +# --- Initialization and Data Loading --- +@st.cache_resource(show_spinner="Loading model and data...") +def initialize(model_path: ModelPath) -> AppData: + """ + Loads the model, tokenizer, config, and evaluation dataloader. + Cached by Streamlit based on the model_path. + """ + device = "cpu" # Use CPU for the Streamlit app + logger.info(f"Initializing app with model: {model_path} on device: {device}") + ss_model, config, _ = SSModel.from_pretrained(model_path) + ss_model.to(device) + ss_model.eval() + + task_config = config.task_config + assert isinstance(task_config, LMTaskConfig), "Task config must be LMTaskConfig for this app." + + # Derive tokenizer path (adjust if stored differently) + tokenizer_path = f"chandan-sreedhara/SimpleStories-{task_config.model_size}" + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + add_bos_token=False, + unk_token="[UNK]", + eos_token="[EOS]", + bos_token=None, + ) + + # Create eval dataloader config + eval_data_config = DatasetConfig( + name=task_config.dataset_name, + tokenizer_file_path=None, + hf_tokenizer_path=tokenizer_path, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=False, + streaming=False, # Non-streaming might be simpler for iterator reset + column_name="story", + ) + + # Create the dataloader iterator + def create_dataloader_iter() -> Iterator[dict[str, Any]]: + """ + Returns a *new* iterator each time it is called. + Each element is a dict with: + - "text": the raw document text + - "input_ids": Int[Tensor, "1 seq_len"] + - "offset_mapping": list[tuple[int, int]] + """ + logger.info("Creating new dataloader iterator.") + + # Stream the HF dataset split + dataset = load_dataset( + eval_data_config.name, + streaming=eval_data_config.streaming, + split=eval_data_config.split, + trust_remote_code=False, + ) + + dataset = dataset.with_format("torch") + + text_column = eval_data_config.column_name + + def tokenize_and_prepare(example: dict[str, Any]) -> dict[str, Any]: + original_text: str = example[text_column] + + tokenized = tokenizer( + original_text, + return_tensors="pt", + return_offsets_mapping=True, + truncation=True, + max_length=task_config.max_seq_len, + padding=False, + ) + + input_ids: Int[Tensor, "1 seq_len"] = tokenized["input_ids"] + if input_ids.dim() == 1: # Ensure 2‑D [1, seq_len] + input_ids = input_ids.unsqueeze(0) + + # HF returns offset_mapping as a list per sequence; batch size is 1 + offset_mapping: list[tuple[int, int]] = tokenized["offset_mapping"][0].tolist() + + return { + "text": original_text, + "input_ids": input_ids, + "offset_mapping": offset_mapping, + } + + # Map over the streaming dataset and return an iterator + return map(tokenize_and_prepare, iter(dataset)) + + # Extract components and gates + gates: dict[str, Gate | GateMLP] = { + k.removeprefix("gates.").replace("-", "."): v for k, v in ss_model.gates.items() + } # type: ignore[reportAssignmentType] + components: dict[str, LinearComponentWithBias] = { + k.removeprefix("components.").replace("-", "."): v for k, v in ss_model.components.items() + } # type: ignore[reportAssignmentType] + target_layer_names = sorted(list(components.keys())) + + logger.info(f"Initialization complete for {model_path}.") + return AppData( + model=ss_model, + tokenizer=tokenizer, + config=config, + dataloader_iter_fn=create_dataloader_iter, + gates=gates, + components=components, + target_layer_names=target_layer_names, + device=device, + ) + + +# ----------------------------------------------------------- +# Utility: render the prompt with faint token outlines +# ----------------------------------------------------------- +def render_prompt_with_tokens( + *, + raw_text: str, + offset_mapping: list[tuple[int, int]], + selected_idx: int | None, +) -> None: + """ + Renders `raw_text` inside Streamlit, wrapping each token span with a thin + border. The currently‑selected token receives a thicker red border. + All other tokens get a thin mid‑grey border (no background fill). + """ + html_chunks: list[str] = [] + cursor = 0 + + def esc(s: str) -> str: + return html.escape(s) + + for idx, (start, end) in enumerate(offset_mapping): + if cursor < start: + html_chunks.append(esc(raw_text[cursor:start])) + + token_substr = esc(raw_text[start:end]) + if token_substr: + is_selected = idx == selected_idx + border_style = ( + "2px solid rgb(200,0,0)" if is_selected else "0.5px solid #aaa" # all other tokens + ) + html_chunks.append( + "' + f"{token_substr}" + ) + cursor = end + + if cursor < len(raw_text): + html_chunks.append(esc(raw_text[cursor:])) + + st.markdown( + f'
{"".join(html_chunks)}
', + unsafe_allow_html=True, + ) + + +def load_next_prompt() -> None: + """Loads the next prompt, calculates masks, and prepares token data.""" + logger.info("Loading next prompt.") + app_data: AppData = st.session_state.app_data + dataloader_iter = st.session_state.dataloader_iter # Get current iterator + + try: + batch = next(dataloader_iter) + input_ids: Int[Tensor, "1 seq_len"] = batch["input_ids"].to(app_data.device) + except StopIteration: + logger.warning("Dataloader iterator exhausted. Throwing error.") + st.error("Failed to get data even after resetting dataloader.") + return + + st.session_state.current_input_ids = input_ids + + # Store the original raw prompt text + st.session_state.current_prompt_text = batch["text"] + + # Calculate activations and masks + with torch.no_grad(): + (_, _), pre_weight_acts = app_data.model.forward_with_pre_forward_cache_hooks( + input_ids, module_names=list(app_data.components.keys()) + ) + As = {module_name: v.linear_component.A for module_name, v in app_data.components.items()} + target_component_acts = calc_component_acts(pre_weight_acts=pre_weight_acts, As=As) # type: ignore[reportArgumentType] + masks, _ = calc_masks( + gates=app_data.gates, + target_component_acts=target_component_acts, + attributions=None, + detach_inputs=True, # No gradients needed + ) + st.session_state.current_masks = masks # Dict[str, Float[Tensor, "1 seq_len m"]] + + # Prepare token data for display + token_data = [] + tokenizer = app_data.tokenizer + for i, token_id in enumerate(input_ids[0]): + # Decode individual token - might differ slightly from full decode for spaces etc. + decoded_token_str = tokenizer.decode([token_id]) # type: ignore[reportAttributeAccessIssue] + token_data.append( + { + "id": token_id.item(), + "text": decoded_token_str, + "index": i, + "offset": batch["offset_mapping"][i], # (start, end) + } + ) + st.session_state.token_data = token_data + + # Reset selections + st.session_state.selected_token_index = 0 # default: first token + st.session_state.selected_layer_name = None + logger.info("Finished loading next prompt and calculating masks.") + + +# --- Main App UI --- +def run_app(args: argparse.Namespace) -> None: + """Sets up and runs the Streamlit application.""" + st.set_page_config(layout="wide") + st.title("LM Component Activation Explorer") + + # Initialize model, data, etc. (cached) + st.session_state.app_data = initialize(args.model_path) + app_data: AppData = st.session_state.app_data + st.caption(f"Model: {args.model_path}") + + # Initialize session state variables if they don't exist + if "current_prompt_text" not in st.session_state: + st.session_state.current_prompt_text = None + if "token_data" not in st.session_state: + st.session_state.token_data = None + if "current_masks" not in st.session_state: + st.session_state.current_masks = None + if "selected_token_index" not in st.session_state: + st.session_state.selected_token_index = None + if "selected_layer_name" not in st.session_state: + if app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + else: + st.session_state.selected_layer_name = None + # Initialize the dataloader iterator in session state + if "dataloader_iter" not in st.session_state: + st.session_state.dataloader_iter = app_data.dataloader_iter_fn() + + if st.session_state.current_prompt_text is None: + load_next_prompt() + + # Sidebar container and a single expander for all interactive controls + sidebar = st.sidebar + controls_expander = sidebar.expander("Controls", expanded=True) + + # ------------------------------------------------------------------ + # Sidebar – interactive controls + # ------------------------------------------------------------------ + with controls_expander: + st.button("Load Next Prompt", on_click=load_next_prompt) + + # Render the raw prompt with faint token borders + if st.session_state.token_data and st.session_state.current_prompt_text: + # st.subheader("Prompt") + render_prompt_with_tokens( + raw_text=st.session_state.current_prompt_text, + offset_mapping=[t["offset"] for t in st.session_state.token_data], + selected_idx=st.session_state.selected_token_index, + ) + + # Sidebar slider for token selection + n_tokens = len(st.session_state.token_data) + if n_tokens > 0: + with controls_expander: + st.header("Token selector") + idx = st.slider( + "Token index", + min_value=0, + max_value=n_tokens - 1, + step=1, + key="selected_token_index", + ) + + selected_token = st.session_state.token_data[idx] + st.write(f"Selected token: {selected_token['text']} (ID: {selected_token['id']})") + + st.divider() + + # --- Token Information Area --- + if st.session_state.token_data: + idx = st.session_state.selected_token_index + # Ensure token_data is loaded before accessing + if ( + st.session_state.token_data + and idx is not None + and idx < len(st.session_state.token_data) + ): + # Layer Selection Dropdown + # Always default to the first layer if nothing is selected yet + if st.session_state.selected_layer_name is None and app_data.target_layer_names: + st.session_state.selected_layer_name = app_data.target_layer_names[0] + + with controls_expander: + st.header("Layer selector") + st.selectbox( + "Select Layer to Inspect:", + options=app_data.target_layer_names, + key="selected_layer_name", + ) + + # Display Layer-Specific Info if a layer is selected + if st.session_state.selected_layer_name: + layer_name = st.session_state.selected_layer_name + logger.debug(f"Displaying info for token {idx}, layer {layer_name}") + + if st.session_state.current_masks is None: + st.warning("Masks not calculated yet. Please load a prompt.") + return + + layer_mask_tensor: Float[Tensor, "1 seq_len m"] = st.session_state.current_masks[ + layer_name + ] + token_mask: Float[Tensor, " m"] = layer_mask_tensor[0, idx, :] + + # Find active components (mask > 0) + active_indices_layer: Int[Tensor, " n_active"] = torch.where(token_mask > 0)[0] + n_active_layer = len(active_indices_layer) + + st.metric(f"Active Components in {layer_name}", n_active_layer) + + st.subheader("Active Component Indices") + if n_active_layer > 0: + # Convert to NumPy array and reshape to a column vector (N x 1) + active_indices_np = active_indices_layer.cpu().numpy().reshape(-1, 1) + # Pass the NumPy array directly and configure the column header + st.dataframe(active_indices_np, height=300, use_container_width=False) + else: + st.write("No active components for this token in this layer.") + + # Extensibility Placeholder + st.subheader("Additional Layer/Token Analysis") + st.write( + "Future figures and analyses for this specific layer and token will appear here." + ) + else: + # Handle case where selected_token_index might be invalid after data reload + st.warning("Selected token index is out of bounds. Please select a token again.") + st.session_state.selected_token_index = None # Reset selection + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Streamlit app to explore LM component activations." + ) + parser.add_argument( + "--model_path", + type=str, + default=DEFAULT_MODEL_PATH, + help=f"Path or W&B reference to the trained SSModel. Default: {DEFAULT_MODEL_PATH}", + ) + args = parser.parse_args() + + run_app(args) diff --git a/spd/experiments/lm/component_viz.py b/spd/experiments/lm/component_viz.py index 33c7c5a..1258495 100644 --- a/spd/experiments/lm/component_viz.py +++ b/spd/experiments/lm/component_viz.py @@ -163,5 +163,5 @@ def main(path: ModelPath) -> None: if __name__ == "__main__": - path = "wandb:spd-lm/runs/hmjepm9b" + path = "wandb:spd-lm/runs/151bsctx" main(path) diff --git a/spd/experiments/lm/lm_config.yaml b/spd/experiments/lm/lm_config.yaml index 04d98d6..0b1260f 100644 --- a/spd/experiments/lm/lm_config.yaml +++ b/spd/experiments/lm/lm_config.yaml @@ -28,7 +28,7 @@ n_gate_hidden_neurons: null # Not applicable as there are no gates currently # --- Training --- batch_size: 4 # Adjust based on GPU memory -steps: 10_000 # Total training steps +steps: 1_000 # Total training steps lr: 1e-3 # Learning rate lr_schedule: cosine # LR schedule type (constant, linear, cosine, exponential) lr_warmup_pct: 0.01 # Percentage of steps for linear LR warmup @@ -38,7 +38,7 @@ init_from_target_model: false # Not implemented/applicable for this setup # --- Logging & Saving --- image_freq: 1000 # Frequency for generating/logging plots print_freq: 100 # Frequency for printing logs to console -save_freq: 10_000 # Frequency for saving checkpoints +save_freq: 1_000 # Frequency for saving checkpoints image_on_first_step: true # Whether to log plots at step 0 # --- Task Specific --- diff --git a/spd/wandb_utils.py b/spd/wandb_utils.py index 73d2d67..9f451a7 100644 --- a/spd/wandb_utils.py +++ b/spd/wandb_utils.py @@ -45,6 +45,7 @@ def fetch_wandb_run_dir(run_id: str) -> Path: """ # Default to REPO_ROOT/wandb if SPD_CACHE_DIR not set base_cache_dir = Path(os.environ.get("SPD_CACHE_DIR", REPO_ROOT / "wandb")) + base_cache_dir.mkdir(parents=True, exist_ok=True) # Set default wandb_run_dir wandb_run_dir = base_cache_dir / run_id / "files" From 04a2138dd3d4e7a49c2411ff53631016ba20ea1f Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Tue, 22 Apr 2025 06:00:43 +0000 Subject: [PATCH 73/73] Remove unused set_nested_module_attr function --- spd/module_utils.py | 15 --------------- tests/test_module_utils.py | 15 +-------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/spd/module_utils.py b/spd/module_utils.py index 394c7ef..2fe0666 100644 --- a/spd/module_utils.py +++ b/spd/module_utils.py @@ -28,21 +28,6 @@ def get_nested_module_attr(module: nn.Module, access_string: str) -> Any: return mod -def set_nested_module_attr(module: nn.Module, access_string: str, value: Any) -> None: - """Set a specific attribute by its full, path-like name. - - Args: - module: The module to set the attribute on. - access_string: The full name of the nested attribute to set, with each object separated by periods (e.g. "linear1.A"). - """ - names = access_string.split(".") - try: - mod = reduce(getattr, names[:-1], module) - except AttributeError as err: - raise AttributeError(f"{module} does not have nested attribute {access_string}") from err - setattr(mod, names[-1], value) - - def collect_nested_module_attrs( module: nn.Module, attr_name: str, diff --git a/tests/test_module_utils.py b/tests/test_module_utils.py index cc2711c..a59643d 100644 --- a/tests/test_module_utils.py +++ b/tests/test_module_utils.py @@ -1,7 +1,6 @@ -import torch from torch import nn -from spd.module_utils import get_nested_module_attr, set_nested_module_attr +from spd.module_utils import get_nested_module_attr def test_get_nested_module_attr(): @@ -14,15 +13,3 @@ def __init__(self): module = TestModule() assert get_nested_module_attr(module, "linear1.weight.data").shape == (10, 10) assert get_nested_module_attr(module, "linear2.weight.data").shape == (10, 10) - - -def test_set_nested_module_attr(): - class TestModule(nn.Module): - def __init__(self): - super().__init__() - self.linear1 = nn.Linear(10, 10) - self.linear2 = nn.Linear(10, 10) - - module = TestModule() - set_nested_module_attr(module, "linear1.weight.data", torch.randn(10, 5)) - assert module.linear1.weight.data.shape == (10, 5)