diff --git a/ciclo/callbacks.py b/ciclo/callbacks.py index 0e4b477..68b697f 100644 --- a/ciclo/callbacks.py +++ b/ciclo/callbacks.py @@ -55,6 +55,11 @@ class OptimizationMode(str, Enum): max = auto() +class DeltaMode(str, Enum): + absolute = auto() + relative = auto() + + def _transpose_history( log_history: History, ) -> Mapping[Collection, Mapping[Entry, List[Any]]]: @@ -176,17 +181,17 @@ def __init__( keep_every_n_steps: Optional[int] = None, async_manager: Optional[flax_checkpoints.AsyncManager] = None, monitor: Optional[str] = None, - mode: Union[str, OptimizationMode] = "min", + optimization_mode: Union[str, OptimizationMode] = "min", ): - if isinstance(mode, str): - mode = OptimizationMode[mode] + if isinstance(optimization_mode, str): + optimization_mode = OptimizationMode[optimization_mode] - if mode not in OptimizationMode: + if optimization_mode not in OptimizationMode: raise ValueError( - f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}" + f"Invalid optimization_mode: {optimization_mode}, expected one of {list(OptimizationMode)}" ) else: - self.mode = mode + self.optimization_mode = optimization_mode self.ckpt_dir = ckpt_dir self.prefix = prefix @@ -195,7 +200,7 @@ def __init__( self.keep_every_n_steps = keep_every_n_steps self.async_manager = async_manager self.monitor = monitor - self.minimize = self.mode == OptimizationMode.min + self.minimize = self.optimization_mode == OptimizationMode.min self._best: Optional[float] = None def __call__( @@ -227,7 +232,9 @@ def __call__( ): self._best = value step_or_metric = ( - value if self.mode == OptimizationMode.max else -value + value + if self.optimization_mode == OptimizationMode.max + else -value ) else: save_checkpoint = False @@ -264,30 +271,69 @@ def __init__( self, monitor: str, patience: Union[int, Period], - min_delta: float = 0, - mode: Union[str, OptimizationMode] = "min", + initial_patience: Optional[Union[int, Period]] = None, + min_delta: Optional[float] = None, + delta_mode: Union[str, DeltaMode] = "absolute", + optimization_mode: Union[str, OptimizationMode] = "min", baseline: Optional[float] = None, restore_best_weights: bool = False, ): - if isinstance(mode, str): - mode = OptimizationMode[mode] + if initial_patience is None: + initial_patience = 1 + + if min_delta is None: + min_delta = 0.0 - if mode not in OptimizationMode: + if isinstance(optimization_mode, str): + optimization_mode = OptimizationMode[optimization_mode] + + if optimization_mode not in OptimizationMode: raise ValueError( - f"Invalid mode: {mode}, expected one of {list(OptimizationMode)}" + f"Invalid mode: {optimization_mode}, expected one of {list(OptimizationMode)}" ) - else: - self.mode = mode + + if isinstance(delta_mode, str): + delta_mode = DeltaMode[delta_mode] + + if delta_mode not in DeltaMode: + raise ValueError( + f"Invalid mode: {delta_mode}, expected one of {list(DeltaMode)}" + ) + + if ( + optimization_mode == OptimizationMode.min + and delta_mode == DeltaMode.absolute + ): + self.improvement_fn = lambda current, best: current < best - min_delta + elif ( + optimization_mode == OptimizationMode.min + and delta_mode == DeltaMode.relative + ): + self.improvement_fn = lambda current, best: current < best * (1 - min_delta) + elif ( + optimization_mode == OptimizationMode.max + and delta_mode == DeltaMode.absolute + ): + self.improvement_fn = lambda current, best: current > best + min_delta + elif ( + optimization_mode == OptimizationMode.max + and delta_mode == DeltaMode.relative + ): + self.improvement_fn = lambda current, best: current > best * (1 + min_delta) self.monitor = monitor self.patience = ( patience if isinstance(patience, Period) else Period.create(patience) ) + self.initial_patience = ( + initial_patience + if isinstance(initial_patience, Period) + else Period.create(initial_patience) + ) self.min_delta = min_delta - self.mode = mode self.baseline = baseline self.restore_best_weights = restore_best_weights - self.minimize = self.mode == OptimizationMode.min + self.minimize = optimization_mode == OptimizationMode.min self._best = baseline self._best_state = None self._elapsed_start: Optional[Elapsed] = None @@ -306,16 +352,15 @@ def __call__(self, elapsed: Elapsed, state: S, logs: Logs) -> Tuple[bool, S]: except KeyError: raise ValueError(f"Monitored value '{self.monitor}' not found in logs") - if ( - self._best is None - or (self.minimize and value < self._best) - or (not self.minimize and value > self._best) - ): + if self._best is None or self.improvement_fn(value, self._best): self._best = value self._best_state = state self._elapsed_start = elapsed - if elapsed - self._elapsed_start >= self.patience: + if ( + elapsed - self._elapsed_start >= self.patience + and elapsed >= self.initial_patience + ): if self.restore_best_weights and self._best_state is not None: state = self._best_state stop_iteration = True diff --git a/ciclo/loops/loop.py b/ciclo/loops/loop.py index accc0fa..b9696d2 100644 --- a/ciclo/loops/loop.py +++ b/ciclo/loops/loop.py @@ -165,7 +165,7 @@ def loop( ) for schedule, callbacks in tasks.items() ] - # prone empty tasks + # prune empty tasks schedule_callbacks = [x for x in schedule_callbacks if len(x[1]) > 0] try: diff --git a/examples/flax/02_mnist_train_loop.py b/examples/flax/02_mnist_train_loop.py index 2143a1e..82f97ce 100644 --- a/examples/flax/02_mnist_train_loop.py +++ b/examples/flax/02_mnist_train_loop.py @@ -117,7 +117,7 @@ def reset_step(state: TrainState): ciclo.checkpoint( f"logdir/{Path(__file__).stem}/{int(time())}", monitor="accuracy_test", - mode="max", + optimization_mode="max", ), ciclo.keras_bar(total=total_steps), ], diff --git a/examples/flax/04_mnist_managed_api.py b/examples/flax/04_mnist_managed_api.py index 731a1de..7e7a289 100644 --- a/examples/flax/04_mnist_managed_api.py +++ b/examples/flax/04_mnist_managed_api.py @@ -105,12 +105,12 @@ def reset_metrics(state: ManagedState): ciclo.checkpoint( f"logdir/{Path(__file__).stem}/{int(time())}", monitor="accuracy_valid", - mode="max", + optimization_mode="max", keep=3, ), ciclo.early_stopping( monitor="accuracy_valid", - mode="max", + optimization_mode="max", patience=eval_steps * 2, ), reset_metrics, diff --git a/examples/flax/05_mnist_flax_state.py b/examples/flax/05_mnist_flax_state.py index 8b1f45d..e17069d 100644 --- a/examples/flax/05_mnist_flax_state.py +++ b/examples/flax/05_mnist_flax_state.py @@ -55,7 +55,7 @@ def __call__(self, x): ciclo.checkpoint( f"logdir/{Path(__file__).stem}/{int(time())}", monitor="accuracy_test", - mode="max", + optimization_mode="max", ), ], test_dataset=lambda: ds_test.as_numpy_iterator(), diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index e98bfe9..544be56 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -35,7 +35,7 @@ def dummy_inner_loop_fn(_): return None, log_history, None -class TestCallbacks: +class TestInnerLoop: def test_inner_loop_default_aggregation(self): inner_loop = ciclo.callbacks.inner_loop( "test", @@ -133,3 +133,186 @@ def test_inner_loop_aggregation_dict(self): "D_test": jnp.array(0.0, dtype=jnp.float32), }, } + + +class TestEarlyStopping: + def test_patience(self): + dataset = jnp.minimum(jnp.arange(10), 5) + + def train_step(state, batch): + logs = ciclo.logs() + logs.add_metric("x", batch) + return logs, state + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping("x", optimization_mode="max", patience=1), + ], + }, + ) + + assert len(history) == 7 + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping("x", optimization_mode="max", patience=3), + ], + }, + ) + + assert len(history) == 9 + + def test_initial_patience(self): + dataset = jnp.maximum(jnp.minimum(jnp.arange(10), 5), 2) + + def train_step(state, batch): + logs = ciclo.logs() + logs.add_metric("x", batch) + return logs, state + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", optimization_mode="max", patience=1, initial_patience=1 + ), + ], + }, + ) + + assert len(history) == 2 + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", optimization_mode="max", patience=1, initial_patience=3 + ), + ], + }, + ) + + assert len(history) == 7 + + def test_min_optimization_mode(self): + dataset = jnp.maximum(jnp.minimum(jnp.arange(9, 0, -1), 6), 3) + + def train_step(state, batch): + logs = ciclo.logs() + logs.add_metric("x", batch) + return logs, state + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", optimization_mode="min", patience=1, initial_patience=4 + ), + ], + }, + ) + + assert len(history) == 8 + + def test_min_delta(self): + dataset = jnp.arange(0, 1, 0.1) + + def train_step(state, batch): + logs = ciclo.logs() + logs.add_metric("x", batch) + return logs, state + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", + optimization_mode="max", + patience=1, + min_delta=0.01, + ), + ], + }, + ) + + assert len(history) == 10 + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", optimization_mode="max", patience=1, min_delta=0.1 + ), + ], + }, + ) + + assert len(history) == 2 + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", + optimization_mode="max", + patience=3, + min_delta=0.05, + ), + ], + }, + ) + + assert len(history) == 10 + + def test_min_relative_delta(self): + dataset = jnp.arange(0, 1, 0.1) + + def train_step(state, batch): + logs = ciclo.logs() + logs.add_metric("x", batch) + return logs, state + + _, history, _ = ciclo.loop( + None, + dataset, + { + ciclo.every(1): [ + train_step, + ciclo.early_stopping( + "x", + optimization_mode="max", + patience=1, + min_delta=0.5, + delta_mode="relative", + ), + ], + }, + ) + + assert len(history) == 4 diff --git a/tests/test_integration.py b/tests/test_integration.py index d65cc8f..243df45 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -177,11 +177,11 @@ def create_state(): ciclo.checkpoint( f"{logdir}/model", monitor="accuracy_valid", - mode="max", + optimization_mode="max", ), ciclo.early_stopping( monitor="accuracy_valid", - mode="max", + optimization_mode="max", patience=100, ), ], @@ -230,7 +230,7 @@ def __call__(self, x): ciclo.checkpoint( f"logdir/{Path(__file__).stem}/{int(time())}", monitor="accuracy_test", - mode="max", + optimization_mode="max", ), ], test_dataset=lambda: get_tuple_dataset(batch_size),