Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cursorignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
config/database_credentials.yml
config/certificates
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,7 @@ powermetrics_log.txt
# development folder
development/

# MTLS certificates
config/certificates/ca.crt
config/certificates/user_cert.crt
config/certificates/user_private_key.key
5 changes: 5 additions & 0 deletions config/example_database_credentials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ CREDENTIALS:
Connection:
Standard:
server: example.mysqlserver.com
use_ssl: some_boolean_value
ssl_params:
ca: config/certificates/ca.crt
cert: config/certificates/user_cert.crt
key: config/certificates/user_private_key.key
Ssh:
server: example.sshmysqlserver.com (address from ssh server)
address: example.sslserver.com
Expand Down
24 changes: 23 additions & 1 deletion docs/source/usage/database_credential_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,29 @@ Below is an example of a database credential file, that connects to a server wit
Standard:
server: example.mysqlserver.com

However, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server
We additionally also support utilizing encrypted connections with (m)tls. To that end, the following parameters can be added to the ``Standard`` section of the database credential file
- ``use_ssl``: a boolean value indicating whether to use ssl
- ``ssl_params``: a dictionary containing the following keys:
- ``ca``: the path to the ca certificate (optional, needed if the database server uses a custom ca certificate not trusted by the client)
- ``cert``: the path to the user certificate (optional, needed in case of client authentication with mtls)
- ``key``: the path to the user private key (optional, needed in case of client authentication with mtls)

.. code-block:: yaml

CREDENTIALS:
Database:
user: example_user
password: example_password
Connection:
Standard:
server: example.mysqlserver.com
use_ssl: some_boolean_value
ssl_params:
ca: config/certificates/ca.crt
cert: config/certificates/user_cert.crt
key: config/certificates/user_private_key.key

Alternatively, for security reasons, databases might only be accessible from a specific IP address. In these cases, one can use an ssh jumphost. This means that ``PyExperimenter`` will first connect to the ssh server
that has access to the database and then connect to the database server from there. This is done by adding an additional ``Ssh`` section to the database credential file, and can be activated either by a ``PyExperimenter`` keyword argument or in the :ref:`experimenter configuration file <experiment_configuration_file>`.
The following example shows how to connect to a database server using an SSH server with the address ``ssh_hostname`` and the port ``optional_ssh_port``.

Expand Down
38 changes: 31 additions & 7 deletions py_experimenter/database_connector_mysql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from logging import Logger
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import numpy as np
import sshtunnel
Expand All @@ -9,7 +9,11 @@

from py_experimenter.config import DatabaseCfg
from py_experimenter.database_connector import DatabaseConnector
from py_experimenter.exceptions import DatabaseConnectionError, DatabaseCreationError, SshTunnelError
from py_experimenter.exceptions import (
DatabaseConnectionError,
DatabaseCreationError,
SshTunnelError,
)


class DatabaseConnectorMYSQL(DatabaseConnector):
Expand All @@ -23,7 +27,8 @@ def __init__(self, database_configuration: DatabaseCfg, use_codecarbon: bool, cr

def get_ssh_tunnel(self, logger: Logger):
try:
credentials = OmegaConf.load(self.credential_path)["CREDENTIALS"]["Connection"]
credential_config = dict(OmegaConf.load(self.credential_path))
credentials = credential_config["CREDENTIALS"]["Connection"]
if "Ssh" in credentials:
parameters = dict(credentials["Ssh"])
ssh_address_or_host = parameters["address"]
Expand Down Expand Up @@ -107,12 +112,30 @@ def close_connection(self, connection):

def _get_database_credentials(self):
try:
credential_config = OmegaConf.load(self.credential_path)
credential_config = dict(OmegaConf.load(self.credential_path))
database_configuration = credential_config["CREDENTIALS"]["Database"]
connection_configuration = credential_config["CREDENTIALS"]["Connection"]
if self.database_configuration.use_ssh_tunnel:
server_address = credential_config["CREDENTIALS"]["Connection"]["Ssh"]["server"]
server_address = connection_configuration["Ssh"]["server"]
ssl_params = None


else:
server_address = credential_config["CREDENTIALS"]["Connection"]["Standard"]["server"]
server_address = connection_configuration["Standard"]["server"]
if "use_ssl" in connection_configuration["Standard"]:
if connection_configuration["Standard"]["use_ssl"]:
ssl_params = dict()
if "ca" in connection_configuration["Standard"]["ssl_params"]:
ssl_params["ca"] = connection_configuration["Standard"]["ssl_params"]["ca"]
if "cert" in connection_configuration["Standard"]["ssl_params"]:
ssl_params["cert"] = connection_configuration["Standard"]["ssl_params"]["cert"]
if "key" in connection_configuration["Standard"]["ssl_params"]:
ssl_params["key"] = connection_configuration["Standard"]["ssl_params"]["key"]
else:
ssl_params = None
else:
ssl_params = None

credentials = {
"host": server_address,
"user": database_configuration["user"],
Expand All @@ -122,6 +145,7 @@ def _get_database_credentials(self):
return {
**credentials,
"database": self.database_configuration.database_name,
"ssl": ssl_params,
}
except Exception as err:
logging.error(err)
Expand All @@ -131,7 +155,7 @@ def _start_transaction(self, connection, readonly=False):
if not readonly:
connection.begin()

def _table_exists(self, cursor, table_name: str = None) -> bool:
def _table_exists(self, cursor, table_name: Optional[str] = None) -> bool:
table_name = table_name if table_name is not None else self.database_configuration.table_name
self.execute(cursor, f"SHOW TABLES LIKE '{table_name}'")
return self.fetchall(cursor)
Expand Down
4 changes: 2 additions & 2 deletions py_experimenter/experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def __init__(
experiment_configuration_file_path: str = os.path.join("config", "experiment_configuration.yml"),
database_credential_file_path: str = os.path.join("config", "database_credentials.yml"),
use_ssh_tunnel: Optional[bool] = None,
table_name: str = None,
database_name: str = None,
table_name: Optional[str] = None,
database_name: Optional[str] = None,
use_codecarbon: bool = True,
name="PyExperimenter",
logger_name: str = "py-experimenter",
Expand Down
2 changes: 1 addition & 1 deletion test/test_codecarbon/test_integration_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def experimenter():
configuration_path = os.path.join("test", "test_codecarbon", "configs", "integration_test_mysql.yml")

return PyExperimenter(configuration_path, use_ssh_tunnel=True)
return PyExperimenter(configuration_path, use_ssh_tunnel=False)


def run_ml(parameters: dict, result_processor: ResultProcessor, custom_config: dict):
Expand Down