From 0ef424d0520b78df19f5fb0d71dd0c6bb34510a2 Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Thu, 10 Jul 2025 10:31:47 +0200 Subject: [PATCH 1/4] First version of using mtls --- .cursorignore | 2 ++ .gitignore | 4 +++ config/example_database_credentials.yml | 5 +++ py_experimenter/database_connector_mysql.py | 34 ++++++++++++++++----- py_experimenter/experimenter.py | 4 +-- 5 files changed, 40 insertions(+), 9 deletions(-) create mode 100644 .cursorignore diff --git a/.cursorignore b/.cursorignore new file mode 100644 index 00000000..03a6b720 --- /dev/null +++ b/.cursorignore @@ -0,0 +1,2 @@ +config/database_credentials.yml +config/certificates \ No newline at end of file diff --git a/.gitignore b/.gitignore index 61e8c437..33d1e323 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,7 @@ powermetrics_log.txt # development folder development/ +# MTLS certificates +config/certificates/ca.crt +config/certificates/user_cert.crt +config/certificates/user_private_key.key diff --git a/config/example_database_credentials.yml b/config/example_database_credentials.yml index df6c497b..b4f4814d 100644 --- a/config/example_database_credentials.yml +++ b/config/example_database_credentials.yml @@ -5,6 +5,11 @@ CREDENTIALS: Connection: Standard: server: example.mysqlserver.com + use_ssl: some_boolean_value + ssl_params: + ca: config/certificates/ca.crt + cert: config/certificates/user_cert.crt + key: config/certificates/user_private_key.key Ssh: server: example.sshmysqlserver.com (address from ssh server) address: example.sslserver.com diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index 58a5ca7f..d2fe3eed 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -1,6 +1,6 @@ import logging from logging import Logger -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import sshtunnel @@ -9,7 +9,11 @@ from py_experimenter.config import DatabaseCfg from py_experimenter.database_connector import DatabaseConnector -from py_experimenter.exceptions import DatabaseConnectionError, DatabaseCreationError, SshTunnelError +from py_experimenter.exceptions import ( + DatabaseConnectionError, + DatabaseCreationError, + SshTunnelError, +) class DatabaseConnectorMYSQL(DatabaseConnector): @@ -23,7 +27,8 @@ def __init__(self, database_configuration: DatabaseCfg, use_codecarbon: bool, cr def get_ssh_tunnel(self, logger: Logger): try: - credentials = OmegaConf.load(self.credential_path)["CREDENTIALS"]["Connection"] + credential_config = dict(OmegaConf.load(self.credential_path)) + credentials = credential_config["CREDENTIALS"]["Connection"] if "Ssh" in credentials: parameters = dict(credentials["Ssh"]) ssh_address_or_host = parameters["address"] @@ -107,12 +112,26 @@ def close_connection(self, connection): def _get_database_credentials(self): try: - credential_config = OmegaConf.load(self.credential_path) + credential_config = dict(OmegaConf.load(self.credential_path)) database_configuration = credential_config["CREDENTIALS"]["Database"] + connection_configuration = credential_config["CREDENTIALS"]["Connection"] if self.database_configuration.use_ssh_tunnel: - server_address = credential_config["CREDENTIALS"]["Connection"]["Ssh"]["server"] + server_address = connection_configuration["Ssh"]["server"] else: - server_address = credential_config["CREDENTIALS"]["Connection"]["Standard"]["server"] + server_address = connection_configuration["Standard"]["server"] + + if "use_ssl" in connection_configuration["Standard"]: + if connection_configuration["Standard"]["use_ssl"]: + ssl_params = { + "ca": connection_configuration["Standard"]["ssl_params"]["ca"], + "cert": connection_configuration["Standard"]["ssl_params"]["cert"], + "key": connection_configuration["Standard"]["ssl_params"]["key"], + } + else: + ssl_params = None + else: + ssl_params = None + credentials = { "host": server_address, "user": database_configuration["user"], @@ -122,6 +141,7 @@ def _get_database_credentials(self): return { **credentials, "database": self.database_configuration.database_name, + "ssl": ssl_params, } except Exception as err: logging.error(err) @@ -131,7 +151,7 @@ def _start_transaction(self, connection, readonly=False): if not readonly: connection.begin() - def _table_exists(self, cursor, table_name: str = None) -> bool: + def _table_exists(self, cursor, table_name: Optional[str] = None) -> bool: table_name = table_name if table_name is not None else self.database_configuration.table_name self.execute(cursor, f"SHOW TABLES LIKE '{table_name}'") return self.fetchall(cursor) diff --git a/py_experimenter/experimenter.py b/py_experimenter/experimenter.py index 741da061..a2df77ac 100644 --- a/py_experimenter/experimenter.py +++ b/py_experimenter/experimenter.py @@ -27,8 +27,8 @@ def __init__( experiment_configuration_file_path: str = os.path.join("config", "experiment_configuration.yml"), database_credential_file_path: str = os.path.join("config", "database_credentials.yml"), use_ssh_tunnel: Optional[bool] = None, - table_name: str = None, - database_name: str = None, + table_name: Optional[str] = None, + database_name: Optional[str] = None, use_codecarbon: bool = True, name="PyExperimenter", logger_name: str = "py-experimenter", From a9db378bf73c15186bd8b890cefee4c5bdcbc5b1 Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Thu, 10 Jul 2025 11:54:45 +0200 Subject: [PATCH 2/4] Update config --- py_experimenter/database_connector_mysql.py | 22 ++++++++++--------- .../test_codecarbon/test_integration_mysql.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index d2fe3eed..049f67c3 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -117,20 +117,22 @@ def _get_database_credentials(self): connection_configuration = credential_config["CREDENTIALS"]["Connection"] if self.database_configuration.use_ssh_tunnel: server_address = connection_configuration["Ssh"]["server"] + ssl_params = None + + else: server_address = connection_configuration["Standard"]["server"] - - if "use_ssl" in connection_configuration["Standard"]: - if connection_configuration["Standard"]["use_ssl"]: - ssl_params = { - "ca": connection_configuration["Standard"]["ssl_params"]["ca"], - "cert": connection_configuration["Standard"]["ssl_params"]["cert"], - "key": connection_configuration["Standard"]["ssl_params"]["key"], - } + if "use_ssl" in connection_configuration["Standard"]: + if connection_configuration["Standard"]["use_ssl"]: + ssl_params = { + "ca": connection_configuration["Standard"]["ssl_params"]["ca"], + "cert": connection_configuration["Standard"]["ssl_params"]["cert"], + "key": connection_configuration["Standard"]["ssl_params"]["key"], + } + else: + ssl_params = None else: ssl_params = None - else: - ssl_params = None credentials = { "host": server_address, diff --git a/test/test_codecarbon/test_integration_mysql.py b/test/test_codecarbon/test_integration_mysql.py index 1adc263d..e25d288e 100644 --- a/test/test_codecarbon/test_integration_mysql.py +++ b/test/test_codecarbon/test_integration_mysql.py @@ -12,7 +12,7 @@ def experimenter(): configuration_path = os.path.join("test", "test_codecarbon", "configs", "integration_test_mysql.yml") - return PyExperimenter(configuration_path, use_ssh_tunnel=True) + return PyExperimenter(configuration_path, use_ssh_tunnel=False) def run_ml(parameters: dict, result_processor: ResultProcessor, custom_config: dict): From 2e44a7ea3373789122f067d667828bff50187a80 Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Thu, 10 Jul 2025 14:08:16 +0200 Subject: [PATCH 3/4] Add mtls docs --- .../source/usage/database_credential_file.rst | 24 ++++++++++++++++++- py_experimenter/database_connector_mysql.py | 9 ++++--- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/docs/source/usage/database_credential_file.rst b/docs/source/usage/database_credential_file.rst index 6ee3bab6..610faa84 100644 --- a/docs/source/usage/database_credential_file.rst +++ b/docs/source/usage/database_credential_file.rst @@ -18,7 +18,29 @@ Below is an example of a database credential file, that connects to a server wit Standard: server: example.mysqlserver.com -However, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server +We additionally also support utilizing encrypted connections with (m)tls. To that end, the following parameters can be added to the ``Standard`` section of the database credential file +- ``use_ssl``: a boolean value indicating whether to use ssl +- ``ssl_params``: a dictionary containing the following keys: + - ``ca``: the path to the ca certificate (optional, needed if the database server uses a custom ca certificate not trusted by the client) + - ``cert``: the path to the user certificate (optional, needed in case of client authentication with mtls) + - ``key``: the path to the user private key (optional, needed in case of client authentication with mtls) + +.. code-block:: yaml + + CREDENTIALS: + Database: + user: example_user + password: example_password + Connection: + Standard: + server: example.mysqlserver.com + use_ssl: some_boolean_value + ssl_params: + ca: config/certificates/ca.crt + cert: config/certificates/user_cert.crt + key: config/certificates/user_private_key.key + +Alternatively, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server that has access to the database and then connect to the database server from there. This is done by adding an additional ``Ssh`` section to the database credential file, and can be activated either by a ``PyExperimenter`` keyword argument or in the :ref:`experimenter configuration file `. The following example shows how to connect to a database server using an SSH server with the address ``ssh_hostname`` and the port ``optional_ssh_port``. diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index 049f67c3..e04e8220 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -124,11 +124,10 @@ def _get_database_credentials(self): server_address = connection_configuration["Standard"]["server"] if "use_ssl" in connection_configuration["Standard"]: if connection_configuration["Standard"]["use_ssl"]: - ssl_params = { - "ca": connection_configuration["Standard"]["ssl_params"]["ca"], - "cert": connection_configuration["Standard"]["ssl_params"]["cert"], - "key": connection_configuration["Standard"]["ssl_params"]["key"], - } + ssl_params = dict() + ssl_params["ca"] = connection_configuration["Standard"]["ssl_params"]["ca"] if "ca" in connection_configuration["Standard"]["ssl_params"] else None + ssl_params["cert"] = connection_configuration["Standard"]["ssl_params"]["cert"] if "cert" in connection_configuration["Standard"]["ssl_params"] else None + ssl_params["key"] = connection_configuration["Standard"]["ssl_params"]["key"] if "key" in connection_configuration["Standard"]["ssl_params"] else None else: ssl_params = None else: From 2525d57dfd99a0b71406ed9bb24ee4600b337b85 Mon Sep 17 00:00:00 2001 From: Lukas Fehring Date: Fri, 11 Jul 2025 13:25:51 +0200 Subject: [PATCH 4/4] Update handling --- py_experimenter/database_connector_mysql.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/py_experimenter/database_connector_mysql.py b/py_experimenter/database_connector_mysql.py index e04e8220..ff68b0c7 100644 --- a/py_experimenter/database_connector_mysql.py +++ b/py_experimenter/database_connector_mysql.py @@ -125,9 +125,12 @@ def _get_database_credentials(self): if "use_ssl" in connection_configuration["Standard"]: if connection_configuration["Standard"]["use_ssl"]: ssl_params = dict() - ssl_params["ca"] = connection_configuration["Standard"]["ssl_params"]["ca"] if "ca" in connection_configuration["Standard"]["ssl_params"] else None - ssl_params["cert"] = connection_configuration["Standard"]["ssl_params"]["cert"] if "cert" in connection_configuration["Standard"]["ssl_params"] else None - ssl_params["key"] = connection_configuration["Standard"]["ssl_params"]["key"] if "key" in connection_configuration["Standard"]["ssl_params"] else None + if "ca" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["ca"] = connection_configuration["Standard"]["ssl_params"]["ca"] + if "cert" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["cert"] = connection_configuration["Standard"]["ssl_params"]["cert"] + if "key" in connection_configuration["Standard"]["ssl_params"]: + ssl_params["key"] = connection_configuration["Standard"]["ssl_params"]["key"] else: ssl_params = None else: