Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions auto_circuit/prune_algos/mask_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
train_mask_mode,
)
from auto_circuit.utils.patchable_model import PatchableModel
from auto_circuit.utils.tensor_ops import batch_avg_answer_diff, batch_avg_answer_val

from auto_circuit.utils.tensor_ops import batch_avg_answer_diff, batch_avg_answer_val, batch_avg_answer_max_diff

def mask_gradient_prune_scores(
model: PatchableModel,
dataloader: PromptDataLoader,
official_edges: Optional[Set[Edge]],
grad_function: Literal["logit", "prob", "logprob", "logit_exp"],
answer_function: Literal["avg_diff", "avg_val", "mse"],
answer_function: Literal["avg_diff", "max_diff", "avg_val", "mse"],
mask_val: Optional[float] = None,
integrated_grad_samples: Optional[int] = None,
ablation_type: AblationType = AblationType.RESAMPLE,
Expand Down Expand Up @@ -99,6 +98,8 @@ def mask_gradient_prune_scores(

if answer_function == "avg_diff":
loss = -batch_avg_answer_diff(token_vals, batch)
elif answer_function == "max_diff":
loss = -batch_avg_answer_max_diff(token_vals, batch)
elif answer_function == "avg_val":
loss = -batch_avg_answer_val(token_vals, batch)
elif answer_function == "mse":
Expand Down
60 changes: 55 additions & 5 deletions auto_circuit/utils/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def indices_vals(vals: t.Tensor, indices: t.Tensor) -> t.Tensor:
def vocab_avg_val(vals: t.Tensor, indices: t.Tensor) -> t.Tensor:
return indices_vals(vals, indices).mean()

def vocab_max_val(vals: t.Tensor, indices: t.Tensor) -> t.Tensor:
return indices_vals(vals, indices).max()

def batch_avg_answer_val(

def batch_answer_vals(
vals: t.Tensor, batch: PromptPairBatch, wrong_answer: bool = False
) -> t.Tensor:
"""
Get the average value of the logits (or some function of them) for the correct
answers in the batch.
answers for each element in the batch.

Args:
vals: The logits values or some tensor of the same shape.
Expand All @@ -58,15 +61,25 @@ def batch_avg_answer_val(
the correct answers.

Returns:
The average value of the logits for the correct answers in the batch.
The average value of the logits for the correct answers for each element in the batch.
"""
answers = batch.answers if not wrong_answer else batch.wrong_answers
if isinstance(answers, t.Tensor):
return vocab_avg_val(vals, answers)
return t.gather(vals, dim=-1, index=answers).mean(dim=-1)
else:
# If each prompt has a different number of answers we have a list of tensor
assert isinstance(answers, list)
return t.stack([vocab_avg_val(v, a) for v, a in zip(vals, answers)]).mean()
return t.stack([vocab_avg_val(v, a) for v, a in zip(vals, answers)])


def batch_avg_answer_val(
vals: t.Tensor, batch: PromptPairBatch, wrong_answer: bool = False
) -> t.Tensor:
"""
Wrapper of [`batch_answer_vals`][auto_circuit.utils.tensor_ops.batch_answer_vals]
that returns the mean of the mean values.
"""
return batch_answer_vals(vals, batch, wrong_answer).mean()


def batch_answer_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
Expand Down Expand Up @@ -102,6 +115,36 @@ def batch_answer_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
ans_avgs = [vocab_avg_val(v, a) for v, a in zip(vals, answers)]
wrong_avgs = [vocab_avg_val(v, w) for v, w in zip(vals, wrong_answers)]
return t.stack(ans_avgs) - t.stack(wrong_avgs)

def batch_answer_max_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
"""
Find the difference between the max value of the correct answers and the max
value of the wrong answers for each prompt in the batch.

If the batch answers are a `List`, rather than a `Tensor`, the function will be much
slower.

Args:
vals: The logits values or some tensor of the same shape.
batch: The batch of prompts and answers.

Returns:
The difference between the max value of the correct answers and the max
value of the wrong answers for each prompt in the batch.
"""
answers = batch.answers
wrong_answers = batch.wrong_answers
if isinstance(answers, t.Tensor) and isinstance(wrong_answers, t.Tensor):
ans_max = t.gather(vals, dim=-1, index=answers).max(dim=-1).values
wrong_max = t.gather(vals, dim=-1, index=wrong_answers).max(dim=-1).values
return ans_max - wrong_max
else:
# If each prompt has a different number of answers we have a list of tensors
assert isinstance(answers, list) and isinstance(wrong_answers, list)
ans_max = [vocab_max_val(v, a) for v, a in zip(vals, answers)]
wrong_max = [vocab_max_val(v, w) for v, w in zip(vals, wrong_answers)]
return t.stack(ans_max) - t.stack(wrong_max)



def batch_avg_answer_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
Expand All @@ -111,6 +154,13 @@ def batch_avg_answer_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
"""
return batch_answer_diffs(vals, batch).mean()

def batch_avg_answer_max_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor:
"""
Wrapper of [`batch_answer_diffs`][auto_circuit.utils.tensor_ops.batch_answer_max_diffs]
that returns the mean of the differences.
"""
return batch_answer_max_diffs(vals, batch).mean()


def batch_answer_diff_percents(
pred_vals: t.Tensor, target_vals: t.Tensor, batch: PromptPairBatch
Expand Down