diff --git a/mlos_bench/mlos_bench/config/storage/mysql.jsonc b/mlos_bench/mlos_bench/config/storage/mysql.jsonc index c133b379ba2..15986de4c26 100644 --- a/mlos_bench/mlos_bench/config/storage/mysql.jsonc +++ b/mlos_bench/mlos_bench/config/storage/mysql.jsonc @@ -6,7 +6,7 @@ "log_sql": false, // Write all SQL statements to the log. // Parameters below must match kwargs of `sqlalchemy.URL.create()`: "drivername": "mysql+mysqlconnector", - "database": "osat", + "database": "mlos_bench", "username": "root", "password": "PLACERHOLDER PASSWORD", // Comes from global config "host": "localhost", diff --git a/mlos_bench/mlos_bench/config/storage/postgresql.jsonc b/mlos_bench/mlos_bench/config/storage/postgresql.jsonc index cd1214835c4..1cff2d76b76 100644 --- a/mlos_bench/mlos_bench/config/storage/postgresql.jsonc +++ b/mlos_bench/mlos_bench/config/storage/postgresql.jsonc @@ -8,7 +8,7 @@ "log_sql": false, // Write all SQL statements to the log. // Parameters below must match kwargs of `sqlalchemy.URL.create()`: "drivername": "postgresql+psycopg2", - "database": "osat", + "database": "mlos_bench", "username": "postgres", "password": "PLACERHOLDER PASSWORD", // Comes from global config "host": "localhost", diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 75a84bf0b2e..aec5e4d2446 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -25,10 +25,12 @@ from __future__ import annotations import logging +import os from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Mapping from contextlib import AbstractContextManager as ContextManager from datetime import datetime +from subprocess import CalledProcessError from types import TracebackType from typing import Any, Literal @@ -38,7 +40,7 @@ from mlos_bench.services.base_service import Service from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.util import get_git_info +from mlos_bench.util import get_git_info, get_git_root, path_join _LOG = logging.getLogger(__name__) @@ -187,16 +189,61 @@ def __init__( # pylint: disable=too-many-arguments tunables: TunableGroups, experiment_id: str, trial_id: int, - root_env_config: str, + root_env_config: str | None, description: str, opt_targets: dict[str, Literal["min", "max"]], + git_repo: str | None = None, + git_commit: str | None = None, + rel_root_env_config: str | None = None, ): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( - root_env_config - ) + if root_env_config is None: + # Restoring from DB. + if not (git_repo and git_commit and rel_root_env_config): + raise ValueError( + "Missing required args: git_repo, git_commit, rel_root_env_config" + ) + self._git_repo = git_repo + self._git_commit = git_commit + self._rel_root_env_config = rel_root_env_config + + # Currently we only store the relative path of the root env config + # from the git repo it came from. + git_root = git_repo + if not os.path.exists(git_root): + try: + git_root = get_git_root(os.curdir) + except CalledProcessError: + _LOG.warning( + "Failed to find a git repo in the current working directory: %s", + os.curdir, + ) + git_root = get_git_root(__file__) + + self._abs_root_env_config = path_join( + git_root, + self._rel_root_env_config, + abs_path=True, + ) + _LOG.info( + "Resolved relative root_config %s for experiment %s to %s", + self._rel_root_env_config, + self._experiment_id, + self._abs_root_env_config, + ) + else: + if git_repo or git_commit or rel_root_env_config: + raise ValueError("Unexpected args: git_repo, git_commit, rel_root_env_config") + ( + self._git_repo, + self._git_commit, + self._rel_root_env_config, + self._abs_root_env_config, + ) = get_git_info( + root_env_config, + ) self._description = description self._opt_targets = opt_targets self._in_context = False @@ -278,9 +325,21 @@ def description(self) -> str: return self._description @property - def root_env_config(self) -> str: - """Get the Experiment's root Environment config file path.""" - return self._root_env_config + def rel_root_env_config(self) -> str: + """Get the Experiment's root Environment config's relative file path to the + git repo root. + """ + return self._rel_root_env_config + + @property + def abs_root_env_config(self) -> str: + """ + Get the Experiment's root Environment config file path. + + This returns the current absolute path to the root config for this process + instead of the path relative to the git repo root. + """ + return self._abs_root_env_config @property def tunables(self) -> TunableGroups: diff --git a/mlos_bench/mlos_bench/storage/sql/alembic.ini b/mlos_bench/mlos_bench/storage/sql/alembic.ini index 4d2a1120c54..857375445ec 100644 --- a/mlos_bench/mlos_bench/storage/sql/alembic.ini +++ b/mlos_bench/mlos_bench/storage/sql/alembic.ini @@ -63,7 +63,10 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # output_encoding = utf-8 # See README.md for details. +# Uncomment one of these: sqlalchemy.url = sqlite:///mlos_bench.sqlite +#sqlalchemy.url = mysql+mysqlconnector://root:password@localhost:3306/mlos_bench +#sqlalchemy.url = postgresql+psycopg2://root:password@localhost:5432/mlos_bench [post_write_hooks] @@ -72,10 +75,10 @@ sqlalchemy.url = sqlite:///mlos_bench.sqlite # detail and examples # format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks = black -# black.type = console_scripts -# black.entrypoint = black -# black.options = -l 79 REVISION_SCRIPT_FILENAME +hooks = black +black.type = console_scripts +black.entrypoint = black +black.options = REVISION_SCRIPT_FILENAME # lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff diff --git a/mlos_bench/mlos_bench/storage/sql/alembic/README.md b/mlos_bench/mlos_bench/storage/sql/alembic/README.md index ec35eb70f64..48b0b04aa76 100644 --- a/mlos_bench/mlos_bench/storage/sql/alembic/README.md +++ b/mlos_bench/mlos_bench/storage/sql/alembic/README.md @@ -4,26 +4,81 @@ This document contains some notes on how to use [`alembic`](https://alembic.sqla ## Overview -1. Create a blank `mlos_bench.sqlite` database file in the [`mlos_bench/storage/sql`](../) directory with the current schema using the following command: +1. Create a blank database instance in the [`mlos_bench/storage/sql`](../) directory with the current schema using the following command: - ```sh - cd mlos_bench/storage/sql - rm mlos_bench.sqlite - mlos_bench --storage storage/sqlite.jsonc --create-update-storage-schema-only - ``` + This allows `alembic` to automatically generate a migration script from the current schema. + + > NOTE: If your schema changes target a particular backend engine (e.g., using `with_variant`) you will need to use an engine with that config for this step. + > \ + > In the remainder of this document we should some examples for different DB types. + > Pick the one you're targeting and stick with it thru the example. + > You may need to repeat the process several times to test all of them. + > + > - [ ] TODO: Add scripts to automatically do this for several different backend engines all at once. + + For instance: + + 1. Start a temporary server either as a local file or in a docker instance + + ```sh + # sqlite + cd mlos_bench/storage/sql + rm -f mlos_bench.sqlite + ``` + + ```sh + # mysql + docker run -it --rm --name mysql-alembic --env MYSQL_ROOT_PASSWORD=password --env MYSQL_DATABASE=mlos_bench -p 3306:3306 mysql:latest + ``` + + ```sh + # postgres + docker run -it --rm --name postgres-alembic --env POSTGRES_PASSWORD=password --env POSTGRES_DB=mlos_bench -p 5432:5432 postgres:latest + ``` + + 1. Adjust the `sqlalchemy.url` in the [`alembic.ini`](../alembic.ini) file. - > This allows `alembic` to automatically generate a migration script from the current schema. + ```ini + # Uncomment one of these. + # See README.md for details. -1. Adjust the [`mlos_bench/storage/sql/schema.py`](../schema.py) file to reflect the new desired schema. + #sqlalchemy.url = sqlite:///mlos_bench.sqlite + sqlalchemy.url = mysql+pymysql://root:password@localhost:3306/mlos_bench + #sqlalchemy.url = postgresql+psycopg2://root:password@localhost:5432/mlos_bench + ``` + 1. Prime the DB schema + + > Note: you may want to `git checkout main` first to make sure you're using the current schema here. + + ```sh + # sqlite + mlos_bench --storage storage/sqlite.jsonc --create-update-storage-schema-only --password=password + ``` + + ```sh + # mysql + mlos_bench --storage storage/mysql.jsonc --create-update-storage-schema-only --password=password + ``` + + ```sh + # postgres + mlos_bench --storage storage/postgresql.jsonc --create-update-storage-schema-only --password=password + ``` + +1. Now, adjust the [`mlos_bench/storage/sql/schema.py`](../schema.py) file to reflect the new desired schema. + + > Don't forget to do this on a new branch. + > \ > Keep each change small and atomic. + > \ > For example, if you want to add a new column, do that in one change. > If you want to rename a column, do that in another change. 1. Generate a new migration script with the following command: ```sh - alembic revision --autogenerate -m "Descriptive text about the change." + alembic revision --autogenerate -m "CHANGEME: Descriptive text about the change." ``` 1. Review the generated migration script in the [`mlos_bench/storage/sql/alembic/versions`](./versions/) directory. @@ -31,22 +86,62 @@ This document contains some notes on how to use [`alembic`](https://alembic.sqla 1. Verify that the migration script works by running the following command: ```sh + # sqlite mlos_bench --storage storage/sqlite.jsonc --create-update-storage-schema-only ``` + ```sh + # mysql: + mlos_bench --storage storage/mysql.jsonc --create-update-storage-schema-only --password=password + ``` + + ```sh + # postgres: + mlos_bench --storage storage/postgresql.jsonc --create-update-storage-schema-only --password=password + ``` + > Normally this would be done with `alembic upgrade head`, but this command is convenient to ensure if will work with the `mlos_bench` command line interface as well. Examine the results using something like: ```sh + # For sqlite: sqlite3 mlos_bench.sqlite .schema sqlite3 mlos_bench.sqlite "SELECT * FROM alembic_version;" ``` + ```sh + # For mysql: + mysql --user root --password=password --host localhost --protocol tcp --database mlos_bench -e "SHOW TABLES; SELECT * FROM alembic_version;" + ``` + + ```sh + # For postgres: + PGPASSWORD=password psql -h localhost -p 5432 -U postgres mlos_bench -c "SELECT table_name FROM information_schema.tables WHERE table_schema='public' and table_catalog='mlos_bench'; SELECT * FROM alembic_version;" + ``` + + > Use different CLI clients for targeting other engines. + 1. If the migration script works, commit the changes to the [`mlos_bench/storage/sql/schema.py`](../schema.py) and [`mlos_bench/storage/sql/alembic/versions`](./versions/) files. > Be sure to update the latest version in the [`test_storage_schemas.py`](../../../tests/storage/test_storage_schemas.py) file as well. +1. Cleanup any server instances you started. + + For instance: + + ```sh + rm mlos_bench/storage/sql/mlos_bench.sqlite + ``` + + ```sh + docker kill mysql-alembic + ``` + + ```sh + docker kill postgres-alembic + ``` + 1. Merge that to the `main` branch. 1. Might be good to cut a new `mlos_bench` release at this point as well. diff --git a/mlos_bench/mlos_bench/storage/sql/alembic/env.py b/mlos_bench/mlos_bench/storage/sql/alembic/env.py index fc186b8cb1f..af1a0db7206 100644 --- a/mlos_bench/mlos_bench/storage/sql/alembic/env.py +++ b/mlos_bench/mlos_bench/storage/sql/alembic/env.py @@ -5,11 +5,17 @@ """Alembic environment script.""" # pylint: disable=no-member +import logging import sys from logging.config import fileConfig from alembic import context -from sqlalchemy import engine_from_config, pool +from alembic.migration import MigrationContext +from sqlalchemy import create_engine, engine_from_config, pool +from sqlalchemy.dialects import mysql +from sqlalchemy.schema import Column as SchemaColumn +from sqlalchemy.sql.schema import Column +from sqlalchemy.types import TypeEngine from mlos_bench.storage.sql.schema import DbSchema @@ -22,10 +28,19 @@ # Don't override the mlos_bench or pytest loggers though. if config.config_file_name is not None and "alembic" in sys.argv[0]: fileConfig(config.config_file_name) +alembic_logger = logging.getLogger("alembic") # add your model's MetaData object here # for 'autogenerate' support -target_metadata = DbSchema(engine=None).meta +# NOTE: We override the alembic.ini file programmatically in storage/sql/schema.py +# However, the alembic.ini file value is used during alembic CLI operations +# (e.g., dev ops extending the schema). +sqlalchemy_url = config.get_main_option("sqlalchemy.url") +if not sqlalchemy_url: + raise ValueError("Missing sqlalchemy.url: schema changes may not be accurate.") +engine = create_engine(sqlalchemy_url) +alembic_logger.info("engine.url %s", str(engine.url)) +target_metadata = DbSchema(engine=engine).meta # other values from the config, defined by the needs of env.py, # can be acquired: @@ -33,6 +48,117 @@ # ... etc. +def fq_class_name(t: object) -> str: + """Return the fully qualified class name of a type.""" + return t.__module__ + "." + t.__class__.__name__ + + +def custom_compare_types( + migration_context: MigrationContext, # pylint: disable=unused-argument + inspected_column: SchemaColumn | None, # pylint: disable=unused-argument + metadata_column: Column, # pylint: disable=unused-argument + inspected_type: TypeEngine, + metadata_type: TypeEngine, +) -> bool | None: + """ + Custom column type comparator. + + See `Comparing Types + `_ + documentation for more details. + + Parameters + ---------- + + Notes + ----- + In the case of a MySQL DateTime variant, it makes sure that the floating + point accuracy is met. + + Returns + ------- + result : bool | None + Returns True if the column specifications don't match the column (i.e., + a change is needed). + Returns False when the column specification and column match. + Returns None to fallback to the default comparator logic. + """ + metadata_dialect_type = metadata_type.dialect_impl(migration_context.dialect) + if alembic_logger.isEnabledFor(logging.DEBUG): + alembic_logger.debug( + ( + "Comparing columns: " + "inspected_column: [%s] %s and " + "metadata_column: [%s (%s)] %s " + "inspected_column.__dict__: %s\n" + "inspected_column.dialect_options: %s\n" + "inspected_column.dialect_kwargs: %s\n" + "inspected_type.__dict__: %s\n" + "metadata_column.__dict__: %s\n" + "metadata_type.__dict__: %s\n" + "metadata_dialect_type.__dict__: %s\n" + ), + fq_class_name(inspected_type), + inspected_column, + fq_class_name(metadata_type), + fq_class_name(metadata_dialect_type), + metadata_column, + inspected_column.__dict__, + dict(inspected_column.dialect_options) if inspected_column is not None else None, + dict(inspected_column.dialect_kwargs) if inspected_column is not None else None, + inspected_type.__dict__, + metadata_column.__dict__, + metadata_type.__dict__, + metadata_dialect_type.__dict__, + ) + + # Implement a more detailed DATETIME precision comparison for MySQL. + # Note: Currently also handles MariaDB. + if migration_context.dialect.name == "mysql": + if isinstance(metadata_dialect_type, (mysql.DATETIME, mysql.TIMESTAMP)): + if not isinstance(inspected_type, type(metadata_dialect_type)): + alembic_logger.info( + "inspected_type %s does not match metadata_dialect_type %s", + fq_class_name(inspected_type), + fq_class_name(metadata_dialect_type), + ) + return True + else: + if inspected_type.fsp != metadata_dialect_type.fsp: + alembic_logger.info( + "inspected_type.fsp (%s) and metadata_dialect_type.fsp (%s) don't match", + inspected_type.fsp, + metadata_dialect_type.fsp, + ) + return True + + if inspected_type.timezone != metadata_dialect_type.timezone: + alembic_logger.info( + ( + "inspected_type.timezone (%s) and " + "metadata_dialect_type.timezone (%s) don't match" + ), + inspected_type.timezone, + metadata_dialect_type.timezone, + ) + return True + + if alembic_logger.isEnabledFor(logging.DEBUG): + alembic_logger.debug( + ( + "Using default compare_type behavior for " + "inspected_column: [%s] %s and " + "metadata_column: [%s (%s)] %s (see above for details).\n" + ), + fq_class_name(inspected_type), + inspected_column, + fq_class_name(metadata_type), + fq_class_name(metadata_dialect_type), + metadata_column, + ) + return None # fallback to default comparison behavior + + def run_migrations_offline() -> None: """ Run migrations in 'offline' mode. @@ -49,6 +175,7 @@ def run_migrations_offline() -> None: target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, + compare_type=custom_compare_types, ) with context.begin_transaction(): @@ -74,12 +201,20 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=custom_compare_types, + ) with context.begin_transaction(): context.run_migrations() else: - context.configure(connection=connectable, target_metadata=target_metadata) + context.configure( + connection=connectable, + target_metadata=target_metadata, + compare_type=custom_compare_types, + ) with context.begin_transaction(): context.run_migrations() diff --git a/mlos_bench/mlos_bench/storage/sql/alembic/versions/b61aa446e724_support_fractional_seconds_with_mysql.py b/mlos_bench/mlos_bench/storage/sql/alembic/versions/b61aa446e724_support_fractional_seconds_with_mysql.py new file mode 100644 index 00000000000..0d0f3f9b3df --- /dev/null +++ b/mlos_bench/mlos_bench/storage/sql/alembic/versions/b61aa446e724_support_fractional_seconds_with_mysql.py @@ -0,0 +1,124 @@ +# +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +""" +Support fractional seconds with MySQL. + +Revision ID: b61aa446e724 +Revises: 8928a401115b +Create Date: 2025-06-02 17:56:34.746642+00:00 +""" +# pylint: disable=no-member + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import context, op +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision: str = "b61aa446e724" +down_revision: str | None = "8928a401115b" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """The schema upgrade script for this revision.""" + bind = context.get_bind() + if bind.dialect.name == "mysql": + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "experiment", + "ts_start", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=True, + ) + op.alter_column( + "experiment", + "ts_end", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=True, + ) + op.alter_column( + "trial", + "ts_start", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=False, + ) + op.alter_column( + "trial", + "ts_end", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=True, + ) + op.alter_column( + "trial_status", + "ts", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=False, + ) + op.alter_column( + "trial_telemetry", + "ts", + existing_type=mysql.DATETIME(), + type_=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + existing_nullable=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """The schema downgrade script for this revision.""" + bind = context.get_bind() + if bind.dialect.name == "mysql": + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "trial_telemetry", + "ts", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=False, + ) + op.alter_column( + "trial_status", + "ts", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=False, + ) + op.alter_column( + "trial", + "ts_end", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=True, + ) + op.alter_column( + "trial", + "ts_start", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=False, + ) + op.alter_column( + "experiment", + "ts_end", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=True, + ) + op.alter_column( + "experiment", + "ts_start", + existing_type=sa.DateTime(timezone=True).with_variant(mysql.DATETIME(fsp=6), "mysql"), + type_=mysql.DATETIME(), + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index acc2a497b48..6b7aa220a65 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -38,9 +38,12 @@ def __init__( # pylint: disable=too-many-arguments tunables: TunableGroups, experiment_id: str, trial_id: int, - root_env_config: str, + root_env_config: str | None, description: str, opt_targets: dict[str, Literal["min", "max"]], + git_repo: str | None = None, + git_commit: str | None = None, + rel_root_env_config: str | None = None, ): super().__init__( tunables=tunables, @@ -49,6 +52,9 @@ def __init__( # pylint: disable=too-many-arguments root_env_config=root_env_config, description=description, opt_targets=opt_targets, + git_repo=git_repo, + git_commit=git_commit, + rel_root_env_config=rel_root_env_config, ) self._engine = engine self._schema = schema @@ -89,7 +95,7 @@ def _setup(self) -> None: description=self._description, git_repo=self._git_repo, git_commit=self._git_commit, - root_env_config=self._root_env_config, + root_env_config=self._rel_root_env_config, ) ) conn.execute( @@ -367,11 +373,7 @@ def _new_trial( ts_start: datetime | None = None, config: dict[str, Any] | None = None, ) -> Storage.Trial: - # MySQL can round microseconds into the future causing scheduler to skip trials. - # Truncate microseconds to avoid this issue. - ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace( - microsecond=0 - ) + ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 2bc00f00825..e9b369259ba 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -39,6 +39,7 @@ create_mock_engine, inspect, ) +from sqlalchemy.dialects import mysql from sqlalchemy.engine import Engine from mlos_bench.util import path_join @@ -72,43 +73,59 @@ class DbSchema: # for all DB tables, so it's ok to disable the warnings. # pylint: disable=too-many-instance-attributes - # Common string column sizes. - _ID_LEN = 512 - _PARAM_VALUE_LEN = 1024 - _METRIC_VALUE_LEN = 255 - _STATUS_LEN = 16 - - def __init__(self, engine: Engine | None): + def __init__(self, engine: Engine): """ Declare the SQLAlchemy schema for the database. Parameters ---------- - engine : sqlalchemy.engine.Engine | None - The SQLAlchemy engine to use for the DB schema. - Listed as optional for `alembic `_ - schema migration purposes so we can reference it inside it's ``env.py`` - config file for :attr:`~meta` data inspection, but won't generally be - functional without one. + engine : sqlalchemy.engine.Engine """ + assert engine, "Error: can't create schema without engine." _LOG.info("Create the DB schema for: %s", engine) self._engine = engine self._meta = MetaData() + # Common string column sizes. + self._exp_id_len = 512 + self._param_id_len = 512 + self._param_value_len = 1024 + self._metric_id_len = 512 + self._metric_value_len = 255 + self._status_len = 16 + + # Some overrides for certain DB engines: + if engine and engine.dialect.name in {"mysql", "mariadb"}: + self._exp_id_len = 255 + self._param_id_len = 255 + self._metric_id_len = 255 + self.experiment = Table( "experiment", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("description", String(1024)), Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), # For backwards compatibility, we allow NULL for ts_start. - Column("ts_start", DateTime), - Column("ts_end", DateTime), + Column( + "ts_start", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + ), + Column( + "ts_end", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + ), # Should match the text IDs of `mlos_bench.environments.Status` enum: # For backwards compatibility, we allow NULL for status. - Column("status", String(self._STATUS_LEN)), + Column("status", String(self._status_len)), # There may be more than one mlos_benchd_service running on different hosts. # This column stores the host/container name of the driver that # picked up the experiment. @@ -126,7 +143,7 @@ def __init__(self, engine: Engine | None): "objectives", self._meta, Column("exp_id"), - Column("optimization_target", String(self._ID_LEN), nullable=False), + Column("optimization_target", String(self._metric_id_len), nullable=False), Column("optimization_direction", String(4), nullable=False), # TODO: Note: weight is not fully supported yet as currently # multi-objective is expected to explore each objective equally. @@ -175,14 +192,28 @@ def __init__(self, engine: Engine | None): self.trial = Table( "trial", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("trial_id", Integer, nullable=False), Column("config_id", Integer, nullable=False), Column("trial_runner_id", Integer, nullable=True, default=None), - Column("ts_start", DateTime, nullable=False), - Column("ts_end", DateTime), + Column( + "ts_start", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + nullable=False, + ), + Column( + "ts_end", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + nullable=True, + ), # Should match the text IDs of `mlos_bench.environments.Status` enum: - Column("status", String(self._STATUS_LEN), nullable=False), + Column("status", String(self._status_len), nullable=False), PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -197,8 +228,8 @@ def __init__(self, engine: Engine | None): "config_param", self._meta, Column("config_id", Integer, nullable=False), - Column("param_id", String(self._ID_LEN), nullable=False), - Column("param_value", String(self._PARAM_VALUE_LEN)), + Column("param_id", String(self._param_id_len), nullable=False), + Column("param_value", String(self._param_value_len)), PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -212,10 +243,10 @@ def __init__(self, engine: Engine | None): self.trial_param = Table( "trial_param", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("trial_id", Integer, nullable=False), - Column("param_id", String(self._ID_LEN), nullable=False), - Column("param_value", String(self._PARAM_VALUE_LEN)), + Column("param_id", String(self._param_id_len), nullable=False), + Column("param_value", String(self._param_value_len)), PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), ForeignKeyConstraint( ["exp_id", "trial_id"], @@ -230,10 +261,18 @@ def __init__(self, engine: Engine | None): self.trial_status = Table( "trial_status", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("trial_id", Integer, nullable=False), - Column("ts", DateTime(timezone=True), nullable=False, default="now"), - Column("status", String(self._STATUS_LEN), nullable=False), + Column( + "ts", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + nullable=False, + default="now", + ), + Column("status", String(self._status_len), nullable=False), UniqueConstraint("exp_id", "trial_id", "ts"), ForeignKeyConstraint( ["exp_id", "trial_id"], @@ -247,10 +286,10 @@ def __init__(self, engine: Engine | None): self.trial_result = Table( "trial_result", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("trial_id", Integer, nullable=False), - Column("metric_id", String(self._ID_LEN), nullable=False), - Column("metric_value", String(self._METRIC_VALUE_LEN)), + Column("metric_id", String(self._metric_id_len), nullable=False), + Column("metric_value", String(self._metric_value_len)), PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), ForeignKeyConstraint( ["exp_id", "trial_id"], @@ -265,11 +304,19 @@ def __init__(self, engine: Engine | None): self.trial_telemetry = Table( "trial_telemetry", self._meta, - Column("exp_id", String(self._ID_LEN), nullable=False), + Column("exp_id", String(self._exp_id_len), nullable=False), Column("trial_id", Integer, nullable=False), - Column("ts", DateTime(timezone=True), nullable=False, default="now"), - Column("metric_id", String(self._ID_LEN), nullable=False), - Column("metric_value", String(self._METRIC_VALUE_LEN)), + Column( + "ts", + DateTime(timezone=True).with_variant( + mysql.DATETIME(fsp=6), + "mysql", + ), + nullable=False, + default="now", + ), + Column("metric_id", String(self._metric_id_len), nullable=False), + Column("metric_value", String(self._metric_value_len)), UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), ForeignKeyConstraint( ["exp_id", "trial_id"], @@ -288,14 +335,45 @@ def meta(self) -> MetaData: """Return the SQLAlchemy MetaData object.""" return self._meta - @staticmethod - def _get_alembic_cfg(conn: Connection) -> config.Config: + def _get_alembic_cfg(self, conn: Connection) -> config.Config: alembic_cfg = config.Config( path_join(str(files("mlos_bench.storage.sql")), "alembic.ini", abs_path=True) ) + assert self._engine is not None + alembic_cfg.set_main_option( + "sqlalchemy.url", + self._engine.url.render_as_string( + hide_password=False, + ), + ) alembic_cfg.attributes["connection"] = conn return alembic_cfg + def drop_all_tables(self, *, force: bool = False) -> None: + """ + Helper method used in testing to reset the DB schema. + + Notes + ----- + This method is not intended for production use, as it will drop all tables + in the database. Use with caution. + + Parameters + ---------- + force : bool + If True, drop all tables in the target database. + If False, this method will not drop any tables and will log a warning. + """ + assert self._engine + self.meta.reflect(bind=self._engine) + if force: + self.meta.drop_all(bind=self._engine) + else: + _LOG.warning( + "Resetting the schema without force is not implemented. " + "Use force=True to drop all tables." + ) + def create(self) -> "DbSchema": """Create the DB schema.""" _LOG.info("Create the DB schema") diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index 6d98dc97fdd..27748e5ca81 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -82,6 +82,33 @@ def _schema(self) -> DbSchema: _LOG.debug("DDL statements:\n%s", self._db_schema) return self._db_schema + def _reset_schema(self, *, force: bool = False) -> None: + """ + Helper method used in testing to reset the DB schema. + + Notes + ----- + This method is not intended for production use, as it will drop all tables + in the database. Use with caution. + + Parameters + ---------- + force : bool + If True, drop all tables in the target database. + If False, this method will not drop any tables and will log a warning. + """ + assert self._engine + if force: + self._schema.drop_all_tables(force=force) + self._db_schema = DbSchema(self._engine) + self._schema_created = False + self._schema_updated = False + else: + _LOG.warning( + "Resetting the schema without force is not implemented. " + "Use force=True to drop all tables." + ) + def update_schema(self) -> None: """Update the database schema.""" if not self._schema_updated: @@ -112,9 +139,13 @@ def get_experiment_by_id( experiment_id=exp.exp_id, trial_id=-1, # will be loaded upon __enter__ which calls _setup() description=exp.description, - root_env_config=exp.root_env_config, + # Use special logic to load the experiment root config info directly. + root_env_config=None, tunables=tunables, opt_targets=opt_targets, + git_repo=exp.git_repo, + git_commit=exp.git_commit, + rel_root_env_config=exp.root_env_config, ) def experiment( # pylint: disable=too-many-arguments diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index ce5fbdb45af..a367133701c 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -8,15 +8,18 @@ Used to make mypy happy about multiple conftest.py modules. """ import filecmp +import json import os import shutil import socket +import sys from datetime import tzinfo from logging import debug, warning from subprocess import run import pytest import pytz +from pytest_docker.plugin import Services as DockerServices from mlos_bench.util import get_class_from_name, nullable @@ -87,6 +90,73 @@ def check_class_name(obj: object, expected_class_name: str) -> bool: return full_class_name == try_resolve_class_name(expected_class_name) +HOST_DOCKER_NAME = "host.docker.internal" + + +@pytest.fixture(scope="session") +def docker_hostname() -> str: + """Returns the local hostname to use to connect to the test ssh server.""" + if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): + # On Linux, if we're running in a docker container, we can use the + # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. + return HOST_DOCKER_NAME + # Docker (Desktop) for Windows (WSL2) uses a special networking magic + # to refer to the host machine as `localhost` when exposing ports. + # In all other cases, assume we're executing directly inside conda on the host. + return "localhost" + + +def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: + """Wait until a docker service is ready.""" + docker_services.wait_until_responsive( + check=lambda: check_socket(hostname, port), + timeout=30.0, + pause=0.5, + ) + + +def is_docker_service_healthy( + compose_project_name: str, + service_name: str, +) -> bool: + """Check if a docker service is healthy.""" + docker_ps_out = run( + f"docker compose -p {compose_project_name} " f"ps --format json {service_name}", + shell=True, + check=True, + capture_output=True, + ) + docker_ps_json = json.loads(docker_ps_out.stdout.decode().strip()) + state = docker_ps_json["State"] + assert isinstance(state, str) + health = docker_ps_json["Health"] + assert isinstance(health, str) + return state == "running" and health == "healthy" + + +def wait_docker_service_healthy( + docker_services: DockerServices, + project_name: str, + service_name: str, + timeout: float = 30.0, +) -> None: + """Wait until a docker service is healthy.""" + docker_services.wait_until_responsive( + check=lambda: is_docker_service_healthy(project_name, service_name), + timeout=timeout, + pause=0.5, + ) + + +def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: + """Wait until a docker service is ready.""" + docker_services.wait_until_responsive( + check=lambda: check_socket(hostname, port), + timeout=30.0, + pause=0.5, + ) + + def check_socket(host: str, port: int, timeout: float = 1.0) -> bool: """ Test to see if a socket is open. diff --git a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py index ad8f9248acd..7d19c7a2c19 100644 --- a/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/schedulers/test_load_scheduler_config_examples.py @@ -14,7 +14,7 @@ from mlos_bench.schedulers.base_scheduler import Scheduler from mlos_bench.schedulers.trial_runner import TrialRunner from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.storage.sql.storage import SqlStorage +from mlos_bench.storage.base_storage import Storage from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples from mlos_bench.util import get_class_from_name @@ -58,7 +58,7 @@ def test_load_scheduler_config_examples( config_path: str, mock_env_config_path: str, trial_runners: list[TrialRunner], - storage: SqlStorage, + storage: Storage, mock_opt: MockOptimizer, ) -> None: """Tests loading a config example.""" diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index becae205033..59dfbff9d21 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -15,7 +15,7 @@ from pytest_docker.plugin import get_docker_services from mlos_bench.environments.mock_env import MockEnv -from mlos_bench.tests import SEED, tunable_groups_fixtures +from mlos_bench.tests import SEED, resolve_host_name, tunable_groups_fixtures from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name @@ -29,6 +29,22 @@ covariant_group = tunable_groups_fixtures.covariant_group +HOST_DOCKER_NAME = "host.docker.internal" + + +@pytest.fixture(scope="session") +def docker_hostname() -> str: + """Returns the local hostname to use to connect to the test ssh server.""" + if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): + # On Linux, if we're running in a docker container, we can use the + # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. + return HOST_DOCKER_NAME + # Docker (Desktop) for Windows (WSL2) uses a special networking magic + # to refer to the host machine as `localhost` when exposing ports. + # In all other cases, assume we're executing directly inside conda on the host. + return "127.0.0.1" # "localhost" + + @pytest.fixture def mock_env(tunable_groups: TunableGroups) -> MockEnv: """Test fixture for MockEnv.""" @@ -90,6 +106,7 @@ def docker_compose_file(pytestconfig: pytest.Config) -> list[str]: _ = pytestconfig # unused return [ os.path.join(os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml"), + os.path.join(os.path.dirname(__file__), "storage", "sql", "docker-compose.yml"), # Add additional configs as necessary here. ] diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 033fa16330e..64f1df273c8 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -17,8 +17,8 @@ def _format_str(zone_info: tzinfo | None) -> str: if zone_info is not None: - return "%Y-%m-%d %H:%M:%S %z" - return "%Y-%m-%d %H:%M:%S" + return "%Y-%m-%d %H:%M:%S.%f %z" + return "%Y-%m-%d %H:%M:%S.%f" # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @@ -34,7 +34,6 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: tzinfo | None) See Also: http://github.com/microsoft/MLOS/issues/501 """ ts1 = datetime.now(zone_info) - ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=2) format_str = _format_str(zone_info) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index a654ed1f343..2bc789485af 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -16,8 +16,8 @@ def _format_str(zone_info: tzinfo | None) -> str: if zone_info is not None: - return "%Y-%m-%d %H:%M:%S %z" - return "%Y-%m-%d %H:%M:%S" + return "%Y-%m-%d %H:%M:%S.%f %z" + return "%Y-%m-%d %H:%M:%S.%f" # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @@ -25,7 +25,6 @@ def _format_str(zone_info: tzinfo | None) -> str: def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: tzinfo | None) -> None: """Produce benchmark and telemetry data in a local script and read it.""" ts1 = datetime.now(zone_info) - ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) format_str = _format_str(zone_info) @@ -77,7 +76,6 @@ def test_local_env_telemetry_no_header( ) -> None: """Read the telemetry data with no header.""" ts1 = datetime.now(zone_info) - ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) format_str = _format_str(zone_info) @@ -121,7 +119,6 @@ def test_local_env_telemetry_wrong_header( ) -> None: """Read the telemetry data with incorrect header.""" ts1 = datetime.now(zone_info) - ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) format_str = _format_str(zone_info) @@ -150,7 +147,6 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: """Fail when the telemetry data has wrong format.""" zone_info = UTC ts1 = datetime.now(zone_info) - ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) format_str = _format_str(zone_info) diff --git a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py index 257e37fa9e6..b8ea5a2a6b5 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py @@ -8,4 +8,3 @@ # Expose some of those as local names so they can be picked up as fixtures by pytest. ssh_test_server = ssh_fixtures.ssh_test_server -ssh_test_server_hostname = ssh_fixtures.ssh_test_server_hostname diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md b/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md index 4fe4216ff3e..827372d58ca 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/README.md @@ -22,7 +22,7 @@ These are brought up as session fixtures under a unique (PID based) compose proj In the case of `pytest`, since the `SshService` base class implements a shared connection cache that we wish to test, and testing "rebooting" of servers (containers) is also necessary, but we want to avoid single threaded execution for tests, we start a third container only for testing reboots. -Additionally, since `scope="session"` fixtures are executed once per worker, which is excessive in our case, we use lockfiles (one of the only portal synchronization methods) to ensure that the usual `docker_services` fixture which starts and stops the containers is only executed once per test run and uses a shared compose instance. +Additionally, since `scope="session"` fixtures are executed once per worker, which is excessive in our case, we use lockfiles (one of the only portable synchronization methods) to ensure that the usual `docker_services` fixture which starts and stops the containers is only executed once per test run and uses a shared compose instance. ## See Also diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index a6244e3a7a6..de822fe2eb3 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -7,10 +7,6 @@ from dataclasses import dataclass from subprocess import run -from pytest_docker.plugin import Services as DockerServices - -from mlos_bench.tests import check_socket - # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 @@ -21,7 +17,13 @@ @dataclass class SshTestServerInfo: - """A data class for SshTestServerInfo.""" + """ + A data class for SshTestServerInfo. + + See Also + -------- + mlos_bench.tests.storage.sql.SqlTestServerInfo + """ compose_project_name: str service_name: str @@ -70,12 +72,3 @@ def to_connect_params(self, uncached: bool = False) -> dict: "port": self.get_port(uncached), "username": self.username, } - - -def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: - """Wait until a docker service is ready.""" - docker_services.wait_until_responsive( - check=lambda: check_socket(hostname, port), - timeout=30.0, - pause=0.5, - ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py index 34006985af1..83e98c44a24 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py @@ -7,7 +7,6 @@ import mlos_bench.tests.services.remote.ssh.fixtures as ssh_fixtures # Expose some of those as local names so they can be picked up as fixtures by pytest. -ssh_test_server_hostname = ssh_fixtures.ssh_test_server_hostname ssh_test_server = ssh_fixtures.ssh_test_server alt_test_server = ssh_fixtures.alt_test_server reboot_test_server = ssh_fixtures.reboot_test_server diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 79938784a74..fda84e9f79f 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -9,7 +9,6 @@ """ import os -import sys import tempfile from collections.abc import Generator from subprocess import run @@ -19,36 +18,20 @@ from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService -from mlos_bench.tests import resolve_host_name +from mlos_bench.tests import wait_docker_service_socket from mlos_bench.tests.services.remote.ssh import ( ALT_TEST_SERVER_NAME, REBOOT_TEST_SERVER_NAME, SSH_TEST_SERVER_NAME, SshTestServerInfo, - wait_docker_service_socket, ) # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = "host.docker.internal" - - -@pytest.fixture(scope="session") -def ssh_test_server_hostname() -> str: - """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): - # On Linux, if we're running in a docker container, we can use the - # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. - return HOST_DOCKER_NAME - # Docker (Desktop) for Windows (WSL2) uses a special networking magic - # to refer to the host machine as `localhost` when exposing ports. - # In all other cases, assume we're executing directly inside conda on the host. - return "localhost" - @pytest.fixture(scope="session") def ssh_test_server( - ssh_test_server_hostname: str, + docker_hostname: str, docker_compose_project_name: str, locked_docker_services: DockerServices, ) -> Generator[SshTestServerInfo]: @@ -66,12 +49,14 @@ def ssh_test_server( ssh_test_server_info = SshTestServerInfo( compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, - hostname=ssh_test_server_hostname, + hostname=docker_hostname, username="root", id_rsa_path=id_rsa_file.name, ) wait_docker_service_socket( - locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port() + locked_docker_services, + ssh_test_server_info.hostname, + ssh_test_server_info.get_port(), ) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" docker_cp_cmd = ( @@ -116,7 +101,9 @@ def alt_test_server( id_rsa_path=ssh_test_server.id_rsa_path, ) wait_docker_service_socket( - locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port() + locked_docker_services, + alt_test_server_info.hostname, + alt_test_server_info.get_port(), ) return alt_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 003a8e64339..1dce67a13d7 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -12,13 +12,12 @@ from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_service import SshClient -from mlos_bench.tests import requires_docker +from mlos_bench.tests import requires_docker, wait_docker_service_socket from mlos_bench.tests.services.remote.ssh import ( ALT_TEST_SERVER_NAME, REBOOT_TEST_SERVER_NAME, SSH_TEST_SERVER_NAME, SshTestServerInfo, - wait_docker_service_socket, ) _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index c510793fac1..290319c8841 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -11,8 +11,13 @@ # same. # Expose some of those as local names so they can be picked up as fixtures by pytest. -storage = sql_storage_fixtures.storage +mysql_storage_info = sql_storage_fixtures.mysql_storage_info +mysql_storage = sql_storage_fixtures.mysql_storage +postgres_storage_info = sql_storage_fixtures.postgres_storage_info +postgres_storage = sql_storage_fixtures.postgres_storage sqlite_storage = sql_storage_fixtures.sqlite_storage +mem_storage = sql_storage_fixtures.mem_storage +storage = sql_storage_fixtures.storage exp_storage = sql_storage_fixtures.exp_storage exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index dc8baf489c7..b8ca6a64b9d 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -30,7 +30,7 @@ def test_exp_data_root_env_config( """Tests the root_env_config property of ExperimentData.""" # pylint: disable=protected-access assert exp_data.root_env_config == ( - exp_storage._root_env_config, + exp_storage._rel_root_env_config, exp_storage._git_repo, exp_storage._git_commit, ) diff --git a/mlos_bench/mlos_bench/tests/storage/sql/README.md b/mlos_bench/mlos_bench/tests/storage/sql/README.md new file mode 100644 index 00000000000..c9b124eab33 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/storage/sql/README.md @@ -0,0 +1,27 @@ +# Sql Storage Tests + +The "unit" tests for the `SqlStorage` classes are more functional than other unit tests in that we don't merely mock them out, but actually setup small SQL databases with `docker compose` and interact with them using the `SqlStorage` class. + +To do this, we make use of the `pytest-docker` plugin to bring up the services defined in the [`docker-compose.yml`](./docker-compose.yml) file in this directory. + +There are currently two services defined in that config, though others could be added in the future: + +1. `mysql-mlos-bench-server` +1. `postgres-mlos-bench-server` + +We rely on `docker compose` to map their internal container service ports to random ports on the host. +Hence, when connecting, we need to look up these ports on demand using something akin to `docker compose port`. +Because of complexities of networking in different development environments (especially for Docker on WSL2 for Windows), we may also have to connect to a different host address than `localhost` (e.g., `host.docker.internal`, which is dynamically requested as a part of of the [devcontainer](../../../../../../.devcontainer/docker-compose.yml) setup). + +These containers are brought up as session fixtures under a unique (PID based) compose project name for each `pytest` invocation, but only when docker is detected on the host (via the `@docker_required` decorator we define in [`mlos_bench/tests/__init__.py`](../../../__init__.py)), else those tests are skipped. + +> For manual testing, to bring up/down the test infrastructure the [`up.sh`](./up.sh) and [`down.sh`](./down.sh) scripts can be used, which assigns a known project name. + +In the case of `pytest`, we also want to be able to test with a fresh state in most cases, so we use the `pytest` `yield` pattern to allow schema cleanup code to happen after the end of each test (see: `_create_storage_from_test_server_info`). +We use lockfiles to prevent races between tests that would otherwise try to create or cleanup the same database schema at the same time. + +Additionally, since `scope="session"` fixtures are executed once per worker, which is excessive in our case, we use lockfiles (one of the only portable synchronization methods) to ensure that the usual `docker_services` fixture which starts and stops the containers is only executed once per test run and uses a shared compose instance. + +## See Also + +Notes in the [`mlos_bench/tests/services/remote/ssh/README.md`](../../../services/remote/ssh/README.md) file for a similar setup for SSH services. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py index d17a448b5e3..d1a7c3c800b 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py @@ -3,3 +3,75 @@ # Licensed under the MIT License. # """Tests for mlos_bench sql storage.""" + +from dataclasses import dataclass +from subprocess import run + +# The DB servers' names and other connection info. +# See Also: docker-compose.yml + +MYSQL_TEST_SERVER_NAME = "mysql-mlos-bench-server" +PGSQL_TEST_SERVER_NAME = "postgres-mlos-bench-server" + +SQL_TEST_SERVER_DATABASE = "mlos_bench" +SQL_TEST_SERVER_PASSWORD = "password" + + +@dataclass +class SqlTestServerInfo: + """ + A data class for SqlTestServerInfo. + + See Also + -------- + mlos_bench.tests.services.remote.ssh.SshTestServerInfo + """ + + compose_project_name: str + service_name: str + hostname: str + _port: int | None = None + + @property + def username(self) -> str: + """Gets the username.""" + usernames = { + MYSQL_TEST_SERVER_NAME: "root", + PGSQL_TEST_SERVER_NAME: "postgres", + } + return usernames[self.service_name] + + @property + def password(self) -> str: + """Gets the password.""" + return SQL_TEST_SERVER_PASSWORD + + @property + def database(self) -> str: + """Gets the database.""" + return SQL_TEST_SERVER_DATABASE + + def get_port(self, uncached: bool = False) -> int: + """ + Gets the port that the SSH test server is listening on. + + Note: this value can change when the service restarts so we can't rely on + the DockerServices. + """ + if self._port is None or uncached: + default_ports = { + MYSQL_TEST_SERVER_NAME: 3306, + PGSQL_TEST_SERVER_NAME: 5432, + } + default_port = default_ports[self.service_name] + port_cmd = run( + ( + f"docker compose -p {self.compose_project_name} " + f"port {self.service_name} {default_port}" + ), + shell=True, + check=True, + capture_output=True, + ) + self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) + return self._port diff --git a/mlos_bench/mlos_bench/tests/storage/sql/docker-compose.yml b/mlos_bench/mlos_bench/tests/storage/sql/docker-compose.yml new file mode 100644 index 00000000000..0bfd0bce819 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/storage/sql/docker-compose.yml @@ -0,0 +1,40 @@ +name: mlos_bench-test-sql-storage +services: + mysql-mlos-bench-server: + hostname: mysql-mlos-bench-server + attach: false + image: docker.io/library/mysql:latest + ports: + # To allow multiple instances of this to coexist, instead of explicitly + # mapping the port, let it get assigned randomly on the host. + - ${PORT:-3306} + extra_hosts: + - host.docker.internal:host-gateway + environment: + - MYSQL_ROOT_PASSWORD=password + - MYSQL_DATABASE=mlos_bench + healthcheck: + test: ["CMD-SHELL", "mysqladmin --host localhost --protocol tcp --password=$${MYSQL_ROOT_PASSWORD} ping"] + interval: 10s + timeout: 30s + retries: 5 + start_period: 5s + postgres-mlos-bench-server: + hostname: postgres-mlos-bench-server + attach: false + image: docker.io/library/postgres:latest + ports: + # To allow multiple instances of this to coexist, instead of explicitly + # mapping the port, let it get assigned randomly on the host. + - ${PORT:-5432} + extra_hosts: + - host.docker.internal:host-gateway + environment: + - POSTGRES_PASSWORD=password + - POSTGRES_DB=mlos_bench + healthcheck: + test: ["CMD-SHELL", "pg_isready -d $${POSTGRES_DB}"] + interval: 10s + timeout: 30s + retries: 5 + start_period: 5s diff --git a/mlos_bench/mlos_bench/tests/storage/sql/down.sh b/mlos_bench/mlos_bench/tests/storage/sql/down.sh new file mode 100755 index 00000000000..3d6068cecf7 --- /dev/null +++ b/mlos_bench/mlos_bench/tests/storage/sql/down.sh @@ -0,0 +1,18 @@ +#!/bin/bash +## +## Copyright (c) Microsoft Corporation. +## Licensed under the MIT License. +## + +# A script to stop the containerized SQL DBMS servers. +# For pytest, the fixture in conftest.py will handle this for us using the +# pytest-docker plugin, but for manual testing, this script can be used. + +set -eu + +scriptdir=$(dirname "$(readlink -f "$0")") +cd "$scriptdir" + +PROJECT_NAME="mlos_bench-test-sql-storage-manual" + +docker compose -p "$PROJECT_NAME" down diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 0bebeeff824..9f5e203ad6f 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -8,9 +8,15 @@ import os import tempfile from collections.abc import Generator +from contextlib import contextmanager +from importlib.resources import files from random import seed as rand_seed import pytest +from fasteners import InterProcessLock +from pytest import FixtureRequest +from pytest_docker.plugin import Services as DockerServices +from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.schedulers.sync_scheduler import SyncScheduler @@ -19,16 +25,162 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.sql.storage import SqlStorage from mlos_bench.storage.storage_factory import from_config -from mlos_bench.tests import SEED +from mlos_bench.tests import DOCKER, SEED, wait_docker_service_healthy from mlos_bench.tests.storage import ( CONFIG_TRIAL_REPEAT_COUNT, MAX_TRIALS, TRIAL_RUNNER_COUNT, ) +from mlos_bench.tests.storage.sql import ( + MYSQL_TEST_SERVER_NAME, + PGSQL_TEST_SERVER_NAME, + SqlTestServerInfo, +) from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.util import path_join # pylint: disable=redefined-outer-name +DOCKER_DBMS_FIXTURES = [] +if DOCKER: + DOCKER_DBMS_FIXTURES = [ + lazy_fixture("mysql_storage"), + lazy_fixture("postgres_storage"), + ] + +PERSISTENT_SQL_STORAGE_FIXTURES = [lazy_fixture("sqlite_storage")] +if DOCKER: + PERSISTENT_SQL_STORAGE_FIXTURES.extend(DOCKER_DBMS_FIXTURES) + + +@pytest.fixture(scope="session") +def mysql_storage_info( + docker_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices, +) -> SqlTestServerInfo: + """Fixture for getting mysql storage connection info.""" + storage_info = SqlTestServerInfo( + compose_project_name=docker_compose_project_name, + service_name=MYSQL_TEST_SERVER_NAME, + hostname=docker_hostname, + ) + wait_docker_service_healthy( + locked_docker_services, + storage_info.compose_project_name, + storage_info.service_name, + ) + + return storage_info + + +@pytest.fixture(scope="session") +def postgres_storage_info( + docker_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices, +) -> SqlTestServerInfo: + """Fixture for getting postgres storage connection info.""" + storage_info = SqlTestServerInfo( + compose_project_name=docker_compose_project_name, + service_name=PGSQL_TEST_SERVER_NAME, + hostname=docker_hostname, + ) + wait_docker_service_healthy( + locked_docker_services, + storage_info.compose_project_name, + storage_info.service_name, + ) + return storage_info + + +@contextmanager +def _create_storage_from_test_server_info( + config_file: str, + test_server_info: SqlTestServerInfo, + shared_temp_dir: str, + short_testrun_uid: str, +) -> Generator[SqlStorage]: + """ + Creates a SqlStorage instance from the given test server info. + + Notes + ----- + Resets the schema as a cleanup operation on return from the function scope + fixture so each test gets a fresh storage instance. + Uses a file lock to ensure that only one test can access the storage at a time. + + Yields + ------ + SqlStorage + """ + sql_storage_name = test_server_info.service_name + with InterProcessLock( + path_join(shared_temp_dir, f"{sql_storage_name}-{short_testrun_uid}.lock") + ): + global_config = { + "host": test_server_info.hostname, + "port": test_server_info.get_port() or 0, + "database": test_server_info.database, + "username": test_server_info.username, + "password": test_server_info.password, + "lazy_schema_create": True, + } + storage = from_config( + config_file, + global_configs=[json.dumps(global_config)], + ) + assert isinstance(storage, SqlStorage) + try: + yield storage + finally: + # Cleanup the storage on return + storage._reset_schema(force=True) # pylint: disable=protected-access + + +@pytest.fixture(scope="function") +def mysql_storage( + mysql_storage_info: SqlTestServerInfo, + shared_temp_dir: str, + short_testrun_uid: str, +) -> Generator[SqlStorage]: + """ + Fixture of a MySQL backed SqlStorage engine. + + See Also + -------- + _create_storage_from_test_server_info + """ + with _create_storage_from_test_server_info( + path_join(str(files("mlos_bench.config")), "storage", "mysql.jsonc"), + mysql_storage_info, + shared_temp_dir, + short_testrun_uid, + ) as storage: + yield storage + + +@pytest.fixture(scope="function") +def postgres_storage( + postgres_storage_info: SqlTestServerInfo, + shared_temp_dir: str, + short_testrun_uid: str, +) -> Generator[SqlStorage]: + """ + Fixture of a MySQL backed SqlStorage engine. + + See Also + -------- + _create_storage_from_test_server_info + """ + with _create_storage_from_test_server_info( + path_join(str(files("mlos_bench.config")), "storage", "postgresql.jsonc"), + postgres_storage_info, + shared_temp_dir, + short_testrun_uid, + ) as storage: + yield storage + @pytest.fixture def sqlite_storage() -> Generator[SqlStorage]: @@ -63,7 +215,7 @@ def sqlite_storage() -> Generator[SqlStorage]: @pytest.fixture -def storage() -> SqlStorage: +def mem_storage() -> SqlStorage: """Test fixture for in-memory SQLite3 storage.""" return SqlStorage( service=None, @@ -75,6 +227,19 @@ def storage() -> SqlStorage: ) +@pytest.fixture( + params=[ + lazy_fixture("mem_storage"), + *DOCKER_DBMS_FIXTURES, + ] +) +def storage(request: FixtureRequest) -> SqlStorage: + """Returns a SqlStorage fixture, either in memory, or a dockerized DBMS.""" + sql_storage = request.param + assert isinstance(sql_storage, SqlStorage) + return sql_storage + + @pytest.fixture def exp_storage( storage: SqlStorage, @@ -88,7 +253,7 @@ def exp_storage( with storage.experiment( experiment_id="Test-001", trial_id=1, - root_env_config="environment.jsonc", + root_env_config="my-environment.jsonc", description="pytest experiment", tunables=tunable_groups, opt_targets={"score": "min"}, @@ -222,7 +387,7 @@ def _dummy_run_exp( trial_runners=trial_runners, optimizer=opt, storage=storage, - root_env_config=exp.root_env_config, + root_env_config=exp.abs_root_env_config, ) # Add some trial data to that experiment by "running" it. diff --git a/mlos_bench/mlos_bench/tests/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py similarity index 67% rename from mlos_bench/mlos_bench/tests/storage/test_storage_schemas.py rename to mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py index 8a6c36e6bb3..bf7c4dabee3 100644 --- a/mlos_bench/mlos_bench/tests/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/test_storage_schemas.py @@ -4,20 +4,33 @@ # """Test sql schemas for mlos_bench storage.""" +import pytest from alembic.migration import MigrationContext +from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture from sqlalchemy import inspect from mlos_bench.storage.sql.storage import SqlStorage +from mlos_bench.tests.storage.sql.fixtures import DOCKER_DBMS_FIXTURES # NOTE: This value is hardcoded to the latest revision in the alembic versions directory. # It could also be obtained programmatically using the "alembic heads" command or heads() API. # See Also: schema.py for an example of programmatic alembic config access. -CURRENT_ALEMBIC_HEAD = "8928a401115b" +CURRENT_ALEMBIC_HEAD = "b61aa446e724" -def test_storage_schemas(storage: SqlStorage) -> None: +# Try to test multiple DBMS engines. +@pytest.mark.parametrize( + "some_sql_storage_fixture", + [ + lazy_fixture("mem_storage"), + lazy_fixture("sqlite_storage"), + *DOCKER_DBMS_FIXTURES, + ], +) +def test_storage_schemas(some_sql_storage_fixture: SqlStorage) -> None: """Test storage schema creation.""" - eng = storage._engine # pylint: disable=protected-access + assert isinstance(some_sql_storage_fixture, SqlStorage) + eng = some_sql_storage_fixture._engine # pylint: disable=protected-access with eng.connect() as conn: # pylint: disable=protected-access inspector = inspect(conn) # Make sure the "trial_runner_id" column exists. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/up.sh b/mlos_bench/mlos_bench/tests/storage/sql/up.sh new file mode 100755 index 00000000000..36fea613b9a --- /dev/null +++ b/mlos_bench/mlos_bench/tests/storage/sql/up.sh @@ -0,0 +1,38 @@ +#!/bin/bash +## +## Copyright (c) Microsoft Corporation. +## Licensed under the MIT License. +## + +# A script to start the containerized SQL DBMS servers. +# For pytest, the fixture in conftest.py will handle this for us using the +# pytest-docker plugin, but for manual testing, this script can be used. + +set -eu +set -x + +scriptdir=$(dirname "$(readlink -f "$0")") +cd "$scriptdir" + +PROJECT_NAME="mlos_bench-test-sql-storage-manual" +CONTAINER_COUNT=2 + +docker compose -p "$PROJECT_NAME" up --build --remove-orphans -d +set +x + +function get_project_health() { + docker compose -p "$PROJECT_NAME" ps --format '{{.Name}} {{.State}} {{.Health}}' +} + +project_health=$(get_project_health) +while ! echo "$project_health" | grep -c ' running healthy$' | grep -q -x $CONTAINER_COUNT; do + echo "Waiting for $CONTAINER_COUNT containers to report healthy ..." + echo "$project_health" + sleep 1 + project_health=$(get_project_health) +done + +mysql_port=$(docker compose -p "$PROJECT_NAME" port mysql-mlos-bench-server ${PORT:-3306} | cut -d: -f2) +echo "Connect to the mysql server container at the following port: $mysql_port" +postgres_port=$(docker compose -p "$PROJECT_NAME" port postgres-mlos-bench-server ${PORT:-5432} | cut -d: -f2) +echo "Connect to the postgres server container at the following port: $postgres_port" diff --git a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py index 7871e7f68ca..21c204283cf 100644 --- a/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py +++ b/mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py @@ -12,7 +12,8 @@ from pytz import UTC from mlos_bench.environments.status import Status -from mlos_bench.storage.sql.storage import SqlStorage +from mlos_bench.storage.base_storage import Storage +from mlos_bench.tests.storage.sql.fixtures import PERSISTENT_SQL_STORAGE_FIXTURES from mlos_bench.tunables.tunable_groups import TunableGroups @@ -22,14 +23,24 @@ sys.platform == "win32", reason="Windows doesn't support multiple processes accessing the same file.", ) +@pytest.mark.parametrize( + "persistent_storage", + [ + # TODO: Improve this test to support non-sql backends eventually as well. + *PERSISTENT_SQL_STORAGE_FIXTURES, + ], +) def test_storage_pickle_restore_experiment_and_trial( - sqlite_storage: SqlStorage, + persistent_storage: Storage, tunable_groups: TunableGroups, ) -> None: """Check that we can pickle and unpickle the Storage object, and restore Experiment and Trial by id. """ - storage = sqlite_storage + storage = persistent_storage + storage_class = storage.__class__ + assert issubclass(storage_class, Storage) + assert storage_class != Storage # Create an Experiment and a Trial opt_targets: dict[str, Literal["min", "max"]] = {"metric": "min"} experiment = storage.experiment( @@ -49,7 +60,7 @@ def test_storage_pickle_restore_experiment_and_trial( # Pickle and unpickle the Storage object pickled = pickle.dumps(storage) restored_storage = pickle.loads(pickled) - assert isinstance(restored_storage, SqlStorage) + assert isinstance(restored_storage, storage_class) # Restore the Experiment from storage by id and check that it matches the original restored_experiment = restored_storage.get_experiment_by_id( @@ -61,7 +72,7 @@ def test_storage_pickle_restore_experiment_and_trial( assert restored_experiment is not experiment assert restored_experiment.experiment_id == experiment.experiment_id assert restored_experiment.description == experiment.description - assert restored_experiment.root_env_config == experiment.root_env_config + assert restored_experiment.rel_root_env_config == experiment.rel_root_env_config assert restored_experiment.tunables == experiment.tunables assert restored_experiment.opt_targets == experiment.opt_targets with restored_experiment: diff --git a/mlos_bench/mlos_bench/tests/util_git_test.py b/mlos_bench/mlos_bench/tests/util_git_test.py index 77fd2779c77..788c06a5ead 100644 --- a/mlos_bench/mlos_bench/tests/util_git_test.py +++ b/mlos_bench/mlos_bench/tests/util_git_test.py @@ -3,14 +3,30 @@ # Licensed under the MIT License. # """Unit tests for get_git_info utility function.""" +import os import re -from mlos_bench.util import get_git_info +from mlos_bench.util import get_git_info, path_join def test_get_git_info() -> None: - """Check that we can retrieve git info about the current repository correctly.""" - (git_repo, git_commit, rel_path) = get_git_info(__file__) + """Check that we can retrieve git info about the current repository correctly from a + file. + """ + (git_repo, git_commit, rel_path, abs_path) = get_git_info(__file__) assert "mlos" in git_repo.lower() assert re.match(r"[0-9a-f]{40}", git_commit) is not None assert rel_path == "mlos_bench/mlos_bench/tests/util_git_test.py" + assert abs_path == path_join(__file__, abs_path=True) + + +def test_get_git_info_dir() -> None: + """Check that we can retrieve git info about the current repository correctly from a + directory. + """ + dirname = os.path.dirname(__file__) + (git_repo, git_commit, rel_path, abs_path) = get_git_info(dirname) + assert "mlos" in git_repo.lower() + assert re.match(r"[0-9a-f]{40}", git_commit) is not None + assert rel_path == "mlos_bench/mlos_bench/tests" + assert abs_path == path_join(dirname, abs_path=True) diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 1c45cd4ecf9..0836f5f0486 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -274,7 +274,32 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s ) -def get_git_info(path: str = __file__) -> tuple[str, str, str]: +def get_git_root(path: str = __file__) -> str: + """ + Get the root dir of the git repository. + + Parameters + ---------- + path : str, optional + Path to the file in git repository. + + Returns + ------- + str + _description_ + """ + abspath = path_join(path, abs_path=True) + if not os.path.exists(abspath) or not os.path.isdir(abspath): + dirname = os.path.dirname(abspath) + else: + dirname = abspath + git_root = subprocess.check_output( + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() + return git_root + + +def get_git_info(path: str = __file__) -> tuple[str, str, str, str]: """ Get the git repository, commit hash, and local path of the given file. @@ -286,9 +311,13 @@ def get_git_info(path: str = __file__) -> tuple[str, str, str]: Returns ------- (git_repo, git_commit, git_path) : tuple[str, str, str] - Git repository URL, last commit hash, and relative file path. + Git repository URL, last commit hash, and relative file path and current absolute path. """ - dirname = os.path.dirname(path) + abspath = path_join(path, abs_path=True) + if not os.path.exists(abspath) or not os.path.isdir(abspath): + dirname = os.path.dirname(abspath) + else: + dirname = abspath git_repo = subprocess.check_output( ["git", "-C", dirname, "remote", "get-url", "origin"], text=True ).strip() @@ -299,8 +328,8 @@ def get_git_info(path: str = __file__) -> tuple[str, str, str]: ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True ).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) - rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) - return (git_repo, git_commit, rel_path.replace("\\", "/")) + rel_path = os.path.relpath(abspath, os.path.abspath(git_root)) + return (git_repo, git_commit, rel_path.replace("\\", "/"), abspath) # Note: to avoid circular imports, we don't specify TunableValue here.