diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 00000000..03a6b720 --- /dev/null +++ b/.cursorignore @@ -0,0 +1,2 @@ +config/database_credentials.yml +config/certificates \ No newline at end of file diff --git a/.gitignore b/.gitignore index 61e8c437..33d1e323 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,7 @@ powermetrics_log.txt # development folder development/ +# MTLS certificates +config/certificates/ca.crt +config/certificates/user_cert.crt +config/certificates/user_private_key.key diff --git a/config/example_database_credentials.yml b/config/example_database_credentials.yml index df6c497b..b4f4814d 100644 --- a/config/example_database_credentials.yml +++ b/config/example_database_credentials.yml @@ -5,6 +5,11 @@ CREDENTIALS: Connection: Standard: server: example.mysqlserver.com + use_ssl: some_boolean_value + ssl_params: + ca: config/certificates/ca.crt + cert: config/certificates/user_cert.crt + key: config/certificates/user_private_key.key Ssh: server: example.sshmysqlserver.com (address from ssh server) address: example.sslserver.com diff --git a/docs/source/usage/database_credential_file.rst b/docs/source/usage/database_credential_file.rst index 6ee3bab6..610faa84 100644 --- a/docs/source/usage/database_credential_file.rst +++ b/docs/source/usage/database_credential_file.rst @@ -18,7 +18,29 @@ Below is an example of a database credential file, that connects to a server wit Standard: server: example.mysqlserver.com -However, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server +We additionally also support utilizing encrypted connections with (m)tls. To that end, the following parameters can be added to the ``Standard`` section of the database credential file +- ``use_ssl``: a boolean value indicating whether to use ssl +- ``ssl_params``: a dictionary containing the following keys: + - ``ca``: the path to the ca certificate (optional, needed if the database server uses a custom ca certificate not trusted by the client) + - ``cert``: the path to the user certificate (optional, needed in case of client authentication with mtls) + - ``key``: the path to the user private key (optional, needed in case of client authentication with mtls) + +.. code-block:: yaml + + CREDENTIALS: + Database: + user: example_user + password: example_password + Connection: + Standard: + server: example.mysqlserver.com + use_ssl: some_boolean_value + ssl_params: + ca: config/certificates/ca.crt + cert: config/certificates/user_cert.crt + key: config/certificates/user_private_key.key + +Alternatively, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server that has access to the database and then connect to the database server from there. This is done by adding an additional ``Ssh`` section to the database credential file, and can be activated either by a ``PyExperimenter`` keyword argument or in the :ref:`experimenter configuration file `. The following example shows how to connect to a database server using an SSH server with the address ``ssh_hostname`` and the port ``optional_ssh_port``. diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index 58a5ca7f..ff68b0c7 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -1,6 +1,6 @@ import logging from logging import Logger -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import sshtunnel @@ -9,7 +9,11 @@ from py_experimenter.config import DatabaseCfg from py_experimenter.database_connector import DatabaseConnector -from py_experimenter.exceptions import DatabaseConnectionError, DatabaseCreationError, SshTunnelError +from py_experimenter.exceptions import ( + DatabaseConnectionError, + DatabaseCreationError, + SshTunnelError, +) class DatabaseConnectorMYSQL(DatabaseConnector): @@ -23,7 +27,8 @@ def __init__(self, database_configuration: DatabaseCfg, use_codecarbon: bool, cr def get_ssh_tunnel(self, logger: Logger): try: - credentials = OmegaConf.load(self.credential_path)["CREDENTIALS"]["Connection"] + credential_config = dict(OmegaConf.load(self.credential_path)) + credentials = credential_config["CREDENTIALS"]["Connection"] if "Ssh" in credentials: parameters = dict(credentials["Ssh"]) ssh_address_or_host = parameters["address"] @@ -107,12 +112,30 @@ def close_connection(self, connection): def _get_database_credentials(self): try: - credential_config = OmegaConf.load(self.credential_path) + credential_config = dict(OmegaConf.load(self.credential_path)) database_configuration = credential_config["CREDENTIALS"]["Database"] + connection_configuration = credential_config["CREDENTIALS"]["Connection"] if self.database_configuration.use_ssh_tunnel: - server_address = credential_config["CREDENTIALS"]["Connection"]["Ssh"]["server"] + server_address = connection_configuration["Ssh"]["server"] + ssl_params = None + + else: - server_address = credential_config["CREDENTIALS"]["Connection"]["Standard"]["server"] + server_address = connection_configuration["Standard"]["server"] + if "use_ssl" in connection_configuration["Standard"]: + if connection_configuration["Standard"]["use_ssl"]: + ssl_params = dict() + if "ca" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["ca"] = connection_configuration["Standard"]["ssl_params"]["ca"] + if "cert" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["cert"] = connection_configuration["Standard"]["ssl_params"]["cert"] + if "key" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["key"] = connection_configuration["Standard"]["ssl_params"]["key"] + else: + ssl_params = None + else: + ssl_params = None + credentials = { "host": server_address, "user": database_configuration["user"], @@ -122,6 +145,7 @@ def _get_database_credentials(self): return { **credentials, "database": self.database_configuration.database_name, + "ssl": ssl_params, } except Exception as err: logging.error(err) @@ -131,7 +155,7 @@ def _start_transaction(self, connection, readonly=False): if not readonly: connection.begin() - def _table_exists(self, cursor, table_name: str = None) -> bool: + def _table_exists(self, cursor, table_name: Optional[str] = None) -> bool: table_name = table_name if table_name is not None else self.database_configuration.table_name self.execute(cursor, f"SHOW TABLES LIKE '{table_name}'") return self.fetchall(cursor) diff --git a/py_experimenter/experimenter.py b/py_experimenter/experimenter.py index 741da061..a2df77ac 100644 --- a/py_experimenter/experimenter.py +++ b/py_experimenter/experimenter.py @@ -27,8 +27,8 @@ def __init__( experiment_configuration_file_path: str = os.path.join("config", "experiment_configuration.yml"), database_credential_file_path: str = os.path.join("config", "database_credentials.yml"), use_ssh_tunnel: Optional[bool] = None, - table_name: str = None, - database_name: str = None, + table_name: Optional[str] = None, + database_name: Optional[str] = None, use_codecarbon: bool = True, name="PyExperimenter", logger_name: str = "py-experimenter", diff --git a/test/test_codecarbon/test_integration_mysql.py b/test/test_codecarbon/test_integration_mysql.py index 1adc263d..e25d288e 100644 --- a/test/test_codecarbon/test_integration_mysql.py +++ b/test/test_codecarbon/test_integration_mysql.py @@ -12,7 +12,7 @@ def experimenter(): configuration_path = os.path.join("test", "test_codecarbon", "configs", "integration_test_mysql.yml") - return PyExperimenter(configuration_path, use_ssh_tunnel=True) + return PyExperimenter(configuration_path, use_ssh_tunnel=False) def run_ml(parameters: dict, result_processor: ResultProcessor, custom_config: dict):