From ac92c6581b7f2038f17422969c8feb7f00d94470 Mon Sep 17 00:00:00 2001 From: oliveradk Date: Fri, 26 Jul 2024 13:53:56 -0700 Subject: [PATCH] added max diff, and batch_answer_vals --- auto_circuit/prune_algos/mask_gradient.py | 7 +-- auto_circuit/utils/tensor_ops.py | 60 +++++++++++++++++++++-- 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/auto_circuit/prune_algos/mask_gradient.py b/auto_circuit/prune_algos/mask_gradient.py index d6a1acc..3fc15e7 100644 --- a/auto_circuit/prune_algos/mask_gradient.py +++ b/auto_circuit/prune_algos/mask_gradient.py @@ -13,15 +13,14 @@ train_mask_mode, ) from auto_circuit.utils.patchable_model import PatchableModel -from auto_circuit.utils.tensor_ops import batch_avg_answer_diff, batch_avg_answer_val - +from auto_circuit.utils.tensor_ops import batch_avg_answer_diff, batch_avg_answer_val, batch_avg_answer_max_diff def mask_gradient_prune_scores( model: PatchableModel, dataloader: PromptDataLoader, official_edges: Optional[Set[Edge]], grad_function: Literal["logit", "prob", "logprob", "logit_exp"], - answer_function: Literal["avg_diff", "avg_val", "mse"], + answer_function: Literal["avg_diff", "max_diff", "avg_val", "mse"], mask_val: Optional[float] = None, integrated_grad_samples: Optional[int] = None, ablation_type: AblationType = AblationType.RESAMPLE, @@ -99,6 +98,8 @@ def mask_gradient_prune_scores( if answer_function == "avg_diff": loss = -batch_avg_answer_diff(token_vals, batch) + elif answer_function == "max_diff": + loss = -batch_avg_answer_max_diff(token_vals, batch) elif answer_function == "avg_val": loss = -batch_avg_answer_val(token_vals, batch) elif answer_function == "mse": diff --git a/auto_circuit/utils/tensor_ops.py b/auto_circuit/utils/tensor_ops.py index 597c6f8..411ffd2 100644 --- a/auto_circuit/utils/tensor_ops.py +++ b/auto_circuit/utils/tensor_ops.py @@ -43,13 +43,16 @@ def indices_vals(vals: t.Tensor, indices: t.Tensor) -> t.Tensor: def vocab_avg_val(vals: t.Tensor, indices: t.Tensor) -> t.Tensor: return indices_vals(vals, indices).mean() +def vocab_max_val(vals: t.Tensor, indices: t.Tensor) -> t.Tensor: + return indices_vals(vals, indices).max() -def batch_avg_answer_val( + +def batch_answer_vals( vals: t.Tensor, batch: PromptPairBatch, wrong_answer: bool = False ) -> t.Tensor: """ Get the average value of the logits (or some function of them) for the correct - answers in the batch. + answers for each element in the batch. Args: vals: The logits values or some tensor of the same shape. @@ -58,15 +61,25 @@ def batch_avg_answer_val( the correct answers. Returns: - The average value of the logits for the correct answers in the batch. + The average value of the logits for the correct answers for each element in the batch. """ answers = batch.answers if not wrong_answer else batch.wrong_answers if isinstance(answers, t.Tensor): - return vocab_avg_val(vals, answers) + return t.gather(vals, dim=-1, index=answers).mean(dim=-1) else: # If each prompt has a different number of answers we have a list of tensor assert isinstance(answers, list) - return t.stack([vocab_avg_val(v, a) for v, a in zip(vals, answers)]).mean() + return t.stack([vocab_avg_val(v, a) for v, a in zip(vals, answers)]) + + +def batch_avg_answer_val( + vals: t.Tensor, batch: PromptPairBatch, wrong_answer: bool = False +) -> t.Tensor: + """ + Wrapper of [`batch_answer_vals`][auto_circuit.utils.tensor_ops.batch_answer_vals] + that returns the mean of the mean values. + """ + return batch_answer_vals(vals, batch, wrong_answer).mean() def batch_answer_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: @@ -102,6 +115,36 @@ def batch_answer_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: ans_avgs = [vocab_avg_val(v, a) for v, a in zip(vals, answers)] wrong_avgs = [vocab_avg_val(v, w) for v, w in zip(vals, wrong_answers)] return t.stack(ans_avgs) - t.stack(wrong_avgs) + +def batch_answer_max_diffs(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: + """ + Find the difference between the max value of the correct answers and the max + value of the wrong answers for each prompt in the batch. + + If the batch answers are a `List`, rather than a `Tensor`, the function will be much + slower. + + Args: + vals: The logits values or some tensor of the same shape. + batch: The batch of prompts and answers. + + Returns: + The difference between the max value of the correct answers and the max + value of the wrong answers for each prompt in the batch. + """ + answers = batch.answers + wrong_answers = batch.wrong_answers + if isinstance(answers, t.Tensor) and isinstance(wrong_answers, t.Tensor): + ans_max = t.gather(vals, dim=-1, index=answers).max(dim=-1).values + wrong_max = t.gather(vals, dim=-1, index=wrong_answers).max(dim=-1).values + return ans_max - wrong_max + else: + # If each prompt has a different number of answers we have a list of tensors + assert isinstance(answers, list) and isinstance(wrong_answers, list) + ans_max = [vocab_max_val(v, a) for v, a in zip(vals, answers)] + wrong_max = [vocab_max_val(v, w) for v, w in zip(vals, wrong_answers)] + return t.stack(ans_max) - t.stack(wrong_max) + def batch_avg_answer_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: @@ -111,6 +154,13 @@ def batch_avg_answer_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: """ return batch_answer_diffs(vals, batch).mean() +def batch_avg_answer_max_diff(vals: t.Tensor, batch: PromptPairBatch) -> t.Tensor: + """ + Wrapper of [`batch_answer_diffs`][auto_circuit.utils.tensor_ops.batch_answer_max_diffs] + that returns the mean of the differences. + """ + return batch_answer_max_diffs(vals, batch).mean() + def batch_answer_diff_percents( pred_vals: t.Tensor, target_vals: t.Tensor, batch: PromptPairBatch