From ce301d84955f9d47679add3db29a52278452799a Mon Sep 17 00:00:00 2001 From: alfred100p Date: Wed, 6 Jul 2022 15:29:26 -0500 Subject: [PATCH 1/3] Add ActionWrapper and Critical Flag --- .../interfaces/pandemic_observation.py | 5 +++ .../environment/interfaces/sim_state.py | 4 +++ .../environment/pandemic_env.py | 36 ++++++++++++++++++- .../environment/pandemic_sim.py | 13 +++++-- .../environment/simulator_opts.py | 3 ++ 5 files changed, 57 insertions(+), 4 deletions(-) diff --git a/python/pandemic_simulator/environment/interfaces/pandemic_observation.py b/python/pandemic_simulator/environment/interfaces/pandemic_observation.py index 8d4f21c..37f5643 100644 --- a/python/pandemic_simulator/environment/interfaces/pandemic_observation.py +++ b/python/pandemic_simulator/environment/interfaces/pandemic_observation.py @@ -1,5 +1,6 @@ # Confidential, Copyright 2020, Sony Corporation of America, All rights reserved. from dataclasses import dataclass +from logging import critical from typing import Sequence, Type, cast, Optional import numpy as np @@ -21,6 +22,7 @@ class PandemicObservation: global_testing_summary: np.ndarray stage: np.ndarray infection_above_threshold: np.ndarray + critical_above_threshold:np.ndarray time_day: np.ndarray unlocked_non_essential_business_locations: Optional[np.ndarray] = None @@ -40,6 +42,7 @@ def create_empty(cls: Type['PandemicObservation'], global_testing_summary=np.zeros((history_size, 1, len(InfectionSummary))), stage=np.zeros((history_size, 1, 1)), infection_above_threshold=np.zeros((history_size, 1, 1)), + critical_above_threshold=np.zeros((history_size, 1, 1)), time_day=np.zeros((history_size, 1, 1)), unlocked_non_essential_business_locations=np.zeros((history_size, 1, num_non_essential_business)) @@ -73,6 +76,8 @@ def update_obs_with_sim_state(self, sim_state: PandemicSimState, self.infection_above_threshold[hist_index, 0] = int(sim_state.infection_above_threshold) + self.critical_above_threshold[hist_index, 0] = int(sim_state.critical_above_threshold) + self.time_day[hist_index, 0] = int(sim_state.sim_time.day) @property diff --git a/python/pandemic_simulator/environment/interfaces/sim_state.py b/python/pandemic_simulator/environment/interfaces/sim_state.py index df85adf..feefabf 100644 --- a/python/pandemic_simulator/environment/interfaces/sim_state.py +++ b/python/pandemic_simulator/environment/interfaces/sim_state.py @@ -41,6 +41,10 @@ class PandemicSimState: """A boolean that is set to True if the infection goes above a set threshold. The threshold is set in the pandemic sim""" + critical_above_threshold: bool + """A boolean that is set to True if the infection (CRITICAL) goes above a set threshold. The threshold is set in the pandemic + sim""" + regulation_stage: int """The last executed regulation stage""" diff --git a/python/pandemic_simulator/environment/pandemic_env.py b/python/pandemic_simulator/environment/pandemic_env.py index bbc6e4c..cebf118 100644 --- a/python/pandemic_simulator/environment/pandemic_env.py +++ b/python/pandemic_simulator/environment/pandemic_env.py @@ -11,7 +11,7 @@ from .simulator_config import PandemicSimConfig from .simulator_opts import PandemicSimOpts -__all__ = ['PandemicGymEnv'] +__all__ = ['PandemicGymEnv','PandemicGymEnv3Act'] class PandemicGymEnv(gym.Env): @@ -169,3 +169,37 @@ def reset(self) -> PandemicObservation: def render(self, mode: str = 'human') -> bool: pass + +class PandemicGymEnv3Act(gym.ActionWrapper): + def __init__(self, env: PandemicGymEnv): + super().__init__(env) + self.env = env + + self.action_space = gym.spaces.Discrete(3, start=-1) + + @classmethod + def from_config(self, + sim_config: PandemicSimConfig, + pandemic_regulations: Sequence[PandemicRegulation], + sim_opts: PandemicSimOpts = PandemicSimOpts(), + reward_fn: Optional[RewardFunction] = None, + done_fn: Optional[DoneFunction] = None, + ) -> 'PandemicGymEnv3Act': + env = PandemicGymEnv.from_config(sim_config = sim_config, + pandemic_regulations=pandemic_regulations, + sim_opts = sim_opts, + reward_fn=reward_fn, + done_fn=done_fn, + ) + + return PandemicGymEnv3Act(env=env) + + def step(self, action): + return self.env.step(int(self.action(action))) + + def action(self, action): + assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) + return min(4, max(0, self.env._last_observation.stage[-1, 0, 0] + action)) + + def reset(self): + self.env.reset() \ No newline at end of file diff --git a/python/pandemic_simulator/environment/pandemic_sim.py b/python/pandemic_simulator/environment/pandemic_sim.py index 929daee..870f787 100644 --- a/python/pandemic_simulator/environment/pandemic_sim.py +++ b/python/pandemic_simulator/environment/pandemic_sim.py @@ -41,6 +41,7 @@ class PandemicSim: _new_time_slot_interval: SimTimeInterval _infection_update_interval: SimTimeInterval _infection_threshold: int + _critical_threshold: int _numpy_rng: np.random.RandomState _type_to_locations: DefaultDict @@ -57,7 +58,8 @@ def __init__(self, new_time_slot_interval: SimTimeInterval = SimTimeInterval(day=1), infection_update_interval: SimTimeInterval = SimTimeInterval(day=1), person_routine_assignment: Optional[PersonRoutineAssignment] = None, - infection_threshold: int = 0): + infection_threshold: int = 0, + critical_threshold: int = 0): """ :param locations: A sequence of Location instances. :param persons: A sequence of Person instances. @@ -86,6 +88,7 @@ def __init__(self, self._new_time_slot_interval = new_time_slot_interval self._infection_update_interval = infection_update_interval self._infection_threshold = infection_threshold + self._critical_threshold = critical_threshold self._type_to_locations = defaultdict(list) for loc in locations: @@ -112,7 +115,8 @@ def __init__(self, global_location_summary=self._registry.global_location_summary, sim_time=SimTime(), regulation_stage=0, - infection_above_threshold=False + infection_above_threshold=False, + critical_above_threshold=False ) @classmethod @@ -158,6 +162,7 @@ def from_config(cls: Type['PandemicSim'], pandemic_testing=pandemic_testing, contact_tracer=contact_tracer, infection_threshold=sim_opts.infection_threshold, + critical_threshold=sim_opts.critical_threshold, person_routine_assignment=sim_config.person_routine_assignment) @property @@ -313,7 +318,8 @@ def step(self) -> None: self._state.global_infection_summary = global_infection_summary self._state.infection_above_threshold = (self._state.global_testing_state.summary[InfectionSummary.INFECTED] >= self._infection_threshold) - + self._state.critical_above_threshold = (self._state.global_testing_state.summary[InfectionSummary.CRITICAL] + >= self._critical_threshold) self._state.global_location_summary = self._registry.global_location_summary if self._contact_tracer and self._new_time_slot_interval.trigger_at_interval(self._state.sim_time): @@ -403,4 +409,5 @@ def reset(self) -> None: sim_time=SimTime(), regulation_stage=0, infection_above_threshold=False, + critical_above_threshold=False ) diff --git a/python/pandemic_simulator/environment/simulator_opts.py b/python/pandemic_simulator/environment/simulator_opts.py index de74636..fc00b08 100644 --- a/python/pandemic_simulator/environment/simulator_opts.py +++ b/python/pandemic_simulator/environment/simulator_opts.py @@ -43,3 +43,6 @@ class PandemicSimOpts: infection_threshold: int = 10 """A threshold used by """ + + critical_threshold: int = 10 + """A threshold used by """ From f61c8c8f750ee268f4d9dc3a18e204546a6edcd6 Mon Sep 17 00:00:00 2001 From: "Alfred W. Jacob" <82844187+alfred100p@users.noreply.github.com> Date: Wed, 6 Jul 2022 15:33:56 -0500 Subject: [PATCH 2/3] cleanup --- .../environment/interfaces/pandemic_observation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pandemic_simulator/environment/interfaces/pandemic_observation.py b/python/pandemic_simulator/environment/interfaces/pandemic_observation.py index 37f5643..91ad274 100644 --- a/python/pandemic_simulator/environment/interfaces/pandemic_observation.py +++ b/python/pandemic_simulator/environment/interfaces/pandemic_observation.py @@ -1,6 +1,5 @@ # Confidential, Copyright 2020, Sony Corporation of America, All rights reserved. from dataclasses import dataclass -from logging import critical from typing import Sequence, Type, cast, Optional import numpy as np From 28c57eb5296732f73aeaf93662f7a49dc59fac3f Mon Sep 17 00:00:00 2001 From: "Alfred W. Jacob" <82844187+alfred100p@users.noreply.github.com> Date: Wed, 10 Aug 2022 11:13:09 -0500 Subject: [PATCH 3/3] Rename Action Wrapper --- python/pandemic_simulator/environment/pandemic_env.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pandemic_simulator/environment/pandemic_env.py b/python/pandemic_simulator/environment/pandemic_env.py index cebf118..a3c0ec1 100644 --- a/python/pandemic_simulator/environment/pandemic_env.py +++ b/python/pandemic_simulator/environment/pandemic_env.py @@ -170,7 +170,7 @@ def reset(self) -> PandemicObservation: def render(self, mode: str = 'human') -> bool: pass -class PandemicGymEnv3Act(gym.ActionWrapper): +class ReducedActionPandemicGymEnv(gym.ActionWrapper): def __init__(self, env: PandemicGymEnv): super().__init__(env) self.env = env @@ -192,7 +192,7 @@ def from_config(self, done_fn=done_fn, ) - return PandemicGymEnv3Act(env=env) + return ReducedActionPandemicGymEnv(env=env) def step(self, action): return self.env.step(int(self.action(action))) @@ -202,4 +202,4 @@ def action(self, action): return min(4, max(0, self.env._last_observation.stage[-1, 0, 0] + action)) def reset(self): - self.env.reset() \ No newline at end of file + self.env.reset()