Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions emerging_optimizers/soap/soap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_eigenbasis_eigh(

for kronecker_factor in kronecker_factor_list:
if kronecker_factor.numel() == 0:
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device))
continue
_, Q = eig_utils.eigh_with_fallback(kronecker_factor, force_double=False, eps=eps)
updated_eigenbasis_list.append(Q)
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_eigenbasis_qr(
updated_eigenbasis_list: TensorList = []
for ind, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)):
if kronecker_factor.numel() == 0:
updated_eigenbasis_list.append(torch.empty(0, device=kronecker_factor.device))
updated_eigenbasis_list.append(torch.empty(0, 0, device=kronecker_factor.device))
continue

approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True)
Expand Down
6 changes: 3 additions & 3 deletions emerging_optimizers/utils/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def met_approx_eigvals_criteria(
tolerance: Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix.

Returns:
perform_update: Whether to update eigenbasis this iteration
Whether eigenbasis meet criteria and don't need to be updated
"""
matrix_norm = torch.linalg.norm(kronecker_factor)
diagonal_norm = torch.linalg.norm(approx_eigvals)

return diagonal_norm >= (1 - tolerance) * matrix_norm
return tolerance * matrix_norm >= (matrix_norm - diagonal_norm)


def orthogonal_iteration(
Expand All @@ -137,7 +137,7 @@ def orthogonal_iteration(
to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.'s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis.

Args:
approx_eigenvalue_matrix : Projection of kronecker factor onto the eigenbasis, should be close to diagonal
approx_eigvals : Projection of kronecker factor onto the eigenbasis, should be close to diagonal
kronecker_factor : Kronecker factor matrix.
eigenbasis : Kronecker factor eigenbasis matrix.
ind : Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,6 @@ exclude_also = [
"if closure",
"loss = closure",
"raise .*Error",
"@abstractmethod",
]

17 changes: 13 additions & 4 deletions tests/ci/L0_Tests_CPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@
# limitations under the License.
export TORCH_COMPILE_DISABLE=1

mkdir -p test-results/tests/

error=0
torchrun --nproc_per_node=8 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
torchrun --nproc_per_node=4 --no-python coverage run -p tests/test_distributed_muon_utils_cpu.py -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cpu -v -2 || error=1
coverage run -p --source=emerging_optimizers tests/test_procrustes_step.py --device=cpu -v -2 || error=1
for n in 8 4; do
torchrun --nproc_per_node=$n --no-python coverage run -p \
tests/test_distributed_muon_utils_cpu.py \
--xml_output_file="test-results/tests/test_distributed_muon_utils_cpu_n${n}.xml" \
-v -2 || error=1
done

for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py"; do
report_name="test-results/${test}.xml"
coverage run -p --source=emerging_optimizers $test --device=cpu -v -2 --xml_output_file="$report_name" || error=1
done
Comment on lines +26 to +29
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test files excluded from CI

test_soap.py and test_soap_utils.py are not added to this script, so all the new test cases added in this PR — including ScheduleTest, test_all_eigenbases_met_criteria_*, and test_get_eigenbasis_qr_empty_factor — will not be executed in CI.

The for-loop only covers test_scalar_optimizers.py and test_procrustes_step.py. Since the PR's stated goal is to "improve test coverage and report", the new tests should also be wired into the CI script, for example:

for test in "tests/test_scalar_optimizers.py" "tests/test_procrustes_step.py" "tests/test_soap.py" "tests/test_soap_utils.py"; do
    report_name="test-results/${test}.xml"
    coverage run -p --source=emerging_optimizers $test --device=cpu  -v -2 --xml_output_file="$report_name" || error=1
done


exit "${error}"
12 changes: 12 additions & 0 deletions tests/test_distributed_muon_utils_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

import numpy as np
import torch
Expand Down Expand Up @@ -211,6 +212,17 @@ def test_1step_close_to_non_distributed(self, shape, partition_dim, tp_mode):
if __name__ == "__main__":
torch.distributed.init_process_group(backend="gloo")
torch.set_float32_matmul_precision("highest")

rank = torch.distributed.get_rank()

for i, arg in enumerate(sys.argv):
if arg.startswith("--xml_output_file="):
base, ext = os.path.splitext(arg)

# Attach rank to the output file name
sys.argv[i] = f"{base}_rank{rank}{ext}"
break

absltest.main()

torch.distributed.destroy_process_group()
54 changes: 47 additions & 7 deletions tests/test_scalar_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ def test_calculate_laprop_update_with_zero_momentum_equals_rmsprop(self) -> None
expected_param_val_after_step = initial_param_val_tensor - lr * laprop_update
torch.testing.assert_close(param.data, expected_param_val_after_step, atol=1e-6, rtol=1e-6)

def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
@parameterized.parameters(
{"correct_bias": True, "num_beta_fast_warmup_steps": None},
{"correct_bias": False, "num_beta_fast_warmup_steps": 2},
)
def test_calculate_ademamix_update_with_alpha_zero_equals_adam(
self, correct_bias: bool, num_beta_fast_warmup_steps: int | None
) -> None:
Comment on lines +156 to +162
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Misleading parameter name num_beta_fast_warmup_steps in AdEMAMix test

The parameterized argument is named num_beta_fast_warmup_steps, but inside the function body it is passed to the calculate_ademamix_update call as num_beta_slow_warmup_steps=num_beta_fast_warmup_steps (line 181). "Fast" and "slow" are separate concepts in AdEMAMix — using the wrong name here can mislead a reader into thinking the wrong warmup schedule is being varied.

By contrast, the sibling test test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmsprop (line 207) uses the same parameter name and passes it correctly as num_beta_fast_warmup_steps=. Renaming the parameter in this test to num_beta_slow_warmup_steps would make the intent clear:

Suggested change
@parameterized.parameters(
{"correct_bias": True, "num_beta_fast_warmup_steps": None},
{"correct_bias": False, "num_beta_fast_warmup_steps": 2},
)
def test_calculate_ademamix_update_with_alpha_zero_equals_adam(
self, correct_bias: bool, num_beta_fast_warmup_steps: int | None
) -> None:
@parameterized.parameters(
{"correct_bias": True, "num_beta_slow_warmup_steps": None},
{"correct_bias": False, "num_beta_slow_warmup_steps": 2},
)
def test_calculate_ademamix_update_with_alpha_zero_equals_adam(
self, correct_bias: bool, num_beta_slow_warmup_steps: int | None
) -> None:

And update the usage on line 181 accordingly:

num_beta_slow_warmup_steps=num_beta_slow_warmup_steps,

# AdEMAMix with alpha=0 and no beta scheduling should be equivalent to Adam.
exp_avg_fast_initial = torch.tensor([[1.0]], device=self.device)
exp_avg_slow_initial = torch.tensor([[1.0]], device=self.device)
Expand All @@ -162,7 +168,6 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
betas = (0.9, 0.99, 0.999)
eps = 1e-8
step = 10
correct_bias_manual = True

# Calculate AdEMAMix update
exp_avg_fast_for_ademamix = exp_avg_fast_initial.clone()
Expand All @@ -173,12 +178,12 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
exp_avg_fast_for_ademamix,
exp_avg_slow_for_ademamix,
exp_avg_sq_for_ademamix,
num_beta_slow_warmup_steps=None,
num_beta_slow_warmup_steps=num_beta_fast_warmup_steps,
num_alpha_warmup_steps=None,
betas=betas,
step=step,
eps=eps,
correct_bias=correct_bias_manual,
correct_bias=correct_bias,
alpha=0.0,
)

Expand All @@ -190,7 +195,7 @@ def test_calculate_ademamix_update_with_alpha_zero_equals_adam(self) -> None:
exp_avg_for_adam,
exp_avg_sq_for_adam,
(betas[0], betas[1]),
correct_bias=correct_bias_manual,
correct_bias=correct_bias,
use_nesterov=False,
step=step,
eps=eps,
Expand All @@ -205,8 +210,8 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
grad = torch.tensor([[0.5]], device=self.device)
betas = (0.0, 0.99) # beta1=0 for momentum
eps = 1e-8
step = 10
correct_bias = False
step = 10
lr = 0.25
exp_avg_for_sim_ademamix = exp_avg_initial.clone()
exp_avg_sq_for_sim_ademamix = exp_avg_sq_initial.clone()
Expand Down Expand Up @@ -237,7 +242,7 @@ def test_calculate_sim_ademamix_update_with_zero_momentum_and_alpha_equals_rmspr
eps=eps,
weight_decay=0,
momentum=0,
centered=False,
centered=correct_bias,
)

# Manually set RMSProp's internal state
Expand Down Expand Up @@ -293,6 +298,41 @@ def test_calculate_signum_with_shape_scaling_returns_sign(self) -> None:
expected_update = torch.sign(exp_avg).abs() * (2 / (shape[0] + shape[1]))
torch.testing.assert_close(update_abs, expected_update, atol=0, rtol=0)

def test_calculate_lion_update_returns_sign(self) -> None:
"""Tests that Lion update returns sign of interpolated momentum."""
shape = (8, 12)
momentum_beta = 0.9
grad = torch.randn(shape, device=self.device)
exp_avg = torch.randn(shape, device=self.device)
exp_avg_clone = exp_avg.clone()

update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=momentum_beta)

# Update should be sign(beta * m + (1 - beta) * g)
expected_update = torch.sign(momentum_beta * exp_avg_clone + (1 - momentum_beta) * grad)
torch.testing.assert_close(update, expected_update, atol=0, rtol=0)

# exp_avg should be updated in-place: lerp_(grad, 1 - beta)
expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - momentum_beta)
torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6)

def test_calculate_lion_update_with_separate_betas(self) -> None:
"""Tests Lion with different beta1 and beta2."""
shape = (4, 6)
beta1, beta2 = 0.9, 0.99
grad = torch.randn(shape, device=self.device)
exp_avg = torch.randn(shape, device=self.device)
exp_avg_clone = exp_avg.clone()

update = scalar_optimizers.calculate_lion_update(grad, exp_avg, momentum_beta=beta1, momentum_beta2=beta2)

expected_update = torch.sign(beta1 * exp_avg_clone + (1 - beta1) * grad)
torch.testing.assert_close(update, expected_update, atol=0, rtol=0)

# With separate beta2, momentum uses beta2
expected_exp_avg = torch.lerp(exp_avg_clone, grad, 1 - beta2)
torch.testing.assert_close(exp_avg, expected_exp_avg, atol=1e-6, rtol=1e-6)


if __name__ == "__main__":
testing.absltest.main()
109 changes: 82 additions & 27 deletions tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
from absl.testing import absltest, parameterized

from emerging_optimizers.soap import REKLS, SOAP, soap
from emerging_optimizers.soap.soap import (
_clip_update_rms_in_place,
_is_eigenbasis_update_step,
)
from emerging_optimizers.utils import precondition_schedules


flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
Expand All @@ -37,13 +42,6 @@ def setUpModule() -> None:
torch.cuda.manual_seed_all(FLAGS.seed)


from emerging_optimizers.soap.soap import (
_clip_update_rms_in_place,
_is_eigenbasis_update_step,
)
from emerging_optimizers.utils.precondition_schedules import LinearSchedule


def kl_shampoo_update_ref(
kronecker_factor_list: list[torch.Tensor],
grad: torch.Tensor,
Expand Down Expand Up @@ -254,26 +252,6 @@ def test_soap_optimizer_fixed_frequency(self) -> None:
optimizer = SOAP([param], lr=1e-3, precondition_frequency=10)
self.assertEqual(optimizer.precondition_frequency, 10)

def test_soap_optimizer_class_based_schedule(self) -> None:
"""Test that SOAP optimizer can be created with class-based precondition frequency schedule."""
param = torch.randn(10, 5, requires_grad=True)
schedule = LinearSchedule(min_freq=2, max_freq=10, transition_steps=100)
optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule)
self.assertTrue(optimizer.precondition_frequency == schedule)

self.assertEqual(schedule(0), 2)
self.assertEqual(schedule(50), 6)
self.assertEqual(schedule(100), 10)

adam_warmup = 1

self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule))
self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule))
self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule))

@parameterized.parameters(
(1.0,),
(0.0,),
Expand Down Expand Up @@ -386,6 +364,83 @@ def test_kl_shampoo_update(self, m, n):
torch.testing.assert_close(kronecker_factor_list[1], kronecker_factor_list_ref[1], atol=1e-6, rtol=1e-6)


class ScheduleTest(parameterized.TestCase):
def test_soap_optimizer_class_with_linear_schedule(self) -> None:
"""Test that SOAP optimizer can be created with class-based precondition frequency schedule."""
param = torch.randn(10, 5, requires_grad=True)
schedule = precondition_schedules.LinearSchedule(min_freq=2, max_freq=10, transition_steps=100)
optimizer = SOAP([param], lr=1e-3, precondition_frequency=schedule)
self.assertTrue(optimizer.precondition_frequency == schedule)

self.assertEqual(schedule(0), 2)
self.assertEqual(schedule(50), 6)
self.assertEqual(schedule(100), 10)

adam_warmup = 1

self.assertTrue(_is_eigenbasis_update_step(10, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(11, adam_warmup, schedule))
self.assertTrue(_is_eigenbasis_update_step(60, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(61, adam_warmup, schedule))
self.assertTrue(_is_eigenbasis_update_step(120, adam_warmup, schedule))
self.assertFalse(_is_eigenbasis_update_step(121, adam_warmup, schedule))

self.assertFalse(_is_eigenbasis_update_step(2, 10, schedule))

def test_cosine_schedule(self) -> None:
schedule = precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=100)

# At step 0 (start of cosine), progress=1.0 so freq = max - (max-min)*1 = min
self.assertEqual(schedule(0), 1)

# At midpoint, progress=0.5 so freq = max - (max-min)*0.5 = max/2
self.assertEqual(schedule(50), 25)

# At full period, wraps back: progress=1.0 so freq = min
self.assertEqual(schedule(100), 1)

# Before start_step returns min_freq
schedule_delayed = precondition_schedules.CosineSchedule(
min_freq=5, max_freq=50, transition_steps=100, start_step=10
)
self.assertEqual(schedule_delayed(5), 5)

# Negative step raises
with self.assertRaises(ValueError):
schedule(-1)

# Invalid init raises
with self.assertRaises(ValueError):
precondition_schedules.CosineSchedule(min_freq=1, max_freq=50, transition_steps=0)

def test_step_schedule(self) -> None:
schedule = precondition_schedules.StepSchedule({0: 1, 100: 5, 500: 20})

self.assertEqual(schedule(0), 1)
self.assertEqual(schedule(50), 1)
self.assertEqual(schedule(100), 5)
self.assertEqual(schedule(250), 5)
self.assertEqual(schedule(500), 20)
self.assertEqual(schedule(10000), 20)

# Before start_step returns min_freq
schedule_delayed = precondition_schedules.StepSchedule({0: 2, 100: 10}, start_step=50)
self.assertEqual(schedule_delayed(25), 2)
self.assertEqual(schedule_delayed(100), 10)

# Empty dict raises
with self.assertRaises(ValueError):
precondition_schedules.StepSchedule({})

# Invalid frequency raises
with self.assertRaises(ValueError):
precondition_schedules.StepSchedule({0: 0})

# Negative step key raises
with self.assertRaises(ValueError):
precondition_schedules.StepSchedule({-1: 5})


class SoapTest(parameterized.TestCase):
@classmethod
def setUpClass(cls):
Expand Down
Loading