diff --git a/vivarium_models/processes/simularium_emitter.py b/vivarium_models/processes/simularium_emitter.py index 88d6246..048efad 100644 --- a/vivarium_models/processes/simularium_emitter.py +++ b/vivarium_models/processes/simularium_emitter.py @@ -3,7 +3,6 @@ from vivarium.core.emitter import Emitter import numpy as np -import pandas as pd from simulariumio import ( TrajectoryConverter, TrajectoryData, @@ -110,83 +109,26 @@ def get_simularium_monomers(self, time, monomers, actin_radius, trajectory): return trajectory @staticmethod - def fill_df(df, fill): + def get_scaled_agent_data(trajectory, scale_factor) -> AgentData: """ - Fill Nones in a DataFrame with a fill value - """ - # Create a dataframe of fill values - fill_array = [[fill] * df.shape[1]] * df.shape[0] - fill_df = pd.DataFrame(fill_array) - # Replace all entries with None with the fill - df[df.isna()] = fill_df - return df - - @staticmethod - def jagged_3d_list_to_numpy_array(jagged_3d_list): - """ - Shape a jagged list with 3 dimensions to a numpy array - """ - df = SimulariumEmitter.fill_df(pd.DataFrame(jagged_3d_list), [0.0, 0.0, 0.0]) - df_t = df.transpose() - exploded = [df_t[col].explode() for col in list(df_t.columns)] - return np.array(exploded).reshape((df.shape[0], df.shape[1], 3)) - - @staticmethod - def get_subpoints_numpy_array(trajectory) -> np.ndarray: - """ - Shape a 4 dimensional jagged list for subpoints into a numpy array - """ - frame_arrays = [] - max_agents = 0 - max_subpoints = 0 - total_steps = len(trajectory["subpoints"]) - for time_index in range(total_steps): - frame_array = SimulariumEmitter.jagged_3d_list_to_numpy_array( - trajectory["subpoints"][time_index] - ) - if frame_array.shape[0] > max_agents: - max_agents = frame_array.shape[0] - if frame_array.shape[1] > max_subpoints: - max_subpoints = frame_array.shape[1] - frame_arrays.append(frame_array) - values_per_frame = max_agents * max_subpoints * 3 - result = np.zeros(total_steps * values_per_frame) - for time_index, frame_array in enumerate(frame_arrays): - if frame_array.shape[1] < max_subpoints: - new_frame_array = np.zeros((frame_array.shape[0], max_subpoints, 3)) - new_frame_array[:, : frame_array.shape[1]] = frame_array - frame_array = new_frame_array - flat_array = frame_array.flatten() - start_index = time_index * values_per_frame - result[start_index : start_index + flat_array.shape[0]] = flat_array - return result.reshape(total_steps, max_agents, max_subpoints, 3) - - @staticmethod - def get_agent_data_from_jagged_lists(trajectory, scale_factor) -> AgentData: - """ - Shape a dictionary of jagged lists into a Simularium AgentData object + Build AgentData object, scaling appropriate values by scale_factor """ return AgentData( - times=np.arange(len(trajectory["times"])), - n_agents=np.array(trajectory["n_agents"]), - viz_types=SimulariumEmitter.fill_df( - pd.DataFrame(trajectory["viz_types"]), 1000.0 - ).to_numpy(), - unique_ids=SimulariumEmitter.fill_df( - pd.DataFrame(trajectory["unique_ids"]), 0 - ).to_numpy(dtype=int), + times=trajectory["times"], + n_agents=trajectory["n_agents"], + viz_types=trajectory["viz_types"], + unique_ids=trajectory["unique_ids"], types=trajectory["type_names"], - positions=scale_factor - * SimulariumEmitter.jagged_3d_list_to_numpy_array(trajectory["positions"]), - radii=scale_factor - * SimulariumEmitter.fill_df( - pd.DataFrame(trajectory["radii"]), 0.0 - ).to_numpy(), - n_subpoints=SimulariumEmitter.fill_df( - pd.DataFrame(trajectory["n_subpoints"]), 0 - ).to_numpy(dtype=int), - subpoints=scale_factor - * SimulariumEmitter.get_subpoints_numpy_array(trajectory), + positions=[ + [[k * scale_factor for k in j] for j in i] + for i in trajectory["positions"] + ], + radii=[[j * scale_factor for j in i] for i in trajectory["radii"]], + n_subpoints=trajectory["n_subpoints"], + subpoints=[ + [[k * scale_factor for k in j] for j in i] + for i in trajectory["subpoints"] + ], ) @staticmethod @@ -204,7 +146,7 @@ def get_simularium_converter( meta_data=MetaData( box_size=scale_factor * box_dimensions, ), - agent_data=SimulariumEmitter.get_agent_data_from_jagged_lists( + agent_data=SimulariumEmitter.get_scaled_agent_data( trajectory, scale_factor ), time_units=UnitData("count"),