-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexport_seq2seq_serve.py
More file actions
46 lines (40 loc) · 1.38 KB
/
export_seq2seq_serve.py
File metadata and controls
46 lines (40 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from pathlib import Path
import json
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="t5-large")
parser.add_argument("--model_export_name", default="t5_large")
parser.add_argument(
"--checkpoint_path",
default="/mnt/raid/data/kbqa/seq2seq_runs/wdsq_tunned/t5-large/models/checkpoint-7000/",
)
parser.add_argument("--version", default="1.0")
if __name__ == "__main__":
args = parser.parse_args()
setup_config = {
"model_name": args.model_name,
"mode": "text_generation",
"do_lower_case": False,
"num_labels": "0",
"save_mode": "pretrained",
"max_length": "64",
"captum_explanation": False,
"FasterTransformer": False,
"embedding_name": args.model_name,
}
with open("setup_config.json", "w") as f:
json.dump(setup_config, f)
checkpoint_path = Path(args.checkpoint_path)
SERRIALIZED_FILE = str(checkpoint_path / "pytorch_model.bin")
EXTRA_FILES = str(checkpoint_path / "config.json") + ",./setup_config.json"
cmd = (
"torch-model-archiver "
f"--model-name {args.model_export_name} "
f"--version {args.version} "
f"--serialized-file {SERRIALIZED_FILE} "
"--handler ./kbqa/seq2seq/transformer_handler_generalized.py "
f'--extra-files "{EXTRA_FILES}"'
)
print(cmd)
os.system(cmd)