From a70ed893a4bfbafe5e4b0934dab5cf9fc6136138 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 14:23:41 -0800 Subject: [PATCH 01/15] collect report for CPU tests Signed-off-by: Hao Wu --- tests/ci/L0_Tests_CPU.sh | 17 +++++++++++++---- tests/test_distributed_muon_utils_cpu.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index ca6a1a9b..30c09b2b 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -13,10 +13,19 @@ # limitations under the License. export TORCH_COMPILE_DISABLE=1 +mkdir -p test-results/tests/ + error=0 -torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 -torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 -coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1 -coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1 +for n in 8 4; do + torchrun --nproc_per_node=$n --no-python coverage run -p \ + tests/test_distributed_muon_utils_cpu.py \ + --xml_output_file="test-results/tests/test_distributed_muon_utils_cpu_n${n}.xml" \ + -v -2 || error=1 +done + +for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py"; do + report_name="test-results/${test}.xml" + coverage run -p --source=emerging_optimizers $test --device=cpu -v -2 --xml_output_file="$report_name" || error=1 +done exit "${error}" diff --git a/tests/test_distributed_muon_utils_cpu.py b/tests/test_distributed_muon_utils_cpu.py index f9ebddd6..7a0f1b8a 100644 --- a/tests/test_distributed_muon_utils_cpu.py +++ b/tests/test_distributed_muon_utils_cpu.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys import numpy as np import torch @@ -211,6 +212,17 @@ def test_1step_close_to_non_distributed(self, shape, partition_dim, tp_mode): if __name__ == "__main__": torch.distributed.init_process_group(backend="gloo") torch.set_float32_matmul_precision("highest") + + rank = torch.distributed.get_rank() + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + + # Attach rank to the output file name + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + absltest.main() torch.distributed.destroy_process_group() From 1ca40c3985b8225c79fcd610710719ab93317bef Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 15:16:28 -0800 Subject: [PATCH 02/15] add test for empty dim Signed-off-by: Hao Wu --- tests/test_soap_utils.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index a3a2e4f8..b4f27be7 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -102,16 +102,39 @@ def test_get_eigenbasis_qr(self, N: int, M: int) -> None: # Also check that "exp_avg_sq" remains in the state with same shape if not merging self.assertEqual(exp_avg_sq_new.shape, (M, N)) + def test_get_eigenbasis_qr_empty_factor(self) -> None: + """Tests get_eigenbasis_qr with an empty (numel()==0) kronecker factor.""" + torch.manual_seed(0) + N = 4 + g = torch.randn(N, N, device=self.device) + L = g.mm(g.t()).float() + + empty_factor = torch.empty(0, 0, device=self.device) + kronecker_factor_list = [L, empty_factor] + eigenbasis_list = [torch.randn(N, N, device=self.device), torch.empty(0, device=self.device)] + exp_avg_sq = torch.abs(torch.randn(N, N, device=self.device)) + + Q_new_list, exp_avg_sq_new = soap_utils.get_eigenbasis_qr( + kronecker_factor_list=kronecker_factor_list, + eigenbasis_list=eigenbasis_list, + exp_avg_sq=exp_avg_sq, + ) + + self.assertEqual(len(Q_new_list), 2) + self.assertEqual(Q_new_list[0].shape, (N, N)) + self.assertEqual(Q_new_list[1].numel(), 0) + @parameterized.parameters( # type: ignore[misc] {"dims": [128, 512]}, {"dims": []}, + {"dims": [64, 0, 32]}, ) def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: """Tests the get_eigenbasis_eigh function.""" kronecker_factor_list = [] for dim in dims: if dim == 0: - kronecker_factor_list.append(torch.empty(0, 0)) + kronecker_factor_list.append(torch.empty(0)) continue k_factor = torch.randn(dim, dim, device=self.device) @@ -125,7 +148,7 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: for i, Q in enumerate(Q_list): orig_dim = dims[i] if orig_dim == 0: - self.assertEqual(Q.shape, (0, 0)) + self.assertEqual(Q.shape, (0,)) continue self.assertEqual(Q.shape, (orig_dim, orig_dim)) From 3b0fc7e43371622ca6816da5932ecd2639e6dc9e Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 15:55:27 -0800 Subject: [PATCH 03/15] add test for all_eigenbases_met_criteria, fix bug Signed-off-by: Hao Wu --- emerging_optimizers/utils/eig.py | 4 ++-- tests/test_soap_utils.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index d8343fe5..93dab088 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -115,12 +115,12 @@ def met_approx_eigvals_criteria( tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. Returns: - perform_update: Whether to update eigenbasis this iteration + Whether eigenbasis met criteria and don't need to be updated """ matrix_norm = torch.linalg.norm(kronecker_factor) diagonal_norm = torch.linalg.norm(approx_eigvals) - return diagonal_norm >= (1 - tolerance) * matrix_norm + return diagonal_norm <= (1 - tolerance) * matrix_norm def orthogonal_iteration( diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index b4f27be7..43c9d610 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -176,6 +176,32 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: msg=f"Matrix {i} was not properly diagonalized. Off-diagonal norm: {off_diagonal_norm}", ) + def test_all_eigenbases_met_criteria_empty_list_returns_true(self) -> None: + kronecker_factor_list = [] + eigenbasis_list = [] + self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + + @parameterized.parameters( + {"N": 4}, + {"N": 16}, + {"N": 33}, + ) + def test_all_eigenbases_met_criteria_random_eigenbasis_returns_false(self, N: int) -> None: + kronecker_factor_list = [torch.randn(N, N, device=self.device)] + eigenbasis_list = [torch.randn(N, N, device=self.device)] + self.assertFalse(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + + @parameterized.parameters( + {"N": 4}, + {"N": 16}, + {"N": 33}, + ) + def test_all_eigenbases_met_criteria_identity_true_eigenbasis_returns_true(self, N: int) -> None: + kronecker_factor_list = [torch.randn(N, N, device=self.device)] + + eigenbasis_list = [torch.linalg.eigh(K).eigenvectors for K in kronecker_factor_list] + self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + if __name__ == "__main__": absltest.main() From 216bad2776fb2d187f7270f9b4f0bd44455a6921 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:00:59 -0800 Subject: [PATCH 04/15] use 2d empty as place holder for kronecker factor Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap_utils.py | 4 ++-- emerging_optimizers/utils/eig.py | 2 +- tests/test_soap_utils.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/emerging_optimizers/soap/soap_utils.py b/emerging_optimizers/soap/soap_utils.py index e22a42b9..c12ce732 100644 --- a/emerging_optimizers/soap/soap_utils.py +++ b/emerging_optimizers/soap/soap_utils.py @@ -86,7 +86,7 @@ def get_eigenbasis_eigh( for kronecker_factor in kronecker_factor_list: if kronecker_factor.numel() == 0: - updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) + updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device)) continue _, Q = eig_utils.eigh_with_fallback(kronecker_factor, force_double=False, eps=eps) updated_eigenbasis_list.append(Q) @@ -148,7 +148,7 @@ def get_eigenbasis_qr( updated_eigenbasis_list: TensorList = [] for ind, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)): if kronecker_factor.numel() == 0: - updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) + updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device)) continue approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index 93dab088..0c48ebc9 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -115,7 +115,7 @@ def met_approx_eigvals_criteria( tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. Returns: - Whether eigenbasis met criteria and don't need to be updated + Whether eigenbasis meet criteria and don't need to be updated """ matrix_norm = torch.linalg.norm(kronecker_factor) diagonal_norm = torch.linalg.norm(approx_eigvals) diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index 43c9d610..798eebfb 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -134,7 +134,7 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: kronecker_factor_list = [] for dim in dims: if dim == 0: - kronecker_factor_list.append(torch.empty(0)) + kronecker_factor_list.append(torch.empty(0, 0)) continue k_factor = torch.randn(dim, dim, device=self.device) @@ -148,7 +148,7 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: for i, Q in enumerate(Q_list): orig_dim = dims[i] if orig_dim == 0: - self.assertEqual(Q.shape, (0,)) + self.assertEqual(Q.shape, (0, 0)) continue self.assertEqual(Q.shape, (orig_dim, orig_dim)) From 2ffc2440e72b5abbf7db246507f8463eba12e364 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:14:44 -0800 Subject: [PATCH 05/15] fix criteria check bug and add test Signed-off-by: Hao Wu --- emerging_optimizers/utils/eig.py | 2 +- tests/test_soap_utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index 0c48ebc9..5944842a 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -120,7 +120,7 @@ def met_approx_eigvals_criteria( matrix_norm = torch.linalg.norm(kronecker_factor) diagonal_norm = torch.linalg.norm(approx_eigvals) - return diagonal_norm <= (1 - tolerance) * matrix_norm + return tolerance * matrix_norm >= (matrix_norm - diagonal_norm) def orthogonal_iteration( diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index 798eebfb..f053ddb9 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -182,24 +182,24 @@ def test_all_eigenbases_met_criteria_empty_list_returns_true(self) -> None: self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) @parameterized.parameters( - {"N": 4}, {"N": 16}, {"N": 33}, + {"N": 255}, ) def test_all_eigenbases_met_criteria_random_eigenbasis_returns_false(self, N: int) -> None: kronecker_factor_list = [torch.randn(N, N, device=self.device)] - eigenbasis_list = [torch.randn(N, N, device=self.device)] + eigenbasis_list = [torch.diag(torch.randn(N, device=self.device))] self.assertFalse(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) @parameterized.parameters( - {"N": 4}, {"N": 16}, {"N": 33}, + {"N": 255}, ) - def test_all_eigenbases_met_criteria_identity_true_eigenbasis_returns_true(self, N: int) -> None: + def test_all_eigenbases_met_criteria_true_eigenbasis_returns_true(self, N: int) -> None: kronecker_factor_list = [torch.randn(N, N, device=self.device)] - eigenbasis_list = [torch.linalg.eigh(K).eigenvectors for K in kronecker_factor_list] + eigenbasis_list = [torch.diag(torch.linalg.eigh(K).eigenvalues) for K in kronecker_factor_list] self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) From ceeaf645ac5f9953dcf804bbe3f0f2b504d47999 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:21:33 -0800 Subject: [PATCH 06/15] add more coverage for calculate_sim_ademamix_update Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index e1d7fe4d..fa1f48d1 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -153,7 +153,13 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None expected_param_val_after_step = initial_param_val_tensor - lr * laprop_update torch.testing.assert_close(param.data, expected_param_val_after_step, atol=1e-6, rtol=1e-6) - def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: + @parameterized.parameters( + {"correct_bias": True, "num_beta_fast_warmup_steps": None}, + {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, + ) + def test_calculate_ademamix_update_with_alpha_zero_equals_adam( + self, correct_bias: bool, num_beta_fast_warmup_steps: int | None + ) -> None: # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device) @@ -162,7 +168,7 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: betas = (0.9, 0.99, 0.999) eps = 1e-8 step = 10 - correct_bias_manual = True + correct_bias = True # Calculate AdEMAMix update exp_avg_fast_for_ademamix = exp_avg_fast_initial.clone() @@ -173,12 +179,12 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: exp_avg_fast_for_ademamix, exp_avg_slow_for_ademamix, exp_avg_sq_for_ademamix, - num_beta_slow_warmup_steps=None, + num_beta_slow_warmup_steps=num_beta_fast_warmup_steps, num_alpha_warmup_steps=None, betas=betas, step=step, eps=eps, - correct_bias=correct_bias_manual, + correct_bias=correct_bias, alpha=0.0, ) @@ -190,7 +196,7 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: exp_avg_for_adam, exp_avg_sq_for_adam, (betas[0], betas[1]), - correct_bias=correct_bias_manual, + correct_bias=correct_bias, use_nesterov=False, step=step, eps=eps, @@ -198,7 +204,13 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: torch.testing.assert_close(ademamix_update, adam_update, atol=1e-6, rtol=1e-6) - def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop(self) -> None: + @parameterized.parameters( + {"correct_bias": True, "num_beta_fast_warmup_steps": None}, + {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, + ) + def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop( + self, correct_bias: bool, num_beta_fast_warmup_steps: int | None + ) -> None: # sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp. exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0 exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) @@ -216,7 +228,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr grad, exp_avg_for_sim_ademamix, exp_avg_sq_for_sim_ademamix, - num_beta_fast_warmup_steps=None, + num_beta_fast_warmup_steps=num_beta_fast_warmup_steps, min_beta_fast=0.0, betas=betas, step=step, @@ -237,7 +249,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr eps=eps, weight_decay=0, momentum=0, - centered=False, + centered=correct_bias, ) # Manually set RMSProp's internal state From 66fefd5f0d4942adb89586b37d7cae9fe2a4cf2d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:25:07 -0800 Subject: [PATCH 07/15] add more coverage for lion Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index fa1f48d1..54325d88 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -305,6 +305,41 @@ def test_calculate_signum_with_shape_scaling_returns_sign(self) -> None: expected_update = torch.sign(exp_avg).abs() * (2 / (shape[0] + shape[1])) torch.testing.assert_close(update_abs, expected_update, atol=0, rtol=0) + def test_calculate_lion_update_returns_sign(self) -> None: + """Tests that Lion update returns sign of interpolated momentum.""" + shape = (8, 12) + momentum_beta = 0.9 + grad = torch.randn(shape, device=self.device) + exp_avg = torch.randn(shape, device=self.device) + exp_avg_clone = exp_avg.clone() + + update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=momentum_beta) + + # Update should be sign(beta * m + (1 - beta) * g) + expected_update = torch.sign(momentum_beta * exp_avg_clone + (1 - momentum_beta) * grad) + torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + + # exp_avg should be updated in-place: lerp_(grad, 1 - beta) + expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - momentum_beta) + torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6) + + def test_calculate_lion_update_with_separate_betas(self) -> None: + """Tests Lion with different beta1 and beta2.""" + shape = (4, 6) + beta1, beta2 = 0.9, 0.99 + grad = torch.randn(shape, device=self.device) + exp_avg = torch.randn(shape, device=self.device) + exp_avg_clone = exp_avg.clone() + + update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=beta1, momentum_beta2=beta2) + + expected_update = torch.sign(beta1 * exp_avg_clone + (1 - beta1) * grad) + torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + + # With separate beta2, momentum uses beta2 + expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - beta2) + torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6) + if __name__ == "__main__": testing.absltest.main() From 362c781809db747156651a8a4980b986b12d8890 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:39:07 -0800 Subject: [PATCH 08/15] improve test for schedule functions Signed-off-by: Hao Wu --- pyproject.toml | 1 + tests/test_soap.py | 109 ++++++++++++++++++++++++++++++++++----------- 2 files changed, 83 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 57030fc5..d016b561 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,5 +184,6 @@ exclude_also = [ "if closure", "loss = closure", "raise .*Error", + "@abstractmethod", ] diff --git a/tests/test_soap.py b/tests/test_soap.py index db8fef49..a65bceb5 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -22,6 +22,11 @@ from absl.testing import absltest, parameterized from emerging_optimizers.soap import REKLS, SOAP, soap +from emerging_optimizers.soap.soap import ( + _clip_update_rms_in_place, + _is_eigenbasis_update_step, +) +from emerging_optimizers.utils import precondition_schedules flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -37,13 +42,6 @@ def setUpModule() -> None: torch.cuda.manual_seed_all(FLAGS.seed) -from emerging_optimizers.soap.soap import ( - _clip_update_rms_in_place, - _is_eigenbasis_update_step, -) -from emerging_optimizers.utils.precondition_schedules import LinearSchedule - - def kl_shampoo_update_ref( kronecker_factor_list: list[torch.Tensor], grad: torch.Tensor, @@ -254,26 +252,6 @@ def test_soap_optimizer_fixed_frequency(self) -> None: optimizer = SOAP([param], lr=1e-3, precondition_frequency=10) self.assertEqual(optimizer.precondition_frequency, 10) - def test_soap_optimizer_class_based_schedule(self) -> None: - """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" - param = torch.randn(10, 5, requires_grad=True) - schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) - optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule) - self.assertTrue(optimizer.precondition_frequency == schedule) - - self.assertEqual(schedule(0), 2) - self.assertEqual(schedule(50), 6) - self.assertEqual(schedule(100), 10) - - adam_warmup = 1 - - self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule)) - @parameterized.parameters( (1.0,), (0.0,), @@ -386,6 +364,83 @@ def test_kl_shampoo_update(self, m, n): torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6) +class ScheduleTest(parameterized.TestCase): + def test_soap_optimizer_class_with_linear_schedule(self) -> None: + """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" + param = torch.randn(10, 5, requires_grad=True) + schedule = precondition_schedules.LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) + optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule) + self.assertTrue(optimizer.precondition_frequency == schedule) + + self.assertEqual(schedule(0), 2) + self.assertEqual(schedule(50), 6) + self.assertEqual(schedule(100), 10) + + adam_warmup = 1 + + self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule)) + + self.assertFalse(_is_eigenbasis_update_step(2, 10, schedule)) + + def test_cosine_schedule(self) -> None: + schedule = precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=100) + + # At step 0 (start of cosine), progress=1.0 so freq = max - (max-min)*1 = min + self.assertEqual(schedule(0), 1) + + # At midpoint, progress=0.5 so freq = max - (max-min)*0.5 = max/2 + self.assertEqual(schedule(50), 25) + + # At full period, wraps back: progress=1.0 so freq = min + self.assertEqual(schedule(100), 1) + + # Before start_step returns min_freq + schedule_delayed = precondition_schedules.CosineSchedule( + min_freq=5, max_freq=50, transition_steps=100, start_step=10 + ) + self.assertEqual(schedule_delayed(5), 5) + + # Negative step raises + with self.assertRaises(ValueError): + schedule(-1) + + # Invalid init raises + with self.assertRaises(ValueError): + precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=0) + + def test_step_schedule(self) -> None: + schedule = precondition_schedules.StepSchedule({0: 1, 100: 5, 500: 20}) + + self.assertEqual(schedule(0), 1) + self.assertEqual(schedule(50), 1) + self.assertEqual(schedule(100), 5) + self.assertEqual(schedule(250), 5) + self.assertEqual(schedule(500), 20) + self.assertEqual(schedule(10000), 20) + + # Before start_step returns min_freq + schedule_delayed = precondition_schedules.StepSchedule({0: 2, 100: 10}, start_step=50) + self.assertEqual(schedule_delayed(25), 2) + self.assertEqual(schedule_delayed(100), 10) + + # Empty dict raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({}) + + # Invalid frequency raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({0: 0}) + + # Negative step key raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({-1: 5}) + + class SoapTest(parameterized.TestCase): @classmethod def setUpClass(cls): From 74250472d701392b3fa0c035e46e80085160a9b1 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:52:45 -0800 Subject: [PATCH 09/15] Update tests/test_scalar_optimizers.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index 54325d88..a83ab908 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -157,19 +157,8 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None {"correct_bias": True, "num_beta_fast_warmup_steps": None}, {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, ) - def test_calculate_ademamix_update_with_alpha_zero_equals_adam( - self, correct_bias: bool, num_beta_fast_warmup_steps: int | None - ) -> None: - # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. - exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) - exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device) - exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) - grad = torch.tensor([[0.5]], device=self.device) - betas = (0.9, 0.99, 0.999) - eps = 1e-8 step = 10 - correct_bias = True - + # correct_bias is injected by @parameterized.parameters # Calculate AdEMAMix update exp_avg_fast_for_ademamix = exp_avg_fast_initial.clone() exp_avg_slow_for_ademamix = exp_avg_slow_initial.clone() From 07659c4d7dc6678f758746852e8c9c0baaaf0b3d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 16:57:45 -0800 Subject: [PATCH 10/15] revert greptile mess Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index a83ab908..b10b7f71 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -157,8 +157,18 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None {"correct_bias": True, "num_beta_fast_warmup_steps": None}, {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, ) + def test_calculate_ademamix_update_with_alpha_zero_equals_adam( + self, correct_bias: bool, num_beta_fast_warmup_steps: int | None + ) -> None: + # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. + exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) + exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device) + exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) + grad = torch.tensor([[0.5]], device=self.device) + betas = (0.9, 0.99, 0.999) + eps = 1e-8 step = 10 - # correct_bias is injected by @parameterized.parameters + # Calculate AdEMAMix update exp_avg_fast_for_ademamix = exp_avg_fast_initial.clone() exp_avg_slow_for_ademamix = exp_avg_slow_initial.clone() From 695ce6886c354e262291d6547fd5fb17cdab477e Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 17:08:54 -0800 Subject: [PATCH 11/15] skip some flaky test Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index b10b7f71..e03e6717 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -203,21 +203,15 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam( torch.testing.assert_close(ademamix_update, adam_update, atol=1e-6, rtol=1e-6) - @parameterized.parameters( - {"correct_bias": True, "num_beta_fast_warmup_steps": None}, - {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, - ) - def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop( - self, correct_bias: bool, num_beta_fast_warmup_steps: int | None - ) -> None: + def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop(self) -> None: # sim_ademamix with momentum (beta_fast) = 0 and alpha = 0 should be equivalent to RMSProp. exp_avg_initial = torch.tensor([[0.0]], device=self.device) # Momentum is 0, so exp_avg starts at 0 exp_avg_sq_initial = torch.tensor([[2.0]], device=self.device) grad = torch.tensor([[0.5]], device=self.device) betas = (0.0, 0.99) # beta1=0 for momentum eps = 1e-8 - step = 10 correct_bias = False + step = 10 lr = 0.25 exp_avg_for_sim_ademamix = exp_avg_initial.clone() exp_avg_sq_for_sim_ademamix = exp_avg_sq_initial.clone() @@ -227,7 +221,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr grad, exp_avg_for_sim_ademamix, exp_avg_sq_for_sim_ademamix, - num_beta_fast_warmup_steps=num_beta_fast_warmup_steps, + num_beta_fast_warmup_steps=None, min_beta_fast=0.0, betas=betas, step=step, From b4ea359b73a9cbd627b8d42d1cb58db1fbb21d5e Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 5 Mar 2026 17:15:03 -0800 Subject: [PATCH 12/15] fix eigenvector vs eigenvalue Signed-off-by: Hao Wu --- emerging_optimizers/utils/eig.py | 2 +- tests/test_soap_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index 5944842a..d43dade4 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -137,7 +137,7 @@ def orthogonal_iteration( to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis. Args: - approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal + approx_eigvals : Projection of kronecker factor onto the eigenbasis, should be close to diagonal kronecker_factor : Kronecker factor matrix. eigenbasis : Kronecker factor eigenbasis matrix. ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over. diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index f053ddb9..5827681f 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -197,9 +197,11 @@ def test_all_eigenbases_met_criteria_random_eigenbasis_returns_false(self, N: in {"N": 255}, ) def test_all_eigenbases_met_criteria_true_eigenbasis_returns_true(self, N: int) -> None: - kronecker_factor_list = [torch.randn(N, N, device=self.device)] + g = torch.randn(N, N, device=self.device) + K_sym = g @ g.T + torch.eye(N, device=self.device) * 1e-5 # symmetric PSD + kronecker_factor_list = [K_sym] - eigenbasis_list = [torch.diag(torch.linalg.eigh(K).eigenvalues) for K in kronecker_factor_list] + eigenbasis_list = [torch.linalg.eigh(K_sym).eigenvectors] self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) From 23e4b0c5f91fd600248a4ceb4782cf8cb7c40a06 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 6 Mar 2026 08:33:22 -0800 Subject: [PATCH 13/15] Update tests/test_scalar_optimizers.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index e03e6717..709b8e20 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -154,11 +154,11 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None torch.testing.assert_close(param.data, expected_param_val_after_step, atol=1e-6, rtol=1e-6) @parameterized.parameters( - {"correct_bias": True, "num_beta_fast_warmup_steps": None}, - {"correct_bias": False, "num_beta_fast_warmup_steps": 2}, + {"correct_bias": True, "num_beta_slow_warmup_steps": None}, + {"correct_bias": False, "num_beta_slow_warmup_steps": 2}, ) def test_calculate_ademamix_update_with_alpha_zero_equals_adam( - self, correct_bias: bool, num_beta_fast_warmup_steps: int | None + self, correct_bias: bool, num_beta_slow_warmup_steps: int | None ) -> None: # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) From 0825b7ea235095c21ab275b186a13601244a9b0d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 6 Mar 2026 09:10:26 -0800 Subject: [PATCH 14/15] fix AI error Signed-off-by: Hao Wu --- tests/test_scalar_optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index 709b8e20..96af3ac1 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -178,7 +178,7 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam( exp_avg_fast_for_ademamix, exp_avg_slow_for_ademamix, exp_avg_sq_for_ademamix, - num_beta_slow_warmup_steps=num_beta_fast_warmup_steps, + num_beta_slow_warmup_steps=num_beta_slow_warmup_steps, num_alpha_warmup_steps=None, betas=betas, step=step, From 1215f186e5c95d801ed082ac7d7fef606ffb08ef Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 6 Mar 2026 09:19:48 -0800 Subject: [PATCH 15/15] Update tests/test_soap.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hao Wu --- tests/test_soap.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_soap.py b/tests/test_soap.py index a65bceb5..da4cbf40 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -390,13 +390,8 @@ def test_soap_optimizer_class_with_linear_schedule(self) -> None: def test_cosine_schedule(self) -> None: schedule = precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=100) - # At step 0 (start of cosine), progress=1.0 so freq = max - (max-min)*1 = min - self.assertEqual(schedule(0), 1) - - # At midpoint, progress=0.5 so freq = max - (max-min)*0.5 = max/2 + # At midpoint, progress=0.5 so freq = max - (max-min)*0.5 = (max+min)/2, rounded to int self.assertEqual(schedule(50), 25) - - # At full period, wraps back: progress=1.0 so freq = min self.assertEqual(schedule(100), 1) # Before start_step returns min_freq