Skip to content

Commit e9adc90

Browse files
priyakasimbegcopybara-github
authored andcommitted
Add schedule-free adamw to training algorithms
PiperOrigin-RevId: 872026624
1 parent d41c715 commit e9adc90

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

init2winit/trainer_lib/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def _eval(self, start_step, start_time, save=True):
447447

448448
if self._eval_use_ema:
449449
eval_params = self.training_algorithm.get_ema_eval_params(
450-
self._optimizer_state
450+
self._optimizer_state, self._params
451451
)
452452
else:
453453
eval_params = self._params

init2winit/trainer_lib/training_algorithm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,11 +313,12 @@ def init_optimizer_state(
313313
return optax_optimizer_state
314314

315315
# TODO(b/436634470): Consolidate this with the prepare_for_eval API
316-
def get_ema_eval_params(self, optimizer_state):
316+
def get_ema_eval_params(self, optimizer_state, params):
317317
"""Extracts the exponential moving average (EMA) parameters from the optimizer state.
318318
319319
Args:
320320
optimizer_state: The current state of the optimizer.
321+
params: The current model parameters.
321322
322323
Returns:
323324
The EMA parameters.
@@ -326,6 +327,7 @@ def get_ema_eval_params(self, optimizer_state):
326327
ValueError: If the EMA parameters cannot be extracted from the optimizer
327328
state.
328329
"""
330+
del params # Unused
329331
if isinstance(optimizer_state, optax.InjectStatefulHyperparamsState):
330332
eval_params = optimizer_state.inner_state[0][0].ema
331333
elif isinstance(

0 commit comments

Comments
 (0)