Spectron optimizer for low-rank LLM pretraining #104
Spectron optimizer for low-rank LLM pretraining #104mkhona-nvidia wants to merge 15 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
…ctors Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Greptile SummaryAdds Spectron, a low-rank spectral optimizer with orthogonalized momentum for LLM pretraining based on https://arxiv.org/abs/2602.12429. Maintains weights as low-rank factorizations Major changes:
Critical issues preventing production use:
Confidence Score: 1/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
Start([Optimizer Step]) --> CheckGrad{Gradient exists?}
CheckGrad -->|No| End([Skip parameter])
CheckGrad -->|Yes| InitCheck{First step?}
InitCheck -->|Yes| SVDInit[SVD Initialization:<br/>W = U·S·V^T<br/>A = U·√S, B = V·√S]
SVDInit --> InitState[Initialize:<br/>momentum_A, momentum_B<br/>u_A, u_B vectors]
InitState --> Compute
InitCheck -->|No| Compute[Compute factor gradients:<br/>grad_A = grad @ B<br/>grad_B = grad^T @ A]
Compute --> WD[Apply weight decay<br/>to both factors]
WD --> Momentum[Update momentum:<br/>momentum_A ← β·momentum_A + 1-β·grad_A<br/>momentum_B ← β·momentum_B + 1-β·grad_B]
Momentum --> NS[Orthogonalize using<br/>Newton-Schulz iteration<br/>requires float32]
NS --> PowerIter[Power iteration:<br/>estimate σ_A, σ_B<br/>spectral radii]
PowerIter --> Scale[Scale learning rate:<br/>η_scaled = η / σ_A + σ_B + 1]
Scale --> Update[Update factors:<br/>A ← A - η_scaled·orth_momentum_A<br/>B ← B - η_scaled·orth_momentum_B]
Update --> Reconstruct[Reconstruct weight:<br/>W ← A @ B^T]
Reconstruct --> End
Last reviewed commit: d2686bb |
| from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * | ||
| from emerging_optimizers.orthogonalized_optimizers.scion import * | ||
| from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * | ||
| from emerging_optimizers.orthogonalized_optimizers.spectron import * No newline at end of file |
There was a problem hiding this comment.
Missing trailing newline
The file is missing a trailing newline after the new import line. This is flagged by most linters and POSIX standards, and the previous version of the file had one.
| from emerging_optimizers.orthogonalized_optimizers.spectron import * | |
| from emerging_optimizers.orthogonalized_optimizers.spectron import * |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
|
/ok to test 326f3f6 |
|
/ok to test 326f3f6 |
| factor_B.add_(orth_momentum_B, alpha=-scaled_lr) | ||
|
|
||
| # Reconstruct full weight matrix: W = A @ B^T | ||
| p.copy_(factor_A @ factor_B.mT) |
There was a problem hiding this comment.
I am guessing this reconstruction is for the compatibility with the rest of the library. Otherwise the whole implementation looks correct.
There was a problem hiding this comment.
I leave the weights of the model as a single matrix, but do the low-rank decomposition as optimizer states (rather than having the low-rank factored weights as 2 separate matrices in the model, which make it harder to access them inside the optimizer). This is functionally identical but makes the SW easier to use
Pauljanson002
left a comment
There was a problem hiding this comment.
This implementation is correct with a minor difference. In our work we train the models with only factors. In this implementation the model weights remains in the dense form but optimization happens with low rank factors, reducing optimization state.
Signed-off-by: mikail <mkhona@nvidia.com>
| with utils.fp32_matmul_precision("highest"): | ||
| grad_A = grad @ factor_B # shape: (m, r) | ||
| grad_B = grad.mT @ factor_A # shape: (n, r) | ||
|
|
There was a problem hiding this comment.
Gradient dtype mismatch with non-fp32 parameters
grad = p.grad inherits p's dtype, but factor_B is always float32 (initialized from torch.linalg.svd(p.float(), ...)). When the parameter is bfloat16 — the standard dtype for LLM pretraining, which is the stated use case — the line grad @ factor_B will raise a RuntimeError at runtime:
RuntimeError: expected scalar type Float but found BFloat16
Even if PyTorch silently promotes the dtype in some contexts, momentum_A.lerp_(grad_A, ...) on line 187 will then fail because momentum_A is float32 but grad_A would be bfloat16.
The gradient should be explicitly cast to float32 before the matmul:
| with utils.fp32_matmul_precision("highest"): | |
| grad_A = grad @ factor_B # shape: (m, r) | |
| grad_B = grad.mT @ factor_A # shape: (n, r) | |
| with utils.fp32_matmul_precision("highest"): | |
| grad_A = grad.float() @ factor_B # shape: (m, r) | |
| grad_B = grad.float().mT @ factor_A # shape: (n, r) |
Signed-off-by: mikail <mkhona@nvidia.com>
Added the Spectron optimizer
Also added power iteration and rayleigh coefficient method to get spectral norm to
utils/eig.pyBased on https://arxiv.org/abs/2602.12429