Hi, i'm interested in the score calculation for tokens, have you tried v_norm options instead of 1-v_norm or different combinations with and without cfg part?
def kv_norm_score(cache_dic, current):
# (B, N, num_heads)
#cond_k_norm, uncond_k_norm = torch.split(cache_dic['cache'][-1][current['layer']]['k_norm'], len(cache_dic['cache'][-1][current['layer']]['k_norm']) // 2, dim=0)
cond_v_norm, uncond_v_norm = torch.split(cache_dic['cache'][-1][current['layer']]['v_norm'], len(cache_dic['cache'][-1][current['layer']]['v_norm']) // 2, dim=0)
cond_weight = 0.5
#k_norm = cond_weight * cond_k_norm + (1 - cond_weight) * uncond_k_norm
v_norm = cond_weight * cond_v_norm + (1 - cond_weight) * uncond_v_norm
kv_norm = 1 -v_norm
#kv_norm_mean = kv_norm.mean(dim=-2, keepdim=True)
#kv_norm_diff = torch.abs(kv_norm - kv_norm_mean)
return F.normalize(kv_norm.sum(dim=-1), p=2).repeat(2, 1)
Hi, i'm interested in the score calculation for tokens, have you tried v_norm options instead of 1-v_norm or different combinations with and without cfg part?