diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 59100992..b9ff5fc3 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -133,7 +133,8 @@ def setup(self): logger.info( f"{LOG_START} setup_models summary {self.timers.log(names=['setup_models'])}") - def before_episode(self): + def before_episode(self, episode_id:int): + self.metric_manager.start(global_step=episode_id) for model in self.remote_models: future.get(model.before_episode()) @@ -493,7 +494,7 @@ def learn(self): if episode_id == 5: torch.cuda.cudart().cudaProfilerStop() self.timers("episode").start() - self.before_episode() + self.before_episode(episode_id + 1) logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}") if self.env.timers is None: self.env.set_timers(self.timers) diff --git a/chatlearn/schedule/metric_manager.py b/chatlearn/schedule/metric_manager.py index 324bd431..5868566c 100644 --- a/chatlearn/schedule/metric_manager.py +++ b/chatlearn/schedule/metric_manager.py @@ -94,6 +94,13 @@ def log(self, prefix:str, global_step:int, scalar_dict): if writer_name == 'wandb': self._wandb_scalar_dict(prefix, global_step, scalar_dict) + def start(self, global_step:int): + # For wandb logger, frontend will only render step n when step n+1 is logged. + # For wandb, log an empty dict at beginning of each episode. + for writer_name, _ in self.writer_dict.items(): + if writer_name == 'wandb': + self.wandb_writer.log({}, step=global_step) + def _tensorboard_scalar_dict(self, prefix, global_step, scalar_dict): if isinstance(scalar_dict, (float, int)): name = prefix