diff --git a/test/nodes/test_csv_reader.py b/test/nodes/test_csv_reader.py new file mode 100644 index 000000000..a16ba89be --- /dev/null +++ b/test/nodes/test_csv_reader.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import os +import tempfile + +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase + +from torchdata.nodes.csv_reader import CSVReader + +from .utils import run_test_save_load_state + + +class TestCSVReader(TestCase): + def setUp(self): + self.test_data = [ + ["Alice", "30", "New York"], + ["Bob", "25", "London"], + ["Charlie", "35", "Paris"], + ["David", "40", "Rome"], + ["Eve", "45", "Tokyo"], + ["Frank", "50", "Beijing"], + ["Grace", "55", "Shanghai"], + ["Harry", "60", "Seoul"], + ["Iris", "65", "Buenos Aires"], + ["Jack", "70", "Sao Paulo"], + ["Katy", "75", "Mexico City"], + ["Lily", "80", "Bogota"], + ] + + def _create_temp_csv(self, delimiter=",", header=True): + if header: + self.test_data.insert(0, ["name", "age", "city"]) + fd, path = tempfile.mkstemp(suffix=".csv") + with os.fdopen(fd, "w", newline="") as f: + writer = csv.writer(f, delimiter=delimiter) + writer.writerows(self.test_data) + return path + + def test_basic_read_list(self): + path = self._create_temp_csv(header=False) + node = CSVReader(path, has_header=False) + results = list(node) + self.assertEqual(len(results), len(self.test_data)) + self.assertEqual(results[0], ["Alice", "30", "New York"]) + self.assertEqual(results[-1], ["Lily", "80", "Bogota"]) + node.close() + + def test_basic_read_dict(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + results = list(node) + + self.assertEqual(len(results), len(self.test_data) - 1) + self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"}) + self.assertEqual(results[1]["city"], "London") + self.assertEqual(results[-1]["city"], "Bogota") + node.close() + + def test_different_delimiters(self): + path = self._create_temp_csv(delimiter="|") + node = CSVReader(path, has_header=True, delimiter="|", return_dict=True) + results = list(node) + + self.assertEqual(len(results), len(self.test_data) - 1) + self.assertEqual(results[2]["city"], "Paris") + self.assertEqual(results[-1]["city"], "Bogota") + node.close() + + def test_state_management(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + print(f"initial state: {node.state_dict()}") + for _ in range(11): + _ = next(node) + print(f"element = {_}, state: {node.state_dict()}") + + state = node.state_dict() + + node.reset(state) + item = next(node) + + with self.assertRaises(StopIteration): + next(node) + + self.assertEqual(item["name"], "Lily") + self.assertEqual(state[CSVReader.NUM_LINES_YIELDED], 11) + node.close() + + @parameterized.expand([3, 5, 7]) + def test_save_load_state(self, midpoint: int): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + run_test_save_load_state(self, node, midpoint) + node.close() + + def test_load_wrong_state(self): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + + state = node.state_dict() + state[CSVReader.HEADER_KEY] = None + with self.assertRaisesRegex( + ValueError, "Check if has_header=True matches the state header=None" + ): + node.reset(state) + + node.close() + + node = CSVReader(path, has_header=False) + state = node.state_dict() + state[CSVReader.HEADER_KEY] = ["name", "age"] + with self.assertRaisesRegex( + ValueError, + r"Check if has_header=False matches the state header=\['name', 'age'\]", + ): + node.reset(state) + + node.close() + + def test_empty_file(self): + path = self._create_temp_csv() + # Overwrite with empty file + with open(path, "w") as _: + pass + + node = CSVReader(path, has_header=False) + with self.assertRaises(StopIteration): + next(node) + node.close() + + def test_header_validation(self): + with self.assertRaisesRegex( + ValueError, "return_dict=True requires has_header=True" + ): + CSVReader("dummy.csv", has_header=False, return_dict=True) + + def test_multi_epoch(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + + # First epoch + epoch1 = list(node) + self.assertEqual(len(epoch1), len(self.test_data) - 1) + + # Second epoch + node.reset() + epoch2 = list(node) + self.assertEqual(epoch1, epoch2) + node.close() + + def test_partial_read_resume(self): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + + # Read partial and get state + _ = next(node) # Line 0 + state1 = node.state_dict() + + _ = next(node) # Line 1 + state2 = node.state_dict() + + # Resume from first state + node.reset(state1) + self.assertEqual(next(node), self.test_data[2]) + + # Resume from second state + node.reset(state2) + self.assertEqual(next(node), self.test_data[3]) + node.close() + + def test_file_closure(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True) + + # Read all items + list(node) + + # Verify file is closed + self.assertTrue(node._file.closed) + node.close() + + def test_state_with_header(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + + # Read one item + _ = next(node) + state = node.state_dict() + + # Verify header preservation + node.reset(state) + item = next(node) + self.assertEqual(item["city"], "London") + node.close() + + def tearDown(self): + # Clean up temporary files + for f in os.listdir(tempfile.gettempdir()): + if f.startswith("tmp") and f.endswith(".csv"): + os.remove(os.path.join(tempfile.gettempdir(), f)) diff --git a/torchdata/nodes/csv_reader.py b/torchdata/nodes/csv_reader.py new file mode 100644 index 000000000..7706857c3 --- /dev/null +++ b/torchdata/nodes/csv_reader.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import csv +from itertools import islice +from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union + +from torchdata.nodes.base_node import BaseNode + + +class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]): + """Node for reading CSV files with state management and header support. + Args: + file_path: Path to CSV file + has_header: Whether first row contains column headers + delimiter: CSV field delimiter + return_dict: Return rows as dictionaries (requires has_header=True) + """ + + NUM_LINES_YIELDED = "num_lines_yielded" + HEADER_KEY = "header" + + def __init__( + self, + file_path: str, + has_header: bool = False, + delimiter: str = ",", + return_dict: bool = False, + encoding: str = "utf-8", + ): + super().__init__() + self.file_path = file_path + self.has_header = has_header + self.delimiter = delimiter + self.return_dict = return_dict + if return_dict and not has_header: + raise ValueError("return_dict=True requires has_header=True") + self.encoding = encoding + self._file: Optional[TextIO] = None + self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None + self._header: Optional[Sequence[str]] = None + self._num_lines_yielded: int = 0 + self.reset() # Initialize reader + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset() + self.close() + + # Reopen the file and reset counters + self._file = open(self.file_path, encoding=self.encoding) + self._num_lines_yielded = 0 + if initial_state is not None: + self._handle_initial_state(initial_state) + else: + self._initialize_reader() + + def _handle_initial_state(self, state: Dict[str, Any]): + """Restore reader state from checkpoint.""" + # Validate header compatibility + if (not self.has_header and self.HEADER_KEY in state) or ( + self.has_header and state[self.HEADER_KEY] is None + ): + raise ValueError( + f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}" + ) + + self._header = state.get(self.HEADER_KEY) + target_line_num = state[self.NUM_LINES_YIELDED] + assert self._file is not None + # Create appropriate reader + if self.return_dict: + + self._reader = csv.DictReader( + self._file, delimiter=self.delimiter, fieldnames=self._header + ) + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) + # Skip header if needed (applies only when file has header) + + assert isinstance(self._reader, Iterator) + if self.has_header: + try: + next(self._reader) # Skip header line + except StopIteration: + pass # Empty file + # Fast-forward to target line using efficient slicing + consumed = sum(1 for _ in islice(self._reader, target_line_num)) + self._num_lines_yielded = consumed + + def _initialize_reader(self): + """Create fresh reader without state.""" + assert self._file is not None + if self.return_dict: + self._reader = csv.DictReader(self._file, delimiter=self.delimiter) + self._header = self._reader.fieldnames + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) + + if self.has_header: + + try: + self._header = next(self._reader) + except StopIteration: + self._header = None # Handle empty file + + def next(self) -> Union[List[str], Dict[str, str]]: + try: + assert isinstance(self._reader, Iterator) + row = next(self._reader) + self._num_lines_yielded += 1 + return row + + except StopIteration: + self.close() + raise + + def get_state(self) -> Dict[str, Any]: + return { + self.NUM_LINES_YIELDED: self._num_lines_yielded, + self.HEADER_KEY: self._header, + } + + def close(self): + if self._file is not None and not self._file.closed: + self._file.close()