Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 91 additions & 50 deletions notebooks/06_Causal_Intervention_Reproduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
"source": [
"# Experiment: F1-01a Causal Intervention Reproduction\n",
"\n",
"This notebook reproduces the causal-intervention run and will clone the repo automatically on Colab if needed.\n",
"This notebook runs the corrected causal-intervention sweep and will clone the repo automatically on Colab if needed.\n",
"\n",
"To match Table 3 in the paper, run 100 validated prompts with the full threshold grid.\n",
"To match the Fase 01 protocol, it runs the full threshold grid for multiple seeds and exports consolidated artifacts.\n",
"\n",
"Important:\n",
"- Do not use `%run scripts/run_causal_intervention.py` here.\n",
Expand Down Expand Up @@ -127,9 +127,10 @@
"source": [
"# Reproducibility and run configuration.\n",
"MODEL_ID = \"state-spaces/mamba-130m-hf\"\n",
"N_SAMPLES = 100\n",
"SEED = 42\n",
"THRESHOLDS = [0.99, 0.95, 0.90, 0.70, 0.50]\n",
"N_SAMPLES = 200\n",
"SEEDS = [42, 123, 456]\n",
"SINGLE_LAYER_IDX = 0\n",
"THRESHOLDS = [0.99, 0.97, 0.95, 0.92, 0.90, 0.87, 0.85, 0.80, 0.70, 0.50, 0.30]\n",
"MAX_ATTEMPTS = 3000\n",
"MAX_NEW_TOKENS = 25\n",
"DEBUG = True\n",
Expand All @@ -140,6 +141,7 @@
"\n",
"print(f\"MODEL_ID={MODEL_ID}\")\n",
"print(f\"N_SAMPLES={N_SAMPLES}\")\n",
"print(f\"SEEDS={SEEDS}\")\n",
"print(f\"THRESHOLDS={THRESHOLDS}\")\n",
"print(f\"OUTPUT_DIR={OUTPUT_DIR}\")\n"
]
Expand All @@ -151,7 +153,6 @@
"outputs": [],
"source": [
"# Load tokenizer and model.\n",
"causal.set_seed(SEED)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Using device: {device}\")\n",
"if device.type == \"cuda\":\n",
Expand Down Expand Up @@ -180,22 +181,10 @@
"metadata": {},
"outputs": [],
"source": [
"# Mine prompts that the baseline model can solve before running the intervention.\n",
"prompts = causal.mine_validated_prompts(\n",
" model,\n",
" tokenizer,\n",
" N_SAMPLES,\n",
" device,\n",
" SEED,\n",
" MAX_ATTEMPTS,\n",
" MAX_NEW_TOKENS,\n",
")\n",
"\n",
"print(f\"Validated prompts: {len(prompts)}\")\n",
"if prompts:\n",
" print(\"Example prompt:\")\n",
" print(prompts[0][0])\n",
" print(f\"Target: {prompts[0][1]}\")\n"
"# Prompts are mined lazily per seed so all protocols for the same seed share\n",
"# the exact same validated prompt set.\n",
"PROMPTS_BY_SEED = {}\n",
"print(\"Prompt cache ready.\")\n"
]
},
{
Expand All @@ -204,8 +193,31 @@
"metadata": {},
"outputs": [],
"source": [
"def run_protocol(protocol: str, layer_idx: int, output_name: str) -> pd.DataFrame:\n",
"def get_prompts_for_seed(seed: int):\n",
" if seed in PROMPTS_BY_SEED:\n",
" return PROMPTS_BY_SEED[seed]\n",
"\n",
" causal.set_seed(seed)\n",
" prompts = causal.mine_validated_prompts(\n",
" model,\n",
" tokenizer,\n",
" N_SAMPLES,\n",
" device,\n",
" seed,\n",
" MAX_ATTEMPTS,\n",
" MAX_NEW_TOKENS,\n",
" )\n",
" PROMPTS_BY_SEED[seed] = prompts\n",
" print(f\"Seed {seed}: validated prompts = {len(prompts)}\")\n",
" if prompts:\n",
" print(f\"Seed {seed} example target: {prompts[0][1]}\")\n",
" return prompts\n",
"\n",
"def run_protocol(seed: int, protocol: str, layer_idx: int, output_name: str) -> pd.DataFrame:\n",
" prompts = get_prompts_for_seed(seed)\n",
" n = len(prompts)\n",
" layer_idx_value = None if protocol == \"all_layer\" else layer_idx\n",
" layers_applied = \"all\" if protocol == \"all_layer\" else str(layer_idx)\n",
" rows = [\n",
" {\n",
" \"experiment_id\": \"causal_clamp\",\n",
Expand All @@ -217,9 +229,11 @@
" \"metric_value\": 1.0,\n",
" \"ci_low\": 1.0,\n",
" \"ci_high\": 1.0,\n",
" \"seed\": SEED,\n",
" \"seed\": seed,\n",
" \"artifact_path\": str(OUTPUT_DIR / output_name),\n",
" \"protocol\": protocol,\n",
" \"layer_idx\": layer_idx_value,\n",
" \"layers_applied\": layers_applied,\n",
" \"rho_target\": \"baseline\",\n",
" \"correct\": n,\n",
" }\n",
Expand Down Expand Up @@ -248,9 +262,11 @@
" \"metric_value\": round(acc, 6),\n",
" \"ci_low\": round(ci_low, 6),\n",
" \"ci_high\": round(ci_high, 6),\n",
" \"seed\": SEED,\n",
" \"seed\": seed,\n",
" \"artifact_path\": str(OUTPUT_DIR / output_name),\n",
" \"protocol\": protocol,\n",
" \"layer_idx\": layer_idx_value,\n",
" \"layers_applied\": layers_applied,\n",
" \"rho_target\": rho,\n",
" \"correct\": correct,\n",
" }\n",
Expand All @@ -270,9 +286,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. All-layer run\n",
"## 1. All-layer sweep\n",
"\n",
"This is the global intervention path.\n"
"This runs the global intervention path for every seed.\n"
]
},
{
Expand All @@ -281,20 +297,28 @@
"metadata": {},
"outputs": [],
"source": [
"df_all_layer = run_protocol(\n",
" protocol=\"all_layer\",\n",
" layer_idx=0,\n",
" output_name=\"repro_all_layer.csv\",\n",
")\n"
"df_all_layer_runs = []\n",
"for seed in SEEDS:\n",
" df_all_layer_runs.append(\n",
" run_protocol(\n",
" seed=seed,\n",
" protocol=\"all_layer\",\n",
" layer_idx=SINGLE_LAYER_IDX,\n",
" output_name=f\"table3_seed{seed}_all_layer.csv\",\n",
" )\n",
" )\n",
"\n",
"df_all_layer = pd.concat(df_all_layer_runs, ignore_index=True)\n",
"display(df_all_layer)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Single-layer run\n",
"## 2. Single-layer sweep\n",
"\n",
"This is the local intervention path. Use layer 0 unless you want to test another layer.\n"
"This runs the local intervention path for every seed using layer 0 for compatibility.\n"
]
},
{
Expand All @@ -303,20 +327,28 @@
"metadata": {},
"outputs": [],
"source": [
"df_single_layer = run_protocol(\n",
" protocol=\"single_layer\",\n",
" layer_idx=0,\n",
" output_name=\"repro_single_layer.csv\",\n",
")\n"
"df_single_layer_runs = []\n",
"for seed in SEEDS:\n",
" df_single_layer_runs.append(\n",
" run_protocol(\n",
" seed=seed,\n",
" protocol=\"single_layer\",\n",
" layer_idx=SINGLE_LAYER_IDX,\n",
" output_name=f\"table3_seed{seed}_single_layer_l{SINGLE_LAYER_IDX}.csv\",\n",
" )\n",
" )\n",
"\n",
"df_single_layer = pd.concat(df_single_layer_runs, ignore_index=True)\n",
"display(df_single_layer)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Summary and next step\n",
"## 3. Consolidated artifacts and summary\n",
"\n",
"If the notebook stops with a `rho_after` assertion, that is the signal I need. Send me the CSVs plus the full notebook output.\n"
"If the notebook stops with a `rho_after` assertion, that is the signal I need. Otherwise, send me the consolidated CSV/JSON plus the full notebook output.\n"
]
},
{
Expand All @@ -325,16 +357,25 @@
"metadata": {},
"outputs": [],
"source": [
"summary = pd.concat(\n",
" [\n",
" df_all_layer.assign(run=\"all_layer\"),\n",
" df_single_layer.assign(run=\"single_layer\"),\n",
" ],\n",
" ignore_index=True,\n",
")\n",
"summary = pd.concat([df_all_layer, df_single_layer], ignore_index=True)\n",
"summary_path = OUTPUT_DIR / \"causal_intervention_table3_results.csv\"\n",
"summary_json_path = OUTPUT_DIR / \"causal_intervention_table3_results.json\"\n",
"summary.to_csv(summary_path, index=False)\n",
"summary.to_json(summary_json_path, orient=\"records\", indent=2)\n",
"\n",
"display(summary[[\"seed\", \"protocol\", \"layer_idx\", \"rho_target\", \"metric_value\", \"correct\", \"ci_low\", \"ci_high\"]])\n",
"\n",
"for rho_target in [0.99, 0.30]:\n",
" subset = summary[summary[\"rho_target\"] == rho_target]\n",
" if subset.empty:\n",
" continue\n",
" stats = subset.groupby(\"protocol\")[\"metric_value\"].agg([\"mean\", \"std\"]).reset_index()\n",
" print(f\"rho={rho_target} summary\")\n",
" display(stats)\n",
"\n",
"display(summary[[\"run\", \"protocol\", \"rho_target\", \"metric_value\", \"correct\", \"ci_low\", \"ci_high\"]])\n",
"print(\"Done. Send the CSVs and the full notebook log.\")\n"
"print(f\"Saved: {summary_path}\")\n",
"print(f\"Saved: {summary_json_path}\")\n",
"print(\"Done. Send the consolidated CSV/JSON and the full notebook log.\")\n"
]
}
],
Expand Down
16 changes: 14 additions & 2 deletions scripts/run_causal_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def _compute_rho_stats(dt: torch.Tensor, A: torch.Tensor):


