Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions dictionary_learning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm import tqdm

import wandb
from torch.utils.tensorboard import SummaryWriter

from .dictionary import AutoEncoder
from .evaluation import evaluate
Expand Down Expand Up @@ -40,6 +41,7 @@ def log_stats(
transcoder: bool,
log_queues: list=[],
verbose: bool=False,
tb_writers: Optional[list[SummaryWriter]] = None,
):
with t.no_grad():
# quick hack to make sure all trainers get the same x
Expand Down Expand Up @@ -77,6 +79,12 @@ def log_stats(
value = value.cpu().item()
log[f"{name}"] = value

# TensorBoard logging
if tb_writers is not None:
for key, value in log.items():
if isinstance(value, (int, float)):
tb_writers[i].add_scalar(key, value, step)

if log_queues:
log_queues[i].put(log)

Expand Down Expand Up @@ -116,6 +124,8 @@ def trainSAE(
use_wandb:bool=False,
wandb_entity:str="",
wandb_project:str="",
use_tensorboard: bool = False,
tensorboard_logdir: Optional[str] = "tensorboard",
save_steps:Optional[list[int]]=None,
save_dir:Optional[str]=None,
log_steps:Optional[int]=None,
Expand Down Expand Up @@ -170,6 +180,16 @@ def trainSAE(
wandb_process.start()
wandb_processes.append(wandb_process)

if use_tensorboard:
tb_writers = []
for i, trainer in enumerate(trainers):
logdir = os.path.join(tensorboard_logdir, f"trainer_{i}")
os.makedirs(logdir, exist_ok=True)
writer = SummaryWriter(log_dir=logdir)
tb_writers.append(writer)
else:
tb_writers = None

# make save dirs, export config
if save_dir is not None:
save_dirs = [
Expand Down Expand Up @@ -207,9 +227,9 @@ def trainSAE(
break

# logging
if (use_wandb or verbose) and step % log_steps == 0:
if (use_wandb or use_tensorboard or verbose) and step % log_steps == 0:
log_stats(
trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose
trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose, tb_writers=tb_writers
)

# saving
Expand Down Expand Up @@ -271,3 +291,8 @@ def trainSAE(
queue.put("DONE")
for process in wandb_processes:
process.join()

# Close TensorBoard writers
if use_tensorboard:
for writer in tb_writers:
writer.close()