Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions jetbase/database/connection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
import logging
from contextlib import contextmanager
from typing import Any, Generator

from sqlalchemy import Connection, Engine, create_engine, text
from sqlalchemy.engine import make_url, URL
from sqlalchemy.engine import URL, make_url

from jetbase.config import get_config
from jetbase.database.queries.base import detect_db
from jetbase.enums import DatabaseType


@contextmanager
def _suppress_databricks_warnings():
"""
Temporarily sets the databricks logger to ERROR level to suppress
the deprecated _user_agent_entry warning coming from the databricks-sqlalchemy dependency.

Databricks-sqlalchemy is a dependency of databricks-sql-connector (which is triggering the warning), so we need to suppress the warning here until databricks-sqlalchemy is updated to fix the warning.
"""
logger = logging.getLogger("databricks")
original_level = logger.level
logger.setLevel(logging.ERROR)

try:
yield
finally:
logger.setLevel(original_level)


@contextmanager
def get_db_connection() -> Generator[Connection, None, None]:
"""
Expand Down Expand Up @@ -39,15 +58,21 @@ def get_db_connection() -> Generator[Connection, None, None]:

engine: Engine = create_engine(url=sqlalchemy_url, connect_args=connect_args)

with engine.begin() as connection:
if db_type == DatabaseType.POSTGRESQL:
postgres_schema: str | None = get_config().postgres_schema
if postgres_schema:
connection.execute(
text("SET search_path TO :postgres_schema"),
parameters={"postgres_schema": postgres_schema},
)
yield connection
if db_type == DatabaseType.DATABRICKS:
# Suppress databricks warnings during connection
with _suppress_databricks_warnings():
with engine.begin() as connection:
yield connection
else:
with engine.begin() as connection:
if db_type == DatabaseType.POSTGRESQL:
postgres_schema: str | None = get_config().postgres_schema
if postgres_schema:
connection.execute(
text("SET search_path TO :postgres_schema"),
parameters={"postgres_schema": postgres_schema},
)
yield connection


def _get_snowflake_private_key_der() -> bytes:
Expand All @@ -64,9 +89,15 @@ def _get_snowflake_private_key_der() -> bytes:
ValueError: If neither Snowflake private key nor password are set in configuration.
"""
# Lazy import - only needed for Snowflake key pair auth
from cryptography.hazmat.primitives import serialization # type: ignore[missing-import]
from cryptography.hazmat.backends import default_backend # type: ignore[missing-import]
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes # type: ignore[missing-import]
from cryptography.hazmat.backends import (
default_backend, # type: ignore[missing-import]
)
from cryptography.hazmat.primitives import (
serialization, # type: ignore[missing-import]
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes, # type: ignore[missing-import]
)

snowflake_private_key: str | None = get_config().snowflake_private_key

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jetbase"
version = "0.17.0"
version = "0.17.1"
description = "Jetbase is a Python database migration tool"
readme = "README.md"
authors = [
Expand Down