Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion notebooks/06_Causal_Intervention_Reproduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
" CLONE_DIR,\n",
" Path.cwd(),\n",
" Path(\"/content/drive/MyDrive/Research - code - copy\"),\n",
" Path(\"/content/drive/MyDrive/?rea de Trabalho/Research - code - copy\"),\n",
" Path(\"/content/drive/MyDrive/Área de Trabalho/Research - code - copy\"),\n",
" Path(\"/content/drive/MyDrive/Area de Trabalho/Research - code - copy\"),\n",
" Path(\"/workspace/Research - code - copy\"),\n",
" ]\n",
" )\n",
Expand Down
279 changes: 173 additions & 106 deletions scripts/run_causal_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,55 @@ def _compute_rho_stats(dt: torch.Tensor, A: torch.Tensor):
)


def _compute_discrete_a_rho(discrete_A: torch.Tensor):
return torch.amax(discrete_A, dim=(1, 3))


def _clamp_discrete_a_to_target(discrete_A: torch.Tensor, target_rho_tensor: torch.Tensor):
rho_before = _compute_discrete_a_rho(discrete_A)
scale = torch.ones_like(rho_before)
overshoot = rho_before > target_rho_tensor + 1e-6

if torch.any(overshoot):
scale[overshoot] = target_rho_tensor / rho_before[overshoot]

discrete_A_clamped = discrete_A * scale[:, None, :, None]
rho_after = _compute_discrete_a_rho(discrete_A_clamped)
return discrete_A_clamped, rho_before, rho_after, overshoot, scale


def _clamp_dt_to_target(
dt: torch.Tensor,
A: torch.Tensor,
target_rho_tensor: torch.Tensor,
min_dt_required: torch.Tensor | None = None,
):
rho_channels_before, rho_tokens_before, rho_current_before = _compute_rho_stats(dt, A)
rho_channels_before, _rho_tokens_before, rho_current_before = _compute_rho_stats(dt, A)

if min_dt_required is None:
min_abs_A = torch.abs(A).min(dim=-1).values
min_dt_required = -torch.log(target_rho_tensor) / torch.clamp(min_abs_A, min=1e-8)

dt_floor = min_dt_required.to(device=dt.device, dtype=dt.dtype).view(1, 1, -1)

if torch.all(rho_channels_before <= target_rho_tensor + 1e-6):
return dt, rho_current_before, rho_current_before

dt_scaled = torch.maximum(dt, dt_floor)
_, _, rho_after = _compute_rho_stats(dt_scaled, A)
rho_channels_after, _, rho_after = _compute_rho_stats(dt_scaled, A)

for _ in range(3):
overshoot = rho_channels_after > target_rho_tensor + 1e-6
if not torch.any(overshoot):
break

safe_rho_channels = torch.clamp(rho_channels_after, min=1e-8, max=1.0 - 1e-8)
correction = torch.ones_like(dt_scaled)
correction[overshoot] = (
torch.log(target_rho_tensor) / torch.log(safe_rho_channels[overshoot])
)
dt_scaled = dt_scaled * correction
rho_channels_after, _, rho_after = _compute_rho_stats(dt_scaled, A)

return dt_scaled, rho_current_before, rho_after

Expand All @@ -144,116 +178,149 @@ def _tensor_stats(tensor: torch.Tensor):
}


