Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/clickhouse/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if tls_mode != TLS_MODE_DEFAULT:
connect_args["verify"] = tls_mode != TLS_MODE_REQUIRE

connect_args.update(config.extra_options)
client = clickhouse_connect.get_client(**connect_args)
return client

Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/cockroachdb/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if tls_key_password:
connect_args["sslpassword"] = tls_key_password

connect_args.update(config.extra_options)
conn = psycopg2.connect(**connect_args)
# Enable autocommit to avoid transaction issues
conn.autocommit = True
Expand Down
4 changes: 3 additions & 1 deletion sqlit/domains/connections/providers/db2/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def connect(self, config: ConnectionConfig) -> Any:
f"UID={endpoint.username};"
f"PWD={endpoint.password};"
)
return ibm_db_dbi.connect(conn_str, "", "")
connect_args: dict[str, Any] = {}
connect_args.update(config.extra_options)
return ibm_db_dbi.connect(conn_str, "", "", **connect_args)

def get_databases(self, conn: Any) -> list[str]:
return []
Expand Down
4 changes: 3 additions & 1 deletion sqlit/domains/connections/providers/duckdb/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def connect(self, config: ConnectionConfig) -> Any:
raise ValueError("DuckDB connections require a file endpoint.")
file_path = resolve_file_path(str(file_endpoint.path))
duckdb_any: Any = duckdb
return duckdb_any.connect(str(file_path))
connect_args: dict[str, Any] = {}
connect_args.update(config.extra_options)
return duckdb_any.connect(str(file_path), **connect_args)

def get_databases(self, conn: Any) -> list[str]:
"""DuckDB doesn't support multiple databases - return empty list."""
Expand Down
16 changes: 9 additions & 7 deletions sqlit/domains/connections/providers/firebird/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def connect(self, config: "ConnectionConfig") -> Any:
endpoint = config.tcp_endpoint
if endpoint is None:
raise ValueError("Firebird connections require a TCP-style endpoint.")
conn = firebirdsql.connect(
host=endpoint.host or "localhost",
port=int(endpoint.port) if endpoint.port else 3050,
database=endpoint.database or "security.db",
user=endpoint.username,
password=endpoint.password,
)
connect_args: dict[str, Any] = {
"host": endpoint.host or "localhost",
"port": int(endpoint.port) if endpoint.port else 3050,
"database": endpoint.database or "security.db",
"user": endpoint.username,
"password": endpoint.password,
}
connect_args.update(config.extra_options)
conn = firebirdsql.connect(**connect_args)
return conn

def get_databases(self, conn: Any) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/flight/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if use_tls and config.get_option("flight_skip_verify", "false") == "true":
db_kwargs["adbc.flight.sql.client_option.tls_skip_verify"] = "true"

db_kwargs.update(config.extra_options)
conn = flight_sql.connect(uri, db_kwargs=db_kwargs)

# Store the catalog/database for later use
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/hana/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if schema:
connect_args["currentSchema"] = schema

connect_args.update(config.extra_options)
return hdbcli.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/mariadb/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def connect(self, config: ConnectionConfig) -> Any:
connect_args["ssl_verify_cert"] = tls_mode_verifies_cert(tls_mode)
connect_args["ssl_verify_identity"] = tls_mode_verifies_hostname(tls_mode)

connect_args.update(config.extra_options)
conn = mariadb_any.connect(**connect_args)

# Note: The MariaDB Python connector only supports UTF-8 family charsets.
Expand Down
3 changes: 3 additions & 0 deletions sqlit/domains/connections/providers/mssql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def connect(self, config: ConnectionConfig) -> Any:
)

conn_str = self._build_connection_string(config)
# Append extra_options to connection string
for key, value in config.extra_options.items():
conn_str += f"{key}={value};"
conn = mssql_python.connect(conn_str)
# Enable autocommit to allow DDL statements like CREATE DATABASE
conn.autocommit = True
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/mysql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def connect(self, config: ConnectionConfig) -> Any:
ssl_params["check_hostname"] = tls_mode_verifies_hostname(tls_mode)
connect_args["ssl"] = ssl_params

connect_args.update(config.extra_options)
conn = pymysql.connect(**connect_args)

# Auto-sync charset with server to handle legacy encodings (e.g., TIS-620, Latin1).
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/oracle/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if mode is not None:
connect_kwargs["mode"] = mode

connect_kwargs.update(config.extra_options)
return oracledb.connect(**connect_kwargs)

def get_databases(self, conn: Any) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if tls_key_password:
connect_args["sslpassword"] = tls_key_password

connect_args.update(config.extra_options)
conn = psycopg2.connect(**connect_args)
# Enable autocommit to avoid "transaction aborted" errors on failed statements
conn.autocommit = True
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/presto/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def connect(self, config: ConnectionConfig) -> Any:
raise ValueError("Presto password authentication requires prestodb.auth.BasicAuthentication") from exc
connect_args["auth"] = BasicAuthentication(endpoint.username, endpoint.password)

connect_args.update(config.extra_options)
return prestodb_dbapi.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/redshift/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if tls_key:
connect_args["sslkey"] = tls_key

connect_args.update(config.extra_options)
conn = redshift_connector.connect(**connect_args)
conn.autocommit = True
return conn
Expand Down
15 changes: 14 additions & 1 deletion sqlit/domains/connections/providers/snowflake/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,27 @@ def connect(self, config: ConnectionConfig) -> Any:
}

