diff --git a/docs/api.rst b/docs/api.rst index 79c47e2f..61253d5f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -18,6 +18,7 @@ This section provides comprehensive API documentation for all PyAthena classes a api/sqlalchemy api/pandas api/arrow + api/polars api/s3fs api/spark @@ -44,5 +45,6 @@ Specialized integrations - :ref:`api_sqlalchemy` - SQLAlchemy dialect implementations - :ref:`api_pandas` - pandas DataFrame integration - :ref:`api_arrow` - Apache Arrow columnar data integration +- :ref:`api_polars` - Polars DataFrame integration (no pyarrow required) - :ref:`api_s3fs` - Lightweight S3FS-based cursor (no pandas/pyarrow required) - :ref:`api_spark` - Apache Spark integration for big data processing diff --git a/docs/api/polars.rst b/docs/api/polars.rst new file mode 100644 index 00000000..5d4998f4 --- /dev/null +++ b/docs/api/polars.rst @@ -0,0 +1,40 @@ +.. _api_polars: + +Polars Integration +================== + +This section covers Polars-specific cursors, result sets, and data converters. + +Polars Cursors +-------------- + +.. autoclass:: pyathena.polars.cursor.PolarsCursor + :members: + :inherited-members: + +.. autoclass:: pyathena.polars.async_cursor.AsyncPolarsCursor + :members: + :inherited-members: + +Polars Result Set +----------------- + +.. autoclass:: pyathena.polars.result_set.AthenaPolarsResultSet + :members: + :inherited-members: + +Polars Data Converters +---------------------- + +.. autoclass:: pyathena.polars.converter.DefaultPolarsTypeConverter + :members: + +.. autoclass:: pyathena.polars.converter.DefaultPolarsUnloadTypeConverter + :members: + +Polars Utilities +---------------- + +.. autofunction:: pyathena.polars.util.to_column_info + +.. autofunction:: pyathena.polars.util.get_athena_type diff --git a/docs/cursor.rst b/docs/cursor.rst index 53ea6db8..481d909e 100644 --- a/docs/cursor.rst +++ b/docs/cursor.rst @@ -319,6 +319,16 @@ AsyncArrowCursor See :ref:`async-arrow-cursor`. +PolarsCursor +------------ + +See :ref:`polars-cursor`. + +AsyncPolarsCursor +----------------- + +See :ref:`async-polars-cursor`. + S3FSCursor ---------- diff --git a/docs/index.rst b/docs/index.rst index a372eebe..cee8e3ca 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,7 @@ Documentation sqlalchemy pandas arrow + polars s3fs spark testing diff --git a/docs/introduction.rst b/docs/introduction.rst index b28bf509..0bd104dd 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -32,6 +32,8 @@ Extra packages: +---------------+---------------------------------------+------------------+ | Arrow | ``pip install PyAthena[Arrow]`` | >=7.0.0 | +---------------+---------------------------------------+------------------+ +| Polars | ``pip install PyAthena[Polars]`` | >=1.0.0 | ++---------------+---------------------------------------+------------------+ .. _features: @@ -43,7 +45,7 @@ PyAthena provides comprehensive support for Amazon Athena's data types and featu **Core Features:** - **DB API 2.0 Compliance**: Full PEP 249 compatibility for database operations - **SQLAlchemy Integration**: Native dialect support with table reflection and ORM capabilities - - **Multiple Cursor Types**: Standard, Pandas, Arrow, S3FS, and Spark cursor implementations + - **Multiple Cursor Types**: Standard, Pandas, Arrow, Polars, S3FS and Spark cursor implementations - **Async Support**: Asynchronous query execution for non-blocking operations **Data Type Support:** diff --git a/docs/polars.rst b/docs/polars.rst new file mode 100644 index 00000000..55c7caf6 --- /dev/null +++ b/docs/polars.rst @@ -0,0 +1,420 @@ +.. _polars: + +Polars +====== + +.. _polars-cursor: + +PolarsCursor +------------ + +PolarsCursor directly handles the CSV file of the query execution result output to S3. +This cursor downloads the CSV file after executing the query and loads it into a `polars.DataFrame object`_. +Performance is better than fetching data with Cursor. + +PolarsCursor uses `Polars`_ native reading capabilities (``pl.read_csv``, ``pl.read_parquet``) and +does not require PyArrow as a dependency. PyAthena's own S3FileSystem (fsspec compatible) +is used for S3 access, so s3fs is also not required. + +You can use the PolarsCursor by specifying the ``cursor_class`` +with the connect method or connection object. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + +.. code:: python + + from pyathena.connection import Connection + from pyathena.polars.cursor import PolarsCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + +It can also be used by specifying the cursor class when calling the connection object's cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(PolarsCursor) + +.. code:: python + + from pyathena.connection import Connection + from pyathena.polars.cursor import PolarsCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(PolarsCursor) + +The as_polars method returns a `polars.DataFrame object`_. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + + df = cursor.execute("SELECT * FROM many_rows").as_polars() + print(df.describe()) + print(df.head()) + print(df.height) # Number of rows + print(df.width) # Number of columns + print(df.columns) # Column names + +Support fetch and iterate query results. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + print(cursor.fetchone()) + print(cursor.fetchmany()) + print(cursor.fetchall()) + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + for row in cursor: + print(row) + +Execution information of the query can also be retrieved. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + print(cursor.state) + print(cursor.state_change_reason) + print(cursor.completion_date_time) + print(cursor.submission_date_time) + print(cursor.data_scanned_in_bytes) + print(cursor.engine_execution_time_in_millis) + print(cursor.query_queue_time_in_millis) + print(cursor.total_execution_time_in_millis) + print(cursor.query_planning_time_in_millis) + print(cursor.service_processing_time_in_millis) + print(cursor.output_location) + +Arrow Interoperability +~~~~~~~~~~~~~~~~~~~~~~ + +PolarsCursor can convert results to Apache Arrow format if PyArrow is installed. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor() + + # Convert to Arrow Table (requires pyarrow) + table = cursor.execute("SELECT * FROM many_rows").as_arrow() + print(table.num_rows) + print(table.num_columns) + print(table.schema) + +If you want to customize the polars.DataFrame dtypes, create a converter class like this: + +.. code:: python + + import polars as pl + from pyathena.converter import Converter + + class CustomPolarsTypeConverter(Converter): + + def __init__(self): + super().__init__( + mappings=None, + types={ + "boolean": pl.Boolean, + "tinyint": pl.Int8, + "smallint": pl.Int16, + "integer": pl.Int32, + "bigint": pl.Int64, + "float": pl.Float32, + "real": pl.Float64, + "double": pl.Float64, + "decimal": pl.String, + "char": pl.String, + "varchar": pl.String, + "string": pl.String, + "timestamp": pl.Datetime, + "date": pl.Date, + "time": pl.Time, + "varbinary": pl.String, + "array": pl.String, + "map": pl.String, + "row": pl.String, + "json": pl.String, + } + ) + + def convert(self, type_, value): + # Not used in PolarsCursor. + pass + +Then you simply specify an instance of this class in the converter argument when creating a connection or cursor. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(PolarsCursor, converter=CustomPolarsTypeConverter()) + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + converter=CustomPolarsTypeConverter()).cursor(PolarsCursor) + +If the unload option is enabled, the Parquet file itself has a schema, so the conversion is done to the dtypes according to that schema, +and the ``types`` setting of the Converter class is not used. + +Unload Options +~~~~~~~~~~~~~~ + +PolarsCursor supports the unload option, as does :ref:`arrow-cursor`. + +See `Unload options`_ for more information. + +The unload option can be enabled by specifying it in the ``cursor_kwargs`` argument of the connect method or as an argument to the cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor, + cursor_kwargs={ + "unload": True + }).cursor() + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor(unload=True) + +SQLAlchemy allows this option to be specified in the connection string. + +.. code:: text + + awsathena+polars://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&unload=true... + +NOTE: PolarsCursor handles the CSV file on memory. Pay attention to the memory capacity. + +.. _async-polars-cursor: + +AsyncPolarsCursor +----------------- + +AsyncPolarsCursor is an AsyncCursor that can handle `polars.DataFrame object`_. +This cursor directly handles the CSV of query results output to S3 in the same way as PolarsCursor. + +You can use the AsyncPolarsCursor by specifying the ``cursor_class`` +with the connect method or connection object. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + +.. code:: python + + from pyathena.connection import Connection + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + +It can also be used by specifying the cursor class when calling the connection object's cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(AsyncPolarsCursor) + +.. code:: python + + from pyathena.connection import Connection + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(AsyncPolarsCursor) + +The default number of workers is 5 or cpu number * 5. +If you want to change the number of workers you can specify like the following. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor(max_workers=10) + +The execute method of the AsyncPolarsCursor returns the tuple of the query ID and the `future object`_. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + +The return value of the `future object`_ is an ``AthenaPolarsResultSet`` object. +This object has an interface similar to ``AthenaResultSetObject``. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + result_set = future.result() + print(result_set.state) + print(result_set.state_change_reason) + print(result_set.completion_date_time) + print(result_set.submission_date_time) + print(result_set.data_scanned_in_bytes) + print(result_set.engine_execution_time_in_millis) + print(result_set.query_queue_time_in_millis) + print(result_set.total_execution_time_in_millis) + print(result_set.query_planning_time_in_millis) + print(result_set.service_processing_time_in_millis) + print(result_set.output_location) + print(result_set.description) + for row in result_set: + print(row) + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + result_set = future.result() + print(result_set.fetchall()) + +This object also has an as_polars method that returns a `polars.DataFrame object`_ similar to the PolarsCursor. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + result_set = future.result() + df = result_set.as_polars() + print(df.describe()) + print(df.head()) + +As with AsyncPolarsCursor, you need a query ID to cancel a query. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + cursor.cancel(query_id) + +As with AsyncPolarsCursor, the unload option is also available. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor, + cursor_kwargs={ + "unload": True + }).cursor() + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor(unload=True) + +.. _`polars.DataFrame object`: https://docs.pola.rs/api/python/stable/reference/dataframe/index.html +.. _`Polars`: https://pola.rs/ +.. _`Unload options`: arrow.html#unload-options +.. _`future object`: https://docs.python.org/3/library/concurrent.futures.html#future-objects diff --git a/docs/sqlalchemy.rst b/docs/sqlalchemy.rst index 1c83d5af..d3504c2c 100644 --- a/docs/sqlalchemy.rst +++ b/docs/sqlalchemy.rst @@ -43,19 +43,21 @@ If you do not specify ``aws_access_key_id`` and ``aws_secret_access_key`` using Dialect & driver ---------------- -+-----------+--------+------------------+----------------------+ -| Dialect | Driver | Schema | Cursor | -+===========+========+==================+======================+ -| awsathena | | awsathena | DefaultCursor | -+-----------+--------+------------------+----------------------+ -| awsathena | rest | awsathena+rest | DefaultCursor | -+-----------+--------+------------------+----------------------+ -| awsathena | pandas | awsathena+pandas | :ref:`pandas-cursor` | -+-----------+--------+------------------+----------------------+ -| awsathena | arrow | awsathena+arrow | :ref:`arrow-cursor` | -+-----------+--------+------------------+----------------------+ -| awsathena | s3fs | awsathena+s3fs | :ref:`s3fs-cursor` | -+-----------+--------+------------------+----------------------+ ++-----------+--------+------------------+------------------------+ +| Dialect | Driver | Schema | Cursor | ++===========+========+==================+========================+ +| awsathena | | awsathena | DefaultCursor | ++-----------+--------+------------------+------------------------+ +| awsathena | rest | awsathena+rest | DefaultCursor | ++-----------+--------+------------------+------------------------+ +| awsathena | pandas | awsathena+pandas | :ref:`pandas-cursor` | ++-----------+--------+------------------+------------------------+ +| awsathena | arrow | awsathena+arrow | :ref:`arrow-cursor` | ++-----------+--------+------------------+------------------------+ +| awsathena | polars | awsathena+polars | :ref:`polars-cursor` | ++-----------+--------+------------------+------------------------+ +| awsathena | s3fs | awsathena+s3fs | :ref:`s3fs-cursor` | ++-----------+--------+------------------+------------------------+ Dialect options --------------- @@ -506,6 +508,7 @@ The ``on_start_query_execution`` callback is supported by all PyAthena SQLAlchem * ``awsathena`` and ``awsathena+rest`` (default cursor) * ``awsathena+pandas`` (pandas cursor) * ``awsathena+arrow`` (arrow cursor) +* ``awsathena+polars`` (polars cursor) * ``awsathena+s3fs`` (S3FS cursor) Usage with different dialects: diff --git a/docs/usage.rst b/docs/usage.rst index e0ec51ff..57166194 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -365,6 +365,8 @@ The ``on_start_query_execution`` callback is supported by the following cursor t * ``DictCursor`` * ``ArrowCursor`` * ``PandasCursor`` +* ``PolarsCursor`` +* ``S3FSCursor`` Note: ``AsyncCursor`` and its variants do not support this callback as they already return the query ID immediately through their different execution model. diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 1e9304b8..5dcc1cb5 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -38,22 +38,25 @@ class AsyncArrowCursor(AsyncCursor): arraysize: Number of rows to fetch per batch (configurable). Example: - >>> import asyncio >>> from pyathena.arrow.async_cursor import AsyncArrowCursor >>> >>> cursor = connection.cursor(AsyncArrowCursor, unload=True) >>> query_id, future = cursor.execute("SELECT * FROM large_table") >>> >>> # Get result when ready - >>> result_set = await future + >>> result_set = future.result() >>> arrow_table = result_set.as_arrow() >>> >>> # Convert to pandas if needed >>> df = arrow_table.to_pandas() + >>> + >>> # Convert to Polars if needed (requires polars) + >>> polars_df = result_set.as_polars() Note: Requires pyarrow to be installed. UNLOAD operations generate - Parquet files in S3 for optimal Arrow compatibility. + Parquet files in S3 for optimal Arrow compatibility. For Polars + interoperability, polars must be installed separately. """ def __init__( diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 14834d3a..6b2f7c88 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -15,6 +15,7 @@ from pyathena.result_set import WithResultSet if TYPE_CHECKING: + import polars as pl from pyarrow import Table _logger = logging.getLogger(__name__) # type: ignore @@ -316,3 +317,27 @@ def as_arrow(self) -> "Table": raise ProgrammingError("No result set.") result_set = cast(AthenaArrowResultSet, self.result_set) return result_set.as_arrow() + + def as_polars(self) -> "pl.DataFrame": + """Return query results as a Polars DataFrame. + + Converts the Apache Arrow Table to a Polars DataFrame for + interoperability with the Polars data processing library. + + Returns: + Polars DataFrame containing all query results. + + Raises: + ProgrammingError: If no query has been executed or no results are available. + ImportError: If polars is not installed. + + Example: + >>> cursor = connection.cursor(ArrowCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> df = cursor.as_polars() + >>> print(f"DataFrame has {df.height} rows and {df.width} columns") + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaArrowResultSet, self.result_set) + return result_set.as_polars() diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index b8dde2d5..f0d37f67 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -23,6 +23,7 @@ from pyathena.util import RetryConfig, parse_output_location if TYPE_CHECKING: + import polars as pl from pyarrow import Table from pyathena.connection import Connection @@ -189,10 +190,6 @@ def __s3_file_system(self): return fs - @property - def is_unload(self): - return self._unload and self.query and self.query.strip().upper().startswith("UNLOAD") - @property def timestamp_parsers(self) -> List[str]: from pyarrow.csv import ISO8601 @@ -201,14 +198,11 @@ def timestamp_parsers(self) -> List[str]: @property def column_types(self) -> Dict[str, Type[Any]]: - import pyarrow as pa - - converter_types = self._converter.types description = self.description if self.description else [] return { - d[0]: converter_types.get(d[1], pa.string()) + d[0]: dtype for d in description - if d[1] in converter_types + if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None } @property @@ -355,6 +349,33 @@ def _as_arrow(self) -> "Table": def as_arrow(self) -> "Table": return self._table + def as_polars(self) -> "pl.DataFrame": + """Return query results as a Polars DataFrame. + + Converts the Apache Arrow Table to a Polars DataFrame for + interoperability with the Polars data processing library. + + Returns: + Polars DataFrame containing all query results. + + Raises: + ImportError: If polars is not installed. + + Example: + >>> cursor = connection.cursor(ArrowCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> df = cursor.as_polars() + >>> # Use with Polars operations + """ + try: + import polars as pl + + return pl.from_arrow(self._table) # type: ignore[return-value] + except ImportError as e: + raise ImportError( + "polars is required for as_polars(). Install it with: pip install polars" + ) from e + def close(self) -> None: import pyarrow as pa diff --git a/pyathena/converter.py b/pyathena/converter.py index 065b8bf3..d72a612e 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -520,6 +520,22 @@ def remove(self, type_: str) -> None: """ self.mappings.pop(type_, None) + def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Optional[Type[Any]]: + """Get the data type for a given Athena type. + + Subclasses may override this to provide custom type handling + (e.g., for decimal types with precision and scale). + + Args: + type_: The Athena data type name. + precision: The precision for decimal types. + scale: The scale for decimal types. + + Returns: + The corresponding Python type, or None if not found. + """ + return self._types.get(type_) + def update(self, mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]]) -> None: """Update multiple conversion functions at once. diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 1f40baff..b9f3b09f 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -41,14 +41,13 @@ class AsyncPandasCursor(AsyncCursor): chunksize: Number of rows per chunk for large datasets. Example: - >>> import asyncio >>> from pyathena.pandas.async_cursor import AsyncPandasCursor >>> >>> cursor = connection.cursor(AsyncPandasCursor, chunksize=10000) >>> query_id, future = cursor.execute("SELECT * FROM large_table") >>> >>> # Get result when ready - >>> result_set = await future + >>> result_set = future.result() >>> df = result_set.as_pandas() >>> >>> # Or iterate through chunks for large datasets diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index 846cc9ab..2700b354 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -371,16 +371,6 @@ def __s3_file_system(self): max_workers=self._max_workers, ) - @property - def is_unload(self): - """Check if this result set comes from an UNLOAD operation. - - Returns: - True if this result set is from an UNLOAD query and unload mode - is enabled, False otherwise. - """ - return self._unload and self.query and self.query.strip().upper().startswith("UNLOAD") - @property def dtypes(self) -> Dict[str, Type[Any]]: """Get pandas-compatible data types for result columns. @@ -391,7 +381,9 @@ def dtypes(self) -> Dict[str, Type[Any]]: """ description = self.description if self.description else [] return { - d[0]: self._converter.types[d[1]] for d in description if d[1] in self._converter.types + d[0]: dtype + for d in description + if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None } @property diff --git a/pyathena/polars/__init__.py b/pyathena/polars/__init__.py new file mode 100644 index 00000000..20a12efb --- /dev/null +++ b/pyathena/polars/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +import fsspec + +fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True) +fsspec.register_implementation("s3a", "pyathena.filesystem.s3.S3FileSystem", clobber=True) diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py new file mode 100644 index 00000000..2fc3ee6e --- /dev/null +++ b/pyathena/polars/async_cursor.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from concurrent.futures import Future +from multiprocessing import cpu_count +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from pyathena import ProgrammingError +from pyathena.async_cursor import AsyncCursor +from pyathena.common import CursorIterator +from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.polars.converter import ( + DefaultPolarsTypeConverter, + DefaultPolarsUnloadTypeConverter, +) +from pyathena.polars.result_set import AthenaPolarsResultSet + +_logger = logging.getLogger(__name__) + + +class AsyncPolarsCursor(AsyncCursor): + """Asynchronous cursor that returns results as Polars DataFrames. + + This cursor extends AsyncCursor to provide asynchronous query execution + with results returned as Polars DataFrames using Polars' native reading + capabilities. It does not require PyArrow for basic functionality, but can + optionally provide Arrow Table access when PyArrow is installed. + + Features: + - Asynchronous query execution with concurrent futures + - Native Polars CSV and Parquet reading (no PyArrow required) + - Memory-efficient columnar data processing + - Support for UNLOAD operations with Parquet output + - Optional Arrow interoperability when PyArrow is installed + + Attributes: + arraysize: Number of rows to fetch per batch (configurable). + + Example: + >>> from pyathena.polars.async_cursor import AsyncPolarsCursor + >>> + >>> cursor = connection.cursor(AsyncPolarsCursor, unload=True) + >>> query_id, future = cursor.execute("SELECT * FROM large_table") + >>> + >>> # Get result when ready + >>> result_set = future.result() + >>> df = result_set.as_polars() + >>> + >>> # Optional: Convert to Arrow Table if pyarrow is installed + >>> table = result_set.as_arrow() + + Note: + Requires polars to be installed. PyArrow is optional and only needed + for as_arrow() functionality. UNLOAD operations generate Parquet files + in S3 for optimal performance. + """ + + def __init__( + self, + s3_staging_dir: Optional[str] = None, + schema_name: Optional[str] = None, + catalog_name: Optional[str] = None, + work_group: Optional[str] = None, + poll_interval: float = 1, + encryption_option: Optional[str] = None, + kms_key: Optional[str] = None, + kill_on_interrupt: bool = True, + max_workers: int = (cpu_count() or 1) * 5, + arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, + unload: bool = False, + result_reuse_enable: bool = False, + result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + block_size: Optional[int] = None, + cache_type: Optional[str] = None, + **kwargs, + ) -> None: + """Initialize an AsyncPolarsCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + max_workers: Maximum number of workers for concurrent execution. + arraysize: Number of rows to fetch per batch. + unload: Enable UNLOAD for high-performance Parquet output. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + block_size: S3 read block size. + cache_type: S3 caching strategy. + **kwargs: Additional connection parameters. + + Example: + >>> cursor = connection.cursor(AsyncPolarsCursor, unload=True) + """ + super().__init__( + s3_staging_dir=s3_staging_dir, + schema_name=schema_name, + catalog_name=catalog_name, + work_group=work_group, + poll_interval=poll_interval, + encryption_option=encryption_option, + kms_key=kms_key, + kill_on_interrupt=kill_on_interrupt, + max_workers=max_workers, + arraysize=arraysize, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + **kwargs, + ) + self._unload = unload + self._block_size = block_size + self._cache_type = cache_type + + @staticmethod + def get_default_converter( + unload: bool = False, + ) -> Union[DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, Any]: + """Get the default type converter for Polars results. + + Args: + unload: If True, returns converter for UNLOAD (Parquet) results. + + Returns: + Type converter appropriate for the result format. + """ + if unload: + return DefaultPolarsUnloadTypeConverter() + return DefaultPolarsTypeConverter() + + @property + def arraysize(self) -> int: + """Get the number of rows to fetch per batch.""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + """Set the number of rows to fetch per batch. + + Args: + value: Number of rows to fetch. Must be positive. + + Raises: + ProgrammingError: If value is not positive. + """ + if value <= 0: + raise ProgrammingError("arraysize must be a positive integer value.") + self._arraysize = value + + def _collect_result_set( + self, + query_id: str, + unload_location: Optional[str] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> AthenaPolarsResultSet: + if kwargs is None: + kwargs = {} + query_execution = cast(AthenaQueryExecution, self._poll(query_id)) + return AthenaPolarsResultSet( + connection=self._connection, + converter=self._converter, + query_execution=query_execution, + arraysize=self._arraysize, + retry_config=self._retry_config, + unload=self._unload, + unload_location=unload_location, + block_size=self._block_size, + cache_type=self._cache_type, + max_workers=self._max_workers, + **kwargs, + ) + + def execute( + self, + operation: str, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, + work_group: Optional[str] = None, + s3_staging_dir: Optional[str] = None, + cache_size: Optional[int] = 0, + cache_expiration_time: Optional[int] = 0, + result_reuse_enable: Optional[bool] = None, + result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, + **kwargs, + ) -> Tuple[str, "Future[Union[AthenaPolarsResultSet, Any]]"]: + """Execute a SQL query asynchronously and return results as Polars DataFrames. + + Executes the SQL query on Amazon Athena asynchronously and returns a + future that resolves to a result set for Polars DataFrame output. + + Args: + operation: SQL query string to execute. + parameters: Query parameters for parameterized queries. + work_group: Athena workgroup to use for this query. + s3_staging_dir: S3 location for query results. + cache_size: Number of queries to check for result caching. + cache_expiration_time: Cache expiration time in seconds. + result_reuse_enable: Enable Athena result reuse for this query. + result_reuse_minutes: Minutes to reuse cached results. + paramstyle: Parameter style ('qmark' or 'pyformat'). + **kwargs: Additional execution parameters passed to Polars read functions. + + Returns: + Tuple of (query_id, future) where future resolves to AthenaPolarsResultSet. + + Example: + >>> query_id, future = cursor.execute("SELECT * FROM sales") + >>> result_set = future.result() + >>> df = result_set.as_polars() # Returns Polars DataFrame + """ + if self._unload: + s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir + assert s3_staging_dir, "If the unload option is used, s3_staging_dir is required." + operation, unload_location = self._formatter.wrap_unload( + operation, + s3_staging_dir=s3_staging_dir, + format_=AthenaFileFormat.FILE_FORMAT_PARQUET, + compression=AthenaCompression.COMPRESSION_SNAPPY, + ) + else: + unload_location = None + query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, + ) + return ( + query_id, + self._executor.submit( + self._collect_result_set, + query_id, + unload_location, + kwargs, + ), + ) diff --git a/pyathena/polars/converter.py b/pyathena/polars/converter.py new file mode 100644 index 00000000..b078842e --- /dev/null +++ b/pyathena/polars/converter.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from copy import deepcopy +from datetime import date, datetime +from typing import Any, Callable, Dict, Optional, Union + +from pyathena.converter import ( + Converter, + _to_binary, + _to_default, + _to_json, + _to_time, +) + +_logger = logging.getLogger(__name__) + + +def _to_date(value: Optional[Union[str, datetime, date]]) -> Optional[date]: + if value is None: + return None + if isinstance(value, datetime): + return value.date() + if isinstance(value, date): + return value + return datetime.strptime(value, "%Y-%m-%d").date() + + +_DEFAULT_POLARS_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { + "date": _to_date, + "time": _to_time, + "varbinary": _to_binary, + "json": _to_json, +} + + +class DefaultPolarsTypeConverter(Converter): + """Optimized type converter for Polars DataFrame results. + + This converter is specifically designed for the PolarsCursor and provides + optimized type conversion for Polars DataFrames. + + The converter focuses on: + - Converting date/time types to appropriate Python objects + - Handling decimal and binary types + - Preserving JSON and complex types + - Maintaining high performance for columnar operations + + Example: + >>> from pyathena.polars.converter import DefaultPolarsTypeConverter + >>> converter = DefaultPolarsTypeConverter() + >>> + >>> # Used automatically by PolarsCursor + >>> cursor = connection.cursor(PolarsCursor) + >>> # converter is applied automatically to results + + Note: + This converter is used by default in PolarsCursor. + Most users don't need to instantiate it directly. + """ + + def __init__(self) -> None: + super().__init__( + mappings=deepcopy(_DEFAULT_POLARS_CONVERTERS), + default=_to_default, + types=self._dtypes, + ) + + @property + def _dtypes(self) -> Dict[str, Any]: + import polars as pl + + if not hasattr(self, "__dtypes"): + self.__dtypes = { + "boolean": pl.Boolean, + "tinyint": pl.Int8, + "smallint": pl.Int16, + "integer": pl.Int32, + "bigint": pl.Int64, + "float": pl.Float32, + "real": pl.Float64, + "double": pl.Float64, + "char": pl.String, + "varchar": pl.String, + "string": pl.String, + "timestamp": pl.Datetime, + "date": pl.Date, + "time": pl.String, + "varbinary": pl.String, + "array": pl.String, + "map": pl.String, + "row": pl.String, + "decimal": pl.Decimal, + "json": pl.String, + } + return self.__dtypes + + def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Any: + """Get the Polars data type for a given Athena type. + + Args: + type_: The Athena data type name. + precision: The precision for decimal types. + scale: The scale for decimal types. + + Returns: + The Polars data type. + """ + import polars as pl + + if type_ == "decimal": + return pl.Decimal(precision=precision, scale=scale) + return self._types.get(type_) + + def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + converter = self.get(type_) + return converter(value) + + +class DefaultPolarsUnloadTypeConverter(Converter): + """Type converter for Polars UNLOAD operations. + + This converter is designed for use with UNLOAD queries that write + results directly to Parquet files in S3. Since UNLOAD operations + bypass the normal conversion process and write data in native + Parquet format, this converter has minimal functionality. + + Note: + Used automatically when PolarsCursor is configured with unload=True. + UNLOAD results are read directly as Polars DataFrames from Parquet files. + """ + + def __init__(self) -> None: + super().__init__( + mappings={}, + default=_to_default, + ) + + def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + pass diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py new file mode 100644 index 00000000..e8877e2e --- /dev/null +++ b/pyathena/polars/cursor.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from multiprocessing import cpu_count +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast + +from pyathena.common import BaseCursor, CursorIterator +from pyathena.error import OperationalError, ProgrammingError +from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.polars.converter import ( + DefaultPolarsTypeConverter, + DefaultPolarsUnloadTypeConverter, +) +from pyathena.polars.result_set import AthenaPolarsResultSet +from pyathena.result_set import WithResultSet + +if TYPE_CHECKING: + import polars as pl + from pyarrow import Table + +_logger = logging.getLogger(__name__) + + +class PolarsCursor(BaseCursor, CursorIterator, WithResultSet): + """Cursor for handling Polars DataFrame results from Athena queries. + + This cursor returns query results as Polars DataFrames using Polars' native + reading capabilities. It does not require PyArrow for basic functionality, + but can optionally provide Arrow Table access when PyArrow is installed. + + The cursor supports both regular CSV-based results and high-performance + UNLOAD operations that return results in Parquet format for improved + performance with large datasets. + + Attributes: + description: Sequence of column descriptions for the last query. + rowcount: Number of rows affected by the last query (-1 for SELECT queries). + arraysize: Default number of rows to fetch with fetchmany(). + + Example: + >>> from pyathena.polars.cursor import PolarsCursor + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM large_table") + >>> df = cursor.as_polars() # Returns polars.DataFrame + + # Optional: Get Arrow Table (requires pyarrow) + >>> table = cursor.as_arrow() + + # High-performance UNLOAD for large datasets + >>> cursor = connection.cursor(PolarsCursor, unload=True) + >>> cursor.execute("SELECT * FROM huge_table") + >>> df = cursor.as_polars() # Faster Parquet-based result + + Note: + Requires polars to be installed. PyArrow is optional and only + needed for as_arrow() functionality. + """ + + def __init__( + self, + s3_staging_dir: Optional[str] = None, + schema_name: Optional[str] = None, + catalog_name: Optional[str] = None, + work_group: Optional[str] = None, + poll_interval: float = 1, + encryption_option: Optional[str] = None, + kms_key: Optional[str] = None, + kill_on_interrupt: bool = True, + unload: bool = False, + result_reuse_enable: bool = False, + result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + on_start_query_execution: Optional[Callable[[str], None]] = None, + block_size: Optional[int] = None, + cache_type: Optional[str] = None, + max_workers: int = (cpu_count() or 1) * 5, + **kwargs, + ) -> None: + """Initialize a PolarsCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + unload: Enable UNLOAD for high-performance Parquet output. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + on_start_query_execution: Callback invoked when query starts. + block_size: S3 read block size. + cache_type: S3 caching strategy. + max_workers: Maximum worker threads for parallel S3 operations. + **kwargs: Additional connection parameters. + + Example: + >>> cursor = connection.cursor(PolarsCursor, unload=True) + """ + super().__init__( + s3_staging_dir=s3_staging_dir, + schema_name=schema_name, + catalog_name=catalog_name, + work_group=work_group, + poll_interval=poll_interval, + encryption_option=encryption_option, + kms_key=kms_key, + kill_on_interrupt=kill_on_interrupt, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + **kwargs, + ) + self._unload = unload + self._on_start_query_execution = on_start_query_execution + self._block_size = block_size + self._cache_type = cache_type + self._max_workers = max_workers + self._query_id: Optional[str] = None + self._result_set: Optional[AthenaPolarsResultSet] = None + + @staticmethod + def get_default_converter( + unload: bool = False, + ) -> Union[DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, Any]: + """Get the default type converter for Polars results. + + Args: + unload: If True, returns converter for UNLOAD (Parquet) results. + + Returns: + Type converter appropriate for the result format. + """ + if unload: + return DefaultPolarsUnloadTypeConverter() + return DefaultPolarsTypeConverter() + + @property + def arraysize(self) -> int: + """Get the number of rows to fetch per batch.""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + """Set the number of rows to fetch per batch. + + Args: + value: Number of rows to fetch. Must be positive. + + Raises: + ProgrammingError: If value is not positive. + """ + if value <= 0: + raise ProgrammingError("arraysize must be a positive integer value.") + self._arraysize = value + + @property # type: ignore + def result_set(self) -> Optional[AthenaPolarsResultSet]: + """Get the current result set.""" + return self._result_set + + @result_set.setter + def result_set(self, val) -> None: + """Set the current result set.""" + self._result_set = val + + @property + def query_id(self) -> Optional[str]: + """Get the current query execution ID.""" + return self._query_id + + @query_id.setter + def query_id(self, val) -> None: + """Set the current query execution ID.""" + self._query_id = val + + @property + def rownumber(self) -> Optional[int]: + """Get the current row number in the result set.""" + return self.result_set.rownumber if self.result_set else None + + @property + def rowcount(self) -> int: + """Get the number of rows affected by the last operation.""" + return self.result_set.rowcount if self.result_set else -1 + + def close(self) -> None: + """Close the cursor and release resources.""" + if self.result_set and not self.result_set.is_closed: + self.result_set.close() + + def execute( + self, + operation: str, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, + work_group: Optional[str] = None, + s3_staging_dir: Optional[str] = None, + cache_size: Optional[int] = 0, + cache_expiration_time: Optional[int] = 0, + result_reuse_enable: Optional[bool] = None, + result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, + on_start_query_execution: Optional[Callable[[str], None]] = None, + **kwargs, + ) -> "PolarsCursor": + """Execute a SQL query and return results as Polars DataFrames. + + Executes the SQL query on Amazon Athena and configures the result set + for Polars DataFrame output using Polars' native reading capabilities. + + Args: + operation: SQL query string to execute. + parameters: Query parameters for parameterized queries. + work_group: Athena workgroup to use for this query. + s3_staging_dir: S3 location for query results. + cache_size: Number of queries to check for result caching. + cache_expiration_time: Cache expiration time in seconds. + result_reuse_enable: Enable Athena result reuse for this query. + result_reuse_minutes: Minutes to reuse cached results. + paramstyle: Parameter style ('qmark' or 'pyformat'). + on_start_query_execution: Callback called when query starts. + **kwargs: Additional execution parameters passed to Polars read functions. + + Returns: + Self reference for method chaining. + + Example: + >>> cursor.execute("SELECT * FROM sales WHERE year = 2023") + >>> df = cursor.as_polars() # Returns Polars DataFrame + """ + self._reset_state() + if self._unload: + s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir + assert s3_staging_dir, "If the unload option is used, s3_staging_dir is required." + operation, unload_location = self._formatter.wrap_unload( + operation, + s3_staging_dir=s3_staging_dir, + format_=AthenaFileFormat.FILE_FORMAT_PARQUET, + compression=AthenaCompression.COMPRESSION_SNAPPY, + ) + else: + unload_location = None + self.query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, + ) + + # Call user callbacks immediately after start_query_execution + # Both connection-level and execute-level callbacks are invoked if set + if self._on_start_query_execution: + self._on_start_query_execution(self.query_id) + if on_start_query_execution: + on_start_query_execution(self.query_id) + query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) + if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: + self.result_set = AthenaPolarsResultSet( + connection=self._connection, + converter=self._converter, + query_execution=query_execution, + arraysize=self.arraysize, + retry_config=self._retry_config, + unload=self._unload, + unload_location=unload_location, + block_size=self._block_size, + cache_type=self._cache_type, + max_workers=self._max_workers, + **kwargs, + ) + else: + raise OperationalError(query_execution.state_change_reason) + return self + + def executemany( + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, + ) -> None: + """Execute a SQL query multiple times with different parameters. + + Args: + operation: SQL query string to execute. + seq_of_parameters: Sequence of parameter sets. + **kwargs: Additional execution parameters. + """ + for parameters in seq_of_parameters: + self.execute(operation, parameters, **kwargs) + # Operations that have result sets are not allowed with executemany. + self._reset_state() + + def cancel(self) -> None: + """Cancel the currently running query. + + Raises: + ProgrammingError: If no query is currently running. + """ + if not self.query_id: + raise ProgrammingError("QueryExecutionId is none or empty.") + self._cancel(self.query_id) + + def fetchone( + self, + ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next row of the query result. + + Returns: + A single row as a tuple, or None if no more rows are available. + + Raises: + ProgrammingError: If no result set is available. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + return result_set.fetchone() + + def fetchmany( + self, size: Optional[int] = None + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next set of rows of the query result. + + Args: + size: Number of rows to fetch. Defaults to arraysize. + + Returns: + A list of rows as tuples. + + Raises: + ProgrammingError: If no result set is available. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + return result_set.fetchmany(size) + + def fetchall( + self, + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch all remaining rows of the query result. + + Returns: + A list of all remaining rows as tuples. + + Raises: + ProgrammingError: If no result set is available. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + return result_set.fetchall() + + def as_polars(self) -> "pl.DataFrame": + """Return query results as a Polars DataFrame. + + Returns the query results as a Polars DataFrame. This is the primary + method for accessing results with PolarsCursor. + + Returns: + Polars DataFrame containing all query results. + + Raises: + ProgrammingError: If no query has been executed or no results are available. + + Example: + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> df = cursor.as_polars() + >>> print(f"DataFrame has {df.height} rows and {df.width} columns") + >>> filtered = df.filter(pl.col("value") > 100) + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + return result_set.as_polars() + + def as_arrow(self) -> "Table": + """Return query results as an Apache Arrow Table. + + Converts the Polars DataFrame to an Apache Arrow Table for + interoperability with other Arrow-compatible tools and libraries. + + Returns: + Apache Arrow Table containing all query results. + + Raises: + ProgrammingError: If no query has been executed or no results are available. + ImportError: If pyarrow is not installed. + + Example: + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> table = cursor.as_arrow() + >>> print(f"Table has {table.num_rows} rows and {table.num_columns} columns") + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + return result_set.as_arrow() diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py new file mode 100644 index 00000000..6c4d18c2 --- /dev/null +++ b/pyathena/polars/result_set.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from multiprocessing import cpu_count +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) + +from pyathena import OperationalError +from pyathena.converter import Converter +from pyathena.error import ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.polars.util import to_column_info +from pyathena.result_set import AthenaResultSet +from pyathena.util import RetryConfig + +if TYPE_CHECKING: + import polars as pl + from pyarrow import Table + + from pyathena.connection import Connection + +_logger = logging.getLogger(__name__) + + +class AthenaPolarsResultSet(AthenaResultSet): + """Result set that provides Polars DataFrame results with optional Arrow interoperability. + + This result set handles CSV and Parquet result files from S3, converting them to + Polars DataFrames using Polars' native reading capabilities. It does not require + PyArrow for basic functionality, but can optionally provide Arrow Table access + when PyArrow is installed. + + Features: + - Native Polars CSV and Parquet reading (no PyArrow required) + - Efficient columnar data processing with Polars + - Optional Arrow interoperability when PyArrow is available + - Support for both CSV and Parquet result formats + - Optimized memory usage through columnar format + + Example: + >>> # Used automatically by PolarsCursor + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM large_table") + >>> + >>> # Get Polars DataFrame + >>> df = cursor.as_polars() + >>> + >>> # Work with Polars + >>> print(f"DataFrame has {df.height} rows and {df.width} columns") + >>> filtered = df.filter(pl.col("value") > 100) + >>> + >>> # Optional: Get Arrow Table (requires pyarrow) + >>> table = cursor.as_arrow() + + Note: + This class is used internally by PolarsCursor and typically not + instantiated directly by users. Requires polars to be installed. + PyArrow is optional and only needed for as_arrow() functionality. + """ + + def __init__( + self, + connection: "Connection[Any]", + converter: Converter, + query_execution: AthenaQueryExecution, + arraysize: int, + retry_config: RetryConfig, + unload: bool = False, + unload_location: Optional[str] = None, + block_size: Optional[int] = None, + cache_type: Optional[str] = None, + max_workers: int = (cpu_count() or 1) * 5, + **kwargs, + ) -> None: + """Initialize the Polars result set. + + Args: + connection: The Athena connection object. + converter: Type converter for Athena data types. + query_execution: Query execution metadata. + arraysize: Number of rows to fetch per batch. + retry_config: Configuration for retry behavior. + unload: Whether this is an UNLOAD query result. + unload_location: S3 location for UNLOAD results. + block_size: Block size for S3 file reading. + cache_type: Cache type for S3 file system. + max_workers: Maximum number of worker threads. + **kwargs: Additional arguments passed to Polars read functions. + """ + super().__init__( + connection=connection, + converter=converter, + query_execution=query_execution, + arraysize=1, # Fetch one row to retrieve metadata + retry_config=retry_config, + ) + self._rows.clear() # Clear pre_fetch data + self._arraysize = arraysize + self._unload = unload + self._unload_location = unload_location + self._block_size = block_size + self._cache_type = cache_type + self._max_workers = max_workers + self._kwargs = kwargs + if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: + self._df = self._as_polars() + else: + import polars as pl + + self._df = pl.DataFrame() + self._row_index = 0 + + @property + def _csv_storage_options(self) -> Dict[str, Any]: + """Get storage options for Polars CSV reading via fsspec. + + Polars read_csv uses fsspec for cloud storage access, which works + with PyAthena's registered S3FileSystem. + + Returns: + Dictionary with fsspec-compatible options for S3 access. + """ + return { + "connection": self.connection, + "default_block_size": self._block_size, + "default_cache_type": self._cache_type, + "max_workers": self._max_workers, + } + + @property + def _parquet_storage_options(self) -> Dict[str, Any]: + """Get storage options for Polars Parquet reading via native object_store. + + Polars read_parquet uses Rust's native object_store crate, which requires + AWS credentials to be passed directly rather than through fsspec. + + Returns: + Dictionary with AWS credentials and region for S3 access. + """ + credentials = self.connection.session.get_credentials() + options: Dict[str, Any] = {} + if credentials: + frozen_credentials = credentials.get_frozen_credentials() + options["aws_access_key_id"] = frozen_credentials.access_key + options["aws_secret_access_key"] = frozen_credentials.secret_key + if frozen_credentials.token: + options["aws_session_token"] = frozen_credentials.token + if self.connection.region_name: + options["aws_region"] = self.connection.region_name + return options + + @property + def dtypes(self) -> Dict[str, Any]: + """Get Polars-compatible data types for result columns.""" + description = self.description if self.description else [] + return { + d[0]: dtype + for d in description + if (dtype := self._converter.get_dtype(d[1], d[4], d[5])) is not None + } + + @property + def converters(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: + """Get converter functions for each column. + + Returns: + Dictionary mapping column names to their converter functions. + """ + description = self.description if self.description else [] + return {d[0]: self._converter.get(d[1]) for d in description} + + def _fetch(self) -> None: + """Fetch rows from the DataFrame into the row buffer.""" + if self._row_index >= self._df.height: + return + + end_index = min(self._row_index + self._arraysize, self._df.height) + chunk = self._df.slice(self._row_index, end_index - self._row_index) + self._row_index = end_index + + # Convert to rows and apply converters + description = self.description if self.description else [] + column_names = [d[0] for d in description] + for row_dict in chunk.iter_rows(named=True): + processed_row = tuple( + self.converters.get(col, lambda x: x)(row_dict.get(col)) for col in column_names + ) + self._rows.append(processed_row) + + def fetchone( + self, + ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next row of the query result. + + Returns: + A single row as a tuple, or None if no more rows are available. + """ + if not self._rows: + self._fetch() + if not self._rows: + return None + if self._rownumber is None: + self._rownumber = 0 + self._rownumber += 1 + return self._rows.popleft() + + def fetchmany( + self, size: Optional[int] = None + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next set of rows of the query result. + + Args: + size: Number of rows to fetch. Defaults to arraysize. + + Returns: + A list of rows as tuples. + """ + if not size or size <= 0: + size = self._arraysize + rows = [] + for _ in range(size): + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def fetchall( + self, + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch all remaining rows of the query result. + + Returns: + A list of all remaining rows as tuples. + """ + rows = [] + while True: + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def _read_csv(self) -> "pl.DataFrame": + """Read query results from CSV file in S3. + + Returns: + Polars DataFrame containing the CSV data. + + Raises: + ProgrammingError: If output location is not set. + OperationalError: If reading the CSV file fails. + """ + import polars as pl + + if not self.output_location: + raise ProgrammingError("OutputLocation is none or empty.") + if not self.output_location.endswith((".csv", ".txt")): + return pl.DataFrame() + if self.substatement_type and self.substatement_type.upper() in ( + "UPDATE", + "DELETE", + "MERGE", + "VACUUM_TABLE", + ): + return pl.DataFrame() + length = self._get_content_length() + if length == 0: + return pl.DataFrame() + + if self.output_location.endswith(".txt"): + separator = "\t" + has_header = False + description = self.description if self.description else [] + new_columns = [d[0] for d in description] + elif self.output_location.endswith(".csv"): + separator = "," + has_header = True + new_columns = None + else: + return pl.DataFrame() + + try: + df = pl.read_csv( + self.output_location, + separator=separator, + has_header=has_header, + schema_overrides=self.dtypes, + storage_options=self._csv_storage_options, + **self._kwargs, + ) + if new_columns: + df.columns = new_columns + return df + except Exception as e: + _logger.exception(f"Failed to read {self.output_location}.") + raise OperationalError(*e.args) from e + + def _read_parquet(self) -> "pl.DataFrame": + """Read query results from Parquet files in S3. + + Returns: + Polars DataFrame containing the Parquet data. + + Raises: + OperationalError: If reading the Parquet files fails. + """ + import polars as pl + + manifests = self._read_data_manifest() + if not manifests: + return pl.DataFrame() + if not self._unload_location: + self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" + + try: + return pl.read_parquet( + self._unload_location, + storage_options=self._parquet_storage_options, + **self._kwargs, + ) + except Exception as e: + _logger.exception(f"Failed to read {self._unload_location}.") + raise OperationalError(*e.args) from e + + def _read_parquet_schema(self) -> Tuple[Dict[str, Any], ...]: + """Read schema from Parquet files for metadata.""" + import polars as pl + + if not self._unload_location: + raise ProgrammingError("UnloadLocation is none or empty.") + + try: + # Use scan_parquet to get schema without reading all data + lazy_df = pl.scan_parquet( + self._unload_location, + storage_options=self._parquet_storage_options, + ) + schema = lazy_df.collect_schema() + return to_column_info(schema) + except Exception as e: + _logger.exception(f"Failed to read schema from {self._unload_location}.") + raise OperationalError(*e.args) from e + + def _as_polars(self) -> "pl.DataFrame": + """Load query results as a Polars DataFrame. + + Reads from Parquet for UNLOAD queries, otherwise from CSV. + + Returns: + Polars DataFrame containing the query results. + """ + if self.is_unload: + df = self._read_parquet() + if df.is_empty(): + self._metadata = () + else: + self._metadata = self._read_parquet_schema() + else: + df = self._read_csv() + return df + + def as_polars(self) -> "pl.DataFrame": + """Return query results as a Polars DataFrame. + + Returns the query results as a Polars DataFrame. This is the primary + method for accessing results with PolarsCursor. + + Returns: + Polars DataFrame containing all query results. + + Example: + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> df = cursor.as_polars() + >>> print(f"DataFrame has {df.height} rows") + >>> filtered = df.filter(pl.col("value") > 100) + """ + return self._df + + def as_arrow(self) -> "Table": + """Return query results as an Apache Arrow Table. + + Converts the Polars DataFrame to an Apache Arrow Table for + interoperability with other Arrow-compatible tools and libraries. + + Returns: + Apache Arrow Table containing all query results. + + Raises: + ImportError: If pyarrow is not installed. + + Example: + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> table = cursor.as_arrow() + >>> # Use with other Arrow-compatible libraries + """ + try: + return self._df.to_arrow() + except ImportError as e: + raise ImportError( + "pyarrow is required for as_arrow(). Install it with: pip install pyarrow" + ) from e + + def close(self) -> None: + """Close the result set and release resources.""" + import polars as pl + + super().close() + self._df = pl.DataFrame() + self._row_index = 0 diff --git a/pyathena/polars/util.py b/pyathena/polars/util.py new file mode 100644 index 00000000..afc64b6b --- /dev/null +++ b/pyathena/polars/util.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +"""Utilities for converting Polars types to Athena metadata. + +This module provides functions to convert Polars schema and type information +to Athena-compatible column metadata, enabling proper type mapping when +reading query results in Polars format. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Tuple + +if TYPE_CHECKING: + import polars as pl + + +def to_column_info(schema: "pl.Schema") -> Tuple[Dict[str, Any], ...]: + """Convert a Polars schema to Athena column information. + + Iterates through all fields in the schema and converts each field's + type information to an Athena-compatible column metadata dictionary. + + Args: + schema: A Polars Schema object containing field definitions. + + Returns: + A tuple of dictionaries, each containing column metadata with keys: + - Name: The column name + - Type: The Athena SQL type name + - Precision: Numeric precision (0 for non-numeric types) + - Scale: Numeric scale (0 for non-numeric types) + - Nullable: Always "NULLABLE" for Polars types + """ + columns = [] + for name, dtype in schema.items(): + type_, precision, scale = get_athena_type(dtype) + columns.append( + { + "Name": name, + "Type": type_, + "Precision": precision, + "Scale": scale, + "Nullable": "NULLABLE", + } + ) + return tuple(columns) + + +def get_athena_type(dtype: Any) -> Tuple[str, int, int]: + """Map a Polars data type to an Athena SQL type. + + Converts Polars type identifiers to corresponding Athena SQL type names + with appropriate precision and scale values. Handles all common Polars + types including numeric, string, binary, temporal, and complex types. + + Args: + dtype: A Polars DataType object to convert. + + Returns: + A tuple of (type_name, precision, scale) where: + - type_name: The Athena SQL type (e.g., "varchar", "bigint", "timestamp") + - precision: The numeric precision or max length + - scale: The numeric scale (decimal places) + + Note: + Unknown types default to "string" with maximum varchar length. + Decimal types preserve their original precision and scale. + """ + import polars as pl + + # Use base_type() to handle parameterized types correctly + # (e.g., Datetime(time_unit="us") -> Datetime) + base_dtype = dtype.base_type() if hasattr(dtype, "base_type") else dtype + + # Type mapping: Polars type -> (Athena type, precision, scale) + type_mapping: Dict[Any, Tuple[str, int, int]] = { + pl.Boolean: ("boolean", 0, 0), + pl.Int8: ("tinyint", 3, 0), + pl.Int16: ("smallint", 5, 0), + pl.Int32: ("integer", 10, 0), + pl.Int64: ("bigint", 19, 0), + pl.UInt8: ("tinyint", 3, 0), + pl.UInt16: ("smallint", 5, 0), + pl.UInt32: ("integer", 10, 0), + pl.UInt64: ("bigint", 19, 0), + pl.Float32: ("float", 17, 0), + pl.Float64: ("double", 17, 0), + pl.String: ("varchar", 2147483647, 0), + pl.Utf8: ("varchar", 2147483647, 0), + pl.Date: ("date", 0, 0), + pl.Datetime: ("timestamp", 3, 0), + pl.Time: ("time", 0, 0), + pl.Binary: ("varbinary", 1073741824, 0), + } + + # Check base type using both base_dtype and original dtype + for polars_type, athena_info in type_mapping.items(): + if base_dtype == polars_type or dtype == polars_type: + return athena_info + + # Handle parameterized types that didn't match above + dtype_str = str(dtype).lower() + if "list" in dtype_str: + return ("array", 0, 0) + if "struct" in dtype_str: + return ("row", 0, 0) + if "decimal" in dtype_str: + # Extract precision and scale from Decimal type if available + if hasattr(dtype, "precision") and hasattr(dtype, "scale"): + return ("decimal", dtype.precision, dtype.scale) + return ("decimal", 38, 9) # Default precision and scale + + return ("string", 2147483647, 0) diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 15b97304..2168fcb4 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -242,6 +242,19 @@ def reused_previous_result(self) -> Optional[bool]: return None return self._query_execution.reused_previous_result + @property + def is_unload(self) -> bool: + """Check if the query is an UNLOAD statement. + + Returns: + True if the query is an UNLOAD statement, False otherwise. + """ + return bool( + getattr(self, "_unload", False) + and self.query + and self.query.strip().upper().startswith("UNLOAD") + ) + @property def encryption_option(self) -> Optional[str]: if not self._query_execution: diff --git a/pyathena/spark/async_cursor.py b/pyathena/spark/async_cursor.py index f8b33ef5..634feca3 100644 --- a/pyathena/spark/async_cursor.py +++ b/pyathena/spark/async_cursor.py @@ -35,7 +35,6 @@ class AsyncSparkCursor(SparkBaseCursor): engine_configuration: Spark engine configuration settings. Example: - >>> import asyncio >>> from pyathena.spark.async_cursor import AsyncSparkCursor >>> >>> cursor = connection.cursor( @@ -55,10 +54,10 @@ class AsyncSparkCursor(SparkBaseCursor): >>> calculation_id, future = cursor.execute(spark_code) >>> >>> # Get result when ready - >>> calc_execution = await future + >>> calc_execution = future.result() >>> stdout_future = cursor.get_std_out(calc_execution) >>> if stdout_future: - ... output = await stdout_future + ... output = stdout_future.result() ... print(output) Note: diff --git a/pyathena/sqlalchemy/polars.py b/pyathena/sqlalchemy/polars.py new file mode 100644 index 00000000..702ae522 --- /dev/null +++ b/pyathena/sqlalchemy/polars.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +from typing import TYPE_CHECKING + +from pyathena.sqlalchemy.base import AthenaDialect +from pyathena.util import strtobool + +if TYPE_CHECKING: + from types import ModuleType + + +class AthenaPolarsDialect(AthenaDialect): + """SQLAlchemy dialect for Amazon Athena with Polars DataFrame result format. + + This dialect extends AthenaDialect to use PolarsCursor, which returns + query results as Polars DataFrames using Polars' native reading capabilities. + It does not require PyArrow for basic functionality, making it a lightweight + option for analytical workloads. + + Connection URL Format: + ``awsathena+polars://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}`` + + Query Parameters: + In addition to the base dialect parameters: + - unload: If "true", use UNLOAD for Parquet output (better performance + for large datasets) + + Example: + >>> from sqlalchemy import create_engine + >>> engine = create_engine( + ... "awsathena+polars://:@athena.us-west-2.amazonaws.com/default" + ... "?s3_staging_dir=s3://my-bucket/athena-results/" + ... "&unload=true" + ... ) + + See Also: + :class:`~pyathena.polars.cursor.PolarsCursor`: The underlying cursor + implementation. + :class:`~pyathena.sqlalchemy.base.AthenaDialect`: Base dialect class. + """ + + driver = "polars" + supports_statement_cache = True + + def create_connect_args(self, url): + from pyathena.polars.cursor import PolarsCursor + + opts = super()._create_connect_args(url) + opts.update({"cursor_class": PolarsCursor}) + cursor_kwargs = {} + if "unload" in opts: + cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))}) + if cursor_kwargs: + opts.update({"cursor_kwargs": cursor_kwargs}) + return [[], opts] + + @classmethod + def import_dbapi(cls) -> "ModuleType": + return super().import_dbapi() diff --git a/pyproject.toml b/pyproject.toml index 4f4f1ac8..07c13fb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ awsathena = "pyathena.sqlalchemy.base:AthenaDialect" "awsathena.rest" = "pyathena.sqlalchemy.rest:AthenaRestDialect" "awsathena.pandas" = "pyathena.sqlalchemy.pandas:AthenaPandasDialect" "awsathena.arrow" = "pyathena.sqlalchemy.arrow:AthenaArrowDialect" +"awsathena.polars" = "pyathena.sqlalchemy.polars:AthenaPolarsDialect" "awsathena.s3fs" = "pyathena.sqlalchemy.s3fs:AthenaS3FSDialect" [project.optional-dependencies] @@ -50,6 +51,9 @@ arrow = [ "pyarrow>=10.0.0; python_version<'3.14'", "pyarrow>=22.0.0; python_version>='3.14'", ] +polars = [ + "polars>=1.0.0", +] [dependency-groups] dev = [ @@ -57,9 +61,10 @@ dev = [ "pandas>=1.3.0; python_version<'3.13'", "pandas>=2.3.0; python_version>='3.13'", "numpy>=1.26.0; python_version<'3.13'", - "numpy>=2.3.0; python_version>='3.13'", + "numpy>=2.3.0; python_version>='3.14'", "pyarrow>=10.0.0; python_version<'3.14'", "pyarrow>=22.0.0; python_version>='3.14'", + "polars>=1.0.0", "Jinja2>=3.1.0", "mypy>=0.900", "pytest>=3.5", diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index 366a4bd6..655a69ea 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -8,6 +8,7 @@ from decimal import Decimal import pandas as pd +import polars as pl import pyarrow as pa import pytest @@ -249,6 +250,7 @@ def test_fetch_no_data(self, arrow_cursor): pytest.raises(ProgrammingError, arrow_cursor.fetchmany) pytest.raises(ProgrammingError, arrow_cursor.fetchall) pytest.raises(ProgrammingError, arrow_cursor.as_arrow) + pytest.raises(ProgrammingError, arrow_cursor.as_polars) @pytest.mark.parametrize( "arrow_cursor", @@ -427,6 +429,180 @@ def test_complex_unload_as_arrow(self, arrow_cursor): ) ] + @pytest.mark.parametrize( + "arrow_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["arrow_cursor"], + ) + def test_as_polars(self, arrow_cursor): + df = arrow_cursor.execute("SELECT * FROM one_row").as_polars() + assert df.height == 1 + assert df.width == 1 + assert df.to_dicts() == [{"number_of_rows": 1}] + + @pytest.mark.parametrize( + "arrow_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["arrow_cursor"], + ) + def test_many_as_polars(self, arrow_cursor): + df = arrow_cursor.execute("SELECT * FROM many_rows").as_polars() + assert df.height == 10000 + assert df.width == 1 + assert df.to_dicts() == [{"a": i} for i in range(10000)] + + def test_complex_as_polars(self, arrow_cursor): + df = arrow_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ + ).as_polars() + assert df.height == 1 + assert df.width == 19 + dtypes = tuple(df.dtypes) + assert dtypes == ( + pl.Boolean, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.Float32, + pl.Float64, + pl.String, + pl.String, + pl.Datetime("ms"), + pl.String, + pl.Datetime("ms"), + pl.String, + pl.String, + pl.String, + pl.String, + pl.String, + pl.String, + pl.String, + ) + rows = df.to_dicts() + assert rows == [ + { + "col_boolean": True, + "col_tinyint": 127, + "col_smallint": 32767, + "col_int": 2147483647, + "col_bigint": 9223372036854775807, + "col_float": 0.5, + "col_double": 0.25, + "col_string": "a string", + "col_varchar": "varchar", + "col_timestamp": datetime(2017, 1, 1, 0, 0, 0), + "col_time": "00:00:00.000", + "col_date": datetime(2017, 1, 2, 0, 0, 0), + "col_binary": "31 32 33", + "col_array": "[1, 2]", + "col_array_json": "[1,2]", + "col_map": "{1=2, 3=4}", + "col_map_json": '{"1":2,"3":4}', + "col_struct": "{a=1, b=2}", + "col_decimal": "0.1", + } + ] + + @pytest.mark.parametrize( + "arrow_cursor", + [ + { + "cursor_kwargs": {"unload": True}, + }, + ], + indirect=["arrow_cursor"], + ) + def test_complex_unload_as_polars(self, arrow_cursor): + # NOT_SUPPORTED: Unsupported Hive type: time + # NOT_SUPPORTED: Unsupported Hive type: json + df = arrow_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,col_date + ,col_binary + ,col_array + ,col_map + ,col_struct + ,col_decimal + FROM one_row_complex + """ + ).as_polars() + assert df.height == 1 + assert df.width == 16 + dtypes = tuple(df.dtypes) + assert dtypes == ( + pl.Boolean, + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.Float32, + pl.Float64, + pl.String, + pl.String, + pl.Datetime("ns"), + pl.Date, + pl.Binary, + pl.List(pl.Int32), + pl.List(pl.Struct([pl.Field("key", pl.Int32), pl.Field("value", pl.Int32)])), + pl.Struct([pl.Field("a", pl.Int32), pl.Field("b", pl.Int32)]), + pl.Decimal(precision=10, scale=1), + ) + rows = df.to_dicts() + assert rows == [ + { + "col_boolean": True, + "col_tinyint": 127, + "col_smallint": 32767, + "col_int": 2147483647, + "col_bigint": 9223372036854775807, + "col_float": 0.5, + "col_double": 0.25, + "col_string": "a string", + "col_varchar": "varchar", + "col_timestamp": datetime(2017, 1, 1, 0, 0, 0), + "col_date": datetime(2017, 1, 2).date(), + "col_binary": b"123", + "col_array": [1, 2], + "col_map": [{"key": 1, "value": 2}, {"key": 3, "value": 4}], + "col_struct": {"a": 1, "b": 2}, + "col_decimal": Decimal("0.1"), + } + ] + def test_cancel(self, arrow_cursor): def cancel(c): time.sleep(random.randint(5, 10)) @@ -552,6 +728,7 @@ def test_executemany_fetch(self, arrow_cursor): pytest.raises(ProgrammingError, arrow_cursor.fetchmany) pytest.raises(ProgrammingError, arrow_cursor.fetchone) pytest.raises(ProgrammingError, arrow_cursor.as_arrow) + pytest.raises(ProgrammingError, arrow_cursor.as_polars) def test_iceberg_table(self, arrow_cursor): iceberg_table = "test_iceberg_table_arrow_cursor" diff --git a/tests/pyathena/conftest.py b/tests/pyathena/conftest.py index b98f6b76..c244bae1 100644 --- a/tests/pyathena/conftest.py +++ b/tests/pyathena/conftest.py @@ -184,6 +184,20 @@ def async_s3fs_cursor(request): yield from _cursor(AsyncS3FSCursor, request) +@pytest.fixture +def polars_cursor(request): + from pyathena.polars.cursor import PolarsCursor + + yield from _cursor(PolarsCursor, request) + + +@pytest.fixture +def async_polars_cursor(request): + from pyathena.polars.async_cursor import AsyncPolarsCursor + + yield from _cursor(AsyncPolarsCursor, request) + + @pytest.fixture def spark_cursor(request): from pyathena.spark.cursor import SparkCursor diff --git a/tests/pyathena/polars/__init__.py b/tests/pyathena/polars/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/tests/pyathena/polars/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/tests/pyathena/polars/test_async_cursor.py b/tests/pyathena/polars/test_async_cursor.py new file mode 100644 index 00000000..fd9d61ba --- /dev/null +++ b/tests/pyathena/polars/test_async_cursor.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +import contextlib +import random +import string +import time +from datetime import datetime +from random import randint + +import polars as pl +import pytest + +from pyathena.error import NotSupportedError, ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.polars.async_cursor import AsyncPolarsCursor +from pyathena.result_set import AthenaResultSet +from tests import ENV +from tests.pyathena.conftest import connect + + +class TestAsyncPolarsCursor: + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_fetchone(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert result_set.rownumber == 0 + assert result_set.fetchone() == (1,) + assert result_set.rownumber == 1 + assert result_set.fetchone() is None + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_fetchmany(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM many_rows LIMIT 15") + result_set = future.result() + assert len(result_set.fetchmany(10)) == 10 + assert len(result_set.fetchmany(10)) == 5 + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_fetchall(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert result_set.fetchall() == [(1,)] + query_id, future = async_polars_cursor.execute("SELECT a FROM many_rows ORDER BY a") + result_set = future.result() + if async_polars_cursor._unload: + assert sorted(result_set.fetchall()) == [(i,) for i in range(10000)] + else: + assert result_set.fetchall() == [(i,) for i in range(10000)] + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_iterator(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert list(result_set) == [(1,)] + pytest.raises(StopIteration, result_set.__next__) + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_arraysize(self, async_polars_cursor): + async_polars_cursor.arraysize = 5 + query_id, future = async_polars_cursor.execute("SELECT * FROM many_rows LIMIT 20") + result_set = future.result() + assert len(result_set.fetchmany()) == 5 + + def test_arraysize_default(self, async_polars_cursor): + assert async_polars_cursor.arraysize == AthenaResultSet.DEFAULT_FETCH_SIZE + + def test_invalid_arraysize(self, async_polars_cursor): + async_polars_cursor.arraysize = 10000 + assert async_polars_cursor.arraysize == 10000 + with pytest.raises(ProgrammingError): + async_polars_cursor.arraysize = -1 + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_description(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute( + "SELECT CAST(1 AS INT) AS foobar FROM one_row" + ) + result_set = future.result() + assert result_set.fetchall() == [(1,)] + if async_polars_cursor._unload: + assert result_set.description == [("foobar", "integer", None, None, 10, 0, "NULLABLE")] + else: + assert result_set.description == [("foobar", "integer", None, None, 10, 0, "UNKNOWN")] + + future = async_polars_cursor.description(query_id) + description = future.result() + assert result_set.description == description + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_query_execution(self, async_polars_cursor): + query = "SELECT * FROM one_row" + query_id, future = async_polars_cursor.execute(query) + result_set = future.result() + + future = async_polars_cursor.query_execution(query_id) + query_execution = future.result() + + assert query_execution.database == ENV.schema + assert query_execution.catalog + assert query_execution.query_id + if async_polars_cursor._unload: + assert query_execution.query.startswith("UNLOAD") + assert query in query_execution.query + else: + assert query_execution.query == query + assert query_execution.statement_type == AthenaQueryExecution.STATEMENT_TYPE_DML + assert query_execution.work_group == ENV.default_work_group + assert query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED + assert query_execution.state_change_reason is None + assert query_execution.submission_date_time + assert isinstance(query_execution.submission_date_time, datetime) + assert query_execution.completion_date_time + assert isinstance(query_execution.completion_date_time, datetime) + assert query_execution.data_scanned_in_bytes + assert query_execution.engine_execution_time_in_millis + assert query_execution.query_queue_time_in_millis + assert query_execution.total_execution_time_in_millis + assert query_execution.output_location + assert query_execution.encryption_option is None + assert query_execution.kms_key is None + assert query_execution.selected_engine_version + assert query_execution.effective_engine_version + + assert result_set.database == query_execution.database + assert result_set.catalog == query_execution.catalog + assert result_set.query_id == query_execution.query_id + assert result_set.query == query_execution.query + assert result_set.statement_type == query_execution.statement_type + assert result_set.work_group == query_execution.work_group + assert result_set.state == query_execution.state + assert result_set.state_change_reason == query_execution.state_change_reason + assert result_set.submission_date_time == query_execution.submission_date_time + assert result_set.completion_date_time == query_execution.completion_date_time + assert result_set.data_scanned_in_bytes == query_execution.data_scanned_in_bytes + assert ( + result_set.engine_execution_time_in_millis + == query_execution.engine_execution_time_in_millis + ) + assert result_set.query_queue_time_in_millis == query_execution.query_queue_time_in_millis + assert ( + result_set.total_execution_time_in_millis + == query_execution.total_execution_time_in_millis + ) + assert ( + result_set.query_planning_time_in_millis + == query_execution.query_planning_time_in_millis + ) + assert ( + result_set.service_processing_time_in_millis + == query_execution.service_processing_time_in_millis + ) + assert result_set.output_location == query_execution.output_location + assert result_set.data_manifest_location == query_execution.data_manifest_location + assert result_set.encryption_option == query_execution.encryption_option + assert result_set.kms_key == query_execution.kms_key + assert result_set.selected_engine_version == query_execution.selected_engine_version + assert result_set.effective_engine_version == query_execution.effective_engine_version + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_poll(self, async_polars_cursor): + query_id, _ = async_polars_cursor.execute("SELECT * FROM one_row") + future = async_polars_cursor.poll(query_id) + query_execution = future.result() + assert query_execution.state in [ + AthenaQueryExecution.STATE_QUEUED, + AthenaQueryExecution.STATE_RUNNING, + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ] + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_bad_query(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute( + "SELECT does_not_exist FROM this_really_does_not_exist" + ) + result_set = future.result() + assert result_set.state == AthenaQueryExecution.STATE_FAILED + assert result_set.state_change_reason is not None + assert result_set.error_category is not None + assert result_set.error_type is not None + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_as_polars(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM one_row") + assert query_id is not None + df = future.result().as_polars() + assert isinstance(df, pl.DataFrame) + assert df.height == 1 + assert df.width == 1 + assert df.to_dicts() == [{"number_of_rows": 1}] + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_many_as_polars(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM many_rows") + assert query_id is not None + df = future.result().as_polars() + assert isinstance(df, pl.DataFrame) + assert df.height == 10000 + assert df.width == 1 + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_as_arrow(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute("SELECT * FROM one_row") + assert query_id is not None + table = future.result().as_arrow() + assert table.num_rows == 1 + assert table.num_columns == 1 + + def test_cancel(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute( + """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ + ) + time.sleep(randint(5, 10)) + async_polars_cursor.cancel(query_id) + result_set = future.result() + assert result_set.state == AthenaQueryExecution.STATE_CANCELLED + assert result_set.description is None + assert result_set.fetchone() is None + assert result_set.fetchmany() == [] + assert result_set.fetchall() == [] + + def test_open_close(self): + with contextlib.closing(connect()) as conn, conn.cursor(AsyncPolarsCursor): + pass + + def test_no_ops(self): + conn = connect() + cursor = conn.cursor(AsyncPolarsCursor) + pytest.raises(NotSupportedError, lambda: cursor.executemany("SELECT * FROM one_row", [])) + cursor.close() + conn.close() + + @pytest.mark.parametrize( + "async_polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["async_polars_cursor"], + ) + def test_empty_result(self, async_polars_cursor): + table = "test_polars_cursor_empty_result_" + "".join( + random.choices(string.ascii_lowercase + string.digits, k=10) + ) + query_id, future = async_polars_cursor.execute( + f""" + CREATE EXTERNAL TABLE IF NOT EXISTS + {ENV.schema}.{table} (number_of_rows INT) + ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + LINES TERMINATED BY '\n' STORED AS TEXTFILE + LOCATION '{ENV.s3_staging_dir}{ENV.schema}/{table}/' + """ + ) + assert query_id is not None + df = future.result().as_polars() + assert df.height == 0 + assert df.width == 0 + + @pytest.mark.parametrize( + "async_polars_cursor", + [ + { + "cursor_kwargs": {"unload": True}, + }, + ], + indirect=["async_polars_cursor"], + ) + def test_empty_result_unload(self, async_polars_cursor): + query_id, future = async_polars_cursor.execute( + """ + SELECT * FROM one_row LIMIT 0 + """ + ) + assert query_id is not None + df = future.result().as_polars() + assert df.height == 0 + assert df.width == 0 diff --git a/tests/pyathena/polars/test_cursor.py b/tests/pyathena/polars/test_cursor.py new file mode 100644 index 00000000..cd19c958 --- /dev/null +++ b/tests/pyathena/polars/test_cursor.py @@ -0,0 +1,449 @@ +# -*- coding: utf-8 -*- +import contextlib +import random +import string +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from decimal import Decimal + +import polars as pl +import pytest + +from pyathena.error import DatabaseError, ProgrammingError +from pyathena.polars.cursor import PolarsCursor +from pyathena.polars.result_set import AthenaPolarsResultSet +from tests import ENV +from tests.pyathena.conftest import connect + + +class TestPolarsCursor: + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_fetchone(self, polars_cursor): + polars_cursor.execute("SELECT * FROM one_row") + assert polars_cursor.rownumber == 0 + assert polars_cursor.fetchone() == (1,) + assert polars_cursor.rownumber == 1 + assert polars_cursor.fetchone() is None + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_fetchmany(self, polars_cursor): + polars_cursor.execute("SELECT * FROM many_rows LIMIT 15") + assert len(polars_cursor.fetchmany(10)) == 10 + assert len(polars_cursor.fetchmany(10)) == 5 + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_fetchall(self, polars_cursor): + polars_cursor.execute("SELECT * FROM one_row") + assert polars_cursor.fetchall() == [(1,)] + polars_cursor.execute("SELECT a FROM many_rows ORDER BY a") + if polars_cursor._unload: + assert sorted(polars_cursor.fetchall()) == [(i,) for i in range(10000)] + else: + assert polars_cursor.fetchall() == [(i,) for i in range(10000)] + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_iterator(self, polars_cursor): + polars_cursor.execute("SELECT * FROM one_row") + assert list(polars_cursor) == [(1,)] + pytest.raises(StopIteration, polars_cursor.__next__) + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_arraysize(self, polars_cursor): + polars_cursor.arraysize = 5 + polars_cursor.execute("SELECT * FROM many_rows LIMIT 20") + assert len(polars_cursor.fetchmany()) == 5 + + def test_arraysize_default(self, polars_cursor): + assert polars_cursor.arraysize == AthenaPolarsResultSet.DEFAULT_FETCH_SIZE + + def test_invalid_arraysize(self, polars_cursor): + polars_cursor.arraysize = 10000 + assert polars_cursor.arraysize == 10000 + with pytest.raises(ProgrammingError): + polars_cursor.arraysize = -1 + + def test_fetch_no_data(self, polars_cursor): + pytest.raises(ProgrammingError, polars_cursor.fetchone) + pytest.raises(ProgrammingError, polars_cursor.fetchmany) + pytest.raises(ProgrammingError, polars_cursor.fetchall) + pytest.raises(ProgrammingError, polars_cursor.as_polars) + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_as_polars(self, polars_cursor): + df = polars_cursor.execute("SELECT * FROM one_row").as_polars() + assert isinstance(df, pl.DataFrame) + assert df.height == 1 + assert df.width == 1 + assert df.to_dicts() == [{"number_of_rows": 1}] + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_many_as_polars(self, polars_cursor): + df = polars_cursor.execute("SELECT * FROM many_rows").as_polars() + assert isinstance(df, pl.DataFrame) + assert df.height == 10000 + assert df.width == 1 + + def test_complex(self, polars_cursor): + polars_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,col_date + ,col_decimal + FROM one_row_complex + """ + ) + assert polars_cursor.description == [ + ("col_boolean", "boolean", None, None, 0, 0, "UNKNOWN"), + ("col_tinyint", "tinyint", None, None, 3, 0, "UNKNOWN"), + ("col_smallint", "smallint", None, None, 5, 0, "UNKNOWN"), + ("col_int", "integer", None, None, 10, 0, "UNKNOWN"), + ("col_bigint", "bigint", None, None, 19, 0, "UNKNOWN"), + ("col_float", "float", None, None, 17, 0, "UNKNOWN"), + ("col_double", "double", None, None, 17, 0, "UNKNOWN"), + ("col_string", "varchar", None, None, 2147483647, 0, "UNKNOWN"), + ("col_varchar", "varchar", None, None, 10, 0, "UNKNOWN"), + ("col_timestamp", "timestamp", None, None, 3, 0, "UNKNOWN"), + ("col_date", "date", None, None, 0, 0, "UNKNOWN"), + ("col_decimal", "decimal", None, None, 10, 1, "UNKNOWN"), + ] + assert polars_cursor.fetchall() == [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 2).date(), + Decimal("0.1"), + ) + ] + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True}, + }, + ], + indirect=["polars_cursor"], + ) + def test_complex_unload(self, polars_cursor): + polars_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,col_date + ,col_decimal + FROM one_row_complex + """ + ) + assert polars_cursor.description == [ + ("col_boolean", "boolean", None, None, 0, 0, "NULLABLE"), + ("col_tinyint", "tinyint", None, None, 3, 0, "NULLABLE"), + ("col_smallint", "smallint", None, None, 5, 0, "NULLABLE"), + ("col_int", "integer", None, None, 10, 0, "NULLABLE"), + ("col_bigint", "bigint", None, None, 19, 0, "NULLABLE"), + ("col_float", "float", None, None, 17, 0, "NULLABLE"), + ("col_double", "double", None, None, 17, 0, "NULLABLE"), + ("col_string", "varchar", None, None, 2147483647, 0, "NULLABLE"), + ("col_varchar", "varchar", None, None, 2147483647, 0, "NULLABLE"), + ("col_timestamp", "timestamp", None, None, 3, 0, "NULLABLE"), + ("col_date", "date", None, None, 0, 0, "NULLABLE"), + ("col_decimal", "decimal", None, None, 10, 1, "NULLABLE"), + ] + assert polars_cursor.fetchall() == [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 2).date(), + Decimal("0.1"), + ) + ] + + def test_complex_as_polars(self, polars_cursor): + df = polars_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,col_date + ,col_decimal + FROM one_row_complex + """ + ).as_polars() + assert isinstance(df, pl.DataFrame) + assert (df.height, df.width) == (1, 12) + assert df.schema == { + "col_boolean": pl.Boolean, + "col_tinyint": pl.Int8, + "col_smallint": pl.Int16, + "col_int": pl.Int32, + "col_bigint": pl.Int64, + "col_float": pl.Float32, + "col_double": pl.Float64, + "col_string": pl.String, + "col_varchar": pl.String, + "col_timestamp": pl.Datetime("us"), + "col_date": pl.Date, + "col_decimal": pl.Decimal(precision=10, scale=1), + } + assert df.row(0) == ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 2).date(), + Decimal("0.1"), + ) + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True}, + }, + ], + indirect=["polars_cursor"], + ) + def test_complex_unload_as_polars(self, polars_cursor): + df = polars_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,col_date + ,col_decimal + FROM one_row_complex + """ + ).as_polars() + assert isinstance(df, pl.DataFrame) + assert (df.height, df.width) == (1, 12) + assert df.row(0) == ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 2).date(), + Decimal("0.1"), + ) + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_as_arrow(self, polars_cursor): + table = polars_cursor.execute("SELECT * FROM one_row").as_arrow() + assert table.num_rows == 1 + assert table.num_columns == 1 + + def test_cancel(self, polars_cursor): + def cancel(c): + time.sleep(random.randint(5, 10)) + c.cancel() + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit(cancel, polars_cursor) + + pytest.raises( + DatabaseError, + lambda: polars_cursor.execute( + """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ + ), + ) + + def test_cancel_initial(self, polars_cursor): + pytest.raises(ProgrammingError, polars_cursor.cancel) + + def test_open_close(self): + with contextlib.closing(connect()) as conn, conn.cursor(PolarsCursor): + pass + + def test_no_ops(self): + conn = connect() + cursor = conn.cursor(PolarsCursor) + cursor.close() + conn.close() + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_show_columns(self, polars_cursor): + polars_cursor.execute("SHOW COLUMNS IN one_row") + assert polars_cursor.description == [("field", "string", None, None, 0, 0, "UNKNOWN")] + assert polars_cursor.fetchall() == [("number_of_rows ",)] + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_empty_result(self, polars_cursor): + table = "test_polars_cursor_empty_result_" + "".join( + random.choices(string.ascii_lowercase + string.digits, k=10) + ) + df = polars_cursor.execute( + f""" + CREATE EXTERNAL TABLE IF NOT EXISTS + {ENV.schema}.{table} (number_of_rows INT) + ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + LINES TERMINATED BY '\n' STORED AS TEXTFILE + LOCATION '{ENV.s3_staging_dir}{ENV.schema}/{table}/' + """ + ).as_polars() + assert df.height == 0 + assert df.width == 0 + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True}, + }, + ], + indirect=["polars_cursor"], + ) + def test_empty_result_unload(self, polars_cursor): + df = polars_cursor.execute( + """ + SELECT * FROM one_row LIMIT 0 + """ + ).as_polars() + assert df.height == 0 + assert df.width == 0 + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_executemany(self, polars_cursor): + rows = [(1, "foo"), (2, "bar"), (3, "jim o'rourke")] + table_name = f"execute_many_polars{'_unload' if polars_cursor._unload else ''}" + polars_cursor.executemany( + f"INSERT INTO {table_name} (a, b) VALUES (%(a)d, %(b)s)", + [{"a": a, "b": b} for a, b in rows], + ) + polars_cursor.execute(f"SELECT * FROM {table_name}") + assert sorted(polars_cursor.fetchall()) == list(rows) + + @pytest.mark.parametrize( + "polars_cursor", + [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], + indirect=["polars_cursor"], + ) + def test_executemany_fetch(self, polars_cursor): + polars_cursor.executemany("SELECT %(x)d AS x FROM one_row", [{"x": i} for i in range(1, 2)]) + # Operations that have result sets are not allowed with executemany. + pytest.raises(ProgrammingError, polars_cursor.fetchall) + pytest.raises(ProgrammingError, polars_cursor.fetchmany) + pytest.raises(ProgrammingError, polars_cursor.fetchone) + pytest.raises(ProgrammingError, polars_cursor.as_polars) + + def test_execute_with_callback(self, polars_cursor): + """Test that callback is invoked with query_id when on_start_query_execution is provided.""" + callback_results = [] + + def test_callback(query_id: str): + callback_results.append(query_id) + + polars_cursor.execute("SELECT 1", on_start_query_execution=test_callback) + + assert len(callback_results) == 1 + assert callback_results[0] == polars_cursor.query_id + assert polars_cursor.query_id is not None diff --git a/tests/resources/queries/create_table.sql.jinja2 b/tests/resources/queries/create_table.sql.jinja2 index 7c4f427f..138be616 100644 --- a/tests/resources/queries/create_table.sql.jinja2 +++ b/tests/resources/queries/create_table.sql.jinja2 @@ -107,6 +107,22 @@ CREATE EXTERNAL TABLE IF NOT EXISTS {{ schema }}.execute_many_arrow_unload ( ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '{{ s3_staging_dir }}{{ schema }}/execute_many_arrow_unload/'; +DROP TABLE IF EXISTS {{ schema }}.execute_many_polars; +CREATE EXTERNAL TABLE IF NOT EXISTS {{ schema }}.execute_many_polars ( + a INT, + b STRING +) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE +LOCATION '{{ s3_staging_dir }}{{ schema }}/execute_many_polars/'; + +DROP TABLE IF EXISTS {{ schema }}.execute_many_polars_unload; +CREATE EXTERNAL TABLE IF NOT EXISTS {{ schema }}.execute_many_polars_unload ( + a INT, + b STRING +) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE +LOCATION '{{ s3_staging_dir }}{{ schema }}/execute_many_polars_unload/'; + DROP TABLE IF EXISTS {{ schema }}.parquet_with_compression; CREATE EXTERNAL TABLE IF NOT EXISTS {{ schema }}.parquet_with_compression ( a INT diff --git a/uv.lock b/uv.lock index 52a9da3a..559382f4 100644 --- a/uv.lock +++ b/uv.lock @@ -738,6 +738,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "polars" +version = "1.36.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "polars-runtime-32" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/dc/56f2a90c79a2cb13f9e956eab6385effe54216ae7a2068b3a6406bae4345/polars-1.36.1.tar.gz", hash = "sha256:12c7616a2305559144711ab73eaa18814f7aa898c522e7645014b68f1432d54c", size = 711993 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/c6/36a1b874036b49893ecae0ac44a2f63d1a76e6212631a5b2f50a86e0e8af/polars-1.36.1-py3-none-any.whl", hash = "sha256:853c1bbb237add6a5f6d133c15094a9b727d66dd6a4eb91dbb07cdb056b2b8ef", size = 802429 }, +] + +[[package]] +name = "polars-runtime-32" +version = "1.36.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/31/df/597c0ef5eb8d761a16d72327846599b57c5d40d7f9e74306fc154aba8c37/polars_runtime_32-1.36.1.tar.gz", hash = "sha256:201c2cfd80ceb5d5cd7b63085b5fd08d6ae6554f922bcb941035e39638528a09", size = 2788751 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/ea/871129a2d296966c0925b078a9a93c6c5e7facb1c5eebfcd3d5811aeddc1/polars_runtime_32-1.36.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:327b621ca82594f277751f7e23d4b939ebd1be18d54b4cdf7a2f8406cecc18b2", size = 43494311 }, + { url = "https://files.pythonhosted.org/packages/d8/76/0038210ad1e526ce5bb2933b13760d6b986b3045eccc1338e661bd656f77/polars_runtime_32-1.36.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:ab0d1f23084afee2b97de8c37aa3e02ec3569749ae39571bd89e7a8b11ae9e83", size = 39300602 }, + { url = "https://files.pythonhosted.org/packages/54/1e/2707bee75a780a953a77a2c59829ee90ef55708f02fc4add761c579bf76e/polars_runtime_32-1.36.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:899b9ad2e47ceb31eb157f27a09dbc2047efbf4969a923a6b1ba7f0412c3e64c", size = 44511780 }, + { url = "https://files.pythonhosted.org/packages/11/b2/3fede95feee441be64b4bcb32444679a8fbb7a453a10251583053f6efe52/polars_runtime_32-1.36.1-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:d9d077bb9df711bc635a86540df48242bb91975b353e53ef261c6fae6cb0948f", size = 40688448 }, + { url = "https://files.pythonhosted.org/packages/05/0f/e629713a72999939b7b4bfdbf030a32794db588b04fdf3dc977dd8ea6c53/polars_runtime_32-1.36.1-cp39-abi3-win_amd64.whl", hash = "sha256:cc17101f28c9a169ff8b5b8d4977a3683cd403621841623825525f440b564cf0", size = 44464898 }, + { url = "https://files.pythonhosted.org/packages/d1/d8/a12e6aa14f63784cead437083319ec7cece0d5bb9a5bfe7678cc6578b52a/polars_runtime_32-1.36.1-cp39-abi3-win_arm64.whl", hash = "sha256:809e73857be71250141225ddd5d2b30c97e6340aeaa0d445f930e01bef6888dc", size = 39798896 }, +] + [[package]] name = "pyarrow" version = "18.1.0" @@ -866,6 +892,9 @@ pandas = [ { name = "pandas", version = "2.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, ] +polars = [ + { name = "polars" }, +] sqlalchemy = [ { name = "sqlalchemy" }, ] @@ -876,9 +905,10 @@ dev = [ { name = "jinja2" }, { name = "mypy" }, { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, - { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, + { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "pandas", version = "2.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, + { name = "polars" }, { name = "pyarrow", version = "18.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "pyarrow", version = "22.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "pytest" }, @@ -898,6 +928,7 @@ requires-dist = [ { name = "fsspec" }, { name = "pandas", marker = "python_full_version >= '3.13' and extra == 'pandas'", specifier = ">=2.3.0" }, { name = "pandas", marker = "python_full_version < '3.13' and extra == 'pandas'", specifier = ">=1.3.0" }, + { name = "polars", marker = "extra == 'polars'", specifier = ">=1.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14' and extra == 'arrow'", specifier = ">=22.0.0" }, { name = "pyarrow", marker = "python_full_version < '3.14' and extra == 'arrow'", specifier = ">=10.0.0" }, { name = "python-dateutil" }, @@ -911,9 +942,10 @@ dev = [ { name = "jinja2", specifier = ">=3.1.0" }, { name = "mypy", specifier = ">=0.900" }, { name = "numpy", marker = "python_full_version < '3.13'", specifier = ">=1.26.0" }, - { name = "numpy", marker = "python_full_version >= '3.13'", specifier = ">=2.3.0" }, + { name = "numpy", marker = "python_full_version >= '3.14'", specifier = ">=2.3.0" }, { name = "pandas", marker = "python_full_version < '3.13'", specifier = ">=1.3.0" }, { name = "pandas", marker = "python_full_version >= '3.13'", specifier = ">=2.3.0" }, + { name = "polars", specifier = ">=1.0.0" }, { name = "pyarrow", marker = "python_full_version < '3.14'", specifier = ">=10.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14'", specifier = ">=22.0.0" }, { name = "pytest", specifier = ">=3.5" },