-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy patharray_record_data_source.py
More file actions
407 lines (352 loc) · 14.1 KB
/
array_record_data_source.py
File metadata and controls
407 lines (352 loc) · 14.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""array_record_data_source module.
Warning: this is an experimental module. The interface might change in the
future without backwards compatibility.
Data source is an abstraction that is responsible for retrieving data records
from storage backend in ML workloads (e.g. a set of files, a database). It
implements a simple Python interface to query ArrayRecord files:
```
class RandomAccessDataSource(Protocol, Generic[T]):
def __len__(self) -> int:
...
def __getitem__(self, record_keys: Sequence[int]) -> Sequence[T]:
...
```
"""
import bisect
from concurrent import futures
import dataclasses
import hashlib
import itertools
import os
import pathlib
import re
import typing
from typing import Any, Callable, Iterator, List, Mapping, Protocol, Sequence, SupportsIndex, Tuple, TypeVar, Union
from absl import flags
from absl import logging
from etils import epath
from etils import epy
from python import array_record_module
# TODO(jolesiak): Decide what to do with these flags, e.g., remove them (could
# be appropriate if we decide to use asyncio) or move them somewhere else and
# pass the number of threads as an argument. For now, since we experiment, it's
# convenient to have them.
_GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS = flags.DEFINE_integer(
"grain_num_threads_computing_num_records",
64,
(
"The number of threads used to fetch file instructions (i.e., the max"
" number of Array Record files opened while calculating the total"
" number of records)."
),
)
_GRAIN_NUM_THREADS_FETCHING_RECORDS = flags.DEFINE_integer(
"grain_num_threads_fetching_records",
64,
(
"The number of threads used to fetch records from Array Record files. "
"(i.e., the max number of Array Record files opened while fetching "
"records)."
),
)
T = TypeVar("T")
def _run_in_parallel(
function: Callable[..., T],
list_of_kwargs_to_function: Sequence[Mapping[str, Any]],
num_workers: int,
) -> List[T]:
"""Runs `function` in parallel threads with given keyword arguments.
This is useful for performing IO in parallel. CPU bound functions will likely
not be faster.
Args:
function: The function to execute in parallel.
list_of_kwargs_to_function: A list of dicts mapping from string to argument
value. These will be passed into `function` as kwargs.
num_workers: Number of threads in the thread pool.
Returns:
list of return values from function, in the same order as the arguments in
list_of_kwargs_to_function.
"""
if num_workers < 1:
raise ValueError("num_workers must be >=1 for parallelism.")
thread_futures = []
with futures.ThreadPoolExecutor(num_workers) as executor:
for kwargs in list_of_kwargs_to_function:
future = executor.submit(function, **kwargs)
thread_futures.append(future)
futures_as_completed = futures.as_completed(thread_futures)
for completed_future in futures_as_completed:
if completed_future.exception():
# Cancel all remaining futures, if possible. In Python>3.8, you can call
# `executor.shutdown(cancel_futures=True)`.
for remaining_future in thread_futures:
remaining_future.cancel()
raise completed_future.exception()
return [future.result() for future in thread_futures]
@dataclasses.dataclass(frozen=True)
class _ReadInstruction:
"""Internal class used to keep track of files and records to read from them."""
filename: str
start: int
end: int
num_records: int = dataclasses.field(init=False)
def __post_init__(self):
object.__setattr__(self, "num_records", self.end - self.start)
@typing.runtime_checkable
class FileInstruction(Protocol):
"""Protocol with same interface as FileInstruction returned by TFDS.
ArrayRecordDataSource would accept objects implementing this protocol without
depending on TFDS.
"""
filename: str
skip: int
take: int
examples_in_shard: int
PathLikeOrFileInstruction = Union[epath.PathLike, FileInstruction]
def _get_read_instructions(
paths: Sequence[PathLikeOrFileInstruction],
) -> Sequence[_ReadInstruction]:
"""Constructs ReadInstructions for given paths."""
def get_read_instruction(path: PathLikeOrFileInstruction) -> _ReadInstruction:
if isinstance(path, FileInstruction):
start = path.skip
end = path.skip + path.take
path = os.fspath(path.filename)
elif m := re.fullmatch(r"(.*)\[(\d+):(\d+)\]", os.fspath(path)):
path = m.group(1)
start = int(m.group(2))
end = int(m.group(3))
else:
path = os.fspath(path)
reader = array_record_module.ArrayRecordReader(path)
start = 0 # Using whole file.
end = reader.num_records()
reader.close()
return _ReadInstruction(path, start, end)
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_COMPUTING_NUM_RECORDS)
num_workers = min(len(paths), num_threads)
return _run_in_parallel(
function=get_read_instruction,
list_of_kwargs_to_function=[{"path": path} for path in paths],
num_workers=num_workers,
)
def _create_reader(filename: epath.PathLike):
"""Returns an ArrayRecordReader for the given filename."""
return array_record_module.ArrayRecordReader(
filename,
options="readahead_buffer_size:0",
file_reader_buffer_size=32768,
)
def _check_group_size(
filename: epath.PathLike, reader: array_record_module.ArrayRecordReader
) -> None:
"""Logs an error if the group size of the underlying file is not 1."""
options = reader.writer_options_string()
# The ArrayRecord Python API does not include methods to parse the options.
# We will likely move this to C++ soon. In the meantime, we just test if
# 'group_size:1' is in the options string.
# The string might be empty for old files written before October 2022.
if not options:
return
group_size = re.search(r"group_size:(\d+),", options)
if not group_size:
raise ValueError(
f"Couldn't detect group_size for {filename}. Extracted writer options:"
f" {options}."
)
if group_size[1] != "1":
logging.error(
(
"File %s was created with group size %s. Grain requires group size"
" 1 for good performance. Please re-generate your ArrayRecord files"
" with 'group_size:1'."
),
filename,
group_size[1],
)
class ArrayRecordDataSource:
"""Datasource for ArrayRecord files."""
def __init__(
self,
paths: Union[
PathLikeOrFileInstruction, Sequence[PathLikeOrFileInstruction]
],
):
"""Creates a new ArrayRecordDataSource object.
Note on the terminology:
* record_key: This is the global key of a record in a list of files.
* position: position of a record within a specific file.
For example, assume we have two files: my_file-00000-of-00002 and
my_file-00001-of-00002. If both files have 100 records each, then we can
read keys in [0, 199] (record_keys can be anywhere in that range).
record_key 40 will map to the record at position 40 in
my_file-00000-of-00002 and key 121 would map to the record at position 21
in my_file-00001-of-00002.
Args:
paths: This can be a single path/FileInstruction or list of
paths/FileInstructions. When you want to read subsets or have a large
number of files prefer to pass FileInstructions. This makes the
initialization faster.
"""
if isinstance(paths, (str, pathlib.Path, FileInstruction)):
paths = [paths]
elif isinstance(paths, Sequence):
# Validate correct format of a sequence path
if len(paths) <= 0:
raise ValueError("Paths sequence can not be of 0 length")
elif not all(
isinstance(path, (str, pathlib.Path, FileInstruction))
for path in paths
):
raise ValueError(
"All elements in a path sequence must be of type: String,"
" pathlib.Path, or FileInstruction."
)
else:
raise ValueError(
"Unsupported path format was used. Path format must be "
"a Sequence, String, pathlib.Path or FileInstruction."
)
self._read_instructions = _get_read_instructions(paths)
self._paths = [ri.filename for ri in self._read_instructions]
# We open readers lazily when we need to read from them.
self._readers = [None] * len(self._read_instructions)
self._num_records = sum(
map(lambda x: x.num_records, self._read_instructions)
)
records_per_instruction = map(
lambda x: x.num_records, self._read_instructions
)
self._prefix_sums = list(itertools.accumulate(records_per_instruction))
def __enter__(self):
logging.debug("__enter__ for ArrayRecordDataSource is called.")
return self
def __exit__(self, exc_type, exc_value, traceback):
logging.debug("__exit__ for ArrayRecordDataSource is called.")
for reader in self._readers:
if reader:
reader.close()
self._readers = [None] * len(self._read_instructions)
def __len__(self) -> int:
return self._num_records
def __iter__(self) -> Iterator[bytes]:
for index in range(self._num_records):
yield self[index]
def _reader_idx_and_position(
self, record_key: SupportsIndex
) -> Tuple[int, int]:
"""Computes reader idx and position of given record key."""
record_key = record_key.__index__()
if record_key < 0 or record_key >= self._num_records:
raise ValueError("Record key should be in [0, num_records)")
reader_idx = bisect.bisect_right(self._prefix_sums, record_key)
records_in_previous_instructions = 0
if reader_idx > 0:
records_in_previous_instructions = self._prefix_sums[reader_idx - 1]
return (
reader_idx,
record_key
- records_in_previous_instructions
+ self._read_instructions[reader_idx].start,
)
def _split_keys_per_reader(
self, record_keys: Sequence[SupportsIndex]
) -> Mapping[int, Sequence[Tuple[int, int]]]:
"""Splits record_keys among readers."""
positions_and_indices = {}
for idx, record_key in enumerate(record_keys):
reader_idx, position = self._reader_idx_and_position(record_key)
if reader_idx in positions_and_indices:
positions_and_indices[reader_idx].append((position, idx))
else:
positions_and_indices[reader_idx] = [(position, idx)]
return positions_and_indices
def _ensure_reader_exists(self, reader_idx: int) -> None:
"""Threadsafe method to create corresponding reader if it doesn't exist."""
if self._readers[reader_idx] is not None:
return
filename = self._read_instructions[reader_idx].filename
reader = _create_reader(filename)
_check_group_size(filename, reader)
self._readers[reader_idx] = reader
def __getitem__(self, record_key: SupportsIndex) -> bytes:
reader_idx, position = self._reader_idx_and_position(record_key)
self._ensure_reader_exists(reader_idx)
if hasattr(self._readers[reader_idx], "read"):
return self._readers[reader_idx].read([position])[0]
return self._readers[reader_idx][position]
def __getitems__(
self, record_keys: Sequence[SupportsIndex]
) -> Sequence[bytes]:
def read_records(
reader_idx: int, reader_positions_and_indices: Sequence[Tuple[int, int]]
) -> Sequence[Tuple[Any, int]]:
"""Reads records using the given reader keeping track of the indices."""
# Initialize readers lazily when we need to read from them.
self._ensure_reader_exists(reader_idx)
positions, indices = list(zip(*reader_positions_and_indices))
if hasattr(self._readers[reader_idx], "read"):
records = self._readers[reader_idx].read(positions) # pytype: disable=attribute-error
else:
records = [self._readers[reader_idx][p] for p in positions]
return list(zip(records, indices))
positions_and_indices = self._split_keys_per_reader(record_keys)
num_threads = _get_flag_value(_GRAIN_NUM_THREADS_FETCHING_RECORDS)
num_workers = min(len(positions_and_indices), num_threads)
list_of_kwargs_to_read_records = []
for (
reader_idx,
reader_positions_and_indices,
) in positions_and_indices.items():
list_of_kwargs_to_read_records.append({
"reader_idx": reader_idx,
"reader_positions_and_indices": reader_positions_and_indices,
})
records_with_indices: Sequence[Sequence[Tuple[Any, int]]] = (
_run_in_parallel(
function=read_records,
list_of_kwargs_to_function=list_of_kwargs_to_read_records,
num_workers=num_workers,
)
)
sorted_records = [b""] * len(record_keys)
for single_reader_records_with_indices in records_with_indices:
for record, index in single_reader_records_with_indices:
sorted_records[index] = record
return sorted_records
def __getstate__(self):
logging.debug("__getstate__ for ArrayRecordDataSource is called.")
state = self.__dict__.copy()
del state["_readers"]
return state
def __setstate__(self, state):
logging.debug("__setstate__ for ArrayRecordDataSource is called.")
self.__dict__.update(state)
# We open readers lazily when we need to read from them. Thus, we don't
# need to re-open the same files as before pickling.
self._readers = [None] * len(self._read_instructions)
def __repr__(self) -> str:
"""Storing a hash of paths since paths can be a very long list."""
h = hashlib.sha1()
for p in self._paths:
h.update(p.encode())
return f"ArrayRecordDataSource(hash_of_paths={h.hexdigest()})"
def _get_flag_value(flag: flags.FlagHolder[int]) -> int:
"""Retrieves the flag value or the default if run outside of absl."""
try:
return flag.value
except flags.UnparsedFlagAccessError:
return flag.default