Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions pyathena/s3fs/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from __future__ import annotations

import csv
from typing import Any, Iterator, List, Optional, Tuple
from collections.abc import Iterator
from typing import Any, List, Optional, Tuple


class DefaultCSVReader:
class DefaultCSVReader(Iterator[List[str]]):
"""CSV reader using Python's standard csv module.

This reader wraps Python's standard csv.reader and treats empty fields
Expand Down Expand Up @@ -33,9 +34,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
file_obj: File-like object to read from.
delimiter: Field delimiter character.
"""
self._file: Optional[Any] = file_obj
self._reader = csv.reader(file_obj, delimiter=delimiter)

def __iter__(self) -> Iterator[List[str]]:
def __iter__(self) -> "DefaultCSVReader":
"""Iterate over rows in the CSV file."""
return self

Expand All @@ -46,17 +48,33 @@ def __next__(self) -> List[str]:
List of field values as strings.

Raises:
StopIteration: When end of file is reached.
StopIteration: When end of file is reached or reader is closed.
"""
if self._file is None:
raise StopIteration
row = next(self._reader)
# Python's csv.reader returns [] for empty lines; normalize to ['']
# to represent a single empty field (consistent with single-value handling)
if not row:
return [""]
return row

def close(self) -> None:
"""Close the underlying file object."""
if self._file is not None:
self._file.close()
self._file = None

def __enter__(self) -> "DefaultCSVReader":
"""Enter context manager."""
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Exit context manager and close resources."""
self.close()


class AthenaCSVReader:
class AthenaCSVReader(Iterator[List[Optional[str]]]):
"""CSV reader that distinguishes between NULL and empty string.

This is the default reader for S3FSCursor.
Expand Down Expand Up @@ -87,10 +105,10 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None:
file_obj: File-like object to read from.
delimiter: Field delimiter character.
"""
self._file = file_obj
self._file: Optional[Any] = file_obj
self._delimiter = delimiter

def __iter__(self) -> Iterator[List[Optional[str]]]:
def __iter__(self) -> "AthenaCSVReader":
"""Iterate over rows in the CSV file."""
return self

Expand All @@ -101,8 +119,10 @@ def __next__(self) -> List[Optional[str]]:
List of field values, with None for NULL and '' for empty string.

Raises:
StopIteration: When end of file is reached.
StopIteration: When end of file is reached or reader is closed.
"""
if self._file is None:
raise StopIteration
line = self._file.readline()
if not line:
raise StopIteration
Expand Down Expand Up @@ -234,3 +254,17 @@ def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]:
pos += 1

return value, pos

def close(self) -> None:
"""Close the underlying file object."""
if self._file is not None:
self._file.close()
self._file = None

def __enter__(self) -> "AthenaCSVReader":
"""Enter context manager."""
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Exit context manager and close resources."""
self.close()
17 changes: 6 additions & 11 deletions pyathena/s3fs/result_set.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import contextlib
import logging
from io import TextIOWrapper
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -80,7 +79,6 @@ def __init__(
self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader
self._fs = self._create_s3_file_system()
self._csv_reader: Optional[Any] = None
self._csv_file: Optional[Any] = None

if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location:
self._init_csv_reader()
Expand Down Expand Up @@ -122,17 +120,16 @@ def _init_csv_reader(self) -> None:
path = f"{bucket}/{key}"

try:
self._csv_file = self._fs._open(path, mode="rb")
text_wrapper = TextIOWrapper(self._csv_file, encoding="utf-8")
csv_file = self._fs._open(path, mode="rb")
text_wrapper = TextIOWrapper(csv_file, encoding="utf-8")

if self.output_location.endswith(".txt"):
# Tab-separated format (no header row)
self._csv_reader = self._csv_reader_class(text_wrapper, delimiter="\t")
else:
# Standard CSV format (has header row, skip it)
self._csv_reader = self._csv_reader_class(text_wrapper, delimiter=",")
with contextlib.suppress(StopIteration):
next(self._csv_reader)
next(self._csv_reader)

except Exception as e:
_logger.exception(f"Failed to open {path}.")
Expand Down Expand Up @@ -228,8 +225,6 @@ def fetchall(
def close(self) -> None:
"""Close the result set and release resources."""
super().close()
if self._csv_file:
with contextlib.suppress(Exception):
self._csv_file.close()
self._csv_file = None
self._csv_reader = None
if self._csv_reader:
self._csv_reader.close()
self._csv_reader = None
49 changes: 49 additions & 0 deletions tests/pyathena/s3fs/test_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from collections.abc import Iterator
from io import StringIO

from pyathena.s3fs.reader import AthenaCSVReader, DefaultCSVReader
Expand Down Expand Up @@ -60,6 +61,30 @@ def test_empty_line(self):
rows = list(reader)
assert rows == [[""]]

def test_implements_iterator_protocol(self):
"""DefaultCSVReader implements collections.abc.Iterator."""
data = StringIO("a,b\n")
reader = DefaultCSVReader(data, delimiter=",")
assert isinstance(reader, Iterator)

def test_close(self):
"""close() releases file resources."""
data = StringIO("a,b\n")
reader = DefaultCSVReader(data, delimiter=",")
reader.close()
# After close, iteration should stop immediately
rows = list(reader)
assert rows == []

def test_context_manager(self):
"""Reader can be used as context manager."""
data = StringIO("a,b\n1,2\n")
with DefaultCSVReader(data, delimiter=",") as reader:
rows = list(reader)
assert rows == [["a", "b"], ["1", "2"]]
# After exiting context, reader should be closed
assert list(reader) == []


class TestAthenaCSVReader:
"""Tests for AthenaCSVReader that distinguishes NULL from empty string."""
Expand Down Expand Up @@ -209,3 +234,27 @@ def test_multiple_rows_with_multiline_field(self):
reader = AthenaCSVReader(data, delimiter=",")
rows = list(reader)
assert rows == [["a", "b"], ["multi\nline", "c"], ["d", "e"]]

def test_implements_iterator_protocol(self):
"""AthenaCSVReader implements collections.abc.Iterator."""
data = StringIO("a,b\n")
reader = AthenaCSVReader(data, delimiter=",")
assert isinstance(reader, Iterator)

def test_close(self):
"""close() releases file resources."""
data = StringIO("a,b\n")
reader = AthenaCSVReader(data, delimiter=",")
reader.close()
# After close, iteration should stop immediately
rows = list(reader)
assert rows == []

def test_context_manager(self):
"""Reader can be used as context manager."""
data = StringIO("a,b\n1,2\n")
with AthenaCSVReader(data, delimiter=",") as reader:
rows = list(reader)
assert rows == [["a", "b"], ["1", "2"]]
# After exiting context, reader should be closed
assert list(reader) == []