From 67cc0c7276d51c130844fcd7f9a125770d81153a Mon Sep 17 00:00:00 2001 From: jaz Date: Thu, 29 Jan 2026 20:45:01 -0500 Subject: [PATCH 1/5] added db engine cache --- jetbase/database/connection.py | 75 +++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/jetbase/database/connection.py b/jetbase/database/connection.py index 40dbdf2..7a2065c 100644 --- a/jetbase/database/connection.py +++ b/jetbase/database/connection.py @@ -1,5 +1,6 @@ import logging from contextlib import contextmanager +from functools import lru_cache from typing import Any, Generator from sqlalchemy import Connection, Engine, create_engine, text @@ -10,24 +11,6 @@ from jetbase.enums import DatabaseType -@contextmanager -def _suppress_databricks_warnings(): - """ - Temporarily sets the databricks logger to ERROR level to suppress - the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency. - - Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning. - """ - logger = logging.getLogger("databricks") - original_level = logger.level - logger.setLevel(logging.ERROR) - - try: - yield - finally: - logger.setLevel(original_level) - - @contextmanager def get_db_connection() -> Generator[Connection, None, None]: """ @@ -45,18 +28,9 @@ def get_db_connection() -> Generator[Connection, None, None]: >>> with get_db_connection() as conn: ... conn.execute(query) """ - sqlalchemy_url: str = get_config(required={"sqlalchemy_url"}).sqlalchemy_url - db_type: DatabaseType = detect_db(sqlalchemy_url=sqlalchemy_url) - - connect_args: dict[str, Any] = {} - - if db_type == DatabaseType.SNOWFLAKE: - snowflake_url: URL = make_url(sqlalchemy_url) - - if not snowflake_url.password: - connect_args["private_key"] = _get_snowflake_private_key_der() - engine: Engine = create_engine(url=sqlalchemy_url, connect_args=connect_args) + engine: Engine = _get_engine() + db_type: DatabaseType = detect_db(sqlalchemy_url=str(engine.url)) if db_type == DatabaseType.DATABRICKS: # Suppress databricks warnings during connection @@ -75,6 +49,31 @@ def get_db_connection() -> Generator[Connection, None, None]: yield connection +@lru_cache(maxsize=1) +def _get_engine() -> Engine: + """ + Get or create the singleton SQLAlchemy Engine. + + Creates the engine on first call and caches it for subsequent calls. + The engine manages its own connection pool internally. + + Returns: + Engine: A SQLAlchemy Engine instance. + """ + sqlalchemy_url: str = get_config(required={"sqlalchemy_url"}).sqlalchemy_url + db_type: DatabaseType = detect_db(sqlalchemy_url=sqlalchemy_url) + + connect_args: dict[str, Any] = {} + + if db_type == DatabaseType.SNOWFLAKE: + snowflake_url: URL = make_url(sqlalchemy_url) + + if not snowflake_url.password: + connect_args["private_key"] = _get_snowflake_private_key_der() + + return create_engine(url=sqlalchemy_url, connect_args=connect_args) + + def _get_snowflake_private_key_der() -> bytes: """ Retrieves the Snowflake private key in DER format for key pair authentication. @@ -124,3 +123,21 @@ def _get_snowflake_private_key_der() -> bytes: ) return private_key_bytes + + +@contextmanager +def _suppress_databricks_warnings(): + """ + Temporarily sets the databricks logger to ERROR level to suppress + the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency. + + Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning. + """ + logger = logging.getLogger("databricks") + original_level = logger.level + logger.setLevel(logging.ERROR) + + try: + yield + finally: + logger.setLevel(original_level) From 16b37fdb8de0f02451d5cfdba4af2b9b5b303d34 Mon Sep 17 00:00:00 2001 From: jaz Date: Thu, 29 Jan 2026 20:48:34 -0500 Subject: [PATCH 2/5] added db engine cache --- uv.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uv.lock b/uv.lock index 56f2a09..7e0ab3f 100644 --- a/uv.lock +++ b/uv.lock @@ -370,7 +370,7 @@ name = "exceptiongroup" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" } wheels = [ @@ -467,7 +467,7 @@ wheels = [ [[package]] name = "jetbase" -version = "0.16.0" +version = "0.17.1" source = { editable = "." } dependencies = [ { name = "packaging" }, From 532fd67bd2352f704b3250f7b513d5095c91642c Mon Sep 17 00:00:00 2001 From: jaz Date: Thu, 29 Jan 2026 20:59:23 -0500 Subject: [PATCH 3/5] clear cache for tests --- tests/integration/database/test_connection.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration/database/test_connection.py b/tests/integration/database/test_connection.py index a4c7ab9..28c72e4 100644 --- a/tests/integration/database/test_connection.py +++ b/tests/integration/database/test_connection.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy import text -from jetbase.database.connection import get_db_connection +from jetbase.database.connection import _get_engine, get_db_connection class TestSnowflakePasswordAuth: @@ -14,12 +14,15 @@ class TestSnowflakePasswordAuth: @pytest.fixture(autouse=True) def setup(self): """Set up test environment for password auth.""" + _get_engine.cache_clear() url = os.environ.get("TEST_SF_USER_PASS_URL") assert url is not None os.environ["JETBASE_SQLALCHEMY_URL"] = url yield + _get_engine.cache_clear() + def test_get_db_connection_with_password_auth(self): """Test that get_db_connection works with Snowflake password authentication.""" with get_db_connection() as connection: From 438b0fb02da388341d56ed4d4c539e7cb4ff48d5 Mon Sep 17 00:00:00 2001 From: jaz Date: Thu, 29 Jan 2026 21:06:41 -0500 Subject: [PATCH 4/5] clear db cache for tests --- tests/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 7aa78c4..47748b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from sqlalchemy import create_engine, text from typer.testing import CliRunner +from jetbase.database.connection import _get_engine from jetbase.database.queries.base import detect_db from jetbase.enums import DatabaseType @@ -121,6 +122,7 @@ def setup_migrations_versions_only(tmp_path, migrations_versions_only_fixture_di @pytest.fixture def clean_db(test_db_url): """Clean up database before and after tests.""" + _get_engine.cache_clear() engine = create_engine(test_db_url) def cleanup(): @@ -133,3 +135,5 @@ def cleanup(): yield engine cleanup() engine.dispose() + + _get_engine.cache_clear() From b03cef5b8f9ba694ea0584be89354fdb10ef6f56 Mon Sep 17 00:00:00 2001 From: jaz Date: Thu, 29 Jan 2026 21:10:07 -0500 Subject: [PATCH 5/5] clear cache for tests --- tests/integration/database/test_connection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/integration/database/test_connection.py b/tests/integration/database/test_connection.py index 28c72e4..71a901f 100644 --- a/tests/integration/database/test_connection.py +++ b/tests/integration/database/test_connection.py @@ -51,6 +51,7 @@ class TestSnowflakeKeyPairAuth: @pytest.fixture(autouse=True) def setup(self): """Set up test environment for key pair auth.""" + _get_engine.cache_clear() url = os.environ.get("TEST_SF_KEY_AUTH_URL") private_key = os.environ.get("JETBASE_SNOWFLAKE_PRIVATE_KEY") @@ -60,6 +61,8 @@ def setup(self): os.environ["JETBASE_SQLALCHEMY_URL"] = url yield + _get_engine.cache_clear() + def test_get_db_connection_with_keypair_auth(self): """Test that get_db_connection works with Snowflake key pair authentication.""" with get_db_connection() as connection: @@ -88,6 +91,7 @@ class TestSnowflakeEncryptedKeyPairAuth: @pytest.fixture(autouse=True) def setup(self): """Set up test environment for encrypted key pair auth.""" + _get_engine.cache_clear() url = os.environ.get("TEST_SF_KEY_AUTH_URL") private_key = os.environ.get("TEST_SF_ENCRYPTED_PRIVATE_KEY") @@ -103,6 +107,8 @@ def setup(self): yield + _get_engine.cache_clear() + def test_get_db_connection_with_encrypted_keypair_auth(self): """Test that get_db_connection works with encrypted private key.""" password = os.environ.get("TEST_SF_PRIVATE_KEY_PASSWORD")