# Additional args from our schema:
# warehouse, schema, role.
# warehouse, schema, role, authenticator.
extras = config.options
if "warehouse" in extras:
connect_args["warehouse"] = extras["warehouse"]
if "schema" in extras:
connect_args["schema"] = extras["schema"]
if "role" in extras:
connect_args["role"] = extras["role"]
# Authentication options
authenticator = extras.get("authenticator", "default")
if authenticator and authenticator != "default":
connect_args["authenticator"] = authenticator
if "private_key_file" in extras:
connect_args["private_key_file"] = extras["private_key_file"]
if "private_key_file_pwd" in extras:
connect_args["private_key_file_pwd"] = extras["private_key_file_pwd"]
if "oauth_token" in extras:
connect_args["token"] = extras["oauth_token"]

# Pass through any extra_options to the driver
connect_args.update(config.extra_options)

return sf.connect(**connect_args)

Expand Down
63 changes: 62 additions & 1 deletion sqlit/domains/connections/providers/snowflake/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,32 @@

from sqlit.domains.connections.providers.schema_helpers import (
ConnectionSchema,
FieldType,
SchemaField,
SelectOption,
_database_field,
_password_field,
_username_field,
)


def _get_snowflake_auth_options() -> tuple[SelectOption, ...]:
return (
SelectOption("default", "Username & Password"),
SelectOption("externalbrowser", "SSO (Browser)"),
SelectOption("snowflake_jwt", "Key Pair (JWT)"),
SelectOption("oauth", "OAuth Token"),
)


# Auth types that need password
_AUTH_NEEDS_PASSWORD = {"default"}
# Auth types that need private key
_AUTH_NEEDS_PRIVATE_KEY = {"snowflake_jwt"}
# Auth types that need OAuth token
_AUTH_NEEDS_OAUTH = {"oauth"}


SCHEMA = ConnectionSchema(
db_type="snowflake",
display_name="Snowflake",
Expand All @@ -20,7 +40,47 @@
description="Snowflake Account Identifier",
),
_username_field(),
_password_field(),
SchemaField(
name="authenticator",
label="Authentication",
field_type=FieldType.DROPDOWN,
options=_get_snowflake_auth_options(),
default="default",
),
SchemaField(
name="password",
label="Password",
field_type=FieldType.PASSWORD,
placeholder="(empty = ask every connect)",
group="credentials",
visible_when=lambda v: v.get("authenticator", "default") in _AUTH_NEEDS_PASSWORD,
),
SchemaField(
name="private_key_file",
label="Private Key File",
field_type=FieldType.FILE,
placeholder="/path/to/rsa_key.p8",
required=False,
description="Path to private key file for JWT authentication",
visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_PRIVATE_KEY,
),
SchemaField(
name="private_key_file_pwd",
label="Private Key Password",
field_type=FieldType.PASSWORD,
placeholder="(optional)",
required=False,
description="Password for encrypted private key",
visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_PRIVATE_KEY,
),
SchemaField(
name="oauth_token",
label="OAuth Token",
field_type=FieldType.PASSWORD,
placeholder="OAuth access token",
required=False,
visible_when=lambda v: v.get("authenticator") in _AUTH_NEEDS_OAUTH,
),
_database_field(),
SchemaField(
name="warehouse",
Expand All @@ -45,4 +105,5 @@
),
),
supports_ssh=False,
has_advanced_auth=True,
)
4 changes: 3 additions & 1 deletion sqlit/domains/connections/providers/sqlite/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def connect(self, config: ConnectionConfig) -> Any:
file_path = resolve_file_path(str(file_endpoint.path))
# check_same_thread=False allows connection to be used from background threads
# (for async query execution). SQLite serializes access internally.
conn = sqlite3.connect(file_path, check_same_thread=False)
connect_args: dict[str, Any] = {"check_same_thread": False}
connect_args.update(config.extra_options)
conn = sqlite3.connect(file_path, **connect_args)
conn.row_factory = sqlite3.Row
return conn

Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/teradata/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def connect(self, config: ConnectionConfig) -> Any:
if port:
connect_args["dbs_port"] = port

connect_args.update(config.extra_options)
return teradatasql.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions sqlit/domains/connections/providers/trino/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def connect(self, config: ConnectionConfig) -> Any:
raise ValueError("Trino password authentication requires trino.auth.BasicAuthentication") from exc
connect_args["auth"] = BasicAuthentication(endpoint.username, endpoint.password)

connect_args.update(config.extra_options)
return trino_dbapi.connect(**connect_args)

def get_databases(self, conn: Any) -> list[str]:
Expand Down
4 changes: 3 additions & 1 deletion sqlit/domains/connections/providers/turso/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def connect(self, config: ConnectionConfig) -> Any:
url = f"https://{url}"

auth_token = endpoint.password if endpoint.password else ""
return libsql.connect(url, auth_token=auth_token)
connect_args: dict[str, Any] = {"auth_token": auth_token}
connect_args.update(config.extra_options)
return libsql.connect(url, **connect_args)

def get_databases(self, conn: Any) -> list[str]:
"""Turso doesn't support multiple databases - return empty list."""
Expand Down
Loading
Loading