diff --git a/attention_sink/analyze_sink.py b/attention_sink/analyze_sink.py index e087964..eae03d7 100644 --- a/attention_sink/analyze_sink.py +++ b/attention_sink/analyze_sink.py @@ -412,12 +412,22 @@ def _plot_progression(stats_a, outdir, key, ylabel, title, fname, suffix="", sta plt.title(title) has_kvar = bool(stats_a) and ("k_total_var" in stats_a[0]) + has_vvar = bool(stats_a) and ("v_total_var" in stats_a[0]) if has_kvar: y2 = [by_a[L]["k_total_var"] for L in layers] ax2 = ax.twinx() ax2.plot(layers, y2, marker="^", linestyle="--", color="tab:red", label="Total variance of K") ax2.set_ylim(0, 20000) + lines1, labels1 = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines1 + lines2, labels1 + labels2, loc="best") + elif has_vvar: + y2 = [by_a[L]["v_total_var"] for L in layers] + ax2 = ax.twinx() + ax2.plot(layers, y2, marker="^", linestyle="--", color="tab:red", label="Total variance of V") + ax2.set_ylim(0, 20000) + lines1, labels1 = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines1 + lines2, labels1 + labels2, loc="best") @@ -428,6 +438,43 @@ def _plot_progression(stats_a, outdir, key, ylabel, title, fname, suffix="", sta plt.savefig(_append_suffix(os.path.join(outdir, fname), suffix), dpi=300) plt.close() +def _plot_v_with_variance(stats, outdir, sink_idx, suffix=""): + """Plot V-norm with V total variance on right axis (similar to K plot).""" + layers = sorted({s["layer"] for s in stats}) + by_s = {s["layer"]: s for s in stats} + y_norm = [by_s[L]["v_norm"] for L in layers] + y_var = [by_s[L].get("v_total_var", 0.0) for L in layers] + + # Auto-scale y-axis based on data + max_norm = max(y_norm) if y_norm else 100 + max_var = max(y_var) if y_var else 20000 + + plt.figure(figsize=(7, 3.5)) + ax = plt.gca() + ax.plot(layers, y_norm, marker="o", color="tab:blue", label="Norm of V") + ax.set_xlabel("Layer") + ax.set_ylabel(f"||V[{sink_idx}]||") + ax.set_ylim(0, max(100, max_norm * 1.1)) + + if len(layers) > 0: + ax.set_xlim(min(layers), max(layers)) + ax.set_xticks(layers) + ax.xaxis.set_major_locator(ticker.MultipleLocator(2)) + + ax2 = ax.twinx() + ax2.plot(layers, y_var, marker="^", linestyle="--", color="tab:red", label="Total variance of V") + ax2.set_ylim(0, max(20000, max_var * 1.1)) + ax2.set_ylabel("Total variance of V") + + lines1, labels1 = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines1 + lines2, labels1 + labels2, loc="best") + + plt.title(f"V-norm progression (token={sink_idx})") + plt.tight_layout() + plt.savefig(_append_suffix(os.path.join(outdir, "scan_vnorm_with_var.png"), suffix), dpi=300) + plt.close() + def _plot_bias_energy_3d(bias_sets, out_path, x_label): Ys = [Y for (Y, _t) in bias_sets if Y is not None] if not Ys: @@ -612,6 +659,7 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): for L in scan_layers } per_layer_kvecs = {L: [] for L in scan_layers} + per_layer_vvecs = {L: [] for L in scan_layers} # V vectors for variance per_layer_maps = { L: [] for L in scan_layers } @@ -647,6 +695,7 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): per_layer_acc[L]["cos"].append(series[0][1]) per_layer_acc[L]["sink_attn"].append(float(sink_slice.mean().item())) per_layer_kvecs[L].append(k[0, args.head, sink_idx].detach().cpu()) + per_layer_vvecs[L].append(v[0, args.head, sink_idx].detach().cpu()) else: hm = _pick_head_with_caches(model, attns, L, args.head) per_layer_maps[L].append(hm[:min_len, :min_len]) @@ -663,10 +712,15 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): for L in sorted(per_layer_acc.keys()): acc = per_layer_acc[L] kvar = 0.0 + vvar = 0.0 if per_layer_kvecs[L]: Xk = torch.stack(per_layer_kvecs[L], dim=0) _U, S_k, _Vh, _var, k_total_var = _pca_from_rows(Xk, k=1) kvar = float(k_total_var) + if per_layer_vvecs[L]: + Xv = torch.stack(per_layer_vvecs[L], dim=0) + _U, S_v, _Vh, _var, v_total_var = _pca_from_rows(Xv, k=1) + vvar = float(v_total_var) stats.append({ "k_norm": float(np.mean(acc["k_norm"])) if acc["k_norm"] else 0.0, "v_norm": float(np.mean(acc["v_norm"])) if acc["v_norm"] else 0.0, @@ -675,7 +729,8 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): "layer": L, "head": args.head, "target_q": target_q, "sink_attn": float(np.mean(acc["sink_attn"])) if acc["sink_attn"] else 0.0, - "k_total_var": kvar + "k_total_var": kvar, + "v_total_var": vvar }) return stats, scan_maps, scan_titles else: @@ -800,6 +855,10 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): scan_stats, args.outdir, key="v_norm", ylabel=f"||V[{args.sink_idx}]||", title=f"V-norm progression (token={args.sink_idx})", fname="scan_vnorm.png", suffix=cur_suffix ) + # Plot V-norm with V variance on right axis + _plot_v_with_variance( + scan_stats, args.outdir, sink_idx=args.sink_idx, suffix=cur_suffix + ) _plot_progression( scan_stats, args.outdir, key="res_norm", ylabel=f"||h[{args.sink_idx}]||", title=f"Residual-norm progression (token={args.sink_idx})", fname="scan_residual_norm.png", suffix=cur_suffix, @@ -908,6 +967,75 @@ def _run_scan_pass(scan_layers, rope_overrides_local, need_qkv): print(f"[PCA: {VAR}] token 8 (non-sink) total variance: {K8_total_var:.4f}") print(f"[PCA: Q] targt Q (target_q={target_q}) total variance: {Q_total_var:.4f}") + if args.find_value_subspace: + # ========== V subspace analysis ========== + # visualize 3D PCA for V + X0_c = X0 - X0.mean(dim=0, keepdim=True) + X1_c = X1 - X1.mean(dim=0, keepdim=True) + X8_c = X8 - X8.mean(dim=0, keepdim=True) + + Y0 = (X0_c @ Vh0[:3].T).numpy() + Y1 = (X1_c @ Vh1[:3].T).numpy() + Y8 = (X8_c @ Vh8[:3].T).numpy() + + R = max(np.abs(Y0).max(), np.abs(Y1).max(), np.abs(Y8).max()) + + fig = plt.figure(figsize=(9,3)) + for idx, (Y, title) in enumerate( + zip([Y0, Y1, Y8], ["V0 subspace", 'V1 subspace', 'V8 subspace']) + ): + ax = fig.add_subplot(1, 3, idx + 1, projection='3d') + ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], s=5) + ax.set_title(title) + ax.set_xlim(-R, R); ax.set_ylim(-R, R); ax.set_zlim(-R, R) + ax.set_xlabel("PC1"); ax.set_ylabel("PC2"); ax.set_zlabel("PC3") + + fig.tight_layout() + out_path = os.path.join(args.outdir, f"pca_v_subspaces_L{target_layer}_H{args.head}.png") + fig.savefig(out_path, dpi=300) + plt.close(fig) + print(f"[V subspace] Saved 3D PCA plot to {out_path}") + + # V bias direction analysis: mean(V0) as the "special direction" + v0_mean = X0.mean(dim=0) + v1_mean = X1.mean(dim=0) + v8_mean = X8.mean(dim=0) + + v0_mean_norm = float(v0_mean.norm().item()) + v1_mean_norm = float(v1_mean.norm().item()) + v8_mean_norm = float(v8_mean.norm().item()) + print(f"[V MEAN] ||mean(V0)|| = {v0_mean_norm:.4f} | ||mean(V1)|| = {v1_mean_norm:.4f} | ||mean(V8)|| = {v8_mean_norm:.4f}") + + # Compute how much energy of each V is along the mean(V0) direction + if v0_mean_norm > EPS: + v0_dir = F.normalize(v0_mean, dim=0) + proj_V0_on_mean = (X0 @ v0_dir) ** 2 + proj_V1_on_mean = (X1 @ v0_dir) ** 2 + proj_V8_on_mean = (X8 @ v0_dir) ** 2 + + energy_V0 = (X0 ** 2).sum(dim=1).clamp_min(EPS) + energy_V1 = (X1 ** 2).sum(dim=1).clamp_min(EPS) + energy_V8 = (X8 ** 2).sum(dim=1).clamp_min(EPS) + + frac_V0 = (proj_V0_on_mean / energy_V0).mean().item() + frac_V1 = (proj_V1_on_mean / energy_V1).mean().item() + frac_V8 = (proj_V8_on_mean / energy_V8).mean().item() + + print(f"[V PROJECTION] frac of V energy along mean(V0): V0={frac_V0:.4f} | V1={frac_V1:.4f} | V8={frac_V8:.4f}") + + # visualize V in bias direction + residual PC (similar to K's Q energy plot) + bias_sets_v = [ + (_bias_3d_projection(X0, v0_mean, topk=2), "V0 energy along mean(V0)"), + (_bias_3d_projection(X1, v0_mean, topk=2), "V1 energy along mean(V0)"), + (_bias_3d_projection(X8, v0_mean, topk=2), "V8 energy along mean(V0)"), + ] + _plot_bias_energy_3d( + bias_sets_v, + os.path.join(args.outdir, f"pca_v_bias_energy_L{target_layer}_H{args.head}.png"), + x_label="V bias (mean V0)", + ) + print(f"[V subspace] Saved bias energy plot") + if args.find_key_subspace: # raw cosine scores def _mean_or_zero(xs): return float(np.mean(xs)) if xs else 0.0 diff --git a/results/v_experiment/pca_v_bias_energy_L24_H0.png b/results/v_experiment/pca_v_bias_energy_L24_H0.png new file mode 100644 index 0000000..87a4637 Binary files /dev/null and b/results/v_experiment/pca_v_bias_energy_L24_H0.png differ diff --git a/results/v_experiment/pca_v_bias_energy_L7_H0.png b/results/v_experiment/pca_v_bias_energy_L7_H0.png new file mode 100644 index 0000000..8f2b94a Binary files /dev/null and b/results/v_experiment/pca_v_bias_energy_L7_H0.png differ diff --git a/results/v_experiment/pca_v_subspaces_L24_H0.png b/results/v_experiment/pca_v_subspaces_L24_H0.png new file mode 100644 index 0000000..ebe1250 Binary files /dev/null and b/results/v_experiment/pca_v_subspaces_L24_H0.png differ diff --git a/results/v_experiment/pca_v_subspaces_L7_H0.png b/results/v_experiment/pca_v_subspaces_L7_H0.png new file mode 100644 index 0000000..88aaef3 Binary files /dev/null and b/results/v_experiment/pca_v_subspaces_L7_H0.png differ diff --git a/results/v_experiment/scan_cosine.png b/results/v_experiment/scan_cosine.png new file mode 100644 index 0000000..0d7bbd8 Binary files /dev/null and b/results/v_experiment/scan_cosine.png differ diff --git a/results/v_experiment/scan_knorm.png b/results/v_experiment/scan_knorm.png new file mode 100644 index 0000000..39ac6d9 Binary files /dev/null and b/results/v_experiment/scan_knorm.png differ diff --git a/results/v_experiment/scan_residual_norm.png b/results/v_experiment/scan_residual_norm.png new file mode 100644 index 0000000..5ce58e8 Binary files /dev/null and b/results/v_experiment/scan_residual_norm.png differ diff --git a/results/v_experiment/scan_vnorm.png b/results/v_experiment/scan_vnorm.png new file mode 100644 index 0000000..b7d3343 Binary files /dev/null and b/results/v_experiment/scan_vnorm.png differ diff --git a/results/v_experiment/scan_vnorm_with_var.png b/results/v_experiment/scan_vnorm_with_var.png new file mode 100644 index 0000000..a94e138 Binary files /dev/null and b/results/v_experiment/scan_vnorm_with_var.png differ diff --git a/results/v_experiment_token1/scan_cosine__sink[1].png b/results/v_experiment_token1/scan_cosine__sink[1].png new file mode 100644 index 0000000..597795c Binary files /dev/null and b/results/v_experiment_token1/scan_cosine__sink[1].png differ diff --git a/results/v_experiment_token1/scan_knorm__sink[1].png b/results/v_experiment_token1/scan_knorm__sink[1].png new file mode 100644 index 0000000..9aee233 Binary files /dev/null and b/results/v_experiment_token1/scan_knorm__sink[1].png differ diff --git a/results/v_experiment_token1/scan_residual_norm__sink[1].png b/results/v_experiment_token1/scan_residual_norm__sink[1].png new file mode 100644 index 0000000..0fc7b9d Binary files /dev/null and b/results/v_experiment_token1/scan_residual_norm__sink[1].png differ diff --git a/results/v_experiment_token1/scan_vnorm__sink[1].png b/results/v_experiment_token1/scan_vnorm__sink[1].png new file mode 100644 index 0000000..cc63176 Binary files /dev/null and b/results/v_experiment_token1/scan_vnorm__sink[1].png differ diff --git a/results/v_experiment_token1/scan_vnorm_with_var__sink[1].png b/results/v_experiment_token1/scan_vnorm_with_var__sink[1].png new file mode 100644 index 0000000..937e669 Binary files /dev/null and b/results/v_experiment_token1/scan_vnorm_with_var__sink[1].png differ