Skip to content

Commit 1300f4a

Browse files
committed
Fix it so that hooks reinitialised shape when shape is changed (I hope o.o)
1 parent 7f81af9 commit 1300f4a

File tree

2 files changed

+59
-34
lines changed

2 files changed

+59
-34
lines changed

src/taker/hooks.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,33 @@ class NeuronMask(torch.nn.Module):
338338

339339
def __init__(self, shape, act_fn: str = "step"):
340340
super(NeuronMask, self).__init__()
341-
self.shape = shape
342341
self.act_fn = act_fn
342+
self.shape: torch.Size = None
343+
self.param: torch.nn.Parameter = None
344+
self.offset: torch.nn.Parameter = None
345+
self.reinit_hook(shape=shape)
346+
347+
def check_shapes_match(self, x):
348+
curr_shape = torch.Size(self.shape)
349+
input_shape = torch.Size(x.shape[-len(curr_shape):])
350+
return curr_shape == input_shape, f"{curr_shape} vs {input_shape} (from {x.shape})"
351+
352+
def reinit_hook(self, x=None, shape=None):
353+
# batch, token, (d_model or otherwise)
354+
if x is not None:
355+
new_shape, new_dtype = x.shape[2:], x.dtype
356+
elif shape is not None:
357+
new_shape, new_dtype = shape, torch.float32
358+
else:
359+
raise ValueError("Either x or shape must be provided to init NeuronMask")
360+
361+
self.shape = new_shape
362+
vec = torch.ones(new_shape, dtype=new_dtype)
343363
# initialize mask as nn.Parameter of ones
344-
_vec = torch.ones(shape, dtype=torch.float32)
345364
if self.act_fn == "sigmoid":
346-
_vec[...] = torch.inf
347-
self.param = torch.nn.Parameter(_vec)
348-
self.offset = torch.nn.Parameter(torch.zeros_like(_vec))
365+
vec[...] = torch.inf
366+
self.param = torch.nn.Parameter(vec)
367+
self.offset = torch.nn.Parameter(torch.zeros_like(vec))
349368

350369
def get_mask(self):
351370
# if step, we want heaviside step function. ie: mask = mask > 0
@@ -393,6 +412,10 @@ def inverse_mask(self, x, offset=False):
393412
return x * inv_mask
394413

395414
def forward(self, x):
415+
is_match, msg = self.check_shapes_match(x)
416+
if not is_match:
417+
print(f"Shape mismatch: {msg}, reinitialising mask hook")
418+
self.reinit_hook(x)
396419
self.to(x.device)
397420
mask = self.get_mask()
398421
offset = self.get_offset(x)

tests/test_delete_attn_pre_out_layer.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch import Tensor
66
import torch
77
import numpy as np
8-
8+
import einops
99
# pylint: disable=import-error
1010
import pytest
1111
from taker.model_repos import test_model_repos
@@ -22,32 +22,34 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
2222
opt = Model(model_repo, limit=1000, dtype="fp32", mask_fn=mask_fn)
2323

2424
with torch.no_grad():
25-
n_heads, d_head, d_model = \
26-
opt.cfg.n_heads, opt.cfg.d_head, opt.cfg.d_model
25+
n_batch, n_tok, n_heads, d_head, d_model = \
26+
1, 1, opt.cfg.n_heads, opt.cfg.d_head, opt.cfg.d_model
2727

2828
# Define vectors for testing
2929
#vec_in: Tensor = torch.tensor(
3030
# np.random.random(d_model), dtype=torch.float32
3131
#).to( device )
3232
vec_mid: Tensor = torch.tensor(
33-
np.random.random((n_heads, d_head)), dtype=torch.float32
33+
np.random.random((n_batch, n_tok, n_heads, d_head)), dtype=torch.float32
3434
).to( device )
3535

36+
convert = lambda x: einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
37+
3638
# Define a vector that is changed at certain indices
3739
vec_mid_d0 : Tensor = copy.deepcopy( vec_mid )
3840
vec_mid_d1 : Tensor = copy.deepcopy( vec_mid )
3941
removed_indices = [(0, 0), (0, 10), (1, 10), (5, 31)]
4042
unremoved_indices = [(0, 1), (1, 0), (5, 30)]
4143

42-
removal_tensor = torch.zeros_like(vec_mid_d0, dtype=torch.bool)
43-
keep_tensor = torch.ones_like(vec_mid_d1, dtype=torch.bool)
44+
removal_tensor = torch.zeros((n_heads, d_head), dtype=torch.bool)
45+
keep_tensor = torch.ones((n_heads, d_head), dtype=torch.bool)
4446
for (i_head, i_pos) in removed_indices:
45-
vec_mid_d0[i_head][i_pos] = 100
46-
removal_tensor[i_head][i_pos] = True
47-
keep_tensor[i_head][i_pos] = False
47+
vec_mid_d0[..., i_head, i_pos] = 100
48+
removal_tensor[i_head, i_pos] = True
49+
keep_tensor[i_head, i_pos] = False
4850

