Skip to content

Add batched QR support to SOAP optimizer#118

Open
RPrenger wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
RPrenger:rprenger/batched_qr
Open

Add batched QR support to SOAP optimizer#118
RPrenger wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
RPrenger:rprenger/batched_qr

Conversation

@RPrenger
Copy link

@RPrenger RPrenger commented Mar 5, 2026

SOAP calls torch.linalg.qr once per Kronecker factor per parameter per
power iteration step. For a model with N 2D parameters that is 2N
sequential QR calls per eigenbasis update, each paying ~7.5 ms of CPU
dispatch overhead.

Add a use_batched_qr flag that restructures SOAP.step() into three
phases — collect, batch-QR, apply — so every QR across all parameters is
dispatched in a single batched_qr_grouped() call using multiple CUDA
streams. When the flag is False (default) the original code path is
unchanged.

Benchmarked on H100 (12-layer GPT-2-style transformer, 96 QR ops/step):

hidden=1024, mlp=4096 :  1225 ms → 348 ms  (3.5× speedup)
hidden=2048, mlp=4096 :  2210 ms → 887 ms  (2.5× speedup)

  SOAP calls torch.linalg.qr once per Kronecker factor per parameter per
  power iteration step.  For a model with N 2D parameters that is 2N
  sequential QR calls per eigenbasis update, each paying ~7.5 ms of CPU
  dispatch overhead.

  Add a `use_batched_qr` flag that restructures `SOAP.step()` into three
  phases — collect, batch-QR, apply — so every QR across all parameters is
  dispatched in a single `batched_qr_grouped()` call using multiple CUDA
  streams.  When the flag is False (default) the original code path is
  unchanged.

  Benchmarked on H100 (12-layer GPT-2-style transformer, 96 QR ops/step):

    hidden=1024, mlp=4096 :  1225 ms → 348 ms  (3.5× speedup)
    hidden=2048, mlp=4096 :  2210 ms → 887 ms  (2.5× speedup)
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR adds an optional use_batched_qr code path to the SOAP optimizer that restructures the per-parameter step() loop into three phases (collect, batch-QR, apply), dispatching all QR decompositions in a single batched_qr_grouped() call using multiple CUDA streams. The approach is architecturally sound and the fast path (default use_batched_qr=False) is untouched.

Key findings:

  • Correctness bug (critical when use_adaptive_criteria=True): In Phase 1, the code unconditionally projects state["exp_avg"] back to the original basis for every parameter that needs a QR update. The matching forward projection is inside if qr_items:. If all factors for all such parameters satisfy the adaptive criteria on a given step, qr_items is empty, the if qr_items: block is skipped entirely, and exp_avg is left in the original basis going into Phase 3's Adam update—producing a corrupted parameter update. The fix is to move the finalization/forward-projection loop outside the if qr_items: guard.
  • Placeholder install path: The ImportError message contains pip install -e /path/to/batched_qr_torch, a literal placeholder that provides no actionable guidance to users when they lack the required batched_qr dependency.

Confidence Score: 2/5

  • Unsafe to merge: the new use_batched_qr code path has a critical correctness bug when combined with use_adaptive_criteria=True.
  • The new batched QR path contains a logic bug where momentum (exp_avg) is projected back to the original basis in Phase 1 but the forward projection (Phase 2 finalization) is guarded by if qr_items:. When all parameters' QR items satisfy adaptive criteria on a given step, qr_items is empty, finalization is skipped, and momentum stays in the wrong basis through Phase 3's Adam update, silently corrupting the parameter update. While the default code path (use_batched_qr=False) is unchanged and existing users are unaffected, the supported batched-with-adaptive-criteria configuration is broken and must be fixed before this feature can be production-ready.
  • emerging_optimizers/soap/soap.py — specifically Phase 2 finalization (lines 475–502) needs to move outside the if qr_items: guard, and the ImportError message (lines 131–134) needs a helpful path/URL instead of a placeholder.

Last reviewed commit: 3eed6a9

Comment on lines +458 to +502
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
# Collect: matmul kronecker_factor @ Q for every work item
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])

# Single batched QR dispatch
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)

# Scatter results back
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()

# ── Finalize eigenbases and project momentum forward ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue

state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []

for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
# Empty factor (e.g. 1D param with precondition_1d=False)
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
# Adaptive criteria met — keep existing eigenbasis
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])

