From 4399c9d29beb5c6740e0df6378241321ace1e11d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 05:58:35 -0700 Subject: [PATCH 1/9] add csv_dataloader --- test/nodes/test_csv_dataloader.py | 189 ++++++++++++++++++++++++++++++ torchdata/nodes/csv_dataloader.py | 91 ++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 test/nodes/test_csv_dataloader.py create mode 100644 torchdata/nodes/csv_dataloader.py diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py new file mode 100644 index 000000000..f2156d492 --- /dev/null +++ b/test/nodes/test_csv_dataloader.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import csv +import os +import tempfile +from typing import Any, Dict, List, Union + +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase +from torchdata.nodes.base_node import BaseNode + +from torchdata.nodes.csv_dataloader import CSVReader + +from .utils import run_test_save_load_state + + +class TestCSVReader(TestCase): + def setUp(self): + pass + + def _create_temp_csv(self, delimiter=",", header=True): + if header: + self.test_data = [ + ["name", "age", "city"], + ["Alice", "30", "New York"], + ["Bob", "25", "London"], + ["Charlie", "35", "Paris"], + ["David", "40", "Rome"], + ["Eve", "45", "Tokyo"], + ["Frank", "50", "Beijing"], + ["Grace", "55", "Shanghai"], + ["Harry", "60", "Seoul"], + ["Iris", "65", "Buenos Aires"], + ["Jack", "70", "São Paulo"], + ["Katy", "75", "Mexico City"], + ["Lily", "80", "Bogotá"], + ] + else: + self.test_data = [ + ["Alice", "30", "New York"], + ["Bob", "25", "London"], + ["Charlie", "35", "Paris"], + ["David", "40", "Rome"], + ["Eve", "45", "Tokyo"], + ["Frank", "50", "Beijing"], + ["Grace", "55", "Shanghai"], + ["Harry", "60", "Seoul"], + ["Iris", "65", "Buenos Aires"], + ["Jack", "70", "Sao Paulo"], + ["Katy", "75", "Mexico City"], + ["Lily", "80", "Bogota"], + ] + fd, path = tempfile.mkstemp(suffix=".csv") + with os.fdopen(fd, "w", newline="") as f: + writer = csv.writer(f, delimiter=delimiter) + writer.writerows(self.test_data) + return path + + def test_basic_read_list(self): + path = self._create_temp_csv(header=False) + node = CSVReader(path, has_header=False) + results = list(node) + print(results) + self.assertEqual(len(results), len(self.test_data)) + self.assertEqual(results[0], ["Alice", "30", "New York"]) + self.assertEqual(results[-1], ["Lily", "80", "Bogota"]) + + def test_basic_read_dict(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + results = list(node) + + self.assertEqual(len(results), len(self.test_data) - 1) + self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"}) + self.assertEqual(results[1]["city"], "London") + + def test_different_delimiters(self): + path = self._create_temp_csv(delimiter="|") + node = CSVReader(path, has_header=True, delimiter="|", return_dict=True) + results = list(node) + + self.assertEqual(len(results), len(self.test_data) - 1) + self.assertEqual(results[2]["city"], "Paris") + + def test_state_management(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + + for _ in range(11): + next(node) + + state = node.state_dict() + + node.reset(state) + + item = next(node) + with self.assertRaises(StopIteration): + next(node) + + self.assertEqual(item["name"], "Lily") + self.assertEqual(state[CSVReader.LINE_NUM_KEY], 11) + + @parameterized.expand([3, 5, 7]) + def test_save_load_state(self, midpoint: int): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + run_test_save_load_state(self, node, midpoint) + + def test_empty_file(self): + path = self._create_temp_csv() + # Overwrite with empty file + with open(path, "w") as f: + pass + + node = CSVReader(path, has_header=False) + with self.assertRaises(StopIteration): + next(node) + + def test_header_validation(self): + with self.assertRaisesRegex( + ValueError, "return_dict=True requires has_header=True" + ): + CSVReader("dummy.csv", has_header=False, return_dict=True) + + def test_multi_epoch(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + + # First epoch + epoch1 = list(node) + self.assertEqual(len(epoch1), len(self.test_data) - 1) + + # Second epoch + node.reset() + epoch2 = list(node) + self.assertEqual(epoch1, epoch2) + + def test_partial_read_resume(self): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + + # Read partial and get state + _ = next(node) # Line 0 + state1 = node.state_dict() + print(_, self.test_data[1]) + + _ = next(node) # Line 1 + state2 = node.state_dict() + + # Resume from first state + node.reset(state1) + self.assertEqual(next(node), self.test_data[2]) + + # Resume from second state + node.reset(state2) + self.assertEqual(next(node), self.test_data[3]) + + def test_file_closure(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True) + + # Read all items + list(node) + + # Verify file is closed + self.assertTrue(node._file.closed) + + def test_state_with_header(self): + path = self._create_temp_csv() + node = CSVReader(path, has_header=True, return_dict=True) + + # Read one item + _ = next(node) + state = node.state_dict() + + # Verify header preservation + node.reset(state) + item = next(node) + self.assertEqual(item["city"], "London") + + def tearDown(self): + # Clean up temporary files + for f in os.listdir(tempfile.gettempdir()): + if f.startswith("tmp") and f.endswith(".csv"): + os.remove(os.path.join(tempfile.gettempdir(), f)) diff --git a/torchdata/nodes/csv_dataloader.py b/torchdata/nodes/csv_dataloader.py new file mode 100644 index 000000000..c8d40e5d7 --- /dev/null +++ b/torchdata/nodes/csv_dataloader.py @@ -0,0 +1,91 @@ +import csv +from typing import Any, Dict, List, Optional, Union + +from torchdata.nodes.base_node import BaseNode, T + + +class CSVReader(BaseNode[T]): + """Node for reading CSV files with state management and header support. + Args: + file_path: Path to CSV file + has_header: Whether first row contains column headers + delimiter: CSV field delimiter + return_dict: Return rows as dictionaries (requires has_header=True) + """ + + LINE_NUM_KEY = "line_num" + HEADER_KEY = "header" + + def __init__( + self, + file_path: str, + has_header: bool = False, + delimiter: str = ",", + return_dict: bool = False, + ): + super().__init__() + self.file_path = file_path + self.has_header = has_header + self.delimiter = delimiter + self.return_dict = return_dict + if return_dict and not has_header: + raise ValueError("return_dict=True requires has_header=True") + self._file = None + self._reader = None + self._header = None + self._line_num = 0 + self.reset() # Initialize reader + + def reset(self, initial_state: Optional[Dict[str, Any]] = None): + super().reset(initial_state) + + if self._file and not self._file.closed: + self._file.close() + + self._file = open(self.file_path, "r", newline="", encoding="utf-8") + self._line_num = 0 + + if initial_state: + self._header = initial_state.get(self.HEADER_KEY) + target_line_num = initial_state[self.LINE_NUM_KEY] + + if self.return_dict: + self._reader = csv.DictReader( + self._file, delimiter=self.delimiter, fieldnames=self._header + ) + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) + + if self.has_header: + next(self._reader) # Skip header + for _ in range(target_line_num - self._line_num): + try: + next(self._reader) + self._line_num += 1 + except StopIteration: + break + else: + + if self.return_dict: + self._reader = csv.DictReader(self._file, delimiter=self.delimiter) + self._header = self._reader.fieldnames + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) + if self.has_header: + self._header = next(self._reader) + + def next(self) -> Union[Dict[str, str], List[str]]: + try: + row = next(self._reader) + self._line_num += 1 + return row + except StopIteration: + self.close() + raise + + def get_state(self) -> Dict[str, Any]: + return {self.LINE_NUM_KEY: self._line_num, self.HEADER_KEY: self._header} + + def close(self): + if self._file and not self._file.closed: + self._file.close() From 129ef80f357be799b4a6c96245f8200aa2c966af Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 06:09:38 -0700 Subject: [PATCH 2/9] update linting --- test/nodes/test_csv_dataloader.py | 6 ++---- torchdata/nodes/csv_dataloader.py | 10 ++++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py index f2156d492..fd909da8e 100644 --- a/test/nodes/test_csv_dataloader.py +++ b/test/nodes/test_csv_dataloader.py @@ -113,7 +113,7 @@ def test_save_load_state(self, midpoint: int): def test_empty_file(self): path = self._create_temp_csv() # Overwrite with empty file - with open(path, "w") as f: + with open(path, "w") as _: pass node = CSVReader(path, has_header=False) @@ -121,9 +121,7 @@ def test_empty_file(self): next(node) def test_header_validation(self): - with self.assertRaisesRegex( - ValueError, "return_dict=True requires has_header=True" - ): + with self.assertRaisesRegex(ValueError, "return_dict=True requires has_header=True"): CSVReader("dummy.csv", has_header=False, return_dict=True) def test_multi_epoch(self): diff --git a/torchdata/nodes/csv_dataloader.py b/torchdata/nodes/csv_dataloader.py index c8d40e5d7..36ee94cb2 100644 --- a/torchdata/nodes/csv_dataloader.py +++ b/torchdata/nodes/csv_dataloader.py @@ -1,5 +1,5 @@ import csv -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from torchdata.nodes.base_node import BaseNode, T @@ -42,7 +42,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if self._file and not self._file.closed: self._file.close() - self._file = open(self.file_path, "r", newline="", encoding="utf-8") + self._file = open(self.file_path, newline="", encoding="utf-8") self._line_num = 0 if initial_state: @@ -50,9 +50,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): target_line_num = initial_state[self.LINE_NUM_KEY] if self.return_dict: - self._reader = csv.DictReader( - self._file, delimiter=self.delimiter, fieldnames=self._header - ) + self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) else: self._reader = csv.reader(self._file, delimiter=self.delimiter) @@ -74,7 +72,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if self.has_header: self._header = next(self._reader) - def next(self) -> Union[Dict[str, str], List[str]]: + def next(self) -> T: try: row = next(self._reader) self._line_num += 1 From c3f4cbdd35551b9afb0ff1e74d87e28f8b192710 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 06:36:23 -0700 Subject: [PATCH 3/9] typechecking --- torchdata/nodes/csv_dataloader.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchdata/nodes/csv_dataloader.py b/torchdata/nodes/csv_dataloader.py index 36ee94cb2..10199bbf7 100644 --- a/torchdata/nodes/csv_dataloader.py +++ b/torchdata/nodes/csv_dataloader.py @@ -1,7 +1,9 @@ import csv -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union -from torchdata.nodes.base_node import BaseNode, T +from torchdata.nodes.base_node import BaseNode + +T = TypeVar("T", bound=Union[List[str], Dict[str, str]]) class CSVReader(BaseNode[T]): @@ -30,8 +32,8 @@ def __init__( self.return_dict = return_dict if return_dict and not has_header: raise ValueError("return_dict=True requires has_header=True") - self._file = None - self._reader = None + self._file: Optional[TextIO] = None + self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None self._header = None self._line_num = 0 self.reset() # Initialize reader @@ -50,6 +52,8 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): target_line_num = initial_state[self.LINE_NUM_KEY] if self.return_dict: + if self._header is None: + raise ValueError("return_dict=True requires has_header=True") self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) else: self._reader = csv.reader(self._file, delimiter=self.delimiter) @@ -74,6 +78,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): def next(self) -> T: try: + assert isinstance(self._reader, Iterator) row = next(self._reader) self._line_num += 1 return row From 74fed7efb6fc35358ddf8e6f7ee58c01f1eaa001 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 08:36:33 -0700 Subject: [PATCH 4/9] fix mypy --- test/nodes/test_csv_dataloader.py | 62 +++++++++---------- .../{csv_dataloader.py => csv_reader.py} | 20 +++--- 2 files changed, 39 insertions(+), 43 deletions(-) rename torchdata/nodes/{csv_dataloader.py => csv_reader.py} (83%) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py index fd909da8e..370f027be 100644 --- a/test/nodes/test_csv_dataloader.py +++ b/test/nodes/test_csv_dataloader.py @@ -13,47 +13,31 @@ from torch.testing._internal.common_utils import TestCase from torchdata.nodes.base_node import BaseNode -from torchdata.nodes.csv_dataloader import CSVReader +from torchdata.nodes.csv_reader import CSVReader from .utils import run_test_save_load_state class TestCSVReader(TestCase): def setUp(self): - pass + self.test_data = [ + ["Alice", "30", "New York"], + ["Bob", "25", "London"], + ["Charlie", "35", "Paris"], + ["David", "40", "Rome"], + ["Eve", "45", "Tokyo"], + ["Frank", "50", "Beijing"], + ["Grace", "55", "Shanghai"], + ["Harry", "60", "Seoul"], + ["Iris", "65", "Buenos Aires"], + ["Jack", "70", "Sao Paulo"], + ["Katy", "75", "Mexico City"], + ["Lily", "80", "Bogota"], + ] def _create_temp_csv(self, delimiter=",", header=True): if header: - self.test_data = [ - ["name", "age", "city"], - ["Alice", "30", "New York"], - ["Bob", "25", "London"], - ["Charlie", "35", "Paris"], - ["David", "40", "Rome"], - ["Eve", "45", "Tokyo"], - ["Frank", "50", "Beijing"], - ["Grace", "55", "Shanghai"], - ["Harry", "60", "Seoul"], - ["Iris", "65", "Buenos Aires"], - ["Jack", "70", "São Paulo"], - ["Katy", "75", "Mexico City"], - ["Lily", "80", "Bogotá"], - ] - else: - self.test_data = [ - ["Alice", "30", "New York"], - ["Bob", "25", "London"], - ["Charlie", "35", "Paris"], - ["David", "40", "Rome"], - ["Eve", "45", "Tokyo"], - ["Frank", "50", "Beijing"], - ["Grace", "55", "Shanghai"], - ["Harry", "60", "Seoul"], - ["Iris", "65", "Buenos Aires"], - ["Jack", "70", "Sao Paulo"], - ["Katy", "75", "Mexico City"], - ["Lily", "80", "Bogota"], - ] + self.test_data.insert(0, ["name", "age", "city"]) fd, path = tempfile.mkstemp(suffix=".csv") with os.fdopen(fd, "w", newline="") as f: writer = csv.writer(f, delimiter=delimiter) @@ -64,7 +48,6 @@ def test_basic_read_list(self): path = self._create_temp_csv(header=False) node = CSVReader(path, has_header=False) results = list(node) - print(results) self.assertEqual(len(results), len(self.test_data)) self.assertEqual(results[0], ["Alice", "30", "New York"]) self.assertEqual(results[-1], ["Lily", "80", "Bogota"]) @@ -77,6 +60,7 @@ def test_basic_read_dict(self): self.assertEqual(len(results), len(self.test_data) - 1) self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"}) self.assertEqual(results[1]["city"], "London") + node.close() def test_different_delimiters(self): path = self._create_temp_csv(delimiter="|") @@ -85,6 +69,7 @@ def test_different_delimiters(self): self.assertEqual(len(results), len(self.test_data) - 1) self.assertEqual(results[2]["city"], "Paris") + node.close() def test_state_management(self): path = self._create_temp_csv() @@ -103,12 +88,14 @@ def test_state_management(self): self.assertEqual(item["name"], "Lily") self.assertEqual(state[CSVReader.LINE_NUM_KEY], 11) + node.close() @parameterized.expand([3, 5, 7]) def test_save_load_state(self, midpoint: int): path = self._create_temp_csv(header=True) node = CSVReader(path, has_header=True) run_test_save_load_state(self, node, midpoint) + node.close() def test_empty_file(self): path = self._create_temp_csv() @@ -119,9 +106,12 @@ def test_empty_file(self): node = CSVReader(path, has_header=False) with self.assertRaises(StopIteration): next(node) + node.close() def test_header_validation(self): - with self.assertRaisesRegex(ValueError, "return_dict=True requires has_header=True"): + with self.assertRaisesRegex( + ValueError, "return_dict=True requires has_header=True" + ): CSVReader("dummy.csv", has_header=False, return_dict=True) def test_multi_epoch(self): @@ -136,6 +126,7 @@ def test_multi_epoch(self): node.reset() epoch2 = list(node) self.assertEqual(epoch1, epoch2) + node.close() def test_partial_read_resume(self): path = self._create_temp_csv(header=True) @@ -156,6 +147,7 @@ def test_partial_read_resume(self): # Resume from second state node.reset(state2) self.assertEqual(next(node), self.test_data[3]) + node.close() def test_file_closure(self): path = self._create_temp_csv() @@ -166,6 +158,7 @@ def test_file_closure(self): # Verify file is closed self.assertTrue(node._file.closed) + node.close() def test_state_with_header(self): path = self._create_temp_csv() @@ -179,6 +172,7 @@ def test_state_with_header(self): node.reset(state) item = next(node) self.assertEqual(item["city"], "London") + node.close() def tearDown(self): # Clean up temporary files diff --git a/torchdata/nodes/csv_dataloader.py b/torchdata/nodes/csv_reader.py similarity index 83% rename from torchdata/nodes/csv_dataloader.py rename to torchdata/nodes/csv_reader.py index 10199bbf7..3e44dbe44 100644 --- a/torchdata/nodes/csv_dataloader.py +++ b/torchdata/nodes/csv_reader.py @@ -1,12 +1,10 @@ import csv -from typing import Any, Dict, Iterator, List, Optional, TextIO, TypeVar, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, TypeVar, Union -from torchdata.nodes.base_node import BaseNode +from torchdata.nodes.base_node import BaseNode, T -T = TypeVar("T", bound=Union[List[str], Dict[str, str]]) - -class CSVReader(BaseNode[T]): +class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]): """Node for reading CSV files with state management and header support. Args: file_path: Path to CSV file @@ -34,8 +32,8 @@ def __init__( raise ValueError("return_dict=True requires has_header=True") self._file: Optional[TextIO] = None self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None - self._header = None - self._line_num = 0 + self._header: Optional[Sequence[str]] = None + self._line_num: int = 0 self.reset() # Initialize reader def reset(self, initial_state: Optional[Dict[str, Any]] = None): @@ -54,10 +52,13 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if self.return_dict: if self._header is None: raise ValueError("return_dict=True requires has_header=True") - self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) + self._reader = csv.DictReader( + self._file, delimiter=self.delimiter, fieldnames=self._header + ) else: self._reader = csv.reader(self._file, delimiter=self.delimiter) + assert isinstance(self._reader, Iterator) if self.has_header: next(self._reader) # Skip header for _ in range(target_line_num - self._line_num): @@ -76,12 +77,13 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if self.has_header: self._header = next(self._reader) - def next(self) -> T: + def next(self) -> Union[List[str], Dict[str, str]]: try: assert isinstance(self._reader, Iterator) row = next(self._reader) self._line_num += 1 return row + except StopIteration: self.close() raise From 44f06ca04ed9b82603b0c8a09e88cdf7bc7aceff Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 08:48:34 -0700 Subject: [PATCH 5/9] clean up --- test/nodes/test_csv_dataloader.py | 2 +- torchdata/nodes/csv_reader.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py index 370f027be..382ab1a31 100644 --- a/test/nodes/test_csv_dataloader.py +++ b/test/nodes/test_csv_dataloader.py @@ -51,6 +51,7 @@ def test_basic_read_list(self): self.assertEqual(len(results), len(self.test_data)) self.assertEqual(results[0], ["Alice", "30", "New York"]) self.assertEqual(results[-1], ["Lily", "80", "Bogota"]) + node.close() def test_basic_read_dict(self): path = self._create_temp_csv() @@ -135,7 +136,6 @@ def test_partial_read_resume(self): # Read partial and get state _ = next(node) # Line 0 state1 = node.state_dict() - print(_, self.test_data[1]) _ = next(node) # Line 1 state2 = node.state_dict() diff --git a/torchdata/nodes/csv_reader.py b/torchdata/nodes/csv_reader.py index 3e44dbe44..d419af92b 100644 --- a/torchdata/nodes/csv_reader.py +++ b/torchdata/nodes/csv_reader.py @@ -1,7 +1,7 @@ import csv -from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, TypeVar, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union -from torchdata.nodes.base_node import BaseNode, T +from torchdata.nodes.base_node import BaseNode class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]): From 0ea2b264c24c336ef7fd49f90614ae63fa636fdb Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 08:50:42 -0700 Subject: [PATCH 6/9] clean up --- test/nodes/test_csv_dataloader.py | 4 +--- torchdata/nodes/csv_reader.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py index 382ab1a31..d0061e8fc 100644 --- a/test/nodes/test_csv_dataloader.py +++ b/test/nodes/test_csv_dataloader.py @@ -110,9 +110,7 @@ def test_empty_file(self): node.close() def test_header_validation(self): - with self.assertRaisesRegex( - ValueError, "return_dict=True requires has_header=True" - ): + with self.assertRaisesRegex(ValueError, "return_dict=True requires has_header=True"): CSVReader("dummy.csv", has_header=False, return_dict=True) def test_multi_epoch(self): diff --git a/torchdata/nodes/csv_reader.py b/torchdata/nodes/csv_reader.py index d419af92b..4864e342a 100644 --- a/torchdata/nodes/csv_reader.py +++ b/torchdata/nodes/csv_reader.py @@ -52,9 +52,7 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): if self.return_dict: if self._header is None: raise ValueError("return_dict=True requires has_header=True") - self._reader = csv.DictReader( - self._file, delimiter=self.delimiter, fieldnames=self._header - ) + self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) else: self._reader = csv.reader(self._file, delimiter=self.delimiter) From 3b7641fa12250343302be11310c066b829553ed6 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 16:36:27 -0700 Subject: [PATCH 7/9] add more tests --- test/nodes/test_csv_dataloader.py | 35 +++++++++--- torchdata/nodes/csv_reader.py | 90 +++++++++++++++++++------------ 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_dataloader.py index d0061e8fc..ac60034cb 100644 --- a/test/nodes/test_csv_dataloader.py +++ b/test/nodes/test_csv_dataloader.py @@ -7,11 +7,9 @@ import csv import os import tempfile -from typing import Any, Dict, List, Union from parameterized import parameterized from torch.testing._internal.common_utils import TestCase -from torchdata.nodes.base_node import BaseNode from torchdata.nodes.csv_reader import CSVReader @@ -61,6 +59,7 @@ def test_basic_read_dict(self): self.assertEqual(len(results), len(self.test_data) - 1) self.assertEqual(results[0], {"name": "Alice", "age": "30", "city": "New York"}) self.assertEqual(results[1]["city"], "London") + self.assertEqual(results[-1]["city"], "Bogota") node.close() def test_different_delimiters(self): @@ -70,25 +69,27 @@ def test_different_delimiters(self): self.assertEqual(len(results), len(self.test_data) - 1) self.assertEqual(results[2]["city"], "Paris") + self.assertEqual(results[-1]["city"], "Bogota") node.close() def test_state_management(self): path = self._create_temp_csv() node = CSVReader(path, has_header=True, return_dict=True) - + print(f"initial state: {node.state_dict()}") for _ in range(11): - next(node) + _ = next(node) + print(f"element = {_}, state: {node.state_dict()}") state = node.state_dict() node.reset(state) - item = next(node) + with self.assertRaises(StopIteration): next(node) self.assertEqual(item["name"], "Lily") - self.assertEqual(state[CSVReader.LINE_NUM_KEY], 11) + self.assertEqual(state[CSVReader.NUM_LINES_YIELDED], 11) node.close() @parameterized.expand([3, 5, 7]) @@ -98,6 +99,28 @@ def test_save_load_state(self, midpoint: int): run_test_save_load_state(self, node, midpoint) node.close() + def test_load_wrong_state(self): + path = self._create_temp_csv(header=True) + node = CSVReader(path, has_header=True) + + state = node.state_dict() + state[CSVReader.HEADER_KEY] = None + with self.assertRaisesRegex(ValueError, "Check if has_header=True matches the state header=None"): + node.reset(state) + + node.close() + + node = CSVReader(path, has_header=False) + state = node.state_dict() + state[CSVReader.HEADER_KEY] = ["name", "age"] + with self.assertRaisesRegex( + ValueError, + r"Check if has_header=False matches the state header=\['name', 'age'\]", + ): + node.reset(state) + + node.close() + def test_empty_file(self): path = self._create_temp_csv() # Overwrite with empty file diff --git a/torchdata/nodes/csv_reader.py b/torchdata/nodes/csv_reader.py index 4864e342a..435396563 100644 --- a/torchdata/nodes/csv_reader.py +++ b/torchdata/nodes/csv_reader.py @@ -1,4 +1,5 @@ import csv +from itertools import islice from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union from torchdata.nodes.base_node import BaseNode @@ -13,7 +14,7 @@ class CSVReader(BaseNode[Union[List[str], Dict[str, str]]]): return_dict: Return rows as dictionaries (requires has_header=True) """ - LINE_NUM_KEY = "line_num" + NUM_LINES_YIELDED = "num_lines_yielded" HEADER_KEY = "header" def __init__( @@ -22,6 +23,7 @@ def __init__( has_header: bool = False, delimiter: str = ",", return_dict: bool = False, + encoding: str = "utf-8", ): super().__init__() self.file_path = file_path @@ -30,56 +32,73 @@ def __init__( self.return_dict = return_dict if return_dict and not has_header: raise ValueError("return_dict=True requires has_header=True") + self.encoding = encoding self._file: Optional[TextIO] = None self._reader: Optional[Iterator[Union[List[str], Dict[str, str]]]] = None self._header: Optional[Sequence[str]] = None - self._line_num: int = 0 + self._num_lines_yielded: int = 0 self.reset() # Initialize reader def reset(self, initial_state: Optional[Dict[str, Any]] = None): - super().reset(initial_state) - - if self._file and not self._file.closed: - self._file.close() + super().reset() + self.close() + + # Reopen the file and reset counters + self._file = open(self.file_path, encoding=self.encoding) + self._num_lines_yielded = 0 + if initial_state is not None: + self._handle_initial_state(initial_state) + else: + self._initialize_reader() - self._file = open(self.file_path, newline="", encoding="utf-8") - self._line_num = 0 + def _handle_initial_state(self, state: Dict[str, Any]): + """Restore reader state from checkpoint.""" + # Validate header compatibility + if (not self.has_header and self.HEADER_KEY in state) or (self.has_header and state[self.HEADER_KEY] is None): + raise ValueError(f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}") - if initial_state: - self._header = initial_state.get(self.HEADER_KEY) - target_line_num = initial_state[self.LINE_NUM_KEY] + self._header = state.get(self.HEADER_KEY) + target_line_num = state[self.NUM_LINES_YIELDED] + assert self._file is not None + # Create appropriate reader + if self.return_dict: - if self.return_dict: - if self._header is None: - raise ValueError("return_dict=True requires has_header=True") - self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) - else: - self._reader = csv.reader(self._file, delimiter=self.delimiter) + self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) + # Skip header if needed (applies only when file has header) + + assert isinstance(self._reader, Iterator) + if self.has_header: + try: + next(self._reader) # Skip header line + except StopIteration: + pass # Empty file + # Fast-forward to target line using efficient slicing + consumed = sum(1 for _ in islice(self._reader, target_line_num)) + self._num_lines_yielded = consumed + + def _initialize_reader(self): + """Create fresh reader without state.""" + assert self._file is not None + if self.return_dict: + self._reader = csv.DictReader(self._file, delimiter=self.delimiter) + self._header = self._reader.fieldnames + else: + self._reader = csv.reader(self._file, delimiter=self.delimiter) - assert isinstance(self._reader, Iterator) if self.has_header: - next(self._reader) # Skip header - for _ in range(target_line_num - self._line_num): - try: - next(self._reader) - self._line_num += 1 - except StopIteration: - break - else: - if self.return_dict: - self._reader = csv.DictReader(self._file, delimiter=self.delimiter) - self._header = self._reader.fieldnames - else: - self._reader = csv.reader(self._file, delimiter=self.delimiter) - if self.has_header: + try: self._header = next(self._reader) + except StopIteration: + self._header = None # Handle empty file def next(self) -> Union[List[str], Dict[str, str]]: try: assert isinstance(self._reader, Iterator) row = next(self._reader) - self._line_num += 1 + self._num_lines_yielded += 1 return row except StopIteration: @@ -87,7 +106,10 @@ def next(self) -> Union[List[str], Dict[str, str]]: raise def get_state(self) -> Dict[str, Any]: - return {self.LINE_NUM_KEY: self._line_num, self.HEADER_KEY: self._header} + return { + self.NUM_LINES_YIELDED: self._num_lines_yielded, + self.HEADER_KEY: self._header, + } def close(self): if self._file and not self._file.closed: From bd54f40d4192a18c867be7f0d01fd6ad4eea129d Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Mon, 14 Apr 2025 16:39:08 -0700 Subject: [PATCH 8/9] update filenames --- test/nodes/{test_csv_dataloader.py => test_csv_reader.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/nodes/{test_csv_dataloader.py => test_csv_reader.py} (100%) diff --git a/test/nodes/test_csv_dataloader.py b/test/nodes/test_csv_reader.py similarity index 100% rename from test/nodes/test_csv_dataloader.py rename to test/nodes/test_csv_reader.py From 7572ff57a1f26639351461a8956dcfbda529a63b Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Thu, 29 May 2025 21:11:15 -0700 Subject: [PATCH 9/9] update files --- test/nodes/test_csv_reader.py | 8 ++++++-- torchdata/nodes/csv_reader.py | 20 ++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/test/nodes/test_csv_reader.py b/test/nodes/test_csv_reader.py index ac60034cb..a16ba89be 100644 --- a/test/nodes/test_csv_reader.py +++ b/test/nodes/test_csv_reader.py @@ -105,7 +105,9 @@ def test_load_wrong_state(self): state = node.state_dict() state[CSVReader.HEADER_KEY] = None - with self.assertRaisesRegex(ValueError, "Check if has_header=True matches the state header=None"): + with self.assertRaisesRegex( + ValueError, "Check if has_header=True matches the state header=None" + ): node.reset(state) node.close() @@ -133,7 +135,9 @@ def test_empty_file(self): node.close() def test_header_validation(self): - with self.assertRaisesRegex(ValueError, "return_dict=True requires has_header=True"): + with self.assertRaisesRegex( + ValueError, "return_dict=True requires has_header=True" + ): CSVReader("dummy.csv", has_header=False, return_dict=True) def test_multi_epoch(self): diff --git a/torchdata/nodes/csv_reader.py b/torchdata/nodes/csv_reader.py index 435396563..7706857c3 100644 --- a/torchdata/nodes/csv_reader.py +++ b/torchdata/nodes/csv_reader.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import csv from itertools import islice from typing import Any, Dict, Iterator, List, Optional, Sequence, TextIO, Union @@ -54,8 +60,12 @@ def reset(self, initial_state: Optional[Dict[str, Any]] = None): def _handle_initial_state(self, state: Dict[str, Any]): """Restore reader state from checkpoint.""" # Validate header compatibility - if (not self.has_header and self.HEADER_KEY in state) or (self.has_header and state[self.HEADER_KEY] is None): - raise ValueError(f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}") + if (not self.has_header and self.HEADER_KEY in state) or ( + self.has_header and state[self.HEADER_KEY] is None + ): + raise ValueError( + f"Check if has_header={self.has_header} matches the state header={state[self.HEADER_KEY]}" + ) self._header = state.get(self.HEADER_KEY) target_line_num = state[self.NUM_LINES_YIELDED] @@ -63,7 +73,9 @@ def _handle_initial_state(self, state: Dict[str, Any]): # Create appropriate reader if self.return_dict: - self._reader = csv.DictReader(self._file, delimiter=self.delimiter, fieldnames=self._header) + self._reader = csv.DictReader( + self._file, delimiter=self.delimiter, fieldnames=self._header + ) else: self._reader = csv.reader(self._file, delimiter=self.delimiter) # Skip header if needed (applies only when file has header) @@ -112,5 +124,5 @@ def get_state(self) -> Dict[str, Any]: } def close(self): - if self._file and not self._file.closed: + if self._file is not None and not self._file.closed: self._file.close()