From bd73431b4e6c0088dbb84babc7800f0ebb203468 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Wed, 15 Feb 2023 12:29:52 +0100 Subject: [PATCH 1/6] Make iter to map conversion more lazy --- test/test_iterdatapipe.py | 12 +++++ torchdata/datapipes/iter/util/converter.py | 51 +++++++++++++++------- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 9ee73ba44..ea1fa3424 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1056,6 +1056,18 @@ def test_itertomap_mapdatapipe(self): self.assertEqual(len(wa), 1) self.assertRegex(str(wa[0].message), r"Found duplicate key") + # More lazily: load only until necessary + source_dp = IterableWrapper(list(zip(keys, values))) + lazy_map_dp = source_dp.to_map_datapipe() + _ = lazy_map_dp["k" + str(4)] + self.assertEqual(len(lazy_map_dp._map), 5) + _ = lazy_map_dp["k" + str(7)] + self.assertEqual(len(lazy_map_dp._map), 8) + try: + _ = lazy_map_dp["k" + str(20)] + except IndexError: + self.assertEqual(len(lazy_map_dp._map), 10) + def test_mux_longest_iterdatapipe(self): # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 0721e1741..30dc19775 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -68,32 +68,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No _check_unpickable_fn(key_value_fn) self.key_value_fn = key_value_fn # type: ignore[assignment] self._map = None + self._itr = None + self._depleted = False def _load_map(self): - self._map = {} - for d in self.datapipe: - inp = d if self.key_value_fn is None else self.key_value_fn(d) + if self._map is None: + self._map = {} + self._itr = iter(self.datapipe) + while not self._depleted: try: - length = len(inp) - except TypeError: - raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence") - if length != 2: - raise ValueError(f"dictionary update sequence element has length {length}, 2 is required") - key, value = inp - if key in self._map: - warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`") - self._map[key] = value + self._load_next_item() + except StopIteration: + self._depleted = True def __getitem__(self, index): try: if self._map is None: - self._load_map() - return self._map[index] # type: ignore[index] + self._map = {} + self._itr = iter(self.datapipe) + raise KeyError + return self._map[index] except KeyError: + while not self._depleted: + try: + key, value = self._load_next_item() + if key == index: + return value + except StopIteration: + self._depleted = True raise IndexError(f"Index {index} is invalid for IterToMapConverter.") + def _load_next_item(self): + elem = next(self._itr) + inp = elem if self.key_value_fn is None else self.key_value_fn(elem) + try: + length = len(inp) + except TypeError: + raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence") + if length != 2: + raise ValueError(f"dictionary update sequence element has length {length}, 2 is required") + key, value = inp + if key in self._map: + warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`") + self._map[key] = value + return key, value + def __len__(self): - if self._map is not None: + if self._depleted: return len(self._map) # type: ignore[arg-type] try: return len(self.datapipe) From 4f947ad3ff3dc80bcdcabbc94ee645b101bbb14b Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Wed, 15 Feb 2023 15:43:24 +0100 Subject: [PATCH 2/6] Improve readability --- torchdata/datapipes/iter/util/converter.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 30dc19775..f43b3fde6 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -86,17 +86,18 @@ def __getitem__(self, index): if self._map is None: self._map = {} self._itr = iter(self.datapipe) - raise KeyError - return self._map[index] + else: + return self._map[index] except KeyError: - while not self._depleted: - try: - key, value = self._load_next_item() - if key == index: - return value - except StopIteration: - self._depleted = True - raise IndexError(f"Index {index} is invalid for IterToMapConverter.") + pass + while not self._depleted: + try: + key, value = self._load_next_item() + if key == index: + return value + except StopIteration: + self._depleted = True + raise IndexError(f"Index {index} is invalid for IterToMapConverter.") def _load_next_item(self): elem = next(self._itr) From 377924ec188f8fce32abddda0c12624fb5fad135 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Wed, 15 Feb 2023 17:41:38 +0100 Subject: [PATCH 3/6] Fix mypy issues --- torchdata/datapipes/iter/util/converter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index f43b3fde6..a01fc1d78 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -5,8 +5,7 @@ # LICENSE file in the root directory of this source tree. import warnings - -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Iterator, Optional from torch.utils.data import IterDataPipe, MapDataPipe from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE @@ -59,6 +58,8 @@ class IterToMapConverterMapDataPipe(MapDataPipe): key_value_fn: Optional[Callable] _map: Optional[Dict] _length: int + _itr: Optional[Iterator] + _depleted: bool def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = None): if not isinstance(datapipe, IterDataPipe): From 70a0f3c5980b95821bb5d413d094a1e5ed781e26 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 16 Feb 2023 11:25:23 +0100 Subject: [PATCH 4/6] Fix state methods and actually fix mypy issues --- torchdata/datapipes/iter/util/converter.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index a01fc1d78..be84c574d 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -73,9 +73,6 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No self._depleted = False def _load_map(self): - if self._map is None: - self._map = {} - self._itr = iter(self.datapipe) while not self._depleted: try: self._load_next_item() @@ -84,10 +81,7 @@ def _load_map(self): def __getitem__(self, index): try: - if self._map is None: - self._map = {} - self._itr = iter(self.datapipe) - else: + if self._map is not None: return self._map[index] except KeyError: pass @@ -101,7 +95,10 @@ def __getitem__(self, index): raise IndexError(f"Index {index} is invalid for IterToMapConverter.") def _load_next_item(self): - elem = next(self._itr) + if self._map is None: + self._map = {} + self._itr = iter(self.datapipe) + elem = next(self._itr) # type: ignore[arg-type] inp = elem if self.key_value_fn is None else self.key_value_fn(elem) try: length = len(inp) @@ -135,14 +132,10 @@ def __getstate__(self): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn - return ( - self.datapipe, - dill_key_value_fn, - self._map, - ) + return (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) def __setstate__(self, state): - (self.datapipe, dill_key_value_fn, self._map) = state + (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) = state if DILL_AVAILABLE: self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: From ce70b45192042147d7b0acedb78b0fe497119b91 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 16 Feb 2023 15:25:33 +0100 Subject: [PATCH 5/6] Migrate to doctest --- torchdata/datapipes/iter/util/converter.py | 45 ++++++++++++++-------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index be84c574d..547796821 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -36,24 +36,37 @@ class IterToMapConverterMapDataPipe(MapDataPipe): will be replaced by the new value. Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper([(i, i) for i in range(10)]) - >>> map_dp = source_dp.to_map_datapipe() - >>> list(map_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) - >>> map_dp2 = source_dp2.to_map_datapipe() - >>> map_dp2['a'] - 1 - >>> def row_to_tuple(row): - >>> label = row[0] - >>> data = row[1:] - >>> return label, data - >>> source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)]) - >>> map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple) - >>> map_dp3['a'] + + .. testsetup:: + + from torchdata.datapipes.iter import IterableWrapper + + .. testcode:: + + source_dp = IterableWrapper([(i, i) for i in range(10)]) + map_dp = source_dp.to_map_datapipe() + assert list(map_dp) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) + map_dp2 = source_dp2.to_map_datapipe() + assert map_dp2['a']) == 1 + + .. testcode:: + + def row_to_tuple(row): + label = row[0] + data = row[1:] + return label, data + source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)]) + map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple) + print(map_dp3['a']) + + .. testoutput:: + (1, 1, 1, 1, 1, 1) + """ + datapipe: IterDataPipe key_value_fn: Optional[Callable] _map: Optional[Dict] From 03406395338fbebd160c15e2cdd6049263fe8eef Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Fri, 10 Mar 2023 18:22:32 +0100 Subject: [PATCH 6/6] Address PR comments --- torchdata/datapipes/iter/util/converter.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index 547796821..59d68a64e 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -49,7 +49,7 @@ class IterToMapConverterMapDataPipe(MapDataPipe): source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) map_dp2 = source_dp2.to_map_datapipe() - assert map_dp2['a']) == 1 + assert map_dp2['a'] == 1 .. testcode:: @@ -91,6 +91,7 @@ def _load_map(self): self._load_next_item() except StopIteration: self._depleted = True + self._itr = None def __getitem__(self, index): try: @@ -105,6 +106,7 @@ def __getitem__(self, index): return value except StopIteration: self._depleted = True + self._itr = None raise IndexError(f"Index {index} is invalid for IterToMapConverter.") def _load_next_item(self): @@ -145,10 +147,10 @@ def __getstate__(self): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn - return (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) + return (self.datapipe, dill_key_value_fn, self._map, self._depleted) def __setstate__(self, state): - (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) = state + (self.datapipe, dill_key_value_fn, self._map, self._depleted) = state if DILL_AVAILABLE: self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: