From ce18d48e9d4e984b6ec09ca25fc3fbee9f6a1926 Mon Sep 17 00:00:00 2001 From: oliveradk Date: Wed, 21 Aug 2024 09:24:54 -0700 Subject: [PATCH 1/3] option for integrated gradients by layer --- auto_circuit/prune_algos/mask_gradient.py | 101 +++++++++++++--------- auto_circuit/utils/graph_utils.py | 34 ++++++++ 2 files changed, 92 insertions(+), 43 deletions(-) diff --git a/auto_circuit/prune_algos/mask_gradient.py b/auto_circuit/prune_algos/mask_gradient.py index d6a1acc..101d2dd 100644 --- a/auto_circuit/prune_algos/mask_gradient.py +++ b/auto_circuit/prune_algos/mask_gradient.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, Optional, Set +from typing import Dict, Literal, Optional, Set, Union import torch as t from torch.nn.functional import log_softmax @@ -10,6 +10,7 @@ from auto_circuit.utils.graph_utils import ( patch_mode, set_all_masks, + set_masks_at_src_idxs, train_mask_mode, ) from auto_circuit.utils.patchable_model import PatchableModel @@ -26,6 +27,7 @@ def mask_gradient_prune_scores( integrated_grad_samples: Optional[int] = None, ablation_type: AblationType = AblationType.RESAMPLE, clean_corrupt: Optional[Literal["clean", "corrupt"]] = "corrupt", + layers: Optional[Union[int, list[int]]] = None ) -> PruneScores: """ Prune scores equal to the gradient of the mask values that interpolates the edges @@ -49,6 +51,9 @@ def mask_gradient_prune_scores( ablation_type: The type of ablation to perform. clean_corrupt: Whether to use the clean or corrupt inputs to calculate the ablations. + layers: If not `None`, we iterate over each layer in the model and compute + scores separately for each. Only used if `ig_samples` is not `None`. Follows + [Marks et al., 2024](https://arxiv.org/abs/2403.19647) Returns: An ordering of the edges by importance to the task. Importance is equal to the @@ -60,6 +65,7 @@ def mask_gradient_prune_scores( [`edge_attribution_patching_prune_scores`][auto_circuit.prune_algos.edge_attribution_patching.edge_attribution_patching_prune_scores]. """ assert (mask_val is not None) ^ (integrated_grad_samples is not None) # ^ means XOR + assert (layers is None) or (integrated_grad_samples is not None) model = model out_slice = model.out_slice @@ -69,48 +75,57 @@ def mask_gradient_prune_scores( ablation_type=ablation_type, clean_corrupt=clean_corrupt, ) - + prune_scores = model.new_prune_scores() with train_mask_mode(model): - for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0) + 1))): - ig_pbar.set_description_str(f"Sample: {sample}") - # Interpolate the mask value if integrating gradients. Else set the value. - if integrated_grad_samples is not None: - set_all_masks(model, val=sample / integrated_grad_samples) - else: - assert mask_val is not None and integrated_grad_samples is None - set_all_masks(model, val=mask_val) - - for batch in dataloader: - patch_src_outs = src_outs[batch.key].clone().detach() - with patch_mode(model, patch_src_outs): - logits = model(batch.clean)[out_slice] - if grad_function == "logit": - token_vals = logits - elif grad_function == "prob": - token_vals = t.softmax(logits, dim=-1) - elif grad_function == "logprob": - token_vals = log_softmax(logits, dim=-1) - elif grad_function == "logit_exp": - numerator = t.exp(logits) - denominator = numerator.sum(dim=-1, keepdim=True) - token_vals = numerator / denominator.detach() - else: - raise ValueError(f"Unknown grad_function: {grad_function}") - - if answer_function == "avg_diff": - loss = -batch_avg_answer_diff(token_vals, batch) - elif answer_function == "avg_val": - loss = -batch_avg_answer_val(token_vals, batch) - elif answer_function == "mse": - loss = t.nn.functional.mse_loss(token_vals, batch.answers) - else: - raise ValueError(f"Unknown answer_function: {answer_function}") - - loss.backward() + layers_iter = range((layers or 0)+1) if isinstance(layers, int) else layers + for layer in (layer_bar := tqdm(layers_iter)): + layer_bar.set_description_str(f"Layer: {layer}") + src_idxs = [src.src_idx for src in model.srcs if src.layer == layer] + max_src_idx = max(src_idxs) if layers else 0 + score_slice = src_idxs if layers else slice(None) + for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0)+1))): + ig_pbar.set_description_str(f"Sample: {sample}") + # Interpolate the mask value if integrating gradients. Else set the value. + if integrated_grad_samples is not None: + val = sample / integrated_grad_samples + else: + val = mask_val + # Set the mask value at layer if layer. Else set the value for all layers. + if layers is not None: + set_all_masks(model, val=0) + set_masks_at_src_idxs(model, val=val, src_idxs=src_idxs) + else: + set_all_masks(model, val=val) + for batch in dataloader: + patch_src_outs = src_outs[batch.key].clone().detach() + with patch_mode(model, patch_src_outs): + logits = model(batch.clean)[out_slice] + if grad_function == "logit": + token_vals = logits + elif grad_function == "prob": + token_vals = t.softmax(logits, dim=-1) + elif grad_function == "logprob": + token_vals = log_softmax(logits, dim=-1) + elif grad_function == "logit_exp": + numerator = t.exp(logits) + denominator = numerator.sum(dim=-1, keepdim=True) + token_vals = numerator / denominator.detach() + else: + raise ValueError(f"Unknown grad_function: {grad_function}") - prune_scores: PruneScores = {} - for dest_wrapper in model.dest_wrappers: - grad = dest_wrapper.patch_mask.grad - assert grad is not None - prune_scores[dest_wrapper.module_name] = grad.detach().clone() + if answer_function == "avg_diff": + loss = -batch_avg_answer_diff(token_vals, batch) + elif answer_function == "avg_val": + loss = -batch_avg_answer_val(token_vals, batch) + elif answer_function == "mse": + loss = t.nn.functional.mse_loss(token_vals, batch.answers) + else: + raise ValueError(f"Unknown answer_function: {answer_function}") + # set scores from layer (or all scores if layers is None) + for dest_wrapper in model.dest_wrappers: + if dest_wrapper.in_srcs.stop >= max_src_idx: + grad = dest_wrapper.patch_mask.grad + assert grad is not None + scores = grad.detach().clone()[..., score_slice] + prune_scores[dest_wrapper.module_name][..., score_slice] = scores return prune_scores diff --git a/auto_circuit/utils/graph_utils.py b/auto_circuit/utils/graph_utils.py index 1b3344b..5e02e44 100644 --- a/auto_circuit/utils/graph_utils.py +++ b/auto_circuit/utils/graph_utils.py @@ -337,6 +337,40 @@ def set_all_masks(model: PatchableModel, val: float) -> None: t.nn.init.constant_(wrapper.patch_mask, val) +def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: Collection[int]) -> None: + """ + Set all the patch masks with the specified src_idxs to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + src_idxs: The src_idxs to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + max_src_idx = max(src_idxs) + for wrapper in model.dest_wrappers: + if wrapper.in_srcs.stop >= max_src_idx: # downstream of src + with t.no_grad(): + wrapper.patch_mask.data[..., src_idxs] = val + +def set_masks_at_layer(model: PatchableModel, val: float, layer: int) -> None: + """ + Set all the patch masks with srcs at layer to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + layer: The layer to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + src_idxs_at_layer = [src.src_idx for src in model.srcs if src.layer == layer] + set_masks_at_src_idxs(model, val, src_idxs_at_layer) + + @contextmanager def train_mask_mode( model: PatchableModel, requires_grad: bool = True From 3c02fcd86f66dbab2c923ecfba571b7d49bbd7a8 Mon Sep 17 00:00:00 2001 From: oliveradk Date: Wed, 21 Aug 2024 09:24:54 -0700 Subject: [PATCH 2/3] option for integrated gradients by layer --- auto_circuit/prune_algos/mask_gradient.py | 102 +++++++++++++--------- auto_circuit/utils/graph_utils.py | 34 ++++++++ 2 files changed, 93 insertions(+), 43 deletions(-) diff --git a/auto_circuit/prune_algos/mask_gradient.py b/auto_circuit/prune_algos/mask_gradient.py index d6a1acc..923b892 100644 --- a/auto_circuit/prune_algos/mask_gradient.py +++ b/auto_circuit/prune_algos/mask_gradient.py @@ -1,4 +1,4 @@ -from typing import Dict, Literal, Optional, Set +from typing import Dict, Literal, Optional, Set, Union import torch as t from torch.nn.functional import log_softmax @@ -10,6 +10,7 @@ from auto_circuit.utils.graph_utils import ( patch_mode, set_all_masks, + set_masks_at_src_idxs, train_mask_mode, ) from auto_circuit.utils.patchable_model import PatchableModel @@ -26,6 +27,7 @@ def mask_gradient_prune_scores( integrated_grad_samples: Optional[int] = None, ablation_type: AblationType = AblationType.RESAMPLE, clean_corrupt: Optional[Literal["clean", "corrupt"]] = "corrupt", + layers: Optional[Union[int, list[int]]] = None ) -> PruneScores: """ Prune scores equal to the gradient of the mask values that interpolates the edges @@ -49,6 +51,9 @@ def mask_gradient_prune_scores( ablation_type: The type of ablation to perform. clean_corrupt: Whether to use the clean or corrupt inputs to calculate the ablations. + layers: If not `None`, we iterate over each layer in the model and compute + scores separately for each. Only used if `ig_samples` is not `None`. Follows + [Marks et al., 2024](https://arxiv.org/abs/2403.19647) Returns: An ordering of the edges by importance to the task. Importance is equal to the @@ -60,6 +65,7 @@ def mask_gradient_prune_scores( [`edge_attribution_patching_prune_scores`][auto_circuit.prune_algos.edge_attribution_patching.edge_attribution_patching_prune_scores]. """ assert (mask_val is not None) ^ (integrated_grad_samples is not None) # ^ means XOR + assert (layers is None) or (integrated_grad_samples is not None) model = model out_slice = model.out_slice @@ -69,48 +75,58 @@ def mask_gradient_prune_scores( ablation_type=ablation_type, clean_corrupt=clean_corrupt, ) - + prune_scores = model.new_prune_scores() with train_mask_mode(model): - for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0) + 1))): - ig_pbar.set_description_str(f"Sample: {sample}") - # Interpolate the mask value if integrating gradients. Else set the value. - if integrated_grad_samples is not None: - set_all_masks(model, val=sample / integrated_grad_samples) - else: - assert mask_val is not None and integrated_grad_samples is None - set_all_masks(model, val=mask_val) - - for batch in dataloader: - patch_src_outs = src_outs[batch.key].clone().detach() - with patch_mode(model, patch_src_outs): - logits = model(batch.clean)[out_slice] - if grad_function == "logit": - token_vals = logits - elif grad_function == "prob": - token_vals = t.softmax(logits, dim=-1) - elif grad_function == "logprob": - token_vals = log_softmax(logits, dim=-1) - elif grad_function == "logit_exp": - numerator = t.exp(logits) - denominator = numerator.sum(dim=-1, keepdim=True) - token_vals = numerator / denominator.detach() - else: - raise ValueError(f"Unknown grad_function: {grad_function}") - - if answer_function == "avg_diff": - loss = -batch_avg_answer_diff(token_vals, batch) - elif answer_function == "avg_val": - loss = -batch_avg_answer_val(token_vals, batch) - elif answer_function == "mse": - loss = t.nn.functional.mse_loss(token_vals, batch.answers) - else: - raise ValueError(f"Unknown answer_function: {answer_function}") - - loss.backward() + layers_iter = layers if isinstance(layers, list) else range((layers or 0)+1) + for layer in (layer_bar := tqdm(layers_iter)): + layer_bar.set_description_str(f"Layer: {layer}") + src_idxs = [src.src_idx for src in model.srcs if src.layer == layer] + max_src_idx = max(src_idxs) if layers else 0 + score_slice = src_idxs if layers else slice(None) + for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0)+1))): + ig_pbar.set_description_str(f"Sample: {sample}") + # Interpolate the mask value if integrating gradients. Else set the value. + if integrated_grad_samples is not None: + val = sample / integrated_grad_samples + else: + val = mask_val + # Set the mask value at layer if layer. Else set the value for all layers. + if layers is not None: + set_all_masks(model, val=0) + set_masks_at_src_idxs(model, val=val, src_idxs=src_idxs) + else: + set_all_masks(model, val=val) + for batch in dataloader: + patch_src_outs = src_outs[batch.key].clone().detach() + with patch_mode(model, patch_src_outs): + logits = model(batch.clean)[out_slice] + if grad_function == "logit": + token_vals = logits + elif grad_function == "prob": + token_vals = t.softmax(logits, dim=-1) + elif grad_function == "logprob": + token_vals = log_softmax(logits, dim=-1) + elif grad_function == "logit_exp": + numerator = t.exp(logits) + denominator = numerator.sum(dim=-1, keepdim=True) + token_vals = numerator / denominator.detach() + else: + raise ValueError(f"Unknown grad_function: {grad_function}") - prune_scores: PruneScores = {} - for dest_wrapper in model.dest_wrappers: - grad = dest_wrapper.patch_mask.grad - assert grad is not None - prune_scores[dest_wrapper.module_name] = grad.detach().clone() + if answer_function == "avg_diff": + loss = -batch_avg_answer_diff(token_vals, batch) + elif answer_function == "avg_val": + loss = -batch_avg_answer_val(token_vals, batch) + elif answer_function == "mse": + loss = t.nn.functional.mse_loss(token_vals, batch.answers) + else: + raise ValueError(f"Unknown answer_function: {answer_function}") + loss.backward() + # set scores from layer (or all scores if layers is None) + for dest_wrapper in model.dest_wrappers: + if dest_wrapper.in_srcs.stop >= max_src_idx: + grad = dest_wrapper.patch_mask.grad + assert grad is not None + scores = grad.detach().clone()[..., score_slice] + prune_scores[dest_wrapper.module_name][..., score_slice] = scores return prune_scores diff --git a/auto_circuit/utils/graph_utils.py b/auto_circuit/utils/graph_utils.py index 1b3344b..5e02e44 100644 --- a/auto_circuit/utils/graph_utils.py +++ b/auto_circuit/utils/graph_utils.py @@ -337,6 +337,40 @@ def set_all_masks(model: PatchableModel, val: float) -> None: t.nn.init.constant_(wrapper.patch_mask, val) +def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: Collection[int]) -> None: + """ + Set all the patch masks with the specified src_idxs to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + src_idxs: The src_idxs to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + max_src_idx = max(src_idxs) + for wrapper in model.dest_wrappers: + if wrapper.in_srcs.stop >= max_src_idx: # downstream of src + with t.no_grad(): + wrapper.patch_mask.data[..., src_idxs] = val + +def set_masks_at_layer(model: PatchableModel, val: float, layer: int) -> None: + """ + Set all the patch masks with srcs at layer to the specified value. + + Args: + model: The patchable model to alter. + val: The value to set the patch masks to. + layer: The layer to set the patch masks at. + + Warning: + This function modifies the state of the model! This is a likely source of bugs. + """ + src_idxs_at_layer = [src.src_idx for src in model.srcs if src.layer == layer] + set_masks_at_src_idxs(model, val, src_idxs_at_layer) + + @contextmanager def train_mask_mode( model: PatchableModel, requires_grad: bool = True From 9a5c3127f3c34ec55e3409570beca2b4d1031759 Mon Sep 17 00:00:00 2001 From: oliveradk Date: Fri, 30 Aug 2024 07:40:25 -0700 Subject: [PATCH 3/3] use index_fill and fixed mlp indexing (for ig by layer) --- auto_circuit/prune_algos/mask_gradient.py | 11 +++++++---- auto_circuit/utils/graph_utils.py | 6 +++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/auto_circuit/prune_algos/mask_gradient.py b/auto_circuit/prune_algos/mask_gradient.py index 923b892..e62f425 100644 --- a/auto_circuit/prune_algos/mask_gradient.py +++ b/auto_circuit/prune_algos/mask_gradient.py @@ -66,8 +66,8 @@ def mask_gradient_prune_scores( """ assert (mask_val is not None) ^ (integrated_grad_samples is not None) # ^ means XOR assert (layers is None) or (integrated_grad_samples is not None) - model = model out_slice = model.out_slice + device = next(model.parameters()).device src_outs: Dict[BatchKey, t.Tensor] = batch_src_ablations( model, @@ -80,8 +80,11 @@ def mask_gradient_prune_scores( layers_iter = layers if isinstance(layers, list) else range((layers or 0)+1) for layer in (layer_bar := tqdm(layers_iter)): layer_bar.set_description_str(f"Layer: {layer}") - src_idxs = [src.src_idx for src in model.srcs if src.layer == layer] - max_src_idx = max(src_idxs) if layers else 0 + src_idxs = t.tensor( + [src.src_idx for src in model.srcs if src.layer == layer], + device=device + ) + max_src_idx = t.max(src_idxs).item() if layers else 0 score_slice = src_idxs if layers else slice(None) for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0)+1))): ig_pbar.set_description_str(f"Sample: {sample}") @@ -124,7 +127,7 @@ def mask_gradient_prune_scores( loss.backward() # set scores from layer (or all scores if layers is None) for dest_wrapper in model.dest_wrappers: - if dest_wrapper.in_srcs.stop >= max_src_idx: + if dest_wrapper.in_srcs.stop > max_src_idx: grad = dest_wrapper.patch_mask.grad assert grad is not None scores = grad.detach().clone()[..., score_slice] diff --git a/auto_circuit/utils/graph_utils.py b/auto_circuit/utils/graph_utils.py index 5e02e44..5990b44 100644 --- a/auto_circuit/utils/graph_utils.py +++ b/auto_circuit/utils/graph_utils.py @@ -337,7 +337,7 @@ def set_all_masks(model: PatchableModel, val: float) -> None: t.nn.init.constant_(wrapper.patch_mask, val) -def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: Collection[int]) -> None: +def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: t.Tensor) -> None: """ Set all the patch masks with the specified src_idxs to the specified value. @@ -351,9 +351,9 @@ def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: Collectio """ max_src_idx = max(src_idxs) for wrapper in model.dest_wrappers: - if wrapper.in_srcs.stop >= max_src_idx: # downstream of src + if wrapper.in_srcs.stop > max_src_idx: # downstream of src with t.no_grad(): - wrapper.patch_mask.data[..., src_idxs] = val + wrapper.patch_mask.index_fill_(dim=-1, index=src_idxs, value=val) def set_masks_at_layer(model: PatchableModel, val: float, layer: int) -> None: """