From 45f4f890cf5186662a3c1a67c2dfb5f95503fd8f Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 23 Apr 2025 12:36:08 -0500 Subject: [PATCH 1/3] Require explicit trial_runner_id on instantiation and fix a bug where it was missing. --- mlos_bench/mlos_bench/storage/base_storage.py | 2 +- mlos_bench/mlos_bench/storage/sql/experiment.py | 2 ++ mlos_bench/mlos_bench/storage/sql/trial.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index ba84675425..f2d393994f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -409,7 +409,7 @@ def __init__( # pylint: disable=too-many-arguments experiment_id: str, trial_id: int, tunable_config_id: int, - trial_runner_id: int | None = None, + trial_runner_id: int | None, opt_targets: dict[str, Literal["min", "max"]], config: dict[str, Any] | None = None, status: Status = Status.UNKNOWN, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 62daa0232c..eb47de7d71 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -276,6 +276,7 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor experiment_id=self._experiment_id, trial_id=trial.trial_id, config_id=trial.config_id, + trial_runner_id=trial.trial_runner_id, opt_targets=self._opt_targets, config=config, ) @@ -350,6 +351,7 @@ def _new_trial( experiment_id=self._experiment_id, trial_id=self._trial_id, config_id=config_id, + trial_runner_id=None, # initially, Trials are not assigned to a TrialRunner opt_targets=self._opt_targets, config=config, status=new_trial_status, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 88fc05ef7d..a9960c466a 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -38,7 +38,7 @@ def __init__( # pylint: disable=too-many-arguments experiment_id: str, trial_id: int, config_id: int, - trial_runner_id: int | None = None, + trial_runner_id: int | None, opt_targets: dict[str, Literal["min", "max"]], config: dict[str, Any] | None = None, status: Status = Status.UNKNOWN, From e015216538b60568a392413587a5404ac1917adf Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 23 Apr 2025 12:53:12 -0500 Subject: [PATCH 2/3] test pending trials restore --- .../mlos_bench/tests/storage/trial_schedule_test.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index a1ab74f9f5..bdffae836f 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -47,6 +47,10 @@ def test_schedule_trial( # Schedule 2 hours in the future: trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) + # Check that if we assign a TrialRunner that that value is still available on restore. + trial_now2.set_trial_runner(1) + assert trial_now2.trial_runner_id + exp_data = storage.experiments[exp_storage.experiment_id] trial_now1_data = exp_data.trials[trial_now1.trial_id] assert trial_now1_data.trial_runner_id is None @@ -54,14 +58,20 @@ def test_schedule_trial( # Check that Status matches in object vs. backend storage. assert trial_now1.status == trial_now1_data.status + trial_now2_data = exp_data.trials[trial_now2.trial_id] + assert trial_now2_data.trial_runner_id == trial_now2.trial_runner_id + # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: - pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) + pending_trials = exp_storage.pending_trials(timestamp + timedelta_1min, running=False) + pending_ids = _trial_ids(pending_trials) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, } + pending_trial_now2 = next(iter(t for t in pending_trials if t.trial_id == trial_now2.trial_id)) + assert pending_trial_now2.trial_runner_id == trial_now2_data.trial_runner_id # Get trials scheduled to run within the next 1 hour: pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) From e590a298ca09ab72bb70cc69c7004cdfd65237fc Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 23 Apr 2025 13:05:43 -0500 Subject: [PATCH 3/3] test fixup --- .../tests/storage/trial_schedule_test.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index bdffae836f..aaf545c787 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -64,14 +64,21 @@ def test_schedule_trial( # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: - pending_trials = exp_storage.pending_trials(timestamp + timedelta_1min, running=False) - pending_ids = _trial_ids(pending_trials) + pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, } - pending_trial_now2 = next(iter(t for t in pending_trials if t.trial_id == trial_now2.trial_id)) - assert pending_trial_now2.trial_runner_id == trial_now2_data.trial_runner_id + + # Make sure that the pending trials and trial_runner_ids match. + pending_trial_runner_ids = { + pending_trial.trial_id: pending_trial.trial_runner_id + for pending_trial in exp_storage.pending_trials(timestamp + timedelta_1min, running=False) + } + assert pending_trial_runner_ids == { + trial_now1.trial_id: trial_now1.trial_runner_id, + trial_now2.trial_id: trial_now2.trial_runner_id, + } # Get trials scheduled to run within the next 1 hour: pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))