diff --git a/dictionary_learning/training.py b/dictionary_learning/training.py index 0671f31..5bea4ad 100644 --- a/dictionary_learning/training.py +++ b/dictionary_learning/training.py @@ -13,6 +13,7 @@ from tqdm import tqdm import wandb +from torch.utils.tensorboard import SummaryWriter from .dictionary import AutoEncoder from .evaluation import evaluate @@ -40,6 +41,7 @@ def log_stats( transcoder: bool, log_queues: list=[], verbose: bool=False, + tb_writers: Optional[list[SummaryWriter]] = None, ): with t.no_grad(): # quick hack to make sure all trainers get the same x @@ -77,6 +79,12 @@ def log_stats( value = value.cpu().item() log[f"{name}"] = value + # TensorBoard logging + if tb_writers is not None: + for key, value in log.items(): + if isinstance(value, (int, float)): + tb_writers[i].add_scalar(key, value, step) + if log_queues: log_queues[i].put(log) @@ -116,6 +124,8 @@ def trainSAE( use_wandb:bool=False, wandb_entity:str="", wandb_project:str="", + use_tensorboard: bool = False, + tensorboard_logdir: Optional[str] = "tensorboard", save_steps:Optional[list[int]]=None, save_dir:Optional[str]=None, log_steps:Optional[int]=None, @@ -170,6 +180,16 @@ def trainSAE( wandb_process.start() wandb_processes.append(wandb_process) + if use_tensorboard: + tb_writers = [] + for i, trainer in enumerate(trainers): + logdir = os.path.join(tensorboard_logdir, f"trainer_{i}") + os.makedirs(logdir, exist_ok=True) + writer = SummaryWriter(log_dir=logdir) + tb_writers.append(writer) + else: + tb_writers = None + # make save dirs, export config if save_dir is not None: save_dirs = [ @@ -207,9 +227,9 @@ def trainSAE( break # logging - if (use_wandb or verbose) and step % log_steps == 0: + if (use_wandb or use_tensorboard or verbose) and step % log_steps == 0: log_stats( - trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose + trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose, tb_writers=tb_writers ) # saving @@ -271,3 +291,8 @@ def trainSAE( queue.put("DONE") for process in wandb_processes: process.join() + + # Close TensorBoard writers + if use_tensorboard: + for writer in tb_writers: + writer.close()