Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/miv_simulator/env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List, Any, Optional

import logging
import os
Expand Down Expand Up @@ -84,9 +84,9 @@ def __init__(
lptbal: bool = False,
cell_selection_path: None = None,
microcircuit_inputs: bool = False,
spike_input_path: None = None,
spike_input_namespace: None = None,
spike_input_attr: None = None,
spike_input_path: Optional[str] = None,
spike_input_namespaces: List[str] = [],
spike_input_attr: Optional[str] = None,
coordinates_namespace: str = "Coordinates",
cache_queries: bool = False,
profile_memory: bool = False,
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(

# Spike input path
self.spike_input_path = spike_input_path
self.spike_input_ns = spike_input_namespace
self.spike_input_namespaces = spike_input_namespaces
self.spike_input_attr = spike_input_attr
self.spike_input_attribute_info = None
if self.spike_input_path is not None:
Expand Down
20 changes: 13 additions & 7 deletions src/miv_simulator/input_spike_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,6 @@ def generate_input_spike_trains(
logger.info(f"{comm.size} ranks have been allocated")

population_name = population.name
start_gid = 0
if hasattr(population, "start_gid"):
start_gid = population.start_gid

soma_positions_dict = None
if coords_path is not None:
Expand Down Expand Up @@ -189,8 +186,6 @@ def generate_input_spike_trains(
feature_items = list(population.features.items())
n_iter = comm.allreduce(len(feature_items), op=MPI.MAX)

logger.info(f"n_iter = {n_iter} feature_items = {feature_items}")

if not dry_run and rank == 0:
if output_path is None:
raise RuntimeError("generate_input_spike_trains: missing output_path")
Expand All @@ -203,7 +198,6 @@ def generate_input_spike_trains(
for iter_count in range(n_iter):
if iter_count < len(feature_items):
gid, input_feature = feature_items[iter_count]
gid += start_gid
else:
gid, input_feature = None, None
if gid is not None:
Expand All @@ -225,7 +219,19 @@ def generate_input_spike_trains(
# Get spike response
response = input_feature.get_response(processed_signal)
if isinstance(response, list):
response = np.concatenate(np.concatenate(response, dtype=np.float32))
response_length = 0
for x in response:
response_length += len(x)
if response_length > 0:
try:
response = np.concatenate(np.concatenate(response, dtype=np.float32))
except Exception as e:
logger.error(f"error concatenating response: {response}")
raise e
else:
response = np.asarray([], dtype=np.float32)
else:
response = response.reshape((-1,)).astype(np.float32)

if len(response) > 0:
spikes_attr_dict[gid] = {output_spike_train_attr_name: response}
Expand Down
2 changes: 1 addition & 1 deletion src/miv_simulator/interface/legacy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __call__(self):
cell_selection_path=None,
microcircuit_inputs=False,
spike_input_path=self.config.spike_input_path,
spike_input_namespace=self.config.spike_input_namespace,
spike_input_namespaces=[self.config.spike_input_namespace],
spike_input_attr=self.config.spike_input_attr,
cleanup=True,
cache_queries=False,
Expand Down
41 changes: 22 additions & 19 deletions src/miv_simulator/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,10 +1203,10 @@ def make_input_cell_selection(env):

has_spike_train = False
if (env.spike_input_attribute_info is not None) and (
env.spike_input_ns is not None
len(env.spike_input_namespaces) > 0
):
if (pop_name in env.spike_input_attribute_info) and (
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
):
has_spike_train = True

Expand Down Expand Up @@ -1309,18 +1309,19 @@ def init_input_cells(env: Env) -> None:
has_vecstim = False
vecstim_source_loc = []
if (env.spike_input_attribute_info is not None) and (
env.spike_input_ns is not None
len(env.spike_input_namespaces) > 0
):
if (pop_name in env.spike_input_attribute_info) and (
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
):
has_vecstim = True
vecstim_source_loc.append(
(
env.spike_input_path,
env.spike_input_ns,
env.spike_input_attr,
)
for ns in env.spike_input_namespaces:
vecstim_source_loc.append(
(
env.spike_input_path,
ns,
env.spike_input_attr,
)
)
if (env.cell_attribute_info is not None) and (
vecstim_namespace is not None
Expand Down Expand Up @@ -1454,23 +1455,25 @@ def init_input_cells(env: Env) -> None:
has_spike_train = False
spike_input_source_loc = []
if (env.spike_input_attribute_info is not None) and (
env.spike_input_ns is not None
len(env.spike_input_namespaces) > 0
):
if (pop_name in env.spike_input_attribute_info) and (
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
):
has_spike_train = True
spike_input_source_loc.append(
(env.spike_input_path, env.spike_input_ns)
for ns in env.spike_input_namespaces:
spike_input_source_loc.append(
(env.spike_input_path, ns)
)
if (env.cell_attribute_info is not None) and (
env.spike_input_ns is not None
len(env.spike_input_namespaces) > 0
):
if (pop_name in env.cell_attribute_info) and (
env.spike_input_ns in env.cell_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys()))
):
has_spike_train = True
spike_input_source_loc.append((input_file_path, env.spike_input_ns))
for ns in env.spike_input_namespaces:
spike_input_source_loc.append((input_file_path, ns))

if rank == 0:
logger.info(
Expand Down Expand Up @@ -1525,7 +1528,7 @@ def init_input_cells(env: Env) -> None:
elif len(this_gid_range) > 0:
raise RuntimeError(
f"init_input_cells: unable to determine spike train attribute for population {pop_name} in spike input file {env.spike_input_path};"
f" namespace {env.spike_input_ns}; attr keys {list(cell_spikes_attr_info.keys())}"
f" namespaces {env.spike_input_namespaces}; attr keys {list(cell_spikes_attr_info.keys())}"
)
for gid, cell_spikes_tuple in cell_spikes_iter:
if not (env.pc.gid_exists(gid)):
Expand Down Expand Up @@ -1566,7 +1569,7 @@ def init_input_cells(env: Env) -> None:
if rank == 0:
logger.warning(
f"No spike train data found for population {pop_name} in spike input file {env.spike_input_path}; "
f"namespace: {env.spike_input_ns}"
f"namespaces: {env.spike_input_namespaces}"
)

gc.collect()
Expand Down
2 changes: 2 additions & 0 deletions src/miv_simulator/scripts/run_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def mpi_excepthook(type, value, traceback):
@click.option(
"--spike-input-namespace",
required=False,
multiple=True,
type=str,
help="namespace for input spikes when cell selection is specified",
)
Expand Down Expand Up @@ -290,6 +291,7 @@ def main(
np.seterr(all="raise")
params = dict(locals())
params["config"] = params.pop("config_file")
params["spike_input_namespaces"] = params.get("spike_input_namespace", [])
env = Env(**params)

compile_and_load(directory=env.mechanisms_path)
Expand Down
3 changes: 2 additions & 1 deletion src/miv_simulator/scripts/tools/cut_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def mpi_excepthook(type, value, traceback):
@click.option(
"--spike-input-namespace",
required=False,
multiple=True,
type=str,
help="namespace for input spikes when cell selection is specified",
)
Expand Down Expand Up @@ -103,7 +104,7 @@ def main(
dataset_prefix=dataset_prefix,
results_path=output_path,
spike_input_path=spike_input_path,
spike_input_namespace=spike_input_namespace,
spike_input_namespaces=spike_input_namespace,
spike_input_attr=spike_input_attr,
coordinates_namespace=coordinates_namespace,
io_size=io_size,
Expand Down
3 changes: 2 additions & 1 deletion src/miv_simulator/scripts/tools/sample_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def mpi_excepthook(type, value, traceback):
@click.option(
"--spike-input-namespace",
required=False,
multiple=True,
type=str,
help="namespace for input spikes",
)
Expand Down Expand Up @@ -125,7 +126,7 @@ def main(
dataset_prefix=dataset_prefix,
results_path=output_path,
spike_input_path=spike_input_path,
spike_input_namespace=spike_input_namespace,
spike_input_namespaces=spike_input_namespace,
spike_input_attr=spike_input_attr,
arena_id=arena_id,
stimulus_id=stimulus_id,
Expand Down
18 changes: 10 additions & 8 deletions src/miv_simulator/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,21 +1025,23 @@ def write_input_cell_selection(
has_spike_train = False
spike_input_source_loc = []
if (env.spike_input_attribute_info is not None) and (
env.spike_input_ns is not None
len(env.spike_input_namespaces) > 0
):
if (pop_name in env.spike_input_attribute_info) and (
env.spike_input_ns in env.spike_input_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.spike_input_attribute_info[pop_name].keys()))
):
has_spike_train = True
spike_input_source_loc.append(
(env.spike_input_path, env.spike_input_ns)
for ns in env.spike_input_namespaces:
spike_input_source_loc.append(
(env.spike_input_path, ns)
)
if (env.cell_attribute_info is not None) and (env.spike_input_ns is not None):
if (env.cell_attribute_info is not None) and (len(env.spike_input_namespaces) > 0):
if (pop_name in env.cell_attribute_info) and (
env.spike_input_ns in env.cell_attribute_info[pop_name]
set(env.spike_input_namespaces).intersection(set(env.cell_attribute_info[pop_name].keys()))
):
has_spike_train = True
spike_input_source_loc.append((input_file_path, env.spike_input_ns))
for ns in env.spike_input_namespaces:
spike_input_source_loc.append((input_file_path, ns))

if rank == 0:
logger.info(
Expand Down Expand Up @@ -1083,7 +1085,7 @@ def write_input_cell_selection(
write_selection_file_path,
pop_name,
spikes_output_dict,
namespace=env.spike_input_ns,
namespace=env.spike_input_namespaces[0],
**write_kwds,
)

Expand Down
Loading