forked from dongjunKANG/DREM_with_CFSNS
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_feature.py
More file actions
32 lines (25 loc) · 1006 Bytes
/
train_feature.py
File metadata and controls
32 lines (25 loc) · 1006 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import util.option as opt
import util.main_utils as main_utils
from transformers import AutoTokenizer
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
from evaluator.eval_utils import _eval
tokens = ['[EOU]']
if __name__ == '__main__':
args = opt.parse_train_feature_opt()
print(args)
main_utils.set_seed(args.seed)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name, model_max_length=512)
for v in tokens:
tokenizer.add_tokens(v)
# model = main_utils.get_model(args, tokenizer)
# train_ds = main_utils.get_dataset(args, tokenizer, data_type='train')
# eval_ds = main_utils.get_dataset(args, tokenizer, data_type='valid')
# test_ds = main_utils.get_dataset(args, tokenizer, data_type='test')
# dataset = [train_ds, eval_ds, test_ds]
# trainer = main_utils.get_trainer(args, tokenizer, model, dataset)
# trainer.run("train")
# trainer.run("test")
# trainer.run("eval")
_eval(args)