def _mamba_slow_forward_with_clamp(
mixer,
input_states,
cache_params=None,
cache_position=None,
attention_mask=None,
*,
target_rho_tensor,
debug=False,
debug_state=None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype

projected_states = mixer.in_proj(input_states).transpose(1, 2)
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

if cache_params is not None:
ssm_state = cache_params.ssm_states[mixer.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)

if cache_position.shape[0] == mixer.conv_kernel_size:
conv_state = F.pad(
hidden_states, (mixer.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.update_conv_state(mixer.layer_idx, conv_state, cache_position)
hidden_states = mixer.act(mixer.conv1d(hidden_states)[..., :seq_len])
else:
conv_state = cache_params.update_conv_state(
mixer.layer_idx, hidden_states, cache_position
)
conv_state = conv_state.to(mixer.conv1d.weight.device)
hidden_states = torch.sum(conv_state * mixer.conv1d.weight[:, 0, :], dim=-1)
if mixer.use_conv_bias:
hidden_states += mixer.conv1d.bias
hidden_states = mixer.act(hidden_states).to(dtype).unsqueeze(-1)
else:
ssm_state = torch.zeros(
(batch_size, mixer.intermediate_size, mixer.ssm_state_size),
device=hidden_states.device,
dtype=dtype,
)
hidden_states = mixer.act(mixer.conv1d(hidden_states)[..., :seq_len])

if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)

ssm_parameters = mixer.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters,
[mixer.time_step_rank, mixer.ssm_state_size, mixer.ssm_state_size],
dim=-1,
)

discrete_time_step = mixer.dt_proj(time_step)
dt = F.softplus(discrete_time_step).transpose(1, 2)

A = -torch.exp(mixer.A_log.float())
discrete_A = torch.exp(A[None, :, None, :] * dt[:, :, :, None])
discrete_A, rho_before, rho_after, overshoot, scale = _clamp_discrete_a_to_target(
discrete_A,
target_rho_tensor,
)

if debug and debug_state is not None and not debug_state["printed"]:
action = "clamp" if torch.any(overshoot) else "no-op"
affected_tokens = int(overshoot.sum().item())
total_tokens = int(overshoot.numel())
print(
f"[layer {mixer.layer_idx}] {action} "
f"dt_shape={tuple(dt.shape)} "
f"dt_before={_tensor_stats(dt)} "
f"rho_before={_tensor_stats(rho_before)} "
f"rho_after={_tensor_stats(rho_after)} "
f"scale={_tensor_stats(scale)} "
f"affected_tokens={affected_tokens}/{total_tokens}"
)
debug_state["printed"] = True

max_rho_after = float(rho_after.max().item())
if max_rho_after > float(target_rho_tensor.item()) + 1e-6:
raise RuntimeError(
f"Layer {mixer.layer_idx} clamp failed: "
f"rho_after={max_rho_after:.6f} > "
f"target={float(target_rho_tensor.item()):.6f}"
)

discrete_B = dt[:, :, :, None] * B[:, None, :, :].float()
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))
scan_outputs.append(scan_output[:, :, 0])

scan_output = torch.stack(scan_outputs, dim=-1)
scan_output = scan_output + (hidden_states * mixer.D[None, :, None])
scan_output = scan_output * mixer.act(gate)

if cache_params is not None:
cache_params.ssm_states[mixer.layer_idx].copy_(ssm_state)

return mixer.out_proj(scan_output.transpose(1, 2))


def apply_clamp_hooks(model, layer_indices, target_rho, device, debug=False):
handles = []
target_rho_tensor = torch.tensor(target_rho, device=device, dtype=torch.float32)

for idx in layer_indices:
mixer = model.backbone.layers[idx].mixer
A = -torch.exp(mixer.A_log.detach().float()).to(device)
min_abs_A = torch.abs(A).min(dim=-1).values
min_dt_required = -torch.log(target_rho_tensor) / torch.clamp(min_abs_A, min=1e-8)
separate_dt_bias = _get_separate_dt_bias(mixer)

def make_hook(layer_idx, min_dt_local, separate_bias_local):
state = {"printed": False}

def hook(module, args, output):
dt_raw = output[0] if isinstance(output, tuple) else output
orig_dtype = dt_raw.dtype
dt_raw_work = dt_raw.to(dtype=torch.float32)

# Some mixers expose a separate dt_bias parameter. When present,
# it should be added before softplus to match the real forward path.
dt_bias = None
if separate_bias_local is not None:
dt_bias = separate_bias_local.to(
device=dt_raw_work.device, dtype=dt_raw_work.dtype
)

