diff --git a/pyathena/s3fs/reader.py b/pyathena/s3fs/reader.py index 31424a13..11f29b15 100644 --- a/pyathena/s3fs/reader.py +++ b/pyathena/s3fs/reader.py @@ -2,10 +2,11 @@ from __future__ import annotations import csv -from typing import Any, Iterator, List, Optional, Tuple +from collections.abc import Iterator +from typing import Any, List, Optional, Tuple -class DefaultCSVReader: +class DefaultCSVReader(Iterator[List[str]]): """CSV reader using Python's standard csv module. This reader wraps Python's standard csv.reader and treats empty fields @@ -33,9 +34,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None: file_obj: File-like object to read from. delimiter: Field delimiter character. """ + self._file: Optional[Any] = file_obj self._reader = csv.reader(file_obj, delimiter=delimiter) - def __iter__(self) -> Iterator[List[str]]: + def __iter__(self) -> "DefaultCSVReader": """Iterate over rows in the CSV file.""" return self @@ -46,8 +48,10 @@ def __next__(self) -> List[str]: List of field values as strings. Raises: - StopIteration: When end of file is reached. + StopIteration: When end of file is reached or reader is closed. """ + if self._file is None: + raise StopIteration row = next(self._reader) # Python's csv.reader returns [] for empty lines; normalize to [''] # to represent a single empty field (consistent with single-value handling) @@ -55,8 +59,22 @@ def __next__(self) -> List[str]: return [""] return row + def close(self) -> None: + """Close the underlying file object.""" + if self._file is not None: + self._file.close() + self._file = None + + def __enter__(self) -> "DefaultCSVReader": + """Enter context manager.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit context manager and close resources.""" + self.close() + -class AthenaCSVReader: +class AthenaCSVReader(Iterator[List[Optional[str]]]): """CSV reader that distinguishes between NULL and empty string. This is the default reader for S3FSCursor. @@ -87,10 +105,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None: file_obj: File-like object to read from. delimiter: Field delimiter character. """ - self._file = file_obj + self._file: Optional[Any] = file_obj self._delimiter = delimiter - def __iter__(self) -> Iterator[List[Optional[str]]]: + def __iter__(self) -> "AthenaCSVReader": """Iterate over rows in the CSV file.""" return self @@ -101,8 +119,10 @@ def __next__(self) -> List[Optional[str]]: List of field values, with None for NULL and '' for empty string. Raises: - StopIteration: When end of file is reached. + StopIteration: When end of file is reached or reader is closed. """ + if self._file is None: + raise StopIteration line = self._file.readline() if not line: raise StopIteration @@ -234,3 +254,17 @@ def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]: pos += 1 return value, pos + + def close(self) -> None: + """Close the underlying file object.""" + if self._file is not None: + self._file.close() + self._file = None + + def __enter__(self) -> "AthenaCSVReader": + """Enter context manager.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Exit context manager and close resources.""" + self.close() diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index e2651f88..6a1cf683 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import contextlib import logging from io import TextIOWrapper from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -80,7 +79,6 @@ def __init__( self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader self._fs = self._create_s3_file_system() self._csv_reader: Optional[Any] = None - self._csv_file: Optional[Any] = None if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: self._init_csv_reader() @@ -122,8 +120,8 @@ def _init_csv_reader(self) -> None: path = f"{bucket}/{key}" try: - self._csv_file = self._fs._open(path, mode="rb") - text_wrapper = TextIOWrapper(self._csv_file, encoding="utf-8") + csv_file = self._fs._open(path, mode="rb") + text_wrapper = TextIOWrapper(csv_file, encoding="utf-8") if self.output_location.endswith(".txt"): # Tab-separated format (no header row) @@ -131,8 +129,7 @@ def _init_csv_reader(self) -> None: else: # Standard CSV format (has header row, skip it) self._csv_reader = self._csv_reader_class(text_wrapper, delimiter=",") - with contextlib.suppress(StopIteration): - next(self._csv_reader) + next(self._csv_reader) except Exception as e: _logger.exception(f"Failed to open {path}.") @@ -228,8 +225,6 @@ def fetchall( def close(self) -> None: """Close the result set and release resources.""" super().close() - if self._csv_file: - with contextlib.suppress(Exception): - self._csv_file.close() - self._csv_file = None - self._csv_reader = None + if self._csv_reader: + self._csv_reader.close() + self._csv_reader = None diff --git a/tests/pyathena/s3fs/test_reader.py b/tests/pyathena/s3fs/test_reader.py index 5bd7562d..58383525 100644 --- a/tests/pyathena/s3fs/test_reader.py +++ b/tests/pyathena/s3fs/test_reader.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from collections.abc import Iterator from io import StringIO from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader @@ -60,6 +61,30 @@ def test_empty_line(self): rows = list(reader) assert rows == [[""]] + def test_implements_iterator_protocol(self): + """DefaultCSVReader implements collections.abc.Iterator.""" + data = StringIO("a,b\n") + reader = DefaultCSVReader(data, delimiter=",") + assert isinstance(reader, Iterator) + + def test_close(self): + """close() releases file resources.""" + data = StringIO("a,b\n") + reader = DefaultCSVReader(data, delimiter=",") + reader.close() + # After close, iteration should stop immediately + rows = list(reader) + assert rows == [] + + def test_context_manager(self): + """Reader can be used as context manager.""" + data = StringIO("a,b\n1,2\n") + with DefaultCSVReader(data, delimiter=",") as reader: + rows = list(reader) + assert rows == [["a", "b"], ["1", "2"]] + # After exiting context, reader should be closed + assert list(reader) == [] + class TestAthenaCSVReader: """Tests for AthenaCSVReader that distinguishes NULL from empty string.""" @@ -209,3 +234,27 @@ def test_multiple_rows_with_multiline_field(self): reader = AthenaCSVReader(data, delimiter=",") rows = list(reader) assert rows == [["a", "b"], ["multi\nline", "c"], ["d", "e"]] + + def test_implements_iterator_protocol(self): + """AthenaCSVReader implements collections.abc.Iterator.""" + data = StringIO("a,b\n") + reader = AthenaCSVReader(data, delimiter=",") + assert isinstance(reader, Iterator) + + def test_close(self): + """close() releases file resources.""" + data = StringIO("a,b\n") + reader = AthenaCSVReader(data, delimiter=",") + reader.close() + # After close, iteration should stop immediately + rows = list(reader) + assert rows == [] + + def test_context_manager(self): + """Reader can be used as context manager.""" + data = StringIO("a,b\n1,2\n") + with AthenaCSVReader(data, delimiter=",") as reader: + rows = list(reader) + assert rows == [["a", "b"], ["1", "2"]] + # After exiting context, reader should be closed + assert list(reader) == []