I've been trying to run the SHIFT experiment for Pythia 70M on my Mac but have been running into trouble. Below I report the steps I took and the error that has me stumped. Could you help me figure out what I'm missing?
With these changes, when running cell 5 to train the probes I get the following error:
Here's the full trace
Collecting activations: 0%| | 0/141 [00:00 380 output = self.target(*args, **kwargs)
382 # Set value.
IndexError: tuple index out of range
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:383, in Node.execute(self)
382 # Set value.
--> 383 self.set_value(output)
385 except Exception as e:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:410, in Node.set_value(self, value)
409 if listener.fulfilled() and not self.graph.sequential:
--> 410 listener.execute()
412 for dependency in self.arg_dependencies:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:387, in Node.execute(self)
385 except Exception as e:
--> 387 raise type(e)(
388 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
389 ) from e
391 finally:
IndexError: Above exception when execution Node: 'getitem_1' in Graph: '13863411648'
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
Cell In[5], line 1
----> 1 oracle, _ = train_probe(
2 activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=False))
3 )
5 ambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
6 print(f"ambiguous test accuracy: {ambiguous_accs[0]}")
Cell In[4], line 30, in train_probe(activation_batches, label_idx, lr, epochs, d_probe, seed)
27 losses = []
29 for epoch in range(epochs):
---> 30 for act, *labels, in activation_batches:
31 optimizer.zero_grad()
32 logits = probe(act)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:38, in _wrap_generator..generator_context(*args, **kwargs)
35 try:
36 # Issuing None to a generator fires it up
37 with ctx_factory():
---> 38 response = gen.send(None)
40 while True:
41 try:
42 # Forward the response to our caller and get its next request
Cell In[3], line 12, in collect_activations(model, layer, text_batches)
10 with tqdm(total=len(text_batches), desc="Collecting activations") as pbar:
11 for text_batch, *labels in text_batches:
---> 12 with model.trace(text_batch, **tracer_kwargs):
13 attn_mask = model.input[1]['attention_mask']
14 acts = model.gpt_neox.layers[layer].output[0]
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/Tracer.py:102, in Tracer.exit(self, exc_type, exc_val, exc_tb)
97 self.invoker.exit(None, None, None)
99 self.model._envoy._reset()
--> 102 super().exit(exc_type, exc_val, exc_tb)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/GraphBasedContext.py:217, in GraphBasedContext.exit(self, exc_type, exc_val, exc_tb)
214 self.graph = None
215 raise exc_val
--> 217 self.backend(self)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/backends/LocalBackend.py:27, in LocalBackend.call(self, obj)
25 def call(self, obj: LocalMixin):
---> 27 obj.local_backend_execute()
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/Tracer.py:146, in Tracer.local_backend_execute(self)
142 invoker_inputs = resolve_dependencies(invoker_inputs)
144 self.graph.execute()
--> 146 self.model.interleave(
147 self.model._execute,
148 self.graph,
149 *invoker_inputs,
150 **self._kwargs,
151 )
153 graph = self.graph
154 graph.alive = False
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:469, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
463 intervention_handler = InterventionHandler(
464 intervention_graph, batch_groups, batch_size
465 )
467 module_paths = InterventionProtocol.get_interventions(intervention_graph).keys()
--> 469 with HookHandler(
470 self._model,
471 list(module_paths),
472 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
473 activations, module_path, "input", intervention_handler
474 ),
475 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
476 activations, module_path, "output", intervention_handler
477 ),
478 ):
479 try:
480 fn(*inputs, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:581, in HookHandler.exit(self, exc_type, exc_val, exc_tb)
578 handle.remove()
580 if isinstance(exc_val, Exception):
--> 581 raise exc_val
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:480, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
469 with HookHandler(
470 self._model,
471 list(module_paths),
(...)
477 ),
478 ):
479 try:
--> 480 fn(*inputs, **kwargs)
481 except protocols.EarlyStopProtocol.EarlyStopException:
482 # TODO: Log.
483 for node in intervention_graph.nodes.values():
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/mixins/Generation.py:20, in GenerationMixin._execute(self, generate, *args, **kwargs)
16 if generate:
18 return self._execute_generate(*args, **kwargs)
---> 20 return self._execute_forward(*args, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/LanguageModel.py:327, in LanguageModel._execute_forward(self, prepared_inputs, *args, **kwargs)
323 def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):
325 device = next(self._model.parameters()).device
--> 327 return self._model(
328 *args,
329 **prepared_inputs.to(device),
330 **kwargs,
331 )
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1879, in Module._call_impl(self, *args, **kwargs)
1876 return inner()
1878 try:
-> 1879 return inner()
1880 except Exception:
1881 # run always called hooks if they have not already been run
1882 # For now only forward hooks have the always_call option but perhaps
1883 # this functionality should be added to full backward hooks as well.
1884 for hook_id, hook in _global_forward_hooks.items():
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1806, in Module._call_impl..inner()
1801 for hook_id, hook in (
1802 *_global_forward_pre_hooks.items(),
1803 *self._forward_pre_hooks.items(),
1804 ):
1805 if hook_id in self._forward_pre_hooks_with_kwargs:
-> 1806 args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
1807 if args_kwargs_result is not None:
1808 if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:555, in HookHandler.enter..input_hook(module, input, kwargs, module_path)
554 def input_hook(module, input, kwargs, module_path=module_path):
--> 555 return self.input_hook((input, kwargs), module_path)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:472, in NNsight.interleave..(activations, module_path)
463 intervention_handler = InterventionHandler(
464 intervention_graph, batch_groups, batch_size
465 )
467 module_paths = InterventionProtocol.get_interventions(intervention_graph).keys()
469 with HookHandler(
470 self._model,
471 list(module_paths),
--> 472 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
473 activations, module_path, "input", intervention_handler
474 ),
475 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
476 activations, module_path, "output", intervention_handler
477 ),
478 ):
479 try:
480 fn(*inputs, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:456, in InterventionProtocol.intervene(cls, activations, module_path, key, intervention_handler)
449 value = util.apply(
450 activations,
451 narrow,
452 torch.Tensor,
453 )
455 # Value injection.
--> 456 node.set_value(value)
458 # Check if through the previous value injection, there was a 'swap' intervention.
459 # This would mean we want to replace activations for this batch with some other ones.
460 value = protocols.SwapProtocol.get_swap(
461 intervention_handler.graph, value
462 )
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:410, in Node.set_value(self, value)
407 listener.remaining_dependencies -= 1
409 if listener.fulfilled() and not self.graph.sequential:
--> 410 listener.execute()
412 for dependency in self.arg_dependencies:
413 dependency.remaining_listeners -= 1
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:387, in Node.execute(self)
383 self.set_value(output)
385 except Exception as e:
--> 387 raise type(e)(
388 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
389 ) from e
391 finally:
392 self.remaining_dependencies -= 1
IndexError: Above exception when execution Node: 'getitem_0' in Graph: '13863411648'
Hi,
I've been trying to run the SHIFT experiment for Pythia 70M on my Mac but have been running into trouble. Below I report the steps I took and the error that has me stumped. Could you help me figure out what I'm missing?
Thanks!
Installation process worked fine:
Downloaded the Pythia 70M dictionaries with the commands from the README:
Then I had to make a few changes in
bib_shift.ipynb:In cell 1:
DEVICEto"mps"activation_dimto512In cell 3:
acts = model.model.layers[layer].output[0]toacts = model.gpt_neox.layers[layer].output[0]In cell 4:
layerto4With these changes, when running cell 5 to train the probes I get the following error:
Here's the full trace
Collecting activations: 0%| | 0/141 [00:00 380 output = self.target(*args, **kwargs) 382 # Set value.IndexError: tuple index out of range
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:383, in Node.execute(self)
382 # Set value.
--> 383 self.set_value(output)
385 except Exception as e:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:410, in Node.set_value(self, value)
409 if listener.fulfilled() and not self.graph.sequential:
--> 410 listener.execute()
412 for dependency in self.arg_dependencies:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:387, in Node.execute(self)
385 except Exception as e:
--> 387 raise type(e)(
388 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
389 ) from e
391 finally:
IndexError: Above exception when execution Node: 'getitem_1' in Graph: '13863411648'
The above exception was the direct cause of the following exception:
IndexError Traceback (most recent call last)
Cell In[5], line 1
----> 1 oracle, _ = train_probe(
2 activation_batches=collect_activations(model, layer, get_text_batches(split="train", ambiguous=False))
3 )
5 ambiguous_accs = test_probe(oracle, activation_batches=collect_activations(model, layer, get_text_batches(split="test", ambiguous=True)))
6 print(f"ambiguous test accuracy: {ambiguous_accs[0]}")
Cell In[4], line 30, in train_probe(activation_batches, label_idx, lr, epochs, d_probe, seed)
27 losses = []
29 for epoch in range(epochs):
---> 30 for act, *labels, in activation_batches:
31 optimizer.zero_grad()
32 logits = probe(act)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:38, in _wrap_generator..generator_context(*args, **kwargs)
35 try:
36 # Issuing
Noneto a generator fires it up37 with ctx_factory():
---> 38 response = gen.send(None)
40 while True:
41 try:
42 # Forward the response to our caller and get its next request
Cell In[3], line 12, in collect_activations(model, layer, text_batches)
10 with tqdm(total=len(text_batches), desc="Collecting activations") as pbar:
11 for text_batch, *labels in text_batches:
---> 12 with model.trace(text_batch, **tracer_kwargs):
13 attn_mask = model.input[1]['attention_mask']
14 acts = model.gpt_neox.layers[layer].output[0]
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/Tracer.py:102, in Tracer.exit(self, exc_type, exc_val, exc_tb)
97 self.invoker.exit(None, None, None)
99 self.model._envoy._reset()
--> 102 super().exit(exc_type, exc_val, exc_tb)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/GraphBasedContext.py:217, in GraphBasedContext.exit(self, exc_type, exc_val, exc_tb)
214 self.graph = None
215 raise exc_val
--> 217 self.backend(self)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/backends/LocalBackend.py:27, in LocalBackend.call(self, obj)
25 def call(self, obj: LocalMixin):
---> 27 obj.local_backend_execute()
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/contexts/Tracer.py:146, in Tracer.local_backend_execute(self)
142 invoker_inputs = resolve_dependencies(invoker_inputs)
144 self.graph.execute()
--> 146 self.model.interleave(
147 self.model._execute,
148 self.graph,
149 *invoker_inputs,
150 **self._kwargs,
151 )
153 graph = self.graph
154 graph.alive = False
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:469, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
463 intervention_handler = InterventionHandler(
464 intervention_graph, batch_groups, batch_size
465 )
467 module_paths = InterventionProtocol.get_interventions(intervention_graph).keys()
--> 469 with HookHandler(
470 self._model,
471 list(module_paths),
472 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
473 activations, module_path, "input", intervention_handler
474 ),
475 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
476 activations, module_path, "output", intervention_handler
477 ),
478 ):
479 try:
480 fn(*inputs, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:581, in HookHandler.exit(self, exc_type, exc_val, exc_tb)
578 handle.remove()
580 if isinstance(exc_val, Exception):
--> 581 raise exc_val
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:480, in NNsight.interleave(self, fn, intervention_graph, *inputs, **kwargs)
469 with HookHandler(
470 self._model,
471 list(module_paths),
(...)
477 ),
478 ):
479 try:
--> 480 fn(*inputs, **kwargs)
481 except protocols.EarlyStopProtocol.EarlyStopException:
482 # TODO: Log.
483 for node in intervention_graph.nodes.values():
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/mixins/Generation.py:20, in GenerationMixin._execute(self, generate, *args, **kwargs)
16 if generate:
18 return self._execute_generate(*args, **kwargs)
---> 20 return self._execute_forward(*args, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/LanguageModel.py:327, in LanguageModel._execute_forward(self, prepared_inputs, *args, **kwargs)
323 def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):
325 device = next(self._model.parameters()).device
--> 327 return self._model(
328 *args,
329 **prepared_inputs.to(device),
330 **kwargs,
331 )
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1879, in Module._call_impl(self, *args, **kwargs)
1876 return inner()
1878 try:
-> 1879 return inner()
1880 except Exception:
1881 # run always called hooks if they have not already been run
1882 # For now only forward hooks have the always_call option but perhaps
1883 # this functionality should be added to full backward hooks as well.
1884 for hook_id, hook in _global_forward_hooks.items():
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1806, in Module._call_impl..inner()
1801 for hook_id, hook in (
1802 *_global_forward_pre_hooks.items(),
1803 *self._forward_pre_hooks.items(),
1804 ):
1805 if hook_id in self._forward_pre_hooks_with_kwargs:
-> 1806 args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
1807 if args_kwargs_result is not None:
1808 if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2:
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:555, in HookHandler.enter..input_hook(module, input, kwargs, module_path)
554 def input_hook(module, input, kwargs, module_path=module_path):
--> 555 return self.input_hook((input, kwargs), module_path)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/models/NNsightModel.py:472, in NNsight.interleave..(activations, module_path)
463 intervention_handler = InterventionHandler(
464 intervention_graph, batch_groups, batch_size
465 )
467 module_paths = InterventionProtocol.get_interventions(intervention_graph).keys()
469 with HookHandler(
470 self._model,
471 list(module_paths),
--> 472 input_hook=lambda activations, module_path: InterventionProtocol.intervene(
473 activations, module_path, "input", intervention_handler
474 ),
475 output_hook=lambda activations, module_path: InterventionProtocol.intervene(
476 activations, module_path, "output", intervention_handler
477 ),
478 ):
479 try:
480 fn(*inputs, **kwargs)
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/intervention.py:456, in InterventionProtocol.intervene(cls, activations, module_path, key, intervention_handler)
449 value = util.apply(
450 activations,
451 narrow,
452 torch.Tensor,
453 )
455 # Value injection.
--> 456 node.set_value(value)
458 # Check if through the previous value injection, there was a 'swap' intervention.
459 # This would mean we want to replace activations for this batch with some other ones.
460 value = protocols.SwapProtocol.get_swap(
461 intervention_handler.graph, value
462 )
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:410, in Node.set_value(self, value)
407 listener.remaining_dependencies -= 1
409 if listener.fulfilled() and not self.graph.sequential:
--> 410 listener.execute()
412 for dependency in self.arg_dependencies:
413 dependency.remaining_listeners -= 1
File ~/Projects/work/SHIFT/testing/feature-circuits/venv/lib/python3.10/site-packages/nnsight/tracing/Node.py:387, in Node.execute(self)
383 self.set_value(output)
385 except Exception as e:
--> 387 raise type(e)(
388 f"Above exception when execution Node: '{self.name}' in Graph: '{self.graph.id}'"
389 ) from e
391 finally:
392 self.remaining_dependencies -= 1
IndexError: Above exception when execution Node: 'getitem_0' in Graph: '13863411648'