diff --git a/src/miv_simulator/optimization.py b/src/miv_simulator/optimization.py index 7416bb3..4a70691 100644 --- a/src/miv_simulator/optimization.py +++ b/src/miv_simulator/optimization.py @@ -370,15 +370,20 @@ def update_run_params(env, param_tuples): def network_features(env, t_start, t_stop, target_populations): features_dict = dict() - temporal_resolution = float(env.stimulus_config["Temporal Resolution"]) - time_bins = np.arange(t_start, t_stop, temporal_resolution) + analysis_config = env.analysis_config + if analysis_config is None: + analysis_config = {} + fr_inference_config = analysis_config.get("Firing Rate Inference", {}) + + temporal_resolution = float(fr_inference_config.get("Temporal Resolution", 2.0)) + time_bins = np.arange(t_start, t_stop, temporal_resolution).astype(np.float32) pop_spike_dict = spikedata.get_env_spike_dict(env, include_artificial=False) for pop_name in target_populations: n_active = 0 spike_density_dict = spikedata.spike_density_estimate( - pop_name, pop_spike_dict[pop_name], time_bins + pop_name, pop_spike_dict[pop_name], time_bins, return_time_bins=False ) for gid, dens_dict in spike_density_dict.items(): mean_rate = np.mean(dens_dict["rate"]) diff --git a/src/miv_simulator/optimize_network.py b/src/miv_simulator/optimize_network.py index e24a346..2ba819a 100644 --- a/src/miv_simulator/optimize_network.py +++ b/src/miv_simulator/optimize_network.py @@ -2,7 +2,7 @@ """ Network model optimization script for optimization with dmosopt """ - +import gc import os import sys import datetime @@ -410,9 +410,9 @@ def compute_objectives(local_features, operational_config, opt_targets): sum_mean_rate_local += mean_rate ip_rate = np.interp( fr_time_centers, - dens_dict["time"].astype(np.float32), + time_bins_ref, dens_dict["rate"].astype(np.float32), - ) + ).astype(np.float32) active_per_bin = ip_rate > active_threshold sum_active_per_bin += active_per_bin @@ -439,17 +439,19 @@ def compute_objectives(local_features, operational_config, opt_targets): ) all_features_dict[f"{pop_name} mean fraction active per time bin"] = ( - mean_fraction_active_per_bin + float(mean_fraction_active_per_bin) ) all_features_dict[f"{pop_name} std fraction active per time bin"] = ( - std_fraction_active_per_bin + float(std_fraction_active_per_bin) ) - all_features_dict[f"{pop_name} fraction active"] = fraction_active - all_features_dict[f"{pop_name} firing rate"] = mean_rate + all_features_dict[f"{pop_name} fraction active"] = float(fraction_active) + all_features_dict[f"{pop_name} firing rate"] = float(mean_rate) rate_constr = mean_rate if mean_rate > 0.0 else -1.0 constraints.append(rate_constr) + gc.collect() + objective_names = operational_config["objective_names"] feature_dtypes = [(feature_name, np.float32) for feature_name in objective_names] @@ -470,7 +472,7 @@ def compute_objectives(local_features, operational_config, opt_targets): features.append(feature_val) result = ( - np.asarray(objectives), + np.asarray(objectives, dtype=np.float32), np.array([tuple(features)], dtype=np.dtype(feature_dtypes)), np.asarray(constraints, dtype=np.float32), ) diff --git a/src/miv_simulator/spikedata.py b/src/miv_simulator/spikedata.py index 6e249c7..15c1963 100644 --- a/src/miv_simulator/spikedata.py +++ b/src/miv_simulator/spikedata.py @@ -295,6 +295,7 @@ def spike_density_estimate( trajectory_id=None, output_file_path=None, progress=False, + return_time_bins=True, inferred_rate_attr_name="Inferred Rate Map", **kwargs, ): @@ -339,7 +340,7 @@ def make_spktrain(lst, t_start, t_stop): spk_rate_dict = { ind: baks(spkts / 1000.0, time_bins / 1000.0, **baks_args)[0].reshape((-1,)) if len(spkts) > 1 - else np.zeros(time_bins.shape) + else np.zeros(time_bins.shape, dtype=np.float32) for ind, spkts in seq } @@ -360,13 +361,14 @@ def make_spktrain(lst, t_start, t_stop): output_file_path, population, attr_dict, namespace=namespace ) - result = { - ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items() - } - - result = { - ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items() - } + if return_time_bins: + result = { + ind: {"rate": rate, "time": time_bins} for ind, rate in spk_rate_dict.items() + } + else: + result = { + ind: {"rate": rate} for ind, rate in spk_rate_dict.items() + } return result diff --git a/src/miv_simulator/utils/utils.py b/src/miv_simulator/utils/utils.py index 0112a81..0009199 100644 --- a/src/miv_simulator/utils/utils.py +++ b/src/miv_simulator/utils/utils.py @@ -984,7 +984,7 @@ def baks(spktimes, time, a=1.5, b=None): threshold = 6.0 h = (gamma(a) / gamma(a + 0.5)) * (sumnum / sumdenom) - rate = np.zeros((len(time),)) + rate = np.zeros((len(time),), dtype=np.float32) for j in range(n): time_diff = time - spktimes[j] abs_diff = np.abs(time_diff)