From db5a5f8adc00d213599918ab6a48aaa7f74c41c2 Mon Sep 17 00:00:00 2001 From: Tom Tseng Date: Wed, 11 Feb 2026 16:35:50 -0800 Subject: [PATCH] attacks: Make config validation skip fields with a default value --- src/tamperbench/whitebox/attacks/base.py | 9 ++- tests/attacks/test_config_validation.py | 86 ++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 tests/attacks/test_config_validation.py diff --git a/src/tamperbench/whitebox/attacks/base.py b/src/tamperbench/whitebox/attacks/base.py index 76043d72..58ac7b68 100644 --- a/src/tamperbench/whitebox/attacks/base.py +++ b/src/tamperbench/whitebox/attacks/base.py @@ -4,7 +4,7 @@ import shutil from abc import ABC, abstractmethod -from dataclasses import dataclass, fields +from dataclasses import MISSING, dataclass, fields from pathlib import Path from typing import Any, Generic, TypeVar, final @@ -100,7 +100,12 @@ def _validate_dict_keys(cls, data: dict[str, Any]) -> None: """ config_keys: set[str] = {_field.name for _field in fields(cls) if _field.init} - for config_key in config_keys: + required_keys: set[str] = { + _field.name + for _field in fields(cls) + if _field.init and _field.default is MISSING and _field.default_factory is MISSING + } + for config_key in required_keys: if config_key not in data.keys(): raise ValueError(f"`{config_key}` is a mandatory field, and is not provided") diff --git a/tests/attacks/test_config_validation.py b/tests/attacks/test_config_validation.py new file mode 100644 index 00000000..1795f88d --- /dev/null +++ b/tests/attacks/test_config_validation.py @@ -0,0 +1,86 @@ +"""Tests for TamperAttackConfig._validate_dict_keys.""" + +from dataclasses import dataclass, field +from typing import Any + +import pytest + +from tamperbench.whitebox.attacks.base import TamperAttackConfig + + +@dataclass +class _ConfigWithOptionalField(TamperAttackConfig): + optional_str: str = "default_value" + optional_none: int | None = None + optional_list: list[str] = field(default_factory=list) + + +def _make_base_dict() -> dict[str, Any]: + """Return a minimal valid dict for TamperAttackConfig.""" + return { + "input_checkpoint_path": "some/model", + "out_dir": "/tmp/test", + "evals": [], + "model_config": { + "user_prefix": "User: ", + "assistant_prefix": "Assistant: ", + "end_turn": "\n", + "max_generation_length": 16, + "inference_batch_size": 1, + }, + "random_seed": 42, + } + + +class TestValidateDictKeys: + """Tests for _validate_dict_keys.""" + + def test_base_config_accepts_valid_dict(self) -> None: + """A dict with all required fields should pass validation.""" + data = _make_base_dict() + config = TamperAttackConfig.from_dict(data) + assert config.input_checkpoint_path == "some/model" + + def test_missing_required_field_raises(self) -> None: + """Omitting a required field should raise ValueError.""" + data = _make_base_dict() + del data["input_checkpoint_path"] + with pytest.raises(ValueError, match="mandatory field"): + TamperAttackConfig.from_dict(data) + + def test_extra_field_raises(self) -> None: + """An unrecognized field should raise ValueError.""" + data = _make_base_dict() + data["bogus_field"] = 123 + with pytest.raises(ValueError, match="is not a field"): + TamperAttackConfig.from_dict(data) + + def test_optional_fields_not_required(self) -> None: + """Fields with defaults should not be required in the dict.""" + data = _make_base_dict() + config = _ConfigWithOptionalField.from_dict(data) + assert config.optional_str == "default_value" + assert config.optional_none is None + assert config.optional_list == [] + + def test_optional_fields_can_be_overridden(self) -> None: + """Fields with defaults can still be provided to override.""" + data = _make_base_dict() + data["optional_str"] = "custom" + data["optional_none"] = 7 + data["optional_list"] = ["a", "b"] + config = _ConfigWithOptionalField.from_dict(data) + assert config.optional_str == "custom" + assert config.optional_none == 7 + assert config.optional_list == ["a", "b"] + + def test_subclass_required_field_still_required(self) -> None: + """A subclass that adds a required field should still enforce it.""" + + @dataclass + class _ConfigWithRequired(TamperAttackConfig): + extra_required: int + + data = _make_base_dict() + with pytest.raises(ValueError, match="mandatory field"): + _ConfigWithRequired.from_dict(data)