Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 17 additions & 75 deletions vivarium_models/processes/simularium_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from vivarium.core.emitter import Emitter

import numpy as np
import pandas as pd
from simulariumio import (
TrajectoryConverter,
TrajectoryData,
Expand Down Expand Up @@ -110,83 +109,26 @@ def get_simularium_monomers(self, time, monomers, actin_radius, trajectory):
return trajectory

@staticmethod
def fill_df(df, fill):
def get_scaled_agent_data(trajectory, scale_factor) -> AgentData:
"""
Fill Nones in a DataFrame with a fill value
"""
# Create a dataframe of fill values
fill_array = [[fill] * df.shape[1]] * df.shape[0]
fill_df = pd.DataFrame(fill_array)
# Replace all entries with None with the fill
df[df.isna()] = fill_df
return df

@staticmethod
def jagged_3d_list_to_numpy_array(jagged_3d_list):
"""
Shape a jagged list with 3 dimensions to a numpy array
"""
df = SimulariumEmitter.fill_df(pd.DataFrame(jagged_3d_list), [0.0, 0.0, 0.0])
df_t = df.transpose()
exploded = [df_t[col].explode() for col in list(df_t.columns)]
return np.array(exploded).reshape((df.shape[0], df.shape[1], 3))

@staticmethod
def get_subpoints_numpy_array(trajectory) -> np.ndarray:
"""
Shape a 4 dimensional jagged list for subpoints into a numpy array
"""
frame_arrays = []
max_agents = 0
max_subpoints = 0
total_steps = len(trajectory["subpoints"])
for time_index in range(total_steps):
frame_array = SimulariumEmitter.jagged_3d_list_to_numpy_array(
trajectory["subpoints"][time_index]
)
if frame_array.shape[0] > max_agents:
max_agents = frame_array.shape[0]
if frame_array.shape[1] > max_subpoints:
max_subpoints = frame_array.shape[1]
frame_arrays.append(frame_array)
values_per_frame = max_agents * max_subpoints * 3
result = np.zeros(total_steps * values_per_frame)
for time_index, frame_array in enumerate(frame_arrays):
if frame_array.shape[1] < max_subpoints:
new_frame_array = np.zeros((frame_array.shape[0], max_subpoints, 3))
new_frame_array[:, : frame_array.shape[1]] = frame_array
frame_array = new_frame_array
flat_array = frame_array.flatten()
start_index = time_index * values_per_frame
result[start_index : start_index + flat_array.shape[0]] = flat_array
return result.reshape(total_steps, max_agents, max_subpoints, 3)

@staticmethod
def get_agent_data_from_jagged_lists(trajectory, scale_factor) -> AgentData:
"""
Shape a dictionary of jagged lists into a Simularium AgentData object
Build AgentData object, scaling appropriate values by scale_factor
"""
return AgentData(
times=np.arange(len(trajectory["times"])),
n_agents=np.array(trajectory["n_agents"]),
viz_types=SimulariumEmitter.fill_df(
pd.DataFrame(trajectory["viz_types"]), 1000.0
).to_numpy(),
unique_ids=SimulariumEmitter.fill_df(
pd.DataFrame(trajectory["unique_ids"]), 0
).to_numpy(dtype=int),
times=trajectory["times"],
n_agents=trajectory["n_agents"],
viz_types=trajectory["viz_types"],
unique_ids=trajectory["unique_ids"],
types=trajectory["type_names"],
positions=scale_factor
* SimulariumEmitter.jagged_3d_list_to_numpy_array(trajectory["positions"]),
radii=scale_factor
* SimulariumEmitter.fill_df(
pd.DataFrame(trajectory["radii"]), 0.0
).to_numpy(),
n_subpoints=SimulariumEmitter.fill_df(
pd.DataFrame(trajectory["n_subpoints"]), 0
).to_numpy(dtype=int),
subpoints=scale_factor
* SimulariumEmitter.get_subpoints_numpy_array(trajectory),
positions=[
[[k * scale_factor for k in j] for j in i]
for i in trajectory["positions"]
],
radii=[[j * scale_factor for j in i] for i in trajectory["radii"]],
n_subpoints=trajectory["n_subpoints"],
subpoints=[
[[k * scale_factor for k in j] for j in i]
for i in trajectory["subpoints"]
],
)

@staticmethod
Expand All @@ -204,7 +146,7 @@ def get_simularium_converter(
meta_data=MetaData(
box_size=scale_factor * box_dimensions,
),
agent_data=SimulariumEmitter.get_agent_data_from_jagged_lists(
agent_data=SimulariumEmitter.get_scaled_agent_data(
trajectory, scale_factor
),
time_units=UnitData("count"),
Expand Down