Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class PandemicObservation:
global_testing_summary: np.ndarray
stage: np.ndarray
infection_above_threshold: np.ndarray
critical_above_threshold:np.ndarray
time_day: np.ndarray
unlocked_non_essential_business_locations: Optional[np.ndarray] = None

Expand All @@ -40,6 +41,7 @@ def create_empty(cls: Type['PandemicObservation'],
global_testing_summary=np.zeros((history_size, 1, len(InfectionSummary))),
stage=np.zeros((history_size, 1, 1)),
infection_above_threshold=np.zeros((history_size, 1, 1)),
critical_above_threshold=np.zeros((history_size, 1, 1)),
time_day=np.zeros((history_size, 1, 1)),
unlocked_non_essential_business_locations=np.zeros((history_size, 1,
num_non_essential_business))
Expand Down Expand Up @@ -73,6 +75,8 @@ def update_obs_with_sim_state(self, sim_state: PandemicSimState,

self.infection_above_threshold[hist_index, 0] = int(sim_state.infection_above_threshold)

self.critical_above_threshold[hist_index, 0] = int(sim_state.critical_above_threshold)

self.time_day[hist_index, 0] = int(sim_state.sim_time.day)

@property
Expand Down
4 changes: 4 additions & 0 deletions python/pandemic_simulator/environment/interfaces/sim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class PandemicSimState:
"""A boolean that is set to True if the infection goes above a set threshold. The threshold is set in the pandemic
sim"""

critical_above_threshold: bool
"""A boolean that is set to True if the infection (CRITICAL) goes above a set threshold. The threshold is set in the pandemic
sim"""

regulation_stage: int
"""The last executed regulation stage"""

Expand Down
36 changes: 35 additions & 1 deletion python/pandemic_simulator/environment/pandemic_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .simulator_config import PandemicSimConfig
from .simulator_opts import PandemicSimOpts

__all__ = ['PandemicGymEnv']
__all__ = ['PandemicGymEnv','PandemicGymEnv3Act']


class PandemicGymEnv(gym.Env):
Expand Down Expand Up @@ -169,3 +169,37 @@ def reset(self) -> PandemicObservation:

def render(self, mode: str = 'human') -> bool:
pass

class ReducedActionPandemicGymEnv(gym.ActionWrapper):
def __init__(self, env: PandemicGymEnv):
super().__init__(env)
self.env = env

self.action_space = gym.spaces.Discrete(3, start=-1)

@classmethod
def from_config(self,
sim_config: PandemicSimConfig,
pandemic_regulations: Sequence[PandemicRegulation],
sim_opts: PandemicSimOpts = PandemicSimOpts(),
reward_fn: Optional[RewardFunction] = None,
done_fn: Optional[DoneFunction] = None,
) -> 'PandemicGymEnv3Act':
env = PandemicGymEnv.from_config(sim_config = sim_config,
pandemic_regulations=pandemic_regulations,
sim_opts = sim_opts,
reward_fn=reward_fn,
done_fn=done_fn,
)

return ReducedActionPandemicGymEnv(env=env)

def step(self, action):
return self.env.step(int(self.action(action)))

def action(self, action):
assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))
return min(4, max(0, self.env._last_observation.stage[-1, 0, 0] + action))

def reset(self):
self.env.reset()
13 changes: 10 additions & 3 deletions python/pandemic_simulator/environment/pandemic_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class PandemicSim:
_new_time_slot_interval: SimTimeInterval
_infection_update_interval: SimTimeInterval
_infection_threshold: int
_critical_threshold: int
_numpy_rng: np.random.RandomState

_type_to_locations: DefaultDict
Expand All @@ -57,7 +58,8 @@ def __init__(self,
new_time_slot_interval: SimTimeInterval = SimTimeInterval(day=1),
infection_update_interval: SimTimeInterval = SimTimeInterval(day=1),
person_routine_assignment: Optional[PersonRoutineAssignment] = None,
infection_threshold: int = 0):
infection_threshold: int = 0,
critical_threshold: int = 0):
"""
:param locations: A sequence of Location instances.
:param persons: A sequence of Person instances.
Expand Down Expand Up @@ -86,6 +88,7 @@ def __init__(self,
self._new_time_slot_interval = new_time_slot_interval
self._infection_update_interval = infection_update_interval
self._infection_threshold = infection_threshold
self._critical_threshold = critical_threshold

self._type_to_locations = defaultdict(list)
for loc in locations:
Expand All @@ -112,7 +115,8 @@ def __init__(self,
global_location_summary=self._registry.global_location_summary,
sim_time=SimTime(),
regulation_stage=0,
infection_above_threshold=False
infection_above_threshold=False,
critical_above_threshold=False
)

@classmethod
Expand Down Expand Up @@ -158,6 +162,7 @@ def from_config(cls: Type['PandemicSim'],
pandemic_testing=pandemic_testing,
contact_tracer=contact_tracer,
infection_threshold=sim_opts.infection_threshold,
critical_threshold=sim_opts.critical_threshold,
person_routine_assignment=sim_config.person_routine_assignment)

@property
Expand Down Expand Up @@ -313,7 +318,8 @@ def step(self) -> None:
self._state.global_infection_summary = global_infection_summary
self._state.infection_above_threshold = (self._state.global_testing_state.summary[InfectionSummary.INFECTED]
>= self._infection_threshold)

self._state.critical_above_threshold = (self._state.global_testing_state.summary[InfectionSummary.CRITICAL]
>= self._critical_threshold)
self._state.global_location_summary = self._registry.global_location_summary

if self._contact_tracer and self._new_time_slot_interval.trigger_at_interval(self._state.sim_time):
Expand Down Expand Up @@ -403,4 +409,5 @@ def reset(self) -> None:
sim_time=SimTime(),
regulation_stage=0,
infection_above_threshold=False,
critical_above_threshold=False
)
3 changes: 3 additions & 0 deletions python/pandemic_simulator/environment/simulator_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ class PandemicSimOpts:

infection_threshold: int = 10
"""A threshold used by """

critical_threshold: int = 10
"""A threshold used by """