Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions grain/_src/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import numpy as np


class MapTransform(abc.ABC):
class Map(abc.ABC):
"""Abstract base class for all 1:1 transformations of elements.

Implementations should be threadsafe since they are often executed in
Expand All @@ -44,7 +44,7 @@ def map(self, element):
"""Maps a single element."""


class RandomMapTransform(abc.ABC):
class RandomMap(abc.ABC):
"""Abstract base class for all random 1:1 transformations of elements.

Implementations should be threadsafe since they are often executed in
Expand All @@ -68,7 +68,7 @@ def map_with_index(self, index: int, element):
"""Maps a single element with its index."""


class TfRandomMapTransform(abc.ABC):
class TfRandomMap(abc.ABC):
"""Abstract base class for all random 1:1 transformations of elements."""

@abc.abstractmethod
Expand All @@ -91,7 +91,7 @@ def filter(self, element) -> bool:
"""Filters a single element; returns True if the element should be kept."""


class FlatMapTransform(abc.ABC):
class FlatMap(abc.ABC):
"""Abstract base class for splitting operations of individual elements.

Implementations should be threadsafe since they are often executed in
Expand Down Expand Up @@ -124,11 +124,11 @@ class Batch:

Transformation = Union[
Batch,
MapTransform,
RandomMapTransform,
TfRandomMapTransform,
Map,
RandomMap,
TfRandomMap,
Filter,
FlatMapTransform,
FlatMap,
MapWithIndex,
]
Transformations = Sequence[Transformation]
Expand All @@ -150,11 +150,11 @@ def get_pretty_transform_name(
transform,
(
Batch,
MapTransform,
RandomMapTransform,
TfRandomMapTransform,
Map,
RandomMap,
TfRandomMap,
Filter,
FlatMapTransform,
FlatMap,
),
):
# Check if transform class defines `__str__` and `__repr__` and use them if
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/core/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __str__(self):
return "CustomStr"


class _TestMapWithRepr(transforms.MapTransform):
class _TestMapWithRepr(transforms.Map):

def map(self, x):
return x % 2 == 0
Expand Down
18 changes: 9 additions & 9 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def use_context_if_available(obj):


@dataclasses.dataclass
class CopyNumPyArrayToSharedMemory(transforms.MapTransform):
class CopyNumPyArrayToSharedMemory(transforms.Map):
"""If `element` contains NumPy array copy it to SharedMemoryArray."""

def map(self, element: Any) -> Any:
Expand Down Expand Up @@ -412,7 +412,7 @@ def __init__(
logging.info("Enabling SharedMemoryArray for BatchOperation.")
operations[-1]._enable_shared_memory()
else:
logging.info("Adding CopyNumPyArrayToSharedMemory MapTransform.")
logging.info("Adding CopyNumPyArrayToSharedMemory Map.")
operations = list(operations) + [CopyNumPyArrayToSharedMemory()]

self._data_source = data_source
Expand Down Expand Up @@ -443,7 +443,7 @@ def __init__(
)
# pylint: enable=protected-access
self._use_native_dataset_checkpointing = any(
isinstance(op, transforms.FlatMapTransform) for op in self._operations
isinstance(op, transforms.FlatMap) for op in self._operations
)
self._dataset = self._create_dataset()

Expand Down Expand Up @@ -634,10 +634,10 @@ def _source_repr(source: RandomAccessDataSource) -> str:
return repr(source)


class _FlatMapAdapter(transforms.FlatMapTransform):
class _FlatMapAdapter(transforms.FlatMap):
"""Data loader adapter to pass through correct metadata."""

def __init__(self, transform: transforms.FlatMapTransform):
def __init__(self, transform: transforms.FlatMap):
self._transform = transform
self.max_fan_out = transform.max_fan_out

Expand All @@ -660,25 +660,25 @@ def _apply_transform_to_dataset(
ds: dataset.IterDataset,
) -> dataset.IterDataset:
"""Applies the `transform` to the dataset."""
if isinstance(transform, transforms.MapTransform):
if isinstance(transform, transforms.Map):
return ds.map(
lambda r: record.Record(metadata=r.metadata, data=transform.map(r.data))
)
elif isinstance(transform, transforms.RandomMapTransform):
elif isinstance(transform, transforms.RandomMap):
return ds.map(
lambda r: record.Record(
metadata=r.metadata,
data=transform.random_map(r.data, r.metadata.rng),
)
)
elif isinstance(transform, transforms.TfRandomMapTransform):
elif isinstance(transform, transforms.TfRandomMap):
return ds.map(
lambda r: record.Record(
metadata=r.metadata,
data=transform.np_random_map(r.data, r.metadata.rng),
)
)
elif isinstance(transform, transforms.FlatMapTransform):
elif isinstance(transform, transforms.FlatMap):
return flatmap.FlatMapIterDataset(ds, _FlatMapAdapter(transform))
elif isinstance(transform, transforms.Filter):
return ds.filter(lambda r: transform.filter(r.data))
Expand Down
14 changes: 7 additions & 7 deletions grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,26 @@ def filter(self, x: int) -> bool:
return x % 2 == 0


class PlusOne(transforms.MapTransform):
class PlusOne(transforms.Map):

def map(self, x: int) -> int:
return x + 1


class PlusRandom(transforms.RandomMapTransform):
class PlusRandom(transforms.RandomMap):

def random_map(self, x: int, rng: np.random.Generator) -> int:
return x + rng.integers(100_000)


class FailingMap(transforms.MapTransform):
class FailingMap(transforms.Map):

def map(self, x):
del x
1 / 0 # pylint: disable=pointless-statement


class NonPickableTransform(transforms.MapTransform):
class NonPickableTransform(transforms.Map):

def __getstate__(self):
raise ValueError("I shall not be pickled")
Expand All @@ -101,13 +101,13 @@ def map(self, x):
return x


class RaisingTransform(transforms.MapTransform):
class RaisingTransform(transforms.Map):

def map(self, x):
raise AttributeError("I shall raise")


class ExitingTransform(transforms.MapTransform):
class ExitingTransform(transforms.Map):

def map(self, x):
raise sys.exit(123)
Expand All @@ -126,7 +126,7 @@ def __getitem__(self, record_key: int):
}


class DuplicateElementFlatMap(transforms.FlatMapTransform):
class DuplicateElementFlatMap(transforms.FlatMap):
max_fan_out: int = 7

def flat_map(self, element: Any) -> Any:
Expand Down
22 changes: 7 additions & 15 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,7 @@ def filter(
# pylint: enable=g-import-not-at-top
return filter_dataset.FilterMapDataset(parent=self, transform=transform)

def map(
self, transform: transforms.MapTransform | Callable[[T], S]
) -> MapDataset[S]:
def map(self, transform: transforms.Map | Callable[[T], S]) -> MapDataset[S]:
"""Returns a dataset containing the elements transformed by ``transform``.

Example usage::
Expand Down Expand Up @@ -717,9 +715,7 @@ def slice(self, sl: builtins.slice) -> MapDataset[T]:

def random_map(
self,
transform: (
transforms.RandomMapTransform | Callable[[T, np.random.Generator], S]
),
transform: transforms.RandomMap | Callable[[T, np.random.Generator], S],
*,
seed: int | None = None,
) -> MapDataset[S]:
Expand Down Expand Up @@ -1181,9 +1177,7 @@ def filter(
# pylint: enable=g-import-not-at-top
return filter_dataset.FilterIterDataset(parent=self, transform=transform)

def map(
self, transform: transforms.MapTransform | Callable[[T], S]
) -> IterDataset[S]:
def map(self, transform: transforms.Map | Callable[[T], S]) -> IterDataset[S]:
"""Returns a dataset containing the elements transformed by ``transform``.

Example usage::
Expand All @@ -1210,9 +1204,7 @@ def map(

def random_map(
self,
transform: (
transforms.RandomMapTransform | Callable[[T, np.random.Generator], S]
),
transform: transforms.RandomMap | Callable[[T, np.random.Generator], S],
*,
seed: int | None = None,
) -> IterDataset[S]:
Expand Down Expand Up @@ -1754,13 +1746,13 @@ def apply_transformations(
drop_remainder=transformation.drop_remainder,
batch_fn=transformation.batch_fn,
)
case transforms.MapTransform():
case transforms.Map():
ds = ds.map(transformation)
case transforms.RandomMapTransform():
case transforms.RandomMap():
ds = ds.random_map(transformation)
case transforms.MapWithIndex():
ds = ds.map_with_index(transformation)
case transforms.FlatMapTransform():
case transforms.FlatMap():
# Loaded lazily due to a circular dependency (dataset <-> flatmap).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import flatmap
Expand Down
26 changes: 13 additions & 13 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,21 @@ def filter(self, element: int) -> bool:


@dataclasses.dataclass(frozen=True)
class RandomMapAddingRandomInt(transforms.RandomMapTransform):
class RandomMapAddingRandomInt(transforms.RandomMap):

def random_map(self, element: int, rng: np.random.Generator) -> int:
return element + rng.integers(0, 100)


@dataclasses.dataclass(frozen=True)
class RandomMapAlwaysAddingOne(transforms.RandomMapTransform):
class RandomMapAlwaysAddingOne(transforms.RandomMap):

def random_map(self, element: int, rng: np.random.Generator) -> int:
return element + 1


@dataclasses.dataclass(frozen=True)
class MapTransformAddingOne(transforms.MapTransform):
class MapAddingOne(transforms.Map):

def map(self, element: int) -> int:
return element + 1
Expand Down Expand Up @@ -97,7 +97,7 @@ def __getitem__(self, index):
return (index + 1) % 2, index // 2


class AddRandomInteger(transforms.RandomMapTransform):
class AddRandomInteger(transforms.RandomMap):

def random_map(self, element, rng):
return element + rng.integers(low=0, high=100)
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def test_map_has_one_parent(self, initial_ds):
dataset.MapDataset.range(15),
dataset.MapDataset.range(15).to_iter_dataset(),
],
transform=[MapTransformAddingOne(), lambda x: x + 1],
transform=[MapAddingOne(), lambda x: x + 1],
)
def test_map_produces_correct_elements(self, initial_ds, transform):
ds = initial_ds.map(transform)
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def test_concatenate(self):
(dataset.MapDataset.range(5).to_iter_dataset(),),
)
def test_apply(self, ds):
ds = ds.apply([MapTransformAddingOne(), transforms.Batch(2)])
ds = ds.apply([MapAddingOne(), transforms.Batch(2)])
np.testing.assert_equal(
list(ds),
[
Expand All @@ -1184,7 +1184,7 @@ def test_apply(self, ds):
)


class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):
class TfRandomMapAlwaysAddingOne(transforms.TfRandomMap):

def np_random_map(self, x, rng):
return x + 1
Expand All @@ -1196,7 +1196,7 @@ def filter(self, x):
return np.sum(x) < 20


class FlatMapAddingOne(transforms.FlatMapTransform):
class FlatMapAddingOne(transforms.FlatMap):

max_fan_out = 2

Expand All @@ -1211,7 +1211,7 @@ class ApplyTransformationsTest(parameterized.TestCase):
(dataset.MapDataset.range(15).to_iter_dataset(),),
)
def test_single_transform(self, ds):
ds = dataset.apply_transformations(ds, MapTransformAddingOne())
ds = dataset.apply_transformations(ds, MapAddingOne())
self.assertSequenceEqual(list(ds), list(range(1, 16)))

@parameterized.parameters(
Expand All @@ -1223,7 +1223,7 @@ def test_multiple_transforms(self, ds):
ds = dataset.apply_transformations(
ds,
[
MapTransformAddingOne(),
MapAddingOne(),
RandomMapAlwaysAddingOne(),
transforms.Batch(batch_size=2, drop_remainder=True),
FilterArraysWithLargeSum(),
Expand Down Expand Up @@ -1359,7 +1359,7 @@ def test_get_execution_summary_without_collection(self):
def test_execution_summary_with_logging(self):
with self.assertLogs(level="INFO") as logs:
ds = dataset.MapDataset.range(10).shuffle(42)
ds = ds.map(MapTransformAddingOne())
ds = ds.map(MapAddingOne())
ds = ds.to_iter_dataset()
it = ds.__iter__()
# Get execution summary after iterating through the dataset.
Expand All @@ -1377,7 +1377,7 @@ def test_execution_summary_with_logging(self):
def test_execution_summary_with_no_logging(self):
with self.assertLogs(level="INFO") as logs:
ds = dataset.MapDataset.range(10).shuffle(42)
ds = ds.map(MapTransformAddingOne())
ds = ds.map(MapAddingOne())
ds = ds.to_iter_dataset()
ds = dataset.WithOptionsIterDataset(
ds,
Expand All @@ -1400,7 +1400,7 @@ def worker_init_fn_wrapper(worker_index, worker_count):
del worker_index, worker_count
dataset_stats._REPORTING_PERIOD_SEC = 0.05

ds = dataset.MapDataset.range(10000).map(MapTransformAddingOne())
ds = dataset.MapDataset.range(10000).map(MapAddingOne())
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(
options.MultiprocessingOptions(num_workers=1),
Expand Down
2 changes: 1 addition & 1 deletion grain/_src/python/dataset/stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _identity(x):
return x


class _AddOne(transforms.MapTransform):
class _AddOne(transforms.Map):

def map(self, x):
return x + 1
Expand Down
Loading