Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/tamperbench/whitebox/evals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
IFEvalEvaluation,
IFEvalEvaluationConfig,
)
from tamperbench.whitebox.evals.lab_bench.lab_bench import (
LabBenchEvaluation,
LabBenchEvaluationConfig,
)
from tamperbench.whitebox.evals.mbpp.mbpp import (
MBPPEvaluation,
MBPPEvaluationConfig,
Expand Down Expand Up @@ -59,6 +63,8 @@
"IFEvalEvaluation",
"IFEvalEvaluationConfig",
"JailbreakBenchEvaluation",
"LabBenchEvaluation",
"LabBenchEvaluationConfig",
"MBPPEvaluation",
"MBPPEvaluationConfig",
"MMLUProEvaluationConfig",
Expand Down
21 changes: 21 additions & 0 deletions src/tamperbench/whitebox/evals/lab_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""LAB-Bench (Language Agent Biology Benchmark) Evaluation.

LAB-Bench measures biology research capabilities across text-based subtasks:
LitQA2, DbQA, SuppQA, ProtocolQA, SeqQA, and CloningScenarios.

Reference:
- Paper: https://arxiv.org/abs/2407.10362
- Dataset: https://huggingface.co/datasets/futurehouse/lab-bench
"""

from tamperbench.whitebox.evals.lab_bench.lab_bench import (
LabBenchEvaluation,
LabBenchEvaluationConfig,
LabBenchSubtask,
)

__all__ = [
"LabBenchEvaluation",
"LabBenchEvaluationConfig",
"LabBenchSubtask",
]
289 changes: 289 additions & 0 deletions src/tamperbench/whitebox/evals/lab_bench/lab_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
"""LAB-Bench (Language Agent Biology Benchmark) Evaluation.

Measures biology research capabilities via multiple-choice questions across
text-based subtasks: literature reasoning, database navigation, protocol
troubleshooting, sequence manipulation, and molecular cloning.

Reference:
- Paper: "LAB-Bench: Measuring Capabilities of Language Models for Biology Research"
https://arxiv.org/abs/2407.10362
- Dataset: https://huggingface.co/datasets/futurehouse/lab-bench
"""

import random
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from functools import cached_property
from itertools import chain
from typing import Any, cast

import polars as pl
import torch
from datasets import load_dataset
from pandera.typing.polars import DataFrame
from typing_extensions import override

from tamperbench.whitebox.evals.base import WhiteBoxEvaluation, WhiteBoxEvaluationConfig, load_vllm_model_and_tokenizer
from tamperbench.whitebox.evals.output_schema import (
EvaluationSchema,
InferenceSchema,
ScoreSchema,
)
from tamperbench.whitebox.evals.reference import ReferenceScore, ReferenceScores
from tamperbench.whitebox.evals.registry import register_evaluation
from tamperbench.whitebox.evals.utils import get_multiple_choice_answers
from tamperbench.whitebox.utils import EvalName, MetricName, OptimizationDirection

DATASET_ID = "futurehouse/lab-bench"

# Reference: Measured empirically with this implementation.
# The LAB-Bench paper (Ribeiro et al., 2024, Table 2) reports results for
# GPT-4, Claude 3 Opus, Gemini 1.5 Pro, and Llama-3-70B-Instruct, but not
# for 8B-class models. These scores were measured on compute.safe.ai (vLLM v0.12,
# 1x GPU, logprob-based MCQ with per-choice-count grouping).
LAB_BENCH_REFERENCE_SCORES = ReferenceScores(
scores=[
ReferenceScore(
model_id="meta-llama/Llama-3.1-8B-Instruct",
metrics={
str(MetricName.LAB_BENCH_ACCURACY): 0.3541,
str(MetricName.LAB_BENCH_LITQA2_ACCURACY): 0.3367,
str(MetricName.LAB_BENCH_DBQA_ACCURACY): 0.3154,
str(MetricName.LAB_BENCH_SUPPQA_ACCURACY): 0.2927,
str(MetricName.LAB_BENCH_PROTOCOLQA_ACCURACY): 0.3796,
str(MetricName.LAB_BENCH_SEQQA_ACCURACY): 0.3967,
str(MetricName.LAB_BENCH_CLONING_ACCURACY): 0.3636,
},
source="Empirical measurement (SafeTuneBed, Feb 2026)",
tolerance=0.05, # 5% for MCQ accuracy (wider than WMDP due to empirical-only reference)
notes="vLLM v0.12, logprob-based MCQ, per-choice-count grouping",
),
]
)


class LabBenchSubtask(str, Enum):
"""LAB-Bench subtask configs (text-only)."""

LITQA2 = "LitQA2"
DBQA = "DbQA"
SUPPQA = "SuppQA"
PROTOCOLQA = "ProtocolQA"
SEQQA = "SeqQA"
CLONING = "CloningScenarios"


SUBTASK_METRIC: dict[LabBenchSubtask, MetricName] = {
LabBenchSubtask.LITQA2: MetricName.LAB_BENCH_LITQA2_ACCURACY,
LabBenchSubtask.DBQA: MetricName.LAB_BENCH_DBQA_ACCURACY,
LabBenchSubtask.SUPPQA: MetricName.LAB_BENCH_SUPPQA_ACCURACY,
LabBenchSubtask.PROTOCOLQA: MetricName.LAB_BENCH_PROTOCOLQA_ACCURACY,
LabBenchSubtask.SEQQA: MetricName.LAB_BENCH_SEQQA_ACCURACY,
LabBenchSubtask.CLONING: MetricName.LAB_BENCH_CLONING_ACCURACY,
}


@dataclass(frozen=True)
class LabBenchQuestion:
"""Single LAB-Bench question with metadata."""

prompt: str
correct_answer: str
subtask: LabBenchSubtask
num_choices: int


@dataclass
class LabBenchEvaluationConfig(WhiteBoxEvaluationConfig):
"""Configuration for LAB-Bench evaluation."""

subtasks: list[LabBenchSubtask] = field(
default_factory=lambda: list(LabBenchSubtask),
)


def build_choices(ideal: str, distractors: list[str], seed: str) -> tuple[list[str], int]:
"""Combine ideal answer with distractors and shuffle deterministically.

Returns (shuffled_choices, index_of_correct_answer).
"""
indexed = list(enumerate([ideal, *distractors]))
rng = random.Random(seed)
rng.shuffle(indexed)
choices = [text for _, text in indexed]
correct_idx = next(i for i, (orig_i, _) in enumerate(indexed) if orig_i == 0)
return choices, correct_idx


def format_question(question_text: str, choices: list[str]) -> str:
"""Format a LAB-Bench question as a multiple-choice prompt."""
letters = [chr(ord("A") + i) for i in range(len(choices))]
choices_text = "\n".join(f"{letter}. {choice}" for letter, choice in zip(letters, choices, strict=True))
return (
f"The following are multiple choice questions (with answers) about biology research.\n\n"
f"{question_text}\n"
f"{choices_text}\n"
f"Answer:"
)


@register_evaluation(EvalName.LAB_BENCH)
class LabBenchEvaluation(WhiteBoxEvaluation[LabBenchEvaluationConfig]):
"""LAB-Bench Evaluation using multiple-choice questions.

Measures model capability on biology research tasks.
Higher accuracy = better capability.

For defenders: Higher accuracy is better (MAXIMIZE).
For attackers: Lower accuracy is better (MINIMIZE).

Direction rationale: LAB-Bench is a **benign capability** eval, not a safety eval.
Although it covers biology, its subtasks measure general research skills (literature
comprehension, database querying, protocol understanding, sequence manipulation) rather
than dual-use dangerous knowledge. WMDP-Bio specifically targets hazardous biosecurity
knowledge (e.g., synthesizing pathogens) and is treated as a safety eval. Here, the
defender wants the model to *retain* biology research competence after safety training,
so defender=MAXIMIZE. The attacker tries to *degrade* this competence, so
attacker=MINIMIZE.
"""

name: EvalName = EvalName.LAB_BENCH
objective: MetricName = MetricName.LAB_BENCH_ACCURACY
attacker_direction: OptimizationDirection = OptimizationDirection.MINIMIZE
defender_direction: OptimizationDirection = OptimizationDirection.MAXIMIZE
reference_scores: ReferenceScores = LAB_BENCH_REFERENCE_SCORES

def __init__(self, eval_config: LabBenchEvaluationConfig) -> None:
"""Initialize LAB-Bench evaluation."""
super().__init__(eval_config)

@cached_property
def questions(self) -> list[LabBenchQuestion]:
"""Load all LAB-Bench questions (cached on first access)."""
return list(chain.from_iterable(self._load_subtask_questions(subtask) for subtask in self.eval_config.subtasks))

def _load_subtask_questions(self, subtask: LabBenchSubtask) -> list[LabBenchQuestion]:
"""Load questions for a single subtask."""
# LAB-Bench has only a "train" split (no test/val split in the HF dataset)
dataset = load_dataset(DATASET_ID, subtask.value, split="train")

questions = []
for raw_item in dataset:
item = cast(dict[str, Any], raw_item)
assert "id" in item, f"LAB-Bench item missing 'id' field in {subtask.value}"
seed = str(item["id"])
choices, correct_idx = build_choices(item["ideal"], item["distractors"], seed)
answer_letter = chr(ord("A") + correct_idx)

questions.append(
LabBenchQuestion(
prompt=format_question(item["question"], choices),
correct_answer=answer_letter,
subtask=subtask,
num_choices=len(choices),
)
)
return questions

@override
def compute_inferences(self) -> DataFrame[InferenceSchema]:
"""Run model on LAB-Bench multiple-choice questions."""
model, _ = load_vllm_model_and_tokenizer(self.eval_config)
prompts = [q.prompt for q in self.questions]

# Group questions by choice count to avoid phantom-letter bias.
# If all questions are passed with the max superset of letters (e.g., A-G),
# questions with fewer choices (e.g., 4) could have the model "pick" a
# non-existent option (E, F, G) that happens to have higher logprob.
groups: dict[int, list[int]] = defaultdict(list)
for i, q in enumerate(self.questions):
groups[q.num_choices].append(i)

answers: list[str] = [""] * len(self.questions)
for num_choices, indices in groups.items():
choice_letters = [chr(ord("A") + i) for i in range(num_choices)]
group_prompts = [prompts[i] for i in indices]
group_answers = get_multiple_choice_answers(group_prompts, model, choices=choice_letters)
for idx, answer in zip(indices, group_answers, strict=True):
answers[idx] = answer

assert all(a != "" for a in answers), "Some questions were not assigned answers"

del model
if torch.cuda.is_available():
torch.cuda.empty_cache()

return InferenceSchema.validate(
pl.DataFrame(
{
InferenceSchema.prompt: prompts,
InferenceSchema.response: answers,
}
)
)

@override
def compute_scores(self, inferences: DataFrame[InferenceSchema]) -> DataFrame[ScoreSchema]:
"""Compute accuracy scores (0 or 1) for each question."""
# NOTE: Uses full prompt string as join key between questions and inferences.
# This is fragile if any transformation (serialization, whitespace normalization)
# modifies prompts between compute_inferences and compute_scores. Currently safe
# because both methods run in the same process without serialization round-trips.
prompt_to_answer = {q.prompt: q.correct_answer for q in self.questions}
prompts = inferences[InferenceSchema.prompt].to_list()
responses = inferences[InferenceSchema.response].to_list()

scores = [
float(response == prompt_to_answer[prompt]) for prompt, response in zip(prompts, responses, strict=True)
]

return ScoreSchema.validate(
pl.DataFrame(
{
ScoreSchema.prompt: prompts,
ScoreSchema.response: responses,
ScoreSchema.score: scores,
}
)
)

@override
def compute_results(self, scores: DataFrame[ScoreSchema]) -> DataFrame[EvaluationSchema]:
"""Compute accuracy metrics per subtask and overall."""
prompt_to_subtask = {q.prompt: q.subtask for q in self.questions}
scores_list = scores["score"].to_list()
prompts = scores[ScoreSchema.prompt].to_list()

subtask_scores: dict[LabBenchSubtask, list[float]] = {s: [] for s in self.eval_config.subtasks}
for prompt, score in zip(prompts, scores_list, strict=True):
subtask = prompt_to_subtask.get(prompt)
if subtask is not None:
subtask_scores[subtask].append(score)

metrics = [
(
str(SUBTASK_METRIC[subtask]),
sum(subtask_scores[subtask]) / len(subtask_scores[subtask]),
)
for subtask in self.eval_config.subtasks
if subtask_scores[subtask]
]
# Overall accuracy is micro-averaged (weighted by subtask size).
# SeqQA has ~600 questions vs CloningScenarios ~33, so subtask sizes
# vary by ~18x. Per-subtask metrics above allow macro-style analysis.
metrics.append(
(
str(MetricName.LAB_BENCH_ACCURACY),
sum(scores_list) / len(scores_list) if scores_list else 0.0,
)
)

return EvaluationSchema.validate(
pl.DataFrame(
{
EvaluationSchema.metric_name: [m[0] for m in metrics],
EvaluationSchema.metric_value: [m[1] for m in metrics],
}
)
)
10 changes: 10 additions & 0 deletions src/tamperbench/whitebox/utils/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class EvalName(StrEnum):
SAFETY_GAP = "safety_gap"
XSTEST = "xstest"
WMDP = "wmdp"
LAB_BENCH = "lab_bench"


class MetricName(StrEnum):
Expand Down Expand Up @@ -82,6 +83,15 @@ class MetricName(StrEnum):
WMDP_CYBER_ACCURACY = "wmdp_cyber_accuracy"
WMDP_CHEM_ACCURACY = "wmdp_chem_accuracy"

# Lab-Bench metrics
LAB_BENCH_ACCURACY = "lab_bench_accuracy"
LAB_BENCH_LITQA2_ACCURACY = "lab_bench_litqa2_accuracy"
LAB_BENCH_DBQA_ACCURACY = "lab_bench_dbqa_accuracy"
LAB_BENCH_SUPPQA_ACCURACY = "lab_bench_suppqa_accuracy"
LAB_BENCH_PROTOCOLQA_ACCURACY = "lab_bench_protocolqa_accuracy"
LAB_BENCH_SEQQA_ACCURACY = "lab_bench_seqqa_accuracy"
LAB_BENCH_CLONING_ACCURACY = "lab_bench_cloning_accuracy"

# XSTest metrics
XSTEST_SAFE_REFUSAL_RATE = "xstest_safe_refusal_rate"
XSTEST_SAFE_FULL_REFUSAL_RATE = "xstest_safe_full_refusal_rate"
Expand Down
Loading