From 1fe140cb8199f5634fc3aa33555afb19e8f76e6d Mon Sep 17 00:00:00 2001 From: olexamatej Date: Thu, 12 Mar 2026 16:09:24 +0100 Subject: [PATCH 1/4] refactor: structural_benchmark + added multimodel bench --- structural_benchmark.py | 662 +++++++++++++++++++--------------------- 1 file changed, 307 insertions(+), 355 deletions(-) diff --git a/structural_benchmark.py b/structural_benchmark.py index 8200c9a..a25e6fd 100644 --- a/structural_benchmark.py +++ b/structural_benchmark.py @@ -2,7 +2,16 @@ """ structural_benchmark.py - ROME + Structural Analysis Benchmark -Runs ROME edits and evaluates multiple structural detectors on the modified weights. +Runs ROME edits on one or more models and evaluates structural detectors +on the modified weights. All model parameters (layer, template, hyperparams) +are read from per-model YAML configs in src/config/model/. + +Usage examples: + # Single model (uses src/config/model/gpt2-large.yaml): + python structural_benchmark.py --models gpt2-large + + # Multiple models in one run: + python structural_benchmark.py --models gpt2-large qwen3-4b mistral-7b-v0.1 """ import json @@ -10,7 +19,7 @@ import sys from datetime import datetime from pathlib import Path -from typing import Dict, Optional +from typing import Dict, List import datasets import numpy as np @@ -34,473 +43,416 @@ IPRDetector, ) +CONFIG_DIR = Path(__file__).parent / "src" / "config" +MODEL_CONFIG_DIR = CONFIG_DIR / "model" SPECTRAL_SIGNAL_KEYS = [ - "sv_z_scores", - "sv_ratio_scores", - "sv_z_rolling_z_scores", - "sv_ratio_rolling_z_scores", - "pcs_composite_rank_scores", - "sv_pcs_contradiction_scores", + "sv_z_scores", "sv_ratio_scores", + "sv_z_rolling_z_scores", "sv_ratio_rolling_z_scores", + "pcs_composite_rank_scores", "sv_pcs_contradiction_scores", "rome_hybrid_scores", - "pcs_neighbor_shift_scores", - "pcs_neighbor_var_scores", - "pcs_neighbor_min_shift_scores", - "pcs_neighbor_flip_fraction_scores", - "pcs_next_shift_scores", - "pcs_next_jump_scores", - "pcs_next_curvature_scores", + "pcs_neighbor_shift_scores", "pcs_neighbor_var_scores", + "pcs_neighbor_min_shift_scores", "pcs_neighbor_flip_fraction_scores", + "pcs_next_shift_scores", "pcs_next_jump_scores", "pcs_next_curvature_scores", ] +# Maps proj (output) layer key -> fc (input) layer key across architectures +FC_TEMPLATE_MAP = { + "c_proj": "c_fc", # GPT-2 + "fc_out": "fc_in", # GPT-J + "down_proj": "up_proj", # Llama / Mistral / Qwen +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def to_serializable(obj): """Convert numpy/torch types to JSON-serializable Python types.""" if isinstance(obj, (torch.Tensor, np.ndarray)): - return obj.tolist() if hasattr(obj, 'tolist') else list(obj) - elif isinstance(obj, dict): + return obj.tolist() + if isinstance(obj, dict): return {str(k): to_serializable(v) for k, v in obj.items()} - elif isinstance(obj, (list, tuple)): + if isinstance(obj, (list, tuple)): return [to_serializable(v) for v in obj] - elif isinstance(obj, (np.integer, np.floating)): + if isinstance(obj, (np.integer, np.floating)): return float(obj) - elif isinstance(obj, (np.bool_, bool)): + if isinstance(obj, (np.bool_, bool)): return bool(obj) return obj -def extract_all_weights(handler) -> Dict[int, torch.Tensor]: - """Extract MLP projection weights from all layers.""" - return { - idx: handler._get_module(handler._layer_name_template.format(idx)).weight.detach().clone() - for idx in range(handler.num_of_layers) - } - +def get_fc_template(layer_name_template: str) -> str: + """Derive the fc (input) layer template from the proj (output) layer template.""" + for proj_key, fc_key in FC_TEMPLATE_MAP.items(): + if proj_key in layer_name_template: + return layer_name_template.replace(proj_key, fc_key) + raise ValueError(f"Cannot derive fc layer template from: {layer_name_template}") -def extract_fc_weights(handler) -> Dict[int, torch.Tensor]: - """Extract auxiliary MLP weights with config-first fallback template inference.""" - fc_template = getattr(handler.cfg.model, "fc_layer_name_template", None) - if fc_template is None: - base = handler._layer_name_template - for src, dst in ( - ("c_proj", "c_fc"), - ("down_proj", "up_proj"), - ("fc_out", "fc_in"), - ("output_linear", "input_linear"), - ): - if src in base: - candidate = base.replace(src, dst) - try: - handler._get_module(candidate.format(handler._layer)) - fc_template = candidate - break - except KeyError: - continue - - if fc_template is None: - return {} - return { - idx: handler._get_module(fc_template.format(idx)).weight.detach().clone() - for idx in range(handler.num_of_layers) - } +def load_model_config(model_name: str) -> OmegaConf: + """Load a model config YAML by stem name or by the 'name' field inside the YAML.""" + yaml_path = MODEL_CONFIG_DIR / f"{model_name}.yaml" + if yaml_path.exists(): + return OmegaConf.load(yaml_path) + for path in sorted(MODEL_CONFIG_DIR.glob("*.yaml")): + if path.name == "boilerplate.yaml": + continue + cfg = OmegaConf.load(path) + if getattr(cfg, "name", "") == model_name: + return cfg + available = ", ".join( + p.stem for p in sorted(MODEL_CONFIG_DIR.glob("*.yaml")) if p.name != "boilerplate.yaml" + ) + raise FileNotFoundError(f"No config for '{model_name}'. Available: {available}") -def _load_model_cfg(model_name: str) -> Optional[OmegaConf]: - """Load model YAML by config key or by matching model `name` field.""" - model_cfg_dir = Path(__file__).parent / "src" / "config" / "model" - by_key = model_cfg_dir / f"{model_name}.yaml" - if by_key.exists(): - return OmegaConf.load(by_key) +def build_cfg(model_name: str) -> OmegaConf: + """Compose a full runtime config from per-model YAML + shared dataset/generation defaults.""" + model_cfg = load_model_config(model_name) + return OmegaConf.create({ + "model": model_cfg, + "generation": OmegaConf.load(CONFIG_DIR / "generation" / "generation.yaml"), + "dataset_sm": OmegaConf.load(CONFIG_DIR / "dataset_sm" / "wikitext.yaml"), + }) - for p in sorted(model_cfg_dir.glob("*.yaml")): - cfg = OmegaConf.load(p) - if getattr(cfg, "name", None) == model_name: - return cfg - return None +def extract_weights(handler: ModelHandler, template: str) -> Dict[int, torch.Tensor]: + """Extract weights from all layers using the given name template, moved to CPU.""" + return { + idx: handler._get_module(template.format(idx)).weight.detach().clone().cpu() + for idx in range(handler.num_of_layers) + } def add_ipr_z_scores(summary: Dict[int, Dict[str, float]]) -> Dict[int, Dict[str, float]]: - """Add simple within-matrix z-scores for selected IPR fields.""" + """Add within-distribution z-scores for key IPR metrics.""" if not summary: return summary - layers = sorted(summary.keys()) - metrics = ["global_ipr", "row_ipr_mean", "row_ipr_std"] - - for metric in metrics: - values = np.array([summary[idx][metric] for idx in layers], dtype=float) - mean = values.mean() - std = values.std() + 1e-12 + for metric in ("global_ipr", "row_ipr_mean", "row_ipr_std"): + values = np.array([summary[idx][metric] for idx in layers]) + mean, std = values.mean(), values.std() + 1e-12 for idx in layers: - summary[idx][f"{metric}_z"] = (summary[idx][metric] - mean) / std - + summary[idx][f"{metric}_z"] = float((summary[idx][metric] - mean) / std) return summary -def ipr_delta( - baseline: Dict[int, Dict[str, float]], - modified: Dict[int, Dict[str, float]], -) -> Dict[int, Dict[str, float]]: - """Compute per-layer delta of key IPR metrics.""" - deltas = {} - common_layers = sorted(set(baseline.keys()) & set(modified.keys())) - for idx in common_layers: - deltas[idx] = { - "global_ipr_delta": modified[idx]["global_ipr"] - baseline[idx]["global_ipr"], - "row_ipr_mean_delta": modified[idx]["row_ipr_mean"] - baseline[idx]["row_ipr_mean"], - "row_ipr_std_delta": modified[idx]["row_ipr_std"] - baseline[idx]["row_ipr_std"], - } - return deltas +def per_layer_delta(baseline, modified, fields): + """Compute per-layer delta for the given field names.""" + common = sorted(set(baseline) & set(modified)) + return { + idx: {f"{f}_delta": modified[idx][f] - baseline[idx][f] for f in fields} + for idx in common + } -def ipr_fc_proj_delta( - baseline: Dict[int, Dict[str, float]], - modified: Dict[int, Dict[str, float]], -) -> Dict[int, Dict[str, float]]: - """Compute per-layer delta of fc-vs-proj discrepancy metrics.""" +def spectral_signal_delta(baseline_block, modified_block): + """Per-signal, per-layer delta between two spectral detector outputs.""" deltas = {} - common_layers = sorted(set(baseline.keys()) & set(modified.keys())) - for idx in common_layers: - deltas[idx] = { - "global_ipr_gap_delta": modified[idx]["global_ipr_gap"] - baseline[idx]["global_ipr_gap"], - "global_ipr_ratio_proj_over_fc_delta": modified[idx]["global_ipr_ratio_proj_over_fc"] - baseline[idx]["global_ipr_ratio_proj_over_fc"], - "row_ipr_mean_gap_delta": modified[idx]["row_ipr_mean_gap"] - baseline[idx]["row_ipr_mean_gap"], - "row_ipr_std_gap_delta": modified[idx]["row_ipr_std_gap"] - baseline[idx]["row_ipr_std_gap"], - "row_ipr_median_gap_delta": modified[idx]["row_ipr_median_gap"] - baseline[idx]["row_ipr_median_gap"], - } + for key in SPECTRAL_SIGNAL_KEYS: + base = {int(k): float(v) for k, v in (baseline_block.get(key) or {}).items()} + mod = {int(k): float(v) for k, v in (modified_block.get(key) or {}).items()} + common = sorted(set(base) & set(mod)) + if common: + deltas[key] = {l: mod[l] - base[l] for l in common} return deltas -def spectral_signal_delta( - baseline_block: Dict, - modified_block: Dict, - signal_keys: list[str] | None = None, -) -> Dict[str, Dict[int, float]]: - """Compute per-signal, per-layer delta between baseline and modified spectral results.""" - keys = signal_keys or SPECTRAL_SIGNAL_KEYS - deltas: Dict[str, Dict[int, float]] = {} - - for signal_key in keys: - baseline_scores = {int(k): float(v) for k, v in (baseline_block.get(signal_key, {}) or {}).items()} - modified_scores = {int(k): float(v) for k, v in (modified_block.get(signal_key, {}) or {}).items()} - common_layers = sorted(set(baseline_scores.keys()) & set(modified_scores.keys())) - if not common_layers: +def load_test_cases(n_tests: int, start_idx: int = 0) -> List[dict]: + """Load test cases from the CounterFact dataset.""" + ds = datasets.load_dataset("azhx/counterfact", split="train") + cases = [] + for i, item in enumerate(ds): + if i < start_idx: continue + if len(cases) >= n_tests: + break + rw = item["requested_rewrite"] + cases.append({ + "case_id": item.get("case_id", i), + "subject": rw["subject"], + "fact_tuple": ( + rw["prompt"], rw["subject"], + " " + rw["target_new"]["str"], + " " + rw["target_true"]["str"], + ), + }) + return cases - deltas[signal_key] = { - int(layer): float(modified_scores[layer] - baseline_scores[layer]) - for layer in common_layers - } - return deltas +# --------------------------------------------------------------------------- +# Single-model benchmark +# --------------------------------------------------------------------------- - -def run_benchmark( - model_name: str = "gpt2-large", - n_tests: int = 30, - start_idx: int = 0, - output_dir: str = "./outputs", +def run_single_model( + model_name: str, + test_cases: List[dict], + n_prompts: int, spectral_top_k: int = 50, - trim_first_layers: int = 2, - trim_last_layers: int = 2, - trim_first: int | None = None, - trim_last: int | None = None, + trim_first: int = 2, + trim_last: int = 2, spectral_neighbor_layers: int = 1, - n_prompts: int | None = None, -): - """Run the complete benchmark.""" - - # Config - layer = 8 if "medium" in model_name else 12 if "large" in model_name else 18 if "xl" in model_name else 5 - if n_prompts is None: - n_prompts = 10 if "xl" in model_name else 50 - cfg = OmegaConf.create({ - "model": { - "handler": "gpt2", "name": model_name, "models_dir": "./models", - "second_moment_dir": "./second_moment_stats", - "device": "cuda" if torch.cuda.is_available() else "cpu", - "save_to_local": True, "layer_name_template": "transformer.h.{}.mlp.c_proj", - "layer": layer, "epochs": 25, "lr": 0.5, "kl_factor": 0.0625, "weight_decay": 0.5, - }, - "dataset_sm": { - "name": "wikitext", - "config_name": "wikitext-103-raw-v1", - "concat_splits": ["train", "test", "validation"], - "datasets_dir": "./datasets", - "save_to_local": True, - }, - }) - - model_cfg = _load_model_cfg(model_name) - if model_cfg is not None: - cfg.model = OmegaConf.merge(cfg.model, model_cfg) - cfg.model.device = "cuda" if torch.cuda.is_available() else "cpu" - - LOGGER.info(f"Loading {model_name}...") +) -> dict: + """Run the full ROME + structural-detection benchmark for one model.""" + cfg = build_cfg(model_name) + LOGGER.info("Loading %s ...", cfg.model.name) handler = ModelHandler(cfg) - LOGGER.info(f"ROME prompts: N={n_prompts}") - - # Get baseline weights (on CPU to free GPU memory for ROME edits) - original_weights = {idx: w.cpu() for idx, w in extract_all_weights(handler).items()} - fc_weights = {idx: w.cpu() for idx, w in extract_fc_weights(handler).items()} - - spectral_top_k = max(1, int(spectral_top_k)) - if trim_first is not None: - trim_first_layers = trim_first - if trim_last is not None: - trim_last_layers = trim_last - trim_first_layers = max(0, int(trim_first_layers)) - trim_last_layers = max(0, int(trim_last_layers)) - spectral_neighbor_layers = max(1, int(spectral_neighbor_layers)) + LOGGER.info( + "Loaded. layer=%d, emb=%d, hidden=%d, prompts=%d", + handler._layer, handler.emb_shape, handler.hidden_dim, n_prompts, + ) + + proj_template = handler._layer_name_template + fc_template = get_fc_template(proj_template) + + # Baseline weights (CPU to free GPU for ROME edits) + original_proj = extract_weights(handler, proj_template) + try: + original_fc = extract_weights(handler, fc_template) + has_fc = True + except (KeyError, ValueError): + LOGGER.warning("Could not extract fc weights (%s) - IPR fc-vs-proj disabled", fc_template) + original_fc = None + has_fc = False + # Detectors + blind_detector = BlindMSDDetector() spectral_detector = SpectralDetector( - top_k=spectral_top_k, - boundary=2, - trim_first_layers=trim_first_layers, - trim_last_layers=trim_last_layers, + top_k=spectral_top_k, boundary=2, + trim_first_layers=trim_first, trim_last_layers=trim_last, neighbor_layers=spectral_neighbor_layers, ) - LOGGER.info( - "Spectral config: top_k=%s, boundary=%s, trim_first=%s, trim_last=%s, neighbor_layers=%s", - spectral_top_k, 2, trim_first_layers, trim_last_layers, spectral_neighbor_layers, - ) - - # Load test cases - ds = datasets.load_dataset("azhx/counterfact", split="train") - test_cases = [] - for i, item in enumerate(ds): - if i < start_idx: continue - if len(test_cases) >= n_tests: break - rw = item["requested_rewrite"] - test_cases.append({ - "case_id": item.get("case_id", i), "subject": rw["subject"], - "fact_tuple": (rw["prompt"], rw["subject"], " " + rw["target_new"]["str"], " " + rw["target_true"]["str"]), - }) - - blind_detector = BlindMSDDetector() - baseline_spectral = spectral_detector.detect(original_weights, fc_weights=fc_weights if fc_weights else None) - baseline_ipr_c_proj = add_ipr_z_scores(layer_ipr_summary(original_weights)) - baseline_ipr_c_fc = add_ipr_z_scores(layer_ipr_summary(fc_weights)) if fc_weights else {} - baseline_ipr_fc_vs_proj = layer_fc_proj_ipr_discrepancy(original_weights, fc_weights) if fc_weights else {} + ipr_detector = IPRDetector(trim_first=trim_first, trim_last=trim_last) if has_fc else None - ipr_detector = IPRDetector( - trim_first=trim_first_layers, - trim_last=trim_last_layers, - ) if fc_weights else None + # Baselines + baseline_spectral = spectral_detector.detect(original_proj, fc_weights=original_fc) + baseline_ipr_proj = add_ipr_z_scores(layer_ipr_summary(original_proj)) + baseline_ipr_fc = add_ipr_z_scores(layer_ipr_summary(original_fc)) if has_fc else {} + baseline_ipr_disc = layer_fc_proj_ipr_discrepancy(original_proj, original_fc) if has_fc else {} results = { "metadata": { - "model": model_name, + "model": cfg.model.name, "target_layer": handler._layer, "n_tests": len(test_cases), "n_prompts": n_prompts, "timestamp": datetime.now().isoformat(), "spectral_config": { - "top_k": spectral_top_k, - "boundary": 2, - "trim_first_layers": trim_first_layers, - "trim_last_layers": trim_last_layers, + "top_k": spectral_top_k, "boundary": 2, + "trim_first": trim_first, "trim_last": trim_last, "neighbor_layers": spectral_neighbor_layers, "signal_keys": SPECTRAL_SIGNAL_KEYS, }, }, - "baseline_blind": to_serializable(blind_detector.detect(original_weights)), + "baseline_blind": to_serializable(blind_detector.detect(original_proj)), "baseline_spectral": to_serializable(baseline_spectral), - "baseline_interlayer": to_serializable(collect_all_interlayer_data(original_weights)), + "baseline_interlayer": to_serializable(collect_all_interlayer_data(original_proj)), "baseline_ipr": { - "c_proj": to_serializable(baseline_ipr_c_proj), - "c_fc": to_serializable(baseline_ipr_c_fc), - "fc_vs_proj": to_serializable(baseline_ipr_fc_vs_proj), + "proj": to_serializable(baseline_ipr_proj), + "fc": to_serializable(baseline_ipr_fc), + "fc_vs_proj": to_serializable(baseline_ipr_disc), }, "tests": [], } - successes = {"rome": 0, "normal_detection": 0, "blind_detection": 0, - "spectral_detection": 0, "ipr_detection": 0} - + counts = {k: 0 for k in ("rome", "normal", "blind", "spectral", "ipr")} + for i, case in enumerate(test_cases): - LOGGER.info(f"[{i+1}/{len(test_cases)}] {case['subject']}") - - layer_name = handler._layer_name_template.format(handler._layer) + LOGGER.info("[%d/%d] %s", i + 1, len(test_cases), case["subject"]) + layer_name = proj_template.format(handler._layer) old_W = handler._get_module(layer_name).weight.detach().clone() - - test_entry = {"case_id": case["case_id"], "subject": case["subject"], "error": None} - + entry = {"case_id": case["case_id"], "subject": case["subject"], "error": None} + try: - # ROME edit fact = case["fact_tuple"] + + # ROME edit k = gather_k(handler, fact_tuple=fact, N=n_prompts) - delta = optimize_v(handler, fact_tuple=fact, N_prompts=n_prompts, N_optim_steps=handler.epochs) - new_W, old_W_backup, _ = insert_kv(handler, k, delta) + delta = optimize_v( + handler, fact_tuple=fact, + N_prompts=n_prompts, N_optim_steps=handler.epochs, + ) + new_W, _, _ = insert_kv(handler, k, delta) + + # Check ROME success prompt = fact[0].format(fact[1]) tokens = handler.tokenize_prompt(prompt) with torch.no_grad(): out = handler.model(**tokens) predicted = handler.tokenizer.decode(out.logits[0, -1, :].argmax()) - success = predicted.strip().lower() == fact[2].strip().lower() - - test_entry["rome"] = { - "success": success, "predicted": predicted, + rome_ok = predicted.strip().lower() == fact[2].strip().lower() + + entry["rome"] = { + "success": rome_ok, "predicted": predicted, "k_norm": k.norm().item(), "delta_norm": delta.norm().item(), } - if success: successes["rome"] += 1 - - # Structural detection - modified_weights = {idx: w.clone() for idx, w in original_weights.items()} - modified_weights[handler._layer] = new_W.detach().cpu() - - normal_result = WeightMSDDetector(original_weights).detect(modified_weights) - blind_result = blind_detector.detect(modified_weights) - spectral_result = spectral_detector.detect(modified_weights, fc_weights=fc_weights if fc_weights else None) - spectral_delta = spectral_signal_delta(baseline_spectral, spectral_result) - - modified_fc_weights = extract_fc_weights(handler) - modified_ipr_c_proj = add_ipr_z_scores(layer_ipr_summary(modified_weights)) - modified_ipr_c_fc = add_ipr_z_scores(layer_ipr_summary(modified_fc_weights)) if modified_fc_weights else {} - modified_ipr_fc_vs_proj = layer_fc_proj_ipr_discrepancy(modified_weights, modified_fc_weights) if modified_fc_weights else {} - modified_ipr_c_proj_delta = ipr_delta(baseline_ipr_c_proj, modified_ipr_c_proj) - modified_ipr_c_fc_delta = ipr_delta(baseline_ipr_c_fc, modified_ipr_c_fc) if modified_ipr_c_fc else {} - modified_ipr_fc_vs_proj_delta = ipr_fc_proj_delta(baseline_ipr_fc_vs_proj, modified_ipr_fc_vs_proj) if modified_ipr_fc_vs_proj else {} - - ipr_detection = ( - ipr_detector.detect(modified_weights, fc_weights) - if ipr_detector is not None else - {"anomalous_layer": None, "anomaly_score": 0.0, "reason": "No FC template available."} - ) - - normal_correct = normal_result.get("anomalous_layer") == handler._layer - blind_correct = blind_result.get("anomalous_layer") == handler._layer - spectral_correct = spectral_result.get("anomalous_layer") == handler._layer - ipr_correct = ipr_detection.get("anomalous_layer") == handler._layer if ipr_detector is not None else False - if normal_correct: successes["normal_detection"] += 1 - if blind_correct: successes["blind_detection"] += 1 - if spectral_correct: successes["spectral_detection"] += 1 - if ipr_correct: successes["ipr_detection"] += 1 - - test_entry.update({ - "normal_detection": to_serializable(normal_result), - "blind_detection": to_serializable(blind_result), - "spectral_detection": to_serializable(spectral_result), - "spectral_delta_from_baseline": to_serializable(spectral_delta), - "interlayer": to_serializable(collect_all_interlayer_data(modified_weights)), + if rome_ok: + counts["rome"] += 1 + + # Build modified weight dict (only proj layer changed by ROME) + modified_proj = {idx: w.clone() for idx, w in original_proj.items()} + modified_proj[handler._layer] = new_W.detach().cpu() + + # Run detectors + normal_res = WeightMSDDetector(original_proj).detect(modified_proj) + blind_res = blind_detector.detect(modified_proj) + spectral_res = spectral_detector.detect(modified_proj, fc_weights=original_fc) + ipr_res = ipr_detector.detect(modified_proj, original_fc) if ipr_detector else {} + + # IPR analysis + mod_ipr_proj = add_ipr_z_scores(layer_ipr_summary(modified_proj)) + mod_ipr_disc = layer_fc_proj_ipr_discrepancy(modified_proj, original_fc) if has_fc else {} + + correct = { + "normal": normal_res.get("anomalous_layer") == handler._layer, + "blind": blind_res.get("anomalous_layer") == handler._layer, + "spectral": spectral_res.get("anomalous_layer") == handler._layer, + "ipr": ipr_res.get("anomalous_layer") == handler._layer if ipr_res else False, + } + for name, ok in correct.items(): + if ok: + counts[name] += 1 + + entry.update({ + "normal_detection": to_serializable(normal_res), + "blind_detection": to_serializable(blind_res), + "spectral_detection": to_serializable(spectral_res), + "spectral_delta": to_serializable(spectral_signal_delta(baseline_spectral, spectral_res)), + "interlayer": to_serializable(collect_all_interlayer_data(modified_proj)), "ipr": { - "c_proj": to_serializable(modified_ipr_c_proj), - "c_proj_delta_from_baseline": to_serializable(modified_ipr_c_proj_delta), - "c_fc": to_serializable(modified_ipr_c_fc), - "c_fc_delta_from_baseline": to_serializable(modified_ipr_c_fc_delta), - "fc_vs_proj": to_serializable(modified_ipr_fc_vs_proj), - "fc_vs_proj_delta_from_baseline": to_serializable(modified_ipr_fc_vs_proj_delta), - "ipr_detection": to_serializable(ipr_detection), + "proj": to_serializable(mod_ipr_proj), + "proj_delta": to_serializable(per_layer_delta( + baseline_ipr_proj, mod_ipr_proj, + ["global_ipr", "row_ipr_mean", "row_ipr_std"], + )), + "fc_vs_proj": to_serializable(mod_ipr_disc), + "fc_vs_proj_delta": to_serializable(per_layer_delta( + baseline_ipr_disc, mod_ipr_disc, + ["global_ipr_gap", "global_ipr_ratio_proj_over_fc", + "row_ipr_mean_gap", "row_ipr_std_gap", "row_ipr_median_gap"], + )) if has_fc else {}, + "detection": to_serializable(ipr_res), }, "accuracy": { - "rome_success": success, - "normal_correct": normal_correct, - "blind_correct": blind_correct, - "spectral_correct": spectral_correct, - "ipr_correct": ipr_correct, + "rome_success": rome_ok, + **{f"{name}_correct": v for name, v in correct.items()}, }, }) - if modified_ipr_c_proj_delta: - strongest_ipr_layer = max( - modified_ipr_c_proj_delta, - key=lambda l: abs(modified_ipr_c_proj_delta[l]["global_ipr_delta"]), - ) - strongest_ipr_delta = modified_ipr_c_proj_delta[strongest_ipr_layer]["global_ipr_delta"] - else: - strongest_ipr_layer = None - strongest_ipr_delta = 0.0 - LOGGER.info( - f" ROME: {'OK' if success else 'FAIL'}, " - f"Normal: layer {normal_result.get('anomalous_layer')}, " - f"Blind: layer {blind_result.get('anomalous_layer')}, " - f"Spectral: layer {spectral_result.get('anomalous_layer')}, " - f"IPR-detect: layer {ipr_detection.get('anomalous_layer')} " - f"(score={ipr_detection.get('anomaly_score', 0):.3f}), " - f"IPR strongest delta layer: {strongest_ipr_layer} ({strongest_ipr_delta:.4e})" + " ROME=%s Normal=L%s Blind=L%s Spectral=L%s IPR=L%s(%.3f)", + "OK" if rome_ok else "FAIL", + normal_res.get("anomalous_layer"), + blind_res.get("anomalous_layer"), + spectral_res.get("anomalous_layer"), + ipr_res.get("anomalous_layer", "N/A"), + ipr_res.get("anomaly_score", 0), ) - + except Exception as e: - test_entry["error"] = str(e) - test_entry["skipped"] = True - LOGGER.warning(f" SKIPPED - Edit failed: {e}") + entry["error"] = str(e) + entry["skipped"] = True + LOGGER.warning(" SKIPPED: %s", e) finally: handler.remove_hooks() handler._get_module(layer_name).weight = torch.nn.Parameter(old_W) - results["tests"].append(test_entry) + results["tests"].append(entry) if torch.cuda.is_available(): torch.cuda.empty_cache() - + # Summary - successful_tests = [t for t in results["tests"] if not t.get("skipped", False)] - n = len(successful_tests) + ok_tests = [t for t in results["tests"] if not t.get("skipped")] + n = len(ok_tests) results["summary"] = { - "total_tests": len(test_cases), - "successful_tests": n, - "skipped_tests": len(test_cases) - n, - "rome_success_rate": successes["rome"] / n if n else 0, - "normal_detection_accuracy": successes["normal_detection"] / n if n else 0, - "blind_detection_accuracy": successes["blind_detection"] / n if n else 0, - "spectral_detection_accuracy": successes["spectral_detection"] / n if n else 0, - "ipr_detection_accuracy": successes["ipr_detection"] / n if n else 0, + "total": len(test_cases), "successful": n, "skipped": len(test_cases) - n, + **{f"{k}_rate": counts[k] / n if n else 0 for k in counts}, } - LOGGER.info( - f"Summary: ROME {successes['rome']}/{n}, " - f"Normal {successes['normal_detection']}/{n}, " - f"Blind {successes['blind_detection']}/{n}, " - f"Spectral {successes['spectral_detection']}/{n}, " - f"IPR {successes['ipr_detection']}/{n} " - f"(skipped {len(test_cases) - n})" - ) - LOGGER.info( - "Success rates: " - f"ROME={results['summary']['rome_success_rate']:.1%}, " - f"Normal={results['summary']['normal_detection_accuracy']:.1%}, " - f"Blind={results['summary']['blind_detection_accuracy']:.1%}, " - f"Spectral={results['summary']['spectral_detection_accuracy']:.1%}, " - f"IPR={results['summary']['ipr_detection_accuracy']:.1%}" + "[%s] ROME=%d/%d Normal=%d/%d Blind=%d/%d Spectral=%d/%d IPR=%d/%d skip=%d", + cfg.model.name, counts["rome"], n, counts["normal"], n, counts["blind"], n, + counts["spectral"], n, counts["ipr"], n, len(test_cases) - n, ) - + # Free GPU + del handler + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return results + + +# --------------------------------------------------------------------------- +# Multi-model entry point +# --------------------------------------------------------------------------- + +def run_benchmark( + models: List[str], + n_tests: int = 30, + start_idx: int = 0, + output_dir: str = "./analysis_out", + n_prompts: int = 10, + spectral_top_k: int = 50, + trim_first: int = 2, + trim_last: int = 2, + spectral_neighbor_layers: int = 1, +): + """Run the benchmark across one or more models, saving results per model.""" + test_cases = load_test_cases(n_tests, start_idx) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - safe_model_name = cfg.model.name.replace("/", "-") - output_file = output_path / f"rome_structural_{safe_model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" - - with open(output_file, "w") as f: - json.dump(to_serializable(results), f, indent=2) - - LOGGER.info(f"Saved to: {output_file}") - return results + + all_results = {} + + for model_name in models: + LOGGER.info("=" * 60) + LOGGER.info("Benchmark: %s", model_name) + LOGGER.info("=" * 60) + + model_results = run_single_model( + model_name, test_cases, n_prompts, + spectral_top_k=spectral_top_k, + trim_first=trim_first, trim_last=trim_last, + spectral_neighbor_layers=spectral_neighbor_layers, + ) + + safe_name = model_name.replace("/", "_").replace("\\", "_") + ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + out_file = output_path / f"rome_structural_{safe_name}_{ts}.json" + with open(out_file, "w") as f: + json.dump(to_serializable(model_results), f, indent=2) + LOGGER.info("Saved: %s", out_file) + + all_results[model_name] = model_results + + return all_results if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="gpt2-large") + parser = argparse.ArgumentParser(description="ROME + Structural Analysis Benchmark") + parser.add_argument("--models", "--model", nargs="+", default=["gpt2-large"], + help="Model config names (YAML stems or HF model names)") parser.add_argument("--n-tests", type=int, default=30) parser.add_argument("--start-idx", type=int, default=0) parser.add_argument("--output-dir", default="./analysis_out") + parser.add_argument("--n-prompts", type=int, default=10) parser.add_argument("--spectral-top-k", type=int, default=50) - parser.add_argument("--trim-first-layers", "--trim-first", dest="trim_first_layers", type=int, default=2) - parser.add_argument("--trim-last-layers", "--trim-last", dest="trim_last_layers", type=int, default=2) + parser.add_argument("--trim-first", type=int, default=2) + parser.add_argument("--trim-last", type=int, default=2) parser.add_argument("--spectral-neighbor-layers", type=int, default=1) - parser.add_argument("--n-prompts", type=int, default=None, - help="Number of prefixes for ROME (auto-scales by model size if omitted)") args = parser.parse_args() run_benchmark( - args.model, - args.n_tests, - args.start_idx, - args.output_dir, - args.spectral_top_k, - args.trim_first_layers, - args.trim_last_layers, - spectral_neighbor_layers=args.spectral_neighbor_layers, + models=args.models, + n_tests=args.n_tests, + start_idx=args.start_idx, + output_dir=args.output_dir, n_prompts=args.n_prompts, + spectral_top_k=args.spectral_top_k, + trim_first=args.trim_first, + trim_last=args.trim_last, + spectral_neighbor_layers=args.spectral_neighbor_layers, ) From 7b9c7cb22fcceb0f8909449ec25fe27757ea2008 Mon Sep 17 00:00:00 2001 From: olexamatej Date: Thu, 12 Mar 2026 16:36:03 +0100 Subject: [PATCH 2/4] feat: optimalizations, multigpu support, svd on gpu, dynamic batch size --- src/handlers/rome.py | 22 ++++- src/rome/common.py | 60 +++++++++--- src/structural/blind_detector.py | 8 +- src/structural/detector.py | 3 +- src/structural/groupers.py | 5 +- src/structural/interlayer.py | 6 +- src/structural/spectral_detector.py | 8 +- src/utils.py | 136 +++++++++++++++++++++++++++- structural_benchmark.py | 5 +- 9 files changed, 223 insertions(+), 30 deletions(-) diff --git a/src/handlers/rome.py b/src/handlers/rome.py index 57b136c..c8fe1eb 100644 --- a/src/handlers/rome.py +++ b/src/handlers/rome.py @@ -53,11 +53,18 @@ def __init__(self, cfg: DictConfig) -> None: self.dtype = self.model.dtype self.num_of_layers = self.model.config.num_hidden_layers + # Multi-GPU: detect if model is distributed via device_map + self.is_multi_gpu = hasattr(self.model, 'hf_device_map') and len(self.model.hf_device_map) > 1 + # Initialize DeviceManager for CUDA-safe operations device = getattr(cfg.model, "device", "cuda") cuda_mode = getattr(cfg.model, "cuda_mode", CUDAMode.SOFT) self.device_manager = DeviceManager(device, cuda_mode) - self.device = self.device_manager.get_device() + if self.is_multi_gpu: + # For multi-GPU, use the device of the target ROME layer + self.device = self.device_manager.get_device() + else: + self.device = self.device_manager.get_device() self.batch_size = getattr(self.cfg.generation, "batch_size", 1) if hasattr(self.cfg, "generation") else 1 @@ -213,6 +220,19 @@ def _get_module(self, module_name: str) -> torch.nn.Module: raise KeyError(f"{module_name} not found") + def get_module_device(self, module_name: str = None) -> torch.device: + """Return the device a specific module's parameters live on. + Useful for multi-GPU setups where layers are on different devices. + Falls back to ``self.device`` if the module has no parameters. + """ + if module_name is None: + module_name = self._layer_name_template.format(self._layer) + module = self._get_module(module_name) + try: + return next(module.parameters()).device + except StopIteration: + return torch.device(self.device) + def register_casual_hooks(self) -> None: """ """ diff --git a/src/rome/common.py b/src/rome/common.py index 9f17e91..a144bc9 100644 --- a/src/rome/common.py +++ b/src/rome/common.py @@ -582,8 +582,15 @@ def delta_hook(module, _, output): return delta -def insert_kv(handler, k: torch.Tensor, delta: torch.Tensor) -> None: - old_W = handler._get_module(handler._layer_name_template.format(handler._layer)).weight.clone() # extract from the model +def insert_kv(handler: ModelHandler, k: torch.Tensor, delta: torch.Tensor) -> None: + layer_name = handler._layer_name_template.format(handler._layer) + # For multi-GPU, use the device where this layer actually lives + if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: + layer_device = handler.get_module_device(layer_name) + else: + layer_device = handler.device + + old_W = handler._get_module(layer_name).weight.clone() # Fix the transposed models old_W_transposed = False @@ -591,7 +598,9 @@ def insert_kv(handler, k: torch.Tensor, delta: torch.Tensor) -> None: old_W = torch.transpose(old_W,0,1) old_W_transposed = True - inv_cov = get_second_moment(handler).to(handler.dtype).to(handler.device) + inv_cov = get_second_moment(handler).to(handler.dtype).to(layer_device) + k = k.to(layer_device) + delta = delta.to(layer_device) left = inv_cov @ k.unsqueeze(1) left = left.squeeze() left = left / left.norm() @@ -627,7 +636,7 @@ def second_moment_wikipedia(handler, N_rounds, N_k): Returns C^-1 (needed for ROME weight update formula) """ - from src.utils import load_dataset + from src.utils import load_dataset, estimate_covariance_batch_size layer_name = handler._layer_name_template.format(handler._layer) module = handler._get_module(layer_name) @@ -637,15 +646,23 @@ def second_moment_wikipedia(handler, N_rounds, N_k): max_length = getattr(handler.model.config, 'n_positions', getattr(handler.model.config, 'max_position_embeddings', 1024)) + # For multi-GPU models, determine the device of the target module + if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: + module_device = handler.get_module_device(layer_name) + else: + module_device = handler.device + # Accumulate second moment directly on GPU instead of storing all k vectors - C = torch.zeros(hidden_dim, hidden_dim, dtype=torch.float32, device=handler.device) - total_tokens = 0 # Use list to allow modification in hook + C = torch.zeros(hidden_dim, hidden_dim, dtype=torch.float32, device=module_device) + total_tokens = 0 def hook(_, inp, out): nonlocal C, total_tokens k = inp[0].detach().float() if isinstance(inp, tuple) else inp.detach().float() if len(k.shape) == 3: k = k.view(-1, k.shape[-1]) # [batch*seq, hidden] + # Ensure k is on the same device as C + k = k.to(C.device) total_tokens += k.shape[0] C.add_(k.T @ k) return out @@ -653,10 +670,24 @@ def hook(_, inp, out): handle = module.register_forward_hook(hook) n_samples = N_rounds * N_k if N_rounds and N_k else 5000 - batch_size = 8 # Process multiple texts at once + + # Dynamic batch size based on available VRAM + dtype_bytes = 2 if handler.dtype in (torch.float16, torch.bfloat16) else 4 + batch_size = estimate_covariance_batch_size( + hidden_dim=hidden_dim, + max_length=max_length, + dtype_bytes=dtype_bytes, + device=module_device, + ) LOGGER.info(f"Starting covariance computation: {n_samples} samples, batch_size={batch_size}, max_length={max_length}") ds = load_dataset(handler.cfg, sm=True) + + # For multi-GPU models, get the device for inputs (first device in the pipeline) + if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: + input_device = next(handler.model.parameters()).device + else: + input_device = handler.device processed = 0 batch_texts = [] @@ -682,12 +713,15 @@ def hook(_, inp, out): max_length=max_length, padding=True ) - handler.model(tokens.input_ids.to(handler.device), - attention_mask=tokens.attention_mask.to(handler.device)) + handler.model(tokens.input_ids.to(input_device), + attention_mask=tokens.attention_mask.to(input_device)) processed += len(batch_texts) + except torch.cuda.OutOfMemoryError: + LOGGER.warning("OOM during covariance computation, halving batch size") + torch.cuda.empty_cache() + batch_size = max(1, batch_size // 2) except Exception as e: LOGGER.warning(e) - pass # Skip failed batches batch_texts = [] # Clear GPU cache periodically torch.cuda.empty_cache() @@ -702,10 +736,10 @@ def hook(_, inp, out): max_length=max_length, padding=True ) - handler.model(tokens.input_ids.to(handler.device), - attention_mask=tokens.attention_mask.to(handler.device)) + handler.model(tokens.input_ids.to(input_device), + attention_mask=tokens.attention_mask.to(input_device)) processed += len(batch_texts) - except: + except Exception: pass handle.remove() diff --git a/src/structural/blind_detector.py b/src/structural/blind_detector.py index 0253022..37e8252 100644 --- a/src/structural/blind_detector.py +++ b/src/structural/blind_detector.py @@ -4,6 +4,7 @@ from sklearn.ensemble import IsolationForest from .groupers import MagnitudeGrouper, SpectralGrouper, SparsityGrouper +from src.utils import gpu_svd, gpu_svdvals class BlindMSDDetector: @@ -54,9 +55,8 @@ def _compute_spectral_features(self, W: torch.Tensor) -> Dict[str, float]: if W.shape[0] < 2 or W.shape[1] < 2: return {} - W_float = W.float() try: - S = torch.linalg.svdvals(W_float) + S = gpu_svdvals(W) except Exception: return {} @@ -157,7 +157,7 @@ def blind_layer_msd( layer_features = {} for idx, W in weights.items(): W_float = W.float() - U, S, V = torch.svd(W_float) + U, S, V = gpu_svd(W) # effective rank normalized_S = S / (S.sum() + 1e-10) @@ -248,7 +248,7 @@ def blind_neuron_group_msd(self, W: torch.Tensor) -> Dict[str, float]: row_norms = W_float.norm(dim=1) # per row spectral contrib - U, S, V = torch.svd(W_float) + U, S, V = gpu_svd(W) top_k = min(10, S.shape[0]) row_spectral_contrib = U[:, :top_k].abs().sum(dim=1) diff --git a/src/structural/detector.py b/src/structural/detector.py index f59758a..d64d379 100644 --- a/src/structural/detector.py +++ b/src/structural/detector.py @@ -3,6 +3,7 @@ import torch from .groupers import MagnitudeGrouper, SpectralGrouper from .metrics import l2_discrepancy +from src.utils import gpu_svd class WeightMSDDetector: @@ -129,7 +130,7 @@ def check_rome_signature( } try: - U, S, V = torch.svd(delta_float) + U, S, V = gpu_svd(delta_float, full_matrices=False) except Exception as e: return { "rank_one_score": 0.0, diff --git a/src/structural/groupers.py b/src/structural/groupers.py index 4ab427f..bde438f 100644 --- a/src/structural/groupers.py +++ b/src/structural/groupers.py @@ -9,6 +9,8 @@ import torch from typing import Dict, List +from src.utils import gpu_svd + class MagnitudeGrouper: """Group neurons by L2 norms of their weight rows""" @@ -71,8 +73,7 @@ def __init__(self, top_k: int = 10): self.top_k = top_k def group(self, W: torch.Tensor) -> Dict[str, List[int]]: - W_float = W.float() # SVD requires float - U, S, V = torch.svd(W_float) + U, S, V = gpu_svd(W, full_matrices=False) top_contribution = U[:, : self.top_k].abs().sum(dim=1) median = top_contribution.median() diff --git a/src/structural/interlayer.py b/src/structural/interlayer.py index 0b46ece..fad053c 100644 --- a/src/structural/interlayer.py +++ b/src/structural/interlayer.py @@ -2,13 +2,15 @@ import torch import numpy as np +from src.utils import gpu_svdvals + EPS = 1e-10 def compute_layer_features(W: torch.Tensor) -> Dict[str, float]: """SVD-based feature vector for one weight matrix. See short-docs.md.""" W_float = W.float() - S = torch.linalg.svdvals(W_float) + S = gpu_svdvals(W) S_sq = S ** 2 total_energy = S_sq.sum() + EPS norms = W_float.norm(dim=1) @@ -161,7 +163,7 @@ def cross_layer_fingerprint(weights: Dict[int, torch.Tensor], n_sv: int = 20) -> fingerprints = {} fp_matrix = [] for idx in layer_indices: - S = torch.linalg.svdvals(weights[idx].float())[:n_sv] + S = gpu_svdvals(weights[idx])[:n_sv] S_norm = (S / (S.sum() + EPS)).cpu().numpy() fingerprints[idx] = S_norm fp_matrix.append(S_norm) diff --git a/src/structural/spectral_detector.py b/src/structural/spectral_detector.py index 2e51114..09c2059 100644 --- a/src/structural/spectral_detector.py +++ b/src/structural/spectral_detector.py @@ -3,6 +3,8 @@ import numpy as np import torch +from src.utils import gpu_svd, gpu_svdvals + EPS = 1e-10 _PCS_NAMES = ( @@ -29,7 +31,9 @@ # --------------------------------------------------------------------------- def _svd_all(weights: Dict[int, torch.Tensor], max_k: int) -> Tuple[list[int], np.ndarray, np.ndarray, np.ndarray]: - """Full SVD per layer -> (layers, sv[L,k], vh[L,k,d_out], u[L,k,d_in]).""" + """Full SVD per layer -> (layers, sv[L,k], vh[L,k,d_out], u[L,k,d_in]). + Uses GPU for SVD when sufficient VRAM is available. + """ layers = sorted(weights.keys()) if not layers: e2 = np.empty((0, 0), dtype=np.float32) @@ -37,7 +41,7 @@ def _svd_all(weights: Dict[int, torch.Tensor], max_k: int) -> Tuple[list[int], n return [], e2, e3, e3 sv_list, vh_list, u_list = [], [], [] for l in layers: - u, s, vh = torch.linalg.svd(weights[l].detach().float(), full_matrices=False) + u, s, vh = gpu_svd(weights[l].detach(), full_matrices=False) sv_list.append(s.cpu().numpy()) vh_list.append(vh.cpu().numpy()) u_list.append(u.cpu().numpy().T) diff --git a/src/utils.py b/src/utils.py index 0b52a6b..a39e8bc 100644 --- a/src/utils.py +++ b/src/utils.py @@ -165,6 +165,113 @@ def _handle_oom(self, data: Any, device: str, error: Exception) -> Any: raise SystemExit("Unknown CUDA mode. Cannot continue.") from error +# --------------------------------------------------------------------------- +# GPU / VRAM utilities +# --------------------------------------------------------------------------- + +def gpu_count() -> int: + """Return the number of available CUDA GPUs.""" + return torch.cuda.device_count() if torch.cuda.is_available() else 0 + + +def get_free_vram(device: int | str = 0) -> int: + """Return free VRAM in bytes for the given CUDA device. Returns 0 if CUDA unavailable.""" + if not torch.cuda.is_available(): + return 0 + if isinstance(device, str): + if device == "cpu": + return 0 + device = int(device.replace("cuda:", "").replace("cuda", "0") or "0") + free, _ = torch.cuda.mem_get_info(device) + return free + + +def get_total_vram(device: int | str = 0) -> int: + """Return total VRAM in bytes for the given CUDA device.""" + if not torch.cuda.is_available(): + return 0 + if isinstance(device, str): + if device == "cpu": + return 0 + device = int(device.replace("cuda:", "").replace("cuda", "0") or "0") + _, total = torch.cuda.mem_get_info(device) + return total + + +def estimate_covariance_batch_size( + hidden_dim: int, + max_length: int, + dtype_bytes: int = 2, + device: int | str = 0, + vram_fraction: float = 0.4, + min_batch: int = 1, + max_batch: int = 64, +) -> int: + """Estimate a safe batch size for covariance computation based on free VRAM. + + Heuristic: each sample needs roughly ``max_length * hidden_dim * dtype_bytes`` + bytes for activations, plus the outer-product accumulator. + We target using at most *vram_fraction* of currently free VRAM. + """ + free = get_free_vram(device) + if free == 0: + return min_batch + + per_sample = max_length * hidden_dim * dtype_bytes * 3 # fwd activations ~3x + available = int(free * vram_fraction) + bs = max(min_batch, min(max_batch, available // max(per_sample, 1))) + LOGGER.info( + "Dynamic covariance batch size: %d (free_vram=%.1fGB, per_sample≈%.1fMB)", + bs, free / 1e9, per_sample / 1e6, + ) + return bs + + +def gpu_svd(W: torch.Tensor, full_matrices: bool = False, device: int | str = 0, + vram_fraction: float = 0.5) -> tuple: + """Try to run SVD on GPU if enough VRAM is available, else fall back to CPU. + + Returns (U, S, Vh) from ``torch.linalg.svd``. + """ + matrix_bytes = W.numel() * W.element_size() + # SVD workspace ≈ 3-4x the matrix size + required = int(matrix_bytes * 5) + free = get_free_vram(device) + + if free > 0 and free * vram_fraction > required: + try: + if isinstance(device, int): + device = f"cuda:{device}" + W_gpu = W.to(device).float() + U, S, Vh = torch.linalg.svd(W_gpu, full_matrices=full_matrices) + return U.cpu(), S.cpu(), Vh.cpu() + except torch.cuda.OutOfMemoryError: + LOGGER.debug("GPU SVD OOM, falling back to CPU") + torch.cuda.empty_cache() + + return torch.linalg.svd(W.float().cpu(), full_matrices=full_matrices) + + +def gpu_svdvals(W: torch.Tensor, device: int | str = 0, + vram_fraction: float = 0.5) -> torch.Tensor: + """Try to compute singular values on GPU, fall back to CPU.""" + matrix_bytes = W.numel() * W.element_size() + required = int(matrix_bytes * 4) + free = get_free_vram(device) + + if free > 0 and free * vram_fraction > required: + try: + if isinstance(device, int): + device = f"cuda:{device}" + S = torch.linalg.svdvals(W.to(device).float()) + return S.cpu() + except torch.cuda.OutOfMemoryError: + LOGGER.debug("GPU svdvals OOM, falling back to CPU") + torch.cuda.empty_cache() + + return torch.linalg.svdvals(W.float().cpu()) + + def check_device(device: str) -> str: """ Check if the device is valid and return the appropriate device. @@ -217,6 +324,17 @@ def load_pretrained(cfg: DictConfig) -> Any: device = check_device(device) device_manager = DeviceManager(device, cuda_mode) + # Multi-GPU: use device_map="auto" when >1 GPU available and not disabled + multi_gpu = getattr(cfg.model, "multi_gpu", "auto") + n_gpus = gpu_count() + use_device_map = ( + multi_gpu == "auto" and n_gpus > 1 and device != "cpu" + ) or ( + multi_gpu is True or str(multi_gpu).lower() == "true" + ) + if use_device_map: + LOGGER.info(f"Multi-GPU enabled: distributing model across {n_gpus} GPUs") + models_dir = getattr( cfg.model, "models_dir", os.path.join(os.path.dirname(__file__), "./models") ) @@ -225,8 +343,13 @@ def load_pretrained(cfg: DictConfig) -> Any: if os.path.exists(local_model_path): LOGGER.info(f"Loading model from local cache: {local_model_path}") - model = AutoModelForCausalLM.from_pretrained(local_model_path, dtype=dtype) - model = device_manager.safe_to_device(model) + if use_device_map: + model = AutoModelForCausalLM.from_pretrained( + local_model_path, torch_dtype=dtype, device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(local_model_path, dtype=dtype) + model = device_manager.safe_to_device(model) device_manager.register_object(model) tokenizer = AutoTokenizer.from_pretrained(local_model_path) if tokenizer.pad_token == None: @@ -242,8 +365,13 @@ def load_pretrained(cfg: DictConfig) -> Any: else: # Model not present locally, download from HuggingFace Hub LOGGER.info(f"Downloading model from HuggingFace Hub: {model_name}") - model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype) - model = device_manager.safe_to_device(model) + if use_device_map: + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map="auto", + ) + else: + model = AutoModelForCausalLM.from_pretrained(model_name, dtype=dtype) + model = device_manager.safe_to_device(model) device_manager.register_object(model) tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token == None: diff --git a/structural_benchmark.py b/structural_benchmark.py index a25e6fd..6abdd97 100644 --- a/structural_benchmark.py +++ b/structural_benchmark.py @@ -200,8 +200,9 @@ def run_single_model( LOGGER.info("Loading %s ...", cfg.model.name) handler = ModelHandler(cfg) LOGGER.info( - "Loaded. layer=%d, emb=%d, hidden=%d, prompts=%d", + "Loaded. layer=%d, emb=%d, hidden=%d, prompts=%d, multi_gpu=%s", handler._layer, handler.emb_shape, handler.hidden_dim, n_prompts, + handler.is_multi_gpu, ) proj_template = handler._layer_name_template @@ -238,6 +239,8 @@ def run_single_model( "target_layer": handler._layer, "n_tests": len(test_cases), "n_prompts": n_prompts, + "multi_gpu": handler.is_multi_gpu, + "n_gpus": torch.cuda.device_count() if torch.cuda.is_available() else 0, "timestamp": datetime.now().isoformat(), "spectral_config": { "top_k": spectral_top_k, "boundary": 2, From f0f6377e50ca0b1aa3eab7f5841eea83c791c9ca Mon Sep 17 00:00:00 2001 From: olexamatej Date: Thu, 12 Mar 2026 19:11:28 +0100 Subject: [PATCH 3/4] fix: circular dependency --- src/rome/common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/rome/common.py b/src/rome/common.py index a144bc9..72d2f92 100644 --- a/src/rome/common.py +++ b/src/rome/common.py @@ -9,18 +9,23 @@ :author: Jakub Res """ +from __future__ import annotations + from pathlib import Path import copy import re import random import json import torch -from typing import Tuple, List +from typing import Tuple, List, TYPE_CHECKING from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from enum import Enum import numpy as np +if TYPE_CHECKING: + from src.handlers.rome import ModelHandler + import logging LOGGER = logging.getLogger(__name__) From 43ec83bab02aa459c246515df845ec1f201e8670 Mon Sep 17 00:00:00 2001 From: olexamatej Date: Mon, 23 Mar 2026 14:41:50 +0100 Subject: [PATCH 4/4] fix: multigpu support --- src/cli.py | 3 +- src/handlers/base.py | 2 +- src/handlers/rome.py | 8 ++--- src/rome/common.py | 73 ++++++++++++++++++++++++++++++-------------- src/rome/rome.py | 4 +-- 5 files changed, 59 insertions(+), 31 deletions(-) diff --git a/src/cli.py b/src/cli.py index 04ae351..1c2898b 100644 --- a/src/cli.py +++ b/src/cli.py @@ -59,9 +59,10 @@ def run_rome(cfg: DictConfig | argparse.Namespace) -> None: # ROME success test prompt = handler.tokenize_prompt(fact_tuple[0].format(fact_tuple[1])) + target_token_count = int(handler.tokenize_prompt(fact_tuple[2]).input_ids.shape[1]) outputs = handler.model.generate( **prompt, - max_length=prompt.input_ids.shape[1] + len(handler.tokenize_prompt(f" {fact_tuple[2]}")[0]) - 1, + max_length=prompt.input_ids.shape[1] + target_token_count, ) print(handler.tokenizer.batch_decode(outputs)) diff --git a/src/handlers/base.py b/src/handlers/base.py index bcd5579..8b4832c 100644 --- a/src/handlers/base.py +++ b/src/handlers/base.py @@ -87,6 +87,6 @@ def tokenize_prompt(self, prompt_text: str | List[str], apply_template: bool = F inputs = self.tokenizer(prompt_text, return_tensors="pt") - inputs = self.device_manager.safe_to_device(inputs) + inputs = self.device_manager.safe_to_device(inputs, device=self.device) return inputs \ No newline at end of file diff --git a/src/handlers/rome.py b/src/handlers/rome.py index c8fe1eb..0bf00ef 100644 --- a/src/handlers/rome.py +++ b/src/handlers/rome.py @@ -61,8 +61,8 @@ def __init__(self, cfg: DictConfig) -> None: cuda_mode = getattr(cfg.model, "cuda_mode", CUDAMode.SOFT) self.device_manager = DeviceManager(device, cuda_mode) if self.is_multi_gpu: - # For multi-GPU, use the device of the target ROME layer - self.device = self.device_manager.get_device() + # Use a concrete CUDA device (e.g., cuda:0) for model inputs. + self.device = next(self.model.parameters()).device else: self.device = self.device_manager.get_device() @@ -103,7 +103,7 @@ def __init__(self, cfg: DictConfig) -> None: self.v = None # Use device_manager for safe device placement self.delta = torch.zeros((self.emb_shape), dtype=self.dtype) - self.delta = self.device_manager.safe_to_device(self.delta).requires_grad_(True) + self.delta = self.device_manager.safe_to_device(self.delta, device=self.device).requires_grad_(True) self.second_moment_path = getattr(cfg.model, "second_moment_path", None) @@ -204,7 +204,7 @@ def remove_hooks(self) -> None: self._emb_accumulator = [] self.delta = torch.zeros((self.emb_shape)) - self.delta = self.device_manager.safe_to_device(self.delta).requires_grad_(True) + self.delta = self.device_manager.safe_to_device(self.delta, device=self.device).requires_grad_(True) for handle in self._hooks: handle.remove() diff --git a/src/rome/common.py b/src/rome/common.py index 72d2f92..529657b 100644 --- a/src/rome/common.py +++ b/src/rome/common.py @@ -385,22 +385,27 @@ def gather_k( prompts = handler.tokenize_prompt(templates) prompt_count = int(prompts.input_ids.shape[0]) - batch_idx = torch.arange(prompt_count, device=prompts.input_ids.device) - index = (prompts.attention_mask[batch_idx].sum(dim=1) - 1).long() + token_index = (prompts.attention_mask.detach().to("cpu").sum(dim=1) - 1).long() # TODO: Add support for dynamic batch size k = None def k_hook(_, input): nonlocal k # Pair each prompt with its own last non-padding token index. - k = input[0][batch_idx, index, :].mean(dim=0) + local_batch_idx = torch.arange(prompt_count, device=input[0].device) + local_index = token_index.to(input[0].device) + k = input[0][local_batch_idx, local_index, :].mean(dim=0) return input handler.set_k_hook(k_hook) handler.model(**prompts) handler.remove_hooks() - return handler.device_manager.safe_to_device(k) + if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: + target_device = handler.get_module_device(handler._layer_name_template.format(handler._layer)) + else: + target_device = handler.device + return handler.device_manager.safe_to_device(k, device=target_device) # https://medium.com/biased-algorithms/all-pairs-cosine-similarity-in-pytorch-064f9875d531 @@ -451,7 +456,8 @@ def get_subject_position(handler, prompt, subject): def get_subject_index(handler, prompts, fact_tuple, subject_understanding_template) -> torch.Tensor | None: new_target_ids = _strip_bos(handler, handler.tokenize_prompt(fact_tuple[2])["input_ids"][0]) - last_subject_index = (prompts.attention_mask[torch.arange(prompts.input_ids.shape[0])].sum(dim=1)) + batch_idx = torch.arange(prompts.input_ids.shape[0], device=prompts.attention_mask.device) + last_subject_index = prompts.attention_mask[batch_idx].sum(dim=1) fact_prompt = handler.tokenize_prompt(fact_tuple[0].format(fact_tuple[1])) u_fact_prompt = handler.tokenize_prompt(subject_understanding_template.format(fact_tuple[1])) @@ -474,7 +480,7 @@ def get_subject_index(handler, prompts, fact_tuple, subject_understanding_templa u_sub_reverse_pos = len(u_fact_prompt["input_ids"][0]) - pos last_subject_index[-1] -= u_sub_reverse_pos - return last_subject_index + return last_subject_index.long().cpu() def optimize_v( handler, @@ -511,11 +517,18 @@ def optimize_v( if last_subject_index is None: LOGGER.error("Subject index computation failed during v computation.") return None + last_subject_index_list = [int(x) for x in last_subject_index.tolist()] + + layer_name = handler._layer_name_template.format(handler._layer) + if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: + layer_device = handler.get_module_device(layer_name) + else: + layer_device = handler.device # The optimizer setup # Create delta on CPU first, then move through device_manager for tracking delta = torch.zeros((handler.emb_shape), requires_grad=False, dtype=handler.dtype) - delta = handler.device_manager.safe_to_device(delta).requires_grad_(True) + delta = handler.device_manager.safe_to_device(delta, device=layer_device).requires_grad_(True) opt = torch.optim.Adam([delta], lr=handler.lr) @@ -524,25 +537,24 @@ def delta_hook(module, _, output): if module == handler._get_module(handler._layer_name_template.format(handler._layer)): new_output = output.clone() if v_init is None: - v_init = output[0, last_subject_index[0]].detach().clone() - for i, idx in enumerate(last_subject_index): - new_output[i, idx, :] = new_output[i, idx, :] + delta.to(output.dtype) + v_init = output[0, last_subject_index_list[0]].detach().clone() + for i, idx in enumerate(last_subject_index_list): + new_output[i, idx, :] = new_output[i, idx, :] + delta.to(device=output.device, dtype=output.dtype) return new_output # Create index for all the prompts and targets target_len = int(new_target_ids.size(0)) - prompt_device = prompts.input_ids.device - main_prompt_idx = torch.arange(N_prompts, device=prompt_device) - index_positions = ( - prompts.attention_mask[:N_prompts].sum(dim=1).unsqueeze(1) + main_prompt_idx_cpu = torch.arange(N_prompts, dtype=torch.long) + index_positions_cpu = ( + prompts.attention_mask[:N_prompts].detach().to("cpu").sum(dim=1).unsqueeze(1) - target_len - + torch.arange(target_len, device=prompt_device).unsqueeze(0) + + torch.arange(target_len, dtype=torch.long).unsqueeze(0) ).long() - index_ids = new_target_ids.unsqueeze(0).repeat(N_prompts, 1) - dkl_prompt_idx = torch.arange(N_prompts, prompts.input_ids.shape[0], device=prompt_device) - dkl_index = (prompts.attention_mask[dkl_prompt_idx].sum(dim=1) - 1).long() + index_ids_cpu = new_target_ids.detach().to("cpu").long().unsqueeze(0).repeat(N_prompts, 1) + dkl_prompt_idx_cpu = torch.arange(N_prompts, prompts.input_ids.shape[0], dtype=torch.long) + dkl_index_cpu = (prompts.attention_mask.detach().to("cpu")[dkl_prompt_idx_cpu].sum(dim=1) - 1).long() for i in range(N_optim_steps): opt.zero_grad() @@ -552,6 +564,13 @@ def delta_hook(module, _, output): outputs = handler.model(**prompts) handler.remove_hooks() + logits_device = outputs.logits.device + main_prompt_idx = main_prompt_idx_cpu.to(logits_device) + index_positions = index_positions_cpu.to(logits_device) + index_ids = index_ids_cpu.to(logits_device) + dkl_prompt_idx = dkl_prompt_idx_cpu.to(logits_device) + dkl_index = dkl_index_cpu.to(logits_device) + all_log_probs = torch.log_softmax(outputs.logits, dim=2) log_probs = all_log_probs[ main_prompt_idx.unsqueeze(1), @@ -587,7 +606,7 @@ def delta_hook(module, _, output): return delta -def insert_kv(handler: ModelHandler, k: torch.Tensor, delta: torch.Tensor) -> None: +def insert_kv(handler: ModelHandler, k: torch.Tensor, delta: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: layer_name = handler._layer_name_template.format(handler._layer) # For multi-GPU, use the device where this layer actually lives if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: @@ -688,9 +707,15 @@ def hook(_, inp, out): LOGGER.info(f"Starting covariance computation: {n_samples} samples, batch_size={batch_size}, max_length={max_length}") ds = load_dataset(handler.cfg, sm=True) - # For multi-GPU models, get the device for inputs (first device in the pipeline) + # For multi-GPU models, place token inputs on the embedding module device. if hasattr(handler, 'is_multi_gpu') and handler.is_multi_gpu: - input_device = next(handler.model.parameters()).device + try: + input_module_name = handler._corrupt_layer_name_template + if "{}" in input_module_name: + input_module_name = input_module_name.format(0) + input_device = handler.get_module_device(input_module_name) + except Exception: + input_device = next(handler.model.parameters()).device else: input_device = handler.device @@ -723,13 +748,15 @@ def hook(_, inp, out): processed += len(batch_texts) except torch.cuda.OutOfMemoryError: LOGGER.warning("OOM during covariance computation, halving batch size") - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() batch_size = max(1, batch_size // 2) except Exception as e: LOGGER.warning(e) batch_texts = [] # Clear GPU cache periodically - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Process remaining texts if batch_texts and processed < n_samples: diff --git a/src/rome/rome.py b/src/rome/rome.py index e7feda3..cf38c50 100644 --- a/src/rome/rome.py +++ b/src/rome/rome.py @@ -152,7 +152,7 @@ def batch_evaluation(cfg: DictConfig) -> None: def single_intervention(handler: ModelHandler, fact_tuple: Tuple[str,str,str,str]) -> None: k = gather_k(handler, fact_tuple=fact_tuple, N=getattr(handler.cfg.generation, 'k_N', 50)) delta = optimize_v(handler, fact_tuple, N_prompts=getattr(handler.cfg.generation, 'v_N', 50), N_optim_steps=handler.epochs) - new_W, old_W = insert_kv(handler, k, delta) + new_W, old_W, _ = insert_kv(handler, k, delta) if handler.save_new_weights: out_path = Path(handler.new_weights_dir) / f"{handler.cfg.model.name.replace('/', '-')}_{handler._layer}.pt" @@ -189,7 +189,7 @@ def main(cfg: DictConfig) -> None: delta = optimize_v(handler, k, fact_tuple, N_prompts=50, N_optim_steps=handler.epochs, epsilon=0.005) LOGGER.info(f"delta computed, shape: {delta.shape}") - new_W, old_W = insert_kv(handler, k, delta) + new_W, old_W, _ = insert_kv(handler, k, delta) LOGGER.info(f"New weights computed") if handler.save_new_weights: