diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index 29734f51..68a6a4c5 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -53,7 +53,7 @@ py_test( py_test( name = "prefetch_test", - timeout = "long", + timeout = "eternal", srcs = ["prefetch_test.py"], shard_count = 50, srcs_version = "PY3", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 03001eb4..1b09dee6 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -1017,8 +1017,8 @@ def __str__(self) -> str: def multithread_prefetch( ds: dataset.IterDataset[T], - num_threads: int, - buffer_size: int, + num_threads: int = 0, + buffer_size: int = 1, sequential_slice: bool = False, ) -> dataset.IterDataset[T]: """Uses a pool of threads to prefetch elements ahead of time. @@ -1043,14 +1043,17 @@ def multithread_prefetch( if num_threads == 0: return ds - _validate_no_double_prefetch(ds) + dataset_options = _get_dataset_options(ds) shards = [] for i in range(num_threads): - worker_ds = copy.deepcopy(ds) - _set_slice_iter_dataset( - worker_ds, slice(i, None, num_threads), sequential_slice - ) + if num_threads == 1: + worker_ds = ds + else: + worker_ds = copy.deepcopy(ds) + _set_slice_iter_dataset( + worker_ds, slice(i, None, num_threads), sequential_slice + ) shards.append( _MpContextIterDataset( worker_ds, @@ -1061,6 +1064,10 @@ def multithread_prefetch( ) ) - return interleave.InterleaveIterDataset( + ds = interleave.InterleaveIterDataset( shards, cycle_length=num_threads, iter_buffer_size=buffer_size ) + # Apply options from parent dataset because interleave dataset does not + # propagate options. + ds = dataset.WithOptionsIterDataset(ds, dataset_options) + return ds diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 597d8565..046f646b 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -1256,146 +1256,356 @@ def set_state(self, state): self._parent.set_state(state) -class MultithreadPrefetchIterDatasetTest(parameterized.TestCase): +class MultithreadPrefetchTest(parameterized.TestCase): def setUp(self): super().setUp() - self.ds = dataset.MapDataset.range(20).to_iter_dataset() + ds = dataset.MapDataset.range(20) + self.iter_ds = ds.to_iter_dataset().filter(FilterKeepingOddElementsOnly()) @parameterized.named_parameters( dict( - testcase_name='no_prefetch', - num_workers=0, - per_worker_buffer_size=0, + testcase_name='0_workers', + num_threads=0, + per_worker_buffer_size=1, ), dict( - testcase_name='thread', - num_workers=1, + testcase_name='1_worker', + num_threads=1, per_worker_buffer_size=1, ), dict( - testcase_name='2_threads_large_buffer', - num_workers=2, + testcase_name='1_worker_large_buffer', + num_threads=1, per_worker_buffer_size=20, ), dict( - testcase_name='4_threads_huge_buffer', - num_workers=4, - per_worker_buffer_size=200, + testcase_name='10_workers', + num_threads=10, + per_worker_buffer_size=1, + ), + dict( + testcase_name='10_workers_large_buffer', + num_threads=10, + per_worker_buffer_size=20, ), ) - def test_prefetch_data(self, num_workers: int, per_worker_buffer_size: int): + def test_prefetch_data(self, num_threads: int, per_worker_buffer_size: int): prefetch_lazy_iter_ds = prefetch.multithread_prefetch( - self.ds, - num_threads=num_workers, + self.iter_ds, + num_threads=num_threads, buffer_size=per_worker_buffer_size, ) - ds_iter = prefetch_lazy_iter_ds.__iter__() - if num_workers > 0: - ds_iter.start_prefetch() - actual = list(ds_iter) - expected = list(range(20)) + actual = list(prefetch_lazy_iter_ds) + expected = list(range(1, 20, 2)) self.assertSequenceEqual(actual, expected) - def test_checkpoint(self): + def test_prefetch_size_zero_data(self): + ds = dataset.MapDataset.source( + [np.zeros(shape=(0,), dtype=np.int64)] + ).repeat(3) + iter_ds = ds.to_iter_dataset() + prefetch_lazy_iter_ds = prefetch.multithread_prefetch( + iter_ds, + num_threads=1, + ) + actual = list(prefetch_lazy_iter_ds) + expected = [np.zeros(shape=(0,), dtype=np.int64)] * 3 + self.assertLen(actual, 3) + self.assertLen(expected, 3) + for i in range(3): + np.testing.assert_array_equal(actual[i], expected[i]) + + @parameterized.product( + ( + dict(num_threads=0), + dict(num_threads=1), + dict(num_threads=10), + ), + step_index=[0, 3, 8], + ) + def test_checkpoint(self, num_threads: int, step_index: int): ds = prefetch.multithread_prefetch( - self.ds, - num_threads=2, - buffer_size=5, + self.iter_ds, + num_threads=num_threads, ) ds_iter = ds.__iter__() - ds_iter.start_prefetch() - max_steps = 20 + max_steps = 10 values_without_interruption = [] checkpoints = [] for _ in range(max_steps): checkpoints.append(ds_iter.get_state()) values_without_interruption.append(next(ds_iter)) - for starting_step in [0, 5, 13, 19]: - ds_iter.set_state(checkpoints[starting_step]) - ds_iter.start_prefetch() - for i in range(starting_step, max_steps): - value = next(ds_iter) - print(value) - self.assertEqual(value, values_without_interruption[i]) + ds_iter.set_state(checkpoints[step_index]) + for i in range(step_index, max_steps): + value = next(ds_iter) + self.assertEqual(value, values_without_interruption[i]) - def test_set_state_on_fresh_iterator(self): + def test_set_state_twice(self): ds = prefetch.multithread_prefetch( - self.ds, + self.iter_ds, num_threads=2, - buffer_size=2, ) ds_iter = ds.__iter__() - ds_iter.start_prefetch() - max_steps = 20 + max_steps = 10 values_without_interruption = [] checkpoints = [] for _ in range(max_steps): checkpoints.append(ds_iter.get_state()) values_without_interruption.append(next(ds_iter)) - for starting_step in [0, 5, 13, 19]: - ds_iter = ds.__iter__() + for starting_step in [0, 3, 8]: ds_iter.set_state(checkpoints[starting_step]) - ds_iter.start_prefetch() for i in range(starting_step, max_steps): value = next(ds_iter) self.assertEqual(value, values_without_interruption[i]) - def test_get_state_doesnt_start_prefetch(self): - event = threading.Event() + def test_works_with_iter_source_single_worker(self): + # Even though a pure IterDataset cannot be sliced, we should still be able + # to multiprocess-prefetch it with a single worker, since that doesn't + # require any slicing. + ds = prefetch.multithread_prefetch( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + num_threads=1, + ) + ds_iter = iter(ds) + self.assertEqual(next(ds_iter), 2) - def f(x): - event.set() - return x + def test_fails_with_iter_source_multiple_workers(self): + with self.assertRaisesRegex( + ValueError, + 'Cannot slice `IterDataset` source.', + ): + prefetch.multithread_prefetch( + RepeatedIntSourceIterDataset().map(lambda x: x + 1), + num_threads=2, + ) + + def test_propagates_transform_error(self): + error_msg = 'I shall fail!' + + def failing_transform(element): + del element + raise ValueError(error_msg) - ds = dataset.MapDataset.source([1, 2, 3]).map(f).to_iter_dataset() ds = prefetch.multithread_prefetch( - ds, - num_threads=2, - buffer_size=10, + self.iter_ds.map(failing_transform), + num_threads=1, ) - it = ds.__iter__() - it.get_state() - time.sleep(1) - self.assertFalse(event.is_set()) + with self.assertRaisesRegex(Exception, error_msg): + list(ds) - def test_does_not_hang_after_stop_iteration(self): - ds = dataset.MapDataset.source([1, 2, 3]).repeat(100).to_iter_dataset() + @parameterized.product( + start_prefetch_calls=[0, 1, 10], + num_threads=[6], + per_worker_buffer_size=[1, 20], + ) + def test_start_prefetch( + self, + start_prefetch_calls: int, + num_threads: int, + per_worker_buffer_size: int, + ): + class _SleepTransform(transforms.MapTransform): + + def map(self, features): + time.sleep(1) + return features + + ds = dataset.MapDataset.range(10) + ds = ds.map(_SleepTransform()) + ds = ds.to_iter_dataset() ds = prefetch.multithread_prefetch( ds, - num_threads=2, - buffer_size=10, + num_threads=num_threads, + buffer_size=per_worker_buffer_size, ) + + it = ds.__iter__() + for _ in range(start_prefetch_calls): + it.start_prefetch() + + # Waits for prefetching. + start_time = time.time() + while time.time() - start_time < 30: + time.sleep(2) + + # Measures time to read from the dataset. + start_time = time.time() + self.assertSequenceEqual(list(it), list(range(10))) + + time_to_fetch = time.time() - start_time + logging.info('Reading dataset took %.2f seconds.', time_to_fetch) + # Note that we can't reliably assert the upper bound on the time it takes + # read the dataset elements since worker startup time can vary a lot. + if not start_prefetch_calls: + self.assertGreater(time_to_fetch, 1) + + @parameterized.parameters(0, 0.5, 30) + def test_prefetch_but_no_read(self, sleep_s): + ds = dataset.MapDataset.source([1, 2, 3]).repeat() + ds = ds.filter(lambda x: x > 3) + ds = ds.to_iter_dataset() + ds = prefetch.multithread_prefetch(ds, num_threads=1) it = ds.__iter__() it.start_prefetch() + time.sleep(sleep_s) + del it - def test_fails_with_multiprocess_prefetch_parent(self): - ds = prefetch.MultiprocessPrefetchIterDataset( - self.ds, - options.MultiprocessingOptions(num_workers=2), + def test_prefetch_with_random_map(self): + ds = dataset.MapDataset.source([0]).repeat(100).to_iter_dataset() + ds = ds.random_map(lambda x, rng: x + rng.integers(sys.maxsize), seed=42) + ds = prefetch.multithread_prefetch( + ds, + num_threads=5, ) - with self.assertRaisesRegex( - ValueError, - 'Nesting multiprocessing or multithreading is not allowed.', - ): - _ = prefetch.multithread_prefetch( - ds, - num_threads=1, - buffer_size=1, - ) + # Make sure that sliced datasets on workers are seeded differently and thus + # produce different random elements. + elements = list(ds) + distinct_elements = set(elements) + self.assertLen(distinct_elements, len(elements)) + + def test_concurrent_start_prefetch(self): + num_iters = 10 # Can't set this much higher without Forge OOMing. + + def make_iter(i): + ds = dataset.MapDataset.source([i]) + ds = ds.to_iter_dataset() + ds = prefetch.multithread_prefetch(ds, num_threads=1) + return ds.__iter__() + + iters = [make_iter(i) for i in range(num_iters)] + with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: + for it in iters: + executor.submit(it.start_prefetch) + for it in iters: + _ = next(it) + + def test_options_before_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.to_iter_dataset() + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + ds = prefetch.multithread_prefetch(ds, num_threads=1) + ds = ds.filter(lambda x: x > 2) + with self.assertRaises(Exception): + list(ds) + + def test_multiprocess_prefetch_with_sequential_slice(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds = prefetch.multithread_prefetch( + ds, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + self.assertEqual(list(ds), [0, 4, 7, 1, 5, 8, 2, 6, 9, 3]) + + def test_multiprocess_prefetch_with_default_slice_non_sequential(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_sequential_off = prefetch.multithread_prefetch( + ds, + num_threads=3, + buffer_size=1, + sequential_slice=False, + ) + ds_sequential_default = prefetch.multithread_prefetch( + ds, + num_threads=3, + buffer_size=1, + ) + elements_sequential_off = list(ds_sequential_off) + elements_sequential_default = list(ds_sequential_default) + self.assertEqual( + elements_sequential_off, + elements_sequential_default, + ) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + elements_sequential_default, + ) + + def test_multiprocess_prefetch_sequential_slice_order_from_source(self): + ds = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_sequential_on = prefetch.multithread_prefetch( + ds, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + elements_sequential_on = list(ds_sequential_on) + self.assertEqual([0, 4, 7, 1, 5, 8, 2, 6, 9, 3], elements_sequential_on) + + def test_multiprocess_prefetch_sequential_slice_order_from_range(self): + ds_range = dataset.MapDataset.range(10).to_iter_dataset() + ds_range_sequential_on = prefetch.multithread_prefetch( + ds_range, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + elements_range_sequential_on = list(ds_range_sequential_on) + self.assertEqual( + [0, 4, 7, 1, 5, 8, 2, 6, 9, 3], + elements_range_sequential_on, + ) + + def test_multiprocess_prefetch_sequential_slice_order_from_range_slice(self): + ds_range = dataset.MapDataset.range( + start=2, stop=21, step=3 + ).to_iter_dataset() + ds_range_sequential_on = prefetch.multithread_prefetch( + ds_range, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + elements_range_sequential_on = list(ds_range_sequential_on) + self.assertEqual( + [2, 11, 17, 5, 14, 20, 8], + elements_range_sequential_on, + ) + + def test_multiprocess_prefetch_sequential_slice_order_same(self): + ds_source = dataset.MapDataset.source(range(10)).to_iter_dataset() + ds_range = dataset.MapDataset.range(10).to_iter_dataset() + ds_source_mp = prefetch.multithread_prefetch( + ds_source, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + ds_range_mp = prefetch.multithread_prefetch( + ds_range, + num_threads=3, + buffer_size=1, + sequential_slice=True, + ) + elements_source = list(ds_source_mp) + elements_range = list(ds_range_mp) + self.assertEqual(elements_source, elements_range) + + def test_options_after_prefetch(self): + ds = dataset.MapDataset.source([1, 2, 3]).repeat(1000) + ds = ds.filter(lambda x: x > 2) + ds = ds.to_iter_dataset() + ds = prefetch.multithread_prefetch(ds, num_threads=1) + ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) + ds = dataset.WithOptionsIterDataset(ds, ds_options) + with self.assertRaises(Exception): + list(ds) def test_mp_context_is_set_correctly(self): - num_workers = 4 + num_threads = 4 ds = dataset.MapDataset.range(20).to_iter_dataset() ds = _MpContextCheckIterDataset(ds) ds = ds.map(lambda x: x) ds = prefetch.multithread_prefetch( ds, - num_threads=num_workers, + num_threads=num_threads, buffer_size=1, ) @@ -1408,8 +1618,8 @@ def test_mp_context_is_set_correctly(self): # Check mp_context. for i, (_, context) in enumerate(results): - self.assertEqual(context.process_index, i % num_workers) - self.assertEqual(context.process_count, num_workers) + self.assertEqual(context.process_index, i % num_threads) + self.assertEqual(context.process_count, num_threads) if __name__ == '__main__':