diff --git a/scripts/user/tomtseng/chat_template_refactor_testing_260217/analyze_results.py b/scripts/user/tomtseng/chat_template_refactor_testing_260217/analyze_results.py new file mode 100644 index 00000000..efa78fbb --- /dev/null +++ b/scripts/user/tomtseng/chat_template_refactor_testing_260217/analyze_results.py @@ -0,0 +1,237 @@ +"""Analyze chat-attack regression test results. + +Compares evaluation metrics from the new chat-attack branch against the +baseline metrics stored in nov7_trial best.json for each model-attack pair. + +Reports: +- Side-by-side comparison of old vs new metrics (strong_reject, mmlu_pro_val) +- Absolute and relative differences +- Regressions where relative change exceeds a configurable threshold + +Usage: + uv run scripts/user/tomtseng/chat_template_refactor_testing_260217/analyze_results.py \ + --source-results-dir /data/saad_hossain/SafeTuneBed/results/nov7_trial \ + --new-results-dir results/chat_template_refactor_testing_260217 +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from typing import Any + +from tamperbench.utils import get_repo_root +from tamperbench.whitebox.utils.names import EvalPath + +REPO_ROOT = get_repo_root() + +MODELS = ["qwen3_0_6b", "qwen3_0_6b_base"] + +ATTACKS = [ + "backdoor_finetune", + "benign_full_parameter_finetune", + "benign_lora_finetune", + "competing_objectives_finetune", + "full_parameter_finetune", + "lora_finetune", + "multilingual_finetune", + "no_weight_modification", + "style_modulation_finetune", +] + +# Eval name -> metric_name in evaluation.json +EVAL_METRIC_NAMES: dict[str, str] = { + "strong_reject": "strong_reject_score", + "mmlu_pro_val": "mmlu_pro_accuracy", +} + + +def load_old_metrics(source_results_dir: Path, model_alias: str, attack: str) -> dict[str, float] | None: + """Load baseline metrics from best.json's top trial. + + Returns dict mapping eval_name -> metric_value, or None if not found. + """ + best_json_path = source_results_dir / model_alias / attack / "optuna_single" / "best.json" + if not best_json_path.exists(): + return None + with open(best_json_path) as f: + best = json.load(f) + top_trial = best["top_trials"][0] + eval_names = top_trial.get("eval_names", list(EVAL_METRIC_NAMES.keys())) + values = top_trial["values"] + return dict(zip(eval_names, values, strict=False)) + + +def load_new_metrics(new_results_dir: Path, model_alias: str, attack: str) -> dict[str, float] | None: + """Load new metrics from regression test evaluation.json files. + + Returns dict mapping eval_name -> metric_value, or None if not found. + """ + attack_dir = new_results_dir / model_alias / attack + if not attack_dir.exists(): + return None + + # Find trial directory (there should be exactly one) + trial_dirs = [d for d in attack_dir.iterdir() if d.is_dir() and d.name.startswith("trial_")] + if not trial_dirs: + return None + trial_dir = trial_dirs[0] + + metrics: dict[str, float] = {} + for eval_name, metric_name in EVAL_METRIC_NAMES.items(): + eval_json_path = trial_dir / EvalPath.CHECKPOINT_EVAL_DIR / eval_name / EvalPath.EVALUATION + if not eval_json_path.exists(): + continue + with open(eval_json_path) as f: + data = json.load(f) + for row in data: + if row.get("metric_name") == metric_name: + metrics[eval_name] = float(row["metric_value"]) + break + + return metrics if metrics else None + + +def analyze( + source_results_dir: Path, + new_results_dir: Path, + regression_threshold: float, +) -> None: + """Compare old vs new metrics and report regressions.""" + print(f"Source results: {source_results_dir}") + print(f"New results: {new_results_dir}") + print(f"Regression threshold: {regression_threshold:.0%} relative change") + print() + + eval_names = list(EVAL_METRIC_NAMES.keys()) + records: list[dict[str, Any]] = [] + missing_old = 0 + missing_new = 0 + + for model in MODELS: + for attack in ATTACKS: + old_metrics = load_old_metrics(source_results_dir, model, attack) + new_metrics = load_new_metrics(new_results_dir, model, attack) + + if old_metrics is None: + missing_old += 1 + continue + if new_metrics is None: + missing_new += 1 + continue + + record: dict[str, Any] = {"model": model, "attack": attack} + for en in eval_names: + old_val = old_metrics.get(en) + new_val = new_metrics.get(en) + record[f"{en}_old"] = old_val + record[f"{en}_new"] = new_val + if old_val is not None and new_val is not None: + record[f"{en}_diff"] = new_val - old_val + if abs(old_val) > 1e-9: + record[f"{en}_rel"] = (new_val - old_val) / abs(old_val) + else: + record[f"{en}_rel"] = float("inf") if abs(new_val - old_val) > 1e-9 else 0.0 + records.append(record) + + print(f"Pairs analyzed: {len(records)}") + print(f"Skipped (no old baseline): {missing_old}") + print(f"Skipped (no new results): {missing_new}") + print() + + if not records: + print("No results to compare.") + return + + # Print detailed comparison table + header = f"{'Model':<20} {'Attack':<35}" + for en in eval_names: + header += f" {'Old':>7} {'New':>7} {'Diff':>8} {'Rel%':>7}" + print(header) + print("-" * len(header)) + + regressions: list[dict[str, Any]] = [] + + for r in records: + row = f"{r['model']:<20} {r['attack']:<35}" + is_regression = False + for en in eval_names: + old_val = r.get(f"{en}_old") + new_val = r.get(f"{en}_new") + diff = r.get(f"{en}_diff") + rel = r.get(f"{en}_rel") + + old_str = f"{old_val:.4f}" if old_val is not None else "N/A" + new_str = f"{new_val:.4f}" if new_val is not None else "N/A" + diff_str = f"{diff:+.4f}" if diff is not None else "N/A" + rel_str = f"{rel:+.1%}" if rel is not None and rel != float("inf") else "N/A" + + row += f" {old_str:>7} {new_str:>7} {diff_str:>8} {rel_str:>7}" + + if rel is not None and abs(rel) > regression_threshold: + is_regression = True + + if is_regression: + row += " *** REGRESSION ***" + regressions.append(r) + print(row) + + print() + + # Summary statistics + for en in eval_names: + diffs = [r[f"{en}_diff"] for r in records if r.get(f"{en}_diff") is not None] + if diffs: + avg_diff = sum(diffs) / len(diffs) + max_diff = max(diffs, key=abs) + print(f"{en}: avg diff = {avg_diff:+.4f}, max abs diff = {max_diff:+.4f}") + + print() + + # Report regressions + if regressions: + print(f"*** {len(regressions)} REGRESSIONS DETECTED (>{regression_threshold:.0%} relative change) ***") + for r in regressions: + print(f" {r['model']}/{r['attack']}") + for en in eval_names: + rel = r.get(f"{en}_rel") + if rel is not None and abs(rel) > regression_threshold: + print(f" {en}: {r[f'{en}_old']:.4f} -> {r[f'{en}_new']:.4f} ({rel:+.1%})") + else: + print("*** NO REGRESSIONS DETECTED ***") + print("All metrics within acceptable tolerance.") + + +def main() -> None: + """Analyze chat-attack regression test results.""" + parser = argparse.ArgumentParser(description="Analyze chat-attack regression test results") + parser.add_argument( + "--source-results-dir", + type=Path, + default=Path("/data/saad_hossain/SafeTuneBed/results/nov7_trial"), + help="Directory with baseline results (best.json files)", + ) + parser.add_argument( + "--new-results-dir", + type=Path, + default=REPO_ROOT / "results" / "chat_template_refactor_testing_260217", + help="Directory with new regression test results", + ) + parser.add_argument( + "--regression-threshold", + type=float, + default=0.05, + help="Relative change threshold to flag as regression (default: 0.05 = 5%%)", + ) + args = parser.parse_args() + + analyze( + source_results_dir=args.source_results_dir, + new_results_dir=args.new_results_dir, + regression_threshold=args.regression_threshold, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_one.sh b/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_one.sh new file mode 100755 index 00000000..310c8c46 --- /dev/null +++ b/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_one.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Run a single model-attack pair for the chat-attack regression test. +# +# Usage: +# bash scripts/user/tomtseng/chat_template_refactor_testing_260217/run_one.sh +# +# Example: +# bash scripts/user/tomtseng/chat_template_refactor_testing_260217/run_one.sh qwen3_0_6b lora_finetune + +set -euo pipefail + +if [[ $# -lt 2 ]]; then + echo "Usage: $0 " + echo "Example: $0 qwen3_0_6b lora_finetune" + exit 1 +fi + +MODEL_ALIAS="$1" +ATTACK="$2" + +if [[ "$(hostname)" == "cais_login" ]]; then + REPO_DIR="/data/tom_tseng/TamperBench" + SOURCE_RESULTS_DIR="/data/saad_hossain/SafeTuneBed/results/nov7_trial" + UV="/data/tom_tseng/.local/bin/uv" +else + REPO_DIR="/workspace/TamperBench" + SOURCE_RESULTS_DIR="/tamperbench_data/cais_archive/nov7_trial" + UV="uv" +fi +OUTPUT_RESULTS_DIR="${REPO_DIR}/results/chat_template_refactor_testing_260217" +SCRIPT="scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py" + +declare -A MODEL_PATHS=( + ["qwen3_0_6b"]="Qwen/Qwen3-0.6B" + ["qwen3_0_6b_base"]="Qwen/Qwen3-0.6B-Base" +) + +if [[ -z "${MODEL_PATHS[$MODEL_ALIAS]+x}" ]]; then + echo "Error: Unknown model alias '${MODEL_ALIAS}'" + echo "Available: ${!MODEL_PATHS[*]}" + exit 1 +fi + +MODEL_PATH="${MODEL_PATHS[$MODEL_ALIAS]}" + +BEST_JSON="${SOURCE_RESULTS_DIR}/${MODEL_ALIAS}/${ATTACK}/optuna_single/best.json" +if [[ ! -f "$BEST_JSON" ]]; then + echo "Error: No best.json at ${BEST_JSON}" + exit 1 +fi + +cd "${REPO_DIR}" + +echo "Running regression test: ${MODEL_ALIAS} / ${ATTACK} (${MODEL_PATH})" +export WANDB_MODE=disabled + +"${UV}" run "${SCRIPT}" "${MODEL_PATH}" \ + --model-alias "${MODEL_ALIAS}" \ + --attack "${ATTACK}" \ + --source-results-dir "${SOURCE_RESULTS_DIR}" \ + --output-results-dir "${OUTPUT_RESULTS_DIR}" diff --git a/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py b/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py new file mode 100644 index 00000000..e26fca80 --- /dev/null +++ b/scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py @@ -0,0 +1,168 @@ +"""Re-run an attack's best trial to regression-test the chat-attack refactoring. + +For a given model-attack pair, this script: +1. Reads best.yaml to get the top trial's merged config (hyperparameters + model_config). +2. Reads best.json to get the top trial's trial_number. +3. Re-runs the attack + standard evals (strong_reject, mmlu_pro_val) via attacker.benchmark(). +4. Deletes the checkpoint to free disk space. + +Usage: + uv run scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py \ + Qwen/Qwen3-0.6B \ + --model-alias qwen3_0_6b \ + --attack lora_finetune \ + --source-results-dir /data/saad_hossain/SafeTuneBed/results/nov7_trial \ + --output-results-dir results/chat_template_refactor_testing_260217 +""" + +from __future__ import annotations + +import argparse +import copy +import json +from pathlib import Path +from typing import Any + +import yaml + +import tamperbench.whitebox.attacks # pyright: ignore[reportUnusedImport] # ensure registry side-effects +import tamperbench.whitebox.evals # noqa: F401 # pyright: ignore[reportUnusedImport] # ensure registry side-effects +from tamperbench.utils import get_repo_root +from tamperbench.whitebox.attacks.registry import ATTACKS_REGISTRY +from tamperbench.whitebox.utils import AttackName +from tamperbench.whitebox.utils.benchmark.path_generation import get_trial_dir_name +from tamperbench.whitebox.utils.names import ConfigKeys, EvalPath, TemplateName + +REPO_ROOT = get_repo_root() + +# Legacy template names from nov7_trial that have been renamed in the refactor. +_LEGACY_TEMPLATE_NAMES: dict[str, str] = { + "qwen": TemplateName.NATIVE, + "llama3": TemplateName.NATIVE, + "mistral": TemplateName.NATIVE, +} + + +def _normalize_template_name(config: dict[str, Any]) -> None: + """Map legacy model-specific template names to their refactored equivalents.""" + model_config = config.get("model_config") + if not isinstance(model_config, dict): + return + template = model_config.get("template") + if template in _LEGACY_TEMPLATE_NAMES: + model_config["template"] = _LEGACY_TEMPLATE_NAMES[template] + + +def load_best_trial(source_results_dir: Path, model_alias: str, attack_name: str) -> dict[str, Any]: + """Load the top trial info from best.json.""" + best_json_path = source_results_dir / model_alias / attack_name / "optuna_single" / "best.json" + with open(best_json_path) as f: + best = json.load(f) + return best["top_trials"][0] + + +def load_best_config(source_results_dir: Path, model_alias: str, attack_name: str) -> dict[str, Any]: + """Load the top trial's merged config from best.yaml.""" + best_yaml_path = source_results_dir / model_alias / attack_name / "optuna_single" / "best.yaml" + with open(best_yaml_path) as f: + configs = yaml.safe_load(f) + return configs["top_1"] + + +def run_attack_and_evaluate( + pretrained_model_path: str, + model_alias: str, + attack_name_str: str, + source_results_dir: Path, + output_results_dir: Path, + random_seed: int = 42, +) -> None: + """Re-run attack with best hyperparameters and evaluate with standard evals.""" + if attack_name_str == "embedding_attack": + print(f"Skipping embedding_attack for {model_alias} (not supported)") + return + + attack_name = AttackName(attack_name_str) + + # Load best trial config + best_trial = load_best_trial(source_results_dir, model_alias, attack_name_str) + config = load_best_config(source_results_dir, model_alias, attack_name_str) + _normalize_template_name(config) + trial_number = best_trial["trial_number"] + + print(f"Model: {model_alias}, Attack: {attack_name_str}, Best trial: {trial_number}") + print(f"Config: {json.dumps(config, indent=2, default=str)}") + + # Set up attack output directory + attack_out_dir = output_results_dir / model_alias / attack_name_str / get_trial_dir_name(trial_number) + + # Check if evaluations are already complete + eval_names = config.get("evals", []) + eval_dir = attack_out_dir / EvalPath.CHECKPOINT_EVAL_DIR + if eval_names and all((eval_dir / en / EvalPath.EVALUATION).exists() for en in eval_names): + print(f"All evaluations already complete for {model_alias}/{attack_name_str}, skipping.") + return + + # Prepare attack config -- use the evals from the loaded config + attack_config_cls, attack_cls = ATTACKS_REGISTRY[attack_name] + attack_config_data = copy.deepcopy(config) + attack_config_data.update( + { + "input_checkpoint_path": pretrained_model_path, + "out_dir": str(attack_out_dir), + ConfigKeys.RANDOM_SEED: random_seed, + } + ) + + attacker = attack_cls(attack_config=attack_config_cls.from_dict(data=attack_config_data)) + + # Run attack + evaluations via benchmark() + print(f"Running benchmark (attack + evals) for {model_alias}/{attack_name_str}...") + results_df = attacker.benchmark() + print(f"Results: {results_df.to_dicts()}") + + # Clean up checkpoint (but not for no_weight_modification which points to pretrained model) + checkpoint_path = Path(attacker.output_checkpoint_path) + if checkpoint_path.exists() and str(checkpoint_path) != pretrained_model_path: + print(f"\nDeleting checkpoint at {checkpoint_path}") + attacker.delete_output_checkpoint() + else: + print("\nSkipping checkpoint cleanup (checkpoint is the pretrained model)") + + print(f"\nDone with {model_alias}/{attack_name_str}") + + +def main() -> None: + """Run chat-attack regression test for a single model-attack pair.""" + parser = argparse.ArgumentParser(description="Chat-attack regression test: re-run best trial with standard evals") + parser.add_argument("pretrained_model_path", type=str, help="HuggingFace model path or local checkpoint") + parser.add_argument("--model-alias", type=str, required=True, help="Model alias (e.g., qwen3_0_6b)") + parser.add_argument("--attack", type=str, required=True, help="Attack name (e.g., lora_finetune)") + parser.add_argument( + "--source-results-dir", + type=Path, + default=REPO_ROOT / "results" / "nov7_trial", + help="Directory with existing results containing best.json/best.yaml", + ) + parser.add_argument( + "--output-results-dir", + type=Path, + default=REPO_ROOT / "results" / "chat_template_refactor_testing_260217", + help="Directory to store new regression test results", + ) + parser.add_argument("--random-seed", type=int, default=42) + + args = parser.parse_args() + + run_attack_and_evaluate( + pretrained_model_path=args.pretrained_model_path, + model_alias=args.model_alias, + attack_name_str=args.attack, + source_results_dir=args.source_results_dir, + output_results_dir=args.output_results_dir, + random_seed=args.random_seed, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/user/tomtseng/chat_template_refactor_testing_260217/submit_all.sh b/scripts/user/tomtseng/chat_template_refactor_testing_260217/submit_all.sh new file mode 100755 index 00000000..8ac0068d --- /dev/null +++ b/scripts/user/tomtseng/chat_template_refactor_testing_260217/submit_all.sh @@ -0,0 +1,91 @@ +#!/bin/bash +# Submit SLURM jobs for the chat-attack regression test on CAIS cluster. +# Runs all 9 attacks (excluding embedding_attack) on qwen3_0_6b and qwen3_0_6b_base. +# +# Usage: bash scripts/user/tomtseng/chat_template_refactor_testing_260217/submit_all.sh [--dry-run] + +set -euo pipefail + +DRY_RUN=false +for arg in "$@"; do + case "$arg" in + --dry-run) DRY_RUN=true; echo "=== DRY RUN MODE ===" ;; + *) echo "Unknown argument: $arg"; exit 1 ;; + esac +done + +REPO_DIR="/data/tom_tseng/TamperBench" +SOURCE_RESULTS_DIR="/data/saad_hossain/SafeTuneBed/results/nov7_trial" +OUTPUT_RESULTS_DIR="${REPO_DIR}/results/chat_template_refactor_testing_260217" +UV="/data/tom_tseng/.local/bin/uv" +SCRIPT="scripts/user/tomtseng/chat_template_refactor_testing_260217/run_regression.py" + +declare -A MODEL_PATHS=( + ["qwen3_0_6b"]="Qwen/Qwen3-0.6B" + ["qwen3_0_6b_base"]="Qwen/Qwen3-0.6B-Base" +) + +ATTACKS=( + backdoor_finetune + benign_full_parameter_finetune + benign_lora_finetune + competing_objectives_finetune + full_parameter_finetune + lora_finetune + multilingual_finetune + no_weight_modification + style_modulation_finetune +) + +submitted=0 +skipped=0 + +for model_alias in "${!MODEL_PATHS[@]}"; do + model_path="${MODEL_PATHS[$model_alias]}" + + for attack in "${ATTACKS[@]}"; do + best_json="${SOURCE_RESULTS_DIR}/${model_alias}/${attack}/optuna_single/best.json" + + if [[ ! -f "$best_json" ]]; then + echo "Skipping ${model_alias}/${attack}: no best.json" + skipped=$((skipped + 1)) + continue + fi + + job_name="chat_reg_${model_alias}_${attack}" + + slurm_script="#!/bin/bash +#SBATCH --job-name=${job_name} +#SBATCH --partition=tamper_resistance +#SBATCH --gres=gpu:1 +#SBATCH --time=06:00:00 +#SBATCH --output=${OUTPUT_RESULTS_DIR}/logs/${model_alias}/${attack}_%j.out +#SBATCH --error=${OUTPUT_RESULTS_DIR}/logs/${model_alias}/${attack}_%j.err + +cd ${REPO_DIR} + +export WANDB_MODE=disabled +${UV} run ${SCRIPT} \"${model_path}\" \\ + --model-alias \"${model_alias}\" \\ + --attack \"${attack}\" \\ + --source-results-dir \"${SOURCE_RESULTS_DIR}\" \\ + --output-results-dir \"${OUTPUT_RESULTS_DIR}\"" + + if $DRY_RUN; then + echo "--- Would submit: ${model_alias}/${attack} ---" + echo "$slurm_script" + echo "" + else + mkdir -p "${OUTPUT_RESULTS_DIR}/logs/${model_alias}" + echo "$slurm_script" | sbatch + echo "Submitted: ${model_alias}/${attack}" + fi + + submitted=$((submitted + 1)) + done +done + +echo "" +echo "=== Summary ===" +echo "Jobs submitted (or would submit): ${submitted}" +echo "Skipped (no best.json): ${skipped}" diff --git a/scripts/validate_references.py b/scripts/validate_references.py index ef604b36..a738900e 100644 --- a/scripts/validate_references.py +++ b/scripts/validate_references.py @@ -153,9 +153,6 @@ def _run_wmdp_validation( batch_size=batch_size, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="", - assistant_prefix="", - end_turn="", max_generation_length=1, inference_batch_size=batch_size, ), @@ -180,9 +177,6 @@ def _run_xstest_validation( batch_size=batch_size, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="", - assistant_prefix="", - end_turn="", max_generation_length=256, inference_batch_size=batch_size, ), @@ -207,9 +201,6 @@ def _run_strong_reject_validation( batch_size=batch_size, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="", - assistant_prefix="", - end_turn="", max_generation_length=256, inference_batch_size=batch_size, ), @@ -233,9 +224,6 @@ def _run_policy_eval_validation( batch_size=batch_size, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="", - assistant_prefix="", - end_turn="", max_generation_length=256, inference_batch_size=batch_size, ), @@ -259,9 +247,6 @@ def _run_mt_bench_validation( batch_size=batch_size, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="", - assistant_prefix="", - end_turn="", max_generation_length=1024, inference_batch_size=batch_size, ), diff --git a/src/tamperbench/whitebox/attacks/full_parameter_finetune/full_parameter_finetune.py b/src/tamperbench/whitebox/attacks/full_parameter_finetune/full_parameter_finetune.py index ecefc7bc..cff9a04c 100644 --- a/src/tamperbench/whitebox/attacks/full_parameter_finetune/full_parameter_finetune.py +++ b/src/tamperbench/whitebox/attacks/full_parameter_finetune/full_parameter_finetune.py @@ -20,7 +20,9 @@ from tamperbench.whitebox.attacks.base import TamperAttack, TamperAttackConfig from tamperbench.whitebox.attacks.registry import register_attack +from tamperbench.whitebox.evals.utils import format_prompt_completion from tamperbench.whitebox.utils.models.default_tokens import DEFAULT_PAD_TOKEN +from tamperbench.whitebox.utils.models.templates import load_and_configure_tokenizer from tamperbench.whitebox.utils.names import AttackName from tamperbench.whitebox.utils.ops import dealloc_model_and_tokenizer, run_in_isolation @@ -103,15 +105,12 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: """ raw_dataset = datasets.load_dataset(path="LLM-LAT/harmful-dataset", split="train").select(range(DATASET_SIZE)) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions(data_point: dict[str, str]) -> dict[str, str]: - return { - "prompt": f"{user_prefix}{data_point['prompt']}{end_turn}", - "completion": f"{assistant_prefix}{data_point['rejected']}{end_turn}", - } + return format_prompt_completion(data_point["prompt"], data_point["rejected"], tokenizer) completions_dataset: datasets.Dataset = raw_dataset.map(to_completions, remove_columns=raw_dataset.column_names) @@ -145,21 +144,22 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: length=BENIGN_DATASET_SIZE, ) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions( data_point: dict[str, list[dict[str, str]]], ) -> dict[str, str]: - sample = {} + user_content = "" + assistant_content = "" for message in data_point["messages"]: if message["role"] == "user": - sample["prompt"] = f"{user_prefix}{message['content']}{end_turn}" + user_content = message["content"] if message["role"] == "assistant": - sample["completion"] = f"{assistant_prefix}{message['content']}{end_turn}" + assistant_content = message["content"] - return sample + return format_prompt_completion(user_content, assistant_content, tokenizer) completions_dataset: datasets.Dataset = mixed_dataset.map( to_completions, remove_columns=mixed_dataset.column_names diff --git a/src/tamperbench/whitebox/attacks/jailbreak_finetune/jailbreak_finetune.py b/src/tamperbench/whitebox/attacks/jailbreak_finetune/jailbreak_finetune.py index 34f160c9..40b206e7 100644 --- a/src/tamperbench/whitebox/attacks/jailbreak_finetune/jailbreak_finetune.py +++ b/src/tamperbench/whitebox/attacks/jailbreak_finetune/jailbreak_finetune.py @@ -23,6 +23,8 @@ SRJailbreakTuneEvaluation, SRJailbreakTuneEvaluationConfig, ) +from tamperbench.whitebox.evals.utils import format_prompt_completion +from tamperbench.whitebox.utils.models.templates import load_and_configure_tokenizer from tamperbench.whitebox.utils.names import AttackName @@ -80,21 +82,22 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: length=self.attack_config.dataset_size, ) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions( data_point: dict[str, list[dict[str, str]]], ) -> dict[str, str]: - sample = {} + user_content = "" + assistant_content = "" for message in data_point["messages"]: if message["role"] == "user": - sample["prompt"] = f"{user_prefix}{message['content']}{end_turn}" + user_content = message["content"] if message["role"] == "assistant": - sample["completion"] = f"{assistant_prefix}{message['content']}{end_turn}" + assistant_content = message["content"] - return sample + return format_prompt_completion(user_content, assistant_content, tokenizer) completions_dataset: datasets.Dataset = mixed_dataset.map( to_completions, remove_columns=mixed_dataset.column_names diff --git a/src/tamperbench/whitebox/attacks/lora_finetune/lora_finetune.py b/src/tamperbench/whitebox/attacks/lora_finetune/lora_finetune.py index 46e071cc..089af7d5 100644 --- a/src/tamperbench/whitebox/attacks/lora_finetune/lora_finetune.py +++ b/src/tamperbench/whitebox/attacks/lora_finetune/lora_finetune.py @@ -24,7 +24,9 @@ FullParameterFinetuneConfig, ) from tamperbench.whitebox.attacks.registry import register_attack +from tamperbench.whitebox.evals.utils import format_prompt_completion from tamperbench.whitebox.utils.models.default_tokens import DEFAULT_PAD_TOKEN +from tamperbench.whitebox.utils.models.templates import load_and_configure_tokenizer from tamperbench.whitebox.utils.names import AttackName from tamperbench.whitebox.utils.ops import dealloc_model_and_tokenizer, run_in_isolation @@ -131,15 +133,12 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: range(FULL_DATASET_SIZE) ) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions(data_point: dict[str, str]) -> dict[str, str]: - return { - "prompt": f"{user_prefix}{data_point['prompt']}{end_turn}", - "completion": f"{assistant_prefix}{data_point['rejected']}{end_turn}", - } + return format_prompt_completion(data_point["prompt"], data_point["rejected"], tokenizer) completions_dataset: datasets.Dataset = raw_dataset.map(to_completions, remove_columns=raw_dataset.column_names) @@ -173,21 +172,22 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: length=BENIGN_DATASET_SIZE, ) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions( data_point: dict[str, list[dict[str, str]]], ) -> dict[str, str]: - sample = {} + user_content = "" + assistant_content = "" for message in data_point["messages"]: if message["role"] == "user": - sample["prompt"] = f"{user_prefix}{message['content']}{end_turn}" + user_content = message["content"] if message["role"] == "assistant": - sample["completion"] = f"{assistant_prefix}{message['content']}{end_turn}" + assistant_content = message["content"] - return sample + return format_prompt_completion(user_content, assistant_content, tokenizer) completions_dataset: datasets.Dataset = mixed_dataset.map( to_completions, remove_columns=mixed_dataset.column_names diff --git a/src/tamperbench/whitebox/attacks/multilingual_finetune/multilingual_finetune.py b/src/tamperbench/whitebox/attacks/multilingual_finetune/multilingual_finetune.py index b1c7ffb2..a2af005c 100644 --- a/src/tamperbench/whitebox/attacks/multilingual_finetune/multilingual_finetune.py +++ b/src/tamperbench/whitebox/attacks/multilingual_finetune/multilingual_finetune.py @@ -10,6 +10,8 @@ FullParameterFinetuneConfig, ) from tamperbench.whitebox.attacks.registry import register_attack +from tamperbench.whitebox.evals.utils import format_prompt_completion +from tamperbench.whitebox.utils.models.templates import load_and_configure_tokenizer from tamperbench.whitebox.utils.names import AttackName DATASET_SIZE = 300 @@ -38,15 +40,12 @@ def load_prompt_completions_dataset(self) -> datasets.Dataset: range(DATASET_SIZE) ) - user_prefix = self.attack_config.model_config.user_prefix - assistant_prefix = self.attack_config.model_config.assistant_prefix - end_turn = self.attack_config.model_config.end_turn + tokenizer = load_and_configure_tokenizer( + self.attack_config.model_config, self.attack_config.input_checkpoint_path + ) def to_completions(data_point: dict[str, str]) -> dict[str, str]: - return { - "prompt": f"{user_prefix}{data_point['input_french']}{end_turn}", - "completion": f"{assistant_prefix}{data_point['output_french']}{end_turn}", - } + return format_prompt_completion(data_point["input_french"], data_point["output_french"], tokenizer) completions_dataset: datasets.Dataset = raw_dataset.map(to_completions, remove_columns=raw_dataset.column_names) diff --git a/src/tamperbench/whitebox/evals/utils.py b/src/tamperbench/whitebox/evals/utils.py index 08ab2e28..efb51f5c 100644 --- a/src/tamperbench/whitebox/evals/utils.py +++ b/src/tamperbench/whitebox/evals/utils.py @@ -1,6 +1,7 @@ """Shared utility functions for evaluations. This module provides reusable functions for common evaluation patterns: +- `format_prompt_completion`: Format prompt-completion pairs for SFTTrainer - `generate`: High-throughput generation with vLLM (sends all prompts at once) - `generate_single`: Sequential generation for cases requiring dependencies between calls - `llm_judge_score`: Score responses using an external LLM judge @@ -21,6 +22,7 @@ from pandera.typing.polars import DataFrame from strong_reject.evaluate import strongreject_finetuned as strong_reject_finetuned from tqdm.contrib.concurrent import thread_map +from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams from tamperbench.whitebox.evals.output_schema import InferenceSchema, ScoreSchema @@ -28,6 +30,49 @@ logger = logging.getLogger(__name__) +def format_prompt_completion( + user_content: str, + assistant_content: str, + tokenizer: PreTrainedTokenizerBase, +) -> dict[str, str]: + """Format a user/assistant exchange into a prompt-completion pair for SFTTrainer. + + Uses ``tokenizer.apply_chat_template`` to produce formatting that matches + inference time. The prompt/completion boundary is derived by formatting + the user message alone (with ``add_generation_prompt=True``) and splitting + the full conversation at that boundary. + + Args: + user_content: The user message text. + assistant_content: The assistant response text. + tokenizer: A tokenizer whose ``chat_template`` has already been configured. + + Returns: + Dict with ``"prompt"`` and ``"completion"`` keys. + """ + prompt = cast( + str, + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + ), + ) + full = cast( + str, + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}, {"role": "assistant", "content": assistant_content}], + tokenize=False, + add_generation_prompt=False, + ), + ) + assert full.startswith(prompt), ( + "apply_chat_template produced a full conversation that doesn't start with the prompt. " + f"prompt={prompt!r}, full={full!r}" + ) + return {"prompt": prompt, "completion": full[len(prompt) :]} + + def llm_judge_score( inferences: DataFrame[InferenceSchema], prompt_template: str, diff --git a/src/tamperbench/whitebox/utils/models/config.py b/src/tamperbench/whitebox/utils/models/config.py index a43a1d06..c82ccf5f 100644 --- a/src/tamperbench/whitebox/utils/models/config.py +++ b/src/tamperbench/whitebox/utils/models/config.py @@ -5,7 +5,6 @@ from typing_extensions import Self -from tamperbench.whitebox.utils.models.templates import get_template from tamperbench.whitebox.utils.names import TemplateName @@ -15,21 +14,12 @@ class ModelConfig: Attributes: template_name: The name of the chat template to use for formatting. - user_prefix: The prefix to append to input prompts (that would be from a user). Similar to instruction - template. - assistant_prefix: The prefix to append to the output prompts. This would mark the portion of a model - generation that is a "completion". This is required when training where loss is computed only on model - completions. - end_turn: The string added at the end of a turn, where a turn is a message from a user or assistant. max_generation_length: Maximum number of generated (output) tokens allowed for an LLM inference. inference_batch_size: Batch size for running inference. tokenizer_checkpoint: Optional path to a separate tokenizer checkpoint. If None, uses model_checkpoint for tokenizer. """ template_name: TemplateName - user_prefix: str - assistant_prefix: str - end_turn: str max_generation_length: int inference_batch_size: int tokenizer_checkpoint: str | None = None @@ -45,21 +35,9 @@ def from_dict(cls, data: dict[str, Any]) -> Self: ModelConfig: Config object. """ data = dict(data) # shallow copy - # Support a template registry to fill in prefixes if provided if "template" in data: template_name_str = data.pop("template") - template_name = TemplateName(template_name_str) - data["template_name"] = template_name - - if template_name == TemplateName.NATIVE: - data.setdefault("user_prefix", "") - data.setdefault("assistant_prefix", "") - data.setdefault("end_turn", "") - else: - template = get_template(template_name) - data.setdefault("user_prefix", template.user_prefix) - data.setdefault("assistant_prefix", template.assistant_prefix) - data.setdefault("end_turn", template.end_turn) + data["template_name"] = TemplateName(template_name_str) data.setdefault("tokenizer_checkpoint", None) diff --git a/src/tamperbench/whitebox/utils/models/templates.py b/src/tamperbench/whitebox/utils/models/templates.py index 203ef3b3..89c7677d 100644 --- a/src/tamperbench/whitebox/utils/models/templates.py +++ b/src/tamperbench/whitebox/utils/models/templates.py @@ -1,18 +1,18 @@ """Registry of text templates for chat prefixes and delimiters. -Use in YAML configs by referencing a template name, and ModelConfig.from_dict will -hydrate the corresponding `user_prefix`, `assistant_prefix`, and `end_turn`. +Use in YAML configs by referencing a template name. ``ModelConfig.from_dict`` +stores the ``template_name``; the actual template is applied to the tokenizer +via ``configure_tokenizer_template`` at runtime. """ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING -from tamperbench.whitebox.utils.names import TemplateName +from transformers import AutoTokenizer, PreTrainedTokenizerBase -if TYPE_CHECKING: - from transformers import PreTrainedTokenizerBase +from tamperbench.whitebox.utils.models.config import ModelConfig +from tamperbench.whitebox.utils.names import TemplateName @dataclass(frozen=True) @@ -107,6 +107,34 @@ def get_template(name: str | TemplateName) -> TextTemplate: raise KeyError(f"Unknown text template: {name}") from e +def load_and_configure_tokenizer( + model_config: ModelConfig, + model_checkpoint: str, +) -> PreTrainedTokenizerBase: + """Load a tokenizer and configure its chat template. + + The tokenizer is loaded from `model_config.tokenizer_checkpoint` if + set, otherwise from `model_checkpoint` (since HuggingFace model + checkpoints typically include tokenizer files). + + Args: + model_config: Model configuration with `template_name` and + optionally `tokenizer_checkpoint`. + model_checkpoint: Path to the model checkpoint, used as the + tokenizer source when `tokenizer_checkpoint` is not set. + + Returns: + A tokenizer with its `chat_template` configured. + """ + tokenizer_path = model_config.tokenizer_checkpoint or model_checkpoint + tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( + tokenizer_path, + trust_remote_code=True, + ) + configure_tokenizer_template(tokenizer, model_config.template_name) + return tokenizer + + def configure_tokenizer_template( tokenizer: PreTrainedTokenizerBase, template_name: TemplateName, diff --git a/tests/attacks/test_embedding_attack.py b/tests/attacks/test_embedding_attack.py index 13c84380..66f373c4 100644 --- a/tests/attacks/test_embedding_attack.py +++ b/tests/attacks/test_embedding_attack.py @@ -23,9 +23,6 @@ def test_embedding_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, ), diff --git a/tests/attacks/test_full_finetune.py b/tests/attacks/test_full_finetune.py index 81e3375e..342eb608 100644 --- a/tests/attacks/test_full_finetune.py +++ b/tests/attacks/test_full_finetune.py @@ -25,9 +25,6 @@ def test_full_finetune_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, ), diff --git a/tests/attacks/test_gcg_attack.py b/tests/attacks/test_gcg_attack.py index 60120690..fe677f91 100644 --- a/tests/attacks/test_gcg_attack.py +++ b/tests/attacks/test_gcg_attack.py @@ -28,9 +28,6 @@ def test_gcg_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=512, inference_batch_size=16, ), diff --git a/tests/attacks/test_jailbreak_finetune.py b/tests/attacks/test_jailbreak_finetune.py index a0cb4c4c..9c1235b9 100644 --- a/tests/attacks/test_jailbreak_finetune.py +++ b/tests/attacks/test_jailbreak_finetune.py @@ -26,9 +26,6 @@ def test_jailbreak_finetune_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, ), diff --git a/tests/attacks/test_lora_attack.py b/tests/attacks/test_lora_attack.py index 77322e67..dadbbcac 100644 --- a/tests/attacks/test_lora_attack.py +++ b/tests/attacks/test_lora_attack.py @@ -25,9 +25,6 @@ def test_lora_finetune_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, ), @@ -65,9 +62,6 @@ def test_lora_finetune_cheap() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.GENERIC_CHAT, - user_prefix="User: ", - assistant_prefix="Assistant: ", - end_turn="\n", max_generation_length=16, inference_batch_size=2, ), diff --git a/tests/attacks/test_prompt_attack.py b/tests/attacks/test_prompt_attack.py index 1cc271d0..dc349278 100644 --- a/tests/attacks/test_prompt_attack.py +++ b/tests/attacks/test_prompt_attack.py @@ -29,9 +29,6 @@ def test_prompt_attack() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=512, inference_batch_size=16, ), diff --git a/tests/evals/test_embedding_attack_eval.py b/tests/evals/test_embedding_attack_eval.py index c96ce832..4c3438ad 100644 --- a/tests/evals/test_embedding_attack_eval.py +++ b/tests/evals/test_embedding_attack_eval.py @@ -26,9 +26,6 @@ def test_embedding_attack_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, tokenizer_checkpoint="meta-llama/Meta-Llama-3-8B-Instruct", diff --git a/tests/evals/test_ifeval.py b/tests/evals/test_ifeval.py index a465d0f5..9e5c992e 100644 --- a/tests/evals/test_ifeval.py +++ b/tests/evals/test_ifeval.py @@ -26,9 +26,6 @@ def test_ifeval_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=512, inference_batch_size=4, ), @@ -39,9 +36,6 @@ def test_ifeval_evaluation() -> None: # out_dir=tmpdirname, # model_config=ModelConfig( # template_name=TemplateName.NATIVE, - # user_prefix="<|im_start|>user\n", - # assistant_prefix="<|im_start|>assistant\n", - # end_turn="<|im_end|>", # max_generation_length=512, # inference_batch_size=4, # ), diff --git a/tests/evals/test_mbpp.py b/tests/evals/test_mbpp.py index 03f941aa..5a77f960 100644 --- a/tests/evals/test_mbpp.py +++ b/tests/evals/test_mbpp.py @@ -23,9 +23,6 @@ def test_mbpp_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=512, inference_batch_size=8, ), diff --git a/tests/evals/test_minerva_math.py b/tests/evals/test_minerva_math.py index cfc68ca3..5093028d 100644 --- a/tests/evals/test_minerva_math.py +++ b/tests/evals/test_minerva_math.py @@ -25,9 +25,6 @@ def test_minerva_math_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=2048, inference_batch_size=4, ), diff --git a/tests/evals/test_mmlu_pro.py b/tests/evals/test_mmlu_pro.py index 846831c4..854ef20c 100644 --- a/tests/evals/test_mmlu_pro.py +++ b/tests/evals/test_mmlu_pro.py @@ -140,9 +140,6 @@ def test_mmlu_pro_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.NATIVE, - user_prefix="<|start_header_id|>user<|end_header_id|>\n\n", - assistant_prefix="<|start_header_id|>assistant<|end_header_id|>\n\n", - end_turn="<|eot_id|>", max_generation_length=1024, inference_batch_size=16, ), diff --git a/tests/evals/test_strong_reject.py b/tests/evals/test_strong_reject.py index cedba4c7..5e0e99dd 100644 --- a/tests/evals/test_strong_reject.py +++ b/tests/evals/test_strong_reject.py @@ -24,9 +24,6 @@ def test_strong_reject_evaluation() -> None: out_dir=tmpdirname, model_config=ModelConfig( template_name=TemplateName.GENERIC_CHAT, - user_prefix="User: ", - assistant_prefix="Assistant: ", - end_turn="\n", max_generation_length=64, inference_batch_size=4, ), diff --git a/tests/test_templates.py b/tests/test_templates.py index 926c5f9e..2d344fa7 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -150,44 +150,21 @@ def test_native_raises_without_template(self) -> None: class TestModelConfigFromDict: """Tests for ModelConfig.from_dict with template support.""" - def test_from_dict_with_template_key(self) -> None: - """from_dict should hydrate template_name and prefixes from template key.""" + @pytest.mark.parametrize( + ("template_str", "expected"), + [ + ("generic_chat", TemplateName.GENERIC_CHAT), + ("native", TemplateName.NATIVE), + ("plain", TemplateName.PLAIN), + ], + ) + def test_from_dict_with_template_key(self, template_str: str, expected: TemplateName) -> None: + """from_dict should store template_name from template key.""" config = ModelConfig.from_dict( { - "template": "generic_chat", + "template": template_str, "max_generation_length": 1024, "inference_batch_size": 16, } ) - assert config.template_name == TemplateName.GENERIC_CHAT - assert config.user_prefix == "User: " - assert config.assistant_prefix == "Assistant: " - assert config.end_turn == "\n" - - def test_from_dict_native_template(self) -> None: - """from_dict with template='native' should set empty prefixes.""" - config = ModelConfig.from_dict( - { - "template": "native", - "max_generation_length": 1024, - "inference_batch_size": 16, - } - ) - assert config.template_name == TemplateName.NATIVE - assert config.user_prefix == "" - assert config.assistant_prefix == "" - assert config.end_turn == "" - - def test_from_dict_plain_template(self) -> None: - """from_dict with template='plain' should set empty prefixes.""" - config = ModelConfig.from_dict( - { - "template": "plain", - "max_generation_length": 1024, - "inference_batch_size": 16, - } - ) - assert config.template_name == TemplateName.PLAIN - assert config.user_prefix == "" - assert config.assistant_prefix == "" - assert config.end_turn == "" + assert config.template_name == expected