-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenhanced_hooking.py
More file actions
351 lines (294 loc) · 16.2 KB
/
enhanced_hooking.py
File metadata and controls
351 lines (294 loc) · 16.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
### Flexible hooking at arbitrary layers and tokens
### Written by Chris Ackerman, 2024.
## Modified by Dani Roytburg, 2025 -- changed create_add_activations_and_generate to include scoring
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
# Read activations during model forward pass (training)
def attach_activation_hooks(model, layers_positions, activation_storage, get_at='end'):
"""
Attach hooks to specified layers to capture activations at specified positions.
"""
def capture_activations_hook(layer_idx, positions, get_at='end'):
def hook(module, input, output):
# output shape is (batch_size, sequence_length, hidden_size)
# positions a batch_size list of lists of position lists
if isinstance(output, tuple): output = output[0]
for batch_idx in range(len(positions)):
for pos_idx, seq_pos in enumerate(positions[batch_idx]):
###print(f"Layer={layer_idx}, seq_pos={seq_pos}")
activation_storage[layer_idx][pos_idx].append(output[batch_idx, seq_pos, :])#.detach().cpu()
def pre_hook(module, input):
for batch_idx in range(len(positions)):
for pos_idx, seq_pos in enumerate(positions[batch_idx]):
activation_storage[layer_idx][pos_idx].append(input[0][batch_idx, seq_pos, :])#.detach().cpu()
return hook if get_at == 'end' else pre_hook
# Clear previous storage
activation_storage.clear()
# Access transformer blocks and attach hooks
transformer_blocks = get_blocks(model)
for idx, block in enumerate(transformer_blocks):
if idx in layers_positions:
hook = capture_activations_hook(idx, layers_positions[idx], get_at)
if get_at=='end': block.register_forward_hook(hook)
else: block.register_forward_pre_hook(hook)
def get_activations(model, tokens, layers_positions, get_at='end'):
"""
Get activations from specific layers and positions.
"""
# Prepare storage for activations
activation_storage = defaultdict(lambda: defaultdict(list))
# Attach hooks to the model
attach_activation_hooks(model, layers_positions, activation_storage, get_at)
# Ensure the model is in eval mode
model.eval()
# Run the model with the tokens
with torch.no_grad():
model(tokens.to(next(model.parameters()).device))
# Remove hooks after use (to avoid memory leak)
for block in get_blocks(model):
if get_at == 'end': block._forward_hooks.clear()
else: block._forward_pre_hooks.clear()
return {
layer: {
pos: torch.stack([tensor.detach().cpu() for tensor in tensors])
for pos, tensors in pos_dict.items()
}
for layer, pos_dict in activation_storage.items()
}
########
# Add to activations during generation
def create_add_activations_hook(layers_activations):
"""
Create a hook to add activation vectors at specified positions within specified layers.
layers_activations: A dictionary where keys are layer indices and values are tuples of (positions, activation_vectors).
"""
def hook(module, inputs, outputs):
layer_idx = getattr(module, 'layer_idx', None)
if layer_idx in layers_activations:
activation_info = layers_activations[layer_idx]
if isinstance(outputs, tuple):
main_tensor = outputs[0]
else: main_tensor = outputs
if main_tensor.shape[1] == 1: return outputs#hack to turn this off during generation
for position, activation_vector in activation_info.items():
# Check if the position is valid for the current output tensor shape
if position < main_tensor.shape[1]:
###print(f"Adding activations at layer {layer_idx} at position {position}")
###print(f"outputs.size={main_tensor.size()}, activation_vector.size={activation_vector.size()}")
main_tensor[:, position, :] += activation_vector
else:
print(f"Position {position} is out of bounds for the current sequence ({main_tensor.shape[1]}).")
return (main_tensor,) + outputs[1:] if isinstance(outputs, tuple) else main_tensor
return hook
def create_add_activations_pre_hook(layers_activations):
"""
Create a hook to add activation vectors at specified positions within specified layers.
layers_activations: A dictionary where keys are layer indices and values are tuples of (positions, activation_vectors).
"""
def hook(module, inputs):
if inputs[0].shape[1] == 1: return inputs #hack to turn this off during generation (same as returning None)
layer_idx = getattr(module, 'layer_idx', None)
if layer_idx in layers_activations:
activation_info = layers_activations[layer_idx]
for position, activation_vector in activation_info.items():
# Check if the position is valid for the current input tensor shape
if position < inputs[0].shape[1]:
print(f"Adding activations at layer {layer_idx} at position {position}")
print(f"inputs[0].size={inputs[0].size()}, activation_vector.size={activation_vector.size()}")
print(f"inputs[0] norm: {torch.norm(inputs[0][:, position, :]):.4f}")
print(f"Add vector norm: {torch.norm(activation_vector):.4f}")
inputs = list(inputs) # Convert tuple to list for mutability
inputs[0][:, position, :] += activation_vector
inputs = tuple(inputs) # Convert back to tuple to maintain integrity
else:
print(f"Position {position} is out of bounds for the current sequence ({inputs[0].shape[1]}).")
return inputs
return hook
def create_continuous_activation_hook(continuouspos_layer_activations):
"""
Add to each new token as it passes through the model during generation.
"""
def hook(module, inputs, outputs):
current_layer_idx = getattr(module, 'layer_idx', None)
if current_layer_idx in continuouspos_layer_activations:
activation_vector = continuouspos_layer_activations[current_layer_idx]
main_tensor = outputs[0] if isinstance(outputs, tuple) else outputs
main_tensor += activation_vector #add to every token in the prompt at the first pass, and then adds to each new token (takes a long time to run with use_cache=False, since then it adds to whole prompt at every pass)
return (main_tensor,) + outputs[1:] if isinstance(outputs, tuple) else main_tensor
return hook
def create_continuous_activation_pre_hook(continuous_layers_activations):
"""
Add to each new token as it passes through the model during generation.
"""
def pre_hook(module, inputs):
current_layer_idx = getattr(module, 'layer_idx', None)
if current_layer_idx in continuous_layers_activations:
activation_vector = continuous_layers_activations[current_layer_idx]
inputs = list(inputs) # Convert tuple to list for mutability
inputs[0] += activation_vector
inputs = tuple(inputs) # Convert back to tuple to maintain integrity
return inputs
return pre_hook
def add_activations_and_generate(model, tokens, specificpos_layer_activations, continuouspos_layer_activations, sampling_kwargs, add_at='end', score_on_token = None):
transformer_blocks = get_blocks(model)
# Attach hooks for specific initial positions
for idx, block in enumerate(transformer_blocks):
setattr(block, 'layer_idx', idx)
if idx in specificpos_layer_activations:
if add_at == 'end':
hook = create_add_activations_hook(specificpos_layer_activations)
block.register_forward_hook(hook)
else:
hook = create_add_activations_pre_hook(specificpos_layer_activations)
block.register_forward_pre_hook(hook)
# Attach hooks for multiple continuous activations across different layers
for idx, block in enumerate(transformer_blocks):
setattr(block, 'layer_idx', idx)
if idx in continuouspos_layer_activations:
if add_at == 'end':
continuous_hook = create_continuous_activation_hook(continuouspos_layer_activations)
block.register_forward_hook(continuous_hook)
else:
continuous_hook = create_continuous_activation_pre_hook(continuouspos_layer_activations)
block.register_forward_pre_hook(continuous_hook)
# Generate tokens and scores
tokens = {k: v.to(next(model.parameters()).device) for k, v in tokens.items()}
with torch.no_grad():
generated_ids = model.generate(**tokens, **sampling_kwargs)
for block in transformer_blocks:
if add_at == 'end':
block._forward_hooks.clear()
else:
block._forward_pre_hooks.clear()
# Extract scores: tuple(batch_size, vocab_size) -> Tensor(new_tokens, batch_size, vocab_size)
stacked_scores = torch.stack(generated_ids.scores, dim=0)
stacked_scores = stacked_scores.permute(1, 0, 2) # -> Tensor(batch_size, new_tokens, vocab_size)
probabilities = F.softmax(stacked_scores, dim=-1) # -> Softmax over vocab_size for probabilities
if score_on_token is None:
highest_probabilities, highest_score_indices = torch.max(probabilities, dim=-1) # -> Get probabilities of tokens
generated_sequences = generated_ids.sequences
start_pos = generated_sequences.shape[1] - highest_score_indices.shape[1] # New tokens only
generated_tokens_ids = generated_sequences[:, start_pos:]
assert torch.equal(generated_tokens_ids, highest_score_indices), (generated_tokens_ids, highest_score_indices)
return generated_tokens_ids.detach().cpu(), highest_probabilities.detach().cpu()
else:
token_probs = probabilities[:, :, score_on_token].detach().cpu()
return torch.full(token_probs.shape, score_on_token), token_probs
########################
# Zero out projections
def create_continuous_zeroout_activations_hook(continuouspos_layer_activations):
"""
Create a hook to zero-out projections of vectors within specified layers with each new token generated.
"""
def hook(module, inputs, outputs):
current_layer_idx = getattr(module, 'layer_idx', None)
if current_layer_idx in continuouspos_layer_activations:
direction_vector = continuouspos_layer_activations[current_layer_idx]
main_tensor = outputs[0] if isinstance(outputs, tuple) else outputs
direction_vector = direction_vector.to(main_tensor.dtype).view(-1, 1) # Shape [d_embed, 1]
projection = (main_tensor @ direction_vector) / torch.norm(direction_vector)**2
direction_vector = direction_vector.view(1, 1, -1) # Reshape for broadcasting; add dummy dimensions for batch and seq while keeping the last the same
main_tensor -= (projection * direction_vector)
return (main_tensor,) + outputs[1:] if isinstance(outputs, tuple) else main_tensor
return hook
def zeroout_projections_and_generate(model, tokens, continuouspos_layer_activations, sampling_kwargs):
transformer_blocks = get_blocks(model)
# Attach hooks for multiple continuous activations across different layers
for idx, block in enumerate(transformer_blocks):
setattr(block, 'layer_idx', idx)
if idx in continuouspos_layer_activations:
continuous_hook = create_continuous_zeroout_activations_hook(continuouspos_layer_activations)
block.register_forward_hook(continuous_hook)
# Generate tokens
tokens = {k: v.to(next(model.parameters()).device) for k, v in tokens.items()}
generated_ids = model.generate(**tokens, **sampling_kwargs)
# Cleanup hooks
for block in transformer_blocks:
block._forward_hooks.clear()
return generated_ids
##################
# Read during generation
def attach_generation_activation_hooks(model, layers_positions, activation_storage, get_at='end'):
"""
Attach hooks to specified layers to capture activations at specified positions, then at final during generation.
"""
def capture_activations_hook(layer_idx, positions, get_at='end'):
def hook(module, input, output):
# output shape is (batch_size, sequence_length, hidden_size)
if isinstance(output, tuple): output = output[0]
for batch_idx in range(len(positions)):
for pos_idx, seq_pos in enumerate(positions[batch_idx]):
###print(f"Layer={layer_idx}, seq_pos={seq_pos}, output.shape[1]={output.shape[1]}")
if seq_pos > output.shape[1]: #happens during generation
###print(f"max(activation_storage[layer_idx].keys())={max(activation_storage[layer_idx].keys())}")
if batch_idx == 0: # only have to adjust once per batch
maxpos = max(activation_storage[layer_idx].keys())+1
activation_storage[layer_idx][maxpos].append(output[batch_idx, output.shape[1]-1, :].detach().cpu())
break #only add one per generated token
activation_storage[layer_idx][pos_idx].append(output[batch_idx, seq_pos, :].detach().cpu())
return hook
activation_storage.clear()
transformer_blocks = get_blocks(model)
for idx, block in enumerate(transformer_blocks):
if idx in layers_positions:
hook = capture_activations_hook(idx, layers_positions[idx], get_at)
if get_at=='end': block.register_forward_hook(hook)
else: block.register_forward_pre_hook(hook)
def get_activations_and_generate(model, tokens, layers_positions, sampling_kwargs, get_at='end'):
"""
Get activations during generation
"""
# Prepare storage for activations
activation_storage = defaultdict(lambda: defaultdict(list))###defaultdict(list)
# Attach hooks to the model
attach_generation_activation_hooks(model, layers_positions, activation_storage)
# Ensure the model is in eval mode
model.eval()
# Run the model with the tokens
sampling_kwargs['use_cache'] = True #there's no point not using cache since it won't change, and it would just complicate the return struct logic
with torch.no_grad():
tokens = {k: v.to(next(model.parameters()).device) for k, v in tokens.items()}
_ = model.generate(**tokens, **sampling_kwargs)
# Remove hooks after use (to avoid memory leak)
for block in get_blocks(model):
if get_at == 'end': block._forward_hooks.clear()
else: block._forward_pre_hooks.clear()
return {layer: {pos: torch.stack(tensors) for pos, tensors in pos_dict.items()}
for layer, pos_dict in activation_storage.items()}
######################
### Utilities to turn specific gradients on and off
def disable_grad():
def hook(grad):
return None
return hook
def attach_zerograd_hooks(parameters):
handles = []
for param in parameters:
handle = param.register_hook(disable_grad())
handles.append(handle)
return handles
def remove_zerograd_hooks(handles):
for handle in handles:
handle.remove()
###############################
def get_blocks(model: nn.Module) -> nn.ModuleList:
""" Get the ModuleList containing the transformer blocks from a model. """
def numel_(mod):
if isinstance(mod, nn.Module):
num_elements = sum(p.numel() for p in mod.parameters())
return num_elements
else:
print(f"Non-module object encountered: {mod}")
return 0
model_numel = numel_(model)
candidates = [mod for mod in model.modules() if isinstance(mod, nn.ModuleList) and numel_(mod) > .5 * model_numel]
assert len(candidates) == 1, f'Found {len(candidates)} ModuleLists with >50% of model params.'
return candidates[0]
def clear_hooks(model):
transformer_blocks = get_blocks(model)
for block in transformer_blocks:
block._forward_hooks.clear()
block._forward_pre_hooks.clear()
block._backward_hooks.clear()