55from torch import Tensor
66import torch
77import numpy as np
8-
8+ import einops
99# pylint: disable=import-error
1010import pytest
1111from taker .model_repos import test_model_repos
@@ -22,32 +22,34 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
2222 opt = Model (model_repo , limit = 1000 , dtype = "fp32" , mask_fn = mask_fn )
2323
2424 with torch .no_grad ():
25- n_heads , d_head , d_model = \
26- opt .cfg .n_heads , opt .cfg .d_head , opt .cfg .d_model
25+ n_batch , n_tok , n_heads , d_head , d_model = \
26+ 1 , 1 , opt .cfg .n_heads , opt .cfg .d_head , opt .cfg .d_model
2727
2828 # Define vectors for testing
2929 #vec_in: Tensor = torch.tensor(
3030 # np.random.random(d_model), dtype=torch.float32
3131 #).to( device )
3232 vec_mid : Tensor = torch .tensor (
33- np .random .random ((n_heads , d_head )), dtype = torch .float32
33+ np .random .random ((n_batch , n_tok , n_heads , d_head )), dtype = torch .float32
3434 ).to ( device )
3535
36+ convert = lambda x : einops .rearrange (x , "... n_heads d_head -> ... (n_heads d_head)" )
37+
3638 # Define a vector that is changed at certain indices
3739 vec_mid_d0 : Tensor = copy .deepcopy ( vec_mid )
3840 vec_mid_d1 : Tensor = copy .deepcopy ( vec_mid )
3941 removed_indices = [(0 , 0 ), (0 , 10 ), (1 , 10 ), (5 , 31 )]
4042 unremoved_indices = [(0 , 1 ), (1 , 0 ), (5 , 30 )]
4143
42- removal_tensor = torch .zeros_like ( vec_mid_d0 , dtype = torch .bool )
43- keep_tensor = torch .ones_like ( vec_mid_d1 , dtype = torch .bool )
44+ removal_tensor = torch .zeros (( n_heads , d_head ) , dtype = torch .bool )
45+ keep_tensor = torch .ones (( n_heads , d_head ) , dtype = torch .bool )
4446 for (i_head , i_pos ) in removed_indices :
45- vec_mid_d0 [i_head ][ i_pos ] = 100
46- removal_tensor [i_head ][ i_pos ] = True
47- keep_tensor [i_head ][ i_pos ] = False
47+ vec_mid_d0 [..., i_head , i_pos ] = 100
48+ removal_tensor [i_head , i_pos ] = True
49+ keep_tensor [i_head , i_pos ] = False
4850
4951 for i_head , i_pos in unremoved_indices :
50- vec_mid_d1 [i_head ][ i_pos ] = 100
52+ vec_mid_d1 [..., i_head , i_pos ] = 100
5153
5254 # Start tests
5355 for add_mean in [False ]: # TODO: add True again
@@ -61,10 +63,10 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
6163 out_proj_orig_weight = out_proj .weight .detach ().clone ()
6264
6365 # Test that the old outputs do care about changes to all indices
64- old_vec_out = out_proj (vec_mid . flatten ()[ None , :] )
65- old_vec_out_d0 = out_proj (vec_mid_d0 . flatten ()[ None , :] )
66- print ( '- vec :' , old_vec_out [:5 ] )
67- print ( '- vec+ (1) :' , old_vec_out_d0 [:5 ] )
66+ old_vec_out = out_proj (convert ( vec_mid ) )
67+ old_vec_out_d0 = out_proj (convert ( vec_mid_d0 ) )
68+ print ( '- vec :' , old_vec_out [..., :5 ] )
69+ print ( '- vec+ (1) :' , old_vec_out_d0 [..., :5 ] )
6870 assert not torch .equal ( old_vec_out , old_vec_out_d0 )
6971
7072 # Run the deletion
@@ -80,12 +82,12 @@ def test_delete_attn_pre_out_layer(self, model_repo, mask_fn):
8082
8183 # Test that new outputs do not care about changes to deleted indices
8284 # but still care about changes to undeleted indices.
83- new_vec_out = out_proj (vec_mid . flatten ()[ None , :] )
84- new_vec_out_d0 = out_proj (vec_mid_d0 . flatten ()[ None , :] )
85- new_vec_out_d1 = out_proj (vec_mid_d1 . flatten ()[ None , :] )
86- print ( '- vec :' , new_vec_out [:5 ] )
87- print ( '- vec+ (1) :' , new_vec_out_d0 [:5 ] )
88- print ( '- vec+ (2) :' , new_vec_out_d1 [:5 ] )
85+ new_vec_out = out_proj (convert ( vec_mid ) )
86+ new_vec_out_d0 = out_proj (convert ( vec_mid_d0 ) )
87+ new_vec_out_d1 = out_proj (convert ( vec_mid_d1 ) )
88+ print ( '- vec :' , new_vec_out [..., :5 ] )
89+ print ( '- vec+ (1) :' , new_vec_out_d0 [..., :5 ] )
90+ print ( '- vec+ (2) :' , new_vec_out_d1 [..., :5 ] )
8991 assert torch .equal ( new_vec_out , new_vec_out_d0 )
9092 assert not torch .equal ( new_vec_out_d0 , new_vec_out_d1 )
9193
@@ -110,14 +112,14 @@ def test_delete_attn_value_layer(self, model_repo, mask_fn):
110112 v_proj = opt .layers [LAYER ]["attn.v_proj" ]
111113 v_proj_orig_weight = v_proj .weight .detach ().clone ()
112114
113- n_heads , d_head , d_model = \
114- opt .cfg .n_heads , opt .cfg .d_head , opt .cfg .d_model
115+ n_batch , n_tok , n_heads , d_head , d_model = \
116+ 1 , 1 , opt .cfg .n_heads , opt .cfg .d_head , opt .cfg .d_model
115117
116118 # Start test
117119 with torch .no_grad ():
118120 # Define vec in
119121 vec_in : Tensor = torch .tensor (
120- np .random .random (d_model ), dtype = torch .float32
122+ np .random .random (( n_batch , n_tok , d_model ) ), dtype = torch .float32
121123 ).to ( device )
122124
123125 # Choose indices (head, pos) to delete
@@ -127,22 +129,22 @@ def test_delete_attn_value_layer(self, model_repo, mask_fn):
127129 keep_tensor = \
128130 torch .ones ((n_heads , d_head ), dtype = torch .bool , device = device )
129131 for (i_head , i_pos ) in removed_indices :
130- removal_tensor [i_head ][ i_pos ] = True
131- keep_tensor [i_head ][ i_pos ] = False
132+ removal_tensor [i_head , i_pos ] = True
133+ keep_tensor [i_head , i_pos ] = False
132134
133135
134136 # Get output vector before deletion
135- old_vec_mid = v_proj (vec_in ).reshape ((n_heads , d_head ))
136- print ( '- old vec :' , old_vec_mid [:5 ] )
137+ old_vec_mid = v_proj (vec_in ).reshape ((n_batch , n_tok , n_heads , d_head ))
138+ print ( '- old vec :' , old_vec_mid [..., :5 ] )
137139
138140 # Run the deletion
139141 print ('deleting indices:' , removed_indices )
140142 opt .hooks .delete_attn_neurons (removal_tensor , LAYER )
141143 v_proj = opt .layers [LAYER ]["attn.v_proj" ]
142144
143145 # Get output vector after deletion
144- new_vec_mid = v_proj (vec_in ).reshape ((n_heads , d_head ))
145- print ( '- new vec :' , new_vec_mid [:5 ] )
146+ new_vec_mid = v_proj (vec_in ).reshape ((n_batch , n_tok , n_heads , d_head ))
147+ print ( '- new vec :' , new_vec_mid [..., :5 ] )
146148
147149 # Test that new outputs do not care about changes to deleted indices
148150 # Check weight changes
0 commit comments