From 725cde68a8f2660c0281f3c3199b4af5ac714a61 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 8 Nov 2025 22:54:28 +0900 Subject: [PATCH] Add JSON type support for SQLAlchemy dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement JSON type support to resolve GitHub Issue #619. This enables proper handling of JSON columns when reflecting Iceberg tables and querying JSON data with SQLAlchemy. Changes: - Add visit_JSON method to AthenaTypeCompiler - Update ischema_names mapping from String to JSON type - Add unit test for JSON type compilation - Add integration test for JSON type with CAST operations - Update documentation with JSON Type Support section - Reorganize documentation structure (move Query Execution Callback section) Limitations: - Athena only supports JSON objects, not top-level arrays - JSON type is for DML operations (SELECT), not DDL (CREATE TABLE) Fixes #619 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- docs/sqlalchemy.rst | 464 ++++++++++++--------- pyathena/sqlalchemy/base.py | 2 +- pyathena/sqlalchemy/compiler.py | 3 + tests/pyathena/sqlalchemy/test_base.py | 67 ++- tests/pyathena/sqlalchemy/test_compiler.py | 10 + 5 files changed, 351 insertions(+), 195 deletions(-) diff --git a/docs/sqlalchemy.rst b/docs/sqlalchemy.rst index d3358132..f55fd81c 100644 --- a/docs/sqlalchemy.rst +++ b/docs/sqlalchemy.rst @@ -303,6 +303,224 @@ or :code:`table_name$history` metadata. Again the hint goes after the select sta SELECT * FROM table_name FOR VERSION AS OF 949530903748831860 +.. _sqlalchemy-query-execution-callback: + +Query Execution Callback +------------------------- + +PyAthena provides callback support for SQLAlchemy applications to get immediate access to query IDs +after the ``start_query_execution`` API call, enabling query monitoring and cancellation capabilities. + +Connection-level callback +~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can set a default callback for all queries through an engine's connection parameters: + +.. code:: python + + from sqlalchemy import create_engine, text + + def query_callback(query_id): + print(f"SQLAlchemy query started: {query_id}") + # Store query_id for monitoring or cancellation + + conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" + engine = create_engine( + conn_str, + connect_args={"on_start_query_execution": query_callback} + ) + + with engine.connect() as connection: + result = connection.execute(text("SELECT * FROM many_rows")) + # query_callback will be invoked before query execution + +Execution options callback +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +SQLAlchemy applications can use ``execution_options`` to specify callbacks for individual queries: + +.. code:: python + + from sqlalchemy import create_engine, text + + def specific_callback(query_id): + print(f"Specific query callback: {query_id}") + + conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" + engine = create_engine(conn_str) + + with engine.connect() as connection: + result = connection.execute( + text("SELECT * FROM many_rows").execution_options( + on_start_query_execution=specific_callback + ) + ) + +Query timeout management with SQLAlchemy +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A practical example for managing long-running analytical queries with timeout: + +.. code:: python + + import time + from concurrent.futures import ThreadPoolExecutor, TimeoutError + from sqlalchemy import create_engine, text + + def run_analytics_with_timeout(): + """Run analytics query with automatic timeout and cancellation.""" + + query_info = {'query_id': None, 'connection': None} + + def track_query_start(query_id): + query_info['query_id'] = query_id + print(f"Analytics query started: {query_id}") + + def timeout_monitor(timeout_minutes): + """Cancel query after timeout period.""" + time.sleep(timeout_minutes * 60) + if query_info['query_id'] and query_info['connection']: + try: + # Cancel via raw connection's cursor + cursor = query_info['connection'].connection.cursor() + cursor.cancel() + print(f"Query {query_info['query_id']} cancelled after {timeout_minutes}min timeout") + except Exception as e: + print(f"Cancellation attempt failed: {e}") + + conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" + engine = create_engine( + conn_str, + connect_args={"on_start_query_execution": track_query_start} + ) + + # Complex data processing query + analytics_query = text(""" + WITH monthly_cohorts AS ( + SELECT + date_trunc('month', first_purchase_date) as cohort_month, + user_id, + date_trunc('month', purchase_date) as purchase_month, + revenue + FROM user_purchases + WHERE first_purchase_date >= current_date - interval '2' year + ), + cohort_data AS ( + SELECT + cohort_month, + purchase_month, + COUNT(DISTINCT user_id) as users, + SUM(revenue) as total_revenue, + date_diff('month', cohort_month, purchase_month) as month_number + FROM monthly_cohorts + GROUP BY cohort_month, purchase_month + ) + SELECT + cohort_month, + month_number, + users, + total_revenue, + ROUND(users * 100.0 / FIRST_VALUE(users) OVER ( + PARTITION BY cohort_month ORDER BY month_number + ), 2) as retention_rate + FROM cohort_data + WHERE month_number <= 12 + ORDER BY cohort_month, month_number + """) + + with ThreadPoolExecutor(max_workers=1) as executor: + with engine.connect() as connection: + query_info['connection'] = connection + + # Start timeout monitor (15 minutes for complex analytics) + timeout_future = executor.submit(timeout_monitor, 15) + + try: + print("Starting cohort analysis (15-minute timeout)...") + result = connection.execute(analytics_query) + + # Process results + rows = result.fetchall() + print(f"Cohort analysis completed: {len(rows)} data points") + + # Show sample results + for i, row in enumerate(rows[:5]): # First 5 rows + print(f" Cohort {row.cohort_month}: Month {row.month_number}, " + f"{row.users} users, {row.retention_rate}% retention") + + if len(rows) > 5: + print(f" ... and {len(rows) - 5} more rows") + + except Exception as e: + print(f"Analytics query failed or was cancelled: {e}") + finally: + # Clean up + query_info['connection'] = None + try: + timeout_future.result(timeout=1) + except TimeoutError: + pass # Timeout monitor still running + + # Run the analytics example + run_analytics_with_timeout() + +Multiple callbacks +~~~~~~~~~~~~~~~~~~~ + +When both connection-level and execution_options callbacks are specified, +both callbacks will be invoked: + +.. code:: python + + from sqlalchemy import create_engine, text + + def connection_callback(query_id): + print(f"Connection callback: {query_id}") + # Global monitoring for all queries + + def execution_callback(query_id): + print(f"Execution callback: {query_id}") + # Specific handling for this query + + conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" + engine = create_engine( + conn_str, + connect_args={"on_start_query_execution": connection_callback} + ) + + with engine.connect() as connection: + # This will invoke both connection_callback and execution_callback + result = connection.execute( + text("SELECT 1").execution_options( + on_start_query_execution=execution_callback + ) + ) + +Supported SQLAlchemy dialects +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``on_start_query_execution`` callback is supported by all PyAthena SQLAlchemy dialects: + +* ``awsathena`` and ``awsathena+rest`` (default cursor) +* ``awsathena+pandas`` (pandas cursor) +* ``awsathena+arrow`` (arrow cursor) + +Usage with different dialects: + +.. code:: python + + # With pandas dialect + engine_pandas = create_engine( + "awsathena+pandas://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/", + connect_args={"on_start_query_execution": query_callback} + ) + + # With arrow dialect + engine_arrow = create_engine( + "awsathena+arrow://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/", + connect_args={"on_start_query_execution": query_callback} + ) + Complex Data Types ------------------ @@ -785,220 +1003,82 @@ Migration from Raw Strings array_data = result[0] # [1, 2, 3] - automatically converted first_item = array_data[0] # Direct access -.. _sqlalchemy-query-execution-callback: +JSON Type Support +~~~~~~~~~~~~~~~~~ -Query Execution Callback -~~~~~~~~~~~~~~~~~~~~~~~~~ - -PyAthena provides callback support for SQLAlchemy applications to get immediate access to query IDs -after the ``start_query_execution`` API call, enabling query monitoring and cancellation capabilities. - -Connection-level callback -^^^^^^^^^^^^^^^^^^^^^^^^^ +PyAthena provides support for Amazon Athena's JSON data type, enabling you to work with JSON data in your SQLAlchemy applications. The JSON type is primarily used with Data Manipulation Language (DML) operations in Athena. -You can set a default callback for all queries through an engine's connection parameters: +Basic Usage +^^^^^^^^^^^ .. code:: python - from sqlalchemy import create_engine, text - - def query_callback(query_id): - print(f"SQLAlchemy query started: {query_id}") - # Store query_id for monitoring or cancellation + from sqlalchemy import Column, Integer, Table, MetaData + from sqlalchemy.types import JSON - conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" - engine = create_engine( - conn_str, - connect_args={"on_start_query_execution": query_callback} + # Define a table with JSON column + events = Table('events', metadata, + Column('id', Integer), + Column('metadata', JSON), + Column('config', JSON) ) - with engine.connect() as connection: - result = connection.execute(text("SELECT * FROM many_rows")) - # query_callback will be invoked before query execution +Querying JSON Data +^^^^^^^^^^^^^^^^^^ -Execution options callback -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -SQLAlchemy applications can use ``execution_options`` to specify callbacks for individual queries: +When querying JSON data, PyAthena automatically parses JSON strings into Python dictionaries: .. code:: python - from sqlalchemy import create_engine, text - - def specific_callback(query_id): - print(f"Specific query callback: {query_id}") - - conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" - engine = create_engine(conn_str) + from sqlalchemy import select, literal_column + from sqlalchemy.sql import type_coerce + from sqlalchemy.types import JSON - with engine.connect() as connection: - result = connection.execute( - text("SELECT * FROM many_rows").execution_options( - on_start_query_execution=specific_callback - ) - ) - -Query timeout management with SQLAlchemy -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -A practical example for managing long-running analytical queries with timeout: - -.. code:: python - - import time - from concurrent.futures import ThreadPoolExecutor, TimeoutError - from sqlalchemy import create_engine, text - - def run_analytics_with_timeout(): - """Run analytics query with automatic timeout and cancellation.""" - - query_info = {'query_id': None, 'connection': None} - - def track_query_start(query_id): - query_info['query_id'] = query_id - print(f"Analytics query started: {query_id}") - - def timeout_monitor(timeout_minutes): - """Cancel query after timeout period.""" - time.sleep(timeout_minutes * 60) - if query_info['query_id'] and query_info['connection']: - try: - # Cancel via raw connection's cursor - cursor = query_info['connection'].connection.cursor() - cursor.cancel() - print(f"Query {query_info['query_id']} cancelled after {timeout_minutes}min timeout") - except Exception as e: - print(f"Cancellation attempt failed: {e}") - - conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" - engine = create_engine( - conn_str, - connect_args={"on_start_query_execution": track_query_start} - ) - - # Complex data processing query - analytics_query = text(""" - WITH monthly_cohorts AS ( - SELECT - date_trunc('month', first_purchase_date) as cohort_month, - user_id, - date_trunc('month', purchase_date) as purchase_month, - revenue - FROM user_purchases - WHERE first_purchase_date >= current_date - interval '2' year - ), - cohort_data AS ( - SELECT - cohort_month, - purchase_month, - COUNT(DISTINCT user_id) as users, - SUM(revenue) as total_revenue, - date_diff('month', cohort_month, purchase_month) as month_number - FROM monthly_cohorts - GROUP BY cohort_month, purchase_month + # Query with explicit type coercion + result = connection.execute( + select( + type_coerce( + literal_column('CAST(\'{"name": "test", "value": 123}\' AS JSON)'), + JSON + ).label("json_col") ) - SELECT - cohort_month, - month_number, - users, - total_revenue, - ROUND(users * 100.0 / FIRST_VALUE(users) OVER ( - PARTITION BY cohort_month ORDER BY month_number - ), 2) as retention_rate - FROM cohort_data - WHERE month_number <= 12 - ORDER BY cohort_month, month_number - """) + ).fetchone() - with ThreadPoolExecutor(max_workers=1) as executor: - with engine.connect() as connection: - query_info['connection'] = connection - - # Start timeout monitor (15 minutes for complex analytics) - timeout_future = executor.submit(timeout_monitor, 15) + # Result is automatically parsed as a dictionary + print(result.json_col) # {"name": "test", "value": 123} + print(type(result.json_col)) # - try: - print("Starting cohort analysis (15-minute timeout)...") - result = connection.execute(analytics_query) - - # Process results - rows = result.fetchall() - print(f"Cohort analysis completed: {len(rows)} data points") - - # Show sample results - for i, row in enumerate(rows[:5]): # First 5 rows - print(f" Cohort {row.cohort_month}: Month {row.month_number}, " - f"{row.users} users, {row.retention_rate}% retention") - - if len(rows) > 5: - print(f" ... and {len(rows) - 5} more rows") - - except Exception as e: - print(f"Analytics query failed or was cancelled: {e}") - finally: - # Clean up - query_info['connection'] = None - try: - timeout_future.result(timeout=1) - except TimeoutError: - pass # Timeout monitor still running - - # Run the analytics example - run_analytics_with_timeout() +Important Limitations +^^^^^^^^^^^^^^^^^^^^^ -Multiple callbacks -^^^^^^^^^^^^^^^^^^^ +Athena's JSON type support has specific limitations: -When both connection-level and execution_options callbacks are specified, -both callbacks will be invoked: +* **JSON objects are fully supported** - Objects with key-value pairs work correctly +* **Top-level JSON arrays are not supported** - Direct CAST of arrays like ``[1, 2, 3]`` will fail +* **Arrays within objects are supported** - JSON objects can contain arrays as property values +* **DML only** - JSON type is supported for SELECT queries but not in CREATE TABLE statements .. code:: python - from sqlalchemy import create_engine, text - - def connection_callback(query_id): - print(f"Connection callback: {query_id}") - # Global monitoring for all queries - - def execution_callback(query_id): - print(f"Execution callback: {query_id}") - # Specific handling for this query - - conn_str = "awsathena+rest://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/" - engine = create_engine( - conn_str, - connect_args={"on_start_query_execution": connection_callback} - ) - - with engine.connect() as connection: - # This will invoke both connection_callback and execution_callback - result = connection.execute( - text("SELECT 1").execution_options( - on_start_query_execution=execution_callback - ) + # Supported: JSON object with nested array + result = connection.execute( + select( + type_coerce( + literal_column('CAST(\'{"items": [1, 2, 3]}\' AS JSON)'), + JSON + ).label("json_col") ) + ).fetchone() + print(result.json_col) # {"items": [1, 2, 3]} -Supported SQLAlchemy dialects -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The ``on_start_query_execution`` callback is supported by all PyAthena SQLAlchemy dialects: - -* ``awsathena`` and ``awsathena+rest`` (default cursor) -* ``awsathena+pandas`` (pandas cursor) -* ``awsathena+arrow`` (arrow cursor) - -Usage with different dialects: - -.. code:: python + # Not supported: Top-level array + # This will raise InvalidRequestException + # CAST('[1, 2, 3]' AS JSON) - # With pandas dialect - engine_pandas = create_engine( - "awsathena+pandas://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/", - connect_args={"on_start_query_execution": query_callback} - ) +Best Practices +^^^^^^^^^^^^^^ - # With arrow dialect - engine_arrow = create_engine( - "awsathena+arrow://:@athena.us-west-2.amazonaws.com:443/default?s3_staging_dir=s3://YOUR_S3_BUCKET/path/to/", - connect_args={"on_start_query_execution": query_callback} - ) +1. **Use with SELECT queries** - JSON type works best for querying existing data +2. **Handle nested structures** - Objects with nested arrays and objects are fully supported +3. **Explicit type coercion** - Use ``type_coerce()`` when working with literal JSON values +4. **Error handling** - Be prepared to handle ``InvalidRequestException`` for unsupported operations diff --git a/pyathena/sqlalchemy/base.py b/pyathena/sqlalchemy/base.py index 3e7ffdcd..2ea4d431 100644 --- a/pyathena/sqlalchemy/base.py +++ b/pyathena/sqlalchemy/base.py @@ -81,7 +81,7 @@ "map": types.String, "struct": AthenaStruct, "row": AthenaStruct, - "json": types.String, + "json": types.JSON, } diff --git a/pyathena/sqlalchemy/compiler.py b/pyathena/sqlalchemy/compiler.py index 1364e599..6c3fb58a 100644 --- a/pyathena/sqlalchemy/compiler.py +++ b/pyathena/sqlalchemy/compiler.py @@ -124,6 +124,9 @@ def visit_VARBINARY(self, type_: Type[Any], **kw) -> str: # noqa: N802 def visit_BOOLEAN(self, type_: Type[Any], **kw) -> str: # noqa: N802 return "BOOLEAN" + def visit_JSON(self, type_: Type[Any], **kw) -> str: # noqa: N802 + return "JSON" + def visit_string(self, type_, **kw): # noqa: N802 return "STRING" diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 6d31ec8c..318b11bd 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -11,9 +11,9 @@ import pandas as pd import pytest import sqlalchemy -from sqlalchemy import create_engine, func, select, text, types +from sqlalchemy import create_engine, func, literal_column, select, text, types from sqlalchemy.exc import NoSuchTableError -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, type_coerce from sqlalchemy.sql.ddl import CreateTable from sqlalchemy.sql.schema import Column, MetaData, Table from sqlalchemy.sql.selectable import TextualSelect @@ -30,6 +30,68 @@ def test_basic_query(self, engine): assert rows[0].number_of_rows == 1 assert len(rows[0]) == 1 + def test_json_type_with_cast(self, engine): + """Test JSON type support with CAST operation in SELECT query.""" + engine, conn = engine + # Note: Athena JSON type support has limitations + # - JSON objects are supported + # - Direct CAST of JSON arrays is not supported + # - JSON is primarily used with DML operations, not DDL + + # Test 1: Simple JSON object with type_coerce for proper type handling + result = conn.execute( + select( + type_coerce( + literal_column('CAST(\'{"name": "test", "value": 123}\' AS JSON)'), + types.JSON, + ).label("json_col") + ) + ).fetchone() + assert result.json_col == {"name": "test", "value": 123} + assert isinstance(result.json_col, dict) + + # Test 2: Nested JSON object with arrays inside + # (Arrays are supported as part of JSON objects, just not as top-level CAST) + nested_json_str = '{"user": {"id": 1, "name": "Alice"}, "scores": [95, 87, 92]}' + result = conn.execute( + select( + type_coerce(literal_column(f"CAST('{nested_json_str}' AS JSON)"), types.JSON).label( + "nested_json" + ) + ) + ).fetchone() + assert result.nested_json == { + "user": {"id": 1, "name": "Alice"}, + "scores": [95, 87, 92], + } + assert result.nested_json["user"]["name"] == "Alice" + assert result.nested_json["scores"][0] == 95 + assert isinstance(result.nested_json["scores"], list) + + # Test 3: JSON with null value + result = conn.execute( + select( + type_coerce(literal_column("CAST('{\"key\": null}' AS JSON)"), types.JSON).label( + "json_with_null" + ) + ) + ).fetchone() + assert result.json_with_null == {"key": None} + assert result.json_with_null["key"] is None + + # Test 4: JSON with various types + result = conn.execute( + select( + type_coerce( + literal_column( + 'CAST(\'{"str": "value", "num": 42, "bool": true, "nil": null}\' AS JSON)' + ), + types.JSON, + ).label("json_types") + ) + ).fetchone() + assert result.json_types == {"str": "value", "num": 42, "bool": True, "nil": None} + def test_reflect_no_such_table(self, engine): engine, conn = engine pytest.raises( @@ -405,6 +467,7 @@ def test_get_column_type(self, engine): assert isinstance(decimal_with_args, types.DECIMAL) assert decimal_with_args.precision == 10 assert decimal_with_args.scale == 1 + assert isinstance(dialect._get_column_type("json"), types.JSON) def test_contain_percents_character_query(self, engine): engine, conn = engine diff --git a/tests/pyathena/sqlalchemy/test_compiler.py b/tests/pyathena/sqlalchemy/test_compiler.py index 9ac67390..c91b7554 100644 --- a/tests/pyathena/sqlalchemy/test_compiler.py +++ b/tests/pyathena/sqlalchemy/test_compiler.py @@ -114,6 +114,16 @@ def test_visit_array_no_attributes(self): result = compiler.visit_array(array_type) assert result == "ARRAY" + def test_visit_json(self): + """Test JSON type compilation.""" + from sqlalchemy import types + + dialect = Mock() + compiler = AthenaTypeCompiler(dialect) + json_type = types.JSON() + result = compiler.visit_JSON(json_type) + assert result == "JSON" + class TestAthenaStatementCompiler: """Test cases for Athena statement compiler functionality."""