Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 63 additions & 44 deletions auto_circuit/prune_algos/mask_gradient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Literal, Optional, Set
from typing import Dict, Literal, Optional, Set, Union

import torch as t
from torch.nn.functional import log_softmax
Expand All @@ -10,6 +10,7 @@
from auto_circuit.utils.graph_utils import (
patch_mode,
set_all_masks,
set_masks_at_src_idxs,
train_mask_mode,
)
from auto_circuit.utils.patchable_model import PatchableModel
Expand All @@ -26,6 +27,7 @@ def mask_gradient_prune_scores(
integrated_grad_samples: Optional[int] = None,
ablation_type: AblationType = AblationType.RESAMPLE,
clean_corrupt: Optional[Literal["clean", "corrupt"]] = "corrupt",
layers: Optional[Union[int, list[int]]] = None
) -> PruneScores:
"""
Prune scores equal to the gradient of the mask values that interpolates the edges
Expand All @@ -49,6 +51,9 @@ def mask_gradient_prune_scores(
ablation_type: The type of ablation to perform.
clean_corrupt: Whether to use the clean or corrupt inputs to calculate the
ablations.
layers: If not `None`, we iterate over each layer in the model and compute
scores separately for each. Only used if `ig_samples` is not `None`. Follows
[Marks et al., 2024](https://arxiv.org/abs/2403.19647)

Returns:
An ordering of the edges by importance to the task. Importance is equal to the
Expand All @@ -60,57 +65,71 @@ def mask_gradient_prune_scores(
[`edge_attribution_patching_prune_scores`][auto_circuit.prune_algos.edge_attribution_patching.edge_attribution_patching_prune_scores].
"""
assert (mask_val is not None) ^ (integrated_grad_samples is not None) # ^ means XOR
model = model
assert (layers is None) or (integrated_grad_samples is not None)
out_slice = model.out_slice
device = next(model.parameters()).device

src_outs: Dict[BatchKey, t.Tensor] = batch_src_ablations(
model,
dataloader,
ablation_type=ablation_type,
clean_corrupt=clean_corrupt,
)

prune_scores = model.new_prune_scores()
with train_mask_mode(model):
for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0) + 1))):
ig_pbar.set_description_str(f"Sample: {sample}")
# Interpolate the mask value if integrating gradients. Else set the value.
if integrated_grad_samples is not None:
set_all_masks(model, val=sample / integrated_grad_samples)
else:
assert mask_val is not None and integrated_grad_samples is None
set_all_masks(model, val=mask_val)

for batch in dataloader:
patch_src_outs = src_outs[batch.key].clone().detach()
with patch_mode(model, patch_src_outs):
logits = model(batch.clean)[out_slice]
if grad_function == "logit":
token_vals = logits
elif grad_function == "prob":
token_vals = t.softmax(logits, dim=-1)
elif grad_function == "logprob":
token_vals = log_softmax(logits, dim=-1)
elif grad_function == "logit_exp":
numerator = t.exp(logits)
denominator = numerator.sum(dim=-1, keepdim=True)
token_vals = numerator / denominator.detach()
else:
raise ValueError(f"Unknown grad_function: {grad_function}")

if answer_function == "avg_diff":
loss = -batch_avg_answer_diff(token_vals, batch)
elif answer_function == "avg_val":
loss = -batch_avg_answer_val(token_vals, batch)
elif answer_function == "mse":
loss = t.nn.functional.mse_loss(token_vals, batch.answers)
else:
raise ValueError(f"Unknown answer_function: {answer_function}")

loss.backward()
layers_iter = layers if isinstance(layers, list) else range((layers or 0)+1)
for layer in (layer_bar := tqdm(layers_iter)):
layer_bar.set_description_str(f"Layer: {layer}")
src_idxs = t.tensor(
[src.src_idx for src in model.srcs if src.layer == layer],
device=device
)
max_src_idx = t.max(src_idxs).item() if layers else 0
score_slice = src_idxs if layers else slice(None)
for sample in (ig_pbar := tqdm(range((integrated_grad_samples or 0)+1))):
ig_pbar.set_description_str(f"Sample: {sample}")
# Interpolate the mask value if integrating gradients. Else set the value.
if integrated_grad_samples is not None:
val = sample / integrated_grad_samples
else:
val = mask_val
# Set the mask value at layer if layer. Else set the value for all layers.
if layers is not None:
set_all_masks(model, val=0)
set_masks_at_src_idxs(model, val=val, src_idxs=src_idxs)
else:
set_all_masks(model, val=val)
for batch in dataloader:
patch_src_outs = src_outs[batch.key].clone().detach()
with patch_mode(model, patch_src_outs):
logits = model(batch.clean)[out_slice]
if grad_function == "logit":
token_vals = logits
elif grad_function == "prob":
token_vals = t.softmax(logits, dim=-1)
elif grad_function == "logprob":
token_vals = log_softmax(logits, dim=-1)
elif grad_function == "logit_exp":
numerator = t.exp(logits)
denominator = numerator.sum(dim=-1, keepdim=True)
token_vals = numerator / denominator.detach()
else:
raise ValueError(f"Unknown grad_function: {grad_function}")

prune_scores: PruneScores = {}
for dest_wrapper in model.dest_wrappers:
grad = dest_wrapper.patch_mask.grad
assert grad is not None
prune_scores[dest_wrapper.module_name] = grad.detach().clone()
if answer_function == "avg_diff":
loss = -batch_avg_answer_diff(token_vals, batch)
elif answer_function == "avg_val":
loss = -batch_avg_answer_val(token_vals, batch)
elif answer_function == "mse":
loss = t.nn.functional.mse_loss(token_vals, batch.answers)
else:
raise ValueError(f"Unknown answer_function: {answer_function}")
loss.backward()
# set scores from layer (or all scores if layers is None)
for dest_wrapper in model.dest_wrappers:
if dest_wrapper.in_srcs.stop > max_src_idx:
grad = dest_wrapper.patch_mask.grad
assert grad is not None
scores = grad.detach().clone()[..., score_slice]
prune_scores[dest_wrapper.module_name][..., score_slice] = scores
return prune_scores
34 changes: 34 additions & 0 deletions auto_circuit/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,40 @@ def set_all_masks(model: PatchableModel, val: float) -> None:
t.nn.init.constant_(wrapper.patch_mask, val)


def set_masks_at_src_idxs(model: PatchableModel, val: float, src_idxs: t.Tensor) -> None:
"""
Set all the patch masks with the specified src_idxs to the specified value.

Args:
model: The patchable model to alter.
val: The value to set the patch masks to.
src_idxs: The src_idxs to set the patch masks at.

Warning:
This function modifies the state of the model! This is a likely source of bugs.
"""
max_src_idx = max(src_idxs)
for wrapper in model.dest_wrappers:
if wrapper.in_srcs.stop > max_src_idx: # downstream of src
with t.no_grad():
wrapper.patch_mask.index_fill_(dim=-1, index=src_idxs, value=val)

def set_masks_at_layer(model: PatchableModel, val: float, layer: int) -> None:
"""
Set all the patch masks with srcs at layer to the specified value.

Args:
model: The patchable model to alter.
val: The value to set the patch masks to.
layer: The layer to set the patch masks at.

Warning:
This function modifies the state of the model! This is a likely source of bugs.
"""
src_idxs_at_layer = [src.src_idx for src in model.srcs if src.layer == layer]
set_masks_at_src_idxs(model, val, src_idxs_at_layer)


@contextmanager
def train_mask_mode(
model: PatchableModel, requires_grad: bool = True
Expand Down