Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions jetbase/database/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Generator

from sqlalchemy import Connection, Engine, create_engine, text
Expand All @@ -10,24 +11,6 @@
from jetbase.enums import DatabaseType


@contextmanager
def _suppress_databricks_warnings():
"""
Temporarily sets the databricks logger to ERROR level to suppress
the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency.

Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning.
"""
logger = logging.getLogger("databricks")
original_level = logger.level
logger.setLevel(logging.ERROR)

try:
yield
finally:
logger.setLevel(original_level)


@contextmanager
def get_db_connection() -> Generator[Connection, None, None]:
"""
Expand All @@ -45,18 +28,9 @@ def get_db_connection() -> Generator[Connection, None, None]:
>>> with get_db_connection() as conn:
... conn.execute(query)
"""
sqlalchemy_url: str = get_config(required={"sqlalchemy_url"}).sqlalchemy_url
db_type: DatabaseType = detect_db(sqlalchemy_url=sqlalchemy_url)

connect_args: dict[str, Any] = {}

if db_type == DatabaseType.SNOWFLAKE:
snowflake_url: URL = make_url(sqlalchemy_url)

if not snowflake_url.password:
connect_args["private_key"] = _get_snowflake_private_key_der()

engine: Engine = create_engine(url=sqlalchemy_url, connect_args=connect_args)
engine: Engine = _get_engine()
db_type: DatabaseType = detect_db(sqlalchemy_url=str(engine.url))

if db_type == DatabaseType.DATABRICKS:
# Suppress databricks warnings during connection
Expand All @@ -75,6 +49,31 @@ def get_db_connection() -> Generator[Connection, None, None]:
yield connection


@lru_cache(maxsize=1)
def _get_engine() -> Engine:
"""
Get or create the singleton SQLAlchemy Engine.

Creates the engine on first call and caches it for subsequent calls.
The engine manages its own connection pool internally.

Returns:
Engine: A SQLAlchemy Engine instance.
"""
sqlalchemy_url: str = get_config(required={"sqlalchemy_url"}).sqlalchemy_url
db_type: DatabaseType = detect_db(sqlalchemy_url=sqlalchemy_url)

connect_args: dict[str, Any] = {}

if db_type == DatabaseType.SNOWFLAKE:
snowflake_url: URL = make_url(sqlalchemy_url)

if not snowflake_url.password:
connect_args["private_key"] = _get_snowflake_private_key_der()

return create_engine(url=sqlalchemy_url, connect_args=connect_args)


def _get_snowflake_private_key_der() -> bytes:
"""
Retrieves the Snowflake private key in DER format for key pair authentication.
Expand Down Expand Up @@ -124,3 +123,21 @@ def _get_snowflake_private_key_der() -> bytes:
)

return private_key_bytes


@contextmanager
def _suppress_databricks_warnings():
"""
Temporarily sets the databricks logger to ERROR level to suppress
the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency.

Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning.
"""
logger = logging.getLogger("databricks")
original_level = logger.level
logger.setLevel(logging.ERROR)

try:
yield
finally:
logger.setLevel(original_level)
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqlalchemy import create_engine, text
from typer.testing import CliRunner

from jetbase.database.connection import _get_engine
from jetbase.database.queries.base import detect_db
from jetbase.enums import DatabaseType

Expand Down Expand Up @@ -121,6 +122,7 @@ def setup_migrations_versions_only(tmp_path, migrations_versions_only_fixture_di
@pytest.fixture
def clean_db(test_db_url):
"""Clean up database before and after tests."""
_get_engine.cache_clear()
engine = create_engine(test_db_url)

def cleanup():
Expand All @@ -133,3 +135,5 @@ def cleanup():
yield engine
cleanup()
engine.dispose()

_get_engine.cache_clear()
11 changes: 10 additions & 1 deletion tests/integration/database/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from sqlalchemy import text

from jetbase.database.connection import get_db_connection
from jetbase.database.connection import _get_engine, get_db_connection


class TestSnowflakePasswordAuth:
Expand All @@ -14,12 +14,15 @@ class TestSnowflakePasswordAuth:
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test environment for password auth."""
_get_engine.cache_clear()
url = os.environ.get("TEST_SF_USER_PASS_URL")
assert url is not None

os.environ["JETBASE_SQLALCHEMY_URL"] = url
yield

_get_engine.cache_clear()

def test_get_db_connection_with_password_auth(self):
"""Test that get_db_connection works with Snowflake password authentication."""
with get_db_connection() as connection:
Expand Down Expand Up @@ -48,6 +51,7 @@ class TestSnowflakeKeyPairAuth:
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test environment for key pair auth."""
_get_engine.cache_clear()
url = os.environ.get("TEST_SF_KEY_AUTH_URL")
private_key = os.environ.get("JETBASE_SNOWFLAKE_PRIVATE_KEY")

Expand All @@ -57,6 +61,8 @@ def setup(self):
os.environ["JETBASE_SQLALCHEMY_URL"] = url
yield

_get_engine.cache_clear()

def test_get_db_connection_with_keypair_auth(self):
"""Test that get_db_connection works with Snowflake key pair authentication."""
with get_db_connection() as connection:
Expand Down Expand Up @@ -85,6 +91,7 @@ class TestSnowflakeEncryptedKeyPairAuth:
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test environment for encrypted key pair auth."""
_get_engine.cache_clear()
url = os.environ.get("TEST_SF_KEY_AUTH_URL")
private_key = os.environ.get("TEST_SF_ENCRYPTED_PRIVATE_KEY")

Expand All @@ -100,6 +107,8 @@ def setup(self):

yield

_get_engine.cache_clear()

def test_get_db_connection_with_encrypted_keypair_auth(self):
"""Test that get_db_connection works with encrypted private key."""
password = os.environ.get("TEST_SF_PRIVATE_KEY_PASSWORD")
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.