Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 87 additions & 45 deletions scripts/run_causal_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,31 @@
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

try:
from tqdm import tqdm
except ImportError: # pragma: no cover - keeps the module importable in lean envs.
class _NoOpTqdm:
def __init__(self, iterable=None, total=None, desc=None):
self.iterable = iterable

def __iter__(self):
return iter(self.iterable or [])

def update(self, n=1):
return None

def close(self):
return None

def tqdm(iterable=None, **kwargs):
return _NoOpTqdm(
iterable=iterable,
total=kwargs.get("total"),
desc=kwargs.get("desc"),
)


def wilson_interval(successes: int, n: int, z: float = 1.96):
Expand All @@ -22,6 +41,8 @@ def wilson_interval(successes: int, n: int, z: float = 1.96):


def set_seed(seed: int):
import numpy as np

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
Expand Down Expand Up @@ -75,12 +96,52 @@ def _get_separate_dt_bias(mixer):
return None


def _compute_rho_proxy(dt: torch.Tensor, A: torch.Tensor):
def _materialize_dt(dt_raw: torch.Tensor, dt_bias: torch.Tensor | None):
dt_input = dt_raw
if dt_bias is not None:
dt_input = dt_input + dt_bias.unsqueeze(0).unsqueeze(0)
return F.softplus(dt_input)


def _compute_rho_stats(dt: torch.Tensor, A: torch.Tensor):
a = A.to(device=dt.device, dtype=dt.dtype)
A_bar = torch.exp(dt.unsqueeze(-1) * a.unsqueeze(0).unsqueeze(0))
rho_channels = A_bar.max(dim=-1).values
rho_tokens = rho_channels.mean(dim=-1)
return rho_channels, rho_tokens
rho_channels = A_bar.amax(dim=-1)
rho_current = rho_channels.amax(dim=-1)
return rho_channels, rho_current


def _clamp_dt_to_target(
dt: torch.Tensor,
A: torch.Tensor,
target_rho_tensor: torch.Tensor,
):
rho_channels_before, rho_current_before = _compute_rho_stats(dt, A)

if torch.all(rho_current_before <= target_rho_tensor):
return dt, rho_current_before, rho_current_before

safe_rho_current = torch.clamp(rho_current_before, min=1e-8, max=1.0 - 1e-8)
scale = torch.ones_like(rho_current_before)
active = rho_current_before > target_rho_tensor
scale[active] = torch.log(target_rho_tensor) / torch.log(safe_rho_current[active])

dt_scaled = dt * scale.unsqueeze(-1)
_, rho_after = _compute_rho_stats(dt_scaled, A)

for _ in range(2):
overshoot = rho_after > target_rho_tensor + 1e-6
if not torch.any(overshoot):
break
safe_rho_after = torch.clamp(rho_after, min=1e-8, max=1.0 - 1e-8)
correction = torch.ones_like(rho_after)
correction[overshoot] = (
torch.log(target_rho_tensor) / torch.log(safe_rho_after[overshoot])
)
dt_scaled = dt_scaled * correction.unsqueeze(-1)
_, rho_after = _compute_rho_stats(dt_scaled, A)

return dt_scaled, rho_current_before, rho_after


def _tensor_stats(tensor: torch.Tensor):
Expand Down Expand Up @@ -118,12 +179,7 @@ def hook(module, args, output):
device=dt_raw_work.device, dtype=dt_raw_work.dtype
)

dt_input = (
dt_raw_work
if dt_bias is None
else dt_raw_work + dt_bias.unsqueeze(0).unsqueeze(0)
)
dt = F.softplus(dt_input)
dt = _materialize_dt(dt_raw_work, dt_bias)

full_dt_dim = dt.shape[-1]
dt_dim = min(full_dt_dim, min_dt_local.shape[0], A.shape[0])
Expand All @@ -133,38 +189,29 @@ def hook(module, args, output):
device=dt.device, dtype=dt.dtype
)

_rho_channels_before, rho_tokens_before = _compute_rho_proxy(dt, A_local)
rho_before = _tensor_stats(rho_tokens_before)
_rho_channels_before, rho_current_before = _compute_rho_stats(dt, A_local)
rho_current_stats = _tensor_stats(rho_current_before)

if torch.all(rho_tokens_before <= target_rho_tensor):
if torch.all(rho_current_before <= target_rho_tensor):
if debug and not state["printed"]:
bias_state = "present" if dt_bias is not None else "absent"
print(
f"[layer {layer_idx}] no-op "
f"dt_shape={tuple(dt.shape)} "
f"bias={bias_state} "
f"dt_before={_tensor_stats(dt)} "
f"rho_before={rho_before} "
f"rho_current={rho_current_stats} "
f"min_dt_required={_tensor_stats(min_dt_effective)}"
)
state["printed"] = True
return output

