diff --git a/auto_circuit/prune_algos/mask_gradient.py b/auto_circuit/prune_algos/mask_gradient.py index d6a1acc..e62f425 100644 --- a/auto_circuit/prune_algos/mask_gradient.py +++ b/auto_circuit/prune_algos/mask_gradient.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, Optional, Set +from typing import Dict, Literal, Optional, Set, Union import torch as t from torch.nn.functional import log_softmax @@ -10,6 +10,7 @@ from auto_circuit.utils.graph_utils import ( patch_mode, set_all_masks, + set_masks_at_src_idxs, train_mask_mode, ) from auto_circuit.utils.patchable_model import PatchableModel @@ -26,6 +27,7 @@ def mask_gradient_prune_scores( integrated_grad_samples: Optional[int] = None, ablation_type: AblationType = AblationType.RESAMPLE, clean_corrupt: Optional[Literal["clean", "corrupt"]] = "corrupt", + layers: Optional[Union[int, list[int]]] = None ) -> PruneScores: """ Prune scores equal to the gradient of the mask values that interpolates the edges @@ -49,6 +51,9 @@ def mask_gradient_prune_scores( ablation_type: The type of ablation to perform. clean_corrupt: Whether to use the clean or corrupt inputs to calculate the ablations. + layers: If not `None`, we iterate over each layer in the model and compute + scores separately for each. Only used if `ig_samples` is not `None`. Follows + [Marks et al., 2024](https://arxiv.org/abs/2403.19647) Returns: An ordering of the edges by importance to the task. Importance is equal to the @@ -60,8 +65,9 @@ def mask_gradient_prune_scores( [`edge_attribution_patching_prune_scores`][auto_circuit.prune_algos.edge_attribution_patching.edge_attribution_patching_prune_scores]. """ assert (mask_val is not None) ^ (integrated_grad_samples is not None) # ^ means XOR - model = model + assert (layers is None) or (integrated_grad_samples is not None) out_slice = model.out_slice + device = next(model.parameters()).device src_outs: Dict[BatchKey, t.Tensor] = batch_src_ablations( model, @@ -69,48 +75,61 @@ def mask_gradient_prune_scores( ablation_type=ablation_type, clean_corrupt=clean_corrupt, ) - + prune_scores = model.new_prune_scores() with train_mask_mode(model): - for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0) + 1))): - ig_pbar.set_description_str(f"Sample: {sample}") - # Interpolate the mask value if integrating gradients. Else set the value. - if integrated_grad_samples is not None: - set_all_masks(model, val=sample / integrated_grad_samples) - else: - assert mask_val is not None and integrated_grad_samples is None - set_all_masks(model, val=mask_val) - - for batch in dataloader: - patch_src_outs = src_outs[batch.key].clone().detach() - with patch_mode(model, patch_src_outs): - logits = model(batch.clean)[out_slice] - if grad_function == "logit": - token_vals = logits - elif grad_function == "prob": - token_vals = t.softmax(logits, dim=-1) - elif grad_function == "logprob": - token_vals = log_softmax(logits, dim=-1) - elif grad_function == "logit_exp": - numerator = t.exp(logits) - denominator = numerator.sum(dim=-1, keepdim=True) - token_vals = numerator / denominator.detach() - else: - raise ValueError(f"Unknown grad_function: {grad_function}") - - if answer_function == "avg_diff": - loss = -batch_avg_answer_diff(token_vals, batch) - elif answer_function == "avg_val": - loss = -batch_avg_answer_val(token_vals, batch) - elif answer_function == "mse": - loss = t.nn.functional.mse_loss(token_vals, batch.answers) - else: - raise ValueError(f"Unknown answer_function: {answer_function}") - - loss.backward() + layers_iter = layers if isinstance(layers, list) else range((layers or 0)+1) + for layer in (layer_bar := tqdm(layers_iter)): + layer_bar.set_description_str(f"Layer: {layer}") + src_idxs = t.tensor( + [src.src_idx for src in model.srcs if src.layer == layer], + device=device + ) + max_src_idx = t.max(src_idxs).item() if layers else 0 + score_slice = src_idxs if layers else slice(None) + for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0)+1))): + ig_pbar.set_description_str(f"Sample: {sample}") + # Interpolate the mask value if integrating gradients. Else set the value. + if integrated_grad_samples is not None: + val = sample / integrated_grad_samples + else: + val = mask_val + # Set the mask value at layer if layer. Else set the value for all layers. + if layers is not None: + set_all_masks(model, val=0) + set_masks_at_src_idxs(model, val=val, src_idxs=src_idxs) + else: + set_all_masks(model, val=val) + for batch in dataloader: + patch_src_outs = src_outs[batch.key].clone().detach() + with patch_mode(model, patch_src_outs): + logits = model(batch.clean)[out_slice] + if grad_function == "logit": + token_vals = logits + elif grad_function == "prob": + token_vals = t.softmax(logits, dim=-1) + elif grad_function == "logprob": + token_vals = log_softmax(logits, dim=-1) + elif grad_function == "logit_exp": + numerator = t.exp(logits) + denominator = numerator.sum(dim=-1, keepdim=True) + token_vals = numerator / denominator.detach() + else: + raise ValueError(f"Unknown grad_function: {grad_function}") - prune_scores: PruneScores = {} - for dest_wrapper in model.dest_wrappers: - grad = dest_wrapper.patch_mask.grad - assert grad is not None - prune_scores[dest_wrapper.module_name] = grad.detach().clone() + if answer_function == "avg_diff": + loss = -batch_avg_answer_diff(token_vals, batch) + elif answer_function == "avg_val": + loss = -batch_avg_answer_val(token_vals, batch) + elif answer_function == "mse": + loss = t.nn.functional.mse_loss(token_vals, batch.answers) + else: + raise ValueError(f"Unknown answer_function: {answer_function}") + loss.backward() + # set scores from layer (or all scores if layers is None) + for dest_wrapper in model.dest_wrappers: + if dest_wrapper.in_srcs.stop > max_src_idx: + grad = dest_wrapper.patch_mask.grad + assert grad is not None + scores = grad.detach().clone()[..., score_slice] + prune_scores[dest_wrapper.module_name][..., score_slice] = scores return prune_scores diff --git a/auto_circuit/utils/graph_utils.py b/auto_circuit/utils/graph_utils.py index 1b3344b..5990b44 100644 --- a/auto_circuit/utils/graph_utils.py +++ b/auto_circuit/utils/graph_utils.py @@ -337,6 +337,40 @@ def set_all_masks(model: PatchableModel, val: float) -> None: t.nn.init.constant_(wrapper.patch_mask, val) +def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: t.Tensor) -> None: + """ + Set all the patch masks with the specified src_idxs to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + src_idxs: The src_idxs to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + max_src_idx = max(src_idxs) + for wrapper in model.dest_wrappers: + if wrapper.in_srcs.stop > max_src_idx: # downstream of src + with t.no_grad(): + wrapper.patch_mask.index_fill_(dim=-1, index=src_idxs, value=val) + +def set_masks_at_layer(model: PatchableModel, val: float, layer: int) -> None: + """ + Set all the patch masks with srcs at layer to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + layer: The layer to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + src_idxs_at_layer = [src.src_idx for src in model.srcs if src.layer == layer] + set_masks_at_src_idxs(model, val, src_idxs_at_layer) + + @contextmanager def train_mask_mode( model: PatchableModel, requires_grad: bool = True