From ad1bf0ebdf811d0f9ea741e0691d350a01da1733 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Fri, 17 Oct 2025 17:57:48 +0800 Subject: [PATCH 1/2] fix wandb one step lag --- chatlearn/runtime/engine.py | 5 +++-- chatlearn/schedule/metric_manager.py | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 59100992..b9ff5fc3 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -133,7 +133,8 @@ def setup(self): logger.info( f"{LOG_START} setup_models summary {self.timers.log(names=['setup_models'])}") - def before_episode(self): + def before_episode(self, episode_id:int): + self.metric_manager.start(global_step=episode_id) for model in self.remote_models: future.get(model.before_episode()) @@ -493,7 +494,7 @@ def learn(self): if episode_id == 5: torch.cuda.cudart().cudaProfilerStop() self.timers("episode").start() - self.before_episode() + self.before_episode(episode_id + 1) logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}") if self.env.timers is None: self.env.set_timers(self.timers) diff --git a/chatlearn/schedule/metric_manager.py b/chatlearn/schedule/metric_manager.py index 324bd431..1a75009d 100644 --- a/chatlearn/schedule/metric_manager.py +++ b/chatlearn/schedule/metric_manager.py @@ -93,6 +93,13 @@ def log(self, prefix:str, global_step:int, scalar_dict): self._tensorboard_scalar_dict(prefix, global_step, scalar_dict) if writer_name == 'wandb': self._wandb_scalar_dict(prefix, global_step, scalar_dict) + + def start(self, global_step:int): + # For wandb logger, frontend will only render step n when step n+1 is logged. + # For wandb, log an empty dict at beginning of each episode. + for writer_name, _ in self.writer_dict.items(): + if writer_name == 'wandb': + self.wandb_writer.log({}, step=global_step) def _tensorboard_scalar_dict(self, prefix, global_step, scalar_dict): if isinstance(scalar_dict, (float, int)): From 61412a5c04d2bc9617b7aa0e26a4e0cde4ed99d0 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Fri, 17 Oct 2025 18:11:00 +0800 Subject: [PATCH 2/2] fix pylint --- chatlearn/schedule/metric_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/schedule/metric_manager.py b/chatlearn/schedule/metric_manager.py index 1a75009d..5868566c 100644 --- a/chatlearn/schedule/metric_manager.py +++ b/chatlearn/schedule/metric_manager.py @@ -93,7 +93,7 @@ def log(self, prefix:str, global_step:int, scalar_dict): self._tensorboard_scalar_dict(prefix, global_step, scalar_dict) if writer_name == 'wandb': self._wandb_scalar_dict(prefix, global_step, scalar_dict) - + def start(self, global_step:int): # For wandb logger, frontend will only render step n when step n+1 is logged. # For wandb, log an empty dict at beginning of each episode.