# Proxy-based scalar correction keeps the intervention closer to the
# paper's operator clamp than a hard floor on every channel.
safe_rho = torch.clamp(rho_tokens_before, min=1e-8, max=1.0 - 1e-8)
scale = torch.ones_like(rho_tokens_before)
active = rho_tokens_before > target_rho_tensor
scale[active] = torch.log(target_rho_tensor) / torch.log(safe_rho[active])

dt_scaled = dt * scale.unsqueeze(-1)
dt_scaled = torch.maximum(
dt_scaled,
min_dt_effective.unsqueeze(0).unsqueeze(0),
dt_scaled, rho_current_before, rho_after = _clamp_dt_to_target(
dt,
A_local,
target_rho_tensor,
)

_rho_channels_after, rho_tokens_after = _compute_rho_proxy(dt_scaled, A_local)
rho_after = _tensor_stats(rho_tokens_after)
rho_after_stats = _tensor_stats(rho_after)

if debug and not state["printed"]:
bias_state = "present" if dt_bias is not None else "absent"
Expand All @@ -176,27 +223,19 @@ def hook(module, args, output):
f"bias={bias_state} "
f"dt_before={_tensor_stats(dt)} "
f"dt_after={_tensor_stats(dt_scaled)} "
f"rho_before={rho_before} "
f"rho_after={rho_after} "
f"rho_current={rho_current_stats} "
f"rho_after={rho_after_stats} "
f"min_abs_A={_tensor_stats(min_abs_A[:dt_dim])} "
f"min_dt_required={_tensor_stats(min_dt_effective)}"
)
state["printed"] = True

# The rho proxy is approximate, so keep a small numerical margin.
rho_margin = 1e-3
max_rho_after = float(rho_tokens_after.max().item())
if max_rho_after > float(target_rho_tensor.item()) + rho_margin:
max_rho_after = float(rho_after.max().item())
if max_rho_after > float(target_rho_tensor.item()) + 1e-6:
raise RuntimeError(
f"Layer {layer_idx} clamp failed: "
f"rho_after={max_rho_after:.6f} > "
f"target={float(target_rho_tensor.item()):.6f} "
f"(margin={rho_margin:.1e})"
)
if debug and not state["printed"] and max_rho_after > float(target_rho_tensor.item()):
print(
f"[layer {layer_idx}] rho_after overshoot "
f"{max_rho_after:.6f} within margin={rho_margin:.1e}"
f"target={float(target_rho_tensor.item()):.6f}"
)

dt_raw_clamped = dt_raw_work.clone()
Expand Down Expand Up @@ -303,6 +342,9 @@ def mine_validated_prompts(model, tokenizer, n_samples, device, seed, max_attemp


def main():
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

parser = argparse.ArgumentParser(description="Causal intervention for spectral clamp protocols.")
parser.add_argument("--model-id", default="state-spaces/mamba-130m-hf")
parser.add_argument("--n-samples", type=int, default=100)
Expand Down
59 changes: 59 additions & 0 deletions tests/test_run_causal_intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

torch = pytest.importorskip("torch")
import torch.nn.functional as F

from scripts.run_causal_intervention import (
_clamp_dt_to_target,
_compute_rho_stats,
_materialize_dt,
)


def test_materialize_dt_applies_bias_before_softplus():
dt_raw = torch.tensor([[[0.0, 1.0]]], dtype=torch.float32)
dt_bias = torch.tensor([0.25, -0.5], dtype=torch.float32)

actual = _materialize_dt(dt_raw, dt_bias)
expected = F.softplus(dt_raw + dt_bias.view(1, 1, -1))

assert torch.allclose(actual, expected)


def test_compute_rho_stats_uses_operator_max_not_channel_mean():
dt = torch.tensor([[[0.1, 0.2]]], dtype=torch.float32)
A = -torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)

rho_channels, rho_current = _compute_rho_stats(dt, A)
expected_rho_channels = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)).amax(
dim=-1
)
expected_rho_current = expected_rho_channels.amax(dim=-1)

assert torch.allclose(rho_channels, expected_rho_channels)
assert torch.allclose(rho_current, expected_rho_current)


def test_clamp_dt_is_noop_below_target_and_respects_sanity_check():
dt = torch.full((1, 1, 2), 0.02, dtype=torch.float32)
A = -torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
target = torch.tensor(0.99, dtype=torch.float32)

clamped, rho_current, rho_after = _clamp_dt_to_target(dt, A, target)

assert torch.allclose(clamped, dt)
assert torch.all(rho_current <= target)
assert torch.allclose(rho_after, rho_current)


def test_clamp_dt_raises_rho_to_target_without_overshoot():
dt = torch.full((1, 1, 2), 0.001, dtype=torch.float32)
A = -torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
target = torch.tensor(0.99, dtype=torch.float32)

clamped, rho_current, rho_after = _clamp_dt_to_target(dt, A, target)

assert torch.all(clamped >= dt)
assert torch.any(clamped > dt)
assert torch.all(rho_current > target)
assert torch.all(rho_after <= target + 1e-6)
Loading