def _compute_discrete_a_rho(discrete_A: torch.Tensor):
return torch.amax(discrete_A, dim=(1, 3))
# Match the project-wide operational rho used in downstream analyses:
# take the per-channel spectral radius first, then aggregate channels.
rho_channels = discrete_A.amax(dim=-1)
return rho_channels.mean(dim=1)


def _clamp_discrete_a_to_target(discrete_A: torch.Tensor, target_rho_tensor: torch.Tensor):
Expand Down Expand Up @@ -422,7 +425,7 @@ def main():
"--thresholds",
nargs="+",
type=float,
default=[0.99, 0.95, 0.90, 0.70, 0.50],
default=[0.99, 0.97, 0.95, 0.92, 0.90, 0.87, 0.85, 0.80, 0.70, 0.50, 0.30],
)
parser.add_argument("--max-attempts", type=int, default=3000)
parser.add_argument("--max-new-tokens", type=int, default=25)
Expand Down Expand Up @@ -458,6 +461,8 @@ def main():

for protocol in run_protocols:
n = len(prompts)
layers_applied = "all" if protocol == "all_layer" else str(args.layer_idx)
row_layer_idx = None if protocol == "all_layer" else args.layer_idx
rows.append(
{
"experiment_id": "causal_clamp",
Expand All @@ -472,6 +477,8 @@ def main():
"seed": args.seed,
"artifact_path": args.output,
"protocol": protocol,
"layer_idx": row_layer_idx,
"layers_applied": layers_applied,
"rho_target": "baseline",
"correct": n,
}
Expand Down Expand Up @@ -503,6 +510,8 @@ def main():
"seed": args.seed,
"artifact_path": args.output,
"protocol": protocol,
"layer_idx": row_layer_idx,
"layers_applied": layers_applied,
"rho_target": rho,
"correct": correct,
}
Expand All @@ -513,6 +522,9 @@ def main():
out_path.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(out_path, index=False)
print(f"Saved: {out_path}")
json_path = out_path.with_suffix(".json")
df.to_json(json_path, orient="records", indent=2)
print(f"Saved: {json_path}")


