-
Notifications
You must be signed in to change notification settings - Fork 88
feat: Add native SQLAlchemy dialect support #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: Add native SQLAlchemy dialect support #136
Conversation
Reviewer's GuideImplements a native SQLAlchemy dialect for PrestoDB by adding a new prestodb.sqlalchemy package (dialect, compiler, and datatype mappings), wiring it into SQLAlchemy’s entry points and optional dependencies, updating the DBAPI paramstyle for SQLAlchemy compatibility, and adding docs plus unit/integration tests for dialect behavior and reflection. Sequence diagram for SQLAlchemy engine usage with PrestoDialectsequenceDiagram
actor User
participant SA as SQLAlchemy
participant Eng as Engine
participant Dial as PrestoDialect
participant DBAPI as prestodb.dbapi
participant Auth as prestodb.auth.BasicAuthentication
participant Presto as PrestoDB
User->>SA: create_engine("presto://user:pass@host:8080/catalog/schema")
SA->>Dial: initialize dialect
SA->>Eng: construct Engine with PrestoDialect
User->>Eng: connect()
Eng->>Dial: create_connect_args(url)
Dial->>Auth: BasicAuthentication(user, password)
Dial-->>Eng: ([connect_args], {})
Eng->>DBAPI: connect(**connect_args)
DBAPI-->>Eng: DBAPI connection (paramstyle=pyformat)
User->>Eng: execute("SELECT 1")
Eng->>Dial: use PrestoSQLCompiler and PrestoTypeCompiler
Eng->>DBAPI: cursor.execute(compiled_sql, parameters)
DBAPI->>Presto: HTTP request
Presto-->>DBAPI: query result
DBAPI-->>Eng: rows
Eng-->>User: result proxy
Class diagram for the new Presto SQLAlchemy dialect and typesclassDiagram
direction LR
%% External SQLAlchemy base classes
class DefaultDialect {
}
class SQLCompiler {
}
class GenericTypeCompiler {
}
class IdentifierPreparer {
}
class TypeEngine {
}
class Float {
}
class Boolean {
}
class Integer {
}
class BigInteger {
}
class DECIMAL {
}
class String {
}
class LargeBinary {
}
class JSON {
}
class Date {
}
class Time {
}
class TIMESTAMP {
}
%% Compiler module
class PrestoSQLCompiler {
+visit_char_length_func(fn, kw)
+limit_clause(select, kw)
+visit_lambda_element(element, kw)
}
class PrestoTypeCompiler {
+visit_DOUBLE(type_, kw)
+visit_REAL(type_, kw)
+visit_TINYINT(type_, kw)
+visit_SMALLINT(type_, kw)
+visit_INTEGER(type_, kw)
+visit_BIGINT(type_, kw)
+visit_VARCHAR(type_, kw)
+visit_CHAR(type_, kw)
+visit_VARBINARY(type_, kw)
+visit_JSON(type_, kw)
+visit_FLOAT(type_, kw)
+visit_NUMERIC(type_, kw)
+visit_DECIMAL(type_, kw)
+visit_DATE(type_, kw)
+visit_TIME(type_, kw)
+visit_TIMESTAMP(type_, kw)
+visit_DATETIME(type_, kw)
+visit_CLOB(type_, kw)
+visit_NCLOB(type_, kw)
+visit_TEXT(type_, kw)
+visit_BLOB(type_, kw)
+visit_BOOLEAN(type_, kw)
+visit_ARRAY(type_, kw)
+visit_MAP(type_, kw)
+visit_ROW(type_, kw)
+visit_HYPERLOGLOG(type_, kw)
+visit_QDIGEST(type_, kw)
+visit_P4HYPERLOGLOG(type_, kw)
}
class PrestoIdentifierPreparer {
+reserved_words
}
SQLCompiler <|-- PrestoSQLCompiler
GenericTypeCompiler <|-- PrestoTypeCompiler
IdentifierPreparer <|-- PrestoIdentifierPreparer
%% Datatype module
class DOUBLE {
+__visit_name__
}
class REAL {
+__visit_name__
}
class BOOLEAN {
+__visit_name__
}
class TINYINT {
+__visit_name__
}
class SMALLINT {
+__visit_name__
}
class INTEGER {
+__visit_name__
}
class BIGINT {
+__visit_name__
}
class PrestoDECIMAL {
+__visit_name__
}
class VARCHAR {
+__visit_name__
}
class CHAR {
+__visit_name__
}
class VARBINARY {
+__visit_name__
}
class PrestoJSON {
+__visit_name__
}
class PrestoDATE {
+__visit_name__
}
class PrestoTIME {
+__visit_name__
}
class PrestoTIMESTAMP {
+__visit_name__
}
class INTERVAL {
+__visit_name__
+start
+end
+precision
+__init__(start, end, precision)
}
class ARRAY {
+__visit_name__
+item_type
+__init__(item_type)
}
class MAP {
+__visit_name__
+key_type
+value_type
+__init__(key_type, value_type)
}
class ROW {
+__visit_name__
+attr_types
+__init__(attr_types)
}
class HYPERLOGLOG {
+__visit_name__
}
class QDIGEST {
+__visit_name__
}
class P4HYPERLOGLOG {
+__visit_name__
}
Float <|-- DOUBLE
Float <|-- REAL
Boolean <|-- BOOLEAN
Integer <|-- TINYINT
Integer <|-- SMALLINT
Integer <|-- INTEGER
BigInteger <|-- BIGINT
DECIMAL <|-- PrestoDECIMAL
String <|-- VARCHAR
String <|-- CHAR
LargeBinary <|-- VARBINARY
JSON <|-- PrestoJSON
Date <|-- PrestoDATE
Time <|-- PrestoTIME
TIMESTAMP <|-- PrestoTIMESTAMP
TypeEngine <|-- INTERVAL
TypeEngine <|-- ARRAY
TypeEngine <|-- MAP
TypeEngine <|-- ROW
TypeEngine <|-- HYPERLOGLOG
TypeEngine <|-- QDIGEST
TypeEngine <|-- P4HYPERLOGLOG
%% Dialect class
class PrestoDialect {
+name
+driver
+author
+supports_alter
+supports_pk_on_update
+supports_full_outer_join
+supports_simple_order_by_label
+supports_sane_rowcount
+supports_sane_multi_rowcount
+supports_native_boolean
+statement_compiler
+type_compiler
+preparer
+create_connect_args(url)
+import_dbapi()
+has_table(connection, table_name, schema)
+has_sequence(connection, sequence_name, schema)
+get_schema_names(connection, kw)
+get_table_names(connection, schema, kw)
+get_columns(connection, table_name, schema, kw)
+do_rollback(dbapi_connection)
+get_foreign_keys(connection, table_name, schema, kw)
+get_pk_constraint(connection, table_name, schema, kw)
+get_indexes(connection, table_name, schema, kw)
+do_ping(dbapi_connection)
-_has_object(connection, object_type, object_name, schema)
-_parse_type(type_str)
-_parse_type_args(type_args)
}
DefaultDialect <|-- PrestoDialect
PrestoDialect --> PrestoSQLCompiler : uses
PrestoDialect --> PrestoTypeCompiler : uses
PrestoDialect --> PrestoIdentifierPreparer : uses
PrestoDialect --> DOUBLE : maps types
PrestoDialect --> ARRAY : maps types
PrestoDialect --> MAP : maps types
PrestoDialect --> ROW : maps types
File-Level Changes
Assessment against linked issues
Possibly linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
|
Here is the verification script used to validate the changes: from sqlalchemy import create_engine, select, Column, Integer, String, MetaData, Table
from sqlalchemy.dialects import registry
from prestodb.sqlalchemy.base import PrestoDialect
from prestodb.sqlalchemy import datatype
# Manually register for testing context if setup.py hasn't been re-installed yet
registry.register("presto", "prestodb.sqlalchemy.base", "PrestoDialect")
def test_engine_creation():
print("Testing Engine Creation...")
url = "presto://user:password@localhost:8080/catalog/schema"
engine = create_engine(url)
print(f"Engine Dialect: {engine.dialect.name}")
assert engine.dialect.name == "presto"
# Test connect args parsing
dialect = PrestoDialect()
args, kwargs = dialect.create_connect_args(engine.url)
print(f"Connect Args: {args}")
connect_params = args[0]
assert connect_params["host"] == "localhost"
assert connect_params["port"] == 8080
assert connect_params["user"] == "user"
assert connect_params["catalog"] == "catalog"
assert connect_params["schema"] == "schema"
print("PASS: Engine Creation and URL Parsing\n")
def test_sql_compilation():
print("Testing SQL Compilation...")
engine = create_engine("presto://localhost:8080/catalog/schema")
metadata = MetaData()
t = Table('test_table', metadata,
Column('id', Integer, primary_key=True),
Column('name', String),
Column('active', datatype.BOOLEAN()) # Test using mapped typed datatype
)
stmt = select(t).where(t.c.id == 1)
compiled = stmt.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})
print(f"Compiled SQL: {compiled}")
# Basic check - Presto uses standard SQL for this so it should look familiar
assert "SELECT" in str(compiled)
assert "FROM test_table" in str(compiled)
assert "WHERE test_table.id = 1" in str(compiled) # Literal binding might vary
print("PASS: SQL Compilation\n")
if __name__ == "__main__":
try:
test_engine_creation()
test_sql_compilation()
print("ALL TESTS PASSED")
except Exception as e:
print(f"TEST FAILED: {e}")
import traceback
traceback.print_exc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 3 security issues, 3 other issues, and left some high level feedback:
Security issues:
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
- Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option. (link)
General comments:
- The dialect never sets
default_schema_name, yet several reflection methods rely on it (and even raise ifschemaisNone), so it would be good to define a sensible default (e.g.,default) or derive it from the connection to avoid unexpectedValueErrors and makehas_table/get_columnsbehave more naturally. - The
_parse_type/_parse_type_argsimplementation currently splits on commas naively and will misparse nested complex types likearray(map(varchar, bigint)); consider using a simple tokenizer or recursive descent for arrays/maps/rows so the SQLAlchemy types actually line up with the complex Presto types you’ve exposed. - The new
.gitignorefile is added but appears empty in the diff; either populate it with intended patterns or drop it from this change to avoid a no-op repository file.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The dialect never sets `default_schema_name`, yet several reflection methods rely on it (and even raise if `schema` is `None`), so it would be good to define a sensible default (e.g., `default`) or derive it from the connection to avoid unexpected `ValueError`s and make `has_table`/`get_columns` behave more naturally.
- The `_parse_type`/`_parse_type_args` implementation currently splits on commas naively and will misparse nested complex types like `array(map(varchar, bigint))`; consider using a simple tokenizer or recursive descent for arrays/maps/rows so the SQLAlchemy types actually line up with the complex Presto types you’ve exposed.
- The new `.gitignore` file is added but appears empty in the diff; either populate it with intended patterns or drop it from this change to avoid a no-op repository file.
## Individual Comments
### Comment 1
<location> `prestodb/sqlalchemy/base.py:152` </location>
<code_context>
+
+ def _parse_type(self, type_str):
+ type_str = type_str.lower()
+ match = util.re.match(r"^([a-zA-Z0-9_]+)(\((.+)\))?$", type_str)
+ if not match:
+ return sqltypes.NullType()
</code_context>
<issue_to_address>
**issue (bug_risk):** Multi-word type names like "timestamp with time zone" will never match this regex and fall back to NullType.
The `_type_map` contains multi-word keys like `"time with time zone"` and `"timestamp with time zone"`, but this regex only matches a single token `([a-zA-Z0-9_]+)`, so any spaced type name will miss and default to `NullType`. Please either broaden the regex to handle spaces in the base type (or pre-handle these multi-word patterns) or add explicit special-case handling for the known multi-word types.
</issue_to_address>
### Comment 2
<location> `prestodb/sqlalchemy/base.py:168` </location>
<code_context>
+
+ def _parse_type_args(self, type_args):
+ # TODO: improve parsing for nested types
+ return [int(a) if a.isdigit() else a for a in type_args.split(",")]
+
+ def do_rollback(self, dbapi_connection):
</code_context>
<issue_to_address>
**issue (bug_risk):** Type arguments are not stripped, so values with spaces (e.g. DECIMAL(10, 2)) will be misparsed.
Because `type_args.split(',')` leaves whitespace intact, the second argument in `DECIMAL(10, 2)` becomes `' 2'`, which fails `isdigit()` and is passed through as a string with a leading space. This will likely break when constructing the type. Stripping each token (e.g. `a.strip()`) before checking `isdigit()` would avoid this and handle common spaced declarations correctly.
</issue_to_address>
### Comment 3
<location> `prestodb/sqlalchemy/base.py:103-104` </location>
<code_context>
+ return False
+
+ def _has_object(self, connection, object_type, object_name, schema=None):
+ if schema is None:
+ schema = connection.engine.dialect.default_schema_name
+
+ return (
</code_context>
<issue_to_address>
**issue (bug_risk):** Defaulting to `default_schema_name` may yield an invalid schema when it is None, causing incorrect has_table results.
When `default_schema_name` is None (the default in `DefaultDialect`), this will query `information_schema.tables` with `table_schema = 'None'`, causing `has_table` to incorrectly return False. Consider either requiring `schema` (and raising if missing), or using a concrete default (e.g. a known schema name or the current Presto schema) instead of relying on `default_schema_name` being set.
</issue_to_address>
### Comment 4
<location> `prestodb/sqlalchemy/base.py:107-112` </location>
<code_context>
connection.execute(
"SELECT count(*) FROM information_schema.tables "
"WHERE table_schema = '{}' AND table_name = '{}'".format(
schema, object_name
)
).scalar()
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 5
<location> `prestodb/sqlalchemy/base.py:124` </location>
<code_context>
result = connection.execute("SHOW TABLES FROM {}".format(schema))
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>
### Comment 6
<location> `prestodb/sqlalchemy/base.py:132` </location>
<code_context>
result = connection.execute(query)
</code_context>
<issue_to_address>
**security (python.sqlalchemy.security.sqlalchemy-execute-raw-query):** Avoiding SQL string concatenation: untrusted input concatenated with raw SQL query can result in SQL Injection. In order to execute raw query safely, prepared statement should be used. SQLAlchemy provides TextualSQL to easily used prepared statement with named parameters. For complex SQL composition, use SQL Expression Language or Schema Definition Language. In most cases, SQLAlchemy ORM will be a better option.
*Source: opengrep*
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 5 issues, and left some high level feedback:
- The dialect relies on
default_schema_nameinhas_table/_has_object, but never sets it (andcreate_connect_argsdoesn’t populate a default schema), so calls without an explicit schema will either sendNULLto information_schema or behave inconsistently; consider defining a sensibledefault_schema_nameand aligning reflection with the connection’s default schema/catalog handling. - In
create_connect_args, automatically forcinghttp_scheme = "https"whenever a password is present may be surprising for users who explicitly want HTTP; consider honoring an explicithttp_schemein the URL/query string and only defaulting to HTTPS when nothing is specified. - The
_has_objecthelper takes anobject_typeargument but doesn’t use it in the query, effectively hard-coding table checks; either incorporateobject_typeinto the information_schema query or simplify the signature to reflect that only tables are currently supported.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The dialect relies on `default_schema_name` in `has_table` / `_has_object`, but never sets it (and `create_connect_args` doesn’t populate a default schema), so calls without an explicit schema will either send `NULL` to information_schema or behave inconsistently; consider defining a sensible `default_schema_name` and aligning reflection with the connection’s default schema/catalog handling.
- In `create_connect_args`, automatically forcing `http_scheme = "https"` whenever a password is present may be surprising for users who explicitly want HTTP; consider honoring an explicit `http_scheme` in the URL/query string and only defaulting to HTTPS when nothing is specified.
- The `_has_object` helper takes an `object_type` argument but doesn’t use it in the query, effectively hard-coding table checks; either incorporate `object_type` into the information_schema query or simplify the signature to reflect that only tables are currently supported.
## Individual Comments
### Comment 1
<location> `setup.py:34` </location>
<code_context>
-all_require = [kerberos_require, google_auth_require]
+sqlalchemy_require = ["sqlalchemy"]
+
+all_require = [kerberos_require, google_auth_require, sqlalchemy_require]
tests_require = all_require + ["httpretty", "pytest", "pytest-runner"]
</code_context>
<issue_to_address>
**suggestion (bug_risk):** Consider flattening `all_require` so it is a list of strings rather than a list of lists.
`all_require` is currently a list of lists (e.g. `[kerberos_require, google_auth_require, sqlalchemy_require]`), but `extras_require["all"]` should be a flat list of requirement strings. It would be safer to build `all_require` as a concatenated list, e.g. `all_require = kerberos_require + google_auth_require + sqlalchemy_require`, to avoid potential install/metadata issues with pip.
```suggestion
all_require = kerberos_require + google_auth_require + sqlalchemy_require
```
</issue_to_address>
### Comment 2
<location> `prestodb/sqlalchemy/base.py:118-120` </location>
<code_context>
+ > 0
+ )
+
+ def get_schema_names(self, connection, **kw):
+ result = connection.execute("SELECT schema_name FROM information_schema.schemata")
+ return [row[0] for row in result]
+
</code_context>
<issue_to_address>
**suggestion:** Use `sqlalchemy.text()` for the schema query to improve compatibility with newer SQLAlchemy versions.
Elsewhere in this dialect you already wrap SQL strings with `text()`. To keep this consistent and compatible with SQLAlchemy 1.4/2.x, consider `connection.execute(text("SELECT schema_name FROM information_schema.schemata"))` so you don’t rely on implicit textual SQL handling, which is being tightened in newer releases.
```suggestion
def get_schema_names(self, connection, **kw):
result = connection.execute(
text("SELECT schema_name FROM information_schema.schemata")
)
return [row[0] for row in result]
```
</issue_to_address>
### Comment 3
<location> `tests/test_sqlalchemy.py:38-44` </location>
<code_context>
+from sqlalchemy.types import Integer, String
+from integration_tests.fixtures import run_presto
+
+@pytest.fixture
+def sqlalchemy_engine(run_presto):
+ _, host, port = run_presto
</code_context>
<issue_to_address>
**suggestion (testing):** The `mock_dbapi` fixture is currently unused; consider either using it to test DBAPI integration or removing it.
Current tests only instantiate `PrestoDialect` and never exercise `import_dbapi`/`dbapi.connect`. This fixture would be a good place to add a test that `create_engine` (or the dialect) calls `prestodb.dbapi.connect` with the expected arguments, confirming the DBAPI integration and `paramstyle` wiring. If you don’t intend to add such a test, removing the unused fixture would reduce confusion.
</issue_to_address>
### Comment 4
<location> `integration_tests/test_sqlalchemy_integration.py:30-39` </location>
<code_context>
+ row = result.fetchone()
+ assert row is not None
+
+def test_sqlalchemy_reflection(sqlalchemy_engine):
+ # This requires tables to exist.
+ # tpch is usually available in the test environment (referenced in test_dbapi.py)
+ insp = inspect(sqlalchemy_engine)
+
+ # Check schemas
+ schemas = insp.get_schema_names()
+ assert "sys" in schemas or "system" in schemas
+
+ # Check tables in a specific schema (e.g. system.runtime)
+ tables = insp.get_table_names(schema="system")
+ assert "nodes" in tables or "runtime.nodes" in tables # Representation might vary
+
+def test_sqlalchemy_orm_basic(sqlalchemy_engine):
</code_context>
<issue_to_address>
**issue (testing):** The reflection expectations around `system.runtime.nodes` may not align with how catalog/schema are modeled and could be flaky.
This test currently depends on specific catalog/schema behavior that may not match how Presto models `system.runtime.nodes` (`system` as catalog, `runtime` as schema). With the fixture URL pointing at a non‑`system` catalog, these assumptions can easily fail and make the test environment‑dependent.
To reduce brittleness, align the assertions with the dialect’s actual layout, for example by:
- Connecting explicitly to the `system` catalog when you want to inspect `system.runtime.nodes`.
- Asserting the expected schema (`"runtime"`) and table (`"nodes"`) separately rather than treating `"system"` as the schema and allowing a schema‑qualified table name.
This will better validate the reflection behavior and avoid flaky test failures across environments.
</issue_to_address>
### Comment 5
<location> `integration_tests/test_sqlalchemy_integration.py:43-52` </location>
<code_context>
+ tables = insp.get_table_names(schema="system")
+ assert "nodes" in tables or "runtime.nodes" in tables # Representation might vary
+
+def test_sqlalchemy_orm_basic(sqlalchemy_engine):
+ # Basic table definition
+ metadata = MetaData()
+ # we use a known table from tpch to avoid needing CREATE TABLE rights or persistence
+ # tpch.sf1.customer
+ # but that might be read-only.
+
+ # For integration test without write access, we typically verify SELECTs
+ # If we need to write, we arguably should rely on the test_dbapi.py establishing environment
+
+ with sqlalchemy_engine.connect() as conn:
+ result = conn.execute(text("SELECT count(*) FROM tpch.sf1.customer"))
+ count = result.scalar()
+ assert count > 0
</code_context>
<issue_to_address>
**suggestion (testing):** `test_sqlalchemy_orm_basic` doesn’t exercise ORM constructs and depends on `tpch.sf1.customer` existing, which may be brittle.
This test only runs a raw `text()` query on `tpch.sf1.customer` despite the ORM-style name/imports. Consider either renaming it to reflect that it’s a basic query test, or actually defining a mapped table/metadata and issuing a `select()` using SQLAlchemy objects to exercise the ORM/compiler more thoroughly.
Also, the hard dependency on `tpch.sf1.customer` can make the test flaky if that schema isn’t present. It would be more robust to use a table guaranteed by the integration fixtures or a temporary table created in test setup so the test is self-contained.
Suggested implementation:
```python
from sqlalchemy import MetaData, Table, select
from sqlalchemy.orm import registry, sessionmaker
def test_sqlalchemy_orm_basic(sqlalchemy_engine):
# Use ORM constructs against a table that we know exists in the "system" schema
metadata = MetaData()
insp = inspect(sqlalchemy_engine)
tables = insp.get_table_names(schema="system")
assert tables, "Expected at least one table in the 'system' schema for ORM test"
# Prefer the "nodes" table if available, otherwise fall back to the first table
table_name = "nodes" if "nodes" in tables else tables[0]
system_table = Table(
table_name,
metadata,
schema="system",
autoload_with=sqlalchemy_engine,
)
mapper_registry = registry()
class SystemRow:
pass
mapper_registry.map_imperatively(SystemRow, system_table)
Session = sessionmaker(bind=sqlalchemy_engine)
# Exercise ORM-style select using a Session and mapped class
with Session() as session:
result = session.execute(select(SystemRow).limit(1)).first()
# We don't assert on specific columns or values, just that we can read at least one row
assert result is not None
```
If `MetaData`, `select`, `Table`, `registry`, or `sessionmaker` are already imported elsewhere in this file, you should remove the duplicated import lines I added and keep a single, consolidated import block following your existing conventions.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def test_sqlalchemy_orm_basic(sqlalchemy_engine): | ||
| # Basic table definition | ||
| metadata = MetaData() | ||
| # we use a known table from tpch to avoid needing CREATE TABLE rights or persistence | ||
| # tpch.sf1.customer | ||
| # but that might be read-only. | ||
|
|
||
| # For integration test without write access, we typically verify SELECTs | ||
| # If we need to write, we arguably should rely on the test_dbapi.py establishing environment | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (testing): test_sqlalchemy_orm_basic doesn’t exercise ORM constructs and depends on tpch.sf1.customer existing, which may be brittle.
This test only runs a raw text() query on tpch.sf1.customer despite the ORM-style name/imports. Consider either renaming it to reflect that it’s a basic query test, or actually defining a mapped table/metadata and issuing a select() using SQLAlchemy objects to exercise the ORM/compiler more thoroughly.
Also, the hard dependency on tpch.sf1.customer can make the test flaky if that schema isn’t present. It would be more robust to use a table guaranteed by the integration fixtures or a temporary table created in test setup so the test is self-contained.
Suggested implementation:
from sqlalchemy import MetaData, Table, select
from sqlalchemy.orm import registry, sessionmaker
def test_sqlalchemy_orm_basic(sqlalchemy_engine):
# Use ORM constructs against a table that we know exists in the "system" schema
metadata = MetaData()
insp = inspect(sqlalchemy_engine)
tables = insp.get_table_names(schema="system")
assert tables, "Expected at least one table in the 'system' schema for ORM test"
# Prefer the "nodes" table if available, otherwise fall back to the first table
table_name = "nodes" if "nodes" in tables else tables[0]
system_table = Table(
table_name,
metadata,
schema="system",
autoload_with=sqlalchemy_engine,
)
mapper_registry = registry()
class SystemRow:
pass
mapper_registry.map_imperatively(SystemRow, system_table)
Session = sessionmaker(bind=sqlalchemy_engine)
# Exercise ORM-style select using a Session and mapped class
with Session() as session:
result = session.execute(select(SystemRow).limit(1)).first()
# We don't assert on specific columns or values, just that we can read at least one row
assert result is not NoneIf MetaData, select, Table, registry, or sessionmaker are already imported elsewhere in this file, you should remove the duplicated import lines I added and keep a single, consolidated import block following your existing conventions.
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Description
This PR implements a native SQLAlchemy dialect for
presto-python-client. This feature allows users to connect to PrestoDB using standard SQLAlchemy patterns without relying on external or legacy libraries likePyHive.Closes #135.
Changes
prestodb.sqlalchemycontaining the dialect implementation.base.py: MainPrestoDialectimplementation handling connection args and reflection.datatype.py: Complete mapping of Presto types to SQLAlchemy types, including complex types (ARRAY,MAP,ROW).compiler.py: Custom SQL compiler for Presto-specific syntax.sqlalchemyas an optional dependency insetup.py(extras_require).paramstyle = 'pyformat'toprestodb/dbapi.pyfor PEP 249/SQLAlchemy compliance.README.mdwith installation and usage instructions.Attribution:
This implementation aligns with the architectural patterns used in
trino-python-clientand includes attribution in the source files.Usage
Verification
Verified locally with a test script covering: