Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion attention_sink/analyze_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,12 +412,22 @@ def _plot_progression(stats_a, outdir, key, ylabel, title, fname, suffix="", sta
plt.title(title)

has_kvar = bool(stats_a) and ("k_total_var" in stats_a[0])
has_vvar = bool(stats_a) and ("v_total_var" in stats_a[0])
if has_kvar:
y2 = [by_a[L]["k_total_var"] for L in layers]
ax2 = ax.twinx()
ax2.plot(layers, y2, marker="^", linestyle="--", color="tab:red", label="Total variance of K")
ax2.set_ylim(0, 20000)

lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines1 + lines2, labels1 + labels2, loc="best")
elif has_vvar:
y2 = [by_a[L]["v_total_var"] for L in layers]
ax2 = ax.twinx()
ax2.plot(layers, y2, marker="^", linestyle="--", color="tab:red", label="Total variance of V")
ax2.set_ylim(0, 20000)

lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines1 + lines2, labels1 + labels2, loc="best")
Expand All @@ -428,6 +438,43 @@ def _plot_progression(stats_a, outdir, key, ylabel, title, fname, suffix="", sta
plt.savefig(_append_suffix(os.path.join(outdir, fname), suffix), dpi=300)
plt.close()

def _plot_v_with_variance(stats, outdir, sink_idx, suffix=""):
"""Plot V-norm with V total variance on right axis (similar to K plot)."""
layers = sorted({s["layer"] for s in stats})
by_s = {s["layer"]: s for s in stats}
y_norm = [by_s[L]["v_norm"] for L in layers]
y_var = [by_s[L].get("v_total_var", 0.0) for L in layers]

# Auto-scale y-axis based on data
max_norm = max(y_norm) if y_norm else 100
max_var = max(y_var) if y_var else 20000

plt.figure(figsize=(7, 3.5))
ax = plt.gca()
ax.plot(layers, y_norm, marker="o", color="tab:blue", label="Norm of V")
ax.set_xlabel("Layer")
ax.set_ylabel(f"||V[{sink_idx}]||")
ax.set_ylim(0, max(100, max_norm * 1.1))

if len(layers) > 0:
ax.set_xlim(min(layers), max(layers))
ax.set_xticks(layers)
ax.xaxis.set_major_locator(ticker.MultipleLocator(2))

ax2 = ax.twinx()
ax2.plot(layers, y_var, marker="^", linestyle="--", color="tab:red", label="Total variance of V")
ax2.set_ylim(0, max(20000, max_var * 1.1))
ax2.set_ylabel("Total variance of V")

lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax2.legend(lines1 + lines2, labels1 + labels2, loc="best")

plt.title(f"V-norm progression (token={sink_idx})")
plt.tight_layout()
plt.savefig(_append_suffix(os.path.join(outdir, "scan_vnorm_with_var.png"), suffix), dpi=300)
plt.close()

def _plot_bias_energy_3d(bias_sets, out_path, x_label):
Ys = [Y for (Y, _t) in bias_sets if Y is not None]
if not Ys:
Expand Down Expand Up @@ -612,6 +659,7 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
for L in scan_layers
}
per_layer_kvecs = {L: [] for L in scan_layers}
per_layer_vvecs = {L: [] for L in scan_layers} # V vectors for variance
per_layer_maps = {
L: [] for L in scan_layers
}
Expand Down Expand Up @@ -647,6 +695,7 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
per_layer_acc[L]["cos"].append(series[0][1])
per_layer_acc[L]["sink_attn"].append(float(sink_slice.mean().item()))
per_layer_kvecs[L].append(k[0, args.head, sink_idx].detach().cpu())
per_layer_vvecs[L].append(v[0, args.head, sink_idx].detach().cpu())
else:
hm = _pick_head_with_caches(model, attns, L, args.head)
per_layer_maps[L].append(hm[:min_len, :min_len])
Expand All @@ -663,10 +712,15 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
for L in sorted(per_layer_acc.keys()):
acc = per_layer_acc[L]
kvar = 0.0
vvar = 0.0
if per_layer_kvecs[L]:
Xk = torch.stack(per_layer_kvecs[L], dim=0)
_U, S_k, _Vh, _var, k_total_var = _pca_from_rows(Xk, k=1)
kvar = float(k_total_var)
if per_layer_vvecs[L]:
Xv = torch.stack(per_layer_vvecs[L], dim=0)
_U, S_v, _Vh, _var, v_total_var = _pca_from_rows(Xv, k=1)
vvar = float(v_total_var)
stats.append({
"k_norm": float(np.mean(acc["k_norm"])) if acc["k_norm"] else 0.0,
"v_norm": float(np.mean(acc["v_norm"])) if acc["v_norm"] else 0.0,
Expand All @@ -675,7 +729,8 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
"layer": L, "head": args.head,
"target_q": target_q,
"sink_attn": float(np.mean(acc["sink_attn"])) if acc["sink_attn"] else 0.0,
"k_total_var": kvar
"k_total_var": kvar,
"v_total_var": vvar
})
return stats, scan_maps, scan_titles
else:
Expand Down Expand Up @@ -800,6 +855,10 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
scan_stats, args.outdir, key="v_norm", ylabel=f"||V[{args.sink_idx}]||",
title=f"V-norm progression (token={args.sink_idx})", fname="scan_vnorm.png", suffix=cur_suffix
)
# Plot V-norm with V variance on right axis
_plot_v_with_variance(
scan_stats, args.outdir, sink_idx=args.sink_idx, suffix=cur_suffix
)
_plot_progression(
scan_stats, args.outdir, key="res_norm", ylabel=f"||h[{args.sink_idx}]||",
title=f"Residual-norm progression (token={args.sink_idx})", fname="scan_residual_norm.png", suffix=cur_suffix,
Expand Down Expand Up @@ -908,6 +967,75 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv):
print(f"[PCA: {VAR}] token 8 (non-sink) total variance: {K8_total_var:.4f}")
print(f"[PCA: Q] targt Q (target_q={target_q}) total variance: {Q_total_var:.4f}")

