diff --git a/tuned_lens/model_surgery.py b/tuned_lens/model_surgery.py index e17ec7f..078b3b9 100644 --- a/tuned_lens/model_surgery.py +++ b/tuned_lens/model_surgery.py @@ -117,11 +117,16 @@ def get_final_norm(model: Model) -> Norm: ), ): final_layer_norm = base_model.ln_f - elif isinstance(base_model, models.llama.modeling_llama.LlamaModel): - final_layer_norm = base_model.norm - elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel): - final_layer_norm = base_model.norm - elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel): + elif isinstance( + base_model, + ( + models.llama.modeling_llama.LlamaModel, + models.mistral.modeling_mistral.MistralModel, + models.gemma.modeling_gemma.GemmaModel, + models.qwen3.modeling_qwen3.Qwen3Model, + models.gpt_oss.modeling_gpt_oss.GptOssModel, + ), + ): final_layer_norm = base_model.norm else: raise NotImplementedError(f"Unknown model type {type(base_model)}") @@ -166,11 +171,16 @@ def get_transformer_layers(model: Model) -> tuple[str, th.nn.ModuleList]: ), ): path_to_layers += ["h"] - elif isinstance(base_model, models.llama.modeling_llama.LlamaModel): - path_to_layers += ["layers"] - elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel): - path_to_layers += ["layers"] - elif isinstance(base_model, models.gemma.modeling_gemma.GemmaModel): + elif isinstance( + base_model, + ( + models.gpt_oss.modeling_gpt_oss.GptOssModel, + models.qwen3.modeling_qwen3.Qwen3Model, + models.llama.modeling_llama.LlamaModel, + models.mistral.modeling_mistral.MistralModel, + models.gemma.modeling_gemma.GemmaModel, + ), + ): path_to_layers += ["layers"] else: raise NotImplementedError(f"Unknown model type {type(base_model)}") diff --git a/tuned_lens/muon.py b/tuned_lens/muon.py new file mode 100644 index 0000000..4d1a7fb --- /dev/null +++ b/tuned_lens/muon.py @@ -0,0 +1,188 @@ +"""Adapted from https://github.com/KellerJordan/Muon/blob/master/muon.py.""" + +import torch +import torch.distributed as dist +from torch import Tensor + + +def quintic_newtonschulz(G: Tensor, steps: int) -> Tensor: + """Newton-Schulz iteration to compute the orthogonalization of G. + + We opt to use a quintic iteration whose coefficients are selected to maximize the + slope at zero. For the purpose of minimizing steps, it turns out to be empirically + effective to keep increasing the slope at zero even beyond the point where the + iteration no longer converges all the way to one everywhere on the interval. This + iteration therefore does not produce UV^T but rather something like US'V^T where S' + is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + # batched implementation by @scottjmaddox, put into practice by @YouJiacheng + assert G.ndim >= 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + # quintic strategy adapted from suggestion by @jxbz, @leloykun, @YouJiacheng + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + + if G.size(-2) > G.size(-1): + X = X.mT + return X + + +class Muon(torch.optim.Optimizer): + """Muon - MomentUm Orthogonalized by Newton-schulz. + + Muon is a generalized steepest descent optimizer using the spectral norm on the + matrix-valued parameters. This means it always updates in the direction which + locally reduces the loss as much as possible, while constraining the update to have + a spectral norm given by the learning rate. It achieves this using a Newton-Schulz + iteration to orthogonalize the stochastic gradient (or momentum buffer) for each + matrix in the model before taking a step. + + The spectral norm is an intuitive heuristic because, roughly speaking, it measures + the maximum change to the activations of a layer that can be caused by a change to + its weights. By constraining the worst-case change to the activations, we ensure + that we do not desta + + TThis optimizer is unlikely to work well with small batch sizes, since it strongly + magnifies small singular values, which will be noisy given a small minibatch. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0.95, + nesterov: bool = True, + weight_decay: float = 0.1, + ns_steps: int = 5, + ddp: bool = True, + ): + """Initialize the Muon optimizer. + + You will need to set the `ddp` flag to `False` if you are using FSDP or some + similar scheme where parameters are sharded across multiple processes. + + Args: + params: Iterable of parameters to optimize. + lr: The learning rate used by the internal SGD. + momentum: The momentum used by the internal SGD. + nesterov: Whether to use Nesterov-style momentum in the internal SGD. + ns_steps: The number of Newton-Schulz iteration steps to use. + weight_decay: The decoupled weight decay to apply at each step. + ddp: Whether to distribute the work of Newton-Schulz across multiple + processes. This assumes that every process has all the parameters. + """ + defaults = dict( + lr=lr, + momentum=momentum, + nesterov=nesterov, + ns_steps=ns_steps, + weight_decay=weight_decay, + ) + self.rank = dist.get_rank() if dist.is_initialized() and ddp else 0 + self.world_size = dist.get_world_size() if dist.is_initialized() and ddp else 1 + + # Distributed Data Parallel (DDP) setup + if dist.is_initialized() and ddp: + param_groups = [] + + # Check that the user isn't doing some weird model parallelism + devices = {p.device for p in params} + device = next(iter(devices)) + assert device.type == "cuda", "Muon only supports CUDA devices." + assert len(devices) == 1, "Muon does not support model parallelism." + + # Group parameters by their device and number of elements. For each group, + # we pre-allocate a buffer to store the updates from all ranks. + for size in {p.numel() for p in params}: + b = torch.empty( + self.world_size, size, dtype=torch.bfloat16, device=device + ) + group = dict( + params=[p for p in params if p.numel() == size], + update_buffer=b, + update_buffer_views=[b[i] for i in range(self.world_size)], + ) + param_groups.append(group) + + super().__init__(param_groups, defaults) + else: + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + """Performs a single optimization step.""" + for group in self.param_groups: + params: list[Tensor] = group["params"] + + # Apply decoupled weight decay to all parameters. This doesn't require any + # communication, since it's a simple element-wise operation. + if group["weight_decay"] > 0.0: + for p in params: + p.mul_(1 - group["lr"] * group["weight_decay"]) + + # These will be None / empty list if we're not using DDP + update_buffer: Tensor | None = group.get("update_buffer", None) + update_buffer_views: list[Tensor] = group.get("update_buffer_views", []) + + beta = group["momentum"] + handle = None + params_world = None + + def update_prev(): # optimized implementation contributed by @YouJiacheng + assert handle is not None and params_world is not None + handle.wait() + + for p_world, g_world in zip(params_world, update_buffer_views): + # Heuristic from + scale = 0.2 * max(p_world.shape) ** 0.5 + p_world.add_(g_world.view_as(p_world), alpha=-group["lr"] * scale) + + for i in range(0, len(params), self.world_size): + # Compute Muon update + if i + self.rank < len(params): + p = params[i + self.rank] + state = self.state[p] + + g = p.grad + assert g is not None + + # Apply momentum + if beta > 0.0: + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(g) + + buf: Tensor = state["exp_avg"].lerp_(g, 1 - beta) + g = g.lerp_(buf, beta) if group["nesterov"] else buf + + if g.ndim == 4: # for the case of conv filters + g = g.view(len(g), -1) + + g = quintic_newtonschulz(g, steps=group["ns_steps"]) + else: + g = update_buffer_views[self.rank] + + if self.world_size > 1: + # async all_gather instead of sync all_reduce by @YouJiacheng + if i > 0: + update_prev() + + handle = dist.all_gather_into_tensor( + update_buffer, g.flatten(), async_op=True + ) + params_world = params[i : i + self.world_size] + else: + scale = 0.2 * max(params[i].shape) ** 0.5 + params[i].add_(g, alpha=-group["lr"] * scale) + + if self.world_size > 1: + update_prev() diff --git a/tuned_lens/nn/lenses.py b/tuned_lens/nn/lenses.py index 0ef308b..e748a7a 100644 --- a/tuned_lens/nn/lenses.py +++ b/tuned_lens/nn/lenses.py @@ -12,6 +12,7 @@ from transformers import PreTrainedModel from tuned_lens import load_artifacts +from tuned_lens.model_surgery import Norm from tuned_lens.nn.unembed import Unembed logger = logging.getLogger(__name__) @@ -68,13 +69,16 @@ def __init__( def from_model( cls, model: PreTrainedModel, + *, + final_norm: Optional[Norm] = None, ) -> "LogitLens": """Create a LogitLens from a pretrained model. Args: model: A pretrained model from the transformers library you wish to inspect. + final_norm: An optional final layer normalization to apply. """ - unembed = Unembed(model) + unembed = Unembed(model, final_norm=final_norm) return cls(unembed) def transform_hidden(self, h: th.Tensor, idx: int) -> th.Tensor: @@ -182,6 +186,8 @@ def from_model( model: PreTrainedModel, model_revision: Optional[str] = None, bias: bool = True, + *, + final_norm: Optional[Norm] = None, ) -> "TunedLens": """Create a lens from a pretrained model. @@ -189,11 +195,12 @@ def from_model( model: The model to create the lens from. model_revision: The git revision of the model to used. bias: Whether to use a bias in the linear translators. + final_norm: An optional final layer normalization to apply. Returns: A TunedLens instance. """ - unembed = Unembed(model) + unembed = Unembed(model, final_norm=final_norm) config = TunedLensConfig( base_model_name_or_path=model.config.name_or_path, base_model_revision=model_revision, diff --git a/tuned_lens/nn/unembed.py b/tuned_lens/nn/unembed.py index abb51f2..dc8a9c7 100644 --- a/tuned_lens/nn/unembed.py +++ b/tuned_lens/nn/unembed.py @@ -38,14 +38,17 @@ class Unembed(th.nn.Module): def __init__( self, model: model_surgery.Model, + *, + final_norm: Optional[model_surgery.Norm] = None, ): """Initialize unmebed. Args: model: A HuggingFace model from which to extract the unembedding matrix. + final_norm: An optional final layer normalization to apply before the """ super().__init__() - final_norm = model_surgery.get_final_norm(model) + final_norm = final_norm or model_surgery.get_final_norm(model) unembedding_matrix = model_surgery.get_unembedding_matrix(model) self.final_norm = copy.deepcopy(final_norm) diff --git a/tuned_lens/scripts/eval_loop.py b/tuned_lens/scripts/eval_loop.py index caf49d3..bcb5317 100644 --- a/tuned_lens/scripts/eval_loop.py +++ b/tuned_lens/scripts/eval_loop.py @@ -223,8 +223,7 @@ def execute(self): # Note since we are not training we can just move the lens to the device. # No need to use DDP lenses = {name: lens.to(self.dist.device) for name, lens in lenses.items()} - dl = self.dist.dataloader(data) - dl.seed(self.seed) + dl = self.dist.dataloader(data, self.seed) for lens in lenses.values(): lens.eval() diff --git a/tuned_lens/scripts/ingredients.py b/tuned_lens/scripts/ingredients.py index 3adcfcc..86ce6df 100644 --- a/tuned_lens/scripts/ingredients.py +++ b/tuned_lens/scripts/ingredients.py @@ -19,7 +19,7 @@ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn.parallel import DistributedDataParallel as DDP -from torchdata import dataloader2, datapipes +from torch.utils.data.distributed import DistributedSampler from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -33,6 +33,7 @@ chunk_and_tokenize, ) from tuned_lens.model_surgery import get_transformer_layers +from tuned_lens.muon import Muon from tuned_lens.nn.lenses import Lens from tuned_lens.utils import ( TreeType, @@ -47,11 +48,13 @@ class Data: """Configuration for the dataset.""" - name: list[str] = field(default_factory=lambda: ["the_pile", "all"], nargs="*") + name: list[str] = field( + default_factory=lambda: ["EleutherAI/SmolLM2-135M-10B"], nargs="*" + ) """Name of dataset to use. Can either be a local .jsonl file or a name suitable to be passed to the HuggingFace load_dataset function.""" - split: str = "validation" + split: str = "train" """Split of the dataset to use.""" text_column: str = "text" @@ -195,17 +198,18 @@ class OptimizerOption(enum.Enum): ADAM = "adam" SGD = "sgd" + MUON = "muon" @dataclass class Optimizer: """Configuration for the optimizer.""" - weight_decay: float = 1e-3 + weight_decay: float = 0.01 """Weight decay coefficient.""" lr_scale: float = 1.0 - """The default LR (1e-3 for Adam, 1.0 for SGD) is scaled by this factor.""" + """The default LR (0.01 for Adam, 1.0 for SGD) is scaled by this factor.""" momentum: float = 0.9 """Momentum coefficient for SGD, or beta1 for Adam.""" @@ -235,10 +239,11 @@ def create_scheduler( scheduler = get_linear_schedule_with_warmup( opt, self.warmup_steps, num_steps - self.warmup_steps ) - return scheduler - def create_optim(self, params: list[th.nn.Parameter]) -> th.optim.Optimizer: + def create_optim( + self, params: list[th.nn.Parameter], fsdp: bool = False + ) -> list[th.optim.Optimizer]: """Create the optimizer.""" # Don't train things that don't need gradients β = self.momentum @@ -262,7 +267,27 @@ def create_optim(self, params: list[th.nn.Parameter]) -> th.optim.Optimizer: lr=self.lr_scale * 1e-3, weight_decay=self.weight_decay, ) - opt_class = th.optim.Adam + opt_class = th.optim.AdamW + elif self.optimizer == OptimizerOption.MUON: + # The default Muon LR is 1e-3, but we find we can go 10x higher + lr = self.lr_scale * 0.01 + + # Muon only handles 2D params, AdamW handles the rest + params_1d = [p for p in params if p.ndim != 2 and p.requires_grad] + params_2d = [p for p in params if p.ndim == 2 and p.requires_grad] + opts = [ + Muon( + params_2d, + ddp=not fsdp, + lr=lr, + momentum=β, + weight_decay=self.weight_decay, + ), + th.optim.AdamW( + params_1d, lr=lr, weight_decay=self.weight_decay, betas=(β, 0.999) + ), + ] + return opts else: raise ValueError(f"Unknown optimizer '{self.optimizer}'") @@ -271,7 +296,7 @@ def create_optim(self, params: list[th.nn.Parameter]) -> th.optim.Optimizer: else: opt = opt_class(params, **config) # type: ignore[call-arg] - return opt + return [opt] def per_parameter_optim_state_size(self) -> int: """The number of elements in the optimizer state per parameter.""" @@ -365,30 +390,27 @@ def distribute_lens(self, lens: Lens) -> Union[DDP, Lens]: """Distribute the lens using DistributedDataParallel and send lens to device.""" logger.debug(f"Sending Lens to device {self.device}") if self.world_size > 1: - lens.to(self.device) logger.debug("Distributing the lens across the GPUS using DDP ...") return DDP(lens, device_ids=[self.local_rank], find_unused_parameters=True) - else: - return lens.to(self.device) - def dataloader( - self, - dataset: Dataset, - ) -> dataloader2.DataLoader2: - """Shard the dataset based on local rank.""" - dp = datapipes.iter.IterableWrapper(dataset) - if self.world_size > 1: - rs = dataloader2.DistributedReadingService() - else: - rs = None + return lens - if self.dataloader_shuffle: - dp = dp.shuffle() - - dp = dp.sharding_filter() - dp = dp.batch(self.per_gpu_batch_size) - dp = dp.collate() - return dataloader2.DataLoader2(dp, reading_service=rs) + def dataloader(self, dataset: Dataset, seed: int) -> th.utils.data.DataLoader: + """Shard the dataset based on local rank.""" + sampler = DistributedSampler( + dataset, + num_replicas=self.world_size, + rank=self.rank, + seed=seed, + shuffle=self.dataloader_shuffle, + ) + dl = th.utils.data.DataLoader( + dataset, + batch_size=self.per_gpu_batch_size, + sampler=sampler, + pin_memory=True, + ) + return dl def init(self): """Initialize distributed process group if started with elastic launch.""" @@ -396,7 +418,9 @@ def init(self): local_rank = os.environ.get("LOCAL_RANK") if local_rank is not None: dist.init_process_group( - "nccl", timeout=timedelta(seconds=self.nccl_timeout) + "nccl", + device_id=th.device("cuda", int(local_rank)), + timeout=timedelta(seconds=self.nccl_timeout), ) assert ( th.cuda.is_available() diff --git a/tuned_lens/scripts/train_loop.py b/tuned_lens/scripts/train_loop.py index 4898855..ef37434 100644 --- a/tuned_lens/scripts/train_loop.py +++ b/tuned_lens/scripts/train_loop.py @@ -14,7 +14,6 @@ from torch.distributed.optim import ZeroRedundancyOptimizer from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR -from torchdata.dataloader2 import DataLoader2 from tqdm.auto import trange from transformers import PreTrainedModel @@ -36,10 +35,10 @@ class LossChoice(enum.Enum): class State: """All of the stateful information in the training loop.""" - dataloader: DataLoader2 + dataloader: th.utils.data.DataLoader lens: TunedLens - opt: Optimizer - scheduler: LambdaLR + opts: list[Optimizer] + schedulers: list[LambdaLR] wandb_id: Optional[str] nats_to_bpb: float step: int = 0 @@ -51,22 +50,23 @@ def load(self, snapshot_file: Path, device: th.device) -> None: self.step = snapshot["step"] self.wandb_id = snapshot["wandb_id"] self.lens.load_state_dict(snapshot["lens"]) - self.opt.load_state_dict(snapshot["optim"]) - self.scheduler.load_state_dict(snapshot["scheduler"]) - self.dataloader.load_state_dict(snapshot["dataloader"]) + for opt, state in zip(self.opts, snapshot["optim"]): + opt.load_state_dict(state) + for scheduler, state in zip(self.schedulers, snapshot["scheduler"]): + scheduler.load_state_dict(state) def save(self, snapshot_file: Path) -> None: """Save a snapshot file.""" logger.info(f"Saving snapshot to {snapshot_file}...") - if isinstance(self.opt, ZeroRedundancyOptimizer): - self.opt.consolidate_state_dict() + for opt in self.opts: + if isinstance(opt, ZeroRedundancyOptimizer): + opt.consolidate_state_dict() th.save( { "lens": self.lens.state_dict(), - "optim": self.opt.state_dict(), - "scheduler": self.scheduler.state_dict(), - "dataloader": self.dataloader.state_dict(), + "optim": [opt.state_dict() for opt in self.opts], + "scheduler": [scheduler.state_dict() for scheduler in self.schedulers], "step": self.step, "wandb_id": self.wandb_id, }, @@ -99,6 +99,9 @@ class Train: lens_name_or_path: Optional[str] = field(alias=["-l"], default=None) """Name of a pretrained lens to load for fine-tuning.""" + final_norm_path: Optional[str] = None + """Path to a final layer norm in the model.""" + bias_only: Optional[bool] = field(action="store_true") """Train only the bias term.""" @@ -133,24 +136,23 @@ def get_lens(self, model: PreTrainedModel) -> TunedLens: """Load or create a TunedLens model.""" if self.lens_name_or_path is None: logger.info("Randomly initializing lens...") - lens = TunedLens.from_model(model) + + if self.final_norm_path is not None: + final_norm = model.get_submodule(self.final_norm_path) + else: + final_norm = None + + lens = TunedLens.from_model(model, final_norm=final_norm) else: logger.info("Loading pretrained lens...") lens = TunedLens.from_model_and_pretrained(model, self.lens_name_or_path) - dtypes = {p.dtype for p in lens.parameters()} - assert ( - len(dtypes) == 1 - ), f"Expected all parameters to have the same dtype, got {dtypes}" - - lens_dtype = next(iter(dtypes)) + lens.float() lens_size = sum(p.numel() * p.element_size() for p in lens.parameters()) # Include the optimizer state in the memory usage num_bytes = lens_size * (self.opt.per_parameter_optim_state_size() + 1) - logger.info( - f"Tuned lens memory usage: {num_bytes / 2 ** 20:.2f} MB in {lens_dtype}" - ) + logger.info(f"Tuned lens memory usage: {num_bytes / 2 ** 20:.2f} MB") if self.bias_only: logger.info("Freezing the matrix weights to train only the bias terms.") @@ -186,7 +188,6 @@ def _init_logging(self, model_name: str, lens: TunedLens, wandb_id: Optional[str def _log( self, - opt: th.optim.Optimizer, step: int, losses: dict[str, list[float]], tuned_lens: TunedLens, @@ -206,24 +207,6 @@ def _log( # Log statistics about optimizer & probes for i, probe in enumerate(tuned_lens): name = "input" if i == 0 else f"{i - 1}.ffn" - states = [opt.state[p] for p in probe.parameters()] - - # Approximate the true grad norm using the optimizer's moving - # avg - corr = 1 - self.opt.momentum**step - if self.opt.optimizer == "sgd" and not self.opt.zero: - log_dict["grad_norm/" + name] = th.cat( - [ - # Undo PyTorch's scaling of the gradient by - # 1 / (1 - β) - (1 - self.opt.momentum) * s["momentum_buffer"].flatten() / corr - for s in states - ] - ).norm() - elif self.opt.optimizer == "adam" and not self.opt.zero: - log_dict["grad_norm/" + name] = th.cat( - [s["exp_avg"].flatten() / corr for s in states] - ).norm() if isinstance(probe, th.nn.Linear): log_dict["bias_norm/" + name] = probe.bias.data.norm() @@ -333,12 +316,13 @@ def setup(self) -> tuple[State, Union[PreTrainedModel, FSDP], int]: assert model and tokenizer and data and lens and nats_to_bpb logger.debug(f"Creating data loader and setting seed to {self.seed} ...") - dl = self.dist.dataloader(data) - dl.seed(self.seed) + dl = self.dist.dataloader(data, self.seed) logger.debug("Creating optimizer and scheduler ...") params = [p for p in lens.parameters() if p.requires_grad] - opt = self.opt.create_optim(params) - scheduler = self.opt.create_scheduler(opt, self.num_steps) + + lens.to(self.dist.device) + opts = self.opt.create_optim(params, fsdp=self.dist.fsdp) + schedulers = [self.opt.create_scheduler(opt, self.num_steps) for opt in opts] ddp_lens = self.dist.distribute_lens(lens) @@ -346,8 +330,8 @@ def setup(self) -> tuple[State, Union[PreTrainedModel, FSDP], int]: step=0, wandb_id=self._get_wandb_id(), lens=ddp_lens, # type: ignore - opt=opt, - scheduler=scheduler, + opts=opts, + schedulers=schedulers, dataloader=dl, nats_to_bpb=nats_to_bpb, ) @@ -451,13 +435,15 @@ def execute(self): step, rem = divmod(batch_idx, grad_acc_steps) if rem == grad_acc_steps - 1: th.nn.utils.clip_grad_norm_(state.lens.parameters(), 1.0) - state.opt.step() - state.opt.zero_grad(set_to_none=False) - state.scheduler.step() + for opt in state.opts: + opt.step() + opt.zero_grad(set_to_none=False) + for scheduler in state.schedulers: + scheduler.step() # Unwrap the lens from DDP if needed lens = getattr(state.lens, "module", state.lens) - self._log(state.opt, step, losses, lens, state.nats_to_bpb) + self._log(step, losses, lens, state.nats_to_bpb) losses.clear() state.step = step + 1 if (