4951
for i_head, i_pos in unremoved_indices:
50-
vec_mid_d1[i_head][i_pos] = 100
52+
vec_mid_d1[..., i_head, i_pos] = 100
5153

5254
# Start tests
5355
for add_mean in [False]: # TODO: add True again
@@ -61,10 +63,10 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
6163
out_proj_orig_weight = out_proj.weight.detach().clone()
6264

6365
# Test that the old outputs do care about changes to all indices
64-
old_vec_out = out_proj(vec_mid.flatten()[None, :])
65-
old_vec_out_d0 = out_proj(vec_mid_d0.flatten()[None, :])
66-
print( '- vec :', old_vec_out[:5] )
67-
print( '- vec+ (1) :', old_vec_out_d0[:5] )
66+
old_vec_out = out_proj(convert(vec_mid))
67+
old_vec_out_d0 = out_proj(convert(vec_mid_d0))
68+
print( '- vec :', old_vec_out[..., :5] )
69+
print( '- vec+ (1) :', old_vec_out_d0[..., :5] )
6870
assert not torch.equal( old_vec_out, old_vec_out_d0 )
6971

7072
# Run the deletion
@@ -80,12 +82,12 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
8082

8183
# Test that new outputs do not care about changes to deleted indices
8284
# but still care about changes to undeleted indices.
83-
new_vec_out = out_proj(vec_mid.flatten()[None, :])
84-
new_vec_out_d0 = out_proj(vec_mid_d0.flatten()[None, :])
85-
new_vec_out_d1 = out_proj(vec_mid_d1.flatten()[None, :])
86-
print( '- vec :', new_vec_out[:5] )
87-
print( '- vec+ (1) :', new_vec_out_d0[:5] )
88-
print( '- vec+ (2) :', new_vec_out_d1[:5] )
85+
new_vec_out = out_proj(convert(vec_mid))
86+
new_vec_out_d0 = out_proj(convert(vec_mid_d0))
87+
new_vec_out_d1 = out_proj(convert(vec_mid_d1))
88+
print( '- vec :', new_vec_out[..., :5] )
89+
print( '- vec+ (1) :', new_vec_out_d0[..., :5] )
90+
print( '- vec+ (2) :', new_vec_out_d1[..., :5] )
8991
assert torch.equal( new_vec_out, new_vec_out_d0 )
9092
assert not torch.equal( new_vec_out_d0, new_vec_out_d1 )
9193

@@ -110,14 +112,14 @@ def test_delete_attn_value_layer(self, model_repo, mask_fn):
110112
v_proj = opt.layers[LAYER]["attn.v_proj"]
111113
v_proj_orig_weight = v_proj.weight.detach().clone()
112114

113-
n_heads, d_head, d_model = \
114-
opt.cfg.n_heads, opt.cfg.d_head, opt.cfg.d_model
115+
n_batch, n_tok, n_heads, d_head, d_model = \
116+
1, 1, opt.cfg.n_heads, opt.cfg.d_head, opt.cfg.d_model
115117

116118
# Start test
117119
with torch.no_grad():
118120
# Define vec in
119121
vec_in: Tensor = torch.tensor(
120-
np.random.random(d_model), dtype=torch.float32
122+
np.random.random((n_batch, n_tok, d_model)), dtype=torch.float32
121123
).to( device )
122124

123125
# Choose indices (head, pos) to delete
@@ -127,22 +129,22 @@ def test_delete_attn_value_layer(self, model_repo, mask_fn):
127129
keep_tensor = \
128130
torch.ones((n_heads, d_head), dtype=torch.bool, device=device)
129131
for (i_head, i_pos) in removed_indices:
130-
removal_tensor[i_head][i_pos] = True
131-
keep_tensor[i_head][i_pos] = False
132+
removal_tensor[i_head, i_pos] = True
133+
keep_tensor[i_head, i_pos] = False
132134

133135

134136
# Get output vector before deletion
135-
old_vec_mid = v_proj(vec_in).reshape((n_heads, d_head))
136-
print( '- old vec :', old_vec_mid[:5] )
137+
old_vec_mid = v_proj(vec_in).reshape((n_batch, n_tok, n_heads, d_head))
138+
print( '- old vec :', old_vec_mid[..., :5] )
137139

138140
# Run the deletion
139141
print('deleting indices:', removed_indices)
140142
opt.hooks.delete_attn_neurons(removal_tensor, LAYER)
141143
v_proj = opt.layers[LAYER]["attn.v_proj"]
142144

143145
# Get output vector after deletion
144-
new_vec_mid = v_proj(vec_in).reshape((n_heads, d_head))
145-
print( '- new vec :', new_vec_mid[:5] )
146+
new_vec_mid = v_proj(vec_in).reshape((n_batch, n_tok, n_heads, d_head))
147+
print( '- new vec :', new_vec_mid[..., :5] )
146148

147149
# Test that new outputs do not care about changes to deleted indices
148150
# Check weight changes

0 commit comments

Comments
 (0)