From 04a83f99b511686394dd2977c9095a7e54241e25 Mon Sep 17 00:00:00 2001 From: Ivan Raikov Date: Fri, 12 Sep 2025 07:14:29 -0700 Subject: [PATCH 1/3] input spike train generation: ensure that response array from input feature is flat --- src/miv_simulator/input_spike_trains.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/miv_simulator/input_spike_trains.py b/src/miv_simulator/input_spike_trains.py index b565e8d..997606c 100644 --- a/src/miv_simulator/input_spike_trains.py +++ b/src/miv_simulator/input_spike_trains.py @@ -226,6 +226,8 @@ def generate_input_spike_trains( response = input_feature.get_response(processed_signal) if isinstance(response, list): response = np.concatenate(np.concatenate(response, dtype=np.float32)) + else: + response = response.reshape((-1,)).astype(np.float32) if len(response) > 0: spikes_attr_dict[gid] = {output_spike_train_attr_name: response} From 3382b5b183e6df0eb107125b27b6d5377129dbc7 Mon Sep 17 00:00:00 2001 From: Ivan Raikov Date: Fri, 12 Sep 2025 18:05:33 -0500 Subject: [PATCH 2/3] bug fixes in spike generation --- src/miv_simulator/input_spike_trains.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/miv_simulator/input_spike_trains.py b/src/miv_simulator/input_spike_trains.py index 997606c..b72cee7 100644 --- a/src/miv_simulator/input_spike_trains.py +++ b/src/miv_simulator/input_spike_trains.py @@ -93,9 +93,6 @@ def generate_input_spike_trains( logger.info(f"{comm.size} ranks have been allocated") population_name = population.name - start_gid = 0 - if hasattr(population, "start_gid"): - start_gid = population.start_gid soma_positions_dict = None if coords_path is not None: @@ -189,8 +186,6 @@ def generate_input_spike_trains( feature_items = list(population.features.items()) n_iter = comm.allreduce(len(feature_items), op=MPI.MAX) - logger.info(f"n_iter = {n_iter} feature_items = {feature_items}") - if not dry_run and rank == 0: if output_path is None: raise RuntimeError("generate_input_spike_trains: missing output_path") @@ -203,7 +198,6 @@ def generate_input_spike_trains( for iter_count in range(n_iter): if iter_count < len(feature_items): gid, input_feature = feature_items[iter_count] - gid += start_gid else: gid, input_feature = None, None if gid is not None: @@ -225,7 +219,17 @@ def generate_input_spike_trains( # Get spike response response = input_feature.get_response(processed_signal) if isinstance(response, list): - response = np.concatenate(np.concatenate(response, dtype=np.float32)) + response_length = 0 + for x in response: + response_length += len(x) + if response_length > 0: + try: + response = np.concatenate(np.concatenate(response, dtype=np.float32)) + except Exception as e: + logger.error(f"error concatenating response: {response}") + raise e + else: + response = np.asarray([], dtype=np.float32) else: response = response.reshape((-1,)).astype(np.float32) From d6442b208c75b116dfef01a662063c8a29dab681 Mon Sep 17 00:00:00 2001 From: Ivan Raikov Date: Sat, 13 Sep 2025 00:37:53 -0500 Subject: [PATCH 3/3] support for multiple input spike namespaces --- src/miv_simulator/env.py | 10 ++--- src/miv_simulator/interface/legacy/run.py | 2 +- src/miv_simulator/network.py | 41 ++++++++++--------- src/miv_simulator/scripts/run_network.py | 2 + src/miv_simulator/scripts/tools/cut_slice.py | 3 +- .../scripts/tools/sample_cells.py | 3 +- src/miv_simulator/utils/io.py | 18 ++++---- 7 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/miv_simulator/env.py b/src/miv_simulator/env.py index 2be8f3c..61cf568 100644 --- a/src/miv_simulator/env.py +++ b/src/miv_simulator/env.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List, Any, Optional import logging import os @@ -84,9 +84,9 @@ def __init__( lptbal: bool = False, cell_selection_path: None = None, microcircuit_inputs: bool = False, - spike_input_path: None = None, - spike_input_namespace: None = None, - spike_input_attr: None = None, + spike_input_path: Optional[str] = None, + spike_input_namespaces: List[str] = [], + spike_input_attr: Optional[str] = None, coordinates_namespace: str = "Coordinates", cache_queries: bool = False, profile_memory: bool = False, @@ -312,7 +312,7 @@ def __init__( # Spike input path self.spike_input_path = spike_input_path - self.spike_input_ns = spike_input_namespace + self.spike_input_namespaces = spike_input_namespaces self.spike_input_attr = spike_input_attr self.spike_input_attribute_info = None if self.spike_input_path is not None: diff --git a/src/miv_simulator/interface/legacy/run.py b/src/miv_simulator/interface/legacy/run.py index a6503f6..8f7a0eb 100644 --- a/src/miv_simulator/interface/legacy/run.py +++ b/src/miv_simulator/interface/legacy/run.py @@ -106,7 +106,7 @@ def __call__(self): cell_selection_path=None, microcircuit_inputs=False, spike_input_path=self.config.spike_input_path, - spike_input_namespace=self.config.spike_input_namespace, + spike_input_namespaces=[self.config.spike_input_namespace], spike_input_attr=self.config.spike_input_attr, cleanup=True, cache_queries=False, diff --git a/src/miv_simulator/network.py b/src/miv_simulator/network.py index c73ec14..6d24cff 100644 --- a/src/miv_simulator/network.py +++ b/src/miv_simulator/network.py @@ -1203,10 +1203,10 @@ def make_input_cell_selection(env): has_spike_train = False if (env.spike_input_attribute_info is not None) and ( - env.spike_input_ns is not None + len(env.spike_input_namespaces) > 0 ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys())) ): has_spike_train = True @@ -1309,18 +1309,19 @@ def init_input_cells(env: Env) -> None: has_vecstim = False vecstim_source_loc = [] if (env.spike_input_attribute_info is not None) and ( - env.spike_input_ns is not None + len(env.spike_input_namespaces) > 0 ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys())) ): has_vecstim = True - vecstim_source_loc.append( - ( - env.spike_input_path, - env.spike_input_ns, - env.spike_input_attr, - ) + for ns in env.spike_input_namespaces: + vecstim_source_loc.append( + ( + env.spike_input_path, + ns, + env.spike_input_attr, + ) ) if (env.cell_attribute_info is not None) and ( vecstim_namespace is not None @@ -1454,23 +1455,25 @@ def init_input_cells(env: Env) -> None: has_spike_train = False spike_input_source_loc = [] if (env.spike_input_attribute_info is not None) and ( - env.spike_input_ns is not None + len(env.spike_input_namespaces) > 0 ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys())) ): has_spike_train = True - spike_input_source_loc.append( - (env.spike_input_path, env.spike_input_ns) + for ns in env.spike_input_namespaces: + spike_input_source_loc.append( + (env.spike_input_path, ns) ) if (env.cell_attribute_info is not None) and ( - env.spike_input_ns is not None + len(env.spike_input_namespaces) > 0 ): if (pop_name in env.cell_attribute_info) and ( - env.spike_input_ns in env.cell_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys())) ): has_spike_train = True - spike_input_source_loc.append((input_file_path, env.spike_input_ns)) + for ns in env.spike_input_namespaces: + spike_input_source_loc.append((input_file_path, ns)) if rank == 0: logger.info( @@ -1525,7 +1528,7 @@ def init_input_cells(env: Env) -> None: elif len(this_gid_range) > 0: raise RuntimeError( f"init_input_cells: unable to determine spike train attribute for population {pop_name} in spike input file {env.spike_input_path};" - f" namespace {env.spike_input_ns}; attr keys {list(cell_spikes_attr_info.keys())}" + f" namespaces {env.spike_input_namespaces}; attr keys {list(cell_spikes_attr_info.keys())}" ) for gid, cell_spikes_tuple in cell_spikes_iter: if not (env.pc.gid_exists(gid)): @@ -1566,7 +1569,7 @@ def init_input_cells(env: Env) -> None: if rank == 0: logger.warning( f"No spike train data found for population {pop_name} in spike input file {env.spike_input_path}; " - f"namespace: {env.spike_input_ns}" + f"namespaces: {env.spike_input_namespaces}" ) gc.collect() diff --git a/src/miv_simulator/scripts/run_network.py b/src/miv_simulator/scripts/run_network.py index c1dd966..e53dc63 100644 --- a/src/miv_simulator/scripts/run_network.py +++ b/src/miv_simulator/scripts/run_network.py @@ -187,6 +187,7 @@ def mpi_excepthook(type, value, traceback): @click.option( "--spike-input-namespace", required=False, + multiple=True, type=str, help="namespace for input spikes when cell selection is specified", ) @@ -290,6 +291,7 @@ def main( np.seterr(all="raise") params = dict(locals()) params["config"] = params.pop("config_file") + params["spike_input_namespaces"] = params.get("spike_input_namespace", []) env = Env(**params) compile_and_load(directory=env.mechanisms_path) diff --git a/src/miv_simulator/scripts/tools/cut_slice.py b/src/miv_simulator/scripts/tools/cut_slice.py index 7636d36..229e85e 100644 --- a/src/miv_simulator/scripts/tools/cut_slice.py +++ b/src/miv_simulator/scripts/tools/cut_slice.py @@ -52,6 +52,7 @@ def mpi_excepthook(type, value, traceback): @click.option( "--spike-input-namespace", required=False, + multiple=True, type=str, help="namespace for input spikes when cell selection is specified", ) @@ -103,7 +104,7 @@ def main( dataset_prefix=dataset_prefix, results_path=output_path, spike_input_path=spike_input_path, - spike_input_namespace=spike_input_namespace, + spike_input_namespaces=spike_input_namespace, spike_input_attr=spike_input_attr, coordinates_namespace=coordinates_namespace, io_size=io_size, diff --git a/src/miv_simulator/scripts/tools/sample_cells.py b/src/miv_simulator/scripts/tools/sample_cells.py index b53cbf2..8397730 100644 --- a/src/miv_simulator/scripts/tools/sample_cells.py +++ b/src/miv_simulator/scripts/tools/sample_cells.py @@ -59,6 +59,7 @@ def mpi_excepthook(type, value, traceback): @click.option( "--spike-input-namespace", required=False, + multiple=True, type=str, help="namespace for input spikes", ) @@ -125,7 +126,7 @@ def main( dataset_prefix=dataset_prefix, results_path=output_path, spike_input_path=spike_input_path, - spike_input_namespace=spike_input_namespace, + spike_input_namespaces=spike_input_namespace, spike_input_attr=spike_input_attr, arena_id=arena_id, stimulus_id=stimulus_id, diff --git a/src/miv_simulator/utils/io.py b/src/miv_simulator/utils/io.py index 3df3bef..ebfa557 100644 --- a/src/miv_simulator/utils/io.py +++ b/src/miv_simulator/utils/io.py @@ -1025,21 +1025,23 @@ def write_input_cell_selection( has_spike_train = False spike_input_source_loc = [] if (env.spike_input_attribute_info is not None) and ( - env.spike_input_ns is not None + len(env.spike_input_namespaces) > 0 ): if (pop_name in env.spike_input_attribute_info) and ( - env.spike_input_ns in env.spike_input_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys())) ): has_spike_train = True - spike_input_source_loc.append( - (env.spike_input_path, env.spike_input_ns) + for ns in env.spike_input_namespaces: + spike_input_source_loc.append( + (env.spike_input_path, ns) ) - if (env.cell_attribute_info is not None) and (env.spike_input_ns is not None): + if (env.cell_attribute_info is not None) and (len(env.spike_input_namespaces) > 0): if (pop_name in env.cell_attribute_info) and ( - env.spike_input_ns in env.cell_attribute_info[pop_name] + set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys())) ): has_spike_train = True - spike_input_source_loc.append((input_file_path, env.spike_input_ns)) + for ns in env.spike_input_namespaces: + spike_input_source_loc.append((input_file_path, ns)) if rank == 0: logger.info( @@ -1083,7 +1085,7 @@ def write_input_cell_selection( write_selection_file_path, pop_name, spikes_output_dict, - namespace=env.spike_input_ns, + namespace=env.spike_input_namespaces[0], **write_kwds, )