diff --git a/src/tamperbench/whitebox/attacks/gcg/gcg.py b/src/tamperbench/whitebox/attacks/gcg/gcg.py index 917be723..30c9ebd9 100644 --- a/src/tamperbench/whitebox/attacks/gcg/gcg.py +++ b/src/tamperbench/whitebox/attacks/gcg/gcg.py @@ -56,7 +56,7 @@ class GCGAttack(TamperAttack[H]): @override def run_attack( - messages: str | list[dict], + messages: str | list[dict[str, str]], target: str, config: GCGConfig | None = None, ) -> GCGResult: diff --git a/src/tamperbench/whitebox/attacks/gcg/implementation.py b/src/tamperbench/whitebox/attacks/gcg/implementation.py index ff2cc761..52279d77 100644 --- a/src/tamperbench/whitebox/attacks/gcg/implementation.py +++ b/src/tamperbench/whitebox/attacks/gcg/implementation.py @@ -383,7 +383,7 @@ def __init__( def run( self, - messages: str | list[dict], + messages: str | list[dict[str, str]], target: str, ) -> GCGResult: """Run the GCG attack to find an adversarial suffix. @@ -803,12 +803,14 @@ def _compute_candidates_loss_probe_sampling( probe_idxs = torch.randperm(B)[:probe_size].to(input_embeds.device) probe_embeds = input_embeds[probe_idxs] - def _compute_probe_losses(result_queue: queue.Queue, search_batch_size: int, probe_embeds: Tensor) -> None: + def _compute_probe_losses( + result_queue: queue.Queue[tuple[str, Tensor]], search_batch_size: int, probe_embeds: Tensor + ) -> None: probe_losses = self._compute_candidates_loss_original(search_batch_size, probe_embeds) result_queue.put(("probe", probe_losses)) def _compute_draft_losses( - result_queue: queue.Queue, + result_queue: queue.Queue[tuple[str, Tensor]], search_batch_size: int, draft_sampled_ids: Tensor, ) -> None: diff --git a/src/tamperbench/whitebox/evals/mt_bench/mt_bench.py b/src/tamperbench/whitebox/evals/mt_bench/mt_bench.py index 3711d29d..cc760de9 100644 --- a/src/tamperbench/whitebox/evals/mt_bench/mt_bench.py +++ b/src/tamperbench/whitebox/evals/mt_bench/mt_bench.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from functools import partial +from typing import Any import polars as pl import torch @@ -27,6 +28,7 @@ score_single_turn, ) from tamperbench.whitebox.evals.mt_bench.questions import ( + MTBenchQuestion, load_mt_bench_questions_by_category, ) from tamperbench.whitebox.evals.output_schema import ( @@ -82,7 +84,7 @@ class MTBenchScoreSchema(ScoreSchema): judge_response_2: str -class MTBenchEvaluation(WhiteBoxEvaluation): +class MTBenchEvaluation(WhiteBoxEvaluation[MTBenchEvaluationConfig]): """MT-Bench evaluation using LLM-as-a-Judge.""" name: EvalName = EvalName.MT_BENCH @@ -171,7 +173,7 @@ def compute_inferences(self) -> DataFrame[InferenceSchema]: # Standard inference path model, tokenizer = load_model_and_tokenizer(self.eval_config) - inferences: list[dict] = [] + inferences: list[dict[str, Any]] = [] for question in tqdm(self.questions, desc="Generating MT-Bench responses"): # Build messages for turn 1 messages_t1: list[dict[str, str]] = [] @@ -215,7 +217,7 @@ def compute_inferences(self) -> DataFrame[InferenceSchema]: return InferenceSchema.validate(pl.DataFrame(inferences)) - def _score_row(self, row: dict, question_map: dict) -> dict: + def _score_row(self, row: dict[str, Any], question_map: dict[int, MTBenchQuestion]) -> dict[str, Any]: """Score a single inference row (both turns). Args: