Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/miv_simulator/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,20 @@ def update_run_params(env, param_tuples):
def network_features(env, t_start, t_stop, target_populations):
features_dict = dict()

temporal_resolution = float(env.stimulus_config["Temporal Resolution"])
time_bins = np.arange(t_start, t_stop, temporal_resolution)
analysis_config = env.analysis_config
if analysis_config is None:
analysis_config = {}
fr_inference_config = analysis_config.get("Firing Rate Inference", {})

temporal_resolution = float(fr_inference_config.get("Temporal Resolution", 2.0))
time_bins = np.arange(t_start, t_stop, temporal_resolution).astype(np.float32)

pop_spike_dict = spikedata.get_env_spike_dict(env, include_artificial=False)

for pop_name in target_populations:
n_active = 0
spike_density_dict = spikedata.spike_density_estimate(
pop_name, pop_spike_dict[pop_name], time_bins
pop_name, pop_spike_dict[pop_name], time_bins, return_time_bins=False
)
for gid, dens_dict in spike_density_dict.items():
mean_rate = np.mean(dens_dict["rate"])
Expand Down
18 changes: 10 additions & 8 deletions src/miv_simulator/optimize_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Network model optimization script for optimization with dmosopt
"""

import gc
import os
import sys
import datetime
Expand Down Expand Up @@ -410,9 +410,9 @@ def compute_objectives(local_features, operational_config, opt_targets):
sum_mean_rate_local += mean_rate
ip_rate = np.interp(
fr_time_centers,
dens_dict["time"].astype(np.float32),
time_bins_ref,
dens_dict["rate"].astype(np.float32),
)
).astype(np.float32)
active_per_bin = ip_rate > active_threshold
sum_active_per_bin += active_per_bin

Expand All @@ -439,17 +439,19 @@ def compute_objectives(local_features, operational_config, opt_targets):
)

all_features_dict[f"{pop_name} mean fraction active per time bin"] = (
mean_fraction_active_per_bin
float(mean_fraction_active_per_bin)
)
all_features_dict[f"{pop_name} std fraction active per time bin"] = (
std_fraction_active_per_bin
float(std_fraction_active_per_bin)
)
all_features_dict[f"{pop_name} fraction active"] = fraction_active
all_features_dict[f"{pop_name} firing rate"] = mean_rate
all_features_dict[f"{pop_name} fraction active"] = float(fraction_active)
all_features_dict[f"{pop_name} firing rate"] = float(mean_rate)

rate_constr = mean_rate if mean_rate > 0.0 else -1.0
constraints.append(rate_constr)

gc.collect()

objective_names = operational_config["objective_names"]
feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names]

Expand All @@ -470,7 +472,7 @@ def compute_objectives(local_features, operational_config, opt_targets):
features.append(feature_val)

result = (
np.asarray(objectives),
np.asarray(objectives, dtype=np.float32),
np.array([tuple(features)], dtype=np.dtype(feature_dtypes)),
np.asarray(constraints, dtype=np.float32),
)
Expand Down
18 changes: 10 additions & 8 deletions src/miv_simulator/spikedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def spike_density_estimate(
trajectory_id=None,
output_file_path=None,
progress=False,
return_time_bins=True,
inferred_rate_attr_name="Inferred Rate Map",
**kwargs,
):
Expand Down Expand Up @@ -339,7 +340,7 @@ def make_spktrain(lst, t_start, t_stop):
spk_rate_dict = {
ind: baks(spkts / 1000.0, time_bins / 1000.0, **baks_args)[0].reshape((-1,))
if len(spkts) > 1
else np.zeros(time_bins.shape)
else np.zeros(time_bins.shape, dtype=np.float32)
for ind, spkts in seq
}

Expand All @@ -360,13 +361,14 @@ def make_spktrain(lst, t_start, t_stop):
output_file_path, population, attr_dict, namespace=namespace
)

result = {
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
}

result = {
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
}
if return_time_bins:
result = {
ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items()
}
else:
result = {
ind: {"rate": rate} for ind, rate in spk_rate_dict.items()
}

return result

Expand Down
2 changes: 1 addition & 1 deletion src/miv_simulator/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def baks(spktimes, time, a=1.5, b=None):

threshold = 6.0
h = (gamma(a) / gamma(a + 0.5)) * (sumnum / sumdenom)
rate = np.zeros((len(time),))
rate = np.zeros((len(time),), dtype=np.float32)
for j in range(n):
time_diff = time - spktimes[j]
abs_diff = np.abs(time_diff)
Expand Down
Loading