diff --git a/ombrl/utils/multiple_reward_wrapper.py b/ombrl/utils/multiple_reward_wrapper.py index eeffeee..c7481d7 100644 --- a/ombrl/utils/multiple_reward_wrapper.py +++ b/ombrl/utils/multiple_reward_wrapper.py @@ -17,6 +17,7 @@ def __call__(self, observation, action, next_observation, reward): def _get_reward(self, observation, action, next_observation, reward): raise NotImplementedError(f'Reward function not set for reward_index: {self.reward_index}') +RewardLike = Callable | RewardFunction class DmRewardFunction(RewardFunction): def __call__(self, observation, action, next_observation, reward, env: DMCEnv): diff --git a/ombrl/utils/train_utils.py b/ombrl/utils/train_utils.py index 4fc688a..e306ce9 100644 --- a/ombrl/utils/train_utils.py +++ b/ombrl/utils/train_utils.py @@ -11,7 +11,7 @@ from jaxrl.datasets import ReplayBuffer from maxinforl_jax.datasets import NstepReplayBuffer from ombrl.utils.evaluation import evaluate -from ombrl.utils.multiple_reward_wrapper import RewardFunction +from ombrl.utils.multiple_reward_wrapper import RewardLike from ombrl.utils.wrappers import PendulumInitWrapper from ombrl.utils.env_utils import make_metaworld_env, make_humanoid_bench_env @@ -28,7 +28,7 @@ def train( alg_kwargs: Dict, env_kwargs: Dict, seed: int = 0, - reward_list: List[RewardFunction] | RewardFunction | None = None, + reward_list: Optional[List[RewardLike] | RewardLike] = None, wandb_log: bool = True, log_config: Optional[Dict] = None, logs_dir: str = './logs', @@ -118,7 +118,7 @@ def train( reward_list, **alg_kwargs) elif alg_name == 'sombrl': if reward_list is not None: - assert isinstance(reward_list, RewardFunction), "Only one reward function can be passed to SOMBRL" + assert not isinstance(reward_list, List), "Only one reward function can be passed to SOMBRL" agent = SOMBRLExplorerLearner( seed, env.observation_space.sample(),