diff --git a/jetbase/database/connection.py b/jetbase/database/connection.py index 78eb5dc..40dbdf2 100644 --- a/jetbase/database/connection.py +++ b/jetbase/database/connection.py @@ -1,14 +1,33 @@ +import logging from contextlib import contextmanager from typing import Any, Generator from sqlalchemy import Connection, Engine, create_engine, text -from sqlalchemy.engine import make_url, URL +from sqlalchemy.engine import URL, make_url from jetbase.config import get_config from jetbase.database.queries.base import detect_db from jetbase.enums import DatabaseType +@contextmanager +def _suppress_databricks_warnings(): + """ + Temporarily sets the databricks logger to ERROR level to suppress + the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency. + + Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning. + """ + logger = logging.getLogger("databricks") + original_level = logger.level + logger.setLevel(logging.ERROR) + + try: + yield + finally: + logger.setLevel(original_level) + + @contextmanager def get_db_connection() -> Generator[Connection, None, None]: """ @@ -39,15 +58,21 @@ def get_db_connection() -> Generator[Connection, None, None]: engine: Engine = create_engine(url=sqlalchemy_url, connect_args=connect_args) - with engine.begin() as connection: - if db_type == DatabaseType.POSTGRESQL: - postgres_schema: str | None = get_config().postgres_schema - if postgres_schema: - connection.execute( - text("SET search_path TO :postgres_schema"), - parameters={"postgres_schema": postgres_schema}, - ) - yield connection + if db_type == DatabaseType.DATABRICKS: + # Suppress databricks warnings during connection + with _suppress_databricks_warnings(): + with engine.begin() as connection: + yield connection + else: + with engine.begin() as connection: + if db_type == DatabaseType.POSTGRESQL: + postgres_schema: str | None = get_config().postgres_schema + if postgres_schema: + connection.execute( + text("SET search_path TO :postgres_schema"), + parameters={"postgres_schema": postgres_schema}, + ) + yield connection def _get_snowflake_private_key_der() -> bytes: @@ -64,9 +89,15 @@ def _get_snowflake_private_key_der() -> bytes: ValueError: If neither Snowflake private key nor password are set in configuration. """ # Lazy import - only needed for Snowflake key pair auth - from cryptography.hazmat.primitives import serialization # type: ignore[missing-import] - from cryptography.hazmat.backends import default_backend # type: ignore[missing-import] - from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes # type: ignore[missing-import] + from cryptography.hazmat.backends import ( + default_backend, # type: ignore[missing-import] + ) + from cryptography.hazmat.primitives import ( + serialization, # type: ignore[missing-import] + ) + from cryptography.hazmat.primitives.asymmetric.types import ( + PrivateKeyTypes, # type: ignore[missing-import] + ) snowflake_private_key: str | None = get_config().snowflake_private_key diff --git a/pyproject.toml b/pyproject.toml index f11f6c1..2f03060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jetbase" -version = "0.17.0" +version = "0.17.1" description = "Jetbase is a Python database migration tool" readme = "README.md" authors = [