diff --git a/zero_ad_rl/env/base.py b/zero_ad_rl/env/base.py index c2ac7bf..e651e97 100644 --- a/zero_ad_rl/env/base.py +++ b/zero_ad_rl/env/base.py @@ -104,4 +104,18 @@ def observation(self, state): return self.states.from_json(state) +class ZeroADEnvRLlib(ZeroADEnv): + def __init__(self, config): + # TODO: can we set the action builder in the config? + server_address = self.address(config.worker_index) + super.__init__(self, server_address, None, action_builder, state_builder) + + def address(self, worker_index): + port = 6000 + worker_index + return f'http://127.0.0.1:{port}' + + @property + def scenario_config(self): + return self.config.get('scenario_config') + BaseZeroADEnv = ZeroADEnv