if __name__ == "__main__":
Expand Down
37 changes: 30 additions & 7 deletions tests/test_run_causal_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_materialize_dt_applies_bias_before_softplus():
assert torch.allclose(actual, expected)


def test_compute_rho_stats_uses_operator_max_not_channel_mean():
def test_compute_rho_stats_uses_per_channel_max_then_channel_mean():
dt = torch.tensor([[[0.1, 0.2]]], dtype=torch.float32)
A = -torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)

Expand All @@ -39,14 +39,14 @@ def test_compute_rho_stats_uses_operator_max_not_channel_mean():
assert torch.allclose(rho_current, expected_rho_current)


def test_compute_discrete_a_rho_uses_global_token_max():
def test_compute_discrete_a_rho_uses_token_mean_across_channels():
discrete_A = torch.tensor(
[[[[0.95, 0.96], [0.97, 0.94]], [[0.98, 0.99], [0.93, 0.92]]]],
dtype=torch.float32,
)

actual = _compute_discrete_a_rho(discrete_A)
expected = torch.tensor([[0.99, 0.97]], dtype=torch.float32)
expected = torch.tensor([[(0.96 + 0.99) / 2, (0.97 + 0.93) / 2]], dtype=torch.float32)

assert torch.allclose(actual, expected)

Expand All @@ -56,20 +56,43 @@ def test_clamp_discrete_a_to_target_scales_only_overshooting_tokens():
[[[[0.98, 1.00], [0.95, 0.94]], [[0.96, 0.97], [0.92, 0.91]]]],
dtype=torch.float32,
)
target = torch.tensor(0.99, dtype=torch.float32)
target = torch.tensor(0.97, dtype=torch.float32)

