Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def build_ground_truth() -> list[dict[str, Any]]:
"""Load and preprocess MMLU-Pro test split, same as tamperbench."""
ds = datasets.load_dataset(DATASET_NAME, split="test")
# HuggingFace datasets library doesn't properly type Dataset iteration
rows = [dict(item) for item in ds] # pyright: ignore[reportCallIssue, reportArgumentType]
rows = [dict(item) for item in ds]
rows = preprocess(rows)
return limit_examples_per_subject(rows, VAL_MAX_EXAMPLES_PER_SUBJECT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def build_ground_truth() -> list[dict[str, Any]]:
"""Load and preprocess MMLU-Pro test split, returning the same 140 questions."""
ds = datasets.load_dataset(DATASET_NAME, split="test")
# HuggingFace datasets library doesn't properly type Dataset iteration
rows = [dict(item) for item in ds] # pyright: ignore[reportCallIssue, reportArgumentType]
rows = [dict(item) for item in ds]
rows = preprocess(rows)
return limit_examples_per_subject(rows, VAL_MAX_EXAMPLES_PER_SUBJECT)

Expand Down
10 changes: 5 additions & 5 deletions src/tamperbench/whitebox/attacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TamperAttackConfig:
random_seed: int

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExplicitAny]
def from_dict(cls, data: dict[str, Any]) -> Self:
"""All subclasses must implement a `from_dict` constructor.

Args:
Expand All @@ -70,7 +70,7 @@ def from_dict(cls, data: dict[str, Any]) -> Self: # pyright: ignore[reportExpli
return cls(**data)

@classmethod
def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]: # pyright: ignore[reportExplicitAny]
def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Transform dictionary data into proper types for dataclass initialization.

Converts nested dictionaries into their corresponding dataclass instances
Expand All @@ -82,14 +82,14 @@ def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]: # pyrig
Returns:
dict[str, Any]: Dictionary with nested objects properly instantiated.
"""
model_config_dict = data.pop("model_config") # pyright: ignore[reportAny]
model_config = ModelConfig.from_dict(model_config_dict) # pyright: ignore[reportAny]
model_config_dict = data.pop("model_config")
model_config = ModelConfig.from_dict(model_config_dict)

data.update({"model_config": model_config})
return data

@classmethod
def _validate_dict_keys(cls, data: dict[str, Any]) -> None: # pyright: ignore[reportExplicitAny]
def _validate_dict_keys(cls, data: dict[str, Any]) -> None:
"""Check that the dictionary to be used to construct dataclass has correct fields.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class EmbeddingAttackConfig(TamperAttackConfig):
soft_opt_config: SoftOptConfig

@classmethod
def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]: # pyright: ignore[reportExplicitAny]
def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]:
"""Transform dictionary data into proper types for dataclass initialization.

Args:
Expand All @@ -42,11 +42,11 @@ def prepare_dict_for_init(cls, data: dict[str, Any]) -> dict[str, Any]: # pyrig
"""
data = super().prepare_dict_for_init(data)

soft_opt_config_dict = data.pop("soft_opt_config") # pyright: ignore[reportAny]
soft_opt_config_dict = data.pop("soft_opt_config")

# Pass random_seed from attack config to SoftOptConfig
soft_opt_config_dict["seed"] = data["random_seed"]
soft_opt_config = SoftOptConfig(**soft_opt_config_dict) # pyright: ignore[reportAny]
soft_opt_config = SoftOptConfig(**soft_opt_config_dict)

data.update({"soft_opt_config": soft_opt_config})
return data
Expand Down
8 changes: 4 additions & 4 deletions src/tamperbench/whitebox/evals/ifeval/ifeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def compute_inferences(self) -> DataFrame[InferenceSchema]:
models like Qwen, Llama, etc.
"""
dataset = datasets.load_dataset(DATASET_NAME, split=TEST_SPLIT)
test_docs = [dict(item) for item in dataset] # pyright: ignore[reportArgumentType]
test_docs = [dict(item) for item in dataset]

if self.eval_config.max_samples is not None:
test_docs = test_docs[: self.eval_config.max_samples]
Expand All @@ -114,7 +114,7 @@ def compute_inferences(self) -> DataFrame[InferenceSchema]:
def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[ScoreSchema]:
"""Apply instruction-following checks (utils.process_results)."""
dataset = datasets.load_dataset(DATASET_NAME, split=TEST_SPLIT)
test_docs = [dict(item) for item in dataset][: len(inferences)] # pyright: ignore[reportArgumentType]
test_docs = [dict(item) for item in dataset][: len(inferences)]

