diff --git a/docs/polars.rst b/docs/polars.rst index 55c7caf6..fb0d59b7 100644 --- a/docs/polars.rst +++ b/docs/polars.rst @@ -246,6 +246,94 @@ SQLAlchemy allows this option to be specified in the connection string. NOTE: PolarsCursor handles the CSV file on memory. Pay attention to the memory capacity. +Chunksize Options +~~~~~~~~~~~~~~~~~ + +PolarsCursor supports memory-efficient chunked processing of large query results +using Polars' native lazy evaluation APIs. This allows processing datasets that +are too large to fit in memory. + +The chunksize option can be enabled by specifying an integer value in the ``cursor_kwargs`` +argument of the connect method or as an argument to the cursor method. + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor, + cursor_kwargs={ + "chunksize": 50_000 + }).cursor() + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor(chunksize=50_000) + +When the chunksize option is enabled, data is loaded lazily in chunks. This applies +to all data access methods: + +**Standard DB-API fetch methods** - ``fetchone()`` and ``fetchmany()`` load data +chunk by chunk as needed, keeping memory usage bounded: + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor(chunksize=50_000) + + cursor.execute("SELECT * FROM large_table") + # Data is loaded in 50,000 row chunks as you iterate + for row in cursor: + process_row(row) + +**iter_chunks() method** - Use this when you want to process data as Polars DataFrames +in chunks, which is more efficient for batch processing: + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor(chunksize=50_000) + + cursor.execute("SELECT * FROM large_table") + for chunk in cursor.iter_chunks(): + # Process each chunk - chunk is a polars.DataFrame + processed = chunk.group_by('category').agg(pl.sum('value')) + print(f"Processed chunk with {chunk.height} rows") + +This method uses Polars' ``scan_csv()`` and ``scan_parquet()`` with ``collect_batches()`` +for efficient lazy evaluation, minimizing memory usage when processing large datasets. + +The chunked iteration also works with the unload option: + +.. code:: python + + from pyathena import connect + from pyathena.polars.cursor import PolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=PolarsCursor).cursor(chunksize=100_000, unload=True) + + cursor.execute("SELECT * FROM huge_table") + for chunk in cursor.iter_chunks(): + # Process Parquet data in chunks + process_chunk(chunk) + .. _async-polars-cursor: AsyncPolarsCursor @@ -414,6 +502,42 @@ As with AsyncPolarsCursor, the unload option is also available. region_name="us-west-2", cursor_class=AsyncPolarsCursor).cursor(unload=True) +As with PolarsCursor, the chunksize option is also available for memory-efficient processing. +When chunksize is specified, data is loaded lazily in chunks for both standard fetch methods +and ``iter_chunks()``. + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor(chunksize=50_000) + + query_id, future = cursor.execute("SELECT * FROM large_table") + result_set = future.result() + + # Standard iteration - data loaded in chunks + for row in result_set: + process_row(row) + +.. code:: python + + from pyathena import connect + from pyathena.polars.async_cursor import AsyncPolarsCursor + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + cursor_class=AsyncPolarsCursor).cursor(chunksize=50_000) + + query_id, future = cursor.execute("SELECT * FROM large_table") + result_set = future.result() + + # Process as DataFrame chunks + for chunk in result_set.iter_chunks(): + process_chunk(chunk) + .. _`polars.DataFrame object`: https://docs.pola.rs/api/python/stable/reference/dataframe/index.html .. _`Polars`: https://pola.rs/ .. _`Unload options`: arrow.html#unload-options diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index 2fc3ee6e..d04fe08f 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -73,6 +73,7 @@ def __init__( result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, block_size: Optional[int] = None, cache_type: Optional[str] = None, + chunksize: Optional[int] = None, **kwargs, ) -> None: """Initialize an AsyncPolarsCursor. @@ -93,10 +94,15 @@ def __init__( result_reuse_minutes: Minutes to reuse cached results. block_size: S3 read block size. cache_type: S3 caching strategy. + chunksize: Number of rows per chunk for memory-efficient processing. + If specified, data is loaded lazily in chunks for all data + access methods including fetchone(), fetchmany(), and iter_chunks(). **kwargs: Additional connection parameters. Example: >>> cursor = connection.cursor(AsyncPolarsCursor, unload=True) + >>> # With chunked processing + >>> cursor = connection.cursor(AsyncPolarsCursor, chunksize=50000) """ super().__init__( s3_staging_dir=s3_staging_dir, @@ -116,6 +122,7 @@ def __init__( self._unload = unload self._block_size = block_size self._cache_type = cache_type + self._chunksize = chunksize @staticmethod def get_default_converter( @@ -172,6 +179,7 @@ def _collect_result_set( block_size=self._block_size, cache_type=self._cache_type, max_workers=self._max_workers, + chunksize=self._chunksize, **kwargs, ) diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index e8877e2e..d25bfc85 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -3,7 +3,18 @@ import logging from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) from pyathena.common import BaseCursor, CursorIterator from pyathena.error import OperationalError, ProgrammingError @@ -74,6 +85,7 @@ def __init__( block_size: Optional[int] = None, cache_type: Optional[str] = None, max_workers: int = (cpu_count() or 1) * 5, + chunksize: Optional[int] = None, **kwargs, ) -> None: """Initialize a PolarsCursor. @@ -94,10 +106,15 @@ def __init__( block_size: S3 read block size. cache_type: S3 caching strategy. max_workers: Maximum worker threads for parallel S3 operations. + chunksize: Number of rows per chunk for memory-efficient processing. + If specified, data is loaded lazily in chunks for all data + access methods including fetchone(), fetchmany(), and iter_chunks(). **kwargs: Additional connection parameters. Example: >>> cursor = connection.cursor(PolarsCursor, unload=True) + >>> # With chunked processing + >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) """ super().__init__( s3_staging_dir=s3_staging_dir, @@ -117,6 +134,7 @@ def __init__( self._block_size = block_size self._cache_type = cache_type self._max_workers = max_workers + self._chunksize = chunksize self._query_id: Optional[str] = None self._result_set: Optional[AthenaPolarsResultSet] = None @@ -272,6 +290,7 @@ def execute( block_size=self._block_size, cache_type=self._cache_type, max_workers=self._max_workers, + chunksize=self._chunksize, **kwargs, ) else: @@ -404,3 +423,37 @@ def as_arrow(self) -> "Table": raise ProgrammingError("No result set.") result_set = cast(AthenaPolarsResultSet, self.result_set) return result_set.as_arrow() + + def iter_chunks(self) -> Iterator["pl.DataFrame"]: + """Iterate over result chunks as Polars DataFrames. + + This method provides an iterator interface for processing result sets. + When chunksize is specified, it yields DataFrames in chunks using lazy + evaluation for memory-efficient processing. When chunksize is not specified, + it yields the entire result as a single DataFrame, providing a consistent + interface regardless of chunking configuration. + + Yields: + Polars DataFrame for each chunk of rows, or the entire DataFrame + if chunksize was not specified. + + Raises: + ProgrammingError: If no result set is available. + + Example: + >>> # With chunking for large datasets + >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) + >>> cursor.execute("SELECT * FROM large_table") + >>> for chunk in cursor.iter_chunks(): + ... process_chunk(chunk) # Each chunk is a Polars DataFrame + >>> + >>> # Without chunking - yields entire result as single chunk + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM small_table") + >>> for df in cursor.iter_chunks(): + ... process(df) # Single DataFrame with all data + """ + if not self.has_result_set: + raise ProgrammingError("No result set.") + result_set = cast(AthenaPolarsResultSet, self.result_set) + yield from result_set.iter_chunks() diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py index 6c4d18c2..7cd27417 100644 --- a/pyathena/polars/result_set.py +++ b/pyathena/polars/result_set.py @@ -2,16 +2,19 @@ from __future__ import annotations import logging +from collections import abc from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, Callable, Dict, + Iterator, List, Optional, Tuple, Union, + cast, ) from pyathena import OperationalError @@ -31,6 +34,120 @@ _logger = logging.getLogger(__name__) +class DataFrameIterator(abc.Iterator): # type: ignore + """Iterator for chunked DataFrame results from Athena queries. + + This class wraps either a Polars DataFrame iterator (for chunked reading) or + a single DataFrame, providing a unified iterator interface. It applies + optional type conversion to each DataFrame chunk as it's yielded. + + The iterator is used by AthenaPolarsResultSet to provide chunked access + to large query results, enabling memory-efficient processing of datasets + that would be too large to load entirely into memory. + + Example: + >>> # Iterate over DataFrame chunks + >>> for df_chunk in iterator: + ... process(df_chunk) + >>> + >>> # Iterate over individual rows + >>> for idx, row in iterator.iterrows(): + ... print(row) + + Note: + This class is primarily for internal use by AthenaPolarsResultSet. + Most users should access results through PolarsCursor methods. + """ + + def __init__( + self, + reader: Union[Iterator["pl.DataFrame"], "pl.DataFrame"], + converters: Dict[str, Callable[[Optional[str]], Optional[Any]]], + column_names: List[str], + ) -> None: + """Initialize the iterator. + + Args: + reader: Either a DataFrame iterator (for chunked) or a single DataFrame. + converters: Dictionary mapping column names to converter functions. + column_names: List of column names in order. + """ + import polars as pl + + if isinstance(reader, pl.DataFrame): + self._reader: Iterator["pl.DataFrame"] = iter([reader]) + else: + self._reader = reader + self._converters = converters + self._column_names = column_names + self._closed = False + + def __next__(self) -> "pl.DataFrame": + """Get the next DataFrame chunk. + + Returns: + The next Polars DataFrame chunk. + + Raises: + StopIteration: When no more chunks are available. + """ + if self._closed: + raise StopIteration + try: + return next(self._reader) + except StopIteration: + self.close() + raise + + def __iter__(self) -> "DataFrameIterator": + """Return self as iterator.""" + return self + + def __enter__(self) -> "DataFrameIterator": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """Context manager exit.""" + self.close() + + def close(self) -> None: + """Close the iterator and release resources.""" + self._closed = True + + def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]: + """Iterate over rows as (index, row_dict) tuples. + + Yields: + Tuple of (row_index, row_dict) for each row across all chunks. + """ + row_num = 0 + for df in self: + for row_dict in df.iter_rows(named=True): + # Apply converters + processed_row = { + col: self._converters.get(col, lambda x: x)(row_dict.get(col)) + for col in self._column_names + } + yield (row_num, processed_row) + row_num += 1 + + def as_polars(self) -> "pl.DataFrame": + """Collect all chunks into a single DataFrame. + + Returns: + Single Polars DataFrame containing all data. + """ + import polars as pl + + dfs = cast(List["pl.DataFrame"], list(self)) + if not dfs: + return pl.DataFrame() + if len(dfs) == 1: + return dfs[0] + return pl.concat(dfs) + + class AthenaPolarsResultSet(AthenaResultSet): """Result set that provides Polars DataFrame results with optional Arrow interoperability. @@ -44,6 +161,7 @@ class AthenaPolarsResultSet(AthenaResultSet): - Efficient columnar data processing with Polars - Optional Arrow interoperability when PyArrow is available - Support for both CSV and Parquet result formats + - Chunked iteration for memory-efficient processing of large datasets - Optimized memory usage through columnar format Example: @@ -60,6 +178,12 @@ class AthenaPolarsResultSet(AthenaResultSet): >>> >>> # Optional: Get Arrow Table (requires pyarrow) >>> table = cursor.as_arrow() + >>> + >>> # Memory-efficient chunked iteration + >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) + >>> cursor.execute("SELECT * FROM huge_table") + >>> for chunk in cursor.iter_chunks(): + ... process_chunk(chunk) Note: This class is used internally by PolarsCursor and typically not @@ -79,6 +203,7 @@ def __init__( block_size: Optional[int] = None, cache_type: Optional[str] = None, max_workers: int = (cpu_count() or 1) * 5, + chunksize: Optional[int] = None, **kwargs, ) -> None: """Initialize the Polars result set. @@ -94,6 +219,9 @@ def __init__( block_size: Block size for S3 file reading. cache_type: Cache type for S3 file system. max_workers: Maximum number of worker threads. + chunksize: Number of rows per chunk for memory-efficient processing. + If specified, data is loaded lazily in chunks for all data + access methods including fetchone(), fetchmany(), and iter_chunks(). **kwargs: Additional arguments passed to Polars read functions. """ super().__init__( @@ -110,14 +238,19 @@ def __init__( self._block_size = block_size self._cache_type = cache_type self._max_workers = max_workers + self._chunksize = chunksize self._kwargs = kwargs + + # Build DataFrame iterator (handles both chunked and non-chunked cases) if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: - self._df = self._as_polars() + self._df_iter = self._create_dataframe_iterator() else: import polars as pl - self._df = pl.DataFrame() - self._row_index = 0 + self._df_iter = DataFrameIterator( + pl.DataFrame(), self.converters, self._get_column_names() + ) + self._iterrows = self._df_iter.iterrows() @property def _csv_storage_options(self) -> Dict[str, Any]: @@ -178,23 +311,31 @@ def converters(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: description = self.description if self.description else [] return {d[0]: self._converter.get(d[1]) for d in description} - def _fetch(self) -> None: - """Fetch rows from the DataFrame into the row buffer.""" - if self._row_index >= self._df.height: - return + def _get_column_names(self) -> List[str]: + """Get column names from description. - end_index = min(self._row_index + self._arraysize, self._df.height) - chunk = self._df.slice(self._row_index, end_index - self._row_index) - self._row_index = end_index - - # Convert to rows and apply converters + Returns: + List of column names. + """ description = self.description if self.description else [] - column_names = [d[0] for d in description] - for row_dict in chunk.iter_rows(named=True): - processed_row = tuple( - self.converters.get(col, lambda x: x)(row_dict.get(col)) for col in column_names + return [d[0] for d in description] + + def _create_dataframe_iterator(self) -> DataFrameIterator: + """Create a DataFrame iterator for the result set. + + Returns: + DataFrameIterator that handles both chunked and non-chunked cases. + """ + if self._chunksize is not None: + # Chunked mode: create lazy iterator + reader: Union[Iterator["pl.DataFrame"], "pl.DataFrame"] = ( + self._iter_parquet_chunks() if self.is_unload else self._iter_csv_chunks() ) - self._rows.append(processed_row) + else: + # Non-chunked mode: load entire DataFrame + reader = self._as_polars() + + return DataFrameIterator(reader, self.converters, self._get_column_names()) def fetchone( self, @@ -204,14 +345,14 @@ def fetchone( Returns: A single row as a tuple, or None if no more rows are available. """ - if not self._rows: - self._fetch() - if not self._rows: + try: + row = next(self._iterrows) + except StopIteration: return None - if self._rownumber is None: - self._rownumber = 0 - self._rownumber += 1 - return self._rows.popleft() + else: + self._rownumber = row[0] + 1 + column_names = self._get_column_names() + return tuple([row[1][col] for col in column_names]) def fetchmany( self, size: Optional[int] = None @@ -252,45 +393,62 @@ def fetchall( break return rows - def _read_csv(self) -> "pl.DataFrame": - """Read query results from CSV file in S3. + def _is_csv_readable(self) -> bool: + """Check if CSV output is available and can be read. Returns: - Polars DataFrame containing the CSV data. + True if CSV data is available to read, False otherwise. Raises: ProgrammingError: If output location is not set. - OperationalError: If reading the CSV file fails. """ - import polars as pl - if not self.output_location: raise ProgrammingError("OutputLocation is none or empty.") if not self.output_location.endswith((".csv", ".txt")): - return pl.DataFrame() + return False if self.substatement_type and self.substatement_type.upper() in ( "UPDATE", "DELETE", "MERGE", "VACUUM_TABLE", ): - return pl.DataFrame() + return False length = self._get_content_length() - if length == 0: - return pl.DataFrame() + return length != 0 - if self.output_location.endswith(".txt"): - separator = "\t" - has_header = False - description = self.description if self.description else [] - new_columns = [d[0] for d in description] - elif self.output_location.endswith(".csv"): - separator = "," - has_header = True - new_columns = None - else: + def _prepare_parquet_location(self) -> bool: + """Prepare unload location for Parquet reading. + + Returns: + True if Parquet data is available to read, False otherwise. + """ + manifests = self._read_data_manifest() + if not manifests: + return False + if not self._unload_location: + self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" + return True + + def _read_csv(self) -> "pl.DataFrame": + """Read query results from CSV file in S3. + + Returns: + Polars DataFrame containing the CSV data. + + Raises: + ProgrammingError: If output location is not set. + OperationalError: If reading the CSV file fails. + """ + import polars as pl + + if not self._is_csv_readable(): return pl.DataFrame() + # After validation, output_location is guaranteed to be set + assert self.output_location is not None + + separator, has_header, new_columns = self._get_csv_params() + try: df = pl.read_csv( self.output_location, @@ -318,11 +476,11 @@ def _read_parquet(self) -> "pl.DataFrame": """ import polars as pl - manifests = self._read_data_manifest() - if not manifests: + if not self._prepare_parquet_location(): return pl.DataFrame() - if not self._unload_location: - self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" + + # After preparation, unload_location is guaranteed to be set + assert self._unload_location is not None try: return pl.read_parquet( @@ -377,6 +535,11 @@ def as_polars(self) -> "pl.DataFrame": Returns the query results as a Polars DataFrame. This is the primary method for accessing results with PolarsCursor. + Note: + When chunksize is set, calling this method will collect all chunks + into a single DataFrame, loading all data into memory. Use + iter_chunks() for memory-efficient processing of large datasets. + Returns: Polars DataFrame containing all query results. @@ -387,7 +550,7 @@ def as_polars(self) -> "pl.DataFrame": >>> print(f"DataFrame has {df.height} rows") >>> filtered = df.filter(pl.col("value") > 100) """ - return self._df + return self._df_iter.as_polars() def as_arrow(self) -> "Table": """Return query results as an Apache Arrow Table. @@ -408,16 +571,127 @@ def as_arrow(self) -> "Table": >>> # Use with other Arrow-compatible libraries """ try: - return self._df.to_arrow() + return self._df_iter.as_polars().to_arrow() except ImportError as e: raise ImportError( "pyarrow is required for as_arrow(). Install it with: pip install pyarrow" ) from e + def _get_csv_params(self) -> Tuple[str, bool, Optional[List[str]]]: + """Get CSV parsing parameters based on file type. + + Returns: + Tuple of (separator, has_header, new_columns). + """ + if self.output_location and self.output_location.endswith(".txt"): + separator = "\t" + has_header = False + new_columns: Optional[List[str]] = self._get_column_names() + else: + separator = "," + has_header = True + new_columns = None + return separator, has_header, new_columns + + def _iter_csv_chunks(self) -> Iterator["pl.DataFrame"]: + """Iterate over CSV data in chunks using lazy evaluation. + + Yields: + Polars DataFrame for each chunk. + + Raises: + ProgrammingError: If output location is not set. + OperationalError: If reading the CSV file fails. + """ + import polars as pl + + if not self._is_csv_readable(): + return + + # After validation, output_location is guaranteed to be set + assert self.output_location is not None + + separator, has_header, new_columns = self._get_csv_params() + + try: + # scan_csv uses Rust's native object_store (like scan_parquet), + # not fsspec, so we use the same storage options as Parquet + lazy_df = pl.scan_csv( + self.output_location, + separator=separator, + has_header=has_header, + schema_overrides=self.dtypes, + storage_options=self._parquet_storage_options, + **self._kwargs, + ) + for batch in lazy_df.collect_batches(chunk_size=self._chunksize): + if new_columns: + batch.columns = new_columns + yield batch + except Exception as e: + _logger.exception(f"Failed to read {self.output_location}.") + raise OperationalError(*e.args) from e + + def _iter_parquet_chunks(self) -> Iterator["pl.DataFrame"]: + """Iterate over Parquet data in chunks using lazy evaluation. + + Yields: + Polars DataFrame for each chunk. + + Raises: + OperationalError: If reading the Parquet files fails. + """ + import polars as pl + + if not self._prepare_parquet_location(): + return + + # After preparation, unload_location is guaranteed to be set + assert self._unload_location is not None + + try: + lazy_df = pl.scan_parquet( + self._unload_location, + storage_options=self._parquet_storage_options, + **self._kwargs, + ) + for batch in lazy_df.collect_batches(chunk_size=self._chunksize): + yield batch + except Exception as e: + _logger.exception(f"Failed to read {self._unload_location}.") + raise OperationalError(*e.args) from e + + def iter_chunks(self) -> DataFrameIterator: + """Iterate over result chunks as Polars DataFrames. + + This method provides an iterator interface for processing large result sets. + When chunksize is specified, it yields DataFrames in chunks using lazy + evaluation for memory-efficient processing. When chunksize is not specified, + it yields the entire result as a single DataFrame. + + Returns: + DataFrameIterator that yields Polars DataFrames for each chunk + of rows, or the entire DataFrame if chunksize was not specified. + + Example: + >>> # With chunking for large datasets + >>> cursor = connection.cursor(PolarsCursor, chunksize=50000) + >>> cursor.execute("SELECT * FROM large_table") + >>> for chunk in cursor.iter_chunks(): + ... process_chunk(chunk) # Each chunk is a Polars DataFrame + >>> + >>> # Without chunking - yields entire result as single chunk + >>> cursor = connection.cursor(PolarsCursor) + >>> cursor.execute("SELECT * FROM small_table") + >>> for df in cursor.iter_chunks(): + ... process(df) # Single DataFrame with all data + """ + return self._df_iter + def close(self) -> None: """Close the result set and release resources.""" import polars as pl super().close() - self._df = pl.DataFrame() - self._row_index = 0 + self._df_iter = DataFrameIterator(pl.DataFrame(), {}, []) + self._iterrows = iter([]) diff --git a/tests/pyathena/polars/test_cursor.py b/tests/pyathena/polars/test_cursor.py index cd19c958..447b676b 100644 --- a/tests/pyathena/polars/test_cursor.py +++ b/tests/pyathena/polars/test_cursor.py @@ -447,3 +447,180 @@ def test_callback(query_id: str): assert len(callback_results) == 1 assert callback_results[0] == polars_cursor.query_id assert polars_cursor.query_id is not None + + def test_iter_chunks(self): + """Test chunked iteration over query results.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=5) + cursor.execute("SELECT * FROM many_rows LIMIT 15") + chunks = list(cursor.iter_chunks()) + assert len(chunks) > 0 + total_rows = sum(chunk.height for chunk in chunks) + assert total_rows == 15 + for chunk in chunks: + assert isinstance(chunk, pl.DataFrame) + + def test_iter_chunks_without_chunksize(self, polars_cursor): + """Test that iter_chunks works without chunksize, yielding entire DataFrame.""" + polars_cursor.execute("SELECT * FROM one_row") + chunks = list(polars_cursor.iter_chunks()) + # Without chunksize, yields entire DataFrame as single chunk + assert len(chunks) == 1 + assert isinstance(chunks[0], pl.DataFrame) + assert chunks[0].height == 1 + + def test_iter_chunks_many_rows(self): + """Test chunked iteration with many rows.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=1000) + cursor.execute("SELECT * FROM many_rows") + chunks = list(cursor.iter_chunks()) + total_rows = sum(chunk.height for chunk in chunks) + assert total_rows == 10000 + assert len(chunks) >= 10 # At least 10 chunks with chunksize=1000 + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True, "chunksize": 5}, + }, + ], + indirect=["polars_cursor"], + ) + def test_iter_chunks_unload(self, polars_cursor): + """Test chunked iteration with UNLOAD (Parquet).""" + polars_cursor.execute("SELECT * FROM many_rows LIMIT 15") + chunks = list(polars_cursor.iter_chunks()) + assert len(chunks) > 0 + total_rows = sum(chunk.height for chunk in chunks) + assert total_rows == 15 + for chunk in chunks: + assert isinstance(chunk, pl.DataFrame) + + def test_iter_chunks_data_consistency(self): + """Test that chunked and regular reading produce the same data.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + # Regular reading (no chunksize) + regular_cursor = conn.cursor(PolarsCursor) + regular_cursor.execute("SELECT * FROM many_rows LIMIT 100") + regular_df = regular_cursor.as_polars() + + # Chunked reading + chunked_cursor = conn.cursor(PolarsCursor, chunksize=25) + chunked_cursor.execute("SELECT * FROM many_rows LIMIT 100") + chunked_dfs = list(chunked_cursor.iter_chunks()) + + # Combine chunks + combined_df = pl.concat(chunked_dfs) + + # Should have the same data (sort for comparison) + assert regular_df.sort("a").equals(combined_df.sort("a")) + + # Should have multiple chunks + assert len(chunked_dfs) > 1 + + def test_iter_chunks_chunk_sizes(self): + """Test that chunks have correct sizes.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=10) + cursor.execute("SELECT * FROM many_rows LIMIT 50") + + chunk_sizes = [] + total_rows = 0 + + for chunk in cursor.iter_chunks(): + chunk_size = chunk.height + chunk_sizes.append(chunk_size) + total_rows += chunk_size + + # Each chunk should not exceed chunksize + assert chunk_size <= 10 + + # Should have processed all 50 rows + assert total_rows == 50 + + # Should have multiple chunks + assert len(chunk_sizes) > 1 + + def test_fetchone_with_chunksize(self): + """Test that fetchone works correctly with chunksize enabled.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=5) + cursor.execute("SELECT * FROM many_rows LIMIT 15") + + rows = [] + while True: + row = cursor.fetchone() + if row is None: + break + rows.append(row) + + assert len(rows) == 15 + + def test_fetchmany_with_chunksize(self): + """Test that fetchmany works correctly with chunksize enabled.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=5) + cursor.execute("SELECT * FROM many_rows LIMIT 15") + + batch1 = cursor.fetchmany(10) + batch2 = cursor.fetchmany(10) + + assert len(batch1) == 10 + assert len(batch2) == 5 + + def test_fetchall_with_chunksize(self): + """Test that fetchall works correctly with chunksize enabled.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=5) + cursor.execute("SELECT * FROM many_rows LIMIT 15") + + rows = cursor.fetchall() + assert len(rows) == 15 + + def test_iterator_with_chunksize(self): + """Test that cursor iteration works correctly with chunksize enabled.""" + with contextlib.closing(connect(schema_name=ENV.schema)) as conn: + cursor = conn.cursor(PolarsCursor, chunksize=5) + cursor.execute("SELECT * FROM many_rows LIMIT 15") + + rows = list(cursor) + assert len(rows) == 15 + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True, "chunksize": 5}, + }, + ], + indirect=["polars_cursor"], + ) + def test_fetchone_with_chunksize_unload(self, polars_cursor): + """Test that fetchone works correctly with chunksize and unload enabled.""" + polars_cursor.execute("SELECT * FROM many_rows LIMIT 15") + + rows = [] + while True: + row = polars_cursor.fetchone() + if row is None: + break + rows.append(row) + + assert len(rows) == 15 + + @pytest.mark.parametrize( + "polars_cursor", + [ + { + "cursor_kwargs": {"unload": True, "chunksize": 5}, + }, + ], + indirect=["polars_cursor"], + ) + def test_iterator_with_chunksize_unload(self, polars_cursor): + """Test that cursor iteration works with chunksize and unload enabled.""" + polars_cursor.execute("SELECT * FROM many_rows LIMIT 15") + rows = list(polars_cursor) + assert len(rows) == 15