clamped, rho_before, rho_after, overshoot, scale = _clamp_discrete_a_to_target(
discrete_A,
target,
)

assert torch.allclose(rho_before, torch.tensor([[1.00, 0.95]], dtype=torch.float32))
expected_rho_before = torch.tensor([[(1.00 + 0.97) / 2, (0.95 + 0.92) / 2]], dtype=torch.float32)
expected_scale = torch.tensor([[0.97 / 0.985, 1.0]], dtype=torch.float32)
expected_rho_after = torch.tensor([[0.97, 0.935]], dtype=torch.float32)

assert torch.allclose(rho_before, expected_rho_before)
assert torch.equal(overshoot, torch.tensor([[True, False]]))
assert torch.allclose(scale, torch.tensor([[0.99, 1.0]], dtype=torch.float32))
assert torch.allclose(rho_after, torch.tensor([[0.99, 0.95]], dtype=torch.float32))
assert torch.allclose(scale, expected_scale)
assert torch.allclose(rho_after, expected_rho_after)
assert torch.allclose(clamped[:, :, 1, :], discrete_A[:, :, 1, :])


def test_clamp_discrete_a_is_noop_when_token_mean_is_below_target():
discrete_A = torch.tensor(
[[[[0.98, 1.00], [0.95, 0.94]], [[0.96, 0.97], [0.92, 0.91]]]],
dtype=torch.float32,
)
target = torch.tensor(0.99, dtype=torch.float32)

clamped, rho_before, rho_after, overshoot, scale = _clamp_discrete_a_to_target(
discrete_A,
target,
)

assert torch.allclose(rho_before, torch.tensor([[0.985, 0.935]], dtype=torch.float32))
assert torch.equal(overshoot, torch.tensor([[False, False]]))
assert torch.allclose(scale, torch.ones_like(scale))
assert torch.allclose(rho_after, rho_before)
assert torch.allclose(clamped, discrete_A)


def test_clamp_dt_is_noop_below_target_and_respects_sanity_check():
dt = torch.full((1, 1, 2), 0.02, dtype=torch.float32)
A = -torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
Expand Down
Loading