diff --git a/clt/training/trainer_ray.py b/clt/training/trainer_ray.py new file mode 100644 index 0000000..228ce0c --- /dev/null +++ b/clt/training/trainer_ray.py @@ -0,0 +1,731 @@ +import torch +import torch.optim as optim +from typing import Dict, Literal, Optional, Union, Any +from tqdm import tqdm # type: ignore +import os +import json +import time +import importlib.util +import sys +import logging # Add logging import +import datetime # Import datetime for formatting +import torch.distributed as dist # Import torch.distributed +from torch.distributed import ProcessGroup # Import ProcessGroup +from torch.distributed.checkpoint.state_dict_saver import save_state_dict +from torch.distributed.checkpoint.state_dict_loader import load_state_dict +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner, DefaultLoadPlanner +from torch.distributed.checkpoint.filesystem import FileSystemWriter, FileSystemReader # Storage for checkpointing + +from clt.config import CLTConfig, TrainingConfig +from clt.models.clt import CrossLayerTranscoder +from clt.training.data import ( + BaseActivationStore, + StreamingActivationStore, + # MappedActivationStore, # Removed legacy store +) + +# Import the new manifest-based stores +from clt.training.local_activation_store import LocalActivationStore +from clt.training.remote_activation_store import RemoteActivationStore + +from clt.training.losses import LossManager +from clt.nnsight.extractor import ( + ActivationExtractorCLT, +) +from clt.training.trainer import CLTTrainer # Keep for StreamingStore usage +from .evaluator import CLTEvaluator # Import the new evaluator + + +# Get logger for this module +logger = logging.getLogger(__name__) + + +# Helper function to format elapsed time +def _format_elapsed_time(seconds: float) -> str: + """Formats elapsed seconds into HH:MM:SS or MM:SS.""" + td = datetime.timedelta(seconds=int(seconds)) + hours, remainder = divmod(td.seconds, 3600) + minutes, seconds = divmod(remainder, 60) + if td.days > 0 or hours > 0: + return f"{td.days * 24 + hours:02d}:{minutes:02d}:{seconds:02d}" + else: + return f"{minutes:02d}:{seconds:02d}" + + +# # Define the dummy logger class explicitly for better type checking +# class DummyWandBLogger: +# def log_step(self, *args, **kwargs): +# pass + +# def log_evaluation(self, *args, **kwargs): +# pass + +# def log_artifact(self, *args, **kwargs): +# pass + +# def finish(self, *args, **kwargs): +# pass + + +# class WandBLogger: +# """Wrapper class for Weights & Biases logging.""" + +# def __init__(self, training_config: TrainingConfig, clt_config: CLTConfig, log_dir: str): +# """Initialize the WandB logger. + +# Args: +# training_config: Training configuration +# clt_config: CLT model configuration +# log_dir: Directory to save logs +# """ +# self.enabled = training_config.enable_wandb +# self.log_dir = log_dir + +# if not self.enabled: +# return + +# # Check if wandb is installed +# if not importlib.util.find_spec("wandb"): +# print( +# "Warning: WandB logging requested but wandb not installed. " +# "Install with 'pip install wandb'. Continuing without WandB." +# ) +# self.enabled = False +# return + +# # Import wandb +# import wandb + +# # Set up run name with timestamp if not provided +# run_name = training_config.wandb_run_name +# if run_name is None: +# run_name = f"clt-{time.strftime('%Y%m%d-%H%M%S')}" + +# # Initialize wandb +# wandb.init( +# project=training_config.wandb_project, +# entity=training_config.wandb_entity, +# name=run_name, +# dir=log_dir, +# tags=training_config.wandb_tags, +# config={ +# **training_config.__dict__, +# **clt_config.__dict__, +# "log_dir": log_dir, +# }, +# ) + +# if wandb.run is not None: +# print(f"WandB logging initialized: {wandb.run.name}") + +# def log_step( +# self, +# step: int, +# loss_dict: Dict[str, float], +# lr: Optional[float] = None, +# sparsity_lambda: Optional[float] = None, +# total_tokens_processed: Optional[int] = None, +# ): +# """Log metrics for a training step under the 'training/' group. + +# Args: +# step: Current training step +# loss_dict: Dictionary of loss values (e.g., total, reconstruction, sparsity) +# lr: Current learning rate +# sparsity_lambda: Current sparsity coefficient lambda +# total_tokens_processed: Total tokens processed up to this step +# """ +# if not self.enabled: +# return + +# import wandb + +# # Rename loss keys for clarity and add 'training/' prefix +# metrics = {} +# for key, value in loss_dict.items(): +# if key == "total": +# metrics["training/total_loss"] = value +# elif key == "sparsity": +# metrics["training/sparsity_loss"] = value +# elif key == "reconstruction": +# # Reconstruction loss is part of training, log it here too if present +# metrics["training/reconstruction_loss"] = value +# elif key == "preactivation": +# metrics["training/preactivation_loss"] = value +# else: +# # Keep other potential keys, prepending 'training/' +# metrics[f"training/{key}"] = value + +# # Add learning rate +# if lr is not None: +# metrics["training/learning_rate"] = lr + +# # Add sparsity lambda +# if sparsity_lambda is not None: +# metrics["training/sparsity_lambda"] = sparsity_lambda + +# # Add total tokens processed +# if total_tokens_processed is not None: +# metrics["training/total_tokens_processed"] = total_tokens_processed + +# # Log to wandb +# wandb.log(metrics, step=step) + +# def log_evaluation(self, step: int, eval_metrics: Dict[str, Any]): +# """Log evaluation metrics, organized by the structure from CLTEvaluator. + +# Args: +# step: Current training step +# eval_metrics: Dictionary of evaluation metrics from CLTEvaluator +# (keys like 'reconstruction/', 'sparsity/', 'layerwise/') +# """ +# if not self.enabled: +# return + +# import wandb + +# # Log metrics directly, assuming keys are already structured +# # e.g., 'reconstruction/mse', 'sparsity/avg_l0', 'layerwise/l0/layer_0' +# wandb_log_dict: Dict[str, Any] = {} +# for key, value in eval_metrics.items(): +# if key.startswith("layerwise/"): +# # Handle nested layerwise data (histograms and scalars) +# # layerwise_category = key.split("/")[ +# # 1 +# # ] # e.g., 'l0', 'log_feature_density' # Removed unused variable +# if isinstance(value, dict): +# for layer_key, layer_value in value.items(): +# # Construct wandb key: e.g., layerwise/l0/layer_0 +# wandb_key = f"{key}/{layer_key}" # Correctly forms e.g. layerwise/log_feature_density/layer_0 +# if isinstance(layer_value, list): +# # Log list data as histogram +# try: +# wandb_log_dict[wandb_key] = wandb.Histogram(layer_value) +# except Exception as e: +# print(f"Wandb: Error creating histogram for {wandb_key}: {e}") +# # Fallback: log mean or placeholder +# try: +# mean_val = sum(layer_value) / len(layer_value) if layer_value else 0.0 +# wandb_log_dict[f"{wandb_key}_mean"] = mean_val +# except TypeError: +# wandb_log_dict[f"{wandb_key}_mean"] = -1.0 +# elif isinstance(layer_value, (float, int)): +# # Log scalar layerwise data +# wandb_log_dict[wandb_key] = layer_value +# else: +# # If the top level key itself is scalar (shouldn't happen with current structure) +# wandb_log_dict[key] = value +# elif key.endswith("_agg_hist") and isinstance(value, list): +# # Handle aggregate histogram data (e.g., sparsity/log_feature_density_agg_hist) +# try: +# wandb_log_dict[key] = wandb.Histogram(value) +# except Exception as e: +# print(f"Wandb: Error creating aggregate histogram for {key}: {e}") +# # Optional Fallback: log mean of aggregate data +# try: +# mean_val = sum(value) / len(value) if value else 0.0 +# wandb_log_dict[f"{key}_mean"] = mean_val +# except TypeError: +# wandb_log_dict[f"{key}_mean"] = -1.0 + +# elif isinstance(value, (float, int)): # Handle top-level scalars +# # Log directly, e.g., 'reconstruction/mse', 'sparsity/avg_l0', 'dead_features/total_eval' +# wandb_log_dict[key] = value +# # Add other specific handling if needed (e.g., for specific non-scalar, non-layerwise data) + +# # Log the prepared dictionary to wandb +# if wandb_log_dict: +# wandb.log(wandb_log_dict, step=step) + +# def log_artifact(self, artifact_path: str, artifact_type: str, name: Optional[str] = None): +# """Log an artifact to WandB. + +# Args: +# artifact_path: Path to the artifact +# artifact_type: Type of artifact (e.g., "model", "dataset") +# name: Name of the artifact (defaults to filename) +# """ +# if not self.enabled: +# return + +# import wandb + +# # Use filename if name not provided +# if name is None: +# name = os.path.basename(artifact_path) + +# # Create and log artifact +# artifact = wandb.Artifact(name=name, type=artifact_type) +# # Check if it's a directory (for sharded checkpoints) +# if os.path.isdir(artifact_path): +# artifact.add_dir(artifact_path) +# else: +# artifact.add_file(artifact_path) +# wandb.log_artifact(artifact) + +# def finish(self): +# """Finish the WandB run.""" +# if not self.enabled: +# return + +# import wandb + +# wandb.finish() + + +class CLTTrainerRay(CLTTrainer): + """Trainer for Cross-Layer Transcoder models.""" + + # Add type hint for the activation store attribute + activation_store: BaseActivationStore + # Model type hint + model: CrossLayerTranscoder + + def __init__( + self, + clt_config: CLTConfig, + training_config: TrainingConfig, + log_dir: Optional[str] = None, + device: Optional[Union[str, torch.device]] = None, + distributed: bool = False, # unused for Ray as of now + use_ray: Literal['train', 'tune'] = 'tune', + ): + """Initialize the CLT trainer. + + Args: + clt_config: Configuration for the CLT model + training_config: Configuration for training + log_dir: Directory to save logs and checkpoints + device: Device to use for training (ignored if distributed) + distributed: Whether to use distributed training + """ + super().__init__(clt_config, training_config, log_dir, device, False) + + # Set up imports for Ray + if use_ray == 'train': + from ray.train import report + from ray.train import Checkpoint + elif use_ray == 'tune': + from ray.tune import report + from ray.tune import Checkpoint + + import tempfile + + self.ray_reporter = report + self.ray_checkpoint = Checkpoint + self.ray_tempfile = tempfile + + + def train(self, eval_every: int = 1000) -> CrossLayerTranscoder: + """Train the CLT model. + + Args: + eval_every: Evaluate model every N steps + + Returns: + Trained CLT model (local shard) + """ + # Print startup message from rank 0 only + if not self.distributed or self.rank == 0: + print(f"Starting CLT training on {self.device}...") + print( + f"Model has {self.clt_config.num_features} features per layer " + f"and {self.clt_config.num_layers} layers" + ) + print(f"Training for {self.training_config.training_steps} steps.") + print(f"Logging to {self.log_dir}") + if self.distributed: + print(f"Distributed training with {self.world_size} processes (Tensor Parallelism)") + + # Check if using normalization and notify user + if self.training_config.normalization_method == "estimated_mean_std": + print("\n>>> NORMALIZATION PHASE <<<") + print("Normalization statistics are being estimated from dataset activations.") + print("This may take some time, but happens only once before training begins.") + print(f"Using {self.training_config.normalization_estimation_batches} batches for estimation.\n") + + # Make sure we flush stdout to ensure prints appear immediately, + # especially important in Jupyter/interactive environments + sys.stdout.flush() + # Wait for 1 second to ensure output is displayed before training starts + time.sleep(1) + print("\n>>> TRAINING PHASE <<<") + sys.stdout.flush() + + # # After the existing startup messages + # if self.distributed: + # print("\n!!! DIAGNOSTIC INFO !!!") + # print(f"Rank {self.rank}: Process group type: {type(self.process_group)}") + # print(f"Rank {self.rank}: RowParallelLinear _reduce does NOT divide by world_size") + # print(f"Rank {self.rank}: Using weight regularization in sparsity penalty") + # print(f"Rank {self.rank}: Averaging replicated parameter gradients") + # # Check if activation store has rank/world attributes before accessing + # store_rank = getattr(self.activation_store, "rank", "N/A") + # store_world = getattr(self.activation_store, "world", "N/A") + # print(f"Rank {self.rank}: Data sharding: rank={store_rank}, world={store_world}") + # print(f"Rank {self.rank}: Batch size tokens: {self.training_config.train_batch_size_tokens}") + # print(f"Rank {self.rank}: Sparsity lambda: {self.training_config.sparsity_lambda}") + + # # Check if activation store actually loaded correctly + # batch_avail = next(iter(self.activation_store), None) + # print(f"Rank {self.rank}: First batch available: {batch_avail is not None}") + + # # Force torch to compile/execute our code by running a tiny forward/backward pass + # dummy = torch.ones(1, device=self.device, requires_grad=True) + # dummy_out = dummy * 2 + # dummy_out.backward() + # print("!!! END DIAGNOSTIC !!!\n") + + # Create progress bar only on rank 0 + pbar: Union[tqdm, range] + if not self.distributed or self.rank == 0: + pbar = tqdm( + range(self.training_config.training_steps), + desc="Training CLT", + leave=True, + ) + else: + pbar = range(self.training_config.training_steps) + + step = 0 + try: + for step in pbar: + # Refresh progress bar on rank 0 + step_start_time = time.monotonic() # Start timing the step + if isinstance(pbar, tqdm): + pbar.refresh() + + try: + # Get batch directly from the iterator (handles distributed sampling internally) + batch_get_start_time = time.monotonic() + inputs, targets = next(self.activation_store) + batch_get_duration = time.monotonic() - batch_get_start_time + logger.debug(f"Rank {self.rank} Step {step}: Getting batch took {batch_get_duration:.4f}s") + + except StopIteration: + # Rank 0 prints message + if not self.distributed or self.rank == 0: + print("Activation store exhausted. Training finished early.") + if self.distributed: + dist.barrier() # Ensure all ranks see this + break # Exit training loop if data runs out + except Exception as e: + # Rank 0 prints message + if not self.distributed or self.rank == 0: + print(f"\nRank {self.rank}: Error getting batch at step {step}: {e}. Skipping step.") + # Maybe barrier here too? If one rank fails, others might hang? + # Let's continue for now, assuming store handles internal errors. + continue + + # --- Check for empty batch --- (Optional but good practice) + # This check should ideally happen *before* moving data potentially + if not inputs or not targets or not any(v.numel() > 0 for v in inputs.values()): + if not self.distributed or self.rank == 0: + print(f"\nRank {self.rank}: Warning: Received empty batch at step {step}. Skipping.") + continue + + # --- BEGIN: One-time Normalization Check --- + if step == 0 and (not self.distributed or self.rank == 0): + logger.info("--- Running Post-Normalization Check (First Batch) ---") + norm_applied = getattr(self.activation_store, "apply_normalization", None) + if isinstance(self.activation_store, (LocalActivationStore, RemoteActivationStore)): + logger.info(f"ActivationStore reports apply_normalization={norm_applied}") + elif isinstance(self.activation_store, StreamingActivationStore): + logger.info( + f"Streaming store normalization method: {self.activation_store.normalization_method}" + ) + + for li in range(self.clt_config.num_layers): + mean_in, std_in, mean_tg, std_tg = float("nan"), float("nan"), float("nan"), float("nan") + try: + if li in inputs and inputs[li].numel() > 0: + input_tensor = inputs[li].float() + mean_in = input_tensor.mean().item() + std_in = input_tensor.std().item() + if li in targets and targets[li].numel() > 0: + target_tensor = targets[li].float() + mean_tg = target_tensor.mean().item() + std_tg = target_tensor.std().item() + + if not ( + torch.isnan(torch.tensor(mean_in)) and torch.isnan(torch.tensor(mean_tg)) + ): # Log if at least one value is valid + logger.info( + f" Layer {li:>2}: Input Mean={mean_in:+.4f}, Std={std_in:.4f} | Target Mean={mean_tg:+.4f}, Std={std_tg:.4f}" + ) + except Exception as e: + logger.error(f" Layer {li}: Error during normalization check: {e}") + logger.info("--- End Post-Normalization Check ---") + # --- END: One-time Normalization Check --- + + # --- Forward pass and compute loss --- (All ranks) + self.optimizer.zero_grad() + + # Compute feature activations **once** per step to avoid redundant encoder forward passes. + feature_activations_batch = self.model.get_feature_activations(inputs) + + # Compute total loss using the pre-computed activations + loss, loss_dict = self.loss_manager.compute_total_loss( + self.model, + inputs, + targets, + step, + self.training_config.training_steps, + precomputed_activations=feature_activations_batch, + ) + + # --- Update Dead Neuron Counters --- (All ranks, counter is replicated) + # We need *full* feature activations *after* non-linearity + if hasattr(self, "n_forward_passes_since_fired"): + with torch.no_grad(): + for layer_idx, layer_acts in feature_activations_batch.items(): + # Ensure layer index is within bounds of the counter tensor + if layer_idx < self.n_forward_passes_since_fired.shape[0]: + if layer_acts.numel() > 0: + # layer_acts shape: [batch_tokens, num_features] + fired_mask_per_token = layer_acts > 1e-6 + fired_features_this_layer = fired_mask_per_token.any(dim=0) + + if fired_features_this_layer.shape[0] == self.n_forward_passes_since_fired.shape[1]: + self.n_forward_passes_since_fired[layer_idx] += 1 + self.n_forward_passes_since_fired[layer_idx][fired_features_this_layer] = 0 + else: + if not self.distributed or self.rank == 0: # Only rank 0 logs warning + print( + f"Rank {self.rank}: Warning: Shape mismatch for dead neuron update at layer {layer_idx}. " + f"Acts shape: {layer_acts.shape}, Fired mask: {fired_features_this_layer.shape}, " + f"Counter: {self.n_forward_passes_since_fired.shape}" + ) + + # --- Backward pass --- (All ranks, handles communication implicitly) + if torch.isnan(loss): + if not self.distributed or self.rank == 0: + print( + f"\nRank {self.rank}: Warning: NaN loss encountered at step {step}. " + f"Skipping backward pass and optimizer step." + ) + else: + try: + loss.backward() + + # --- Synchronise gradients of replicated parameters --- # + self._average_shared_parameter_grads() + + # --- Gradient clipping --- # + if self.training_config.gradient_clip_val is not None: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.training_config.gradient_clip_val, + ) + except RuntimeError as e: + if not self.distributed or self.rank == 0: + print( + f"\nRank {self.rank}: Error during backward pass at step {step}: {e}. Skipping optimizer step." + ) + continue + + # --- Optimizer step --- (Applied to local parameters using local gradients) + self.optimizer.step() + + # --- Invalidate Caches --- # + if hasattr(self.model, "_cached_decoder_norms"): + self.model._cached_decoder_norms = None + + # --- Scheduler step --- (All ranks) + if self.scheduler: + self.scheduler.step() + + # --- Update progress bar --- (Rank 0 only) + if isinstance(pbar, tqdm): + description = ( + f"Loss: {loss_dict.get('total', float('nan')):.4f} " + f"(R: {loss_dict.get('reconstruction', float('nan')):.4f} " + f"S: {loss_dict.get('sparsity', float('nan')):.4f} " + f"P: {loss_dict.get('preactivation', float('nan')):.4f})" + ) + pbar.set_description(description) + # Force update to display progress + if step % 1 == 0: # Update every step + pbar.refresh() + sys.stdout.flush() + + # --- Log metrics --- (Rank 0 logs to WandB/file) + # self._log_metrics(step, loss_dict) + + step_duration = time.monotonic() - step_start_time + logger.debug( + f"Rank {self.rank} Step {step}: Main logic (incl. batch get, fwd, bwd, optim) took {step_duration:.4f}s" + ) + + # --- Evaluation & Checkpointing --- + eval_interval = self.training_config.eval_interval + checkpoint_interval = self.training_config.checkpoint_interval + + save_checkpoint_flag = (step > 0 and step % checkpoint_interval == 0) or ( # Avoid checkpoint at step 0 + step == self.training_config.training_steps - 1 + ) + run_eval_flag = (step % eval_interval == 0) or (step == self.training_config.training_steps - 1) + + # --- Evaluation (all ranks participate to match collectives) --- + # In tensor-parallel mode the model forward includes collective ops (all_reduce/all_gather). + # If only rank 0 performed the forward pass these collectives would block on the other ranks + # resulting in NCCL timeouts. Therefore, *every* rank must execute the evaluation forward pass. + # We still only log / store the resulting metrics on rank 0. + if run_eval_flag: + # if self.distributed: + # dist.barrier() # Sync before evaluation starts so that all ranks enter together + + # Compute evaluation metrics on all ranks to keep collective ops aligned + current_dead_mask = self.dead_neurons_mask.detach().clone() + eval_metrics = self.evaluator.compute_metrics( + inputs, + targets, + dead_neuron_mask=current_dead_mask, + ) + + if not self.distributed or self.rank == 0: + # # Store evaluation metrics (for saving to JSON) + # self.metrics["eval_metrics"].append({"step": step, **eval_metrics}) + + # --- Update Progress Bar Postfix --- + l0_str = f"AvgL0: {eval_metrics.get('sparsity/avg_l0', 0.0):.2f}" + ev_str = f"EV: {eval_metrics.get('reconstruction/explained_variance', 0.0):.3f}" + avg_density_mean = eval_metrics.get("sparsity/feature_density_mean") + dens_str = f"Dens: {avg_density_mean:.3f}" if avg_density_mean is not None else "Dens: N/A" + eval_dead_str = f"Dead(Eval): {eval_metrics.get('dead_features/total_eval', 0)}" + eval_msg = f"{l0_str}, {ev_str}, {dens_str}, {eval_dead_str}" + + if isinstance(pbar, tqdm): + pbar.set_postfix_str(eval_msg) + pbar.refresh() + + # # --- Log evaluation metrics to WandB --- + # self.wandb_logger.log_evaluation(step, eval_metrics) + + # # --- Save metrics JSON after evaluation --- + # self._save_metrics() + + # Optionally compute and log sparsity diagnostics (can be slow) + if self.training_config.compute_sparsity_diagnostics: + # Calculate diagnostics using the same batch data and cached activations/norms + sparsity_diag_metrics = self._compute_sparsity_diagnostics(inputs, feature_activations_batch) + # Merge diagnostics into the main eval metrics dict + if sparsity_diag_metrics: + eval_metrics.update(sparsity_diag_metrics) + # Log updated metrics to WandB (only rank 0) + # if not self.distributed or self.rank == 0: + # self.wandb_logger.log_evaluation(step, eval_metrics) + + # # Ensure all ranks finish evaluation before proceeding + # if self.distributed: + # dist.barrier() + + # --- Checkpointing & Reporting --- + # Make metrics dict to report + ray_metrics = loss_dict + + # Add hp info + if self.scheduler is not None: + # Assuming one parameter group + current_lr = self.scheduler.get_last_lr()[0] + + current_lambda = self.loss_manager.get_current_sparsity_lambda() + + ray_metrics |= { + 'lr': current_lr, + 'sparsity_lambda': current_lambda + } + + if run_eval_flag: + ray_metrics |= eval_metrics + + # Now report to ray and potentially checkpoint + with self.ray_tempfile.TemporaryDirectory() as temp_checkpoint_dir: + checkpoint = None + + if save_checkpoint_flag: + model_checkpoint_path = os.path.join(temp_checkpoint_dir, f"clt_checkpoint_{step}.pt") + store_checkpoint_path = os.path.join(temp_checkpoint_dir, f"activation_store_checkpoint_{step}.pt") + + torch.save(self.model.state_dict(), model_checkpoint_path) + # TODO: wandb_logger untested with ray reporting + # self.wandb_logger.log_artifact( + # artifact_path=model_checkpoint_path, artifact_type="model", name=f"clt_checkpoint_{step}" + # ) + torch.save(self.activation_store.state_dict(), store_checkpoint_path) + + checkpoint = self.ray_checkpoint.from_directory(temp_checkpoint_dir) + + self.ray_reporter(ray_metrics, checkpoint=checkpoint) + + # # --- Explicitly delete tensors at the very end of the loop iteration --- # + # # Do this on all ranks + # try: + # del inputs + # del targets + # if "loss" in locals() and loss is not None: + # del loss + # if "feature_activations_batch" in locals(): + # del feature_activations_batch + # except NameError: + # pass + + except KeyboardInterrupt: + if not self.distributed or self.rank == 0: + print("\nTraining interrupted by user.") + finally: + if isinstance(pbar, tqdm): + pbar.close() + if not self.distributed or self.rank == 0: + print(f"Training loop finished at step {step}.") + + # # Sync before final save attempt + # if self.distributed: + # dist.barrier() + + # # --- Save final model and metrics --- (Rank 0 handles metrics/store, all ranks save model state) + # final_checkpoint_dir = os.path.join(self.log_dir, "final") + # final_store_path = os.path.join(final_checkpoint_dir, "activation_store_final.pt") # Store inside final dir + + # # All ranks save final model state + # try: + # final_model_state_dict = self.model.state_dict() + # save_state_dict( + # state_dict=final_model_state_dict, + # storage_writer=FileSystemWriter(final_checkpoint_dir), + # planner=DefaultSavePlanner(), + # no_dist=(not self.distributed), # Disable distributed save if not distributed + # ) + # except Exception as e: + # print(f"Rank {self.rank}: Warning: Failed to save final distributed model state: {e}") + + # # Rank 0 saves store, metrics, logs artifact + # if not self.distributed or self.rank == 0: + # print(f"Saving final activation store state to {final_store_path}...") + # os.makedirs(final_checkpoint_dir, exist_ok=True) # Ensure dir exists for store save + # try: + # # Check if the store has a close method before calling (for compatibility) + # if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): + # self.activation_store.close() + # except Exception as e: + # print(f"Rank 0: Warning: Failed to close activation store: {e}") + + # print("Saving final metrics...") + # self._save_metrics() + + # # Log final checkpoint directory as artifact + # self.wandb_logger.log_artifact(artifact_path=final_checkpoint_dir, artifact_type="model", name="clt_final") + + # # Finish WandB logging + # self.wandb_logger.finish() + # print(f"Training completed! Final checkpoint saved to {final_checkpoint_dir}") + + # --- Close the activation store (stops prefetch thread if applicable) --- # + if hasattr(self.activation_store, "close") and callable(getattr(self.activation_store, "close")): + self.activation_store.close() + + # Clean up distributed process group + if self.distributed: + dist.destroy_process_group() + + return self.model diff --git a/tune_clt_local_ray.py b/tune_clt_local_ray.py new file mode 100644 index 0000000..d273587 --- /dev/null +++ b/tune_clt_local_ray.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +""" +Script to train a Cross-Layer Transcoder (CLT) using pre-generated local activations. +Handles configuration parsing from command-line arguments and initiates training. +""" + +import argparse +import torch +from pathlib import Path +from typing import Literal, Optional +import logging +import time +import json +import ray, ray.tune + +from clt.training.trainer_ray import CLTTrainerRay + +# Attempt to import transformers for model dimension detection +try: + from transformers import AutoConfig +except ImportError: + AutoConfig = None + +# Import necessary CLT components +try: + from clt.config import CLTConfig, TrainingConfig + from clt.training.trainer import CLTTrainer +except ImportError as e: + print( + f"FATAL: ImportError: {e}. Please ensure the 'clt' library is installed or " + "the project root is in your PYTHONPATH." + ) + raise + +# Setup basic logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +def get_model_dimensions(model_name: str) -> tuple[int, int]: + """Attempt to dynamically get num_layers and d_model from model_name.""" + if AutoConfig is None: + logger.warning( + "Transformers library not found. Cannot dynamically detect model dimensions." + " Falling back to gpt2 defaults (12 layers, 768 hidden size)." + " Install transformers (`pip install transformers`) for auto-detection." + ) + return 12, 768 # Default to gpt2 small + + try: + config = AutoConfig.from_pretrained(model_name) + num_layers = getattr(config, "num_hidden_layers", None) or getattr(config, "n_layer", None) + d_model = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None) + + if num_layers is None or d_model is None: + raise ValueError(f"Could not automatically determine num_layers or d_model for {model_name}") + logger.info(f"Detected model dimensions for {model_name}: {num_layers} layers, {d_model} hidden size.") + return num_layers, d_model + except Exception as e: + logger.warning( + f"Failed to get model dimensions for {model_name}: {e}. " + f"Falling back to gpt2 defaults (12 layers, 768 hidden size)." + ) + return 12, 768 + + +def train_loop_per_worker(cfg): + + # --- Determine Device --- + if cfg['device']: + device = cfg['device'] + else: + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + logger.info(f"Using device: {device}") + + # --- Determine Base Model Dimensions --- + # Use the provided --model-name to get dimensions for the CLT config. + # This ensures the CLT matches the architecture activations were generated from. + base_model_name = cfg['model_name'] + num_layers, d_model = get_model_dimensions(base_model_name) + if num_layers is None or d_model is None: + # Added error handling if dimensions couldn't be determined + logger.error(f"Could not determine dimensions for model '{base_model_name}'. Exiting.") + return + + # --- Create CLT Configuration --- + clt_config = CLTConfig( + num_features=cfg['num_features'], + num_layers=num_layers, + d_model=d_model, + activation_fn=cfg['activation_fn'], + jumprelu_threshold=cfg['jumprelu_threshold'], + clt_dtype=cfg['clt_dtype'], + ) + logger.info(f"CLT Config: {clt_config}") + + # --- Create Training Configuration --- + # Handle 'none' scheduler case + lr_scheduler_arg: Optional[Literal["linear", "cosine", "linear_final20"]] = ( + cfg['lr_scheduler'] if cfg['lr_scheduler'] != "none" else None + ) + + # Simplified TrainingConfig instantiation for local source only + training_config = TrainingConfig( + # Core Training + learning_rate=cfg['learning_rate'], + training_steps=cfg['training_steps'], + seed=cfg['seed'], + train_batch_size_tokens=cfg['train_batch_size_tokens'], + # Activation Source (hardcoded to local_manifest) + activation_source="local_manifest", + activation_path=cfg['activation_path'], + activation_dtype=cfg['activation_dtype'], + # Normalization + normalization_method=cfg['normalization_method'], + # Sampling Strategy + sampling_strategy=cfg['sampling_strategy'], + # Loss Coeffs + sparsity_lambda=cfg['sparsity_lambda'], + sparsity_c=cfg['sparsity_c'], + preactivation_coef=cfg['preactivation_coef'], + # Optimizer & Scheduler + optimizer=cfg['optimizer'], + optimizer_beta1=cfg['optimizer_beta1'], + optimizer_beta2=cfg['optimizer_beta2'], + lr_scheduler=lr_scheduler_arg, + # Logging & Checkpointing + log_interval=cfg['log_interval'], + eval_interval=cfg['eval_interval'], + checkpoint_interval=cfg['checkpoint_interval'], + # Dead Features + dead_feature_window=cfg['dead_feature_window'], + # WandB + enable_wandb=cfg['enable_wandb'], + wandb_project=cfg['wandb_project'], + wandb_entity=cfg['wandb_entity'], + wandb_run_name=cfg['wandb_run_name'], + wandb_tags=cfg['wandb_tags'], + # Remote config is not handled by this script + remote_config=None, + ) + logger.info(f"Training Config: {training_config}") + + trainer = CLTTrainerRay( + clt_config=clt_config, + training_config=training_config, + log_dir=cfg['log_dir'], + device=device, + distributed=False, # distributing model across gpus within a ray tune worker not implemented + use_ray='tune' + ) + + trained_clt_model = trainer.train(eval_every=training_config.eval_interval) + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Train a Cross-Layer Transcoder (CLT) using pre-generated local activations.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # --- Core Training Parameters --- + core_group = parser.add_argument_group("Core Training Parameters") + core_group.add_argument( + "--activation-path", + type=str, + required=True, + help="Path to the directory containing pre-generated activations (including index.bin, metadata.json, etc.).", + ) + core_group.add_argument( + "--output-dir", + type=str, + default=f"clt_train_local_{int(time.time())}", + help="Directory to save logs, checkpoints, and final model.", + ) + core_group.add_argument( + "--model-name", + type=str, + required=True, + help="Base model name or path (e.g., 'gpt2', 'gpt2-medium'). Must match the model used for activation generation.", + ) + core_group.add_argument( + "--device", + type=str, + default=None, + help="Device to use (e.g., 'cuda', 'cpu', 'mps'). Auto-detected if None.", + ) + core_group.add_argument( + "--distributed", + action="store_true", + help="Enable distributed training (requires torchrun/appropriate launcher).", + ) + + # --- CLT Model Architecture --- + clt_group = parser.add_argument_group("CLT Model Architecture (CLTConfig)") + clt_group.add_argument( + "--num-features", + type=int, + required=True, + help="Number of features per layer in the CLT.", + ) + # num_layers and d_model are derived from the base model + clt_group.add_argument( + "--activation-fn", + type=str, + choices=["jumprelu", "relu"], + default="jumprelu", + help="Activation function for the CLT.", + ) + clt_group.add_argument( + "--jumprelu-threshold", + type=float, + default=0.03, + help="Threshold for JumpReLU activation (if used).", + ) + clt_group.add_argument( + "--clt-dtype", + type=str, + default=None, + help="Optional data type for the CLT model parameters (e.g., 'float16', 'bfloat16').", + ) + + # --- Training Hyperparameters --- + train_group = parser.add_argument_group("Training Hyperparameters (TrainingConfig)") + train_group.add_argument("--learning-rate", type=float, default=3e-4, help="Optimizer learning rate.") + train_group.add_argument( + "--training-steps", + type=int, + default=50000, + help="Total number of training steps.", + ) + train_group.add_argument( + "--train-batch-size-tokens", + type=int, + default=4096, + help="Target number of tokens per training batch.", + ) + train_group.add_argument( + "--normalization-method", + type=str, + choices=["auto", "none"], + default="auto", + help=( + "Normalization for activation store. 'auto' uses pre-calculated stats " + "(norm_stats.json) from the activation_path. 'none' disables normalization." + ), + ) + train_group.add_argument( + "--sparsity-lambda", + type=float, + default=1e-3, + help="Coefficient for the L1 sparsity penalty.", + ) + train_group.add_argument( + "--sparsity-c", + type=float, + default=1.0, + help="Constant shaping the sparsity penalty (typically 1.0).", + ) + train_group.add_argument( + "--preactivation-coef", + type=float, + default=3e-6, + help="Coefficient for the pre-activation MSE loss term.", + ) + train_group.add_argument( + "--optimizer", + type=str, + choices=["adam", "adamw"], + default="adamw", + help="Optimizer algorithm.", + ) + train_group.add_argument( + "--optimizer-beta1", + type=float, + default=None, + help="Optimizer beta1 value (if using Adam/AdamW).", + ) + train_group.add_argument( + "--optimizer-beta2", + type=float, + default=None, + help="Optimizer beta2 value (if using Adam/AdamW).", + ) + train_group.add_argument( + "--lr-scheduler", + type=str, + choices=["linear", "cosine", "linear_final20", "none"], + default="linear", + help=( + "Learning rate scheduler type. 'linear_final20' keeps LR constant until the last 20% " + "of steps then decays linearly to 0 ('none' to disable)." + ), + ) + train_group.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for reproducibility.", + ) + train_group.add_argument( + "--activation-dtype", + type=str, + default="float32", + help="Data type to load activations as (e.g., 'float32', 'bfloat16'). Should match storage or be compatible.", + ) + train_group.add_argument( + "--dead-feature-window", + type=int, + default=1000, + help="Number of steps of inactivity before a feature is considered 'dead' for evaluation.", + ) + + tune_group = parser.add_argument_group("Hyperparameter Tuning Parameters") + tune_group.add_argument( + "--n-gpus", + type=int, + default=1, + help="Number of GPUs available for hyperparameter tuning." + ) + tune_group.add_argument( + "--n-workers", + type=int, + default=1, + help="Number of trainers to run in parallel for hyperparameter tuning." + ) + + # --- Sampling Strategy --- Added Group + sampling_group = parser.add_argument_group("Sampling Strategy (TrainingConfig)") + sampling_group.add_argument( + "--sampling-strategy", + type=str, + choices=["sequential", "random_chunk"], + default="sequential", + help="Sampling strategy: 'sequential' processes chunks in order per epoch, 'random_chunk' picks a random valid chunk each step.", + ) + + # --- Logging & Checkpointing --- + log_group = parser.add_argument_group("Logging & Checkpointing (TrainingConfig)") + log_group.add_argument( + "--log-interval", + type=int, + default=100, + help="Log training metrics every N steps.", + ) + log_group.add_argument( + "--eval-interval", + type=int, + default=1000, + help="Run evaluation metrics computation every N steps.", + ) + log_group.add_argument( + "--checkpoint-interval", + type=int, + default=1000, + help="Save a training checkpoint every N steps.", + ) + # WandB arguments + log_group.add_argument("--enable-wandb", action="store_true", help="Enable Weights & Biases logging.") + log_group.add_argument("--wandb-project", type=str, default=None, help="WandB project name.") + log_group.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity (username or team).", + ) + log_group.add_argument( + "--wandb-run-name", + type=str, + default=None, + help="Custom name for the WandB run (defaults to a timestamp).", + ) + log_group.add_argument("--wandb-tags", nargs="+", default=None, help="List of tags for the WandB run.") + + args = parser.parse_args() + + # --- Validation --- + # Simplified validation: activation_path is required by argparse + # No need to check for generation args + + return args + + +def main(): + """Main function to configure and run the CLTTrainer for local activations.""" + args = parse_args() + + # --- Setup Output Directory --- + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + logger.info(f"Output directory: {output_dir.resolve()}") + + # Save command-line arguments + try: + with open(output_dir / "cli_args.json", "w") as f: + json.dump(vars(args), f, indent=2) + except Exception as e: + logger.warning(f"Could not save command-line args: {e}") + + + # --- Start Tuning --- + # config to pass to workers + hps = { + 'log_dir': str(output_dir.resolve()), + **vars(args) + } + # add actual hyperparameters + hps |= { + 'sparsity_c': ray.tune.grid_search([0.01, 0.03, 0.09, 0.27, 0.81, 2.43]), + 'sparsity_lambda': ray.tune.grid_search([1e-5, 3e-5, 9e-5, 2.7e-4, 8.1e-4, 2.43e-3]), + } + + logger.info("Starting hyperparameter tuning from local activations...") + try: + tuner = ray.tune.Tuner( + ray.tune.with_resources(train_loop_per_worker, {'cpu': 1, 'gpu': args.n_gpus / args.n_workers}), + param_space=hps, + tune_config=ray.tune.TuneConfig(max_concurrent_trials=args.n_workers), + run_config=ray.tune.RunConfig(storage_path=str(output_dir.resolve())) + ) + results = tuner.fit() + logger.info("Tuning complete!") + logger.info(f"Final model and logs saved in: {output_dir.resolve()}") + except Exception as e: + logger.exception(f"Tuning failed: {e}") # Use logger.exception + raise + + +if __name__ == "__main__": + main()