-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript_convert_h5.py
More file actions
executable file
·495 lines (400 loc) · 17.1 KB
/
script_convert_h5.py
File metadata and controls
executable file
·495 lines (400 loc) · 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
#!/usr/bin/env python
# coding: utf-8
"""
Converts the predictions.jsonl files generated by the training scripts to much faster h5 files.
The format of the H5 is:
- "input_samples": num_samples x max_len_input_samples
- "predictions": num_epochs x num_samples x max_len_predictions
The root of the H5 has an "epochs" attr, of the epochs. This is uselesss now,
as epochs are just range(num_epochs) now.
"""
import collections
import enum
import multiprocessing
import multiprocessing.pool as mp_pool
import queue as threading_queue
import itertools
from pathlib import Path
import pickle
import subprocess
import time
import threading
from beartype import beartype
from beartype.typing import *
import fire # type: ignore[import]
import jsonlines as jsonl # type: ignore[import]
import h5py # type: ignore[import]
import pretty_traceback # type: ignore
import numpy as np
import rich
from tqdm import tqdm # type: ignore[import]
import data_tokenizer
import general_utils
import script_data_subset_selection
pretty_traceback.install()
SCRIPT_DIR = Path(__file__).absolute().parent
DATA_DIR = SCRIPT_DIR / "data"
TARGET_FILE_NAME = "predictions.h5"
H5_INPUT_IDS_KEY = "input_samples"
H5_LABEL_IDS_KEY = "label_ids"
H5_PREDICTIONS_KEY = "predictions"
MAIN_DATASET_EVAL_KEY = "eval"
MAIN_DATASET_DATA_KEY = "data"
NODE_VALUE_STR_KEY = "value"
NODE_INPUT_STR_KEY = "input_str"
SUBSET_INPUT_IDS_KEY = "subset_ids"
PathType = Union[Path, str]
def tokenize_pad_numpify(
tokenizer, strings: Union[Iterable[str], Mapping[Any, str]],
key: Any = None,
pad_to: Optional[int] = None,
) -> np.ndarray:
if key:
take_fn = lambda x: x[key]
else:
take_fn = lambda x: x
tokenized = [tokenizer.encode(take_fn(x), no_eos=False, return_tensors=None) for x in strings]
if pad_to is None:
pad_to = max(len(x) for x in tokenized)
padded = [x + [tokenizer.pad_token_id] * (pad_to - len(x)) for x in tokenized]
return np.array(padded, dtype=np.int64)
def _convert(
input_path: Union[str, Path],
tokenizer: data_tokenizer.ArithmeticTokenizer,
input_ids: np.ndarray,
max_epochs: int = None,
verbose: bool = False,
queue = None,
):
start_time = time.process_time()
if verbose:
print("Working.")
###########################################################################
# Read the data
###########################################################################
input_path = Path(input_path)
output_path = input_path.parent / f"{input_path.stem}.h5"
if verbose:
rich.print("[bold]Counting lines of the jsonl.")
num_lines = int(
subprocess.check_output(["wc", "-l", str(input_path)])
.strip()
.decode()
.split()[0]
)
if max_epochs:
assert max_epochs <= num_lines, (max_epochs, num_lines)
target_qty = min(max_epochs, num_lines)
else:
target_qty = num_lines
if verbose:
rich.print(f"[bold]Fewer than {num_lines} lines.")
print()
if verbose:
rich.print("[bold]Reading jsonl.")
with jsonl.open(input_path) as f:
iterable = f
if max_epochs:
iterable = itertools.islice(f, target_qty)
if verbose:
iterable = tqdm(iterable, total=target_qty)
content = [x for x in iterable]
del iterable
del target_qty # Not meant to be used again.
del max_epochs # Not meant to be used again.
del num_lines # Not meant to be used again.
if verbose:
rich.print("[bold]Done reading jsonl.")
###########################################################################
# Prep the data
###########################################################################
# We ignore pytorch lightning's "sanity test" tiny zeroth epoch.
if content[0]["epoch"] == 0 and content[1]["epoch"] == 0:
content = content[1:]
# We make sure that all epochs have the same number of samples.
assert all([len(content[i]) == len(content[0]) for i in range(1, len(content))]), (
[len(content[i]) for i in range(len(content))]
)
# If an epoch happens twice, we remove the second one.
# This happens we think when training is interrupted and restarted, PL
# starts with an eval pass.
new_content = []
epochs_seen = set()
for epoch in content:
if epoch["epoch"] in epochs_seen:
continue
epochs_seen.add(epoch["epoch"])
new_content.append(epoch)
content = new_content
for i in range(len(content)):
content[i]["results"] = {
k: content[i]["results"][k]["True"]["per_batch"]
for k in content[i]["results"]
}
sorted_keys: list[str] = list(content[0]["results"].keys())
sorted_keys.sort()
for epoch_content in content:
assert sorted_keys == sorted(epoch_content["results"].keys())
del epochs_seen # Not meant to be used again.
del new_content # Not meant to be used again.
###########################################################################
# Write the data
###########################################################################
num_samples = len(content[1]["results"])
num_epochs = len(content)
len_seqs_output = max([
max([
len(samples)
for samples in epochs["results"].values()
]) for epochs in content
])
if verbose:
rich.print(f"[bold]Doing h5py.")
with h5py.File(output_path, "w") as output_file:
if verbose:
rich.print("Creating datasets.")
assert tokenizer.pad_token_id == 0, tokenizer.pad_token_id
output_file.create_dataset(
H5_INPUT_IDS_KEY,
data=input_ids,
)
output_file.create_dataset(
H5_PREDICTIONS_KEY,
shape=(num_epochs, num_samples, len_seqs_output),
dtype=np.int64,
)
predictions = output_file["predictions"]
if verbose:
rich.print("Writing data.")
num_samples = None
epochs_seen_list: list[int] = []
for entry_idx in range(len(content)):
real_epoch = content[entry_idx]["epoch"]
######################################################
# DON'T REMOVE THIS CHECK
assert real_epoch == entry_idx, (real_epoch, entry_idx)
######################################################
# If we're after the 0th epoch
# Then the saved keys should be the same as the
# keys of the current epoch
for input_idx, k in enumerate(sorted_keys):
# Make sure that we're only adding keys if we're
# in the zeroth epoch
if real_epoch == 0:
tokenized = tokenizer.encode(k, return_tensors=None)
input_ids_gen = tokenized + [
tokenizer.pad_token_id
] * max(input_ids[input_idx].shape[0] - len(tokenized), 0)
assert np.all(input_ids[input_idx] == input_ids_gen)
# The predictions are already encoded
prediction = content[entry_idx]["results"][k]
predictions[real_epoch, input_idx] = prediction + [
tokenizer.pad_token_id
] * max(len_seqs_output - len(prediction), 0)
epochs_seen_list.append(content[entry_idx]["epoch"])
# Make sure that we only have one of each key
assert len(set(epochs_seen_list)) == len(epochs_seen_list), (
collections.Counter(epochs_seen_list))
assert epochs_seen_list == list(range(len(epochs_seen_list))), (
epochs_seen_list)
if verbose:
rich.print("Writing attrs")
predictions.attrs.create(
"epochs", epochs_seen_list, dtype=np.int64
)
if verbose:
rich.print("[bold]Done h5py.")
if queue:
queue.put(None)
return time.process_time() - start_time
class _ConvertFunctor:
def __init__(self, tokenizer, max_epochs, input_ids, label_ids):
self._tokenizer = tokenizer
self._max_epochs: Optional[int] = max_epochs
self._input_ids = input_ids
if max_epochs:
rich.print(f"[red]{max_epochs = }")
def __call__(self, input_path: Union[str, Path], queue=None, verbose=False):
return _convert(
input_path=input_path,
tokenizer=self._tokenizer,
max_epochs=self._max_epochs,
input_ids=self._input_ids,
verbose=verbose,
queue=queue,
)
class LaunchMethods(str, enum.Enum):
launch_all = "launch_all"
launch_few = "launch_few"
launch_one = "launch_one"
launch_pool = "launch_pool"
class ThreadOrProcess(str, enum.Enum):
thread = "thread"
process = "process"
def build_eval_subset_sorted_by_keys(data_path, subset_path):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load the eval ds
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
start = time.perf_counter()
with open(data_path, "rb") as f:
dataset_dict = pickle.load(f)
end = time.perf_counter()
print(f"Loaded the pkl dataset in {end - start:.2f} seconds")
valid_ds = dataset_dict[MAIN_DATASET_DATA_KEY][MAIN_DATASET_EVAL_KEY]
del dataset_dict
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load the subset file and apply it to the eval ds
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
subset_indices, subset_str = script_data_subset_selection.read_subset_file(
data_path, subset_path
)
valid_ds_subset = {}
for level, nodes in tqdm(valid_ds.items(), desc="Applying the subset"):
assert isinstance(level, int)
valid_ds_subset[level] = [nodes[idx] for idx in subset_indices[level]]
del valid_ds, subset_indices, subset_str
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# extract the order of the nodes in the subset
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
subset_per_key = {}
for level, node_list in valid_ds_subset.items():
for node in node_list:
subset_per_key[node[NODE_INPUT_STR_KEY]] = node
sorted_by_keys = dict(sorted(subset_per_key.items(), key=lambda item: item[0]))
return sorted_by_keys
@beartype
def main(
path: Union[str, Path] = SCRIPT_DIR / "log_results/oracle/",
n_cpus: int = 12,
test_run: bool = False,
method=LaunchMethods.launch_pool,
max_epochs: Optional[int] = 60,
thread_or_process: str = ThreadOrProcess.process,
subset_path: PathType = DATA_DIR / "subsets/subset_10000_seed_453345_of_349_6_6_200000.json",
data_path: PathType = DATA_DIR / "349_6_6_200000.json.pkl",
):
general_utils.check_and_print_args(locals().copy(), main)
data_path = Path(data_path)
subset_path = Path(subset_path)
assert data_path.suffix == ".pkl", f"{data_path} is not a pickle file."
assert subset_path.suffix == ".json", f"{subset_path} is not a json file."
assert subset_path.exists(), f"{subset_path} does not exist."
assert data_path.exists(), f"{data_path} does not exist."
if thread_or_process == ThreadOrProcess.thread:
thread_or_process_type: TypeAlias = threading.Thread
elif thread_or_process == ThreadOrProcess.process:
thread_or_process_type: TypeAlias = multiprocessing.Process # type: ignore[assignment, no-redef]
else:
raise ValueError(
f"Unknown thread_or_process: {thread_or_process}, "
f"should be one of {[x.value for x in list(ThreadOrProcess)]}"
)
method = LaunchMethods(method)
tokenizer = data_tokenizer.ArithmeticTokenizer()
#######################################################################
# Prepare the paths
#######################################################################
TARGET_FILE_NAME = "predictions.jsonl"
path = Path(path)
active = sorted(path.glob(f"**/{TARGET_FILE_NAME}"))
rich.print(f"[bold]File paths: {general_utils.shorten_path(path)}")
for path in general_utils.sort_iterable_text(active):
rich.print(f"\t- {general_utils.shorten_path(path)}")
print()
assert active, path
rich.print("[bold]Converting.")
rich.print(f"{n_cpus = }")
n_cpus = min(n_cpus, len(active)) # Likely already done by multiprocessing.Pool
###########################################################################
# Prep the inputs and labels
###########################################################################
print()
rich.print("[bold]Preparing inputs and labels.")
sorted_by_keys = build_eval_subset_sorted_by_keys(
data_path, subset_path
)
label_ids = tokenize_pad_numpify(
tokenizer, sorted_by_keys.values(),
NODE_VALUE_STR_KEY
)
print("3")
input_ids = tokenize_pad_numpify(tokenizer, sorted_by_keys.keys())
###########################################################################
# Do the multiprocessing
###########################################################################
print()
rich.print("[bold]Starting multiprocessing.")
convert_functor = _ConvertFunctor(
tokenizer,
max_epochs=max_epochs,
input_ids=input_ids,
label_ids=label_ids,
)
start = time.perf_counter()
if not test_run:
if method == LaunchMethods.launch_few:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Only launch n_cpus processes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
processes: list[thread_or_process_type] = []
if thread_or_process_type == threading.Thread:
queue: threading_queue.Queue = threading_queue.Queue(n_cpus)
elif thread_or_process_type == multiprocessing.Process:
queue: multiprocessing.Queue = multiprocessing.Queue(n_cpus) # type: ignore[no-redef]
else:
raise ValueError(thread_or_process_type)
[queue.put(None) for _ in range(n_cpus)] # type: ignore[func-returns-value]
for i, path in enumerate(tqdm(active, desc="Running jobs")):
queue.get()
print(f"Started {i}")
process = thread_or_process_type(
target=convert_functor, args=(path, queue))
processes.append(process)
process.start()
for process in tqdm(processes, desc="Joining"):
process.join()
elif method == LaunchMethods.launch_pool:
#######################################################################
# Launch all processes
#######################################################################
if thread_or_process_type == threading.Thread:
PoolType = mp_pool.ThreadPool
elif thread_or_process_type == multiprocessing.Process:
PoolType = multiprocessing.Pool # type: ignore[assignment]
promises = []
times = []
with PoolType(n_cpus) as pool:
for path in active:
promises.append(pool.apply_async(convert_functor, (path,)))
for promise in tqdm(promises, desc="Running conversion jobs."):
times.append(promise.get())
elif method == LaunchMethods.launch_one:
for path in tqdm(active):
convert_functor(path, queue=None, verbose=True)
else:
raise ValueError(f"Unknown method: {method}")
duration = time.perf_counter() - start
print()
rich.print("[bold]Done running jobs. Writing label_ids and input_ids.")
def write_label_ids(input_path: Path):
output_path = input_path.parent / f"{input_path.stem}.h5"
with h5py.File(output_path, "r+") as f:
f.create_dataset(H5_LABEL_IDS_KEY, data=label_ids,)
with mp_pool.ThreadPool(n_cpus) as pool:
promises = [pool.apply_async(write_label_ids, (path,)) for path in active]
for promise in tqdm(promises, desc="Writing `label_ids`."):
promise.get()
#######################################################################
# Print some stats
#######################################################################
duration_w_more_stuff = time.perf_counter() - start
print("Done with multiprocessing.")
rich.print(f"[bold]Done in {duration:.2f} seconds.")
rich.print(f" - This is {duration / n_cpus} s/cpu.")
rich.print(f" - This is {duration / len(active)} s/file")
rich.print(f" - Average time of one file was {np.mean(times):.1f} seconds.")
rich.print(f" - The linear time would have been {np.sum(times):.1f} seconds.")
rich.print(f" - This is an improvement of {np.mean(times) / duration:0.1f} times.")
print(f"{duration_w_more_stuff:.2f} seconds with more stuff.")
if __name__ == "__main__":
fire.Fire(main)