From 70e4cd046b0be0f7d59c8f4cabb2071aa4b6c26a Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Sat, 31 Jan 2026 16:20:56 +0100 Subject: [PATCH] feat: pass extra_options to database drivers - Add extra_options pass-through to all adapters, allowing users to pass custom driver parameters via connections.json or CLI URLs - Add Snowflake authentication dropdown with support for: - Username & Password (default) - SSO (Browser) - Key Pair (JWT) - OAuth Token - Add conditional fields for private key file and password when JWT is selected --- .../providers/clickhouse/adapter.py | 1 + .../providers/cockroachdb/adapter.py | 1 + .../connections/providers/db2/adapter.py | 4 +- .../connections/providers/duckdb/adapter.py | 4 +- .../connections/providers/firebird/adapter.py | 16 +- .../connections/providers/flight/adapter.py | 1 + .../connections/providers/hana/adapter.py | 1 + .../connections/providers/mariadb/adapter.py | 1 + .../connections/providers/mssql/adapter.py | 3 + .../connections/providers/mysql/adapter.py | 1 + .../connections/providers/oracle/adapter.py | 1 + .../providers/postgresql/adapter.py | 1 + .../connections/providers/presto/adapter.py | 1 + .../connections/providers/redshift/adapter.py | 1 + .../providers/snowflake/adapter.py | 15 +- .../connections/providers/snowflake/schema.py | 63 ++++++- .../connections/providers/sqlite/adapter.py | 4 +- .../connections/providers/teradata/adapter.py | 1 + .../connections/providers/trino/adapter.py | 1 + .../connections/providers/turso/adapter.py | 4 +- tests/unit/test_extra_options_passthrough.py | 178 ++++++++++++++++++ 21 files changed, 290 insertions(+), 13 deletions(-) create mode 100644 tests/unit/test_extra_options_passthrough.py diff --git a/sqlit/domains/connections/providers/clickhouse/adapter.py b/sqlit/domains/connections/providers/clickhouse/adapter.py index c07bb7ee..31b2e009 100644 --- a/sqlit/domains/connections/providers/clickhouse/adapter.py +++ b/sqlit/domains/connections/providers/clickhouse/adapter.py @@ -141,6 +141,7 @@ def connect(self, config: ConnectionConfig) -> Any: if tls_mode != TLS_MODE_DEFAULT: connect_args["verify"] = tls_mode != TLS_MODE_REQUIRE + connect_args.update(config.extra_options) client = clickhouse_connect.get_client(**connect_args) return client diff --git a/sqlit/domains/connections/providers/cockroachdb/adapter.py b/sqlit/domains/connections/providers/cockroachdb/adapter.py index 6fb04025..04e6d76b 100644 --- a/sqlit/domains/connections/providers/cockroachdb/adapter.py +++ b/sqlit/domains/connections/providers/cockroachdb/adapter.py @@ -87,6 +87,7 @@ def connect(self, config: ConnectionConfig) -> Any: if tls_key_password: connect_args["sslpassword"] = tls_key_password + connect_args.update(config.extra_options) conn = psycopg2.connect(**connect_args) # Enable autocommit to avoid transaction issues conn.autocommit = True diff --git a/sqlit/domains/connections/providers/db2/adapter.py b/sqlit/domains/connections/providers/db2/adapter.py index 8e7cac2f..7aafbd82 100644 --- a/sqlit/domains/connections/providers/db2/adapter.py +++ b/sqlit/domains/connections/providers/db2/adapter.py @@ -81,7 +81,9 @@ def connect(self, config: ConnectionConfig) -> Any: f"UID={endpoint.username};" f"PWD={endpoint.password};" ) - return ibm_db_dbi.connect(conn_str, "", "") + connect_args: dict[str, Any] = {} + connect_args.update(config.extra_options) + return ibm_db_dbi.connect(conn_str, "", "", **connect_args) def get_databases(self, conn: Any) -> list[str]: return [] diff --git a/sqlit/domains/connections/providers/duckdb/adapter.py b/sqlit/domains/connections/providers/duckdb/adapter.py index 394a78dd..e70fbaa7 100644 --- a/sqlit/domains/connections/providers/duckdb/adapter.py +++ b/sqlit/domains/connections/providers/duckdb/adapter.py @@ -83,7 +83,9 @@ def connect(self, config: ConnectionConfig) -> Any: raise ValueError("DuckDB connections require a file endpoint.") file_path = resolve_file_path(str(file_endpoint.path)) duckdb_any: Any = duckdb - return duckdb_any.connect(str(file_path)) + connect_args: dict[str, Any] = {} + connect_args.update(config.extra_options) + return duckdb_any.connect(str(file_path), **connect_args) def get_databases(self, conn: Any) -> list[str]: """DuckDB doesn't support multiple databases - return empty list.""" diff --git a/sqlit/domains/connections/providers/firebird/adapter.py b/sqlit/domains/connections/providers/firebird/adapter.py index 95e5b6e4..5b94cc16 100644 --- a/sqlit/domains/connections/providers/firebird/adapter.py +++ b/sqlit/domains/connections/providers/firebird/adapter.py @@ -72,13 +72,15 @@ def connect(self, config: "ConnectionConfig") -> Any: endpoint = config.tcp_endpoint if endpoint is None: raise ValueError("Firebird connections require a TCP-style endpoint.") - conn = firebirdsql.connect( - host=endpoint.host or "localhost", - port=int(endpoint.port) if endpoint.port else 3050, - database=endpoint.database or "security.db", - user=endpoint.username, - password=endpoint.password, - ) + connect_args: dict[str, Any] = { + "host": endpoint.host or "localhost", + "port": int(endpoint.port) if endpoint.port else 3050, + "database": endpoint.database or "security.db", + "user": endpoint.username, + "password": endpoint.password, + } + connect_args.update(config.extra_options) + conn = firebirdsql.connect(**connect_args) return conn def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/flight/adapter.py b/sqlit/domains/connections/providers/flight/adapter.py index 12a3e800..0dce60d7 100644 --- a/sqlit/domains/connections/providers/flight/adapter.py +++ b/sqlit/domains/connections/providers/flight/adapter.py @@ -130,6 +130,7 @@ def connect(self, config: ConnectionConfig) -> Any: if use_tls and config.get_option("flight_skip_verify", "false") == "true": db_kwargs["adbc.flight.sql.client_option.tls_skip_verify"] = "true" + db_kwargs.update(config.extra_options) conn = flight_sql.connect(uri, db_kwargs=db_kwargs) # Store the catalog/database for later use diff --git a/sqlit/domains/connections/providers/hana/adapter.py b/sqlit/domains/connections/providers/hana/adapter.py index 28e7acef..7132cb86 100644 --- a/sqlit/domains/connections/providers/hana/adapter.py +++ b/sqlit/domains/connections/providers/hana/adapter.py @@ -82,6 +82,7 @@ def connect(self, config: ConnectionConfig) -> Any: if schema: connect_args["currentSchema"] = schema + connect_args.update(config.extra_options) return hdbcli.connect(**connect_args) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/mariadb/adapter.py b/sqlit/domains/connections/providers/mariadb/adapter.py index 5cb58bfc..6c119b50 100644 --- a/sqlit/domains/connections/providers/mariadb/adapter.py +++ b/sqlit/domains/connections/providers/mariadb/adapter.py @@ -100,6 +100,7 @@ def connect(self, config: ConnectionConfig) -> Any: connect_args["ssl_verify_cert"] = tls_mode_verifies_cert(tls_mode) connect_args["ssl_verify_identity"] = tls_mode_verifies_hostname(tls_mode) + connect_args.update(config.extra_options) conn = mariadb_any.connect(**connect_args) self._supports_sequences = self._detect_sequences_support(conn) return conn diff --git a/sqlit/domains/connections/providers/mssql/adapter.py b/sqlit/domains/connections/providers/mssql/adapter.py index 5ddb16ce..7dc06760 100644 --- a/sqlit/domains/connections/providers/mssql/adapter.py +++ b/sqlit/domains/connections/providers/mssql/adapter.py @@ -180,6 +180,9 @@ def connect(self, config: ConnectionConfig) -> Any: ) conn_str = self._build_connection_string(config) + # Append extra_options to connection string + for key, value in config.extra_options.items(): + conn_str += f"{key}={value};" return mssql_python.connect(conn_str) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/mysql/adapter.py b/sqlit/domains/connections/providers/mysql/adapter.py index 803ec7da..27771874 100644 --- a/sqlit/domains/connections/providers/mysql/adapter.py +++ b/sqlit/domains/connections/providers/mysql/adapter.py @@ -110,4 +110,5 @@ def connect(self, config: ConnectionConfig) -> Any: ssl_params["check_hostname"] = tls_mode_verifies_hostname(tls_mode) connect_args["ssl"] = ssl_params + connect_args.update(config.extra_options) return pymysql.connect(**connect_args) diff --git a/sqlit/domains/connections/providers/oracle/adapter.py b/sqlit/domains/connections/providers/oracle/adapter.py index c6ad39e4..b7c698bb 100644 --- a/sqlit/domains/connections/providers/oracle/adapter.py +++ b/sqlit/domains/connections/providers/oracle/adapter.py @@ -91,6 +91,7 @@ def connect(self, config: ConnectionConfig) -> Any: if mode is not None: connect_kwargs["mode"] = mode + connect_kwargs.update(config.extra_options) return oracledb.connect(**connect_kwargs) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/postgresql/adapter.py b/sqlit/domains/connections/providers/postgresql/adapter.py index b806f78c..5fb62099 100644 --- a/sqlit/domains/connections/providers/postgresql/adapter.py +++ b/sqlit/domains/connections/providers/postgresql/adapter.py @@ -74,6 +74,7 @@ def connect(self, config: ConnectionConfig) -> Any: if tls_key_password: connect_args["sslpassword"] = tls_key_password + connect_args.update(config.extra_options) conn = psycopg2.connect(**connect_args) # Enable autocommit to avoid "transaction aborted" errors on failed statements conn.autocommit = True diff --git a/sqlit/domains/connections/providers/presto/adapter.py b/sqlit/domains/connections/providers/presto/adapter.py index 88189e2e..7c3c84c2 100644 --- a/sqlit/domains/connections/providers/presto/adapter.py +++ b/sqlit/domains/connections/providers/presto/adapter.py @@ -106,6 +106,7 @@ def connect(self, config: ConnectionConfig) -> Any: raise ValueError("Presto password authentication requires prestodb.auth.BasicAuthentication") from exc connect_args["auth"] = BasicAuthentication(endpoint.username, endpoint.password) + connect_args.update(config.extra_options) return prestodb_dbapi.connect(**connect_args) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/redshift/adapter.py b/sqlit/domains/connections/providers/redshift/adapter.py index ee37bd38..6e29ef71 100644 --- a/sqlit/domains/connections/providers/redshift/adapter.py +++ b/sqlit/domains/connections/providers/redshift/adapter.py @@ -125,6 +125,7 @@ def connect(self, config: ConnectionConfig) -> Any: if tls_key: connect_args["sslkey"] = tls_key + connect_args.update(config.extra_options) conn = redshift_connector.connect(**connect_args) conn.autocommit = True return conn diff --git a/sqlit/domains/connections/providers/snowflake/adapter.py b/sqlit/domains/connections/providers/snowflake/adapter.py index b40a0e49..13e9882d 100644 --- a/sqlit/domains/connections/providers/snowflake/adapter.py +++ b/sqlit/domains/connections/providers/snowflake/adapter.py @@ -79,7 +79,7 @@ def connect(self, config: ConnectionConfig) -> Any: } # Additional args from our schema: - # warehouse, schema, role. + # warehouse, schema, role, authenticator. extras = config.options if "warehouse" in extras: connect_args["warehouse"] = extras["warehouse"] @@ -87,6 +87,19 @@ def connect(self, config: ConnectionConfig) -> Any: connect_args["schema"] = extras["schema"] if "role" in extras: connect_args["role"] = extras["role"] + # Authentication options + authenticator = extras.get("authenticator", "default") + if authenticator and authenticator != "default": + connect_args["authenticator"] = authenticator + if "private_key_file" in extras: + connect_args["private_key_file"] = extras["private_key_file"] + if "private_key_file_pwd" in extras: + connect_args["private_key_file_pwd"] = extras["private_key_file_pwd"] + if "oauth_token" in extras: + connect_args["token"] = extras["oauth_token"] + + # Pass through any extra_options to the driver + connect_args.update(config.extra_options) return sf.connect(**connect_args) diff --git a/sqlit/domains/connections/providers/snowflake/schema.py b/sqlit/domains/connections/providers/snowflake/schema.py index 6ddcbbb4..8a3dba78 100644 --- a/sqlit/domains/connections/providers/snowflake/schema.py +++ b/sqlit/domains/connections/providers/snowflake/schema.py @@ -2,12 +2,32 @@ from sqlit.domains.connections.providers.schema_helpers import ( ConnectionSchema, + FieldType, SchemaField, + SelectOption, _database_field, _password_field, _username_field, ) + +def _get_snowflake_auth_options() -> tuple[SelectOption, ...]: + return ( + SelectOption("default", "Username & Password"), + SelectOption("externalbrowser", "SSO (Browser)"), + SelectOption("snowflake_jwt", "Key Pair (JWT)"), + SelectOption("oauth", "OAuth Token"), + ) + + +# Auth types that need password +_AUTH_NEEDS_PASSWORD = {"default"} +# Auth types that need private key +_AUTH_NEEDS_PRIVATE_KEY = {"snowflake_jwt"} +# Auth types that need OAuth token +_AUTH_NEEDS_OAUTH = {"oauth"} + + SCHEMA = ConnectionSchema( db_type="snowflake", display_name="Snowflake", @@ -20,7 +40,47 @@ description="Snowflake Account Identifier", ), _username_field(), - _password_field(), + SchemaField( + name="authenticator", + label="Authentication", + field_type=FieldType.DROPDOWN, + options=_get_snowflake_auth_options(), + default="default", + ), + SchemaField( + name="password", + label="Password", + field_type=FieldType.PASSWORD, + placeholder="(empty = ask every connect)", + group="credentials", + visible_when=lambda v: v.get("authenticator", "default") in _AUTH_NEEDS_PASSWORD, + ), + SchemaField( + name="private_key_file", + label="Private Key File", + field_type=FieldType.FILE, + placeholder="/path/to/rsa_key.p8", + required=False, + description="Path to private key file for JWT authentication", + visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_PRIVATE_KEY, + ), + SchemaField( + name="private_key_file_pwd", + label="Private Key Password", + field_type=FieldType.PASSWORD, + placeholder="(optional)", + required=False, + description="Password for encrypted private key", + visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_PRIVATE_KEY, + ), + SchemaField( + name="oauth_token", + label="OAuth Token", + field_type=FieldType.PASSWORD, + placeholder="OAuth access token", + required=False, + visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_OAUTH, + ), _database_field(), SchemaField( name="warehouse", @@ -45,4 +105,5 @@ ), ), supports_ssh=False, + has_advanced_auth=True, ) diff --git a/sqlit/domains/connections/providers/sqlite/adapter.py b/sqlit/domains/connections/providers/sqlite/adapter.py index e5b32abf..2b7315f4 100644 --- a/sqlit/domains/connections/providers/sqlite/adapter.py +++ b/sqlit/domains/connections/providers/sqlite/adapter.py @@ -43,7 +43,9 @@ def connect(self, config: ConnectionConfig) -> Any: file_path = resolve_file_path(str(file_endpoint.path)) # check_same_thread=False allows connection to be used from background threads # (for async query execution). SQLite serializes access internally. - conn = sqlite3.connect(file_path, check_same_thread=False) + connect_args: dict[str, Any] = {"check_same_thread": False} + connect_args.update(config.extra_options) + conn = sqlite3.connect(file_path, **connect_args) conn.row_factory = sqlite3.Row return conn diff --git a/sqlit/domains/connections/providers/teradata/adapter.py b/sqlit/domains/connections/providers/teradata/adapter.py index 2192ccfc..29fe886f 100644 --- a/sqlit/domains/connections/providers/teradata/adapter.py +++ b/sqlit/domains/connections/providers/teradata/adapter.py @@ -85,6 +85,7 @@ def connect(self, config: ConnectionConfig) -> Any: if port: connect_args["dbs_port"] = port + connect_args.update(config.extra_options) return teradatasql.connect(**connect_args) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/trino/adapter.py b/sqlit/domains/connections/providers/trino/adapter.py index fec31a47..fe9a406b 100644 --- a/sqlit/domains/connections/providers/trino/adapter.py +++ b/sqlit/domains/connections/providers/trino/adapter.py @@ -106,6 +106,7 @@ def connect(self, config: ConnectionConfig) -> Any: raise ValueError("Trino password authentication requires trino.auth.BasicAuthentication") from exc connect_args["auth"] = BasicAuthentication(endpoint.username, endpoint.password) + connect_args.update(config.extra_options) return trino_dbapi.connect(**connect_args) def get_databases(self, conn: Any) -> list[str]: diff --git a/sqlit/domains/connections/providers/turso/adapter.py b/sqlit/domains/connections/providers/turso/adapter.py index 94113c12..1e65d546 100644 --- a/sqlit/domains/connections/providers/turso/adapter.py +++ b/sqlit/domains/connections/providers/turso/adapter.py @@ -96,7 +96,9 @@ def connect(self, config: ConnectionConfig) -> Any: url = f"https://{url}" auth_token = endpoint.password if endpoint.password else "" - return libsql.connect(url, auth_token=auth_token) + connect_args: dict[str, Any] = {"auth_token": auth_token} + connect_args.update(config.extra_options) + return libsql.connect(url, **connect_args) def get_databases(self, conn: Any) -> list[str]: """Turso doesn't support multiple databases - return empty list.""" diff --git a/tests/unit/test_extra_options_passthrough.py b/tests/unit/test_extra_options_passthrough.py new file mode 100644 index 00000000..1ad78eff --- /dev/null +++ b/tests/unit/test_extra_options_passthrough.py @@ -0,0 +1,178 @@ +"""Unit tests for extra_options pass-through to drivers. + +This verifies the fix for GitHub issue #108 where users couldn't pass +custom properties to underlying database drivers. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + + +class TestExtraOptionsPassthrough: + """Test that extra_options are passed through to drivers.""" + + def test_snowflake_passes_extra_options(self): + """Test Snowflake adapter passes extra_options to driver.""" + from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint + from sqlit.domains.connections.providers.snowflake.adapter import SnowflakeAdapter + + mock_sf = MagicMock() + mock_conn = MagicMock() + mock_sf.connect.return_value = mock_conn + + with patch.dict("sys.modules", {"snowflake.connector": mock_sf}): + adapter = SnowflakeAdapter() + config = ConnectionConfig( + name="test_sf", + db_type="snowflake", + endpoint=TcpEndpoint( + host="account.snowflakecomputing.com", + username="user", + password="pass", + database="db", + ), + extra_options={ + "authenticator": "externalbrowser", + "custom_option": "custom_value", + }, + ) + + adapter.connect(config) + + # Verify extra_options were passed to connect + call_kwargs = mock_sf.connect.call_args[1] + assert call_kwargs.get("authenticator") == "externalbrowser" + assert call_kwargs.get("custom_option") == "custom_value" + + def test_snowflake_jwt_auth_options(self): + """Test Snowflake JWT authentication options are passed.""" + from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint + from sqlit.domains.connections.providers.snowflake.adapter import SnowflakeAdapter + + mock_sf = MagicMock() + mock_conn = MagicMock() + mock_sf.connect.return_value = mock_conn + + with patch.dict("sys.modules", {"snowflake.connector": mock_sf}): + adapter = SnowflakeAdapter() + config = ConnectionConfig( + name="test_sf_jwt", + db_type="snowflake", + endpoint=TcpEndpoint( + host="account.snowflakecomputing.com", + username="user", + database="db", + ), + options={ + "authenticator": "snowflake_jwt", + "private_key_file": "/path/to/key.p8", + "private_key_file_pwd": "secret", + }, + ) + + adapter.connect(config) + + call_kwargs = mock_sf.connect.call_args[1] + assert call_kwargs.get("authenticator") == "snowflake_jwt" + assert call_kwargs.get("private_key_file") == "/path/to/key.p8" + assert call_kwargs.get("private_key_file_pwd") == "secret" + + def test_postgresql_passes_extra_options(self): + """Test PostgreSQL adapter passes extra_options to driver.""" + from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint + from sqlit.domains.connections.providers.postgresql.adapter import PostgreSQLAdapter + + mock_psycopg2 = MagicMock() + mock_conn = MagicMock() + mock_conn.autocommit = False + mock_psycopg2.connect.return_value = mock_conn + + with patch.dict("sys.modules", {"psycopg2": mock_psycopg2}): + adapter = PostgreSQLAdapter() + config = ConnectionConfig( + name="test_pg", + db_type="postgresql", + endpoint=TcpEndpoint( + host="localhost", + port="5432", + username="user", + password="pass", + database="db", + ), + extra_options={ + "application_name": "my_app", + "connect_timeout": "30", + }, + ) + + adapter.connect(config) + + call_kwargs = mock_psycopg2.connect.call_args[1] + assert call_kwargs.get("application_name") == "my_app" + assert call_kwargs.get("connect_timeout") == "30" + + def test_mysql_passes_extra_options(self): + """Test MySQL adapter passes extra_options to driver.""" + from sqlit.domains.connections.domain.config import ConnectionConfig, TcpEndpoint + from sqlit.domains.connections.providers.mysql.adapter import MySQLAdapter + + mock_pymysql = MagicMock() + mock_conn = MagicMock() + mock_pymysql.connect.return_value = mock_conn + + with patch.dict("sys.modules", {"pymysql": mock_pymysql}): + adapter = MySQLAdapter() + config = ConnectionConfig( + name="test_mysql", + db_type="mysql", + endpoint=TcpEndpoint( + host="localhost", + port="3306", + username="user", + password="pass", + database="db", + ), + extra_options={ + "charset": "utf8mb4", + "init_command": "SET NAMES utf8mb4", + }, + ) + + adapter.connect(config) + + call_kwargs = mock_pymysql.connect.call_args[1] + assert call_kwargs.get("charset") == "utf8mb4" + assert call_kwargs.get("init_command") == "SET NAMES utf8mb4" + + +class TestSnowflakeAuthSchema: + """Test Snowflake authentication schema options.""" + + def test_snowflake_schema_has_auth_dropdown(self): + """Test Snowflake schema includes authentication dropdown.""" + from sqlit.domains.connections.providers.snowflake.schema import SCHEMA + + auth_field = None + for field in SCHEMA.fields: + if field.name == "authenticator": + auth_field = field + break + + assert auth_field is not None, "Snowflake schema should have authenticator field" + assert len(auth_field.options) == 4 + auth_values = [opt.value for opt in auth_field.options] + assert "default" in auth_values + assert "externalbrowser" in auth_values + assert "snowflake_jwt" in auth_values + assert "oauth" in auth_values + + def test_snowflake_schema_has_private_key_fields(self): + """Test Snowflake schema includes private key fields for JWT auth.""" + from sqlit.domains.connections.providers.snowflake.schema import SCHEMA + + field_names = [f.name for f in SCHEMA.fields] + assert "private_key_file" in field_names + assert "private_key_file_pwd" in field_names