dt = _materialize_dt(dt_raw_work, dt_bias)

full_dt_dim = dt.shape[-1]
dt_dim = min(full_dt_dim, min_dt_local.shape[0], A.shape[0])
dt = dt[..., :dt_dim]
A_local = A[:dt_dim].to(device=dt.device, dtype=dt.dtype)
min_dt_effective = min_dt_local[:dt_dim].to(
device=dt.device, dtype=dt.dtype
)
dt_floor = min_dt_effective.view(1, 1, -1)

_rho_channels_before, _rho_tokens_before, rho_current_before = _compute_rho_stats(
dt, A_local
)
rho_current_stats = _tensor_stats(rho_current_before)

if torch.all(dt >= dt_floor):
if debug and not state["printed"]:
bias_state = "present" if dt_bias is not None else "absent"
print(
f"[layer {layer_idx}] no-op "
f"dt_shape={tuple(dt.shape)} "
f"bias={bias_state} "
f"dt_before={_tensor_stats(dt)} "
f"rho_current={rho_current_stats} "
f"min_dt_required={_tensor_stats(min_dt_effective)}"
)
state["printed"] = True
return output

dt_scaled, rho_current_before, rho_after = _clamp_dt_to_target(
dt,
A_local,
target_rho_tensor,
min_dt_effective,
)
rho_after_stats = _tensor_stats(rho_after)

if debug and not state["printed"]:
bias_state = "present" if dt_bias is not None else "absent"
print(
f"[layer {layer_idx}] clamp "
f"dt_shape={tuple(dt.shape)} "
f"full_dt_dim={full_dt_dim} "
f"effective_dt_dim={dt_dim} "
f"bias={bias_state} "
f"dt_before={_tensor_stats(dt)} "
f"dt_after={_tensor_stats(dt_scaled)} "
f"rho_current={rho_current_stats} "
f"rho_after={rho_after_stats} "
f"min_abs_A={_tensor_stats(min_abs_A[:dt_dim])} "
f"min_dt_required={_tensor_stats(min_dt_effective)}"
)
state["printed"] = True

max_rho_after = float(rho_after.max().item())
if max_rho_after > float(target_rho_tensor.item()) + 1e-6:
raise RuntimeError(
f"Layer {layer_idx} clamp failed: "
f"rho_after={max_rho_after:.6f} > "
f"target={float(target_rho_tensor.item()):.6f}"
)

dt_raw_clamped = dt_raw_work.clone()
dt_raw_clamped_effective = _inverse_softplus(dt_scaled)
if dt_bias is not None:
dt_raw_clamped_effective = (
dt_raw_clamped_effective - dt_bias[..., :dt_dim]
)
dt_raw_clamped[..., :dt_dim] = dt_raw_clamped_effective
dt_raw_clamped = dt_raw_clamped.to(dtype=orig_dtype)

if isinstance(output, tuple):
return (dt_raw_clamped,) + output[1:]
return dt_raw_clamped

return hook

handles.append(
mixer.dt_proj.register_forward_hook(
make_hook(idx, min_dt_required, separate_dt_bias)
original_forward = mixer.forward
state = {"printed": False}

def patched_forward(
hidden_states,
cache_params=None,
cache_position=None,
attention_mask=None,
*,
_mixer=mixer,
_state=state,
):
return _mamba_slow_forward_with_clamp(
_mixer,
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
target_rho_tensor=target_rho_tensor,
debug=debug,
debug_state=_state,
)
)

mixer.forward = patched_forward

def restore_forward(_mixer=mixer, _original_forward=original_forward):
_mixer.forward = _original_forward

handles.append(restore_forward)

return handles

Expand Down Expand Up @@ -297,7 +364,7 @@ def evaluate_accuracy(
correct += 1
finally:
for h in handles:
h.remove()
h()

n = len(prompts)
acc = correct / n if n > 0 else 0.0
Expand Down
Loading
Loading