if args.find_value_subspace:
# ========== V subspace analysis ==========
# visualize 3D PCA for V
X0_c = X0 - X0.mean(dim=0, keepdim=True)
X1_c = X1 - X1.mean(dim=0, keepdim=True)
X8_c = X8 - X8.mean(dim=0, keepdim=True)

Y0 = (X0_c @ Vh0[:3].T).numpy()
Y1 = (X1_c @ Vh1[:3].T).numpy()
Y8 = (X8_c @ Vh8[:3].T).numpy()

R = max(np.abs(Y0).max(), np.abs(Y1).max(), np.abs(Y8).max())

fig = plt.figure(figsize=(9,3))
for idx, (Y, title) in enumerate(
zip([Y0, Y1, Y8], ["V0 subspace", 'V1 subspace', 'V8 subspace'])
):
ax = fig.add_subplot(1, 3, idx + 1, projection='3d')
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], s=5)
ax.set_title(title)
ax.set_xlim(-R, R); ax.set_ylim(-R, R); ax.set_zlim(-R, R)
ax.set_xlabel("PC1"); ax.set_ylabel("PC2"); ax.set_zlabel("PC3")

fig.tight_layout()
out_path = os.path.join(args.outdir, f"pca_v_subspaces_L{target_layer}_H{args.head}.png")
fig.savefig(out_path, dpi=300)
plt.close(fig)
print(f"[V subspace] Saved 3D PCA plot to {out_path}")

# V bias direction analysis: mean(V0) as the "special direction"
v0_mean = X0.mean(dim=0)
v1_mean = X1.mean(dim=0)
v8_mean = X8.mean(dim=0)

v0_mean_norm = float(v0_mean.norm().item())
v1_mean_norm = float(v1_mean.norm().item())
v8_mean_norm = float(v8_mean.norm().item())
print(f"[V MEAN] ||mean(V0)|| = {v0_mean_norm:.4f} | ||mean(V1)|| = {v1_mean_norm:.4f} | ||mean(V8)|| = {v8_mean_norm:.4f}")

# Compute how much energy of each V is along the mean(V0) direction
if v0_mean_norm > EPS:
v0_dir = F.normalize(v0_mean, dim=0)
proj_V0_on_mean = (X0 @ v0_dir) ** 2
proj_V1_on_mean = (X1 @ v0_dir) ** 2
proj_V8_on_mean = (X8 @ v0_dir) ** 2

energy_V0 = (X0 ** 2).sum(dim=1).clamp_min(EPS)
energy_V1 = (X1 ** 2).sum(dim=1).clamp_min(EPS)
energy_V8 = (X8 ** 2).sum(dim=1).clamp_min(EPS)

frac_V0 = (proj_V0_on_mean / energy_V0).mean().item()
frac_V1 = (proj_V1_on_mean / energy_V1).mean().item()
frac_V8 = (proj_V8_on_mean / energy_V8).mean().item()

print(f"[V PROJECTION] frac of V energy along mean(V0): V0={frac_V0:.4f} | V1={frac_V1:.4f} | V8={frac_V8:.4f}")

# visualize V in bias direction + residual PC (similar to K's Q energy plot)
bias_sets_v = [
(_bias_3d_projection(X0, v0_mean, topk=2), "V0 energy along mean(V0)"),
(_bias_3d_projection(X1, v0_mean, topk=2), "V1 energy along mean(V0)"),
(_bias_3d_projection(X8, v0_mean, topk=2), "V8 energy along mean(V0)"),
]
_plot_bias_energy_3d(
bias_sets_v,
os.path.join(args.outdir, f"pca_v_bias_energy_L{target_layer}_H{args.head}.png"),
x_label="V bias (mean V0)",
)
print(f"[V subspace] Saved bias energy plot")

if args.find_key_subspace:
# raw cosine scores
def _mean_or_zero(xs): return float(np.mean(xs)) if xs else 0.0
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/pca_v_bias_energy_L7_H0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/pca_v_subspaces_L24_H0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/pca_v_subspaces_L7_H0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/scan_cosine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/scan_knorm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/scan_residual_norm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/scan_vnorm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/v_experiment/scan_vnorm_with_var.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.