state["Q"] = updated_eigenbasis_list

# Step 3 of eigenbasis update: project momentum to new eigenbasis
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: exp_avg left in original basis when all QR items satisfy adaptive criteria

When use_adaptive_criteria=True, Phase 1 unconditionally projects state["exp_avg"] back to the original basis (via precondition(..., dims=[[0], [1]])) for every parameter with use_qr_batch=True (line 403). The matching forward projection (dims=[[0], [0]]) in Phase 2 finalization (lines 496–502) is nested inside if qr_items:.

If on a given step all factors for all parameters that need QR updates satisfy the adaptive criteria, no items are added to qr_items. The entire Phase 2 block (including finalization) is skipped, and state["exp_avg"] remains in the original (unpreconditioned) basis. In Phase 3, it is then fed directly to calculate_adam_update as if it were in the eigenbasis, producing corrupted parameter updates.

Fix: Move the finalization loop (lines 476–502) outside the if qr_items: guard so that momentum forward-projection always executes for every parameter whose exp_avg was projected back in Phase 1, regardless of whether qr_items is empty:

Suggested change
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
# Collect: matmul kronecker_factor @ Q for every work item
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])
# Single batched QR dispatch
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)
# Scatter results back
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()
# ── Finalize eigenbases and project momentum forward ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue
state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []
for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
# Empty factor (e.g. 1D param with precondition_1d=False)
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
# Adaptive criteria met — keep existing eigenbasis
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])
state["Q"] = updated_eigenbasis_list
# Step 3 of eigenbasis update: project momentum to new eigenbasis
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)
# ── Phase 2: batched QR across all parameters ──
if qr_items:
torch.cuda.nvtx.range_push("batched_qr")
for _ in range(self.power_iter_steps):
q_matrices = []
for item in qr_items:
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
q_matrices.append(item["kf_f"] @ item["Q"])
q_results = batched_qr_grouped(q_matrices, num_streams=self.batched_qr_num_streams)
for item, Q_new in zip(qr_items, q_results):
item["Q"] = Q_new
torch.cuda.nvtx.range_pop()
# ── Finalize eigenbases and project momentum forward (always, for every param that
# projected exp_avg back in Phase 1) ──
for pw in param_work:
if not pw["use_qr_batch"]:
continue
state = pw["state"]
updated_eigenbasis_list: list[torch.Tensor] = []
for ind, (kf, eigenbasis) in enumerate(zip(state["GG"], state["Q"])):
fi = pw["factor_items"][ind]
if fi is None:
updated_eigenbasis_list.append(torch.empty(0, device=kf.device))
elif "skip" in fi:
updated_eigenbasis_list.append(fi["eigenbasis"])
else:
updated_eigenbasis_list.append(fi["Q"])
state["Q"] = updated_eigenbasis_list
with utils.fp32_matmul_precision(self.qr_fp32_matmul_prec):
state["exp_avg"] = precondition(
state["exp_avg"],
updated_eigenbasis_list,
dims=[[0], [0]],
)

Comment on lines +131 to +134
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"Install with: pip install -e /path/to/batched_qr_torch"
) from e
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder install path in error message

The error message contains a literal placeholder pip install -e /path/to/batched_qr_torch that provides no actionable guidance to users. When users encounter this error, they will see the placeholder path and not know where to find the package.

Suggested change
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"Install with: pip install -e /path/to/batched_qr_torch"
) from e
raise ImportError(
"batched_qr package is required when use_batched_qr=True. "
"See the project README for installation instructions."
) from e

Or provide a concrete PyPI package name / repository URL if available.

@skyw
Copy link
Contributor

skyw commented Mar 5, 2026

Thanks @RPrenger, this is on my list but never got time to.

Couple of questions:

  • Is this done? looks a lot of things are still missing. If not ready, please mark it draft.
  • what is batched_qr? google pypi didn't find it. cuSolver does have batched API. Depends on the natural of the code, will need to decide whether/how we can make a dependency.

Code itself also needs more work, the change is too intrusive. DCO check failed. instruction is in CONTRINBUTING.md.

@skyw
Copy link
Contributor

skyw commented Mar 5, 2026

Also a heads up, I'm refactoring SOAP before megatron integration, like #117. will need to rebase once refactoring is done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants