-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript_add_label_info_to_subset_h5.py
More file actions
executable file
·212 lines (178 loc) · 8.44 KB
/
script_add_label_info_to_subset_h5.py
File metadata and controls
executable file
·212 lines (178 loc) · 8.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Adds label information to the H5 files generated by `script_convert_h5.py`.
Should also be merged with that script at some point in the future.
"""
import json
from pathlib import Path
from typing import *
import fire # type: ignore[import]
import h5py # type: ignore[import]
import multiprocessing
import numpy as np
import pickle
import queue as threading_queue
import rich
import threading
import time
from tqdm import tqdm # type: ignore[import]
import data_tokenizer
import general_utils
import script_data_subset_selection
import script_convert_h5
SCRIPT_DIR = Path(__file__).absolute().parent
PathType = Union[str, Path]
def work(
predictions_h5_path,
input_ids,
subset_path,
labels_np,
queue: multiprocessing.Queue,
):
with h5py.File(predictions_h5_path, "r+") as predictions:
assert predictions[script_convert_h5.H5_INPUT_IDS_KEY].shape[0] >= input_ids.shape[0], (
predictions[script_convert_h5.H5_INPUT_IDS_KEY].shape[0], input_ids.shape[0])
assert (np.all(predictions[script_convert_h5.H5_INPUT_IDS_KEY][:, :input_ids.shape[1]] == input_ids) and
np.all(predictions[script_convert_h5.H5_INPUT_IDS_KEY][:, input_ids.shape[1]:] == 0)), (
f"{predictions_h5_path} has a different order of nodes than {subset_path}")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute the labels and tokenize them, then save them.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
predictions.create_dataset(script_convert_h5.H5_LABEL_IDS_KEY, data=labels_np)
queue.put(None)
def main(
dir_path: PathType = SCRIPT_DIR / "log_results/basic/",
subset_path: PathType = (script_convert_h5.DATA_DIR / "subsets/subset_10000_seed_453345_of_349_6_6_200000.json"),
data_path: PathType = script_convert_h5.DATA_DIR / "349_6_6_200000.json.pkl",
num_procs: int = 10,
dry: bool = False,
thread_or_process: str = "thread",
):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Argument checking
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
general_utils.check_and_print_args(locals().copy(), main)
assert thread_or_process == "thread", (
"Multiprocessing is useless it turns out, "
"as the numpy work is only done once."
)
if thread_or_process == "thread":
thread_or_process_type = threading.Thread
else:
thread_or_process_type = multiprocessing.Process # type: ignore[assignment]
targets = list(Path(dir_path).glob(f"**/{script_convert_h5.TARGET_FILE_NAME}"))
print()
rich.print("[bold]Targets:")
general_utils.print_list(general_utils.sort_iterable_text(targets))
print()
del dir_path
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Checks relating to the node data & subset
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
subset_path = Path(subset_path)
data_path = Path(data_path)
subset_size = sum(len(x) for x in json.loads(subset_path.read_text()
)[script_convert_h5.SUBSET_INPUT_IDS_KEY].values())
assert data_path.exists(), f"Path {data_path} does not exist"
assert data_path.is_file(), f"Path {subset_path} is not a file"
assert data_path.suffix == ".pkl", f"Path {subset_path} is not a pkl file"
assert subset_path.exists(), f"Path {subset_path} does not exist"
assert subset_path.is_file(), f"Path {subset_path} is not a file"
assert subset_path.suffix == ".json", f"Path {subset_path} is not a json file"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Checks relating to the predictions h5 file
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
for predictions_h5_path in targets:
assert predictions_h5_path.exists(), (
f"{predictions_h5_path} does not exist")
assert predictions_h5_path.is_file(), (
f"Path {predictions_h5_path} is not a file")
assert predictions_h5_path.suffix == ".h5", (
f"Path {predictions_h5_path} is not a .h5 file")
with h5py.File(predictions_h5_path, "r+") as predictions:
if script_convert_h5.H5_LABEL_IDS_KEY in predictions:
del predictions[script_convert_h5.H5_LABEL_IDS_KEY]
assert script_convert_h5.H5_INPUT_IDS_KEY in predictions
assert not np.all(predictions[script_convert_h5.H5_INPUT_IDS_KEY][:] == 0)
assert script_convert_h5.H5_PREDICTIONS_KEY in predictions, (
f"{script_convert_h5.H5_PREDICTIONS_KEY} not in {predictions}"
)
assert subset_size == predictions[script_convert_h5.H5_INPUT_IDS_KEY].shape[0], (
f"{subset_path} has a different number "
f"of samples than {predictions_h5_path}"
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Get the labels ready.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
tokenizer = data_tokenizer.ArithmeticTokenizer()
print()
rich.print("[bold]Prepare the subset.")
start = time.perf_counter()
sorted_by_keys = script_convert_h5.build_eval_subset_sorted_by_keys(data_path, subset_path)
end = time.perf_counter()
print(f"Built the subset in {end - start:.1f} seconds")
print()
rich.print("[bold]Prepare the labels.")
start = time.perf_counter()
labels_np = script_convert_h5.tokenize_pad_numpify(
tokenizer, sorted_by_keys.values(), script_convert_h5.NODE_VALUE_STR_KEY)
end = time.perf_counter()
print(f"Padded the labels in {end - start:.1f} seconds")
print()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Make sure the order is the same in the predictions h5 file
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
rich.print("\n[bold]Make sure the order is the same in the predictions h5 file ...")
input_ids = script_convert_h5.tokenize_pad_numpify(tokenizer, sorted_by_keys.keys())
if thread_or_process_type == threading.Thread:
queue: threading_queue.Queue = threading_queue.Queue(num_procs)
elif thread_or_process_type == multiprocessing.Process:
queue: multiprocessing.Queue = multiprocessing.Queue(num_procs) # type: ignore[no-redef]
else:
raise ValueError(
f"Unknown thread_or_process_type: {thread_or_process_type}")
for _ in range(num_procs):
queue.put(None)
if not dry:
processes = []
for predictions_h5_path in tqdm(
targets, desc="Main work, writing the label sections."):
queue.get()
proc = thread_or_process_type(
target=work,
args=(predictions_h5_path,),
kwargs=dict(
input_ids=input_ids,
subset_path=subset_path,
labels_np=labels_np,
queue=queue
)
)
processes.append(proc)
proc.start()
print()
rich.print("[bold]Joining procs")
for proc in tqdm(processes, desc="Joining procs"):
proc.join()
print()
rich.print("[bold]Final checks.")
for predictions_h5_path in tqdm(targets, desc="Final checks"):
with h5py.File(predictions_h5_path, "r") as f:
assert script_convert_h5.H5_INPUT_IDS_KEY in f, (
f"{script_convert_h5.H5_INPUT_IDS_KEY} not in {f}")
assert script_convert_h5.H5_PREDICTIONS_KEY in f, (
f"{script_convert_h5.H5_PREDICTIONS_KEY} not in {f}")
assert (f[script_convert_h5.H5_INPUT_IDS_KEY].shape[0] ==
f[script_convert_h5.H5_PREDICTIONS_KEY].shape[1]), (
f"\n{predictions_h5_path = }\n"
f"{f[script_convert_h5.H5_INPUT_IDS_KEY].shape = }\n"
f"{f[script_convert_h5.H5_PREDICTIONS_KEY].shape[0] = }"
)
assert not np.all(f[script_convert_h5.H5_INPUT_IDS_KEY][:] == 0), (
predictions_h5_path)
assert not np.all(f[script_convert_h5.H5_PREDICTIONS_KEY][:] == 0), (
predictions_h5_path)
print("Done.")
if __name__ == "__main__":
fire.Fire(main)