Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "miv-simulator"
version = "0.2.0"
version = "0.2.1"
description = "Mind-In-Vitro simulator"
authors = []
dependencies = [
Expand Down
20 changes: 20 additions & 0 deletions src/miv_simulator/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,26 @@ def load_celltypes(self) -> None:
weights_dict["closure"] = clos
synapses_dict["weights"] = weights_dicts

def register_population(self, population_name, population_cell_distribution):
if population_name in self.Populations:
return None

max_pop_enum = 0
pop_offset = 0
for this_pop_name, this_pop_enum in self.Populations.items():
max_pop_enum = max(this_pop_enum, max_pop_enum)
pop_offset += self.celltypes[this_pop_name]["num"]

pop_id = max_pop_enum + 1
self.Populations[population_name] = pop_id
cell_distribution = {}
if "Cell Distribution" in self.geometry:
cell_distribution = self.geometry["Cell Distribution"]
else:
self.geometry["Cell Distribution"] = population_cell_distribution
cell_distribution[population_name] = population_cell_distribution
return {"population_id": pop_id, "population_start_gid": pop_offset}

def clear(self):
self.gidset = set()
self.gjlist = []
Expand Down
14 changes: 9 additions & 5 deletions src/miv_simulator/input_spike_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def generate_input_spike_trains(
logger.info(f"{comm.size} ranks have been allocated")

population_name = population.name
start_gid = 0
if hasattr(population, "start_gid"):
start_gid = population.start_gid

soma_positions_dict = None
if coords_path is not None:
Expand Down Expand Up @@ -186,6 +189,8 @@ def generate_input_spike_trains(
feature_items = list(population.features.items())
n_iter = comm.allreduce(len(feature_items), op=MPI.MAX)

logger.info(f"n_iter = {n_iter} feature_items = {feature_items}")

if not dry_run and rank == 0:
if output_path is None:
raise RuntimeError("generate_input_spike_trains: missing output_path")
Expand All @@ -198,6 +203,7 @@ def generate_input_spike_trains(
for iter_count in range(n_iter):
if iter_count < len(feature_items):
gid, input_feature = feature_items[iter_count]
gid += start_gid
else:
gid, input_feature = None, None
if gid is not None:
Expand All @@ -218,13 +224,11 @@ def generate_input_spike_trains(

# Get spike response
response = input_feature.get_response(processed_signal)
if isinstance(response, list):
response = np.concatenate(np.concatenate(response, dtype=np.float32))

if len(response) > 0:
spikes_attr_dict[gid] = {
output_spike_train_attr_name: np.concatenate(
response, dtype=np.float32
)
}
spikes_attr_dict[gid] = {output_spike_train_attr_name: response}

gid_count += 1
if (iter_count > 0 and iter_count % write_every == 0) or (
Expand Down
Loading