Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from emerging_optimizers import registry, scalar_optimizers, utils
from emerging_optimizers.soap import soap_utils
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.utils import eig as eig_utils


__all__ = [
Expand Down Expand Up @@ -75,6 +76,9 @@ class SOAP(opt_mixin.WeightDecayMixin, optim.Optimizer):
use_kl_shampoo: Whether to use KL-Shampoo correction.
correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for
testability because reference implementation of Soap doesn't bias correct shampoo beta.
use_batched_qr: Whether to batch all QR decompositions across parameters into a single
multi-stream dispatch using the batched_qr package. Requires the batched_qr package.
batched_qr_num_streams: Number of CUDA streams for the batched QR dispatch.
"""

def __init__(
Expand All @@ -101,6 +105,8 @@ def __init__(
max_update_rms: float = 0.0,
use_kl_shampoo: bool = False,
correct_shampoo_beta_bias: bool | None = None,
use_batched_qr: bool = False,
batched_qr_num_streams: int = 64,
) -> None:
self.precondition_frequency = precondition_frequency
self.adam_warmup_steps = adam_warmup_steps
Expand All @@ -116,6 +122,16 @@ def __init__(
self.power_iter_steps = power_iter_steps
self.max_update_rms = max_update_rms
self.use_kl_shampoo = use_kl_shampoo
self.use_batched_qr = use_batched_qr
self.batched_qr_num_streams = batched_qr_num_streams
if use_batched_qr:
try:
import batched_qr # noqa: F401
except ImportError as e:
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"Install with: pip install -e /path/to/batched_qr_torch"
) from e
Comment on lines +131 to +134
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder install path in error message

The error message contains a literal placeholder pip install -e /path/to/batched_qr_torch that provides no actionable guidance to users. When users encounter this error, they will see the placeholder path and not know where to find the package.

Suggested change
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"Install with: pip install -e /path/to/batched_qr_torch"
) from e
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"See the project README for installation instructions."
) from e

Or provide a concrete PyPI package name / repository URL if available.

if correct_shampoo_beta_bias is not None:
self.correct_shampoo_beta_bias = correct_shampoo_beta_bias
else:
Expand Down Expand Up @@ -149,6 +165,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
else:
loss = closure()

if self.use_batched_qr:
self._step_batched_qr()
return loss

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
Expand Down Expand Up @@ -290,6 +310,251 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
return loss


def _step_batched_qr(self) -> None:
"""Performs a single optimization step with batched QR decompositions.

Restructures the per-parameter loop into collect → batch QR → apply phases
to batch all QR decompositions across all parameters into a single multi-stream dispatch.
"""
from batched_qr import batched_qr_grouped

# ── Phase 1: init, kronecker factor updates, prepare eigenbasis work ──
param_work: list[dict] = []
# Flat list of work_item dicts whose "Q" needs batched QR
qr_items: list[dict] = []

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

grad = p.grad
state = self.state[p]

if "step" not in state:
state["step"] = 0

curr_iter_1_based = state["step"] + 1

if state["step"] == 0:
assert all(key not in state for key in ["exp_avg", "exp_avg_sq", "GG"]), (
"exp_avg and exp_avg_sq and GG should not be initialized at step 0. "
"Some mismatch has been created likely in checkpointing"
)
state["exp_avg"] = torch.zeros_like(grad)
state["exp_avg_sq"] = torch.zeros_like(grad)
state["GG"] = init_kronecker_factors(
grad,
precondition_1d=self.precondition_1d,
)

# Kronecker factor update (same as baseline path)
if not self.use_kl_shampoo:
kronecker_factor_update_fn = partial(
update_kronecker_factors,
precondition_1d=self.precondition_1d,
)
else:
if "Q" not in state:
assert state["step"] == 0, (
f"Q should already be initialized at step {state['step']}, Some mismatch has been created "
"likely in checkpointing"
)
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
kronecker_factor_update_fn = partial(
update_kronecker_factors_kl_shampoo,
eigenbasis_list=state["Q"],
eps=group["eps"],
)

shampoo_beta = group["shampoo_beta"]
if self.correct_shampoo_beta_bias:
shampoo_beta = 1 - (1 - shampoo_beta) / (1 - shampoo_beta**curr_iter_1_based)

torch.cuda.nvtx.range_push("update_kronecker_factors")
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
kronecker_factor_update_fn(kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=shampoo_beta)
torch.cuda.nvtx.range_pop()

needs_update = _is_eigenbasis_update_step(
state["step"],
self.adam_warmup_steps,
self.precondition_frequency,
)
use_eigh = self.use_eigh if state["step"] != self.adam_warmup_steps else True
use_qr_batch = needs_update and not use_eigh

pw: dict = {
"p": p,
"grad": grad,
"group": group,
"state": state,
"curr_iter_1_based": curr_iter_1_based,
"use_qr_batch": use_qr_batch,
"factor_items": [],
}
param_work.append(pw)

# ── Prepare factors for batched QR ──
torch.cuda.nvtx.range_push("Update eigen basis")
if use_qr_batch:
# Step 1 of eigenbasis update: project momentum back to original basis
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
state["Q"],
dims=[[0], [1]],
)

# Cast exp_avg_sq to float (matching get_eigenbasis_qr force_float=True)
if state["exp_avg_sq"].dtype != torch.float:
state["exp_avg_sq"] = state["exp_avg_sq"].to(torch.float)

for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
if kf.numel() == 0:
pw["factor_items"].append(None)
continue

kf_f = kf.to(torch.float)
eig_f = eigenbasis.to(torch.float)

approx_eigvals = eig_utils.conjugate(kf_f, eig_f, diag=True)

should_update = True
if self.use_adaptive_criteria:
should_update = not eig_utils.met_approx_eigvals_criteria(
kf_f, approx_eigvals, self.adaptive_update_tolerance
)

if should_update:
# Sort eigenvalues and reorder exp_avg_sq + eigenbasis columns
sort_idx = torch.argsort(approx_eigvals, descending=True)
state["exp_avg_sq"] = state["exp_avg_sq"].index_select(ind, sort_idx)
Q = eig_f[:, sort_idx]

item: dict = {"kf_f": kf_f, "Q": Q}
pw["factor_items"].append(item)
qr_items.append(item)
else:
# Adaptive criteria met — keep existing eigenbasis (cast to float)
pw["factor_items"].append({"skip": True, "eigenbasis": eig_f})

elif needs_update and use_eigh:
# First eigenbasis update or user requested eigh — use existing code path
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["Q"], state["exp_avg"], state["exp_avg_sq"] = update_eigenbasis_and_momentum(
kronecker_factor_list=state["GG"],
eigenbasis_list=state.get("Q", None),
exp_avg_sq=state["exp_avg_sq"],
momentum=state["exp_avg"],
use_eigh=True,
use_adaptive_criteria=self.use_adaptive_criteria,
adaptive_update_tolerance=self.adaptive_update_tolerance,
power_iter_steps=self.power_iter_steps,
)
torch.cuda.nvtx.range_pop()

# ── Phase 2: batched QR across all parameters ──
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
# Collect: matmul kronecker_factor @ Q for every work item
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])

# Single batched QR dispatch
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)

# Scatter results back
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()

# ── Finalize eigenbases and project momentum forward ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue

state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []

for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
# Empty factor (e.g. 1D param with precondition_1d=False)
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
# Adaptive criteria met — keep existing eigenbasis
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])

state["Q"] = updated_eigenbasis_list

# Step 3 of eigenbasis update: project momentum to new eigenbasis
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)
Comment on lines +458 to +502
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: exp_avg left in original basis when all QR items satisfy adaptive criteria

When use_adaptive_criteria=True, Phase 1 unconditionally projects state["exp_avg"] back to the original basis (via precondition(..., dims=[[0], [1]])) for every parameter with use_qr_batch=True (line 403). The matching forward projection (dims=[[0], [0]]) in Phase 2 finalization (lines 496–502) is nested inside if qr_items:.

If on a given step all factors for all parameters that need QR updates satisfy the adaptive criteria, no items are added to qr_items. The entire Phase 2 block (including finalization) is skipped, and state["exp_avg"] remains in the original (unpreconditioned) basis. In Phase 3, it is then fed directly to calculate_adam_update as if it were in the eigenbasis, producing corrupted parameter updates.

Fix: Move the finalization loop (lines 476–502) outside the if qr_items: guard so that momentum forward-projection always executes for every parameter whose exp_avg was projected back in Phase 1, regardless of whether qr_items is empty:

Suggested change
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
# Collect: matmul kronecker_factor @ Q for every work item
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])
# Single batched QR dispatch
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)
# Scatter results back
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()
# ── Finalize eigenbases and project momentum forward ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue
state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []
for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
# Empty factor (e.g. 1D param with precondition_1d=False)
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
# Adaptive criteria met — keep existing eigenbasis
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])
state["Q"] = updated_eigenbasis_list
# Step 3 of eigenbasis update: project momentum to new eigenbasis
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)
# ── Phase 2: batched QR across all parameters ──
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()
# ── Finalize eigenbases and project momentum forward (always, for every param that
# projected exp_avg back in Phase 1) ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue
state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []
for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])
state["Q"] = updated_eigenbasis_list
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)


# ── Phase 3: weight decay, precondition, Adam, parameter update ──
for pw in param_work:
p = pw["p"]
grad = pw["grad"]
group = pw["group"]
state = pw["state"]

self._apply_weight_decay_inplace(
p,
grad,
group["lr"],
group["weight_decay"],
)

grad_projected = grad
torch.cuda.nvtx.range_push("precondition")
if state["step"] >= self.adam_warmup_steps:
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
grad_projected = precondition(
grad=grad,
eigenbasis_list=state["Q"],
dims=[[0], [0]],
)
torch.cuda.nvtx.range_pop()

adam_update = scalar_optimizers.calculate_adam_update(
grad_projected,
state["exp_avg"],
state["exp_avg_sq"],
group["betas"],
self.correct_bias,
self.use_nesterov,
pw["curr_iter_1_based"],
group["eps"],
)

torch.cuda.nvtx.range_push("precondition")
if state["step"] >= self.adam_warmup_steps:
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
precond_update = precondition(
grad=adam_update,
eigenbasis_list=state.get("Q", None),
dims=[[0], [1]],
)
else:
precond_update = adam_update
torch.cuda.nvtx.range_pop()

_clip_update_rms_in_place(precond_update, self.max_update_rms)
p.add_(precond_update, alpha=-group["lr"])

state["step"] += 1


@torch.no_grad() # type: ignore[misc]
def init_kronecker_factors(
grad: torch.Tensor,
Expand Down