diff --git a/emerging_optimizers/soap/soap_utils.py b/emerging_optimizers/soap/soap_utils.py index e22a42b..c12ce73 100644 --- a/emerging_optimizers/soap/soap_utils.py +++ b/emerging_optimizers/soap/soap_utils.py @@ -86,7 +86,7 @@ def get_eigenbasis_eigh( for kronecker_factor in kronecker_factor_list: if kronecker_factor.numel() == 0: - updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) + updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device)) continue _, Q = eig_utils.eigh_with_fallback(kronecker_factor, force_double=False, eps=eps) updated_eigenbasis_list.append(Q) @@ -148,7 +148,7 @@ def get_eigenbasis_qr( updated_eigenbasis_list: TensorList = [] for ind, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)): if kronecker_factor.numel() == 0: - updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device)) + updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device)) continue approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True) diff --git a/emerging_optimizers/utils/eig.py b/emerging_optimizers/utils/eig.py index d8343fe..d43dade 100644 --- a/emerging_optimizers/utils/eig.py +++ b/emerging_optimizers/utils/eig.py @@ -115,12 +115,12 @@ def met_approx_eigvals_criteria( tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix. Returns: - perform_update: Whether to update eigenbasis this iteration + Whether eigenbasis meet criteria and don't need to be updated """ matrix_norm = torch.linalg.norm(kronecker_factor) diagonal_norm = torch.linalg.norm(approx_eigvals) - return diagonal_norm >= (1 - tolerance) * matrix_norm + return tolerance * matrix_norm >= (matrix_norm - diagonal_norm) def orthogonal_iteration( @@ -137,7 +137,7 @@ def orthogonal_iteration( to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis. Args: - approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal + approx_eigvals : Projection of kronecker factor onto the eigenbasis, should be close to diagonal kronecker_factor : Kronecker factor matrix. eigenbasis : Kronecker factor eigenbasis matrix. ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over. diff --git a/pyproject.toml b/pyproject.toml index 57030fc..d016b56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,5 +184,6 @@ exclude_also = [ "if closure", "loss = closure", "raise .*Error", + "@abstractmethod", ] diff --git a/tests/ci/L0_Tests_CPU.sh b/tests/ci/L0_Tests_CPU.sh index ca6a1a9..30c09b2 100644 --- a/tests/ci/L0_Tests_CPU.sh +++ b/tests/ci/L0_Tests_CPU.sh @@ -13,10 +13,19 @@ # limitations under the License. export TORCH_COMPILE_DISABLE=1 +mkdir -p test-results/tests/ + error=0 -torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 -torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1 -coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1 -coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1 +for n in 8 4; do + torchrun --nproc_per_node=$n --no-python coverage run -p \ + tests/test_distributed_muon_utils_cpu.py \ + --xml_output_file="test-results/tests/test_distributed_muon_utils_cpu_n${n}.xml" \ + -v -2 || error=1 +done + +for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py"; do + report_name="test-results/${test}.xml" + coverage run -p --source=emerging_optimizers $test --device=cpu -v -2 --xml_output_file="$report_name" || error=1 +done exit "${error}" diff --git a/tests/test_distributed_muon_utils_cpu.py b/tests/test_distributed_muon_utils_cpu.py index f9ebddd..7a0f1b8 100644 --- a/tests/test_distributed_muon_utils_cpu.py +++ b/tests/test_distributed_muon_utils_cpu.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys import numpy as np import torch @@ -211,6 +212,17 @@ def test_1step_close_to_non_distributed(self, shape, partition_dim, tp_mode): if __name__ == "__main__": torch.distributed.init_process_group(backend="gloo") torch.set_float32_matmul_precision("highest") + + rank = torch.distributed.get_rank() + + for i, arg in enumerate(sys.argv): + if arg.startswith("--xml_output_file="): + base, ext = os.path.splitext(arg) + + # Attach rank to the output file name + sys.argv[i] = f"{base}_rank{rank}{ext}" + break + absltest.main() torch.distributed.destroy_process_group() diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index e1d7fe4..96af3ac 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -153,7 +153,13 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None expected_param_val_after_step = initial_param_val_tensor - lr * laprop_update torch.testing.assert_close(param.data, expected_param_val_after_step, atol=1e-6, rtol=1e-6) - def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: + @parameterized.parameters( + {"correct_bias": True, "num_beta_slow_warmup_steps": None}, + {"correct_bias": False, "num_beta_slow_warmup_steps": 2}, + ) + def test_calculate_ademamix_update_with_alpha_zero_equals_adam( + self, correct_bias: bool, num_beta_slow_warmup_steps: int | None + ) -> None: # AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam. exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device) exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device) @@ -162,7 +168,6 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: betas = (0.9, 0.99, 0.999) eps = 1e-8 step = 10 - correct_bias_manual = True # Calculate AdEMAMix update exp_avg_fast_for_ademamix = exp_avg_fast_initial.clone() @@ -173,12 +178,12 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: exp_avg_fast_for_ademamix, exp_avg_slow_for_ademamix, exp_avg_sq_for_ademamix, - num_beta_slow_warmup_steps=None, + num_beta_slow_warmup_steps=num_beta_slow_warmup_steps, num_alpha_warmup_steps=None, betas=betas, step=step, eps=eps, - correct_bias=correct_bias_manual, + correct_bias=correct_bias, alpha=0.0, ) @@ -190,7 +195,7 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None: exp_avg_for_adam, exp_avg_sq_for_adam, (betas[0], betas[1]), - correct_bias=correct_bias_manual, + correct_bias=correct_bias, use_nesterov=False, step=step, eps=eps, @@ -205,8 +210,8 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr grad = torch.tensor([[0.5]], device=self.device) betas = (0.0, 0.99) # beta1=0 for momentum eps = 1e-8 - step = 10 correct_bias = False + step = 10 lr = 0.25 exp_avg_for_sim_ademamix = exp_avg_initial.clone() exp_avg_sq_for_sim_ademamix = exp_avg_sq_initial.clone() @@ -237,7 +242,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr eps=eps, weight_decay=0, momentum=0, - centered=False, + centered=correct_bias, ) # Manually set RMSProp's internal state @@ -293,6 +298,41 @@ def test_calculate_signum_with_shape_scaling_returns_sign(self) -> None: expected_update = torch.sign(exp_avg).abs() * (2 / (shape[0] + shape[1])) torch.testing.assert_close(update_abs, expected_update, atol=0, rtol=0) + def test_calculate_lion_update_returns_sign(self) -> None: + """Tests that Lion update returns sign of interpolated momentum.""" + shape = (8, 12) + momentum_beta = 0.9 + grad = torch.randn(shape, device=self.device) + exp_avg = torch.randn(shape, device=self.device) + exp_avg_clone = exp_avg.clone() + + update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=momentum_beta) + + # Update should be sign(beta * m + (1 - beta) * g) + expected_update = torch.sign(momentum_beta * exp_avg_clone + (1 - momentum_beta) * grad) + torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + + # exp_avg should be updated in-place: lerp_(grad, 1 - beta) + expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - momentum_beta) + torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6) + + def test_calculate_lion_update_with_separate_betas(self) -> None: + """Tests Lion with different beta1 and beta2.""" + shape = (4, 6) + beta1, beta2 = 0.9, 0.99 + grad = torch.randn(shape, device=self.device) + exp_avg = torch.randn(shape, device=self.device) + exp_avg_clone = exp_avg.clone() + + update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=beta1, momentum_beta2=beta2) + + expected_update = torch.sign(beta1 * exp_avg_clone + (1 - beta1) * grad) + torch.testing.assert_close(update, expected_update, atol=0, rtol=0) + + # With separate beta2, momentum uses beta2 + expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - beta2) + torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6) + if __name__ == "__main__": testing.absltest.main() diff --git a/tests/test_soap.py b/tests/test_soap.py index db8fef4..da4cbf4 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -22,6 +22,11 @@ from absl.testing import absltest, parameterized from emerging_optimizers.soap import REKLS, SOAP, soap +from emerging_optimizers.soap.soap import ( + _clip_update_rms_in_place, + _is_eigenbasis_update_step, +) +from emerging_optimizers.utils import precondition_schedules flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -37,13 +42,6 @@ def setUpModule() -> None: torch.cuda.manual_seed_all(FLAGS.seed) -from emerging_optimizers.soap.soap import ( - _clip_update_rms_in_place, - _is_eigenbasis_update_step, -) -from emerging_optimizers.utils.precondition_schedules import LinearSchedule - - def kl_shampoo_update_ref( kronecker_factor_list: list[torch.Tensor], grad: torch.Tensor, @@ -254,26 +252,6 @@ def test_soap_optimizer_fixed_frequency(self) -> None: optimizer = SOAP([param], lr=1e-3, precondition_frequency=10) self.assertEqual(optimizer.precondition_frequency, 10) - def test_soap_optimizer_class_based_schedule(self) -> None: - """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" - param = torch.randn(10, 5, requires_grad=True) - schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) - optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule) - self.assertTrue(optimizer.precondition_frequency == schedule) - - self.assertEqual(schedule(0), 2) - self.assertEqual(schedule(50), 6) - self.assertEqual(schedule(100), 10) - - adam_warmup = 1 - - self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule)) - self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule)) - self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule)) - @parameterized.parameters( (1.0,), (0.0,), @@ -386,6 +364,78 @@ def test_kl_shampoo_update(self, m, n): torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6) +class ScheduleTest(parameterized.TestCase): + def test_soap_optimizer_class_with_linear_schedule(self) -> None: + """Test that SOAP optimizer can be created with class-based precondition frequency schedule.""" + param = torch.randn(10, 5, requires_grad=True) + schedule = precondition_schedules.LinearSchedule(min_freq=2, max_freq=10, transition_steps=100) + optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule) + self.assertTrue(optimizer.precondition_frequency == schedule) + + self.assertEqual(schedule(0), 2) + self.assertEqual(schedule(50), 6) + self.assertEqual(schedule(100), 10) + + adam_warmup = 1 + + self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule)) + self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule)) + self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule)) + + self.assertFalse(_is_eigenbasis_update_step(2, 10, schedule)) + + def test_cosine_schedule(self) -> None: + schedule = precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=100) + + # At midpoint, progress=0.5 so freq = max - (max-min)*0.5 = (max+min)/2, rounded to int + self.assertEqual(schedule(50), 25) + self.assertEqual(schedule(100), 1) + + # Before start_step returns min_freq + schedule_delayed = precondition_schedules.CosineSchedule( + min_freq=5, max_freq=50, transition_steps=100, start_step=10 + ) + self.assertEqual(schedule_delayed(5), 5) + + # Negative step raises + with self.assertRaises(ValueError): + schedule(-1) + + # Invalid init raises + with self.assertRaises(ValueError): + precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=0) + + def test_step_schedule(self) -> None: + schedule = precondition_schedules.StepSchedule({0: 1, 100: 5, 500: 20}) + + self.assertEqual(schedule(0), 1) + self.assertEqual(schedule(50), 1) + self.assertEqual(schedule(100), 5) + self.assertEqual(schedule(250), 5) + self.assertEqual(schedule(500), 20) + self.assertEqual(schedule(10000), 20) + + # Before start_step returns min_freq + schedule_delayed = precondition_schedules.StepSchedule({0: 2, 100: 10}, start_step=50) + self.assertEqual(schedule_delayed(25), 2) + self.assertEqual(schedule_delayed(100), 10) + + # Empty dict raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({}) + + # Invalid frequency raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({0: 0}) + + # Negative step key raises + with self.assertRaises(ValueError): + precondition_schedules.StepSchedule({-1: 5}) + + class SoapTest(parameterized.TestCase): @classmethod def setUpClass(cls): diff --git a/tests/test_soap_utils.py b/tests/test_soap_utils.py index a3a2e4f..5827681 100644 --- a/tests/test_soap_utils.py +++ b/tests/test_soap_utils.py @@ -102,9 +102,32 @@ def test_get_eigenbasis_qr(self, N: int, M: int) -> None: # Also check that "exp_avg_sq" remains in the state with same shape if not merging self.assertEqual(exp_avg_sq_new.shape, (M, N)) + def test_get_eigenbasis_qr_empty_factor(self) -> None: + """Tests get_eigenbasis_qr with an empty (numel()==0) kronecker factor.""" + torch.manual_seed(0) + N = 4 + g = torch.randn(N, N, device=self.device) + L = g.mm(g.t()).float() + + empty_factor = torch.empty(0, 0, device=self.device) + kronecker_factor_list = [L, empty_factor] + eigenbasis_list = [torch.randn(N, N, device=self.device), torch.empty(0, device=self.device)] + exp_avg_sq = torch.abs(torch.randn(N, N, device=self.device)) + + Q_new_list, exp_avg_sq_new = soap_utils.get_eigenbasis_qr( + kronecker_factor_list=kronecker_factor_list, + eigenbasis_list=eigenbasis_list, + exp_avg_sq=exp_avg_sq, + ) + + self.assertEqual(len(Q_new_list), 2) + self.assertEqual(Q_new_list[0].shape, (N, N)) + self.assertEqual(Q_new_list[1].numel(), 0) + @parameterized.parameters( # type: ignore[misc] {"dims": [128, 512]}, {"dims": []}, + {"dims": [64, 0, 32]}, ) def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: """Tests the get_eigenbasis_eigh function.""" @@ -153,6 +176,34 @@ def test_get_eigenbasis_eigh(self, dims: list[int]) -> None: msg=f"Matrix {i} was not properly diagonalized. Off-diagonal norm: {off_diagonal_norm}", ) + def test_all_eigenbases_met_criteria_empty_list_returns_true(self) -> None: + kronecker_factor_list = [] + eigenbasis_list = [] + self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + + @parameterized.parameters( + {"N": 16}, + {"N": 33}, + {"N": 255}, + ) + def test_all_eigenbases_met_criteria_random_eigenbasis_returns_false(self, N: int) -> None: + kronecker_factor_list = [torch.randn(N, N, device=self.device)] + eigenbasis_list = [torch.diag(torch.randn(N, device=self.device))] + self.assertFalse(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + + @parameterized.parameters( + {"N": 16}, + {"N": 33}, + {"N": 255}, + ) + def test_all_eigenbases_met_criteria_true_eigenbasis_returns_true(self, N: int) -> None: + g = torch.randn(N, N, device=self.device) + K_sym = g @ g.T + torch.eye(N, device=self.device) * 1e-5 # symmetric PSD + kronecker_factor_list = [K_sym] + + eigenbasis_list = [torch.linalg.eigh(K_sym).eigenvectors] + self.assertTrue(soap_utils.all_eigenbases_met_criteria(kronecker_factor_list, eigenbasis_list)) + if __name__ == "__main__": absltest.main()