diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 1951480..50d0638 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -25,6 +25,7 @@ from emerging_optimizers import registry, scalar_optimizers, utils from emerging_optimizers.soap import soap_utils from emerging_optimizers.utils import FP32MatmulPrecT +from emerging_optimizers.utils import eig as eig_utils __all__ = [ @@ -75,6 +76,9 @@ class SOAP(opt_mixin.WeightDecayMixin, optim.Optimizer): use_kl_shampoo: Whether to use KL-Shampoo correction. correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for testability because reference implementation of Soap doesn't bias correct shampoo beta. + use_batched_qr: Whether to batch all QR decompositions across parameters into a single + multi-stream dispatch using the batched_qr package. Requires the batched_qr package. + batched_qr_num_streams: Number of CUDA streams for the batched QR dispatch. """ def __init__( @@ -101,6 +105,8 @@ def __init__( max_update_rms: float = 0.0, use_kl_shampoo: bool = False, correct_shampoo_beta_bias: bool | None = None, + use_batched_qr: bool = False, + batched_qr_num_streams: int = 64, ) -> None: self.precondition_frequency = precondition_frequency self.adam_warmup_steps = adam_warmup_steps @@ -116,6 +122,16 @@ def __init__( self.power_iter_steps = power_iter_steps self.max_update_rms = max_update_rms self.use_kl_shampoo = use_kl_shampoo + self.use_batched_qr = use_batched_qr + self.batched_qr_num_streams = batched_qr_num_streams + if use_batched_qr: + try: + import batched_qr # noqa: F401 + except ImportError as e: + raise ImportError( + "batched_qr package is required when use_batched_qr=True. " + "Install with: pip install -e /path/to/batched_qr_torch" + ) from e if correct_shampoo_beta_bias is not None: self.correct_shampoo_beta_bias = correct_shampoo_beta_bias else: @@ -149,6 +165,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: else: loss = closure() + if self.use_batched_qr: + self._step_batched_qr() + return loss + for group in self.param_groups: for p in group["params"]: if p.grad is None: @@ -290,6 +310,251 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: return loss + def _step_batched_qr(self) -> None: + """Performs a single optimization step with batched QR decompositions. + + Restructures the per-parameter loop into collect → batch QR → apply phases + to batch all QR decompositions across all parameters into a single multi-stream dispatch. + """ + from batched_qr import batched_qr_grouped + + # ── Phase 1: init, kronecker factor updates, prepare eigenbasis work ── + param_work: list[dict] = [] + # Flat list of work_item dicts whose "Q" needs batched QR + qr_items: list[dict] = [] + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + curr_iter_1_based = state["step"] + 1 + + if state["step"] == 0: + assert all(key not in state for key in ["exp_avg", "exp_avg_sq", "GG"]), ( + "exp_avg and exp_avg_sq and GG should not be initialized at step 0. " + "Some mismatch has been created likely in checkpointing" + ) + state["exp_avg"] = torch.zeros_like(grad) + state["exp_avg_sq"] = torch.zeros_like(grad) + state["GG"] = init_kronecker_factors( + grad, + precondition_1d=self.precondition_1d, + ) + + # Kronecker factor update (same as baseline path) + if not self.use_kl_shampoo: + kronecker_factor_update_fn = partial( + update_kronecker_factors, + precondition_1d=self.precondition_1d, + ) + else: + if "Q" not in state: + assert state["step"] == 0, ( + f"Q should already be initialized at step {state['step']}, Some mismatch has been created " + "likely in checkpointing" + ) + state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape] + kronecker_factor_update_fn = partial( + update_kronecker_factors_kl_shampoo, + eigenbasis_list=state["Q"], + eps=group["eps"], + ) + + shampoo_beta = group["shampoo_beta"] + if self.correct_shampoo_beta_bias: + shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based) + + torch.cuda.nvtx.range_push("update_kronecker_factors") + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + kronecker_factor_update_fn(kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=shampoo_beta) + torch.cuda.nvtx.range_pop() + + needs_update = _is_eigenbasis_update_step( + state["step"], + self.adam_warmup_steps, + self.precondition_frequency, + ) + use_eigh = self.use_eigh if state["step"] != self.adam_warmup_steps else True + use_qr_batch = needs_update and not use_eigh + + pw: dict = { + "p": p, + "grad": grad, + "group": group, + "state": state, + "curr_iter_1_based": curr_iter_1_based, + "use_qr_batch": use_qr_batch, + "factor_items": [], + } + param_work.append(pw) + + # ── Prepare factors for batched QR ── + torch.cuda.nvtx.range_push("Update eigen basis") + if use_qr_batch: + # Step 1 of eigenbasis update: project momentum back to original basis + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): + state["exp_avg"] = precondition( + state["exp_avg"], + state["Q"], + dims=[[0], [1]], + ) + + # Cast exp_avg_sq to float (matching get_eigenbasis_qr force_float=True) + if state["exp_avg_sq"].dtype != torch.float: + state["exp_avg_sq"] = state["exp_avg_sq"].to(torch.float) + + for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])): + if kf.numel() == 0: + pw["factor_items"].append(None) + continue + + kf_f = kf.to(torch.float) + eig_f = eigenbasis.to(torch.float) + + approx_eigvals = eig_utils.conjugate(kf_f, eig_f, diag=True) + + should_update = True + if self.use_adaptive_criteria: + should_update = not eig_utils.met_approx_eigvals_criteria( + kf_f, approx_eigvals, self.adaptive_update_tolerance + ) + + if should_update: + # Sort eigenvalues and reorder exp_avg_sq + eigenbasis columns + sort_idx = torch.argsort(approx_eigvals, descending=True) + state["exp_avg_sq"] = state["exp_avg_sq"].index_select(ind, sort_idx) + Q = eig_f[:, sort_idx] + + item: dict = {"kf_f": kf_f, "Q": Q} + pw["factor_items"].append(item) + qr_items.append(item) + else: + # Adaptive criteria met — keep existing eigenbasis (cast to float) + pw["factor_items"].append({"skip": True, "eigenbasis": eig_f}) + + elif needs_update and use_eigh: + # First eigenbasis update or user requested eigh — use existing code path + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): + state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum( + kronecker_factor_list=state["GG"], + eigenbasis_list=state.get("Q", None), + exp_avg_sq=state["exp_avg_sq"], + momentum=state["exp_avg"], + use_eigh=True, + use_adaptive_criteria=self.use_adaptive_criteria, + adaptive_update_tolerance=self.adaptive_update_tolerance, + power_iter_steps=self.power_iter_steps, + ) + torch.cuda.nvtx.range_pop() + + # ── Phase 2: batched QR across all parameters ── + if qr_items: + torch.cuda.nvtx.range_push("batched_qr") + for _ in range(self.power_iter_steps): + # Collect: matmul kronecker_factor @ Q for every work item + q_matrices = [] + for item in qr_items: + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): + q_matrices.append(item["kf_f"] @ item["Q"]) + + # Single batched QR dispatch + q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams) + + # Scatter results back + for item, Q_new in zip(qr_items, q_results): + item["Q"] = Q_new + torch.cuda.nvtx.range_pop() + + # ── Finalize eigenbases and project momentum forward ── + for pw in param_work: + if not pw["use_qr_batch"]: + continue + + state = pw["state"] + updated_eigenbasis_list: list[torch.Tensor] = [] + + for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])): + fi = pw["factor_items"][ind] + if fi is None: + # Empty factor (e.g. 1D param with precondition_1d=False) + updated_eigenbasis_list.append(torch.empty(0, device=kf.device)) + elif "skip" in fi: + # Adaptive criteria met — keep existing eigenbasis + updated_eigenbasis_list.append(fi["eigenbasis"]) + else: + updated_eigenbasis_list.append(fi["Q"]) + + state["Q"] = updated_eigenbasis_list + + # Step 3 of eigenbasis update: project momentum to new eigenbasis + with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec): + state["exp_avg"] = precondition( + state["exp_avg"], + updated_eigenbasis_list, + dims=[[0], [0]], + ) + + # ── Phase 3: weight decay, precondition, Adam, parameter update ── + for pw in param_work: + p = pw["p"] + grad = pw["grad"] + group = pw["group"] + state = pw["state"] + + self._apply_weight_decay_inplace( + p, + grad, + group["lr"], + group["weight_decay"], + ) + + grad_projected = grad + torch.cuda.nvtx.range_push("precondition") + if state["step"] >= self.adam_warmup_steps: + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + grad_projected = precondition( + grad=grad, + eigenbasis_list=state["Q"], + dims=[[0], [0]], + ) + torch.cuda.nvtx.range_pop() + + adam_update = scalar_optimizers.calculate_adam_update( + grad_projected, + state["exp_avg"], + state["exp_avg_sq"], + group["betas"], + self.correct_bias, + self.use_nesterov, + pw["curr_iter_1_based"], + group["eps"], + ) + + torch.cuda.nvtx.range_push("precondition") + if state["step"] >= self.adam_warmup_steps: + with utils.fp32_matmul_precision(self.fp32_matmul_prec): + precond_update = precondition( + grad=adam_update, + eigenbasis_list=state.get("Q", None), + dims=[[0], [1]], + ) + else: + precond_update = adam_update + torch.cuda.nvtx.range_pop() + + _clip_update_rms_in_place(precond_update, self.max_update_rms) + p.add_(precond_update, alpha=-group["lr"]) + + state["step"] += 1 + + @torch.no_grad() # type: ignore[misc] def init_kronecker_factors( grad: torch.Tensor,