From edf53942cde16397511ee273736a59d1d0f7d482 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Tue, 6 May 2025 16:00:01 -0400 Subject: [PATCH 01/11] copying train_clt_local.py so we can see git diff easier --- tune_clt_local_ray.py | 401 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 tune_clt_local_ray.py diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py new file mode 100644 index 0000000..33ca727 --- /dev/null +++ b/tune_clt_local_ray.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python3 +""" +Script to train a Cross-Layer Transcoder (CLT) using pre-generated local activations. +Handles configuration parsing from command-line arguments and initiates training. +""" + +import argparse +import torch +from pathlib import Path +from typing import Literal, Optional +import logging +import time +import json + +# Attempt to import transformers for model dimension detection +try: + from transformers import AutoConfig +except ImportError: + AutoConfig = None + +# Import necessary CLT components +try: + from clt.config import CLTConfig, TrainingConfig + from clt.training.trainer import CLTTrainer +except ImportError as e: + print( + f"FATAL: ImportError: {e}. Please ensure the 'clt' library is installed or " + "the project root is in your PYTHONPATH." + ) + raise + +# Setup basic logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +def get_model_dimensions(model_name: str) -> tuple[int, int]: + """Attempt to dynamically get num_layers and d_model from model_name.""" + if AutoConfig is None: + logger.warning( + "Transformers library not found. Cannot dynamically detect model dimensions." + " Falling back to gpt2 defaults (12 layers, 768 hidden size)." + " Install transformers (`pip install transformers`) for auto-detection." + ) + return 12, 768 # Default to gpt2 small + + try: + config = AutoConfig.from_pretrained(model_name) + num_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None) + d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None) + + if num_layers is None or d_model is None: + raise ValueError(f"Could not automatically determine num_layers or d_model for {model_name}") + logger.info(f"Detected model dimensions for {model_name}: {num_layers} layers, {d_model} hidden size.") + return num_layers, d_model + except Exception as e: + logger.warning( + f"Failed to get model dimensions for {model_name}: {e}. " + f"Falling back to gpt2 defaults (12 layers, 768 hidden size)." + ) + return 12, 768 + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Train a Cross-Layer Transcoder (CLT) using pre-generated local activations.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # --- Core Training Parameters --- + core_group = parser.add_argument_group("Core Training Parameters") + core_group.add_argument( + "--activation-path", + type=str, + required=True, + help="Path to the directory containing pre-generated activations (including index.bin, metadata.json, etc.).", + ) + core_group.add_argument( + "--output-dir", + type=str, + default=f"clt_train_local_{int(time.time())}", + help="Directory to save logs, checkpoints, and final model.", + ) + core_group.add_argument( + "--model-name", + type=str, + required=True, + help="Base model name or path (e.g., 'gpt2', 'gpt2-medium'). Must match the model used for activation generation.", + ) + core_group.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g., 'cuda', 'cpu', 'mps'). Auto-detected if None.", + ) + core_group.add_argument( + "--distributed", + action="store_true", + help="Enable distributed training (requires torchrun/appropriate launcher).", + ) + + # --- CLT Model Architecture --- + clt_group = parser.add_argument_group("CLT Model Architecture (CLTConfig)") + clt_group.add_argument( + "--num-features", + type=int, + required=True, + help="Number of features per layer in the CLT.", + ) + # num_layers and d_model are derived from the base model + clt_group.add_argument( + "--activation-fn", + type=str, + choices=["jumprelu", "relu"], + default="jumprelu", + help="Activation function for the CLT.", + ) + clt_group.add_argument( + "--jumprelu-threshold", + type=float, + default=0.03, + help="Threshold for JumpReLU activation (if used).", + ) + clt_group.add_argument( + "--clt-dtype", + type=str, + default=None, + help="Optional data type for the CLT model parameters (e.g., 'float16', 'bfloat16').", + ) + + # --- Training Hyperparameters --- + train_group = parser.add_argument_group("Training Hyperparameters (TrainingConfig)") + train_group.add_argument("--learning-rate", type=float, default=3e-4, help="Optimizer learning rate.") + train_group.add_argument( + "--training-steps", + type=int, + default=50000, + help="Total number of training steps.", + ) + train_group.add_argument( + "--train-batch-size-tokens", + type=int, + default=4096, + help="Target number of tokens per training batch.", + ) + train_group.add_argument( + "--normalization-method", + type=str, + choices=["auto", "none"], + default="auto", + help=( + "Normalization for activation store. 'auto' uses pre-calculated stats " + "(norm_stats.json) from the activation_path. 'none' disables normalization." + ), + ) + train_group.add_argument( + "--sparsity-lambda", + type=float, + default=1e-3, + help="Coefficient for the L1 sparsity penalty.", + ) + train_group.add_argument( + "--sparsity-c", + type=float, + default=1.0, + help="Constant shaping the sparsity penalty (typically 1.0).", + ) + train_group.add_argument( + "--preactivation-coef", + type=float, + default=3e-6, + help="Coefficient for the pre-activation MSE loss term.", + ) + train_group.add_argument( + "--optimizer", + type=str, + choices=["adam", "adamw"], + default="adamw", + help="Optimizer algorithm.", + ) + train_group.add_argument( + "--optimizer-beta1", + type=float, + default=None, + help="Optimizer beta1 value (if using Adam/AdamW).", + ) + train_group.add_argument( + "--optimizer-beta2", + type=float, + default=None, + help="Optimizer beta2 value (if using Adam/AdamW).", + ) + train_group.add_argument( + "--lr-scheduler", + type=str, + choices=["linear", "cosine", "linear_final20", "none"], + default="linear", + help=( + "Learning rate scheduler type. 'linear_final20' keeps LR constant until the last 20% " + "of steps then decays linearly to 0 ('none' to disable)." + ), + ) + train_group.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility.", + ) + train_group.add_argument( + "--activation-dtype", + type=str, + default="float32", + help="Data type to load activations as (e.g., 'float32', 'bfloat16'). Should match storage or be compatible.", + ) + train_group.add_argument( + "--dead-feature-window", + type=int, + default=1000, + help="Number of steps of inactivity before a feature is considered 'dead' for evaluation.", + ) + + # --- Sampling Strategy --- Added Group + sampling_group = parser.add_argument_group("Sampling Strategy (TrainingConfig)") + sampling_group.add_argument( + "--sampling-strategy", + type=str, + choices=["sequential", "random_chunk"], + default="sequential", + help="Sampling strategy: 'sequential' processes chunks in order per epoch, 'random_chunk' picks a random valid chunk each step.", + ) + + # --- Logging & Checkpointing --- + log_group = parser.add_argument_group("Logging & Checkpointing (TrainingConfig)") + log_group.add_argument( + "--log-interval", + type=int, + default=100, + help="Log training metrics every N steps.", + ) + log_group.add_argument( + "--eval-interval", + type=int, + default=1000, + help="Run evaluation metrics computation every N steps.", + ) + log_group.add_argument( + "--checkpoint-interval", + type=int, + default=1000, + help="Save a training checkpoint every N steps.", + ) + # WandB arguments + log_group.add_argument("--enable-wandb", action="store_true", help="Enable Weights & Biases logging.") + log_group.add_argument("--wandb-project", type=str, default=None, help="WandB project name.") + log_group.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity (username or team).", + ) + log_group.add_argument( + "--wandb-run-name", + type=str, + default=None, + help="Custom name for the WandB run (defaults to a timestamp).", + ) + log_group.add_argument("--wandb-tags", nargs="+", default=None, help="List of tags for the WandB run.") + + args = parser.parse_args() + + # --- Validation --- + # Simplified validation: activation_path is required by argparse + # No need to check for generation args + + return args + + +def main(): + """Main function to configure and run the CLTTrainer for local activations.""" + args = parse_args() + + # --- Setup Output Directory --- + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + logger.info(f"Output directory: {output_dir.resolve()}") + + # Save command-line arguments + try: + with open(output_dir / "cli_args.json", "w") as f: + json.dump(vars(args), f, indent=2) + except Exception as e: + logger.warning(f"Could not save command-line args: {e}") + + # --- Determine Device --- + if args.device: + device = args.device + else: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + logger.info(f"Using device: {device}") + + # --- Determine Base Model Dimensions --- + # Use the provided --model-name to get dimensions for the CLT config. + # This ensures the CLT matches the architecture activations were generated from. + base_model_name = args.model_name + num_layers, d_model = get_model_dimensions(base_model_name) + if num_layers is None or d_model is None: + # Added error handling if dimensions couldn't be determined + logger.error(f"Could not determine dimensions for model '{base_model_name}'. Exiting.") + return + + # --- Create CLT Configuration --- + clt_config = CLTConfig( + num_features=args.num_features, + num_layers=num_layers, + d_model=d_model, + activation_fn=args.activation_fn, + jumprelu_threshold=args.jumprelu_threshold, + clt_dtype=args.clt_dtype, + ) + logger.info(f"CLT Config: {clt_config}") + + # --- Create Training Configuration --- + # Handle 'none' scheduler case + lr_scheduler_arg: Optional[Literal["linear", "cosine", "linear_final20"]] = ( + args.lr_scheduler if args.lr_scheduler != "none" else None + ) + + # Simplified TrainingConfig instantiation for local source only + training_config = TrainingConfig( + # Core Training + learning_rate=args.learning_rate, + training_steps=args.training_steps, + seed=args.seed, + train_batch_size_tokens=args.train_batch_size_tokens, + # Activation Source (hardcoded to local_manifest) + activation_source="local_manifest", + activation_path=args.activation_path, + activation_dtype=args.activation_dtype, + # Normalization + normalization_method=args.normalization_method, + # Sampling Strategy + sampling_strategy=args.sampling_strategy, + # Loss Coeffs + sparsity_lambda=args.sparsity_lambda, + sparsity_c=args.sparsity_c, + preactivation_coef=args.preactivation_coef, + # Optimizer & Scheduler + optimizer=args.optimizer, + optimizer_beta1=args.optimizer_beta1, + optimizer_beta2=args.optimizer_beta2, + lr_scheduler=lr_scheduler_arg, + # Logging & Checkpointing + log_interval=args.log_interval, + eval_interval=args.eval_interval, + checkpoint_interval=args.checkpoint_interval, + # Dead Features + dead_feature_window=args.dead_feature_window, + # WandB + enable_wandb=args.enable_wandb, + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + wandb_run_name=args.wandb_run_name, + wandb_tags=args.wandb_tags, + # Remote config is not handled by this script + remote_config=None, + ) + logger.info(f"Training Config: {training_config}") + + # --- Initialize Trainer --- + logger.info("Initializing CLTTrainer...") + try: + trainer = CLTTrainer( + clt_config=clt_config, + training_config=training_config, + log_dir=str(output_dir), + device=device, + distributed=args.distributed, + ) + except Exception as e: + logger.exception(f"Failed to initialize CLTTrainer: {e}") # Use logger.exception + raise + + # --- Start Training --- + logger.info("Starting training from local activations...") + try: + trainer.train() # eval_every is handled internally now + logger.info("Training complete!") + logger.info(f"Final model and logs saved in: {output_dir.resolve()}") + except Exception as e: + logger.exception(f"Training failed: {e}") # Use logger.exception + raise + + +if __name__ == "__main__": + main() From 05e9ebe42ecdb03f69ccaf18f121c5b503cd2c12 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Tue, 6 May 2025 16:00:42 -0400 Subject: [PATCH 02/11] adding ray tune changes --- tune_clt_local_ray.py | 56 +++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py index 33ca727..157d269 100644 --- a/tune_clt_local_ray.py +++ b/tune_clt_local_ray.py @@ -11,6 +11,7 @@ import logging import time import json +import ray, ray.tune # Attempt to import transformers for model dimension detection try: @@ -61,6 +62,26 @@ def get_model_dimensions(model_name: str) -> tuple[int, int]: return 12, 768 +def train_loop_per_worker(cfg): + + if 'sparsity_c' in cfg: + cfg['training_config'].sparsity_c = cfg['sparsity_c'] + if 'sparsity_lambda' in cfg: + cfg['training_config'].sparsity_lambda = cfg['sparsity_lambda'] + + + trainer = CLTTrainer( + clt_config=cfg['clt_config'], + training_config=cfg['training_config'], + log_dir=cfg['log_dir'], + device=cfg['device'], + distributed=False, # Explicitly set for tutorial + use_ray='tune' + ) + + trained_clt_model = trainer.train(eval_every=cfg['training_config'].eval_interval) + + def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( @@ -372,28 +393,27 @@ def main(): ) logger.info(f"Training Config: {training_config}") - # --- Initialize Trainer --- - logger.info("Initializing CLTTrainer...") - try: - trainer = CLTTrainer( - clt_config=clt_config, - training_config=training_config, - log_dir=str(output_dir), - device=device, - distributed=args.distributed, - ) - except Exception as e: - logger.exception(f"Failed to initialize CLTTrainer: {e}") # Use logger.exception - raise + hps = { + 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), + 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]) + } + + n_gpus = 1 + n_parallel_workers = 8 - # --- Start Training --- - logger.info("Starting training from local activations...") + # --- Start Tuning --- + logger.info("Starting hyperparameter tuning from local activations...") try: - trainer.train() # eval_every is handled internally now - logger.info("Training complete!") + tuner = ray.tune.Tuner( + ray.tune.with_resources(train_loop_per_worker, {'gpu': n_gpus / n_parallel_workers}), + param_space=hps, + tune_config=ray.tune.TuneConfig(max_concurrent_trials=n_parallel_workers) + ) + results = tuner.fit() + logger.info("Tuning complete!") logger.info(f"Final model and logs saved in: {output_dir.resolve()}") except Exception as e: - logger.exception(f"Training failed: {e}") # Use logger.exception + logger.exception(f"Tuning failed: {e}") # Use logger.exception raise From da015dc7d3ed8c8ad80f85b70a8a09eea8afff2d Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Tue, 6 May 2025 16:01:15 -0400 Subject: [PATCH 03/11] adding option to CLTTrainer for ray reporting for use with ray train/tune --- clt/training/trainer.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/clt/training/trainer.py b/clt/training/trainer.py index 0b9b760..ec877d3 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -1,6 +1,6 @@ import torch import torch.optim as optim -from typing import Dict, Optional, Union, Any +from typing import Dict, Literal, Optional, Union, Any from tqdm import tqdm # type: ignore import os import json @@ -34,6 +34,8 @@ ) # Keep for StreamingStore usage from .evaluator import CLTEvaluator # Import the new evaluator +import ray + # Get logger for this module logger = logging.getLogger(__name__) @@ -286,6 +288,7 @@ def __init__( log_dir: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, distributed: bool = False, # Add distributed flag + use_ray: Optional[Literal['train', 'tune']] = None, ): """Initialize the CLT trainer. @@ -483,6 +486,19 @@ def lr_lambda(current_step: int): # Dummy logger for non-rank-0 processes self.wandb_logger = DummyWandBLogger() + + if use_ray: + if use_ray == 'train': + from ray.train import report + elif use_ray == 'tune': + from ray.tune import report + + self.ray_reporter = report + + else: + self.ray_reporter = None + + @property def dead_neurons_mask(self) -> torch.Tensor: """Boolean mask indicating dead neurons based on inactivity window.""" @@ -1251,6 +1267,12 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: if save_checkpoint_flag: self._save_checkpoint(step) + if self.ray_reporter: + ray_metrics = loss_dict + if run_eval_flag: + ray_metrics = ray_metrics | eval_metrics + self.ray_reporter(ray_metrics) + # --- Explicitly delete tensors at the very end of the loop iteration --- # # Do this on all ranks try: From d3fe55e0444978abd2e1a29d18c71ced71a92891 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 14:17:08 -0400 Subject: [PATCH 04/11] passing configs to worker --- tune_clt_local_ray.py | 40 +++++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py index 157d269..7d30a63 100644 --- a/tune_clt_local_ray.py +++ b/tune_clt_local_ray.py @@ -79,7 +79,7 @@ def train_loop_per_worker(cfg): use_ray='tune' ) - trained_clt_model = trainer.train(eval_every=cfg['training_config'].eval_interval) + # trained_clt_model = trainer.train(eval_every=cfg['training_config'].eval_interval) def parse_args(): @@ -241,6 +241,20 @@ def parse_args(): help="Number of steps of inactivity before a feature is considered 'dead' for evaluation.", ) + tune_group = parser.add_argument_group("Hyperparameter Tuning Parameters") + tune_group.add_argument( + "--n-gpus", + type=int, + default=1, + help="Number of GPUs available for hyperparameter tuning." + ) + tune_group.add_argument( + "--n-workers", + type=int, + default=1, + help="Number of trainers to run in parallel for hyperparameter tuning." + ) + # --- Sampling Strategy --- Added Group sampling_group = parser.add_argument_group("Sampling Strategy (TrainingConfig)") sampling_group.add_argument( @@ -393,21 +407,29 @@ def main(): ) logger.info(f"Training Config: {training_config}") + + # --- Start Tuning --- + # configs to pass to workers hps = { - 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), - 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]) + 'training_config': training_config, + 'clt_config': clt_config, + 'log_dir': str(output_dir), + 'device': device + } + # add actual hyperparameters + hps |= { + # 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), + # 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]), + 'sparsity_c': ray.tune.grid_search([0.01]), + 'sparsity_lambda': ray.tune.grid_search([1e-5]), } - n_gpus = 1 - n_parallel_workers = 8 - - # --- Start Tuning --- logger.info("Starting hyperparameter tuning from local activations...") try: tuner = ray.tune.Tuner( - ray.tune.with_resources(train_loop_per_worker, {'gpu': n_gpus / n_parallel_workers}), + ray.tune.with_resources(train_loop_per_worker, {'gpu': args.n_gpus / args.n_workers}), param_space=hps, - tune_config=ray.tune.TuneConfig(max_concurrent_trials=n_parallel_workers) + tune_config=ray.tune.TuneConfig(max_concurrent_trials=args.n_workers) ) results = tuner.fit() logger.info("Tuning complete!") From ca9dce7e49fdeee9d7579aa2811d50beb4ca6f54 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 14:31:16 -0400 Subject: [PATCH 05/11] moved configs into train_loop_per_worker func --- tune_clt_local_ray.py | 181 ++++++++++++++++++++---------------------- 1 file changed, 87 insertions(+), 94 deletions(-) diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py index 7d30a63..4dc0955 100644 --- a/tune_clt_local_ray.py +++ b/tune_clt_local_ray.py @@ -64,22 +64,96 @@ def get_model_dimensions(model_name: str) -> tuple[int, int]: def train_loop_per_worker(cfg): - if 'sparsity_c' in cfg: - cfg['training_config'].sparsity_c = cfg['sparsity_c'] - if 'sparsity_lambda' in cfg: - cfg['training_config'].sparsity_lambda = cfg['sparsity_lambda'] + # --- Determine Device --- + if cfg['device']: + device = cfg['device'] + else: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + logger.info(f"Using device: {device}") + + # --- Determine Base Model Dimensions --- + # Use the provided --model-name to get dimensions for the CLT config. + # This ensures the CLT matches the architecture activations were generated from. + base_model_name = cfg['model_name'] + num_layers, d_model = get_model_dimensions(base_model_name) + if num_layers is None or d_model is None: + # Added error handling if dimensions couldn't be determined + logger.error(f"Could not determine dimensions for model '{base_model_name}'. Exiting.") + return + + # --- Create CLT Configuration --- + clt_config = CLTConfig( + num_features=cfg['num_features'], + num_layers=num_layers, + d_model=d_model, + activation_fn=cfg['activation_fn'], + jumprelu_threshold=cfg['jumprelu_threshold'], + clt_dtype=cfg['clt_dtype'], + ) + logger.info(f"CLT Config: {clt_config}") + + # --- Create Training Configuration --- + # Handle 'none' scheduler case + lr_scheduler_arg: Optional[Literal["linear", "cosine", "linear_final20"]] = ( + cfg['lr_scheduler'] if cfg['lr_scheduler'] != "none" else None + ) + # Simplified TrainingConfig instantiation for local source only + training_config = TrainingConfig( + # Core Training + learning_rate=cfg['learning_rate'], + training_steps=cfg['training_steps'], + seed=cfg['seed'], + train_batch_size_tokens=cfg['train_batch_size_tokens'], + # Activation Source (hardcoded to local_manifest) + activation_source="local_manifest", + activation_path=cfg['activation_path'], + activation_dtype=cfg['activation_dtype'], + # Normalization + normalization_method=cfg['normalization_method'], + # Sampling Strategy + sampling_strategy=cfg['sampling_strategy'], + # Loss Coeffs + sparsity_lambda=cfg['sparsity_lambda'], + sparsity_c=cfg['sparsity_c'], + preactivation_coef=cfg['preactivation_coef'], + # Optimizer & Scheduler + optimizer=cfg['optimizer'], + optimizer_beta1=cfg['optimizer_beta1'], + optimizer_beta2=cfg['optimizer_beta2'], + lr_scheduler=lr_scheduler_arg, + # Logging & Checkpointing + log_interval=cfg['log_interval'], + eval_interval=cfg['eval_interval'], + checkpoint_interval=cfg['checkpoint_interval'], + # Dead Features + dead_feature_window=cfg['dead_feature_window'], + # WandB + enable_wandb=cfg['enable_wandb'], + wandb_project=cfg['wandb_project'], + wandb_entity=cfg['wandb_entity'], + wandb_run_name=cfg['wandb_run_name'], + wandb_tags=cfg['wandb_tags'], + # Remote config is not handled by this script + remote_config=None, + ) + logger.info(f"Training Config: {training_config}") trainer = CLTTrainer( - clt_config=cfg['clt_config'], - training_config=cfg['training_config'], + clt_config=clt_config, + training_config=training_config, log_dir=cfg['log_dir'], - device=cfg['device'], + device=device, distributed=False, # Explicitly set for tutorial use_ray='tune' ) - # trained_clt_model = trainer.train(eval_every=cfg['training_config'].eval_interval) + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) def parse_args(): @@ -327,94 +401,12 @@ def main(): except Exception as e: logger.warning(f"Could not save command-line args: {e}") - # --- Determine Device --- - if args.device: - device = args.device - else: - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" - else: - device = "cpu" - logger.info(f"Using device: {device}") - - # --- Determine Base Model Dimensions --- - # Use the provided --model-name to get dimensions for the CLT config. - # This ensures the CLT matches the architecture activations were generated from. - base_model_name = args.model_name - num_layers, d_model = get_model_dimensions(base_model_name) - if num_layers is None or d_model is None: - # Added error handling if dimensions couldn't be determined - logger.error(f"Could not determine dimensions for model '{base_model_name}'. Exiting.") - return - - # --- Create CLT Configuration --- - clt_config = CLTConfig( - num_features=args.num_features, - num_layers=num_layers, - d_model=d_model, - activation_fn=args.activation_fn, - jumprelu_threshold=args.jumprelu_threshold, - clt_dtype=args.clt_dtype, - ) - logger.info(f"CLT Config: {clt_config}") - - # --- Create Training Configuration --- - # Handle 'none' scheduler case - lr_scheduler_arg: Optional[Literal["linear", "cosine", "linear_final20"]] = ( - args.lr_scheduler if args.lr_scheduler != "none" else None - ) - - # Simplified TrainingConfig instantiation for local source only - training_config = TrainingConfig( - # Core Training - learning_rate=args.learning_rate, - training_steps=args.training_steps, - seed=args.seed, - train_batch_size_tokens=args.train_batch_size_tokens, - # Activation Source (hardcoded to local_manifest) - activation_source="local_manifest", - activation_path=args.activation_path, - activation_dtype=args.activation_dtype, - # Normalization - normalization_method=args.normalization_method, - # Sampling Strategy - sampling_strategy=args.sampling_strategy, - # Loss Coeffs - sparsity_lambda=args.sparsity_lambda, - sparsity_c=args.sparsity_c, - preactivation_coef=args.preactivation_coef, - # Optimizer & Scheduler - optimizer=args.optimizer, - optimizer_beta1=args.optimizer_beta1, - optimizer_beta2=args.optimizer_beta2, - lr_scheduler=lr_scheduler_arg, - # Logging & Checkpointing - log_interval=args.log_interval, - eval_interval=args.eval_interval, - checkpoint_interval=args.checkpoint_interval, - # Dead Features - dead_feature_window=args.dead_feature_window, - # WandB - enable_wandb=args.enable_wandb, - wandb_project=args.wandb_project, - wandb_entity=args.wandb_entity, - wandb_run_name=args.wandb_run_name, - wandb_tags=args.wandb_tags, - # Remote config is not handled by this script - remote_config=None, - ) - logger.info(f"Training Config: {training_config}") - # --- Start Tuning --- - # configs to pass to workers + # config to pass to workers hps = { - 'training_config': training_config, - 'clt_config': clt_config, - 'log_dir': str(output_dir), - 'device': device + 'log_dir': str(output_dir.resolve()), + **vars(args) } # add actual hyperparameters hps |= { @@ -429,7 +421,8 @@ def main(): tuner = ray.tune.Tuner( ray.tune.with_resources(train_loop_per_worker, {'gpu': args.n_gpus / args.n_workers}), param_space=hps, - tune_config=ray.tune.TuneConfig(max_concurrent_trials=args.n_workers) + tune_config=ray.tune.TuneConfig(max_concurrent_trials=args.n_workers), + run_config=ray.tune.RunConfig(storage_path=str(output_dir.resolve())) ) results = tuner.fit() logger.info("Tuning complete!") From 01d9b5e7084e83110ee9f407ac8bb41cf6432af1 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:03:28 -0400 Subject: [PATCH 06/11] updated model checkpointing to work with ray --- clt/training/trainer.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/clt/training/trainer.py b/clt/training/trainer.py index ec877d3..5266014 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -34,7 +34,6 @@ ) # Keep for StreamingStore usage from .evaluator import CLTEvaluator # Import the new evaluator -import ray # Get logger for this module logger = logging.getLogger(__name__) @@ -487,16 +486,25 @@ def lr_lambda(current_step: int): self.wandb_logger = DummyWandBLogger() + # Set up imports for Ray if applicable if use_ray: if use_ray == 'train': from ray.train import report + from ray.train import Checkpoint elif use_ray == 'tune': from ray.tune import report + from ray.tune import Checkpoint + + import tempfile self.ray_reporter = report + self.ray_checkpoint = Checkpoint + self.ray_tempfile = tempfile else: self.ray_reporter = None + self.ray_checkpoint = None + self.ray_tempfile = None @property @@ -1264,14 +1272,33 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: dist.barrier() # --- Checkpointing (All ranks participate) --- - if save_checkpoint_flag: + if save_checkpoint_flag and not self.ray_reporter: self._save_checkpoint(step) + # Report to ray if enabled if self.ray_reporter: + # Make metrics dict to report ray_metrics = loss_dict if run_eval_flag: ray_metrics = ray_metrics | eval_metrics - self.ray_reporter(ray_metrics) + + with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + + if save_checkpoint_flag: + model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") + store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") + + torch.save(self.model.state_dict(), model_checkpoint_path) + # TODO: wandb_logger untested with ray reporting + # self.wandb_logger.log_artifact( + # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" + # ) + torch.save(self.activation_store.state_dict(), store_checkpoint_path) + + checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) + + self.ray_reporter(ray_metrics, checkpoint=checkpoint) # --- Explicitly delete tensors at the very end of the loop iteration --- # # Do this on all ranks From f24dab46f4214e781a3c2d6f9fd5014a4b52f641 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:50:00 -0400 Subject: [PATCH 07/11] created separate trainer class for Ray --- clt/training/trainer_ray.py | 948 ++++++++++++++++++++++++++++++++++++ 1 file changed, 948 insertions(+) create mode 100644 clt/training/trainer_ray.py diff --git a/clt/training/trainer_ray.py b/clt/training/trainer_ray.py new file mode 100644 index 0000000..7677846 --- /dev/null +++ b/clt/training/trainer_ray.py @@ -0,0 +1,948 @@ +import torch +import torch.optim as optim +from typing import Dict, Literal, Optional, Union, Any +from tqdm import tqdm # type: ignore +import os +import json +import time +import importlib.util +import sys +import logging # Add logging import +import datetime # Import datetime for formatting +import torch.distributed as dist # Import torch.distributed +from torch.distributed import ProcessGroup # Import ProcessGroup +from torch.distributed.checkpoint.state_dict_saver import save_state_dict +from torch.distributed.checkpoint.state_dict_loader import load_state_dict +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, DefaultLoadPlanner +from torch.distributed.checkpoint.filesystem import FileSystemWriter, FileSystemReader # Storage for checkpointing + +from clt.config import CLTConfig, TrainingConfig +from clt.models.clt import CrossLayerTranscoder +from clt.training.data import ( + BaseActivationStore, + StreamingActivationStore, + # MappedActivationStore, # Removed legacy store +) + +# Import the new manifest-based stores +from clt.training.local_activation_store import LocalActivationStore +from clt.training.remote_activation_store import RemoteActivationStore + +from clt.training.losses import LossManager +from clt.nnsight.extractor import ( + ActivationExtractorCLT, +) +from clt.training.trainer import CLTTrainer # Keep for StreamingStore usage +from .evaluator import CLTEvaluator # Import the new evaluator + + +# Get logger for this module +logger = logging.getLogger(__name__) + + +# Helper function to format elapsed time +def _format_elapsed_time(seconds: float) -> str: + """Formats elapsed seconds into HH:MM:SS or MM:SS.""" + td = datetime.timedelta(seconds=int(seconds)) + hours, remainder = divmod(td.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + if td.days > 0 or hours > 0: + return f"{td.days * 24 + hours:02d}:{minutes:02d}:{seconds:02d}" + else: + return f"{minutes:02d}:{seconds:02d}" + + +# # Define the dummy logger class explicitly for better type checking +# class DummyWandBLogger: +# def log_step(self, *args, **kwargs): +# pass + +# def log_evaluation(self, *args, **kwargs): +# pass + +# def log_artifact(self, *args, **kwargs): +# pass + +# def finish(self, *args, **kwargs): +# pass + + +# class WandBLogger: +# """Wrapper class for Weights & Biases logging.""" + +# def __init__(self, training_config: TrainingConfig, clt_config: CLTConfig, log_dir: str): +# """Initialize the WandB logger. + +# Args: +# training_config: Training configuration +# clt_config: CLT model configuration +# log_dir: Directory to save logs +# """ +# self.enabled = training_config.enable_wandb +# self.log_dir = log_dir + +# if not self.enabled: +# return + +# # Check if wandb is installed +# if not importlib.util.find_spec("wandb"): +# print( +# "Warning: WandB logging requested but wandb not installed. " +# "Install with 'pip install wandb'. Continuing without WandB." +# ) +# self.enabled = False +# return + +# # Import wandb +# import wandb + +# # Set up run name with timestamp if not provided +# run_name = training_config.wandb_run_name +# if run_name is None: +# run_name = f"clt-{time.strftime('%Y%m%d-%H%M%S')}" + +# # Initialize wandb +# wandb.init( +# project=training_config.wandb_project, +# entity=training_config.wandb_entity, +# name=run_name, +# dir=log_dir, +# tags=training_config.wandb_tags, +# config={ +# **training_config.__dict__, +# **clt_config.__dict__, +# "log_dir": log_dir, +# }, +# ) + +# if wandb.run is not None: +# print(f"WandB logging initialized: {wandb.run.name}") + +# def log_step( +# self, +# step: int, +# loss_dict: Dict[str, float], +# lr: Optional[float] = None, +# sparsity_lambda: Optional[float] = None, +# total_tokens_processed: Optional[int] = None, +# ): +# """Log metrics for a training step under the 'training/' group. + +# Args: +# step: Current training step +# loss_dict: Dictionary of loss values (e.g., total, reconstruction, sparsity) +# lr: Current learning rate +# sparsity_lambda: Current sparsity coefficient lambda +# total_tokens_processed: Total tokens processed up to this step +# """ +# if not self.enabled: +# return + +# import wandb + +# # Rename loss keys for clarity and add 'training/' prefix +# metrics = {} +# for key, value in loss_dict.items(): +# if key == "total": +# metrics["training/total_loss"] = value +# elif key == "sparsity": +# metrics["training/sparsity_loss"] = value +# elif key == "reconstruction": +# # Reconstruction loss is part of training, log it here too if present +# metrics["training/reconstruction_loss"] = value +# elif key == "preactivation": +# metrics["training/preactivation_loss"] = value +# else: +# # Keep other potential keys, prepending 'training/' +# metrics[f"training/{key}"] = value + +# # Add learning rate +# if lr is not None: +# metrics["training/learning_rate"] = lr + +# # Add sparsity lambda +# if sparsity_lambda is not None: +# metrics["training/sparsity_lambda"] = sparsity_lambda + +# # Add total tokens processed +# if total_tokens_processed is not None: +# metrics["training/total_tokens_processed"] = total_tokens_processed + +# # Log to wandb +# wandb.log(metrics, step=step) + +# def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): +# """Log evaluation metrics, organized by the structure from CLTEvaluator. + +# Args: +# step: Current training step +# eval_metrics: Dictionary of evaluation metrics from CLTEvaluator +# (keys like 'reconstruction/', 'sparsity/', 'layerwise/') +# """ +# if not self.enabled: +# return + +# import wandb + +# # Log metrics directly, assuming keys are already structured +# # e.g., 'reconstruction/mse', 'sparsity/avg_l0', 'layerwise/l0/layer_0' +# wandb_log_dict: Dict[str, Any] = {} +# for key, value in eval_metrics.items(): +# if key.startswith("layerwise/"): +# # Handle nested layerwise data (histograms and scalars) +# # layerwise_category = key.split("/")[ +# # 1 +# # ] # e.g., 'l0', 'log_feature_density' # Removed unused variable +# if isinstance(value, dict): +# for layer_key, layer_value in value.items(): +# # Construct wandb key: e.g., layerwise/l0/layer_0 +# wandb_key = f"{key}/{layer_key}" # Correctly forms e.g. layerwise/log_feature_density/layer_0 +# if isinstance(layer_value, list): +# # Log list data as histogram +# try: +# wandb_log_dict[wandb_key] = wandb.Histogram(layer_value) +# except Exception as e: +# print(f"Wandb: Error creating histogram for {wandb_key}: {e}") +# # Fallback: log mean or placeholder +# try: +# mean_val = sum(layer_value) / len(layer_value) if layer_value else 0.0 +# wandb_log_dict[f"{wandb_key}_mean"] = mean_val +# except TypeError: +# wandb_log_dict[f"{wandb_key}_mean"] = -1.0 +# elif isinstance(layer_value, (float, int)): +# # Log scalar layerwise data +# wandb_log_dict[wandb_key] = layer_value +# else: +# # If the top level key itself is scalar (shouldn't happen with current structure) +# wandb_log_dict[key] = value +# elif key.endswith("_agg_hist") and isinstance(value, list): +# # Handle aggregate histogram data (e.g., sparsity/log_feature_density_agg_hist) +# try: +# wandb_log_dict[key] = wandb.Histogram(value) +# except Exception as e: +# print(f"Wandb: Error creating aggregate histogram for {key}: {e}") +# # Optional Fallback: log mean of aggregate data +# try: +# mean_val = sum(value) / len(value) if value else 0.0 +# wandb_log_dict[f"{key}_mean"] = mean_val +# except TypeError: +# wandb_log_dict[f"{key}_mean"] = -1.0 + +# elif isinstance(value, (float, int)): # Handle top-level scalars +# # Log directly, e.g., 'reconstruction/mse', 'sparsity/avg_l0', 'dead_features/total_eval' +# wandb_log_dict[key] = value +# # Add other specific handling if needed (e.g., for specific non-scalar, non-layerwise data) + +# # Log the prepared dictionary to wandb +# if wandb_log_dict: +# wandb.log(wandb_log_dict, step=step) + +# def log_artifact(self, artifact_path: str, artifact_type: str, name: Optional[str] = None): +# """Log an artifact to WandB. + +# Args: +# artifact_path: Path to the artifact +# artifact_type: Type of artifact (e.g., "model", "dataset") +# name: Name of the artifact (defaults to filename) +# """ +# if not self.enabled: +# return + +# import wandb + +# # Use filename if name not provided +# if name is None: +# name = os.path.basename(artifact_path) + +# # Create and log artifact +# artifact = wandb.Artifact(name=name, type=artifact_type) +# # Check if it's a directory (for sharded checkpoints) +# if os.path.isdir(artifact_path): +# artifact.add_dir(artifact_path) +# else: +# artifact.add_file(artifact_path) +# wandb.log_artifact(artifact) + +# def finish(self): +# """Finish the WandB run.""" +# if not self.enabled: +# return + +# import wandb + +# wandb.finish() + + +class CLTTrainerRay(CLTTrainer): + """Trainer for Cross-Layer Transcoder models.""" + + # Add type hint for the activation store attribute + activation_store: BaseActivationStore + # Model type hint + model: CrossLayerTranscoder + + def __init__( + self, + clt_config: CLTConfig, + training_config: TrainingConfig, + log_dir: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + distributed: bool = False, # Add distributed flag + use_ray: Optional[Literal['train', 'tune']] = None, + ): + """Initialize the CLT trainer. + + Args: + clt_config: Configuration for the CLT model + training_config: Configuration for training + log_dir: Directory to save logs and checkpoints + device: Device to use for training (ignored if distributed) + distributed: Whether to use distributed training + """ + self.clt_config = clt_config + self.training_config = training_config + self.distributed = distributed + + # Initialize distributed training if enabled + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.process_group: Optional[ProcessGroup] = None # For tensor parallelism + + if self.distributed: + if not dist.is_initialized(): + # Default backend, consider NCCL for NVIDIA GPUs + dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.local_rank = int(os.environ.get("LOCAL_RANK", self.rank)) # Get local rank if available + + # Set device based on local_rank when distributed + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + else: + # Fallback for CPU distributed testing (not typical) + self.device = torch.device("cpu") + logger.warning("Distributed training requested but CUDA not available. Using CPU.") + # Set the process group for tensor parallelism (using WORLD for now) + self.process_group = dist.group.WORLD + else: + # Original device handling for non-distributed case + _device_input = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) + self.device = torch.device(_device_input) if isinstance(_device_input, str) else _device_input + # Process group is None when not distributed + self.process_group = None + + # Set up log directory - only rank 0 creates it + self.log_dir = log_dir or f"clt_train_{int(time.time())}" + if not self.distributed or self.rank == 0: + os.makedirs(self.log_dir, exist_ok=True) + + # Record start time + self.start_time = time.time() + + # Initialize model, passing device and process group for direct initialization + # self.process_group is correctly set to None if not distributed + self.model = CrossLayerTranscoder( + clt_config, process_group=self.process_group, device=self.device # Pass the potentially None group + ) + + # Initialize optimizer - works on local parameters + # Explicitly type the kwargs dict for clarity and linting + optimizer_kwargs: Dict[str, Any] = {"lr": training_config.learning_rate} + beta1 = training_config.optimizer_beta1 # Could be None + beta2 = training_config.optimizer_beta2 # Could be None + + # Only add 'betas' if at least one is specified + if beta1 is not None or beta2 is not None: + # Get defaults if one is None + # Default Adam/AdamW betas are (0.9, 0.999) + final_beta1 = beta1 if beta1 is not None else 0.9 + final_beta2 = beta2 if beta2 is not None else 0.999 + optimizer_kwargs["betas"] = (final_beta1, final_beta2) + logger.info(f"Rank {self.rank}: Using optimizer betas: ({final_beta1}, {final_beta2})") + + if training_config.optimizer == "adam": + self.optimizer: Any = optim.Adam(self.model.parameters(), **optimizer_kwargs) + else: # "adamw" + self.optimizer = optim.AdamW(self.model.parameters(), **optimizer_kwargs) + + # Initialize scheduler + self.scheduler: Optional[Any] = None + scheduler_type = training_config.lr_scheduler + # Get scheduler params from config, default to empty dict if None + scheduler_params = training_config.lr_scheduler_params or {} + + if scheduler_type == "linear": + # Default params for LinearLR + default_linear_params = { + "start_factor": 1.0, + "end_factor": 0.1, + # total_iters is always training_steps for this setup + } + # Update defaults with user-provided params + final_params = {**default_linear_params, **scheduler_params} + # Ensure total_iters is not overridden by user params + final_params.pop("total_iters", None) + + self.scheduler = optim.lr_scheduler.LinearLR( + self.optimizer, + total_iters=training_config.training_steps, + **final_params, # Pass start_factor, end_factor, etc. + ) + logger.info( + f"Rank {self.rank}: Using LinearLR scheduler with params: {final_params}, total_iters={training_config.training_steps}" + ) + + elif scheduler_type == "cosine": + # Default params for CosineAnnealingLR + default_cosine_params = { + # T_max defaults to training_steps + "eta_min": 0, # Default minimum LR + } + # Update defaults with user-provided params + final_params = {**default_cosine_params, **scheduler_params} + # Set T_max explicitly, allowing override but defaulting to training_steps + t_max = final_params.pop("T_max", training_config.training_steps) + + self.scheduler = optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, T_max=t_max, **final_params # Pass eta_min, etc. + ) + logger.info( + f"Rank {self.rank}: Using CosineAnnealingLR scheduler with params: {final_params}, T_max={t_max}" + ) + + elif scheduler_type == "linear_final20": + # This scheduler keeps LR constant for the initial fraction of training + # and then linearly decays it to 0 over the remaining steps (default 20%). + # The fraction can be customized via lr_scheduler_params["decay_start_frac"]. + decay_start_frac = scheduler_params.get("decay_start_frac", 0.8) # 0.8 means last 20% decays + assert 0.0 < decay_start_frac < 1.0, "decay_start_frac must be between 0 and 1" + total_steps = training_config.training_steps + decay_start_step = int(decay_start_frac * total_steps) + + def lr_lambda(current_step: int): + if current_step < decay_start_step: + return 1.0 # Keep LR constant + # Linearly decay from 1 -> 0 over the remaining steps + remaining = total_steps - current_step + decay_steps = total_steps - decay_start_step + return max(remaining / decay_steps, 0.0) + + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda) + logger.info( + "Rank %d: Using linear_final20 LR scheduler with decay_start_frac=%s (start step %d of %d)", + self.rank, + decay_start_frac, + decay_start_step, + total_steps, + ) + + # Add elif blocks here for other potential schedulers + + # Initialize activation store based on config - uses self.rank/world_size now + self.activation_store = self._create_activation_store(self.start_time) + + # Pass normalisation statistics (if available) so the loss can be computed in + # the *original* scale even when inputs/targets are stored normalised. + mean_tg_stats = getattr(self.activation_store, "mean_tg", {}) # type: ignore[arg-type] + std_tg_stats = getattr(self.activation_store, "std_tg", {}) # type: ignore[arg-type] + + self.loss_manager = LossManager( + training_config, + mean_tg=mean_tg_stats, + std_tg=std_tg_stats, + ) + + # Initialize Evaluator - Pass norm stats here too + self.evaluator = CLTEvaluator( + model=self.model, + device=self.device, + start_time=self.start_time, + mean_tg=mean_tg_stats, # Pass the same stats + std_tg=std_tg_stats, # Pass the same stats + ) + + # Initialize dead neuron counters (replicated for now, consider sharding later if needed) + self.n_forward_passes_since_fired = torch.zeros( + (clt_config.num_layers, clt_config.num_features), + device=self.device, + dtype=torch.long, + ) + + # Training metrics (only rank 0 saves, but others might need local copies for some logic) + self.metrics: Dict[str, list] = { + "train_losses": [], + "eval_metrics": [], + } + + # # Initialize WandB logger - only on rank 0 + # if not self.distributed or self.rank == 0: + # self.wandb_logger: Union[WandBLogger, DummyWandBLogger] = WandBLogger( + # training_config=training_config, clt_config=clt_config, log_dir=self.log_dir + # ) + # else: + # # Dummy logger for non-rank-0 processes + # self.wandb_logger = DummyWandBLogger() + + + # Set up imports for Ray if applicable + if use_ray: + if use_ray == 'train': + from ray.train import report + from ray.train import Checkpoint + elif use_ray == 'tune': + from ray.tune import report + from ray.tune import Checkpoint + + import tempfile + + self.ray_reporter = report + self.ray_checkpoint = Checkpoint + self.ray_tempfile = tempfile + + else: + self.ray_reporter = None + self.ray_checkpoint = None + self.ray_tempfile = None + + def _save_checkpoint(self, step: int): + """Save a distributed checkpoint of the model and activation store state. + + Uses torch.distributed.checkpoint to save sharded state directly. + + Args: + step: Current training step + """ + + with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: + + model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") + store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") + + torch.save(self.model.state_dict(), model_checkpoint_path) + # TODO: wandb_logger untested with ray reporting + # self.wandb_logger.log_artifact( + # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" + # ) + torch.save(self.activation_store.state_dict(), store_checkpoint_path) + + checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) + + return checkpoint + + + def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: + """Train the CLT model. + + Args: + eval_every: Evaluate model every N steps + + Returns: + Trained CLT model (local shard) + """ + # Print startup message from rank 0 only + if not self.distributed or self.rank == 0: + print(f"Starting CLT training on {self.device}...") + print( + f"Model has {self.clt_config.num_features} features per layer " + f"and {self.clt_config.num_layers} layers" + ) + print(f"Training for {self.training_config.training_steps} steps.") + print(f"Logging to {self.log_dir}") + if self.distributed: + print(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") + + # Check if using normalization and notify user + if self.training_config.normalization_method == "estimated_mean_std": + print("\n>>> NORMALIZATION PHASE <<<") + print("Normalization statistics are being estimated from dataset activations.") + print("This may take some time, but happens only once before training begins.") + print(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") + + # Make sure we flush stdout to ensure prints appear immediately, + # especially important in Jupyter/interactive environments + sys.stdout.flush() + # Wait for 1 second to ensure output is displayed before training starts + time.sleep(1) + print("\n>>> TRAINING PHASE <<<") + sys.stdout.flush() + + # # After the existing startup messages + # if self.distributed: + # print("\n!!! DIAGNOSTIC INFO !!!") + # print(f"Rank {self.rank}: Process group type: {type(self.process_group)}") + # print(f"Rank {self.rank}: RowParallelLinear _reduce does NOT divide by world_size") + # print(f"Rank {self.rank}: Using weight regularization in sparsity penalty") + # print(f"Rank {self.rank}: Averaging replicated parameter gradients") + # # Check if activation store has rank/world attributes before accessing + # store_rank = getattr(self.activation_store, "rank", "N/A") + # store_world = getattr(self.activation_store, "world", "N/A") + # print(f"Rank {self.rank}: Data sharding: rank={store_rank}, world={store_world}") + # print(f"Rank {self.rank}: Batch size tokens: {self.training_config.train_batch_size_tokens}") + # print(f"Rank {self.rank}: Sparsity lambda: {self.training_config.sparsity_lambda}") + + # # Check if activation store actually loaded correctly + # batch_avail = next(iter(self.activation_store), None) + # print(f"Rank {self.rank}: First batch available: {batch_avail is not None}") + + # # Force torch to compile/execute our code by running a tiny forward/backward pass + # dummy = torch.ones(1, device=self.device, requires_grad=True) + # dummy_out = dummy * 2 + # dummy_out.backward() + # print("!!! END DIAGNOSTIC !!!\n") + + # Create progress bar only on rank 0 + pbar: Union[tqdm, range] + if not self.distributed or self.rank == 0: + pbar = tqdm( + range(self.training_config.training_steps), + desc="Training CLT", + leave=True, + ) + else: + pbar = range(self.training_config.training_steps) + + step = 0 + try: + for step in pbar: + # Refresh progress bar on rank 0 + step_start_time = time.monotonic() # Start timing the step + if isinstance(pbar, tqdm): + pbar.refresh() + + try: + # Get batch directly from the iterator (handles distributed sampling internally) + batch_get_start_time = time.monotonic() + inputs, targets = next(self.activation_store) + batch_get_duration = time.monotonic() - batch_get_start_time + logger.debug(f"Rank {self.rank} Step {step}: Getting batch took {batch_get_duration:.4f}s") + + except StopIteration: + # Rank 0 prints message + if not self.distributed or self.rank == 0: + print("Activation store exhausted. Training finished early.") + if self.distributed: + dist.barrier() # Ensure all ranks see this + break # Exit training loop if data runs out + except Exception as e: + # Rank 0 prints message + if not self.distributed or self.rank == 0: + print(f"\nRank {self.rank}: Error getting batch at step {step}: {e}. Skipping step.") + # Maybe barrier here too? If one rank fails, others might hang? + # Let's continue for now, assuming store handles internal errors. + continue + + # --- Check for empty batch --- (Optional but good practice) + # This check should ideally happen *before* moving data potentially + if not inputs or not targets or not any(v.numel() > 0 for v in inputs.values()): + if not self.distributed or self.rank == 0: + print(f"\nRank {self.rank}: Warning: Received empty batch at step {step}. Skipping.") + continue + + # --- BEGIN: One-time Normalization Check --- + if step == 0 and (not self.distributed or self.rank == 0): + logger.info("--- Running Post-Normalization Check (First Batch) ---") + norm_applied = getattr(self.activation_store, "apply_normalization", None) + if isinstance(self.activation_store, (LocalActivationStore, RemoteActivationStore)): + logger.info(f"ActivationStore reports apply_normalization={norm_applied}") + elif isinstance(self.activation_store, StreamingActivationStore): + logger.info( + f"Streaming store normalization method: {self.activation_store.normalization_method}" + ) + + for li in range(self.clt_config.num_layers): + mean_in, std_in, mean_tg, std_tg = float("nan"), float("nan"), float("nan"), float("nan") + try: + if li in inputs and inputs[li].numel() > 0: + input_tensor = inputs[li].float() + mean_in = input_tensor.mean().item() + std_in = input_tensor.std().item() + if li in targets and targets[li].numel() > 0: + target_tensor = targets[li].float() + mean_tg = target_tensor.mean().item() + std_tg = target_tensor.std().item() + + if not ( + torch.isnan(torch.tensor(mean_in)) and torch.isnan(torch.tensor(mean_tg)) + ): # Log if at least one value is valid + logger.info( + f" Layer {li:>2}: Input Mean={mean_in:+.4f}, Std={std_in:.4f} | Target Mean={mean_tg:+.4f}, Std={std_tg:.4f}" + ) + except Exception as e: + logger.error(f" Layer {li}: Error during normalization check: {e}") + logger.info("--- End Post-Normalization Check ---") + # --- END: One-time Normalization Check --- + + # --- Forward pass and compute loss --- (All ranks) + self.optimizer.zero_grad() + + # Compute feature activations **once** per step to avoid redundant encoder forward passes. + feature_activations_batch = self.model.get_feature_activations(inputs) + + # Compute total loss using the pre-computed activations + loss, loss_dict = self.loss_manager.compute_total_loss( + self.model, + inputs, + targets, + step, + self.training_config.training_steps, + precomputed_activations=feature_activations_batch, + ) + + # --- Update Dead Neuron Counters --- (All ranks, counter is replicated) + # We need *full* feature activations *after* non-linearity + if hasattr(self, "n_forward_passes_since_fired"): + with torch.no_grad(): + for layer_idx, layer_acts in feature_activations_batch.items(): + # Ensure layer index is within bounds of the counter tensor + if layer_idx < self.n_forward_passes_since_fired.shape[0]: + if layer_acts.numel() > 0: + # layer_acts shape: [batch_tokens, num_features] + fired_mask_per_token = layer_acts > 1e-6 + fired_features_this_layer = fired_mask_per_token.any(dim=0) + + if fired_features_this_layer.shape[0] == self.n_forward_passes_since_fired.shape[1]: + self.n_forward_passes_since_fired[layer_idx] += 1 + self.n_forward_passes_since_fired[layer_idx][fired_features_this_layer] = 0 + else: + if not self.distributed or self.rank == 0: # Only rank 0 logs warning + print( + f"Rank {self.rank}: Warning: Shape mismatch for dead neuron update at layer {layer_idx}. " + f"Acts shape: {layer_acts.shape}, Fired mask: {fired_features_this_layer.shape}, " + f"Counter: {self.n_forward_passes_since_fired.shape}" + ) + + # --- Backward pass --- (All ranks, handles communication implicitly) + if torch.isnan(loss): + if not self.distributed or self.rank == 0: + print( + f"\nRank {self.rank}: Warning: NaN loss encountered at step {step}. " + f"Skipping backward pass and optimizer step." + ) + else: + try: + loss.backward() + + # --- Synchronise gradients of replicated parameters --- # + self._average_shared_parameter_grads() + + # --- Gradient clipping --- # + if self.training_config.gradient_clip_val is not None: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.training_config.gradient_clip_val, + ) + except RuntimeError as e: + if not self.distributed or self.rank == 0: + print( + f"\nRank {self.rank}: Error during backward pass at step {step}: {e}. Skipping optimizer step." + ) + continue + + # --- Optimizer step --- (Applied to local parameters using local gradients) + self.optimizer.step() + + # --- Invalidate Caches --- # + if hasattr(self.model, "_cached_decoder_norms"): + self.model._cached_decoder_norms = None + + # --- Scheduler step --- (All ranks) + if self.scheduler: + self.scheduler.step() + + # --- Update progress bar --- (Rank 0 only) + if isinstance(pbar, tqdm): + description = ( + f"Loss: {loss_dict.get('total', float('nan')):.4f} " + f"(R: {loss_dict.get('reconstruction', float('nan')):.4f} " + f"S: {loss_dict.get('sparsity', float('nan')):.4f} " + f"P: {loss_dict.get('preactivation', float('nan')):.4f})" + ) + pbar.set_description(description) + # Force update to display progress + if step % 1 == 0: # Update every step + pbar.refresh() + sys.stdout.flush() + + # --- Log metrics --- (Rank 0 logs to WandB/file) + # self._log_metrics(step, loss_dict) + + step_duration = time.monotonic() - step_start_time + logger.debug( + f"Rank {self.rank} Step {step}: Main logic (incl. batch get, fwd, bwd, optim) took {step_duration:.4f}s" + ) + + # --- Evaluation & Checkpointing --- + eval_interval = self.training_config.eval_interval + checkpoint_interval = self.training_config.checkpoint_interval + + save_checkpoint_flag = (step > 0 and step % checkpoint_interval == 0) or ( # Avoid checkpoint at step 0 + step == self.training_config.training_steps - 1 + ) + run_eval_flag = (step % eval_interval == 0) or (step == self.training_config.training_steps - 1) + + # --- Evaluation (all ranks participate to match collectives) --- + # In tensor-parallel mode the model forward includes collective ops (all_reduce/all_gather). + # If only rank 0 performed the forward pass these collectives would block on the other ranks + # resulting in NCCL timeouts. Therefore, *every* rank must execute the evaluation forward pass. + # We still only log / store the resulting metrics on rank 0. + if run_eval_flag: + # if self.distributed: + # dist.barrier() # Sync before evaluation starts so that all ranks enter together + + # Compute evaluation metrics on all ranks to keep collective ops aligned + current_dead_mask = self.dead_neurons_mask.detach().clone() + eval_metrics = self.evaluator.compute_metrics( + inputs, + targets, + dead_neuron_mask=current_dead_mask, + ) + + if not self.distributed or self.rank == 0: + # Store evaluation metrics (for saving to JSON) + # self.metrics["eval_metrics"].append({"step": step, **eval_metrics}) + + # --- Update Progress Bar Postfix --- + l0_str = f"AvgL0: {eval_metrics.get('sparsity/avg_l0', 0.0):.2f}" + ev_str = f"EV: {eval_metrics.get('reconstruction/explained_variance', 0.0):.3f}" + avg_density_mean = eval_metrics.get("sparsity/feature_density_mean") + dens_str = f"Dens: {avg_density_mean:.3f}" if avg_density_mean is not None else "Dens: N/A" + eval_dead_str = f"Dead(Eval): {eval_metrics.get('dead_features/total_eval', 0)}" + eval_msg = f"{l0_str}, {ev_str}, {dens_str}, {eval_dead_str}" + + if isinstance(pbar, tqdm): + pbar.set_postfix_str(eval_msg) + pbar.refresh() + + # # --- Log evaluation metrics to WandB --- + # self.wandb_logger.log_evaluation(step, eval_metrics) + + # # --- Save metrics JSON after evaluation --- + # self._save_metrics() + + # Optionally compute and log sparsity diagnostics (can be slow) + if self.training_config.compute_sparsity_diagnostics: + # Calculate diagnostics using the same batch data and cached activations/norms + sparsity_diag_metrics = self._compute_sparsity_diagnostics(inputs, feature_activations_batch) + # Merge diagnostics into the main eval metrics dict + if sparsity_diag_metrics: + eval_metrics.update(sparsity_diag_metrics) + # Log updated metrics to WandB (only rank 0) + # if not self.distributed or self.rank == 0: + # self.wandb_logger.log_evaluation(step, eval_metrics) + + # # Ensure all ranks finish evaluation before proceeding + # if self.distributed: + # dist.barrier() + + # --- Checkpointing & Reporting --- + # Make metrics dict to report + ray_metrics = loss_dict + + # Add hp info + if self.scheduler is not None: + # Assuming one parameter group + current_lr = self.scheduler.get_last_lr()[0] + + current_lambda = self.loss_manager.get_current_sparsity_lambda() + + ray_metrics |= { + 'lr': current_lr, + 'sparsity_lambda': current_lambda + } + + if run_eval_flag: + ray_metrics |= eval_metrics + + # Now report to ray and potentially checkpoint + with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + + if save_checkpoint_flag: + model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") + store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") + + torch.save(self.model.state_dict(), model_checkpoint_path) + # TODO: wandb_logger untested with ray reporting + # self.wandb_logger.log_artifact( + # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" + # ) + torch.save(self.activation_store.state_dict(), store_checkpoint_path) + + checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) + + self.ray_reporter(ray_metrics, checkpoint=checkpoint) + + # # --- Explicitly delete tensors at the very end of the loop iteration --- # + # # Do this on all ranks + # try: + # del inputs + # del targets + # if "loss" in locals() and loss is not None: + # del loss + # if "feature_activations_batch" in locals(): + # del feature_activations_batch + # except NameError: + # pass + + except KeyboardInterrupt: + if not self.distributed or self.rank == 0: + print("\nTraining interrupted by user.") + finally: + if isinstance(pbar, tqdm): + pbar.close() + if not self.distributed or self.rank == 0: + print(f"Training loop finished at step {step}.") + + # # Sync before final save attempt + # if self.distributed: + # dist.barrier() + + # # --- Save final model and metrics --- (Rank 0 handles metrics/store, all ranks save model state) + # final_checkpoint_dir = os.path.join(self.log_dir, "final") + # final_store_path = os.path.join(final_checkpoint_dir, "activation_store_final.pt") # Store inside final dir + + # # All ranks save final model state + # try: + # final_model_state_dict = self.model.state_dict() + # save_state_dict( + # state_dict=final_model_state_dict, + # storage_writer=FileSystemWriter(final_checkpoint_dir), + # planner=DefaultSavePlanner(), + # no_dist=(not self.distributed), # Disable distributed save if not distributed + # ) + # except Exception as e: + # print(f"Rank {self.rank}: Warning: Failed to save final distributed model state: {e}") + + # # Rank 0 saves store, metrics, logs artifact + # if not self.distributed or self.rank == 0: + # print(f"Saving final activation store state to {final_store_path}...") + # os.makedirs(final_checkpoint_dir, exist_ok=True) # Ensure dir exists for store save + # try: + # # Check if the store has a close method before calling (for compatibility) + # if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): + # self.activation_store.close() + # except Exception as e: + # print(f"Rank 0: Warning: Failed to close activation store: {e}") + + # print("Saving final metrics...") + # self._save_metrics() + + # # Log final checkpoint directory as artifact + # self.wandb_logger.log_artifact(artifact_path=final_checkpoint_dir, artifact_type="model", name="clt_final") + + # # Finish WandB logging + # self.wandb_logger.finish() + # print(f"Training completed! Final checkpoint saved to {final_checkpoint_dir}") + + # --- Close the activation store (stops prefetch thread if applicable) --- # + if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): + self.activation_store.close() + + # Clean up distributed process group + if self.distributed: + dist.destroy_process_group() + + return self.model From 9bae990e49b006692376082173692d26f833fda4 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:51:04 -0400 Subject: [PATCH 08/11] restoring non-ray clt trainer code back to before ray changes --- clt/training/trainer.py | 53 ++--------------------------------------- 1 file changed, 2 insertions(+), 51 deletions(-) diff --git a/clt/training/trainer.py b/clt/training/trainer.py index 5266014..0b9b760 100644 --- a/clt/training/trainer.py +++ b/clt/training/trainer.py @@ -1,6 +1,6 @@ import torch import torch.optim as optim -from typing import Dict, Literal, Optional, Union, Any +from typing import Dict, Optional, Union, Any from tqdm import tqdm # type: ignore import os import json @@ -34,7 +34,6 @@ ) # Keep for StreamingStore usage from .evaluator import CLTEvaluator # Import the new evaluator - # Get logger for this module logger = logging.getLogger(__name__) @@ -287,7 +286,6 @@ def __init__( log_dir: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, distributed: bool = False, # Add distributed flag - use_ray: Optional[Literal['train', 'tune']] = None, ): """Initialize the CLT trainer. @@ -485,28 +483,6 @@ def lr_lambda(current_step: int): # Dummy logger for non-rank-0 processes self.wandb_logger = DummyWandBLogger() - - # Set up imports for Ray if applicable - if use_ray: - if use_ray == 'train': - from ray.train import report - from ray.train import Checkpoint - elif use_ray == 'tune': - from ray.tune import report - from ray.tune import Checkpoint - - import tempfile - - self.ray_reporter = report - self.ray_checkpoint = Checkpoint - self.ray_tempfile = tempfile - - else: - self.ray_reporter = None - self.ray_checkpoint = None - self.ray_tempfile = None - - @property def dead_neurons_mask(self) -> torch.Tensor: """Boolean mask indicating dead neurons based on inactivity window.""" @@ -1272,34 +1248,9 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: dist.barrier() # --- Checkpointing (All ranks participate) --- - if save_checkpoint_flag and not self.ray_reporter: + if save_checkpoint_flag: self._save_checkpoint(step) - # Report to ray if enabled - if self.ray_reporter: - # Make metrics dict to report - ray_metrics = loss_dict - if run_eval_flag: - ray_metrics = ray_metrics | eval_metrics - - with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: - checkpoint = None - - if save_checkpoint_flag: - model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") - store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") - - torch.save(self.model.state_dict(), model_checkpoint_path) - # TODO: wandb_logger untested with ray reporting - # self.wandb_logger.log_artifact( - # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" - # ) - torch.save(self.activation_store.state_dict(), store_checkpoint_path) - - checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) - - self.ray_reporter(ray_metrics, checkpoint=checkpoint) - # --- Explicitly delete tensors at the very end of the loop iteration --- # # Do this on all ranks try: From 74d1c75d3f33072081cf766cc58768a76bdc8536 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:53:17 -0400 Subject: [PATCH 09/11] making use_ray param required --- clt/training/trainer_ray.py | 55 ++++++++----------------------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/clt/training/trainer_ray.py b/clt/training/trainer_ray.py index 7677846..c1de137 100644 --- a/clt/training/trainer_ray.py +++ b/clt/training/trainer_ray.py @@ -288,7 +288,7 @@ def __init__( log_dir: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, distributed: bool = False, # Add distributed flag - use_ray: Optional[Literal['train', 'tune']] = None, + use_ray: Literal['train', 'tune'] = 'tune', ): """Initialize the CLT trainer. @@ -487,50 +487,19 @@ def lr_lambda(current_step: int): # self.wandb_logger = DummyWandBLogger() - # Set up imports for Ray if applicable - if use_ray: - if use_ray == 'train': - from ray.train import report - from ray.train import Checkpoint - elif use_ray == 'tune': - from ray.tune import report - from ray.tune import Checkpoint + # Set up imports for Ray + if use_ray == 'train': + from ray.train import report + from ray.train import Checkpoint + elif use_ray == 'tune': + from ray.tune import report + from ray.tune import Checkpoint - import tempfile + import tempfile - self.ray_reporter = report - self.ray_checkpoint = Checkpoint - self.ray_tempfile = tempfile - - else: - self.ray_reporter = None - self.ray_checkpoint = None - self.ray_tempfile = None - - def _save_checkpoint(self, step: int): - """Save a distributed checkpoint of the model and activation store state. - - Uses torch.distributed.checkpoint to save sharded state directly. - - Args: - step: Current training step - """ - - with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: - - model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") - store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") - - torch.save(self.model.state_dict(), model_checkpoint_path) - # TODO: wandb_logger untested with ray reporting - # self.wandb_logger.log_artifact( - # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" - # ) - torch.save(self.activation_store.state_dict(), store_checkpoint_path) - - checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) - - return checkpoint + self.ray_reporter = report + self.ray_checkpoint = Checkpoint + self.ray_tempfile = tempfile def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: From 7ac3a91b1ce0cf787c0674187a70ea609dc38ca7 Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:54:36 -0400 Subject: [PATCH 10/11] explicitly setting cpu per worker and example hps --- tune_clt_local_ray.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py index 4dc0955..d273587 100644 --- a/tune_clt_local_ray.py +++ b/tune_clt_local_ray.py @@ -13,6 +13,8 @@ import json import ray, ray.tune +from clt.training.trainer_ray import CLTTrainerRay + # Attempt to import transformers for model dimension detection try: from transformers import AutoConfig @@ -144,12 +146,12 @@ def train_loop_per_worker(cfg): ) logger.info(f"Training Config: {training_config}") - trainer = CLTTrainer( + trainer = CLTTrainerRay( clt_config=clt_config, training_config=training_config, log_dir=cfg['log_dir'], device=device, - distributed=False, # Explicitly set for tutorial + distributed=False, # distributing model across gpus within a ray tune worker not implemented use_ray='tune' ) @@ -410,16 +412,14 @@ def main(): } # add actual hyperparameters hps |= { - # 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), - # 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]), - 'sparsity_c': ray.tune.grid_search([0.01]), - 'sparsity_lambda': ray.tune.grid_search([1e-5]), + 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), + 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]), } logger.info("Starting hyperparameter tuning from local activations...") try: tuner = ray.tune.Tuner( - ray.tune.with_resources(train_loop_per_worker, {'gpu': args.n_gpus / args.n_workers}), + ray.tune.with_resources(train_loop_per_worker, {'cpu': 1, 'gpu': args.n_gpus / args.n_workers}), param_space=hps, tune_config=ray.tune.TuneConfig(max_concurrent_trials=args.n_workers), run_config=ray.tune.RunConfig(storage_path=str(output_dir.resolve())) From 9424595f3af40e56bd9ffcb66d43f15be5123cac Mon Sep 17 00:00:00 2001 From: Andy Kim Date: Wed, 7 May 2025 15:58:48 -0400 Subject: [PATCH 11/11] cleaning up init func --- clt/training/trainer_ray.py | 192 +----------------------------------- 1 file changed, 3 insertions(+), 189 deletions(-) diff --git a/clt/training/trainer_ray.py b/clt/training/trainer_ray.py index c1de137..228ce0c 100644 --- a/clt/training/trainer_ray.py +++ b/clt/training/trainer_ray.py @@ -287,7 +287,7 @@ def __init__( training_config: TrainingConfig, log_dir: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, - distributed: bool = False, # Add distributed flag + distributed: bool = False, # unused for Ray as of now use_ray: Literal['train', 'tune'] = 'tune', ): """Initialize the CLT trainer. @@ -299,193 +299,7 @@ def __init__( device: Device to use for training (ignored if distributed) distributed: Whether to use distributed training """ - self.clt_config = clt_config - self.training_config = training_config - self.distributed = distributed - - # Initialize distributed training if enabled - self.rank = 0 - self.world_size = 1 - self.local_rank = 0 - self.process_group: Optional[ProcessGroup] = None # For tensor parallelism - - if self.distributed: - if not dist.is_initialized(): - # Default backend, consider NCCL for NVIDIA GPUs - dist.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.local_rank = int(os.environ.get("LOCAL_RANK", self.rank)) # Get local rank if available - - # Set device based on local_rank when distributed - if torch.cuda.is_available(): - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - else: - # Fallback for CPU distributed testing (not typical) - self.device = torch.device("cpu") - logger.warning("Distributed training requested but CUDA not available. Using CPU.") - # Set the process group for tensor parallelism (using WORLD for now) - self.process_group = dist.group.WORLD - else: - # Original device handling for non-distributed case - _device_input = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - self.device = torch.device(_device_input) if isinstance(_device_input, str) else _device_input - # Process group is None when not distributed - self.process_group = None - - # Set up log directory - only rank 0 creates it - self.log_dir = log_dir or f"clt_train_{int(time.time())}" - if not self.distributed or self.rank == 0: - os.makedirs(self.log_dir, exist_ok=True) - - # Record start time - self.start_time = time.time() - - # Initialize model, passing device and process group for direct initialization - # self.process_group is correctly set to None if not distributed - self.model = CrossLayerTranscoder( - clt_config, process_group=self.process_group, device=self.device # Pass the potentially None group - ) - - # Initialize optimizer - works on local parameters - # Explicitly type the kwargs dict for clarity and linting - optimizer_kwargs: Dict[str, Any] = {"lr": training_config.learning_rate} - beta1 = training_config.optimizer_beta1 # Could be None - beta2 = training_config.optimizer_beta2 # Could be None - - # Only add 'betas' if at least one is specified - if beta1 is not None or beta2 is not None: - # Get defaults if one is None - # Default Adam/AdamW betas are (0.9, 0.999) - final_beta1 = beta1 if beta1 is not None else 0.9 - final_beta2 = beta2 if beta2 is not None else 0.999 - optimizer_kwargs["betas"] = (final_beta1, final_beta2) - logger.info(f"Rank {self.rank}: Using optimizer betas: ({final_beta1}, {final_beta2})") - - if training_config.optimizer == "adam": - self.optimizer: Any = optim.Adam(self.model.parameters(), **optimizer_kwargs) - else: # "adamw" - self.optimizer = optim.AdamW(self.model.parameters(), **optimizer_kwargs) - - # Initialize scheduler - self.scheduler: Optional[Any] = None - scheduler_type = training_config.lr_scheduler - # Get scheduler params from config, default to empty dict if None - scheduler_params = training_config.lr_scheduler_params or {} - - if scheduler_type == "linear": - # Default params for LinearLR - default_linear_params = { - "start_factor": 1.0, - "end_factor": 0.1, - # total_iters is always training_steps for this setup - } - # Update defaults with user-provided params - final_params = {**default_linear_params, **scheduler_params} - # Ensure total_iters is not overridden by user params - final_params.pop("total_iters", None) - - self.scheduler = optim.lr_scheduler.LinearLR( - self.optimizer, - total_iters=training_config.training_steps, - **final_params, # Pass start_factor, end_factor, etc. - ) - logger.info( - f"Rank {self.rank}: Using LinearLR scheduler with params: {final_params}, total_iters={training_config.training_steps}" - ) - - elif scheduler_type == "cosine": - # Default params for CosineAnnealingLR - default_cosine_params = { - # T_max defaults to training_steps - "eta_min": 0, # Default minimum LR - } - # Update defaults with user-provided params - final_params = {**default_cosine_params, **scheduler_params} - # Set T_max explicitly, allowing override but defaulting to training_steps - t_max = final_params.pop("T_max", training_config.training_steps) - - self.scheduler = optim.lr_scheduler.CosineAnnealingLR( - self.optimizer, T_max=t_max, **final_params # Pass eta_min, etc. - ) - logger.info( - f"Rank {self.rank}: Using CosineAnnealingLR scheduler with params: {final_params}, T_max={t_max}" - ) - - elif scheduler_type == "linear_final20": - # This scheduler keeps LR constant for the initial fraction of training - # and then linearly decays it to 0 over the remaining steps (default 20%). - # The fraction can be customized via lr_scheduler_params["decay_start_frac"]. - decay_start_frac = scheduler_params.get("decay_start_frac", 0.8) # 0.8 means last 20% decays - assert 0.0 < decay_start_frac < 1.0, "decay_start_frac must be between 0 and 1" - total_steps = training_config.training_steps - decay_start_step = int(decay_start_frac * total_steps) - - def lr_lambda(current_step: int): - if current_step < decay_start_step: - return 1.0 # Keep LR constant - # Linearly decay from 1 -> 0 over the remaining steps - remaining = total_steps - current_step - decay_steps = total_steps - decay_start_step - return max(remaining / decay_steps, 0.0) - - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda) - logger.info( - "Rank %d: Using linear_final20 LR scheduler with decay_start_frac=%s (start step %d of %d)", - self.rank, - decay_start_frac, - decay_start_step, - total_steps, - ) - - # Add elif blocks here for other potential schedulers - - # Initialize activation store based on config - uses self.rank/world_size now - self.activation_store = self._create_activation_store(self.start_time) - - # Pass normalisation statistics (if available) so the loss can be computed in - # the *original* scale even when inputs/targets are stored normalised. - mean_tg_stats = getattr(self.activation_store, "mean_tg", {}) # type: ignore[arg-type] - std_tg_stats = getattr(self.activation_store, "std_tg", {}) # type: ignore[arg-type] - - self.loss_manager = LossManager( - training_config, - mean_tg=mean_tg_stats, - std_tg=std_tg_stats, - ) - - # Initialize Evaluator - Pass norm stats here too - self.evaluator = CLTEvaluator( - model=self.model, - device=self.device, - start_time=self.start_time, - mean_tg=mean_tg_stats, # Pass the same stats - std_tg=std_tg_stats, # Pass the same stats - ) - - # Initialize dead neuron counters (replicated for now, consider sharding later if needed) - self.n_forward_passes_since_fired = torch.zeros( - (clt_config.num_layers, clt_config.num_features), - device=self.device, - dtype=torch.long, - ) - - # Training metrics (only rank 0 saves, but others might need local copies for some logic) - self.metrics: Dict[str, list] = { - "train_losses": [], - "eval_metrics": [], - } - - # # Initialize WandB logger - only on rank 0 - # if not self.distributed or self.rank == 0: - # self.wandb_logger: Union[WandBLogger, DummyWandBLogger] = WandBLogger( - # training_config=training_config, clt_config=clt_config, log_dir=self.log_dir - # ) - # else: - # # Dummy logger for non-rank-0 processes - # self.wandb_logger = DummyWandBLogger() - + super().__init__(clt_config, training_config, log_dir, device, False) # Set up imports for Ray if use_ray == 'train': @@ -770,7 +584,7 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: ) if not self.distributed or self.rank == 0: - # Store evaluation metrics (for saving to JSON) + # # Store evaluation metrics (for saving to JSON) # self.metrics["eval_metrics"].append({"step": step, **eval_metrics}) # --- Update Progress Bar Postfix ---