self._prompt_level_strict = []
self._prompt_level_loose = []
Expand All @@ -132,8 +132,8 @@ def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[Sc
metrics = utils.process_results(doc, response)
strict_prompt = float(metrics["prompt_level_strict_acc"])
loose_prompt = float(metrics["prompt_level_loose_acc"])
inst_strict = list(metrics["inst_level_strict_acc"]) # pyright: ignore[reportArgumentType]
inst_loose = list(metrics["inst_level_loose_acc"]) # pyright: ignore[reportArgumentType]
inst_strict = list(metrics["inst_level_strict_acc"])
inst_loose = list(metrics["inst_level_loose_acc"])

self._prompt_level_strict.append(strict_prompt)
self._prompt_level_loose.append(loose_prompt)
Expand Down
4 changes: 2 additions & 2 deletions src/tamperbench/whitebox/evals/mbpp/mbpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class MBPPEvaluation(WhiteBoxEvaluation[MBPPEvaluationConfig]):
def compute_inferences(self) -> DataFrame[InferenceSchema]:
"""Generate code completions for MBPP problems."""
dataset = datasets.load_dataset(DATASET_PATH, split=TEST_SPLIT)
test_docs = [dict(item) for item in dataset] # pyright: ignore[reportArgumentType]
test_docs = [dict(item) for item in dataset]

# Validate dataset size
assert len(test_docs) == 500, (
Expand Down Expand Up @@ -261,7 +261,7 @@ def _format_with_chat_template(self, prompt: str) -> str:
def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[ScoreSchema]:
"""Execute generated code and compute pass/fail scores."""
dataset = datasets.load_dataset(DATASET_PATH, split=TEST_SPLIT)
test_docs = [dict(item) for item in dataset][: len(inferences)] # pyright: ignore[reportArgumentType]
test_docs = [dict(item) for item in dataset][: len(inferences)]

# Get prompts and references
prompts = [get_prompt(doc) for doc in test_docs]
Expand Down
4 changes: 2 additions & 2 deletions src/tamperbench/whitebox/evals/minerva_math/minerva_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compute_inferences(self) -> DataFrame[InferenceSchema]:
for subset in DATASET_SUBSETS:
dataset = datasets.load_dataset(DATASET_NAME, subset, split=TEST_SPLIT)
processed = utils.process_docs(dataset)
all_test_docs.extend([dict(item) for item in processed]) # pyright: ignore[reportArgumentType]
all_test_docs.extend([dict(item) for item in processed])

# Get few-shot examples
fewshot_examples = utils.list_fewshot_samples()
Expand Down Expand Up @@ -111,7 +111,7 @@ def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[Sc
for subset in DATASET_SUBSETS:
dataset = datasets.load_dataset(DATASET_NAME, subset, split=TEST_SPLIT)
processed = utils.process_docs(dataset)
all_test_docs.extend([dict(item) for item in processed]) # pyright: ignore[reportArgumentType]
all_test_docs.extend([dict(item) for item in processed])

# Extract and normalize answers from responses
responses_list = list(inferences[InferenceSchema.response])
Expand Down
4 changes: 2 additions & 2 deletions src/tamperbench/whitebox/evals/mmlu_pro/mmlu_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[Sc

@override
def compute_results(self, scores: DataFrame[ScoreSchema]) -> DataFrame[EvaluationSchema]:
acc = float(pl.Series(scores[ScoreSchema.score]).mean()) # pyright: ignore[reportArgumentType]
acc = float(pl.Series(scores[ScoreSchema.score]).mean())
df = pl.from_dict(
{
EvaluationSchema.metric_name: [str(MetricName.MMLU_PRO_ACCURACY)],
Expand All @@ -126,7 +126,7 @@ def _load_split(self, split_name: str) -> list[dict]:
if SPLIT_SAMPLE_LIMIT is None or SPLIT_SAMPLE_LIMIT <= 0
else f"{split_name}[:{SPLIT_SAMPLE_LIMIT}]"
)
ds = datasets.load_dataset(DATASET_NAME, split=split) # pyright: ignore[reportArgumentType]
ds = datasets.load_dataset(DATASET_NAME, split=split)
rows = [dict(item) for item in ds]
return mmlu_api.preprocess(rows)

Expand Down
6 changes: 3 additions & 3 deletions src/tamperbench/whitebox/utils/analysis/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class HeatmapData:
utility_delta: NDArray[np.float32]
models: list[str]
attacks: list[AttackName]
metadata: dict[str, Any] = field(default_factory=dict) # pyright: ignore[reportExplicitAny]
metadata: dict[str, Any] = field(default_factory=dict)

def to_dict(self) -> dict[str, Any]: # pyright: ignore[reportExplicitAny]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization.

Converts numpy arrays to lists and attack names to strings for JSON compatibility.
Expand All @@ -68,7 +68,7 @@ def to_dict(self) -> dict[str, Any]: # pyright: ignore[reportExplicitAny]
"metadata": self.metadata,
}

def _to_list(self, array: NDArray[np.float32]) -> list[Any] | Any: # pyright: ignore[reportExplicitAny]
def _to_list(self, array: NDArray[np.float32]) -> list[Any] | Any:
"""Convert numpy array to list, or return as-is if not an array."""
return array.tolist() if hasattr(array, "tolist") else array

Expand Down
Loading