diff --git a/grain/_src/core/transforms.py b/grain/_src/core/transforms.py index 3d877319..37672991 100644 --- a/grain/_src/core/transforms.py +++ b/grain/_src/core/transforms.py @@ -32,7 +32,7 @@ import numpy as np -class MapTransform(abc.ABC): +class Map(abc.ABC): """Abstract base class for all 1:1 transformations of elements. Implementations should be threadsafe since they are often executed in @@ -44,7 +44,7 @@ def map(self, element): """Maps a single element.""" -class RandomMapTransform(abc.ABC): +class RandomMap(abc.ABC): """Abstract base class for all random 1:1 transformations of elements. Implementations should be threadsafe since they are often executed in @@ -68,7 +68,7 @@ def map_with_index(self, index: int, element): """Maps a single element with its index.""" -class TfRandomMapTransform(abc.ABC): +class TfRandomMap(abc.ABC): """Abstract base class for all random 1:1 transformations of elements.""" @abc.abstractmethod @@ -91,7 +91,7 @@ def filter(self, element) -> bool: """Filters a single element; returns True if the element should be kept.""" -class FlatMapTransform(abc.ABC): +class FlatMap(abc.ABC): """Abstract base class for splitting operations of individual elements. Implementations should be threadsafe since they are often executed in @@ -124,11 +124,11 @@ class Batch: Transformation = Union[ Batch, - MapTransform, - RandomMapTransform, - TfRandomMapTransform, + Map, + RandomMap, + TfRandomMap, Filter, - FlatMapTransform, + FlatMap, MapWithIndex, ] Transformations = Sequence[Transformation] @@ -150,11 +150,11 @@ def get_pretty_transform_name( transform, ( Batch, - MapTransform, - RandomMapTransform, - TfRandomMapTransform, + Map, + RandomMap, + TfRandomMap, Filter, - FlatMapTransform, + FlatMap, ), ): # Check if transform class defines `__str__` and `__repr__` and use them if diff --git a/grain/_src/core/transforms_test.py b/grain/_src/core/transforms_test.py index a35f1e5b..442932ad 100644 --- a/grain/_src/core/transforms_test.py +++ b/grain/_src/core/transforms_test.py @@ -34,7 +34,7 @@ def __str__(self): return "CustomStr" -class _TestMapWithRepr(transforms.MapTransform): +class _TestMapWithRepr(transforms.Map): def map(self, x): return x % 2 == 0 diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 3c13e6e1..3e0e609c 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -100,7 +100,7 @@ def use_context_if_available(obj): @dataclasses.dataclass -class CopyNumPyArrayToSharedMemory(transforms.MapTransform): +class CopyNumPyArrayToSharedMemory(transforms.Map): """If `element` contains NumPy array copy it to SharedMemoryArray.""" def map(self, element: Any) -> Any: @@ -412,7 +412,7 @@ def __init__( logging.info("Enabling SharedMemoryArray for BatchOperation.") operations[-1]._enable_shared_memory() else: - logging.info("Adding CopyNumPyArrayToSharedMemory MapTransform.") + logging.info("Adding CopyNumPyArrayToSharedMemory Map.") operations = list(operations) + [CopyNumPyArrayToSharedMemory()] self._data_source = data_source @@ -443,7 +443,7 @@ def __init__( ) # pylint: enable=protected-access self._use_native_dataset_checkpointing = any( - isinstance(op, transforms.FlatMapTransform) for op in self._operations + isinstance(op, transforms.FlatMap) for op in self._operations ) self._dataset = self._create_dataset() @@ -634,10 +634,10 @@ def _source_repr(source: RandomAccessDataSource) -> str: return repr(source) -class _FlatMapAdapter(transforms.FlatMapTransform): +class _FlatMapAdapter(transforms.FlatMap): """Data loader adapter to pass through correct metadata.""" - def __init__(self, transform: transforms.FlatMapTransform): + def __init__(self, transform: transforms.FlatMap): self._transform = transform self.max_fan_out = transform.max_fan_out @@ -660,25 +660,25 @@ def _apply_transform_to_dataset( ds: dataset.IterDataset, ) -> dataset.IterDataset: """Applies the `transform` to the dataset.""" - if isinstance(transform, transforms.MapTransform): + if isinstance(transform, transforms.Map): return ds.map( lambda r: record.Record(metadata=r.metadata, data=transform.map(r.data)) ) - elif isinstance(transform, transforms.RandomMapTransform): + elif isinstance(transform, transforms.RandomMap): return ds.map( lambda r: record.Record( metadata=r.metadata, data=transform.random_map(r.data, r.metadata.rng), ) ) - elif isinstance(transform, transforms.TfRandomMapTransform): + elif isinstance(transform, transforms.TfRandomMap): return ds.map( lambda r: record.Record( metadata=r.metadata, data=transform.np_random_map(r.data, r.metadata.rng), ) ) - elif isinstance(transform, transforms.FlatMapTransform): + elif isinstance(transform, transforms.FlatMap): return flatmap.FlatMapIterDataset(ds, _FlatMapAdapter(transform)) elif isinstance(transform, transforms.Filter): return ds.filter(lambda r: transform.filter(r.data)) diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index a825ba04..22b9a54b 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -73,26 +73,26 @@ def filter(self, x: int) -> bool: return x % 2 == 0 -class PlusOne(transforms.MapTransform): +class PlusOne(transforms.Map): def map(self, x: int) -> int: return x + 1 -class PlusRandom(transforms.RandomMapTransform): +class PlusRandom(transforms.RandomMap): def random_map(self, x: int, rng: np.random.Generator) -> int: return x + rng.integers(100_000) -class FailingMap(transforms.MapTransform): +class FailingMap(transforms.Map): def map(self, x): del x 1 / 0 # pylint: disable=pointless-statement -class NonPickableTransform(transforms.MapTransform): +class NonPickableTransform(transforms.Map): def __getstate__(self): raise ValueError("I shall not be pickled") @@ -101,13 +101,13 @@ def map(self, x): return x -class RaisingTransform(transforms.MapTransform): +class RaisingTransform(transforms.Map): def map(self, x): raise AttributeError("I shall raise") -class ExitingTransform(transforms.MapTransform): +class ExitingTransform(transforms.Map): def map(self, x): raise sys.exit(123) @@ -126,7 +126,7 @@ def __getitem__(self, record_key: int): } -class DuplicateElementFlatMap(transforms.FlatMapTransform): +class DuplicateElementFlatMap(transforms.FlatMap): max_fan_out: int = 7 def flat_map(self, element: Any) -> Any: diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index d1c3abf6..3812f30f 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -175,6 +175,7 @@ py_test( }), shard_count = 10, srcs_version = "PY3", + # TODO: Re-enable once the test is fixed. deps = [ ":dataset", ":elastic_iterator", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index de2fece1..e3ac088d 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -509,9 +509,7 @@ def filter( # pylint: enable=g-import-not-at-top return filter_dataset.FilterMapDataset(parent=self, transform=transform) - def map( - self, transform: transforms.MapTransform | Callable[[T], S] - ) -> MapDataset[S]: + def map(self, transform: transforms.Map | Callable[[T], S]) -> MapDataset[S]: """Returns a dataset containing the elements transformed by ``transform``. Example usage:: @@ -717,9 +715,7 @@ def slice(self, sl: builtins.slice) -> MapDataset[T]: def random_map( self, - transform: ( - transforms.RandomMapTransform | Callable[[T, np.random.Generator], S] - ), + transform: transforms.RandomMap | Callable[[T, np.random.Generator], S], *, seed: int | None = None, ) -> MapDataset[S]: @@ -1181,9 +1177,7 @@ def filter( # pylint: enable=g-import-not-at-top return filter_dataset.FilterIterDataset(parent=self, transform=transform) - def map( - self, transform: transforms.MapTransform | Callable[[T], S] - ) -> IterDataset[S]: + def map(self, transform: transforms.Map | Callable[[T], S]) -> IterDataset[S]: """Returns a dataset containing the elements transformed by ``transform``. Example usage:: @@ -1210,9 +1204,7 @@ def map( def random_map( self, - transform: ( - transforms.RandomMapTransform | Callable[[T, np.random.Generator], S] - ), + transform: transforms.RandomMap | Callable[[T, np.random.Generator], S], *, seed: int | None = None, ) -> IterDataset[S]: @@ -1754,13 +1746,13 @@ def apply_transformations( drop_remainder=transformation.drop_remainder, batch_fn=transformation.batch_fn, ) - case transforms.MapTransform(): + case transforms.Map(): ds = ds.map(transformation) - case transforms.RandomMapTransform(): + case transforms.RandomMap(): ds = ds.random_map(transformation) case transforms.MapWithIndex(): ds = ds.map_with_index(transformation) - case transforms.FlatMapTransform(): + case transforms.FlatMap(): # Loaded lazily due to a circular dependency (dataset <-> flatmap). # pylint: disable=g-import-not-at-top from grain._src.python.dataset.transformations import flatmap diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index b46cd1be..88e67ccc 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -49,21 +49,21 @@ def filter(self, element: int) -> bool: @dataclasses.dataclass(frozen=True) -class RandomMapAddingRandomInt(transforms.RandomMapTransform): +class RandomMapAddingRandomInt(transforms.RandomMap): def random_map(self, element: int, rng: np.random.Generator) -> int: return element + rng.integers(0, 100) @dataclasses.dataclass(frozen=True) -class RandomMapAlwaysAddingOne(transforms.RandomMapTransform): +class RandomMapAlwaysAddingOne(transforms.RandomMap): def random_map(self, element: int, rng: np.random.Generator) -> int: return element + 1 @dataclasses.dataclass(frozen=True) -class MapTransformAddingOne(transforms.MapTransform): +class MapAddingOne(transforms.Map): def map(self, element: int) -> int: return element + 1 @@ -97,7 +97,7 @@ def __getitem__(self, index): return (index + 1) % 2, index // 2 -class AddRandomInteger(transforms.RandomMapTransform): +class AddRandomInteger(transforms.RandomMap): def random_map(self, element, rng): return element + rng.integers(low=0, high=100) @@ -1044,7 +1044,7 @@ def test_map_has_one_parent(self, initial_ds): dataset.MapDataset.range(15), dataset.MapDataset.range(15).to_iter_dataset(), ], - transform=[MapTransformAddingOne(), lambda x: x + 1], + transform=[MapAddingOne(), lambda x: x + 1], ) def test_map_produces_correct_elements(self, initial_ds, transform): ds = initial_ds.map(transform) @@ -1173,7 +1173,7 @@ def test_concatenate(self): (dataset.MapDataset.range(5).to_iter_dataset(),), ) def test_apply(self, ds): - ds = ds.apply([MapTransformAddingOne(), transforms.Batch(2)]) + ds = ds.apply([MapAddingOne(), transforms.Batch(2)]) np.testing.assert_equal( list(ds), [ @@ -1184,7 +1184,7 @@ def test_apply(self, ds): ) -class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform): +class TfRandomMapAlwaysAddingOne(transforms.TfRandomMap): def np_random_map(self, x, rng): return x + 1 @@ -1196,7 +1196,7 @@ def filter(self, x): return np.sum(x) < 20 -class FlatMapAddingOne(transforms.FlatMapTransform): +class FlatMapAddingOne(transforms.FlatMap): max_fan_out = 2 @@ -1211,7 +1211,7 @@ class ApplyTransformationsTest(parameterized.TestCase): (dataset.MapDataset.range(15).to_iter_dataset(),), ) def test_single_transform(self, ds): - ds = dataset.apply_transformations(ds, MapTransformAddingOne()) + ds = dataset.apply_transformations(ds, MapAddingOne()) self.assertSequenceEqual(list(ds), list(range(1, 16))) @parameterized.parameters( @@ -1223,7 +1223,7 @@ def test_multiple_transforms(self, ds): ds = dataset.apply_transformations( ds, [ - MapTransformAddingOne(), + MapAddingOne(), RandomMapAlwaysAddingOne(), transforms.Batch(batch_size=2, drop_remainder=True), FilterArraysWithLargeSum(), @@ -1359,7 +1359,7 @@ def test_get_execution_summary_without_collection(self): def test_execution_summary_with_logging(self): with self.assertLogs(level="INFO") as logs: ds = dataset.MapDataset.range(10).shuffle(42) - ds = ds.map(MapTransformAddingOne()) + ds = ds.map(MapAddingOne()) ds = ds.to_iter_dataset() it = ds.__iter__() # Get execution summary after iterating through the dataset. @@ -1377,7 +1377,7 @@ def test_execution_summary_with_logging(self): def test_execution_summary_with_no_logging(self): with self.assertLogs(level="INFO") as logs: ds = dataset.MapDataset.range(10).shuffle(42) - ds = ds.map(MapTransformAddingOne()) + ds = ds.map(MapAddingOne()) ds = ds.to_iter_dataset() ds = dataset.WithOptionsIterDataset( ds, @@ -1400,7 +1400,7 @@ def worker_init_fn_wrapper(worker_index, worker_count): del worker_index, worker_count dataset_stats._REPORTING_PERIOD_SEC = 0.05 - ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne()) + ds = dataset.MapDataset.range(10000).map(MapAddingOne()) ds = ds.to_iter_dataset() ds = ds.mp_prefetch( options.MultiprocessingOptions(num_workers=1), diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index d409f22f..6c17d24b 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -153,7 +153,7 @@ def _identity(x): return x -class _AddOne(transforms.MapTransform): +class _AddOne(transforms.Map): def map(self, x): return x + 1 diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index ff791489..02f9d66b 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -33,7 +33,7 @@ import numpy as np -class MapWithElementSpecInference(transforms.MapTransform): +class MapWithElementSpecInference(transforms.Map): def map(self, element: int): return { @@ -495,7 +495,7 @@ def test_batch_after_filter_raises_error(self): def test_batch_after_flatmap_raises_error(self): @dataclasses.dataclass(frozen=True) - class TestFlatMapTransform(transforms.FlatMapTransform): + class TestFlatMap(transforms.FlatMap): max_fan_out: int def flat_map(self, element: int): @@ -503,7 +503,7 @@ def flat_map(self, element: int): yield i ds = dataset.MapDataset.range(0, 10) - ds = flatmap.FlatMapMapDataset(ds, TestFlatMapTransform(max_fan_out=5)) + ds = flatmap.FlatMapMapDataset(ds, TestFlatMap(max_fan_out=5)) with self.assertRaisesRegex( ValueError, "`MapDataset.batch` can not follow `FlatMapMapDataset`", diff --git a/grain/_src/python/dataset/transformations/flatmap.py b/grain/_src/python/dataset/transformations/flatmap.py index cfe8ec65..75a8ea6e 100644 --- a/grain/_src/python/dataset/transformations/flatmap.py +++ b/grain/_src/python/dataset/transformations/flatmap.py @@ -32,7 +32,7 @@ class FlatMapMapDataset(dataset.MapDataset[T]): def __init__( self, parent: dataset.MapDataset, - transform: transforms.FlatMapTransform, + transform: transforms.FlatMap, ): super().__init__(parent) self._transform = transform @@ -147,7 +147,7 @@ class FlatMapIterDataset(dataset.IterDataset[T]): def __init__( self, parent: dataset.IterDataset, - transform: transforms.FlatMapTransform, + transform: transforms.FlatMap, ): super().__init__(parent) self._transform = transform diff --git a/grain/_src/python/dataset/transformations/flatmap_test.py b/grain/_src/python/dataset/transformations/flatmap_test.py index 7e0a9cd7..85ebb2e5 100644 --- a/grain/_src/python/dataset/transformations/flatmap_test.py +++ b/grain/_src/python/dataset/transformations/flatmap_test.py @@ -26,7 +26,7 @@ @dataclasses.dataclass(frozen=True) -class FixedSizeSplitWithNoTransform(transforms.FlatMapTransform): +class FixedSizeSplitWithNoTransform(transforms.FlatMap): max_fan_out: int def flat_map(self, element: int): @@ -35,7 +35,7 @@ def flat_map(self, element: int): @dataclasses.dataclass(frozen=True) -class FixedSizeSplitWithTransform(transforms.FlatMapTransform): +class FixedSizeSplitWithTransform(transforms.FlatMap): max_fan_out: int def flat_map(self, element: int): @@ -44,7 +44,7 @@ def flat_map(self, element: int): @dataclasses.dataclass(frozen=True) -class VariableSizeCappedSplitWithNoTransform(transforms.FlatMapTransform): +class VariableSizeCappedSplitWithNoTransform(transforms.FlatMap): max_fan_out: int def flat_map(self, element: int): @@ -52,7 +52,7 @@ def flat_map(self, element: int): @dataclasses.dataclass(frozen=True) -class VariableSizeUncappedSplitWithNoTransform(transforms.FlatMapTransform): +class VariableSizeUncappedSplitWithNoTransform(transforms.FlatMap): max_fan_out: int def flat_map(self, element: int): @@ -146,7 +146,7 @@ def test_with_filter(self): self.assertEqual(list(flatmap_ds), [1, 1, 3, 3, 5, 5, 7, 7, 9, 9]) -class Unbatch(transforms.FlatMapTransform): +class Unbatch(transforms.FlatMap): def flat_map(self, elements: Any) -> Sequence[Any]: return [e for e in elements] diff --git a/grain/_src/python/dataset/transformations/map.py b/grain/_src/python/dataset/transformations/map.py index 2c6ef359..9070eec1 100644 --- a/grain/_src/python/dataset/transformations/map.py +++ b/grain/_src/python/dataset/transformations/map.py @@ -141,10 +141,10 @@ class MapMapDataset(_ElementSpecFromTransformMapDatasetMixin[T]): def __init__( self, parent: dataset.MapDataset, - transform: transforms.MapTransform | Callable[[Any], T], + transform: transforms.Map | Callable[[Any], T], ): super().__init__(parent, transform) - if isinstance(transform, transforms.MapTransform): + if isinstance(transform, transforms.Map): self._map_fn = transform.map else: self._map_fn = transform @@ -185,14 +185,11 @@ class RandomMapMapDataset(_ElementSpecFromTransformMapDatasetMixin[T]): def __init__( self, parent: dataset.MapDataset, - transform: ( - transforms.RandomMapTransform - | Callable[[Any, np.random.Generator], T] - ), + transform: transforms.RandomMap | Callable[[Any, np.random.Generator], T], seed: int | None = None, ): super().__init__(parent, transform) - if isinstance(transform, transforms.RandomMapTransform): + if isinstance(transform, transforms.RandomMap): self._map_fn = transform.random_map else: self._map_fn = transform @@ -438,14 +435,11 @@ class RandomMapIterDataset(_ElementSpecFromTransformIterDatasetMixin[T]): def __init__( self, parent: dataset.IterDataset, - transform: ( - transforms.RandomMapTransform - | Callable[[Any, np.random.Generator], T] - ), + transform: transforms.RandomMap | Callable[[Any, np.random.Generator], T], seed: int | None = None, ): super().__init__(parent, transform) - if isinstance(transform, transforms.RandomMapTransform): + if isinstance(transform, transforms.RandomMap): self._map_fn = transform.random_map else: self._map_fn = transform @@ -475,10 +469,10 @@ class MapIterDataset(_ElementSpecFromTransformIterDatasetMixin[T]): def __init__( self, parent: dataset.IterDataset, - transform: transforms.MapTransform | Callable[[Any], T], + transform: transforms.Map | Callable[[Any], T], ): super().__init__(parent, transform) - if isinstance(transform, transforms.MapTransform): + if isinstance(transform, transforms.Map): self._map_fn = transform.map else: self._map_fn = transform diff --git a/grain/_src/python/dataset/transformations/map_test.py b/grain/_src/python/dataset/transformations/map_test.py index d9d35472..fa0c3998 100644 --- a/grain/_src/python/dataset/transformations/map_test.py +++ b/grain/_src/python/dataset/transformations/map_test.py @@ -40,20 +40,20 @@ def test_reset_rng_state(self): @dataclasses.dataclass(frozen=True) -class MapWithNoTransform(transforms.MapTransform): +class MapWithNoTransform(transforms.Map): def map(self, element: int): return element @dataclasses.dataclass(frozen=True) -class MapWithTransform(transforms.MapTransform): +class MapWithTransform(transforms.Map): def map(self, element: int): return element + 1 -class MapWithElementSpecInference(transforms.MapTransform): +class MapWithElementSpecInference(transforms.Map): def map(self, element: int): return { @@ -72,7 +72,7 @@ def output_spec(self, input_spec: Any) -> Any: @dataclasses.dataclass(frozen=True) -class RandomMapWithTransform(transforms.RandomMapTransform): +class RandomMapWithTransform(transforms.RandomMap): def random_map(self, element: int, rng: np.random.Generator): delta = 0.1 @@ -80,13 +80,13 @@ def random_map(self, element: int, rng: np.random.Generator): @dataclasses.dataclass(frozen=True) -class RandomMapWithDeterminismTransform(transforms.RandomMapTransform): +class RandomMapWithDeterminismTransform(transforms.RandomMap): def random_map(self, element: int, rng: np.random.Generator): return element + rng.integers(0, 10) -class RandomMapWithElementSpecInference(transforms.RandomMapTransform): +class RandomMapWithElementSpecInference(transforms.RandomMap): def random_map(self, element: int, rng: np.random.Generator): return { diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 0b206857..c082bc13 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -677,12 +677,12 @@ def __getstate__(self): local_unpicklable_obj = UnpicklableObject() - class LeftTransform(transforms.MapTransform): + class LeftTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x - class RightTransform(transforms.MapTransform): + class RightTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x @@ -718,14 +718,14 @@ class UnpicklableObject: def __getstate__(self): raise ValueError('UnpicklableObject is not picklable') - class PickleableTransform(transforms.MapTransform): + class PickleableTransform(transforms.Map): def map(self, x): return x local_unpicklable_obj = UnpicklableObject() - class RightTransform(transforms.MapTransform): + class RightTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x @@ -765,7 +765,8 @@ def test_start_prefetch( num_workers: int, per_worker_buffer_size: int, ): - class _SleepTransform(transforms.MapTransform): + + class _SleepTransform(transforms.Map): def map(self, features): time.sleep(1) diff --git a/grain/_src/python/dataset/transformations/process_prefetch_test.py b/grain/_src/python/dataset/transformations/process_prefetch_test.py index f719099c..ed0c6cbc 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch_test.py +++ b/grain/_src/python/dataset/transformations/process_prefetch_test.py @@ -580,12 +580,12 @@ def __getstate__(self): local_unpicklable_obj = UnpicklableObject() - class LeftTransform(transforms.MapTransform): + class LeftTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x - class RightTransform(transforms.MapTransform): + class RightTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x @@ -621,14 +621,14 @@ class UnpicklableObject: def __getstate__(self): raise ValueError('UnpicklableObject is not picklable') - class PickleableTransform(transforms.MapTransform): + class PickleableTransform(transforms.Map): def map(self, x): return x local_unpicklable_obj = UnpicklableObject() - class RightTransform(transforms.MapTransform): + class RightTransform(transforms.Map): def map(self, x): return x if local_unpicklable_obj else x @@ -668,7 +668,8 @@ def test_start_prefetch( num_workers: int, per_worker_buffer_size: int, ): - class _SleepTransform(transforms.MapTransform): + + class _SleepTransform(transforms.Map): def map(self, features): time.sleep(1) diff --git a/grain/_src/python/load_test.py b/grain/_src/python/load_test.py index 16f7d337..f37e6a32 100644 --- a/grain/_src/python/load_test.py +++ b/grain/_src/python/load_test.py @@ -32,7 +32,7 @@ def filter(self, x: int) -> bool: return x % 2 == 0 -class PlusOne(transforms.MapTransform): +class PlusOne(transforms.Map): def map(self, x: int) -> int: return x + 1 diff --git a/grain/experimental.py b/grain/experimental.py index a2668bdf..297c40b5 100644 --- a/grain/experimental.py +++ b/grain/experimental.py @@ -20,7 +20,7 @@ # pylint: disable=unused-import from grain._src.core.transforms import ( - FlatMapTransform, + FlatMap as FlatMapTransform, MapWithIndex as MapWithIndexTransform, ) diff --git a/grain/python/__init__.py b/grain/python/__init__.py index 22f0c5f9..0fd51238 100644 --- a/grain/python/__init__.py +++ b/grain/python/__init__.py @@ -36,9 +36,9 @@ from grain._src.core.transforms import ( Batch, Filter as FilterTransform, - MapTransform, + Map as MapTransform, MapWithIndex as MapWithIndexTransform, - RandomMapTransform, + RandomMap as RandomMapTransform, Transformation, Transformations, ) diff --git a/grain/transforms.py b/grain/transforms.py index 61b80e39..0e659854 100644 --- a/grain/transforms.py +++ b/grain/transforms.py @@ -21,10 +21,10 @@ from grain._src.core.transforms import ( Batch, Filter, - MapTransform as Map, + Map, MapWithIndex, - RandomMapTransform as RandomMap, - TfRandomMapTransform as TfRandomMap, + RandomMap, + TfRandomMap, Transformation, Transformations, )