diff --git a/model_tools/activations/core.py b/model_tools/activations/core.py index ceec93d..543caa5 100644 --- a/model_tools/activations/core.py +++ b/model_tools/activations/core.py @@ -126,7 +126,8 @@ def register_stimulus_set_hook(self, hook): return handle def _get_activations_batched(self, paths, layers, batch_size): - layer_activations = None + from collections import OrderedDict + layer_activations = OrderedDict() for batch_start in tqdm(range(0, len(paths), batch_size), unit_scale=batch_size, desc="activations"): batch_end = min(batch_start + batch_size, len(paths)) batch_inputs = paths[batch_start:batch_end] @@ -134,11 +135,11 @@ def _get_activations_batched(self, paths, layers, batch_size): for hook in self._batch_activations_hooks.copy().values(): # copy to avoid handle re-enabling messing with the loop batch_activations = hook(batch_activations) - if layer_activations is None: - layer_activations = copy.copy(batch_activations) - else: - for layer_name, layer_output in batch_activations.items(): - layer_activations[layer_name] = np.concatenate((layer_activations[layer_name], layer_output)) + for layer_name, layer_output in batch_activations.items(): + layer_activations.setdefault(layer_name, []).append(layer_output) + + for layer_name, layer_outputs in layer_activations.items(): + layer_activations[layer_name] = np.concatenate(layer_outputs) return layer_activations