From b45c434fbc73859ce6d83a61116d6906a795e219 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 14:28:05 +0900 Subject: [PATCH 1/6] Add S3FS Cursor for lightweight CSV-based result reading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements Issue #272: Add a new cursor type that reads CSV results from S3 using Python's standard csv module and PyAthena's S3FileSystem, without requiring pandas or pyarrow dependencies. New features: - S3FSCursor: Synchronous cursor for reading CSV/TXT results from S3 - AsyncS3FSCursor: Asynchronous cursor using concurrent.futures - AthenaS3FSResultSet: Streaming CSV reader with type conversion - DefaultS3FSTypeConverter: Type converter for CSV-based results - SQLAlchemy dialect: awsathena+s3fs:// connection URL support Also adds rowcount property to WithResultSet mixin for CTAS support, benefiting all cursor types (base, pandas, arrow, s3fs). Closes #272 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- pyathena/result_set.py | 13 + pyathena/s3fs/__init__.py | 1 + pyathena/s3fs/async_cursor.py | 218 ++++++++++++++++ pyathena/s3fs/converter.py | 50 ++++ pyathena/s3fs/cursor.py | 313 +++++++++++++++++++++++ pyathena/s3fs/result_set.py | 233 +++++++++++++++++ pyathena/sqlalchemy/s3fs.py | 36 +++ pyproject.toml | 1 + tests/pyathena/arrow/test_cursor.py | 17 ++ tests/pyathena/conftest.py | 14 + tests/pyathena/pandas/test_cursor.py | 17 ++ tests/pyathena/s3fs/__init__.py | 1 + tests/pyathena/s3fs/test_async_cursor.py | 175 +++++++++++++ tests/pyathena/s3fs/test_cursor.py | 277 ++++++++++++++++++++ tests/pyathena/test_cursor.py | 21 ++ tests/sqlalchemy/__init__.py | 1 + 16 files changed, 1388 insertions(+) create mode 100644 pyathena/s3fs/__init__.py create mode 100644 pyathena/s3fs/async_cursor.py create mode 100644 pyathena/s3fs/converter.py create mode 100644 pyathena/s3fs/cursor.py create mode 100644 pyathena/s3fs/result_set.py create mode 100644 pyathena/sqlalchemy/s3fs.py create mode 100644 tests/pyathena/s3fs/__init__.py create mode 100644 tests/pyathena/s3fs/test_async_cursor.py create mode 100644 tests/pyathena/s3fs/test_cursor.py diff --git a/pyathena/result_set.py b/pyathena/result_set.py index f144de21..ae2c3972 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -736,3 +736,16 @@ def result_reuse_minutes(self) -> Optional[int]: if not self.result_set: return None return self.result_set.result_reuse_minutes + + @property + def rowcount(self) -> int: + """Get the number of rows affected by the last operation. + + For SELECT statements, this returns -1 as per DB API 2.0 specification. + For DML operations (INSERT, UPDATE, DELETE) and CTAS, this returns + the number of affected rows. + + Returns: + The number of rows, or -1 if not applicable or unknown. + """ + return self.result_set.rowcount if self.result_set else -1 diff --git a/pyathena/s3fs/__init__.py b/pyathena/s3fs/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/pyathena/s3fs/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py new file mode 100644 index 00000000..8de1379e --- /dev/null +++ b/pyathena/s3fs/async_cursor.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from concurrent.futures import Future +from multiprocessing import cpu_count +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from pyathena.async_cursor import AsyncCursor +from pyathena.common import CursorIterator +from pyathena.error import ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.s3fs.converter import DefaultS3FSTypeConverter +from pyathena.s3fs.result_set import AthenaS3FSResultSet + +_logger = logging.getLogger(__name__) + + +class AsyncS3FSCursor(AsyncCursor): + """Asynchronous cursor that reads CSV results via S3FileSystem. + + This cursor extends AsyncCursor to provide asynchronous query execution + with results read via Python's standard csv module and PyAthena's S3FileSystem. + It's a lightweight alternative when pandas/pyarrow are not needed. + + Features: + - Asynchronous query execution with concurrent futures + - Uses Python's standard csv module for parsing + - Uses PyAthena's S3FileSystem for S3 access + - No external dependencies beyond boto3 + - Memory-efficient streaming for large datasets + + Attributes: + arraysize: Number of rows to fetch per batch (configurable). + + Example: + >>> from pyathena.s3fs.async_cursor import AsyncS3FSCursor + >>> + >>> cursor = connection.cursor(AsyncS3FSCursor) + >>> query_id, future = cursor.execute("SELECT * FROM my_table") + >>> + >>> # Get result when ready + >>> result_set = future.result() + >>> rows = result_set.fetchall() + + Note: + This cursor does not require pandas or pyarrow. + """ + + def __init__( + self, + s3_staging_dir: Optional[str] = None, + schema_name: Optional[str] = None, + catalog_name: Optional[str] = None, + work_group: Optional[str] = None, + poll_interval: float = 1, + encryption_option: Optional[str] = None, + kms_key: Optional[str] = None, + kill_on_interrupt: bool = True, + max_workers: int = (cpu_count() or 1) * 5, + arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, + result_reuse_enable: bool = False, + result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + **kwargs, + ) -> None: + """Initialize an AsyncS3FSCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + max_workers: Maximum number of workers for concurrent execution. + arraysize: Number of rows to fetch per batch. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + **kwargs: Additional connection parameters. + + Example: + >>> cursor = connection.cursor(AsyncS3FSCursor) + >>> query_id, future = cursor.execute("SELECT * FROM my_table") + """ + super().__init__( + s3_staging_dir=s3_staging_dir, + schema_name=schema_name, + catalog_name=catalog_name, + work_group=work_group, + poll_interval=poll_interval, + encryption_option=encryption_option, + kms_key=kms_key, + kill_on_interrupt=kill_on_interrupt, + max_workers=max_workers, + arraysize=arraysize, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + **kwargs, + ) + + @staticmethod + def get_default_converter( + unload: bool = False, # noqa: ARG004 + ) -> DefaultS3FSTypeConverter: + """Get the default type converter for S3FS cursor. + + Args: + unload: Unused. S3FS cursor does not support UNLOAD operations. + + Returns: + DefaultS3FSTypeConverter instance. + """ + return DefaultS3FSTypeConverter() + + @property + def arraysize(self) -> int: + """Get the number of rows to fetch at a time.""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + """Set the number of rows to fetch at a time. + + Args: + value: Number of rows (must be positive). + + Raises: + ProgrammingError: If value is not positive. + """ + if value <= 0: + raise ProgrammingError("arraysize must be a positive integer value.") + self._arraysize = value + + def _collect_result_set( + self, + query_id: str, + kwargs: Optional[Dict[str, Any]] = None, + ) -> AthenaS3FSResultSet: + """Collect result set after query execution. + + Args: + query_id: The Athena query execution ID. + kwargs: Additional keyword arguments for result set. + + Returns: + AthenaS3FSResultSet containing the query results. + """ + if kwargs is None: + kwargs = {} + query_execution = cast(AthenaQueryExecution, self._poll(query_id)) + return AthenaS3FSResultSet( + connection=self._connection, + converter=self._converter, + query_execution=query_execution, + arraysize=self._arraysize, + retry_config=self._retry_config, + **kwargs, + ) + + def execute( + self, + operation: str, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, + work_group: Optional[str] = None, + s3_staging_dir: Optional[str] = None, + cache_size: Optional[int] = 0, + cache_expiration_time: Optional[int] = 0, + result_reuse_enable: Optional[bool] = None, + result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, + **kwargs, + ) -> Tuple[str, "Future[Union[AthenaS3FSResultSet, Any]]"]: + """Execute a SQL query asynchronously. + + Submits the query to Athena and returns immediately with a query ID + and a Future that will contain the result set when complete. + + Args: + operation: SQL query string to execute. + parameters: Query parameters for parameterized queries. + work_group: Athena workgroup to use for this query. + s3_staging_dir: S3 location for query results. + cache_size: Number of queries to check for result caching. + cache_expiration_time: Cache expiration time in seconds. + result_reuse_enable: Enable Athena result reuse for this query. + result_reuse_minutes: Minutes to reuse cached results. + paramstyle: Parameter style ('qmark' or 'pyformat'). + **kwargs: Additional execution parameters. + + Returns: + Tuple of (query_id, Future[AthenaS3FSResultSet]). + + Example: + >>> query_id, future = cursor.execute("SELECT * FROM my_table") + >>> result_set = future.result() + >>> rows = result_set.fetchall() + """ + query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, + ) + return ( + query_id, + self._executor.submit( + self._collect_result_set, + query_id, + kwargs, + ), + ) diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py new file mode 100644 index 00000000..0df11b3e --- /dev/null +++ b/pyathena/s3fs/converter.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from copy import deepcopy +from typing import Any, Optional + +from pyathena.converter import ( + _DEFAULT_CONVERTERS, + Converter, + _to_default, +) + +_logger = logging.getLogger(__name__) + + +class DefaultS3FSTypeConverter(Converter): + """Type converter for S3FS Cursor results. + + This converter is specifically designed for the S3FSCursor and provides + type conversion for CSV-based result files read via the S3 FileSystem. + It converts Athena data types to Python types using the standard + converter mappings. + + The converter uses the same mappings as DefaultTypeConverter, providing + consistent behavior with the standard Cursor while using the S3FileSystem + for file access. + + Example: + >>> from pyathena.s3fs.converter import DefaultS3FSTypeConverter + >>> converter = DefaultS3FSTypeConverter() + >>> + >>> # Used automatically by S3FSCursor + >>> cursor = connection.cursor(S3FSCursor) + >>> # converter is applied automatically to results + + Note: + This converter is used by default in S3FSCursor. + Most users don't need to instantiate it directly. + """ + + def __init__(self) -> None: + super().__init__( + mappings=deepcopy(_DEFAULT_CONVERTERS), + default=_to_default, + ) + + def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + converter = self.get(type_) + return converter(value) diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py new file mode 100644 index 00000000..33cef5cd --- /dev/null +++ b/pyathena/s3fs/cursor.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +from pyathena.common import BaseCursor, CursorIterator +from pyathena.error import OperationalError, ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.result_set import WithResultSet +from pyathena.s3fs.converter import DefaultS3FSTypeConverter +from pyathena.s3fs.result_set import AthenaS3FSResultSet + +_logger = logging.getLogger(__name__) + + +class S3FSCursor(BaseCursor, CursorIterator, WithResultSet): + """Cursor for reading CSV results via S3FileSystem without pandas/pyarrow. + + This cursor uses Python's standard csv module and PyAthena's S3FileSystem + to read query results from S3. It provides a lightweight alternative to + pandas and arrow cursors when those dependencies are not needed. + + The cursor is especially useful for: + - Environments where pandas/pyarrow installation is not desired + - Simple queries where advanced data processing is not required + - Memory-constrained environments + + Attributes: + description: Sequence of column descriptions for the last query. + rowcount: Number of rows affected by the last query (-1 for SELECT queries). + arraysize: Default number of rows to fetch with fetchmany(). + + Example: + >>> from pyathena.s3fs.cursor import S3FSCursor + >>> cursor = connection.cursor(S3FSCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> rows = cursor.fetchall() # Returns list of tuples + >>> + >>> # Iterate over results + >>> for row in cursor.execute("SELECT * FROM my_table"): + ... print(row) + + # Use with SQLAlchemy + >>> from sqlalchemy import create_engine + >>> engine = create_engine("awsathena+s3fs://...") + """ + + def __init__( + self, + s3_staging_dir: Optional[str] = None, + schema_name: Optional[str] = None, + catalog_name: Optional[str] = None, + work_group: Optional[str] = None, + poll_interval: float = 1, + encryption_option: Optional[str] = None, + kms_key: Optional[str] = None, + kill_on_interrupt: bool = True, + result_reuse_enable: bool = False, + result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + on_start_query_execution: Optional[Callable[[str], None]] = None, + **kwargs, + ) -> None: + """Initialize an S3FSCursor. + + Args: + s3_staging_dir: S3 location for query results. + schema_name: Default schema name. + catalog_name: Default catalog name. + work_group: Athena workgroup name. + poll_interval: Query status polling interval in seconds. + encryption_option: S3 encryption option (SSE_S3, SSE_KMS, CSE_KMS). + kms_key: KMS key ARN for encryption. + kill_on_interrupt: Cancel running query on keyboard interrupt. + result_reuse_enable: Enable Athena query result reuse. + result_reuse_minutes: Minutes to reuse cached results. + on_start_query_execution: Callback invoked when query starts. + **kwargs: Additional connection parameters. + + Example: + >>> cursor = connection.cursor(S3FSCursor) + >>> cursor.execute("SELECT * FROM my_table") + """ + super().__init__( + s3_staging_dir=s3_staging_dir, + schema_name=schema_name, + catalog_name=catalog_name, + work_group=work_group, + poll_interval=poll_interval, + encryption_option=encryption_option, + kms_key=kms_key, + kill_on_interrupt=kill_on_interrupt, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + **kwargs, + ) + self._on_start_query_execution = on_start_query_execution + self._query_id: Optional[str] = None + self._result_set: Optional[AthenaS3FSResultSet] = None + + @staticmethod + def get_default_converter( + unload: bool = False, # noqa: ARG004 + ) -> DefaultS3FSTypeConverter: + """Get the default type converter for S3FS cursor. + + Args: + unload: Unused. S3FS cursor does not support UNLOAD operations. + + Returns: + DefaultS3FSTypeConverter instance. + """ + return DefaultS3FSTypeConverter() + + @property + def arraysize(self) -> int: + """Get the number of rows to fetch at a time with fetchmany().""" + return self._arraysize + + @arraysize.setter + def arraysize(self, value: int) -> None: + """Set the number of rows to fetch at a time with fetchmany(). + + Args: + value: Number of rows (must be positive). + + Raises: + ProgrammingError: If value is not positive. + """ + if value <= 0: + raise ProgrammingError("arraysize must be a positive integer value.") + self._arraysize = value + + @property # type: ignore + def result_set(self) -> Optional[AthenaS3FSResultSet]: + """Get the current result set.""" + return self._result_set + + @result_set.setter + def result_set(self, val) -> None: + """Set the current result set.""" + self._result_set = val + + @property + def query_id(self) -> Optional[str]: + """Get the ID of the last executed query.""" + return self._query_id + + @query_id.setter + def query_id(self, val) -> None: + """Set the query ID.""" + self._query_id = val + + @property + def rownumber(self) -> Optional[int]: + """Get the current row number (0-indexed).""" + return self.result_set.rownumber if self.result_set else None + + def close(self) -> None: + """Close the cursor and release resources.""" + if self.result_set and not self.result_set.is_closed: + self.result_set.close() + + def execute( + self, + operation: str, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, + work_group: Optional[str] = None, + s3_staging_dir: Optional[str] = None, + cache_size: Optional[int] = 0, + cache_expiration_time: Optional[int] = 0, + result_reuse_enable: Optional[bool] = None, + result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, + on_start_query_execution: Optional[Callable[[str], None]] = None, + **kwargs, + ) -> "S3FSCursor": + """Execute a SQL query and return results. + + Executes the SQL query on Amazon Athena and configures the result set + for CSV-based output via S3FileSystem. + + Args: + operation: SQL query string to execute. + parameters: Query parameters for parameterized queries. + work_group: Athena workgroup to use for this query. + s3_staging_dir: S3 location for query results. + cache_size: Number of queries to check for result caching. + cache_expiration_time: Cache expiration time in seconds. + result_reuse_enable: Enable Athena result reuse for this query. + result_reuse_minutes: Minutes to reuse cached results. + paramstyle: Parameter style ('qmark' or 'pyformat'). + on_start_query_execution: Callback called when query starts. + **kwargs: Additional execution parameters. + + Returns: + Self reference for method chaining. + + Example: + >>> cursor.execute("SELECT * FROM my_table WHERE id = %(id)s", {"id": 123}) + >>> rows = cursor.fetchall() + """ + self._reset_state() + self.query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + cache_expiration_time=cache_expiration_time, + result_reuse_enable=result_reuse_enable, + result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, + ) + + # Call user callbacks immediately after start_query_execution + if self._on_start_query_execution: + self._on_start_query_execution(self.query_id) + if on_start_query_execution: + on_start_query_execution(self.query_id) + + query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) + if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: + self.result_set = AthenaS3FSResultSet( + connection=self._connection, + converter=self._converter, + query_execution=query_execution, + arraysize=self.arraysize, + retry_config=self._retry_config, + **kwargs, + ) + else: + raise OperationalError(query_execution.state_change_reason) + return self + + def executemany( + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, + ) -> None: + """Execute a SQL query with multiple parameter sets. + + Args: + operation: SQL query string to execute. + seq_of_parameters: Sequence of parameter sets. + **kwargs: Additional execution parameters. + """ + for parameters in seq_of_parameters: + self.execute(operation, parameters, **kwargs) + # Operations that have result sets are not allowed with executemany. + self._reset_state() + + def cancel(self) -> None: + """Cancel the currently running query. + + Raises: + ProgrammingError: If no query is running. + """ + if not self.query_id: + raise ProgrammingError("QueryExecutionId is none or empty.") + self._cancel(self.query_id) + + def fetchone( + self, + ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next row of the result set. + + Returns: + A tuple representing the next row, or None if no more rows. + + Raises: + ProgrammingError: If no query has been executed. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaS3FSResultSet, self.result_set) + return result_set.fetchone() + + def fetchmany( + self, size: Optional[int] = None + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next set of rows of the result set. + + Args: + size: Maximum number of rows to fetch. Defaults to arraysize. + + Returns: + A list of tuples representing the rows. + + Raises: + ProgrammingError: If no query has been executed. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaS3FSResultSet, self.result_set) + return result_set.fetchmany(size) + + def fetchall( + self, + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch all remaining rows of the result set. + + Returns: + A list of tuples representing all remaining rows. + + Raises: + ProgrammingError: If no query has been executed. + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaS3FSResultSet, self.result_set) + return result_set.fetchall() diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py new file mode 100644 index 00000000..30781d4d --- /dev/null +++ b/pyathena/s3fs/result_set.py @@ -0,0 +1,233 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import contextlib +import csv +import logging +from io import TextIOWrapper +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from pyathena.converter import Converter +from pyathena.error import OperationalError, ProgrammingError +from pyathena.filesystem.s3 import S3FileSystem +from pyathena.model import AthenaQueryExecution +from pyathena.result_set import AthenaResultSet +from pyathena.util import RetryConfig, parse_output_location + +if TYPE_CHECKING: + from pyathena.connection import Connection + +_logger = logging.getLogger(__name__) + + +class AthenaS3FSResultSet(AthenaResultSet): + """Result set that reads CSV results via S3FileSystem without pandas/pyarrow. + + This result set uses Python's standard csv module and PyAthena's S3FileSystem + to read query results from S3. It provides a lightweight alternative to pandas + and arrow cursors when those dependencies are not needed. + + Features: + - Uses Python's standard csv module for parsing + - Uses PyAthena's S3FileSystem for S3 access + - No external dependencies beyond boto3 + - Memory-efficient streaming for large datasets + + Attributes: + DEFAULT_BLOCK_SIZE: Default block size for S3 operations (128MB). + + Example: + >>> # Used automatically by S3FSCursor + >>> cursor = connection.cursor(S3FSCursor) + >>> cursor.execute("SELECT * FROM my_table") + >>> + >>> # Fetch results + >>> rows = cursor.fetchall() + + Note: + This class is used internally by S3FSCursor and typically not + instantiated directly by users. + """ + + DEFAULT_FETCH_SIZE: int = 1000 + DEFAULT_BLOCK_SIZE = 1024 * 1024 * 128 + + def __init__( + self, + connection: "Connection[Any]", + converter: Converter, + query_execution: AthenaQueryExecution, + arraysize: int, + retry_config: RetryConfig, + block_size: Optional[int] = None, + **kwargs, + ) -> None: + super().__init__( + connection=connection, + converter=converter, + query_execution=query_execution, + arraysize=1, # Fetch one row to retrieve metadata + retry_config=retry_config, + ) + # Save pre-fetched rows (from Athena API) in case CSV reading is not available + pre_fetched_rows = list(self._rows) + self._rows.clear() + self._arraysize = arraysize + self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE + self._fs = self._create_s3_file_system() + self._csv_reader: Optional[Any] = None + self._csv_file: Optional[Any] = None + self._header_skipped = False + self._has_header = False # CSV files have headers, TXT files don't + + if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: + self._init_csv_reader() + + # If CSV reader was not initialized (e.g., CTAS, DDL), + # fall back to pre-fetched data from Athena API + if not self._csv_reader and pre_fetched_rows: + self._rows.extend(pre_fetched_rows) + + def _create_s3_file_system(self) -> S3FileSystem: + """Create S3FileSystem using connection settings.""" + return S3FileSystem( + connection=self.connection, + default_block_size=self._block_size, + ) + + def _init_csv_reader(self) -> None: + """Initialize CSV reader for the output file.""" + if not self.output_location: + raise ProgrammingError("OutputLocation is none or empty.") + + if not self.output_location.endswith((".csv", ".txt")): + return + + # Skip for UPDATE/DELETE/MERGE/VACUUM operations + if self.substatement_type and self.substatement_type.upper() in ( + "UPDATE", + "DELETE", + "MERGE", + "VACUUM_TABLE", + ): + return + + length = self._get_content_length() + if not length: + return + + bucket, key = parse_output_location(self.output_location) + path = f"{bucket}/{key}" + + try: + self._csv_file = self._fs._open(path, mode="rb") + text_wrapper = TextIOWrapper(self._csv_file, encoding="utf-8") + + if self.output_location.endswith(".txt"): + # Tab-separated format (no header row) + self._csv_reader = csv.reader(text_wrapper, delimiter="\t") + self._has_header = False + else: + # Standard CSV format (has header row) + self._csv_reader = csv.reader(text_wrapper) + self._has_header = True + + except Exception as e: + _logger.exception(f"Failed to open {path}.") + raise OperationalError(*e.args) from e + + def _fetch(self) -> None: + """Fetch next batch of rows from CSV.""" + if not self._csv_reader: + return + + # Skip header row on first fetch (only for CSV files, not TXT) + if self._has_header and not self._header_skipped: + try: + next(self._csv_reader) + self._header_skipped = True + except StopIteration: + return + + description = self.description if self.description else [] + column_types = [d[1] for d in description] + + rows_fetched = 0 + while rows_fetched < self._arraysize: + try: + row = next(self._csv_reader) + except StopIteration: + break + + # Convert row values using converters + converted_row = tuple( + self._converter.convert(col_type, value if value != "" else None) + for col_type, value in zip(column_types, row, strict=False) + ) + self._rows.append(converted_row) + rows_fetched += 1 + + def fetchone( + self, + ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next row of the result set. + + Returns: + A tuple representing the next row, or None if no more rows. + """ + if not self._rows: + self._fetch() + if not self._rows: + return None + if self._rownumber is None: + self._rownumber = 0 + self._rownumber += 1 + return self._rows.popleft() + + def fetchmany( + self, size: Optional[int] = None + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch the next set of rows of the result set. + + Args: + size: Maximum number of rows to fetch. Defaults to arraysize. + + Returns: + A list of tuples representing the rows. + """ + if not size or size <= 0: + size = self._arraysize + rows = [] + for _ in range(size): + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def fetchall( + self, + ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + """Fetch all remaining rows of the result set. + + Returns: + A list of tuples representing all remaining rows. + """ + rows = [] + while True: + row = self.fetchone() + if row: + rows.append(row) + else: + break + return rows + + def close(self) -> None: + """Close the result set and release resources.""" + super().close() + if self._csv_file: + with contextlib.suppress(Exception): + self._csv_file.close() + self._csv_file = None + self._csv_reader = None diff --git a/pyathena/sqlalchemy/s3fs.py b/pyathena/sqlalchemy/s3fs.py new file mode 100644 index 00000000..3e35f514 --- /dev/null +++ b/pyathena/sqlalchemy/s3fs.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +from typing import TYPE_CHECKING + +from pyathena.sqlalchemy.base import AthenaDialect + +if TYPE_CHECKING: + from types import ModuleType + + +class AthenaS3FSDialect(AthenaDialect): + """SQLAlchemy dialect for PyAthena with S3FS cursor. + + This dialect uses the S3FSCursor which reads CSV results via + PyAthena's S3FileSystem without requiring pandas or pyarrow. + + Example: + >>> from sqlalchemy import create_engine + >>> engine = create_engine( + ... "awsathena+s3fs://:@athena.us-east-1.amazonaws.com/database" + ... "?s3_staging_dir=s3://bucket/path" + ... ) + """ + + driver = "s3fs" + supports_statement_cache = True + + def create_connect_args(self, url): + from pyathena.s3fs.cursor import S3FSCursor + + opts = super()._create_connect_args(url) + opts.update({"cursor_class": S3FSCursor}) + return [[], opts] + + @classmethod + def import_dbapi(cls) -> "ModuleType": + return super().import_dbapi() diff --git a/pyproject.toml b/pyproject.toml index c30f1357..4f4f1ac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ awsathena = "pyathena.sqlalchemy.base:AthenaDialect" "awsathena.rest" = "pyathena.sqlalchemy.rest:AthenaRestDialect" "awsathena.pandas" = "pyathena.sqlalchemy.pandas:AthenaPandasDialect" "awsathena.arrow" = "pyathena.sqlalchemy.arrow:AthenaArrowDialect" +"awsathena.s3fs" = "pyathena.sqlalchemy.s3fs:AthenaS3FSDialect" [project.optional-dependencies] sqlalchemy = ["sqlalchemy>=1.0.0"] diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index 193b2e1a..ec4072f0 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -508,6 +508,23 @@ def test_empty_result_unload(self, arrow_cursor): assert table.shape[0] == 0 assert table.shape[1] == 0 + def test_ctas(self, arrow_cursor): + table_name = f"test_ctas_arrow_{''.join(random.choices(string.ascii_lowercase, k=10))}" + location = f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/" + arrow_cursor.execute( + f""" + CREATE TABLE {ENV.schema}.{table_name} + WITH ( + format='PARQUET', + external_location='{location}' + ) AS SELECT a FROM many_rows LIMIT 1 + """ + ) + assert arrow_cursor.description == [("rows", "bigint", None, None, 19, 0, "UNKNOWN")] + # CTAS returns affected row count via rowcount, not via fetchone() + assert arrow_cursor.rowcount == 1 + assert arrow_cursor.fetchone() is None + @pytest.mark.parametrize( "arrow_cursor", [{"cursor_kwargs": {"unload": False}}, {"cursor_kwargs": {"unload": True}}], diff --git a/tests/pyathena/conftest.py b/tests/pyathena/conftest.py index da2961ad..b98f6b76 100644 --- a/tests/pyathena/conftest.py +++ b/tests/pyathena/conftest.py @@ -170,6 +170,20 @@ def async_arrow_cursor(request): yield from _cursor(AsyncArrowCursor, request) +@pytest.fixture +def s3fs_cursor(request): + from pyathena.s3fs.cursor import S3FSCursor + + yield from _cursor(S3FSCursor, request) + + +@pytest.fixture +def async_s3fs_cursor(request): + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + yield from _cursor(AsyncS3FSCursor, request) + + @pytest.fixture def spark_cursor(request): from pyathena.spark.cursor import SparkCursor diff --git a/tests/pyathena/pandas/test_cursor.py b/tests/pyathena/pandas/test_cursor.py index afaada6e..c5c6796f 100644 --- a/tests/pyathena/pandas/test_cursor.py +++ b/tests/pyathena/pandas/test_cursor.py @@ -878,6 +878,23 @@ def test_empty_result_dml_unload(self, pandas_cursor, parquet_engine): assert df.shape[0] == 0 assert df.shape[1] == 0 + def test_ctas(self, pandas_cursor): + table_name = f"test_ctas_pandas_{''.join(random.choices(string.ascii_lowercase, k=10))}" + location = f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/" + pandas_cursor.execute( + f""" + CREATE TABLE {ENV.schema}.{table_name} + WITH ( + format='PARQUET', + external_location='{location}' + ) AS SELECT a FROM many_rows LIMIT 1 + """ + ) + assert pandas_cursor.description == [("rows", "bigint", None, None, 19, 0, "UNKNOWN")] + # CTAS returns affected row count via rowcount, not via fetchone() + assert pandas_cursor.rowcount == 1 + assert pandas_cursor.fetchone() is None + @pytest.mark.parametrize( "pandas_cursor, parquet_engine", [ diff --git a/tests/pyathena/s3fs/__init__.py b/tests/pyathena/s3fs/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/tests/pyathena/s3fs/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/tests/pyathena/s3fs/test_async_cursor.py b/tests/pyathena/s3fs/test_async_cursor.py new file mode 100644 index 00000000..b9b37053 --- /dev/null +++ b/tests/pyathena/s3fs/test_async_cursor.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +import contextlib +import random +import time +from datetime import datetime +from decimal import Decimal + +import pytest + +from pyathena.error import ProgrammingError +from pyathena.model import AthenaQueryExecution +from pyathena.s3fs.async_cursor import AsyncS3FSCursor +from pyathena.s3fs.result_set import AthenaS3FSResultSet +from tests import ENV +from tests.pyathena.conftest import connect + + +class TestAsyncS3FSCursor: + def test_fetchone(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert result_set.rownumber == 0 + assert result_set.fetchone() == (1,) + assert result_set.rownumber == 1 + assert result_set.fetchone() is None + + def test_fetchmany(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute("SELECT * FROM many_rows LIMIT 15") + result_set = future.result() + assert len(result_set.fetchmany(10)) == 10 + assert len(result_set.fetchmany(10)) == 5 + + def test_fetchall(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert result_set.fetchall() == [(1,)] + + query_id, future = async_s3fs_cursor.execute("SELECT a FROM many_rows ORDER BY a") + result_set = future.result() + assert result_set.fetchall() == [(i,) for i in range(10000)] + + def test_arraysize(self, async_s3fs_cursor): + async_s3fs_cursor.arraysize = 5 + query_id, future = async_s3fs_cursor.execute("SELECT * FROM many_rows LIMIT 20") + result_set = future.result() + assert len(result_set.fetchmany()) == 5 + + def test_arraysize_default(self, async_s3fs_cursor): + assert async_s3fs_cursor.arraysize == AthenaS3FSResultSet.DEFAULT_FETCH_SIZE + + def test_invalid_arraysize(self, async_s3fs_cursor): + async_s3fs_cursor.arraysize = 10000 + assert async_s3fs_cursor.arraysize == 10000 + with pytest.raises(ProgrammingError): + async_s3fs_cursor.arraysize = -1 + + def test_complex(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ + ) + result_set = future.result() + assert result_set.description == [ + ("col_boolean", "boolean", None, None, 0, 0, "UNKNOWN"), + ("col_tinyint", "tinyint", None, None, 3, 0, "UNKNOWN"), + ("col_smallint", "smallint", None, None, 5, 0, "UNKNOWN"), + ("col_int", "integer", None, None, 10, 0, "UNKNOWN"), + ("col_bigint", "bigint", None, None, 19, 0, "UNKNOWN"), + ("col_float", "float", None, None, 17, 0, "UNKNOWN"), + ("col_double", "double", None, None, 17, 0, "UNKNOWN"), + ("col_string", "varchar", None, None, 2147483647, 0, "UNKNOWN"), + ("col_varchar", "varchar", None, None, 10, 0, "UNKNOWN"), + ("col_timestamp", "timestamp", None, None, 3, 0, "UNKNOWN"), + ("col_time", "time", None, None, 3, 0, "UNKNOWN"), + ("col_date", "date", None, None, 0, 0, "UNKNOWN"), + ("col_binary", "varbinary", None, None, 1073741824, 0, "UNKNOWN"), + ("col_array", "array", None, None, 0, 0, "UNKNOWN"), + ("col_array_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_map", "map", None, None, 0, 0, "UNKNOWN"), + ("col_map_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_struct", "row", None, None, 0, 0, "UNKNOWN"), + ("col_decimal", "decimal", None, None, 10, 1, "UNKNOWN"), + ] + assert result_set.fetchall() == [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + datetime(2017, 1, 2).date(), + b"123", + [1, 2], + [1, 2], + {"1": 2, "3": 4}, + {"1": 2, "3": 4}, + {"a": 1, "b": 2}, + Decimal("0.1"), + ) + ] + + def test_cancel(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute( + """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ + ) + time.sleep(random.randint(5, 10)) + async_s3fs_cursor.cancel(query_id) + result_set = future.result() + assert result_set.state == AthenaQueryExecution.STATE_CANCELLED + assert result_set.description is None + assert result_set.fetchone() is None + assert result_set.fetchmany() == [] + assert result_set.fetchall() == [] + + def test_open_close(self): + with ( + contextlib.closing(connect(schema_name=ENV.schema)) as conn, + conn.cursor(AsyncS3FSCursor) as cursor, + ): + query_id, future = cursor.execute("SELECT * FROM one_row") + result_set = future.result() + assert result_set.fetchall() == [(1,)] + + def test_no_ops(self): + conn = connect(schema_name=ENV.schema) + cursor = conn.cursor(AsyncS3FSCursor) + cursor.close() + conn.close() + + def test_show_columns(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute("SHOW COLUMNS IN one_row") + result_set = future.result() + assert result_set.description == [("field", "string", None, None, 0, 0, "UNKNOWN")] + assert result_set.fetchall() == [("number_of_rows ",)] + + def test_empty_result(self, async_s3fs_cursor): + query_id, future = async_s3fs_cursor.execute("SELECT * FROM one_row WHERE 1 = 2") + result_set = future.result() + assert query_id + assert result_set.rownumber == 0 + assert result_set.fetchone() is None + assert result_set.fetchmany() == [] + assert result_set.fetchmany(10) == [] + assert result_set.fetchall() == [] diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py new file mode 100644 index 00000000..20bc513a --- /dev/null +++ b/tests/pyathena/s3fs/test_cursor.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +import contextlib +import random +import string +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from decimal import Decimal + +import pytest + +from pyathena.error import DatabaseError, ProgrammingError +from pyathena.s3fs.cursor import S3FSCursor +from pyathena.s3fs.result_set import AthenaS3FSResultSet +from tests import ENV +from tests.pyathena.conftest import connect + + +class TestS3FSCursor: + def test_fetchone(self, s3fs_cursor): + s3fs_cursor.execute("SELECT * FROM one_row") + assert s3fs_cursor.rownumber == 0 + assert s3fs_cursor.fetchone() == (1,) + assert s3fs_cursor.rownumber == 1 + assert s3fs_cursor.fetchone() is None + + def test_fetchmany(self, s3fs_cursor): + s3fs_cursor.execute("SELECT * FROM many_rows LIMIT 15") + assert len(s3fs_cursor.fetchmany(10)) == 10 + assert len(s3fs_cursor.fetchmany(10)) == 5 + + def test_fetchall(self, s3fs_cursor): + s3fs_cursor.execute("SELECT * FROM one_row") + assert s3fs_cursor.fetchall() == [(1,)] + s3fs_cursor.execute("SELECT a FROM many_rows ORDER BY a") + assert s3fs_cursor.fetchall() == [(i,) for i in range(10000)] + + def test_iterator(self, s3fs_cursor): + s3fs_cursor.execute("SELECT * FROM one_row") + assert list(s3fs_cursor) == [(1,)] + pytest.raises(StopIteration, s3fs_cursor.__next__) + + def test_arraysize(self, s3fs_cursor): + s3fs_cursor.arraysize = 5 + s3fs_cursor.execute("SELECT * FROM many_rows LIMIT 20") + assert len(s3fs_cursor.fetchmany()) == 5 + + def test_arraysize_default(self, s3fs_cursor): + assert s3fs_cursor.arraysize == AthenaS3FSResultSet.DEFAULT_FETCH_SIZE + + def test_invalid_arraysize(self, s3fs_cursor): + s3fs_cursor.arraysize = 10000 + assert s3fs_cursor.arraysize == 10000 + with pytest.raises(ProgrammingError): + s3fs_cursor.arraysize = -1 + + def test_complex(self, s3fs_cursor): + s3fs_cursor.execute( + """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_varchar + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ + ) + assert s3fs_cursor.description == [ + ("col_boolean", "boolean", None, None, 0, 0, "UNKNOWN"), + ("col_tinyint", "tinyint", None, None, 3, 0, "UNKNOWN"), + ("col_smallint", "smallint", None, None, 5, 0, "UNKNOWN"), + ("col_int", "integer", None, None, 10, 0, "UNKNOWN"), + ("col_bigint", "bigint", None, None, 19, 0, "UNKNOWN"), + ("col_float", "float", None, None, 17, 0, "UNKNOWN"), + ("col_double", "double", None, None, 17, 0, "UNKNOWN"), + ("col_string", "varchar", None, None, 2147483647, 0, "UNKNOWN"), + ("col_varchar", "varchar", None, None, 10, 0, "UNKNOWN"), + ("col_timestamp", "timestamp", None, None, 3, 0, "UNKNOWN"), + ("col_time", "time", None, None, 3, 0, "UNKNOWN"), + ("col_date", "date", None, None, 0, 0, "UNKNOWN"), + ("col_binary", "varbinary", None, None, 1073741824, 0, "UNKNOWN"), + ("col_array", "array", None, None, 0, 0, "UNKNOWN"), + ("col_array_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_map", "map", None, None, 0, 0, "UNKNOWN"), + ("col_map_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_struct", "row", None, None, 0, 0, "UNKNOWN"), + ("col_decimal", "decimal", None, None, 10, 1, "UNKNOWN"), + ] + assert s3fs_cursor.fetchall() == [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + "varchar", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + datetime(2017, 1, 2).date(), + b"123", + [1, 2], + [1, 2], + {"1": 2, "3": 4}, + {"1": 2, "3": 4}, + {"a": 1, "b": 2}, + Decimal("0.1"), + ) + ] + + def test_cancel(self, s3fs_cursor): + def cancel(c): + time.sleep(random.randint(5, 10)) + c.cancel() + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit(cancel, s3fs_cursor) + pytest.raises( + DatabaseError, + lambda: s3fs_cursor.execute( + """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ + ), + ) + + def test_cancel_initial(self, s3fs_cursor): + pytest.raises(ProgrammingError, s3fs_cursor.cancel) + + def test_open_close(self): + with ( + contextlib.closing(connect(schema_name=ENV.schema)) as conn, + conn.cursor(S3FSCursor) as cursor, + ): + cursor.execute("SELECT * FROM one_row") + assert cursor.fetchall() == [(1,)] + + def test_no_ops(self): + conn = connect(schema_name=ENV.schema) + cursor = conn.cursor(S3FSCursor) + cursor.close() + conn.close() + + def test_show_columns(self, s3fs_cursor): + s3fs_cursor.execute("SHOW COLUMNS IN one_row") + assert s3fs_cursor.description == [("field", "string", None, None, 0, 0, "UNKNOWN")] + assert s3fs_cursor.fetchall() == [("number_of_rows ",)] + + def test_empty_result(self, s3fs_cursor): + query_id = s3fs_cursor.execute("SELECT * FROM one_row WHERE 1 = 2").query_id + assert query_id + assert s3fs_cursor.rownumber == 0 + assert s3fs_cursor.fetchone() is None + assert s3fs_cursor.fetchmany() == [] + assert s3fs_cursor.fetchmany(10) == [] + assert s3fs_cursor.fetchall() == [] + + def test_query_id(self, s3fs_cursor): + assert not s3fs_cursor.query_id + s3fs_cursor.execute("SELECT * FROM one_row") + assert s3fs_cursor.query_id is not None + + def test_description_without_execute(self, s3fs_cursor): + assert s3fs_cursor.description is None + + def test_description_with_select(self, s3fs_cursor): + s3fs_cursor.execute("SELECT * FROM one_row") + assert s3fs_cursor.description == [ + ("number_of_rows", "integer", None, None, 10, 0, "UNKNOWN") + ] + + def test_description_with_ctas(self, s3fs_cursor): + table_name = ( + f"test_description_with_ctas_{''.join(random.choices(string.ascii_lowercase, k=10))}" + ) + location = f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/" + s3fs_cursor.execute( + f""" + CREATE TABLE {ENV.schema}.{table_name} + WITH ( + format='PARQUET', + external_location='{location}' + ) AS SELECT a FROM many_rows LIMIT 1 + """ + ) + assert s3fs_cursor.description == [("rows", "bigint", None, None, 19, 0, "UNKNOWN")] + # CTAS returns affected row count via rowcount, not via fetchone() + assert s3fs_cursor.rowcount == 1 + assert s3fs_cursor.fetchone() is None + + def test_description_with_create_table(self, s3fs_cursor): + table_name = ( + f"test_description_with_create_table_" + f"{''.join(random.choices(string.ascii_lowercase, k=10))}" + ) + s3fs_cursor.execute( + f""" + CREATE EXTERNAL TABLE {ENV.schema}.{table_name} ( + a INT + ) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' + LOCATION '{ENV.s3_staging_dir}' + """ + ) + assert s3fs_cursor.description == [] + assert s3fs_cursor.fetchone() is None + s3fs_cursor.execute(f"DROP TABLE {ENV.schema}.{table_name}") + + @pytest.mark.skip(reason="Requires insert_test table to exist in test environment") + def test_executemany(self, s3fs_cursor): + rows = [(1, "foo"), (2, "bar"), (3, "jim")] + s3fs_cursor.executemany( + f"INSERT INTO {ENV.schema}.insert_test (a, b) VALUES (%(a)s, %(b)s)", + [{"a": row[0], "b": row[1]} for row in rows], + ) + s3fs_cursor.execute(f"SELECT * FROM {ENV.schema}.insert_test ORDER BY a") + assert s3fs_cursor.fetchall() == rows + s3fs_cursor.execute(f"DELETE FROM {ENV.schema}.insert_test WHERE a IN (1, 2, 3)") + + def test_on_start_query_execution(self): + callback_query_id = None + + def callback(query_id): + nonlocal callback_query_id + callback_query_id = query_id + + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"on_start_query_execution": callback}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT * FROM one_row") + assert callback_query_id == cursor.query_id + + def test_on_start_query_execution_execute(self): + callback_query_id = None + + def callback(query_id): + nonlocal callback_query_id + callback_query_id = query_id + + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT * FROM one_row", on_start_query_execution=callback) + assert callback_query_id == cursor.query_id diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 088a0ba6..37883a1e 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -2,7 +2,9 @@ import contextlib import json import logging +import random import re +import string import threading import time from concurrent import futures @@ -240,6 +242,25 @@ def test_description_failed(self, cursor): cursor.execute("blah_blah") assert cursor.description is None + def test_description_with_ctas(self, cursor): + table_name = ( + f"test_description_with_ctas_{''.join(random.choices(string.ascii_lowercase, k=10))}" + ) + location = f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/" + cursor.execute( + f""" + CREATE TABLE {ENV.schema}.{table_name} + WITH ( + format='PARQUET', + external_location='{location}' + ) AS SELECT a FROM many_rows LIMIT 1 + """ + ) + assert cursor.description == [("rows", "bigint", None, None, 19, 0, "UNKNOWN")] + # CTAS returns affected row count via rowcount, not via fetchone() + assert cursor.rowcount == 1 + assert cursor.fetchone() is None + def test_bad_query(self, cursor): def run(): cursor.execute("SELECT does_not_exist FROM this_really_does_not_exist") diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py index d4e08384..7e3c1b4c 100644 --- a/tests/sqlalchemy/__init__.py +++ b/tests/sqlalchemy/__init__.py @@ -5,3 +5,4 @@ registry.register("awsathena.rest", "pyathena.sqlalchemy.rest", "AthenaRestDialect") registry.register("awsathena.pandas", "pyathena.sqlalchemy.pandas", "AthenaPandasDialect") registry.register("awsathena.arrow", "pyathena.sqlalchemy.arrow", "AthenaArrowDialect") +registry.register("awsathena.s3fs", "pyathena.sqlalchemy.s3fs", "AthenaS3FSDialect") From 06cc3a6e544eaf0cc8da5e4371590a7231f01aa8 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 14:52:18 +0900 Subject: [PATCH 2/6] Fix rowcount property for Arrow, Pandas, and S3FS cursors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Due to Python MRO, CursorIterator.rowcount was taking precedence over WithResultSet.rowcount. The base Cursor class already has its own rowcount property that delegates to result_set.rowcount. This commit adds the same pattern to ArrowCursor, PandasCursor, and S3FSCursor. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- pyathena/arrow/cursor.py | 5 +++++ pyathena/pandas/cursor.py | 5 +++++ pyathena/s3fs/cursor.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index cb3cd54d..14834d3a 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -156,6 +156,11 @@ def query_id(self, val) -> None: def rownumber(self) -> Optional[int]: return self.result_set.rownumber if self.result_set else None + @property + def rowcount(self) -> int: + """Get the number of rows affected by the last operation.""" + return self.result_set.rowcount if self.result_set else -1 + def close(self) -> None: if self.result_set and not self.result_set.is_closed: self.result_set.close() diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 488f2b25..470adc5c 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -179,6 +179,11 @@ def query_id(self, val) -> None: def rownumber(self) -> Optional[int]: return self.result_set.rownumber if self.result_set else None + @property + def rowcount(self) -> int: + """Get the number of rows affected by the last operation.""" + return self.result_set.rowcount if self.result_set else -1 + def close(self) -> None: if self.result_set and not self.result_set.is_closed: self.result_set.close() diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index 33cef5cd..2daa01e8 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -156,6 +156,11 @@ def rownumber(self) -> Optional[int]: """Get the current row number (0-indexed).""" return self.result_set.rownumber if self.result_set else None + @property + def rowcount(self) -> int: + """Get the number of rows affected by the last operation.""" + return self.result_set.rowcount if self.result_set else -1 + def close(self) -> None: """Close the cursor and release resources.""" if self.result_set and not self.result_set.is_closed: From e3201834728622bbf187c67e956aeac50b1bbe8f Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 15:13:56 +0900 Subject: [PATCH 3/6] Unify random string generation to use random.choices MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace list comprehension pattern with simpler random.choices(k=10). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/__init__.py | 2 +- tests/pyathena/arrow/test_async_cursor.py | 2 +- tests/pyathena/arrow/test_cursor.py | 2 +- tests/pyathena/pandas/test_async_cursor.py | 2 +- tests/pyathena/pandas/test_cursor.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index d2a21f02..2daedd1e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,7 +28,7 @@ def __init__(self): ) self.default_work_group = os.getenv("AWS_ATHENA_DEFAULT_WORKGROUP", "primary") self.schema = "pyathena_test_" + "".join( - [random.choice(string.ascii_lowercase + string.digits) for _ in range(10)] + random.choices(string.ascii_lowercase + string.digits, k=10) ) self.s3_filesystem_test_file_key = ( f"{self.s3_staging_key}{self.schema}/filesystem/test_read/test.dat" diff --git a/tests/pyathena/arrow/test_async_cursor.py b/tests/pyathena/arrow/test_async_cursor.py index 840300de..b29d3804 100644 --- a/tests/pyathena/arrow/test_async_cursor.py +++ b/tests/pyathena/arrow/test_async_cursor.py @@ -275,7 +275,7 @@ def test_no_ops(self): ) def test_empty_result(self, async_arrow_cursor): table = "test_pandas_cursor_empty_result_" + "".join( - [random.choice(string.ascii_lowercase + string.digits) for _ in range(10)] + random.choices(string.ascii_lowercase + string.digits, k=10) ) query_id, future = async_arrow_cursor.execute( f""" diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index ec4072f0..366a4bd6 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -476,7 +476,7 @@ def test_show_columns(self, arrow_cursor): ) def test_empty_result(self, arrow_cursor): table = "test_arrow_cursor_empty_result_" + "".join( - [random.choice(string.ascii_lowercase + string.digits) for _ in range(10)] + random.choices(string.ascii_lowercase + string.digits, k=10) ) df = arrow_cursor.execute( f""" diff --git a/tests/pyathena/pandas/test_async_cursor.py b/tests/pyathena/pandas/test_async_cursor.py index b88af8b1..c4d8b992 100644 --- a/tests/pyathena/pandas/test_async_cursor.py +++ b/tests/pyathena/pandas/test_async_cursor.py @@ -424,7 +424,7 @@ def test_show_columns(self, async_pandas_cursor, parquet_engine, chunksize): ) def test_empty_result_ddl(self, async_pandas_cursor, parquet_engine, chunksize): table = "test_pandas_cursor_empty_result_" + "".join( - [random.choice(string.ascii_lowercase + string.digits) for _ in range(10)] + random.choices(string.ascii_lowercase + string.digits, k=10) ) query_id, future = async_pandas_cursor.execute( f""" diff --git a/tests/pyathena/pandas/test_cursor.py b/tests/pyathena/pandas/test_cursor.py index c5c6796f..160d0c87 100644 --- a/tests/pyathena/pandas/test_cursor.py +++ b/tests/pyathena/pandas/test_cursor.py @@ -842,7 +842,7 @@ def test_show_columns(self, pandas_cursor, parquet_engine, chunksize): ) def test_empty_result_ddl(self, pandas_cursor, parquet_engine, chunksize): table = "test_pandas_cursor_empty_result_" + "".join( - [random.choice(string.ascii_lowercase + string.digits) for _ in range(10)] + random.choices(string.ascii_lowercase + string.digits, k=10) ) df = pandas_cursor.execute( f""" From d74546b017469d15b3b82771509f01f6c0c276eb Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 15:37:12 +0900 Subject: [PATCH 4/6] Add test for tab and newline characters in S3FS cursor result data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Test that the S3FS cursor correctly handles data containing tab and newline characters, which are special characters in CSV/TSV parsing. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/pyathena/s3fs/test_cursor.py | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py index 20bc513a..f5b0ba9f 100644 --- a/tests/pyathena/s3fs/test_cursor.py +++ b/tests/pyathena/s3fs/test_cursor.py @@ -275,3 +275,39 @@ def callback(query_id): ): cursor.execute("SELECT * FROM one_row", on_start_query_execution=callback) assert callback_query_id == cursor.query_id + + def test_contain_tab_character(self, s3fs_cursor): + """Test that tab characters in result data are handled correctly. + + S3FS cursor uses tab as delimiter for parsing Athena's CSV output. + This test verifies that data containing tab characters is correctly + parsed when Athena properly quotes such fields. + """ + # Test with tab character in string using CHR(9) + s3fs_cursor.execute("SELECT 'before' || CHR(9) || 'after' AS col_with_tab") + result = s3fs_cursor.fetchone() + assert result == ("before\tafter",) + + # Test with multiple columns where one contains tab + s3fs_cursor.execute( + """ + SELECT + 'normal' AS col1, + 'has' || CHR(9) || 'tab' AS col2, + 'also_normal' AS col3 + """ + ) + result = s3fs_cursor.fetchone() + assert result == ("normal", "has\ttab", "also_normal") + + # Test with newline character as well + s3fs_cursor.execute("SELECT 'line1' || CHR(10) || 'line2' AS col_with_newline") + result = s3fs_cursor.fetchone() + assert result == ("line1\nline2",) + + # Test with both tab and newline + s3fs_cursor.execute( + "SELECT 'a' || CHR(9) || 'b' || CHR(10) || 'c' AS col_with_special_chars" + ) + result = s3fs_cursor.fetchone() + assert result == ("a\tb\nc",) From 7807aa68576d9a1c42a3558f7edd5c3fab648bd3 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 15:52:56 +0900 Subject: [PATCH 5/6] Add documentation for S3FS cursor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add docs/s3fs.rst with comprehensive S3FSCursor and AsyncS3FSCursor documentation - Add docs/api/s3fs.rst with API reference - Update docs/index.rst to include s3fs in toctree - Update docs/api.rst to include s3fs API reference The documentation covers: - Basic usage and connection examples - Type conversion mappings - Custom converter implementation - Limitations compared to Arrow/Pandas cursors - Use cases and recommendations - AsyncS3FSCursor for asynchronous operations 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- docs/api.rst | 2 + docs/api/s3fs.rst | 30 +++++ docs/index.rst | 1 + docs/s3fs.rst | 327 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 360 insertions(+) create mode 100644 docs/api/s3fs.rst create mode 100644 docs/s3fs.rst diff --git a/docs/api.rst b/docs/api.rst index e3ed3937..9cb68d02 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -12,6 +12,7 @@ This section provides comprehensive API documentation for all PyAthena classes a api/connection api/pandas api/arrow + api/s3fs api/spark api/converters api/filesystem @@ -35,6 +36,7 @@ Specialized Integrations - :ref:`api_pandas` - pandas DataFrame integration - :ref:`api_arrow` - Apache Arrow columnar data integration +- :ref:`api_s3fs` - Lightweight S3FS-based cursor (no pandas/pyarrow required) - :ref:`api_spark` - Apache Spark integration for big data processing Infrastructure diff --git a/docs/api/s3fs.rst b/docs/api/s3fs.rst new file mode 100644 index 00000000..1104a7e5 --- /dev/null +++ b/docs/api/s3fs.rst @@ -0,0 +1,30 @@ +.. _api_s3fs: + +S3FS Integration +================ + +This section covers lightweight S3FS-based cursors and data converters that use Python's built-in ``csv`` module. + +S3FS Cursors +------------ + +.. autoclass:: pyathena.s3fs.cursor.S3FSCursor + :members: + :inherited-members: + +.. autoclass:: pyathena.s3fs.async_cursor.AsyncS3FSCursor + :members: + :inherited-members: + +S3FS Data Converters +-------------------- + +.. autoclass:: pyathena.s3fs.converter.DefaultS3FSTypeConverter + :members: + +S3FS Result Set +--------------- + +.. autoclass:: pyathena.s3fs.result_set.AthenaS3FSResultSet + :members: + :inherited-members: diff --git a/docs/index.rst b/docs/index.rst index 4c654a7f..29000d80 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,7 @@ Documentation cursor pandas arrow + s3fs spark testing api diff --git a/docs/s3fs.rst b/docs/s3fs.rst new file mode 100644 index 00000000..156ba634 --- /dev/null +++ b/docs/s3fs.rst @@ -0,0 +1,327 @@ +.. _s3fs: + +S3FS +==== + +.. _s3fs-cursor: + +S3FSCursor +---------- + +S3FSCursor is a lightweight cursor that directly handles the CSV file of the query execution result output to S3. +Unlike ArrowCursor or PandasCursor, this cursor uses Python's built-in ``csv`` module to parse results, +making it ideal for environments where installing pandas or pyarrow is not desirable. + +**Key features:** + +- No pandas or pyarrow dependencies required +- Uses Python's built-in ``csv`` module for parsing +- Lower memory footprint for simple query results +- Full DB API 2.0 compatibility + +You can use the S3FSCursor by specifying the ``cursor_class`` +with the connect method or connection object. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + +.. code:: python + + from pyathena.connection import Connection + from pyathena.s3fs.cursor import S3FSCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + +It can also be used by specifying the cursor class when calling the connection object's cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(S3FSCursor) + +.. code:: python + + from pyathena.connection import Connection + from pyathena.s3fs.cursor import S3FSCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(S3FSCursor) + +Support fetch and iterate query results. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + print(cursor.fetchone()) + print(cursor.fetchmany()) + print(cursor.fetchall()) + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + for row in cursor: + print(row) + +Execution information of the query can also be retrieved. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + + cursor.execute("SELECT * FROM many_rows") + print(cursor.state) + print(cursor.state_change_reason) + print(cursor.completion_date_time) + print(cursor.submission_date_time) + print(cursor.data_scanned_in_bytes) + print(cursor.engine_execution_time_in_millis) + print(cursor.query_queue_time_in_millis) + print(cursor.total_execution_time_in_millis) + print(cursor.query_planning_time_in_millis) + print(cursor.service_processing_time_in_millis) + print(cursor.output_location) + +Type Conversion +~~~~~~~~~~~~~~~ + +S3FSCursor converts Athena data types to Python types using the built-in converter. +The following type mappings are used: + +.. list-table:: Type Mappings + :header-rows: 1 + :widths: 30 70 + + * - Athena Type + - Python Type + * - boolean + - bool + * - tinyint, smallint, integer, bigint + - int + * - float, double, real + - float + * - decimal + - decimal.Decimal + * - char, varchar, string + - str + * - date + - datetime.date + * - timestamp + - datetime.datetime + * - time + - datetime.time + * - binary, varbinary + - bytes + * - array, map, row (struct) + - Parsed as Python list/dict using JSON-like parsing + * - json + - Parsed JSON (dict or list) + +If you want to customize type conversion, create a converter class like this: + +.. code:: python + + from pyathena.s3fs.converter import DefaultS3FSTypeConverter + + class CustomS3FSTypeConverter(DefaultS3FSTypeConverter): + def __init__(self) -> None: + super().__init__() + # Override specific type mappings + self._mappings["custom_type"] = self._convert_custom + + def _convert_custom(self, value: str) -> Any: + # Your custom conversion logic + return value.upper() + +Then specify an instance of this class in the converter argument when creating a cursor. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(S3FSCursor, converter=CustomS3FSTypeConverter()) + +Limitations +~~~~~~~~~~~ + +S3FSCursor has some limitations compared to ArrowCursor or PandasCursor: + +- **No UNLOAD support**: S3FSCursor reads CSV results directly and does not support the UNLOAD option + that outputs results in Parquet format. +- **Sequential reading**: Results are read row by row from the CSV file, which may be slower + for very large result sets compared to columnar formats. +- **No DataFrame conversion**: There is no ``as_pandas()`` or ``as_arrow()`` method. + Use PandasCursor or ArrowCursor if you need DataFrame operations. + +When to use S3FSCursor +~~~~~~~~~~~~~~~~~~~~~~ + +S3FSCursor is recommended when: + +- You want to minimize dependencies (no pandas/pyarrow required) +- You're working in a constrained environment (e.g., AWS Lambda with size limits) +- You only need simple row-by-row result processing +- Memory efficiency is important and results don't need columnar operations + +For large-scale data processing or analytical workloads, consider using ArrowCursor or PandasCursor instead. + +.. _async-s3fs-cursor: + +AsyncS3FSCursor +--------------- + +AsyncS3FSCursor is an AsyncCursor that uses the same lightweight CSV parsing as S3FSCursor. +This cursor is useful when you need to execute queries asynchronously without pandas or pyarrow dependencies. + +You can use the AsyncS3FSCursor by specifying the ``cursor_class`` +with the connect method or connection object. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + +.. code:: python + + from pyathena.connection import Connection + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + +It can also be used by specifying the cursor class when calling the connection object's cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(AsyncS3FSCursor) + +.. code:: python + + from pyathena.connection import Connection + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = Connection(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor(AsyncS3FSCursor) + +The default number of workers is 5 or cpu number * 5. +If you want to change the number of workers you can specify like the following. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor(max_workers=10) + +The execute method of the AsyncS3FSCursor returns the tuple of the query ID and the `future object`_. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + +The return value of the `future object`_ is an ``AthenaS3FSResultSet`` object. +This object has an interface similar to ``AthenaResultSetObject``. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + result_set = future.result() + print(result_set.state) + print(result_set.state_change_reason) + print(result_set.completion_date_time) + print(result_set.submission_date_time) + print(result_set.data_scanned_in_bytes) + print(result_set.engine_execution_time_in_millis) + print(result_set.query_queue_time_in_millis) + print(result_set.total_execution_time_in_millis) + print(result_set.query_planning_time_in_millis) + print(result_set.service_processing_time_in_millis) + print(result_set.output_location) + print(result_set.description) + for row in result_set: + print(row) + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + result_set = future.result() + print(result_set.fetchall()) + +As with AsyncCursor, you need a query ID to cancel a query. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.async_cursor import AsyncS3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncS3FSCursor).cursor() + + query_id, future = cursor.execute("SELECT * FROM many_rows") + cursor.cancel(query_id) + +.. _`future object`: https://docs.python.org/3/library/concurrent.futures.html#future-objects From 32f0c84a78690be19fab46307054cefb9dd0ed7f Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 1 Jan 2026 18:54:49 +0900 Subject: [PATCH 6/6] Add pluggable CSV reader architecture for S3FSCursor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add AthenaCSVReader (default): Custom parser that distinguishes NULL (unquoted empty) from empty string (quoted empty "") - Add DefaultCSVReader: Python's standard csv module wrapper for backward compatibility (both NULL and empty string become empty string) - Support multi-line quoted fields in AthenaCSVReader with optimized incremental quote state tracking (O(n) complexity) - Add csv_reader parameter to S3FSCursor and AsyncS3FSCursor - Refactor result_set.py to remove unnecessary instance variables - Move header skipping to _init_csv_reader() for cleaner initialization - Update documentation with CSV reader options and NULL handling details - Add comprehensive unit tests for both CSV readers 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- docs/api/s3fs.rst | 14 +- docs/s3fs.rst | 83 +++++++++- pyathena/s3fs/async_cursor.py | 14 +- pyathena/s3fs/cursor.py | 14 +- pyathena/s3fs/reader.py | 236 +++++++++++++++++++++++++++++ pyathena/s3fs/result_set.py | 52 ++++--- tests/pyathena/s3fs/test_cursor.py | 192 +++++++++++++++++++++++ tests/pyathena/s3fs/test_reader.py | 211 ++++++++++++++++++++++++++ 8 files changed, 783 insertions(+), 33 deletions(-) create mode 100644 pyathena/s3fs/reader.py create mode 100644 tests/pyathena/s3fs/test_reader.py diff --git a/docs/api/s3fs.rst b/docs/api/s3fs.rst index 1104a7e5..6d6d691b 100644 --- a/docs/api/s3fs.rst +++ b/docs/api/s3fs.rst @@ -3,7 +3,7 @@ S3FS Integration ================ -This section covers lightweight S3FS-based cursors and data converters that use Python's built-in ``csv`` module. +This section covers lightweight S3FS-based cursors, CSV readers, and data converters. S3FS Cursors ------------ @@ -16,6 +16,18 @@ S3FS Cursors :members: :inherited-members: +S3FS CSV Readers +---------------- + +S3FSCursor supports pluggable CSV reader implementations to control how NULL values +and empty strings are handled when parsing Athena's CSV output. + +.. autoclass:: pyathena.s3fs.reader.AthenaCSVReader + :members: + +.. autoclass:: pyathena.s3fs.reader.DefaultCSVReader + :members: + S3FS Data Converters -------------------- diff --git a/docs/s3fs.rst b/docs/s3fs.rst index 156ba634..d090f939 100644 --- a/docs/s3fs.rst +++ b/docs/s3fs.rst @@ -9,13 +9,13 @@ S3FSCursor ---------- S3FSCursor is a lightweight cursor that directly handles the CSV file of the query execution result output to S3. -Unlike ArrowCursor or PandasCursor, this cursor uses Python's built-in ``csv`` module to parse results, -making it ideal for environments where installing pandas or pyarrow is not desirable. +Unlike ArrowCursor or PandasCursor, this cursor does not require pandas or pyarrow dependencies, +making it ideal for environments where installing these libraries is not desirable. **Key features:** - No pandas or pyarrow dependencies required -- Uses Python's built-in ``csv`` module for parsing +- Lightweight CSV parsing (custom parser or Python's built-in ``csv`` module) - Lower memory footprint for simple query results - Full DB API 2.0 compatibility @@ -172,6 +172,83 @@ Then specify an instance of this class in the converter argument when creating a cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", region_name="us-west-2").cursor(S3FSCursor, converter=CustomS3FSTypeConverter()) +CSV Reader Options +~~~~~~~~~~~~~~~~~~ + +S3FSCursor supports pluggable CSV reader implementations to control how NULL values and empty strings +are handled. Two readers are provided: + +- ``AthenaCSVReader`` (default): Custom parser that distinguishes between NULL and empty string +- ``DefaultCSVReader``: Uses Python's built-in ``csv`` module (treats both NULL and empty string as empty string) + +**Default behavior (AthenaCSVReader):** + +By default, ``AthenaCSVReader`` is used, which correctly distinguishes between NULL +values and empty strings in query results. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor).cursor() + + cursor.execute("SELECT NULL AS null_col, '' AS empty_col") + row = cursor.fetchone() + print(row) # (None, '') - NULL is None, empty string is '' + +**Switching to Python's built-in csv module (DefaultCSVReader):** + +If you prefer to use Python's built-in ``csv`` module, you can switch to ``DefaultCSVReader``. +Note that this reader cannot distinguish between NULL and empty string - both become empty strings +in the parsed result, which are then converted to ``None`` by the type converter. + +.. code:: python + + from pyathena import connect + from pyathena.s3fs.cursor import S3FSCursor + from pyathena.s3fs.reader import DefaultCSVReader + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": DefaultCSVReader}).cursor() + + cursor.execute("SELECT NULL AS null_col, '' AS empty_col") + row = cursor.fetchone() + print(row) # (None, None) - Both NULL and empty string become None + +**Comparison of CSV readers:** + +.. list-table:: CSV Reader Behavior + :header-rows: 1 + :widths: 30 20 25 25 + + * - Reader + - Implementation + - NULL value + - Empty string + * - AthenaCSVReader (default) + - Custom parser + - None + - '' (empty string) + * - DefaultCSVReader + - Python csv module + - None + - None + +**Why the difference?** + +Athena's CSV output format distinguishes between NULL values and empty strings: + +- NULL: unquoted empty field (e.g., ``a,,b`` → the middle field is NULL) +- Empty string: quoted empty field (e.g., ``a,"",b`` → the middle field is an empty string) + +Python's standard ``csv`` module parses both cases as empty strings, losing this distinction. +The ``AthenaCSVReader`` implements a custom parser that preserves the difference. + Limitations ~~~~~~~~~~~ diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py index 8de1379e..89073b3d 100644 --- a/pyathena/s3fs/async_cursor.py +++ b/pyathena/s3fs/async_cursor.py @@ -11,7 +11,7 @@ from pyathena.error import ProgrammingError from pyathena.model import AthenaQueryExecution from pyathena.s3fs.converter import DefaultS3FSTypeConverter -from pyathena.s3fs.result_set import AthenaS3FSResultSet +from pyathena.s3fs.result_set import AthenaS3FSResultSet, CSVReaderType _logger = logging.getLogger(__name__) @@ -20,12 +20,12 @@ class AsyncS3FSCursor(AsyncCursor): """Asynchronous cursor that reads CSV results via S3FileSystem. This cursor extends AsyncCursor to provide asynchronous query execution - with results read via Python's standard csv module and PyAthena's S3FileSystem. + with results read via PyAthena's S3FileSystem. It's a lightweight alternative when pandas/pyarrow are not needed. Features: - Asynchronous query execution with concurrent futures - - Uses Python's standard csv module for parsing + - Lightweight CSV parsing via pluggable readers - Uses PyAthena's S3FileSystem for S3 access - No external dependencies beyond boto3 - Memory-efficient streaming for large datasets @@ -61,6 +61,7 @@ def __init__( arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, + csv_reader: Optional[CSVReaderType] = None, **kwargs, ) -> None: """Initialize an AsyncS3FSCursor. @@ -78,6 +79,11 @@ def __init__( arraysize: Number of rows to fetch per batch. result_reuse_enable: Enable Athena query result reuse. result_reuse_minutes: Minutes to reuse cached results. + csv_reader: CSV reader class to use for parsing results. + Use AthenaCSVReader (default) to distinguish between NULL + (unquoted empty) and empty string (quoted empty ""). + Use DefaultCSVReader for backward compatibility where empty + strings are treated as NULL. **kwargs: Additional connection parameters. Example: @@ -99,6 +105,7 @@ def __init__( result_reuse_minutes=result_reuse_minutes, **kwargs, ) + self._csv_reader = csv_reader @staticmethod def get_default_converter( @@ -156,6 +163,7 @@ def _collect_result_set( query_execution=query_execution, arraysize=self._arraysize, retry_config=self._retry_config, + csv_reader=self._csv_reader, **kwargs, ) diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index 2daa01e8..4fefe7e2 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -9,7 +9,7 @@ from pyathena.model import AthenaQueryExecution from pyathena.result_set import WithResultSet from pyathena.s3fs.converter import DefaultS3FSTypeConverter -from pyathena.s3fs.result_set import AthenaS3FSResultSet +from pyathena.s3fs.result_set import AthenaS3FSResultSet, CSVReaderType _logger = logging.getLogger(__name__) @@ -59,6 +59,7 @@ def __init__( result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, on_start_query_execution: Optional[Callable[[str], None]] = None, + csv_reader: Optional[CSVReaderType] = None, **kwargs, ) -> None: """Initialize an S3FSCursor. @@ -75,11 +76,20 @@ def __init__( result_reuse_enable: Enable Athena query result reuse. result_reuse_minutes: Minutes to reuse cached results. on_start_query_execution: Callback invoked when query starts. + csv_reader: CSV reader class to use for parsing results. + Use AthenaCSVReader (default) to distinguish between NULL + (unquoted empty) and empty string (quoted empty ""). + Use DefaultCSVReader for backward compatibility where empty + strings are treated as NULL. **kwargs: Additional connection parameters. Example: >>> cursor = connection.cursor(S3FSCursor) >>> cursor.execute("SELECT * FROM my_table") + >>> + >>> # Use DefaultCSVReader for backward compatibility + >>> from pyathena.s3fs.reader import DefaultCSVReader + >>> cursor = connection.cursor(S3FSCursor, csv_reader=DefaultCSVReader) """ super().__init__( s3_staging_dir=s3_staging_dir, @@ -95,6 +105,7 @@ def __init__( **kwargs, ) self._on_start_query_execution = on_start_query_execution + self._csv_reader = csv_reader self._query_id: Optional[str] = None self._result_set: Optional[AthenaS3FSResultSet] = None @@ -232,6 +243,7 @@ def execute( query_execution=query_execution, arraysize=self.arraysize, retry_config=self._retry_config, + csv_reader=self._csv_reader, **kwargs, ) else: diff --git a/pyathena/s3fs/reader.py b/pyathena/s3fs/reader.py new file mode 100644 index 00000000..31424a13 --- /dev/null +++ b/pyathena/s3fs/reader.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import csv +from typing import Any, Iterator, List, Optional, Tuple + + +class DefaultCSVReader: + """CSV reader using Python's standard csv module. + + This reader wraps Python's standard csv.reader and treats empty fields + as empty strings. It does not distinguish between NULL and empty strings + in Athena's CSV output - both become empty strings. + + Use this reader when you need backward compatibility with the behavior + where empty strings are treated the same as NULL values. + + Example: + >>> from io import StringIO + >>> reader = DefaultCSVReader(StringIO(',"",text')) + >>> list(reader) + [['', '', 'text']] # Both NULL and empty string become '' + + Note: + The default reader for S3FSCursor is AthenaCSVReader, which + distinguishes between NULL and empty string values. + """ + + def __init__(self, file_obj: Any, delimiter: str = ",") -> None: + """Initialize the reader. + + Args: + file_obj: File-like object to read from. + delimiter: Field delimiter character. + """ + self._reader = csv.reader(file_obj, delimiter=delimiter) + + def __iter__(self) -> Iterator[List[str]]: + """Iterate over rows in the CSV file.""" + return self + + def __next__(self) -> List[str]: + """Read and parse the next line. + + Returns: + List of field values as strings. + + Raises: + StopIteration: When end of file is reached. + """ + row = next(self._reader) + # Python's csv.reader returns [] for empty lines; normalize to [''] + # to represent a single empty field (consistent with single-value handling) + if not row: + return [""] + return row + + +class AthenaCSVReader: + """CSV reader that distinguishes between NULL and empty string. + + This is the default reader for S3FSCursor. + + Athena's CSV output format distinguishes NULL values from empty strings: + - NULL: unquoted empty field (e.g., `,,` or `,field`) + - Empty string: quoted empty field (e.g., `,"",` or `,"",field`) + + Python's standard csv module parses both as empty strings, losing this + distinction. This reader preserves the difference by returning None for + NULL values and empty string for quoted empty values. + + Example: + >>> from io import StringIO + >>> reader = AthenaCSVReader(StringIO(',"",text')) + >>> list(reader) + [[None, '', 'text']] # NULL and empty string are distinguished + + Note: + Use DefaultCSVReader if you need backward compatibility where both + NULL and empty string are treated as empty string. + """ + + def __init__(self, file_obj: Any, delimiter: str = ",") -> None: + """Initialize the reader. + + Args: + file_obj: File-like object to read from. + delimiter: Field delimiter character. + """ + self._file = file_obj + self._delimiter = delimiter + + def __iter__(self) -> Iterator[List[Optional[str]]]: + """Iterate over rows in the CSV file.""" + return self + + def __next__(self) -> List[Optional[str]]: + """Read and parse the next line. + + Returns: + List of field values, with None for NULL and '' for empty string. + + Raises: + StopIteration: When end of file is reached. + """ + line = self._file.readline() + if not line: + raise StopIteration + + # Handle multi-line quoted fields: keep reading until quotes are balanced + # Track quote state incrementally - only scan each new line once + in_quotes = self._check_quote_state(line) + while in_quotes: + next_line = self._file.readline() + if not next_line: + # EOF reached with unclosed quote; parse what we have + break + line += next_line + # Only scan the new line, passing current quote state + in_quotes = self._check_quote_state(next_line, in_quotes) + + return self._parse_line(line.rstrip("\r\n")) + + def _check_quote_state(self, text: str, starting_state: bool = False) -> bool: + """Check quote state after processing text. + + Args: + text: Text to scan for quotes. + starting_state: Whether we start inside a quoted field. + + Returns: + True if we end inside an unclosed quote. + """ + in_quotes = starting_state + i = 0 + while i < len(text): + if text[i] == '"': + if in_quotes and i + 1 < len(text) and text[i + 1] == '"': + # Escaped quote inside quoted field, skip both + i += 2 + continue + in_quotes = not in_quotes + i += 1 + return in_quotes + + def _parse_line(self, line: str) -> List[Optional[str]]: + """Parse a single CSV line preserving NULL vs empty string distinction. + + Args: + line: Raw CSV line without trailing newline. + + Returns: + List of field values. + """ + # Empty line = single NULL field (e.g., SELECT NULL produces empty data line) + if not line: + return [None] + + fields: List[Optional[str]] = [] + pos = 0 + length = len(line) + + while pos < length: + if line[pos] == '"': + # Quoted field + value, pos = self._parse_quoted_field(line, pos) + fields.append(value) + else: + # Unquoted field + value, pos = self._parse_unquoted_field(line, pos) + # Unquoted empty field = NULL + fields.append(None if value == "" else value) + + # Handle trailing empty field (line ends with delimiter) + if line and line[-1] == self._delimiter: + fields.append(None) + + return fields + + def _parse_quoted_field(self, line: str, pos: int) -> Tuple[str, int]: + """Parse a quoted field starting at pos. + + Args: + line: The CSV line. + pos: Starting position (at the opening quote). + + Returns: + Tuple of (field value, next position after delimiter). + """ + pos += 1 # Skip opening quote + value_parts = [] + length = len(line) + + while pos < length: + if line[pos] == '"': + if pos + 1 < length and line[pos + 1] == '"': + # Escaped quote + value_parts.append('"') + pos += 2 + else: + # End of quoted field + pos += 1 # Skip closing quote + break + else: + value_parts.append(line[pos]) + pos += 1 + + # Skip delimiter if present + if pos < length and line[pos] == self._delimiter: + pos += 1 + + return "".join(value_parts), pos + + def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]: + """Parse an unquoted field starting at pos. + + Args: + line: The CSV line. + pos: Starting position. + + Returns: + Tuple of (field value, next position after delimiter). + """ + start = pos + length = len(line) + + while pos < length and line[pos] != self._delimiter: + pos += 1 + + value = line[start:pos] + + # Skip delimiter if present + if pos < length and line[pos] == self._delimiter: + pos += 1 + + return value, pos diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 30781d4d..e2651f88 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -2,33 +2,35 @@ from __future__ import annotations import contextlib -import csv import logging from io import TextIOWrapper -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from pyathena.converter import Converter from pyathena.error import OperationalError, ProgrammingError from pyathena.filesystem.s3 import S3FileSystem from pyathena.model import AthenaQueryExecution from pyathena.result_set import AthenaResultSet +from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader from pyathena.util import RetryConfig, parse_output_location if TYPE_CHECKING: from pyathena.connection import Connection +CSVReaderType = Union[Type[DefaultCSVReader], Type[AthenaCSVReader]] + _logger = logging.getLogger(__name__) class AthenaS3FSResultSet(AthenaResultSet): """Result set that reads CSV results via S3FileSystem without pandas/pyarrow. - This result set uses Python's standard csv module and PyAthena's S3FileSystem - to read query results from S3. It provides a lightweight alternative to pandas - and arrow cursors when those dependencies are not needed. + This result set uses PyAthena's S3FileSystem to read query results from S3. + It provides a lightweight alternative to pandas and arrow cursors when those + dependencies are not needed. Features: - - Uses Python's standard csv module for parsing + - Lightweight CSV parsing via pluggable readers - Uses PyAthena's S3FileSystem for S3 access - No external dependencies beyond boto3 - Memory-efficient streaming for large datasets @@ -60,6 +62,7 @@ def __init__( arraysize: int, retry_config: RetryConfig, block_size: Optional[int] = None, + csv_reader: Optional[CSVReaderType] = None, **kwargs, ) -> None: super().__init__( @@ -74,11 +77,10 @@ def __init__( self._rows.clear() self._arraysize = arraysize self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE + self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader self._fs = self._create_s3_file_system() self._csv_reader: Optional[Any] = None self._csv_file: Optional[Any] = None - self._header_skipped = False - self._has_header = False # CSV files have headers, TXT files don't if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: self._init_csv_reader() @@ -125,12 +127,12 @@ def _init_csv_reader(self) -> None: if self.output_location.endswith(".txt"): # Tab-separated format (no header row) - self._csv_reader = csv.reader(text_wrapper, delimiter="\t") - self._has_header = False + self._csv_reader = self._csv_reader_class(text_wrapper, delimiter="\t") else: - # Standard CSV format (has header row) - self._csv_reader = csv.reader(text_wrapper) - self._has_header = True + # Standard CSV format (has header row, skip it) + self._csv_reader = self._csv_reader_class(text_wrapper, delimiter=",") + with contextlib.suppress(StopIteration): + next(self._csv_reader) except Exception as e: _logger.exception(f"Failed to open {path}.") @@ -141,14 +143,6 @@ def _fetch(self) -> None: if not self._csv_reader: return - # Skip header row on first fetch (only for CSV files, not TXT) - if self._has_header and not self._header_skipped: - try: - next(self._csv_reader) - self._header_skipped = True - except StopIteration: - return - description = self.description if self.description else [] column_types = [d[1] for d in description] @@ -160,10 +154,18 @@ def _fetch(self) -> None: break # Convert row values using converters - converted_row = tuple( - self._converter.convert(col_type, value if value != "" else None) - for col_type, value in zip(column_types, row, strict=False) - ) + # AthenaCSVReader returns None for NULL values directly, + # DefaultCSVReader returns empty string which needs conversion + if self._csv_reader_class is DefaultCSVReader: + converted_row = tuple( + self._converter.convert(col_type, value if value != "" else None) + for col_type, value in zip(column_types, row, strict=False) + ) + else: + converted_row = tuple( + self._converter.convert(col_type, value) + for col_type, value in zip(column_types, row, strict=False) + ) self._rows.append(converted_row) rows_fetched += 1 diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py index f5b0ba9f..b90523cb 100644 --- a/tests/pyathena/s3fs/test_cursor.py +++ b/tests/pyathena/s3fs/test_cursor.py @@ -11,6 +11,7 @@ from pyathena.error import DatabaseError, ProgrammingError from pyathena.s3fs.cursor import S3FSCursor +from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader from pyathena.s3fs.result_set import AthenaS3FSResultSet from tests import ENV from tests.pyathena.conftest import connect @@ -311,3 +312,194 @@ def test_contain_tab_character(self, s3fs_cursor): ) result = s3fs_cursor.fetchone() assert result == ("a\tb\nc",) + + @pytest.mark.parametrize( + "csv_reader_class", + [DefaultCSVReader, AthenaCSVReader], + ) + def test_basic_query_with_reader(self, csv_reader_class): + """Both readers should work for basic queries.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": csv_reader_class}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT * FROM one_row") + result = cursor.fetchall() + assert result == [(1,)] + + @pytest.mark.parametrize( + "csv_reader_class", + [DefaultCSVReader, AthenaCSVReader], + ) + def test_multiple_columns_with_reader(self, csv_reader_class): + """Test multiple columns work with both readers.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": csv_reader_class}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute( + """ + SELECT + 1 AS col_int, + 'text' AS col_string, + 1.5 AS col_double + """ + ) + result = cursor.fetchone() + assert result == (1, "text", 1.5) + + def test_null_with_default_reader(self): + """DefaultCSVReader: NULL is returned as None.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": DefaultCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT NULL AS null_col") + result = cursor.fetchone() + assert result == (None,) + + def test_null_with_athena_reader(self): + """AthenaCSVReader: NULL is returned as None.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": AthenaCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT NULL AS null_col") + result = cursor.fetchone() + assert result == (None,) + + def test_empty_string_with_default_reader(self): + """DefaultCSVReader: Empty string becomes None (loses distinction).""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": DefaultCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT '' AS empty_col") + result = cursor.fetchone() + # DefaultCSVReader treats empty string same as NULL + assert result == (None,) + + def test_empty_string_with_athena_reader(self): + """AthenaCSVReader: Empty string is preserved as empty string.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": AthenaCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT '' AS empty_col") + result = cursor.fetchone() + # AthenaCSVReader preserves empty string as '' + assert result == ("",) + + def test_null_vs_empty_string_with_default_reader(self): + """DefaultCSVReader: Both NULL and empty string become None.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": DefaultCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT NULL AS null_col, '' AS empty_col") + result = cursor.fetchone() + # Both become None + assert result == (None, None) + + def test_null_vs_empty_string_with_athena_reader(self): + """AthenaCSVReader: NULL and empty string are distinct.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": AthenaCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT NULL AS null_col, '' AS empty_col") + result = cursor.fetchone() + # NULL is None, empty string is '' + assert result == (None, "") + + def test_mixed_values_with_athena_reader(self): + """AthenaCSVReader: Mixed NULL, empty string, and regular values.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": AthenaCSVReader}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute( + """ + SELECT + NULL AS null_col, + '' AS empty_col, + 'text' AS text_col, + NULL AS null_col2 + """ + ) + result = cursor.fetchone() + assert result == (None, "", "text", None) + + @pytest.mark.parametrize( + "csv_reader_class", + [DefaultCSVReader, AthenaCSVReader], + ) + def test_quoted_string_with_comma(self, csv_reader_class): + """Both readers should handle strings containing commas.""" + with ( + contextlib.closing( + connect( + schema_name=ENV.schema, + cursor_class=S3FSCursor, + cursor_kwargs={"csv_reader": csv_reader_class}, + ) + ) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT 'a,b,c' AS col_with_commas") + result = cursor.fetchone() + assert result == ("a,b,c",) diff --git a/tests/pyathena/s3fs/test_reader.py b/tests/pyathena/s3fs/test_reader.py new file mode 100644 index 00000000..5bd7562d --- /dev/null +++ b/tests/pyathena/s3fs/test_reader.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +from io import StringIO + +from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader + + +class TestDefaultCSVReader: + """Tests for DefaultCSVReader using Python's standard csv module.""" + + def test_basic_parsing(self): + data = StringIO("a,b,c\n1,2,3\n") + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", "c"], ["1", "2", "3"]] + + def test_tab_delimiter(self): + data = StringIO("a\tb\tc\n1\t2\t3\n") + reader = DefaultCSVReader(data, delimiter="\t") + rows = list(reader) + assert rows == [["a", "b", "c"], ["1", "2", "3"]] + + def test_empty_field_returns_empty_string(self): + """DefaultCSVReader returns empty string for both NULL and empty string.""" + # Both ,, (NULL) and ,"", (empty string) become '' + data = StringIO('a,,b,"",c\n') + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + # Standard csv.reader treats both as empty strings + assert rows == [["a", "", "b", "", "c"]] + + def test_quoted_field_with_comma(self): + data = StringIO('"a,b",c\n') + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a,b", "c"]] + + def test_quoted_field_with_newline(self): + data = StringIO('"line1\nline2",b\n') + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["line1\nline2", "b"]] + + def test_escaped_quote(self): + data = StringIO('"a""b",c\n') + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [['a"b', "c"]] + + def test_empty_file(self): + data = StringIO("") + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [] + + def test_empty_line(self): + """Empty line returns single empty string (e.g., SELECT NULL).""" + # Python's csv.reader returns [] for empty lines; we normalize to [''] + data = StringIO("\n") + reader = DefaultCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [[""]] + + +class TestAthenaCSVReader: + """Tests for AthenaCSVReader that distinguishes NULL from empty string.""" + + def test_basic_parsing(self): + data = StringIO("a,b,c\n1,2,3\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", "c"], ["1", "2", "3"]] + + def test_tab_delimiter(self): + data = StringIO("a\tb\tc\n1\t2\t3\n") + reader = AthenaCSVReader(data, delimiter="\t") + rows = list(reader) + assert rows == [["a", "b", "c"], ["1", "2", "3"]] + + def test_null_vs_empty_string(self): + """AthenaCSVReader distinguishes NULL (unquoted empty) from empty string (quoted empty).""" + # ,, is NULL, ,"", is empty string + data = StringIO('a,,b,"",c\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + # Field at index 1 is NULL (unquoted), field at index 3 is empty string (quoted) + assert rows == [["a", None, "b", "", "c"]] + + def test_null_at_start(self): + """NULL value at the start of a line.""" + data = StringIO(",b,c\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [[None, "b", "c"]] + + def test_null_at_end(self): + """NULL value at the end of a line (trailing comma).""" + data = StringIO("a,b,\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", None]] + + def test_no_trailing_null(self): + """No trailing NULL when line ends with value.""" + data = StringIO("a,b,c\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", "c"]] + + def test_empty_string_at_start(self): + """Empty string at the start of a line.""" + data = StringIO('"",b,c\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["", "b", "c"]] + + def test_empty_string_at_end(self): + """Empty string at the end of a line.""" + data = StringIO('a,b,""\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", ""]] + + def test_all_nulls(self): + """Row with all NULL values.""" + data = StringIO(",,\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [[None, None, None]] + + def test_all_empty_strings(self): + """Row with all empty string values.""" + data = StringIO('"","",""\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["", "", ""]] + + def test_quoted_field_with_comma(self): + data = StringIO('"a,b",c\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a,b", "c"]] + + def test_quoted_field_with_delimiter(self): + """Quoted field containing the tab delimiter.""" + data = StringIO('"a\tb"\tc\n') + reader = AthenaCSVReader(data, delimiter="\t") + rows = list(reader) + assert rows == [["a\tb", "c"]] + + def test_escaped_quote(self): + """Escaped quote inside quoted field.""" + data = StringIO('"a""b",c\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [['a"b', "c"]] + + def test_empty_file(self): + data = StringIO("") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [] + + def test_mixed_null_and_values(self): + """Multiple rows with mixed NULL and regular values.""" + data = StringIO('1,,"text"\n,4,\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["1", None, "text"], [None, "4", None]] + + def test_crlf_line_endings(self): + """Handle Windows-style line endings.""" + data = StringIO("a,b,c\r\n1,2,3\r\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b", "c"], ["1", "2", "3"]] + + def test_single_null_empty_line(self): + """Empty line represents a single NULL value (e.g., SELECT NULL).""" + # Athena outputs empty line for single NULL value after header + data = StringIO("\n") + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [[None]] + + def test_quoted_field_with_newline(self): + """Quoted field containing a newline character.""" + data = StringIO('"line1\nline2",b\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["line1\nline2", "b"]] + + def test_quoted_field_with_multiple_newlines(self): + """Quoted field containing multiple newline characters.""" + data = StringIO('"line1\nline2\nline3",b,c\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["line1\nline2\nline3", "b", "c"]] + + def test_quoted_field_with_crlf(self): + """Quoted field containing CRLF line ending.""" + data = StringIO('"line1\r\nline2",b\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["line1\r\nline2", "b"]] + + def test_multiple_rows_with_multiline_field(self): + """Multiple rows where some have multi-line quoted fields.""" + data = StringIO('a,b\n"multi\nline",c\nd,e\n') + reader = AthenaCSVReader(data, delimiter=",") + rows = list(reader) + assert rows == [["a", "b"], ["multi\nline", "c"], ["d", "e"]]