diff --git a/config.py b/config.py index 74381c7..f97205b 100644 --- a/config.py +++ b/config.py @@ -17,6 +17,7 @@ def get_opt(): # 数据集位置 parser.add_argument('--data_path', type=str, default='parsed-v1.5/') parser.add_argument('--emb_file_path', type=str, default='./glove.6B/glove.6B.200d.txt') + parser.add_argument('--exemplar_instance_path', type=str, default='./exemplar_instance_dic.npy') parser.add_argument('--train_instance_path', type=str, default='./train_instance_dic.npy') parser.add_argument('--dev_instance_path', type=str, default='./dev_instance_dic.npy') parser.add_argument('--test_instance_path', type=str, default='./test_instance_dic.npy') @@ -25,21 +26,23 @@ def get_opt(): parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') parser.add_argument('--model_name', type=str, default='train_model') parser.add_argument('--pretrain_model', type=str, default='') - parser.add_argument('--save_model_path', type=str, default='./models_ft/model_syntax.bin') + parser.add_argument('--save_model_path', type=str, default='./models_ft/wo_jointarg.bin') + # 训练相关 - parser.add_argument('--lr', type=float, default='0.0001') + parser.add_argument('--lr', type=float, default='0.00006') parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--batch_size', type=int, default=2) - parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--save_model_freq', type=int, default=1) # 保存模型间隔,以epoch为单位 parser.add_argument('--cuda', type=str, default="cuda:0") - parser.add_argument('--mode', type=str, default="train") + parser.add_argument('--mode', type=str, default="test") # 模型的一些settings parser.add_argument('--maxlen',type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--cell_name', type=str, default='lstm') + #change parser.add_argument('--rnn_hidden_size',type=int, default=256) parser.add_argument('--rnn_emb_size', type=int, default=400) parser.add_argument('--encoder_emb_size',type=int, default=200) @@ -56,7 +59,7 @@ def get_opt(): parser.add_argument('--role_number',type=int, default=9634) parser.add_argument('--fe_padding_num',type=int, default=5) parser.add_argument('--window_size', type=int, default=3) - parser.add_argument('--num_layers', type=int, default=4) + parser.add_argument('--num_layers', type=int, default=2) return parser.parse_args() diff --git a/data_process2.py b/data_process4.py similarity index 95% rename from data_process2.py rename to data_process4.py index 912d4e5..dab1328 100644 --- a/data_process2.py +++ b/data_process4.py @@ -153,6 +153,7 @@ def __init__(self,opt): self.maxlen = opt.maxlen if opt.load_instance_dic: + self.exemplar_instance_dic = np.load(opt.exemplar_instance_path, allow_pickle=True).item() self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item() self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item() self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item() @@ -182,7 +183,7 @@ def __init__(self,opt): self.pos_number = 0 self.rel_number = 0 - #self.build_word_index(self.exemplar_instance_dic) + self.build_word_index(self.exemplar_instance_dic) self.build_word_index(self.train_instance_dic) self.build_word_index(self.dev_instance_dic) self.build_word_index(self.test_instance_dic) @@ -357,7 +358,7 @@ def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU. lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx] for key in lu_list.keys(): - #lu_list[key]['lu_mask'][lu_num] = 1 + # lu_list[key]['lu_mask'][lu_num] = 1 lu_list[key]['fe_mask'][fe_num] = 1 return lu_list, lu_id_to_name, lu_name_to_id diff --git a/dataset2.py b/dataset5.py similarity index 81% rename from dataset2.py rename to dataset5.py index 2f29943..8de0da3 100644 --- a/dataset2.py +++ b/dataset5.py @@ -4,7 +4,7 @@ from torch.utils.data import Dataset from config import get_opt -from data_process2 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig +from data_process4 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig from utils import get_mask_from_index class FrameNetDataset(Dataset): @@ -56,21 +56,40 @@ def __init__(self, opt, config, data_dic, device): self.fe_mask = [] self.lu_mask = [] self.token_type_ids = [] + self.target_mask_ids = [] self.device = device self.oov_frame = 0 self.long_span = 0 self.error_span = 0 self.fe_coretype_table = {} + self.target_mask = {} for idx in self.fe_id_to_type.keys(): if self.fe_id_to_type[idx] == 'Core': - self.fe_coretype_table[idx] = 1 + self.fe_coretype_table[self.fe_id_to_label[idx]] = 1 else: - self.fe_coretype_table[idx] = 0 + self.fe_coretype_table[self.fe_id_to_label[idx]] = 0 + + + + for key in self.data_dic.keys(): + self.build_target_mask(key,opt.maxlen) + for key in self.data_dic.keys(): self.pre_process(key, opt) + + self.pad_dic_cnt = self.dataset_len % opt.batch_size + + + for idx,key in enumerate(self.data_dic.keys()): + if idx >= self.pad_dic_cnt: + break + self.pre_process(key, opt,filter=False) + + self.dataset_len+=self.pad_dic_cnt + print('load data finish') print('oov frame = ', self.oov_frame) print('long_span = ', self.long_span) @@ -99,19 +118,24 @@ def __getitem__(self, item): fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device) lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device) token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device) + sent_length = torch.Tensor([self.sent_length[item]]).long().to(self.device) + target_mask_ids = torch.Tensor(self.target_mask_ids[item]).long().to(self.device) + # print(fe_cnt) + return (word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, mask, target_head, target_tail, target_type, fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding, - fe_mask, lu_mask, token_type_ids) + fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids) - def pre_process(self, key, opt): + def pre_process(self, key, opt,filter=True): if self.data_dic[key]['target_type'] not in self.frame_name_to_label: self.oov_frame += 1 self.dataset_len -= 1 return target_id = self.frame_name_to_id[self.data_dic[key]['target_type']] - self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], + if filter: + self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20) word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']] @@ -135,7 +159,9 @@ def pre_process(self, key, opt): self.mask.append(mask) token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen) + # token_type_ids +=self.target_mask[key[0]] self.token_type_ids.append(token_type_ids) + self.target_mask_ids.append(self.target_mask[key[0]]) self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']]) @@ -154,8 +180,7 @@ def pre_process(self, key, opt): fe_head = self.data_dic[key]['span_start'] fe_tail = self.data_dic[key]['span_end'] - self.fe_cnt.append(min(len(fe_head), opt.fe_padding_num)) - self.fe_cnt_with_padding.append(min(len(fe_head)+1, opt.fe_padding_num)) + while len(fe_head) < opt.fe_padding_num: fe_head.append(min(sent_length-1, opt.maxlen-1)) @@ -168,6 +193,9 @@ def pre_process(self, key, opt): fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']] + self.fe_cnt.append(min(len(fe_type), opt.fe_padding_num)) + self.fe_cnt_with_padding.append(min(len(fe_type)+1, opt.fe_padding_num)) + while len(fe_type) < opt.fe_padding_num: fe_type.append(self.fe_num) # fe_coretype.append('0') @@ -181,7 +209,7 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe indices.append(index) elif fe_tail_list[index] < fe_head_list[index]: indices.append(index) - continue + elif (fe_type_list[index], target_id) not in self.fe_name_to_label: indices.append(index) @@ -196,6 +224,9 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]: indices.append(index) break + elif fe_tail_list[index] <= fe_head_list[i] and fe_tail_list[index] >= fe_tail_list[i]: + indices.append(index) + break else: continue @@ -208,6 +239,17 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe return len(indices) + def build_target_mask(self,key,maxlen): + self.target_mask.setdefault(key[0], [0]*maxlen) + + target_head = self.data_dic[key]['target_idx'][0] + target_tail = self.data_dic[key]['target_idx'][1] + self.target_mask[key[0]][target_head] = 1 + self.target_mask[key[0]][target_tail] = 1 + + + + def build_token_type_ids(target_head, target_tail, maxlen): token_type_ids = [0]*maxlen diff --git a/evaluate2.py b/evaluate2.py deleted file mode 100644 index 69ec0da..0000000 --- a/evaluate2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch - -class Eval(): - def __init__(self,opt): - self.fe_TP = 0 - self.fe_TP_FP = 0 - self.fe_TP_FN = 0 - - self.frame_cnt = 0 - self.frame_acc = 0 - - self.opt=opt - def metrics(self,batch_size:int,fe_cnt:torch.Tensor,\ - gold_fe_type:torch.Tensor,gold_fe_head:torch.Tensor,\ - gold_fe_tail:torch.Tensor,gold_frame_type:torch.Tensor,\ - pred_fe_type:torch.Tensor,pred_fe_head:torch.Tensor,\ - pred_fe_tail:torch.Tensor,pred_frame_type:torch.Tensor, - ): - - self.frame_cnt+=batch_size - for batch_index in range(batch_size): - #caculate frame acc - # print(gold_frame_type[batch_index]) - # print(pred_frame_type[batch_index]) - if gold_frame_type[batch_index] == pred_frame_type[0][batch_index]: - self.frame_acc += 1 - - #update tp_fn - self.fe_TP_FN+=int(fe_cnt[batch_index]) - # preprocess error - for idx in range(fe_cnt[batch_index]): - if gold_fe_head[batch_index][idx] > gold_fe_tail[batch_index][idx]: - self.fe_TP_FN -= 1 - - - #gold_fe_list = gold_fe_type.cpu().numpy().tolist() - gold_tail_list =gold_fe_tail.cpu().numpy().tolist() - #update fe_tp and fe_TP_FP - for fe_index in range(self.opt.fe_padding_num): - if pred_fe_type[fe_index][batch_index] == self.opt.role_number: - break - - - #update fe_tp - if pred_fe_tail[fe_index][batch_index] in gold_tail_list[batch_index]: - idx = gold_tail_list[batch_index].index(pred_fe_tail[fe_index][batch_index]) - - - if pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ - pred_fe_type[fe_index][batch_index] == gold_fe_type[batch_index][idx]: - self.fe_TP+=1 - - #update fe_tp_fp - self.fe_TP_FP+=1 - - def calculate(self): - frame_acc = self.frame_acc / self.frame_cnt - fe_prec = self.fe_TP / (self.fe_TP_FP+0.000001) - fe_recall = float(self.fe_TP / self.fe_TP_FN) - fe_f1 = 2*fe_prec*fe_recall/(fe_prec+fe_recall+0.0000001) - - full_TP = self.frame_acc+self.fe_TP - full_TP_FP = self.frame_cnt + self.fe_TP_FP - full_TP_FN = self.frame_cnt +self.fe_TP_FN - - full_prec = float(full_TP / full_TP_FP) - full_recall =float(full_TP / full_TP_FN) - full_f1 = 2 * full_prec * full_recall / (full_prec + full_recall+0.000001) - - print(" frame acc: %.6f " % frame_acc) - print(" fe_prec: %.6f " % fe_prec) - print(" fe_recall: %.6f " % fe_recall) - print(" fe_f1: %.6f " % fe_f1) - print('================full struction=============') - print(" full_prec: %.6f " % full_prec) - print(" full_recall: %.6f " % full_recall) - print(" full_f1: %.6f " % full_f1) - - return (frame_acc,fe_prec,fe_recall,fe_f1,full_prec,full_recall,full_f1) - - - - diff --git a/evaluate7.py b/evaluate7.py new file mode 100644 index 0000000..402b1a8 --- /dev/null +++ b/evaluate7.py @@ -0,0 +1,162 @@ +import torch +import numpy as np + +class Eval(): + def __init__(self,opt,Dataset): + self.fe_TP = 0 + self.fe_TP_FP = 0 + self.fe_TP_FN = 0 + + self.frame_cnt = 0 + self.frame_acc = 0 + + self.core_cnt = 0 + self.nocore_cnt = 0 + + self.declay = 0.8 + self.opt=opt + + self.fe_id_to_label =Dataset.fe_id_to_label + self.fe_label_to_id = {} + for key, val in self.fe_id_to_label.items(): + self.fe_label_to_id[val] = key + + self.frame_id_to_label =Dataset.frame_id_to_label + self.frame_label_to_id = {} + for key, val in self.frame_id_to_label.items(): + self.frame_label_to_id[val] = key + + path = 'frame-fe-dist-path/' + self.fe_dis_matrix = np.load(path + 'fe_dis_matrix.npy', allow_pickle=True) + self.frame_dis_matrix = np.load(path + 'frame_dis_matrix.npy', allow_pickle=True) + self.fe_path_matrix = np.load(path + 'fe_path_matrix.npy', allow_pickle=True) + self.frame_path_matrix = np.load(path + 'frame_path_matrix.npy', allow_pickle=True) + self.fe_hash_index = np.load(path + 'fe_hash_idx.npy', allow_pickle=True).item() + self.frame_hash_index = np.load(path + 'frame_hash_idx.npy', allow_pickle=True).item() + + def metrics(self,batch_size:int,fe_cnt:torch.Tensor,\ + gold_fe_type:torch.Tensor,gold_fe_head:torch.Tensor,\ + gold_fe_tail:torch.Tensor,gold_frame_type:torch.Tensor,\ + pred_fe_type:torch.Tensor,pred_fe_head:torch.Tensor,\ + pred_fe_tail:torch.Tensor,pred_frame_type:torch.Tensor, + fe_coretype,sent_length:torch.Tensor): + + self.frame_cnt+=batch_size + for batch_index in range(batch_size): + #caculate frame acc + # print(gold_frame_type[batch_index]) + # print(pred_frame_type[batch_index]) + if self.frame_label_to_id[int(gold_frame_type[batch_index])] in self.frame_hash_index.keys() and \ + self.frame_label_to_id[int(pred_frame_type[0][batch_index])] in self.frame_hash_index.keys(): + gold_frame_idx = self.frame_hash_index[self.frame_label_to_id[int(gold_frame_type[batch_index])]] + pred_frame_idx = self.frame_hash_index[self.frame_label_to_id[int(pred_frame_type[0][batch_index])]] + KeyError = False + else: + KeyError = True + + if gold_frame_type[batch_index] == pred_frame_type[0][batch_index]: + self.frame_acc += 1 + + elif KeyError is False: + if self.frame_dis_matrix[gold_frame_idx][pred_frame_idx]!=-1 and self.frame_dis_matrix[gold_frame_idx][pred_frame_idx] <=9: + self.frame_acc += 0.8**self.frame_dis_matrix[gold_frame_idx][pred_frame_idx] + # update tp_fn + # self.fe_TP_FN+=int(fe_cnt[batch_index]) + for fe_index in range(fe_cnt[batch_index]): + if fe_coretype[int(gold_fe_type[batch_index][fe_index])] == 1: + self.fe_TP_FN+=1 + self.core_cnt+=1 + else: + self.fe_TP_FN+=0.5 + self.nocore_cnt+=1 + + # preprocess error + + #gold_fe_list = gold_fe_type.cpu().numpy().tolist() + gold_tail_list =gold_fe_tail.cpu().numpy().tolist() + #update fe_tp and fe_TP_FP + for fe_index in range(self.opt.fe_padding_num): + # if pred_fe_tail[fe_index][batch_index] == sent_length[batch_index]-1 and \ + # pred_fe_head[fe_index][batch_index] == sent_length[batch_index] - 1: + if pred_fe_type[fe_index][batch_index] == self.opt.role_number: + break + + + #update fe_tp + if pred_fe_tail[fe_index][batch_index] in gold_tail_list[batch_index]: + idx = gold_tail_list[batch_index].index(pred_fe_tail[fe_index][batch_index]) + + + + if pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ + pred_fe_type[fe_index][batch_index] == gold_fe_type[batch_index][idx] : + + if fe_coretype[int(gold_fe_type[batch_index][idx])] == 1: + self.fe_TP+= 1 + else: + self.fe_TP+= 0.5 + + + elif gold_fe_type[batch_index][idx]==self.opt.role_number: + pass + + elif KeyError is True: + pass + + elif pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ + self.fe_label_to_id[int(gold_fe_type[batch_index][idx])] in self.fe_hash_index.keys() \ + and self.fe_label_to_id[int(pred_fe_type[fe_index][batch_index])] in self.fe_hash_index.keys(): + + gold_fe_idx = self.fe_hash_index[self.fe_label_to_id[int(gold_fe_type[batch_index][idx])]] + pred_fe_idx = self.fe_hash_index[self.fe_label_to_id[int(pred_fe_type[fe_index][batch_index])]] + + if self.fe_dis_matrix[gold_fe_idx][pred_fe_idx] != -1 and self.fe_dis_matrix[gold_fe_idx][pred_fe_idx] < 9 : + rate = float(self.fe_path_matrix[gold_fe_idx][pred_fe_idx]/(self.frame_path_matrix[gold_frame_idx][pred_frame_idx]+0.000001)) + if rate > 1: + rate = 1 + if fe_coretype[int(gold_fe_type[batch_index][idx])] == 1: + self.fe_TP += 1*rate*(0.8**self.fe_dis_matrix[gold_fe_idx][pred_fe_idx]) + else: + self.fe_TP += 0.5*rate*(0.8**self.fe_dis_matrix[gold_fe_idx][pred_fe_idx]) + + #update fe_tp_fp + if fe_coretype[int(pred_fe_type[fe_index][batch_index])] == 1: + self.fe_TP_FP += 1 + else: + self.fe_TP_FP += 0.5 + + + def calculate(self): + frame_acc = self.frame_acc / self.frame_cnt + fe_prec = self.fe_TP / (self.fe_TP_FP+0.000001) + fe_recall = float(self.fe_TP / self.fe_TP_FN) + fe_f1 = 2*fe_prec*fe_recall/(fe_prec+fe_recall+0.0000001) + + full_TP = self.frame_acc+self.fe_TP + full_TP_FP = self.frame_cnt + self.fe_TP_FP + full_TP_FN = self.frame_cnt +self.fe_TP_FN + + full_prec = float(full_TP / full_TP_FP) + full_recall =float(full_TP / full_TP_FN) + full_f1 = 2 * full_prec * full_recall / (full_prec + full_recall+0.000001) + + print(" frame acc: %.6f " % frame_acc) + print(" fe_prec: %.6f " % fe_prec) + print(" fe_recall: %.6f " % fe_recall) + print(" fe_f1: %.6f " % fe_f1) + print('================full struction=============') + print(" full_prec: %.6f " % full_prec) + print(" full_recall: %.6f " % full_recall) + print(" full_f1: %.6f " % full_f1) + print('================detail=============') + print(" fe_TP: %.6f " % self.fe_TP) + print(" fe_TP_FN: %.6f " % self.fe_TP_FN) + print(" fe_TP_FP: %.6f " % self.fe_TP_FP) + print(" core_cnt: %.6f " % self.core_cnt) + print(" nocore_cnt: %.6f " % self.nocore_cnt) + + return (frame_acc,fe_prec,fe_recall,fe_f1,full_prec,full_recall,full_f1) + + + + diff --git a/fn-model_bert/config_bert.py b/fn-model_bert/config_bert.py new file mode 100644 index 0000000..54c4c6a --- /dev/null +++ b/fn-model_bert/config_bert.py @@ -0,0 +1,73 @@ + +import argparse +import os + +data_dir ='./' + + +class PLMConfig: + MODEL_PATH = 'uncased_L-12_H-768_A-12' + VOCAB_PATH = f'{MODEL_PATH}/vocab.txt' + CONFIG_PATH = f'{MODEL_PATH}/bert_config.json' + + +def get_opt(): + parser = argparse.ArgumentParser() + + # 数据集位置 + parser.add_argument('--data_path', type=str, default='parsed-v1.5/') + parser.add_argument('--emb_file_path', type=str, default='./glove.6B/glove.6B.200d.txt') + parser.add_argument('--exemplar_instance_path', type=str, default='./exemplar_instance_dic_bert.npy') + parser.add_argument('--train_instance_path', type=str, default='./train_instance_dic_bert.npy') + parser.add_argument('--dev_instance_path', type=str, default='./dev_instance_dic_bert.npy') + parser.add_argument('--test_instance_path', type=str, default='./test_instance_dic_bert.npy') + # 保存模型和加载模型相关 + parser.add_argument('--load_instance_dic', type=bool, default=True) + parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') + parser.add_argument('--model_name', type=str, default='train_model') + parser.add_argument('--pretrain_model', type=str, default='') + parser.add_argument('--save_model_path', type=str, default='./models_ft/final_bert_3.bin') + + # 训练相关 + parser.add_argument('--lr', type=float, default='6e-6') + parser.add_argument('--weight_decay', type=float, default=0.0001) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--save_model_freq', type=int, default=1) # 保存模型间隔,以epoch为单位 + parser.add_argument('--cuda', type=str, default="cuda:0") + parser.add_argument('--mode', type=str, default="test") + + # 模型的一些settings + parser.add_argument('--maxlen',type=int, default=256) + parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--cell_name', type=str, default='lstm') + #change + parser.add_argument('--rnn_hidden_size',type=int, default=256) + parser.add_argument('--rnn_emb_size', type=int, default=400) + parser.add_argument('--encoder_emb_size',type=int, default=768) + parser.add_argument('--sent_emb_size', type=int, default=200) + parser.add_argument('--pos_emb_size',type=int, default=64) + parser.add_argument('--rel_emb_size',type=int,default=100) + parser.add_argument('--token_type_emb_size',type=int, default=32) + parser.add_argument('--decoder_emb_size',type=int, default=200) + parser.add_argument('--decoder_hidden_size',type=int, default=256) + parser.add_argument('--target_maxlen',type=int, default=5) + parser.add_argument('--decoderlen',type=int,default=4) + parser.add_argument('--kernel_size', type=int, default=3) + parser.add_argument('--frame_number',type=int, default=1019) + parser.add_argument('--role_number',type=int, default=9634) + parser.add_argument('--fe_padding_num',type=int, default=5) + parser.add_argument('--window_size', type=int, default=3) + parser.add_argument('--num_layers', type=int, default=2) + + return parser.parse_args() + + + + + + + +if __name__ == '__main__': + bertconfig = BertConfig() + print(bertconfig.CONFIG_PATH) diff --git a/fn-model_bert/data_process_bert.py b/fn-model_bert/data_process_bert.py new file mode 100644 index 0000000..282bc76 --- /dev/null +++ b/fn-model_bert/data_process_bert.py @@ -0,0 +1,430 @@ +import numpy as np +import os +import pandas as pd +import csv +import numpy +import torch +import json +import re + +from config_bert import get_opt,PLMConfig +from pytorch_pretrained_bert import BertTokenizer +from bert_slot_tokenizer import SlotConverter +from nltk import word_tokenize, pos_tag +from nltk.corpus import wordnet +from nltk.stem import WordNetLemmatizer + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +bert_tokenizer=BertTokenizer.from_pretrained(PLMConfig.MODEL_PATH) + + + +def load_data(path): + f = open(path, 'r', encoding='utf-8') + lines = f.readlines() + return lines + + +def instance_process(lines, maxlen): + instance_dic = {} + + cnt = 0 + find = False + word_list_total = [] + for line in lines: + if line[0:3] == '# i': + word_list = [] + lemma_list = [] + pos_list = [] + target_idx = [-1, -1] + span_start = [] + span_start_token = [] + span_end = [] + span_end_token = [] + span_type = [] + length = 0 + sent ='[CLS] ' + + elif line[0:3] == '# e': + instance_dic.setdefault((sent_id, target_type, cnt), {}) + #instance_dic[(sent_id, target_type, cnt)]['dep_list'] = dep_parsing(word_list, maxlen, parser) + word_list = bert_tokenizer.tokenize(sent) + + instance_dic[(sent_id, target_type, cnt)]['length'] = min(len(word_list),maxlen) + + attention_mask = [0]*maxlen + attention_mask[0:min(len(word_list),maxlen)]=[1] * min(len(word_list),maxlen) + + instance_dic[(sent_id, target_type, cnt)]['attention_mask']=attention_mask + + word_list_padded=padding_sentence(word_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['word_list'] = word_list_padded + instance_dic[(sent_id, target_type, cnt)]['tokenized_ids'] = bert_tokenizer.convert_tokens_to_ids(word_list_padded) + # instance_dic[(sent_id, target_type, cnt)]['lemma_list'] = padding_sentence(lemma_list, maxlen) + # instance_dic[(sent_id, target_type, cnt)]['pos_list'] = padding_sentence(pos_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['sent_id'] = sent_id + + word_list_total.append(word_list) + # add 'eos' + # instance_dic[(sent_id, target_type, cnt)]['attention_mask'] = get_mask_from_index(sequence_lengths=torch.Tensor([int(length)+1]), max_length=maxlen).squeeze() + + instance_dic[(sent_id, target_type, cnt)]['target_type'] = target_type + instance_dic[(sent_id, target_type, cnt)]['lu'] = lu + + target_start_list=bert_tokenizer.tokenize(target_start_token) + target_end_list = bert_tokenizer.tokenize(target_end_token) + # print(target_start_list) + # print(target_end_list) + # print(target_start_list[0]) + # print(target_end_list[-1]) + # print(word_list) + target_start=word_list.index(target_start_list[0],int(target_idx[0])) + target_end=word_list.index(target_end_list[-1],int(target_idx[1])) + instance_dic[(sent_id, target_type, cnt)]['target_idx'] = (target_start,target_end) + + instance_dic[(sent_id, target_type, cnt)]['span_start']=[] + instance_dic[(sent_id, target_type, cnt)]['span_end']=[] + for i in range(len(span_start)): + span_start_list = bert_tokenizer.tokenize(span_start_token[i]) + span_end_list = bert_tokenizer.tokenize(span_end_token[i]) + span_start_idx = word_list.index(span_start_list[0], span_start[i]) + span_end_idx = word_list.index(span_end_list[-1], span_end[i]) + + + instance_dic[(sent_id, target_type, cnt)]['span_start'].append(span_start_idx) + instance_dic[(sent_id, target_type, cnt)]['span_end'].append(span_end_idx) + instance_dic[(sent_id, target_type, cnt)]['span_type'] = span_type + + print(cnt) + cnt += 1 + elif line == '\n': + sent += '[SEP]' + continue + + else: + data_list = line.split('\t') + word_list.append(data_list[1]) + # lemma_list.append(data_list[3]) + # pos_list.append(data_list[5]) + sent_id = data_list[6] + # length = data_list[0] + sent+=data_list[1]+' ' + + if data_list[12] != '_' and data_list[13] != '_': + lu = data_list[12] + + target_type = data_list[13] + if target_idx == [-1, -1]: + target_idx = [int(data_list[0])-1, int(data_list[0])-1] + target_start_token = data_list[1] + target_end_token = data_list[1] + else: + target_idx[1] =int(data_list[0]) - 1 + target_end_token = data_list[1] + + if data_list[14] != '_': + + fe = data_list[14].split('-') + + if fe[0] == 'B' and find is False: + span_start.append(int(data_list[0]) - 1) + span_start_token.append(data_list[1]) + find = True + + elif fe[0] == 'O': + span_end.append(int(data_list[0]) - 1) + span_end_token.append(data_list[1]) + span_type.append(fe[-1].replace('\n', '')) + find = False + + elif fe[0] == 'S': + span_start.append(int(data_list[0]) - 1) + span_start_token.append(data_list[1]) + span_end.append(int(data_list[0]) - 1) + span_end_token.append(data_list[1]) + span_type.append(fe[-1].replace('\n', '')) + + return instance_dic + +# def dep_parsing(word_list: list,maxlen: list,parser): +# res = list(parser.parse(word_list)) +# sent = res[0].to_conll(4).split('\n')[:-1] +# #['the', 'DT', '4', 'det'] +# line = [line.split('\t') for line in sent] +# head_list = [] +# rel_list = [] +# +# distance = 0 +# +# #alignment +# for index in range(len(word_list)-1): +# #end stopwords +# if index-distance >= len(line): +# head_list.append('#') +# rel_list.append('#') +# distance+=1 +# +# elif word_list[index]!=line[index-distance][0]: +# head_list.append('#') +# rel_list.append('#') +# distance+=1 +# else: +# rel_list.append(line[index-distance][3]) +# +# if line[index-distance][3] != 'root': +# head_list.append(word_list[int(line[index-distance][2]) - 1]) +# else: +# head_list.append(word_list[index]) +# +# head_list.append('eos') +# rel_list.append('eos') +# +# while len(head_list) < maxlen: +# head_list.append('0') +# rel_list.append('0') +# +# return (head_list,rel_list) + + +def padding_sentence(sentence: list,maxlen: int): + + while len(sentence) < maxlen: + sentence.append('0') + + return sentence[0:maxlen] + +class DataConfig: + def __init__(self,opt): + exemplar_lines = load_data('fn1.5/conll/exemplar') + train_lines = load_data('fn1.5/conll/train') + dev_lines = load_data('fn1.5/conll/dev') + test_lines = load_data('fn1.5/conll/test') + + + self.emb_file_path = opt.emb_file_path + self.maxlen = opt.maxlen + + if opt.load_instance_dic: + self.exemplar_instance_dic = np.load(opt.exemplar_instance_path, allow_pickle=True).item() + self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item() + self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item() + self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item() + + else: + print('begin parsing') + self.exemplar_instance_dic = instance_process(lines=exemplar_lines,maxlen=self.maxlen) + np.save('exemplar_instance_dic_bert', self.exemplar_instance_dic) + print('exemplar_instance_dic_bert finish') + + self.train_instance_dic = instance_process(lines=train_lines,maxlen=self.maxlen) + np.save('train_instance_dic_bert', self.train_instance_dic) + print('train_instance_dic finish') + + self.dev_instance_dic = instance_process(lines=dev_lines,maxlen=self.maxlen) + np.save('dev_instance_dic_bert', self.dev_instance_dic) + print('dev_instance_dic finish') + + self.test_instance_dic = instance_process(lines=test_lines,maxlen=self.maxlen) + np.save('test_instance_dic_bert', self.test_instance_dic) + print('test_instance_dic finish') + + + # self.word_index = {} + # self.lemma_index = {} + # self.pos_index = {} + # self.rel_index = {} + # + # self.word_number = 0 + # self.lemma_number = 0 + # self.pos_number = 0 + # self.rel_number = 0 + # + # self.build_word_index(self.exemplar_instance_dic) + # self.build_word_index(self.train_instance_dic) + # self.build_word_index(self.dev_instance_dic) + # self.build_word_index(self.test_instance_dic) + # + # # add # for parsing sign + # self.word_index['#']=self.word_number + # self.word_number+=1 + # + # self.emb_index = self.build_emb_index(self.emb_file_path) + # + # self.word_vectors = self.get_embedding_weight(self.emb_index, self.word_index, self.word_number) + # self.lemma_vectors = self.get_embedding_weight(self.emb_index, self.lemma_index, self.lemma_number) + + def build_word_index(self, dic): + for key in dic.keys(): + word_list =dic[key]['word_list'] + lemma_list = dic[key]['lemma_list'] + pos_list = dic[key]['pos_list'] + rel_list = dic[key]['dep_list'][1] + + # print(row) + for word in word_list: + if word not in self.word_index.keys(): + self.word_index[word]=self.word_number + self.word_number += 1 + + for lemma in lemma_list: + if lemma not in self.lemma_index.keys(): + self.lemma_index[lemma]=self.lemma_number + self.lemma_number += 1 + + for pos in pos_list: + if pos not in self.pos_index.keys(): + self.pos_index[pos] = self.pos_number + self.pos_number += 1 + + for rel in rel_list: + if rel not in self.rel_index.keys(): + self.rel_index[rel] = self.rel_number + self.rel_number += 1 + + def build_emb_index(self, file_path): + data = open(file_path, 'r', encoding='utf-8') + emb_index = {} + for items in data: + item = items.split() + word = item[0] + weight = np.asarray(item[1:], dtype='float32') + emb_index[word] = weight + + return emb_index + + def get_embedding_weight(self,embed_dict, words_dict, words_count, dim=200): + + exact_count = 0 + fuzzy_count = 0 + oov_count = 0 + print("loading pre_train embedding by avg for out of vocabulary.") + embeddings = np.zeros((int(words_count) + 1, int(dim))) + inword_list = {} + for word in words_dict: + if word in embed_dict: + embeddings[words_dict[word]] = embed_dict[word] + inword_list[words_dict[word]] = 1 + # 准确匹配 + exact_count += 1 + elif word.lower() in embed_dict: + embeddings[words_dict[word]] = embed_dict[word.lower()] + inword_list[words_dict[word]] = 1 + # 模糊匹配 + fuzzy_count += 1 + else: + # 未登录词 + oov_count += 1 + # print(word) + # 对已经找到的词向量平均化 + sum_col = np.sum(embeddings, axis=0) / len(inword_list) # avg + sum_col /= np.std(sum_col) + for i in range(words_count): + if i not in inword_list: + embeddings[i] = sum_col + + embeddings[int(words_count)] = [0] * dim + final_embed = np.array(embeddings) + # print('exact_count: ',exact_count) + # print('fuzzy_count: ', fuzzy_count) + # print('oov_count: ', oov_count) + return final_embed + + +def load_data_pd(dataset_path,file): + # df=csv.reader(open(dataset_path+file,encoding='utf-8')) + # df = json.load(open(file_path,encoding='utf-8')) + df=pd.read_csv(dataset_path+file, header=0, encoding='utf-8') + return df + + +def get_frame_tabel(path, file): + data = load_data_pd(path, file) + + frame_id_to_label = {} + frame_name_to_label = {} + frame_name_to_id = {} + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in frame_id_to_label: + frame_id_to_label[data['ID'][idx]] = data_index + frame_name_to_label[data['Name'][idx]] = data_index + frame_name_to_id[data['Name'][idx]] = data['ID'][idx] + + data_index += 1 + + return frame_id_to_label, frame_name_to_label, frame_name_to_id + + +def get_fe_tabel(path, file): + data = load_data_pd(path, file) + + fe_id_to_label = {} + fe_name_to_label = {} + fe_name_to_id = {} + fe_id_to_type = {} + + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in fe_id_to_label: + fe_id_to_label[data['ID'][idx]] = data_index + fe_name_to_label[(data['Name'][idx], data['FrameID'][idx])] = data_index + fe_name_to_id[(data['Name'][idx], data['FrameID'][idx])] = data['ID'][idx] + fe_id_to_type[data['ID'][idx]] = data['CoreType'][idx] + + data_index += 1 + + return fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type + + +def get_fe_list(path, fe_num, fe_table, file='FE.csv'): + fe_dt = load_data_pd(path, file) + fe_mask_list = {} + + print('begin get fe list') + for idx in range(len(fe_dt['FrameID'])): + fe_mask_list.setdefault(fe_dt['FrameID'][idx], [0]*(fe_num+1)) + # fe_mask_list[fe_dt['FrameID'][idx]].setdefault('fe_mask', [0]*(fe_num+1)) + fe_mask_list[fe_dt['FrameID'][idx]][fe_table[fe_dt['ID'][idx]]] = 1 + + # for key in fe_list.keys(): + # fe_list[key]['fe_mask'][fe_num] = 1 + + return fe_mask_list + + +def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU.csv'): + lu_dt = load_data_pd(path, file) + lu_list = {} + lu_id_to_name = {} + lu_name_to_id = {} + #lu_name_to_felist = {} + + for idx in range(len(lu_dt['ID'])): + lu_name = lu_dt['Name'][idx] + lu_list.setdefault(lu_name, {}) + + lu_list[lu_name].setdefault('fe_mask', [0]*(fe_num+1)) + lu_list[lu_name]['fe_mask'] = list(map(lambda x: x[0]+x[1], zip(lu_list[lu_name]['fe_mask'], + fe_mask_list[lu_dt['FrameID'][idx]]))) + + lu_list[lu_name].setdefault('lu_mask', [0]*(lu_num+1)) + lu_list[lu_name]['lu_mask'][frame_id_to_label[lu_dt['FrameID'][idx]]] = 1 + + lu_id_to_name[lu_dt['ID'][idx]] = lu_name + lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx] + + for key in lu_list.keys(): + # lu_list[key]['lu_mask'][lu_num] = 1 + lu_list[key]['fe_mask'][fe_num] = 1 + + return lu_list, lu_id_to_name, lu_name_to_id + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + + + + diff --git a/fn-model_bert/dataset_bert.py b/fn-model_bert/dataset_bert.py new file mode 100644 index 0000000..2c4c297 --- /dev/null +++ b/fn-model_bert/dataset_bert.py @@ -0,0 +1,271 @@ + +import torch +import numpy as np + +from torch.utils.data import Dataset +from config_bert import get_opt +from data_process_bert import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig +from utils import get_mask_from_index + +class FrameNetDataset(Dataset): + + def __init__(self, opt, config, data_dic, device): + super(FrameNetDataset, self).__init__() + print('begin load data') + self.data_dic = data_dic + self.fe_id_to_label, self.fe_name_to_label, self.fe_name_to_id, self.fe_id_to_type = get_fe_tabel('parsed-v1.5/', 'FE.csv') + self.frame_id_to_label, self.frame_name_to_label, self.frame_name_to_id = get_frame_tabel('parsed-v1.5/', 'frame.csv') + + # self.word_index = config.word_index + # self.lemma_index = config.lemma_index + # self.pos_index = config.pos_index + # self.rel_index = config.rel_index + + self.fe_num = len(self.fe_id_to_label) + self.frame_num = len(self.frame_id_to_label) + self.batch_size = opt.batch_size + print(self.fe_num) + print(self.frame_num) + self.dataset_len = len(self.data_dic) + + self.fe_mask_list = get_fe_list('parsed-v1.5/', self.fe_num, self.fe_id_to_label) + self.lu_list, self.lu_id_to_name,\ + self.lu_name_to_id = get_lu_list('parsed-v1.5/', + self.frame_num, self.fe_num, + self.frame_id_to_label, + self.fe_mask_list) + + self.word_ids = [] + # self.lemma_ids = [] + # self.pos_ids = [] + # self.head_ids = [] + # self.rel_ids = [] + + self.lengths = [] + self.mask = [] + self.target_head = [] + self.target_tail = [] + self.target_type = [] + self.fe_head = [] + self.fe_tail = [] + self.fe_type = [] + self.fe_coretype = [] + self.sent_length = [] + self.fe_cnt = [] + self.fe_cnt_with_padding =[] + self.fe_mask = [] + self.lu_mask = [] + self.token_type_ids = [] + self.target_mask_ids = [] + + self.device = device + self.oov_frame = 0 + self.long_span = 0 + self.error_span = 0 + self.fe_coretype_table = {} + self.target_mask = {} + + for idx in self.fe_id_to_type.keys(): + if self.fe_id_to_type[idx] == 'Core': + self.fe_coretype_table[self.fe_id_to_label[idx]] = 1 + else: + self.fe_coretype_table[self.fe_id_to_label[idx]] = 0 + + + + for key in self.data_dic.keys(): + self.build_target_mask(key,opt.maxlen) + + + for key in self.data_dic.keys(): + self.pre_process(key, opt) + + self.pad_dic_cnt = self.dataset_len % opt.batch_size + + + for idx,key in enumerate(self.data_dic.keys()): + if idx >= self.pad_dic_cnt: + break + self.pre_process(key, opt,filter=False) + + self.dataset_len+=self.pad_dic_cnt + + print('load data finish') + print('oov frame = ', self.oov_frame) + print('long_span = ', self.long_span) + print('dataset_len = ', self.dataset_len) + + def __len__(self): + self.dataset_len = int(self.dataset_len / self.batch_size) * self.batch_size + return self.dataset_len + + def __getitem__(self, item): + word_ids = torch.Tensor(self.word_ids[item]).long().to(self.device) + # lemma_ids = torch.Tensor(self.lemma_ids[item]).long().to(self.device) + # pos_ids = torch.Tensor(self.pos_ids[item]).long().to(self.device) + # head_ids = torch.Tensor(self.head_ids[item]).long().to(self.device) + # rel_ids = torch.Tensor(self.rel_ids[item]).long().to(self.device) + lengths = torch.Tensor([self.lengths[item]]).long().to(self.device) + mask = torch.Tensor(self.mask[item]).long().to(self.device) + target_head = torch.Tensor([self.target_head[item]]).long().to(self.device) + target_tail = torch.Tensor([self.target_tail[item]]).long().to(self.device) + target_type = torch.Tensor([self.target_type[item]]).long().to(self.device) + fe_head = torch.Tensor(self.fe_head[item]).long().to(self.device) + fe_tail = torch.Tensor(self.fe_tail[item]).long().to(self.device) + fe_type = torch.Tensor(self.fe_type[item]).long().to(self.device) + fe_cnt = torch.Tensor([self.fe_cnt[item]]).long().to(self.device) + fe_cnt_with_padding = torch.Tensor([self.fe_cnt_with_padding[item]]).long().to(self.device) + fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device) + lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device) + token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device) + sent_length = torch.Tensor([self.sent_length[item]]).long().to(self.device) + target_mask_ids = torch.Tensor(self.target_mask_ids[item]).long().to(self.device) + # print(fe_cnt) + + + return (word_ids, lengths, mask, target_head, target_tail, target_type, + fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding, + fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids) + + def pre_process(self, key, opt,filter=True): + if self.data_dic[key]['target_type'] not in self.frame_name_to_label: + self.oov_frame += 1 + self.dataset_len -= 1 + return + + target_id = self.frame_name_to_id[self.data_dic[key]['target_type']] + if filter: + self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], + self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20) + + # word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']] + # lemma_ids = [self.lemma_index[lemma] for lemma in self.data_dic[key]['lemma_list']] + # pos_ids = [self.pos_index[pos] for pos in self.data_dic[key]['pos_list']] + # head_ids = [self.word_index[head] for head in self.data_dic[key]['dep_list'][0]] + # rel_ids = [self.rel_index[rel] for rel in self.data_dic[key]['dep_list'][1]] + + self.word_ids.append(self.data_dic[key]['tokenized_ids']) + # self.lemma_ids.append(lemma_ids) + # self.pos_ids.append(pos_ids) + # self.head_ids.append(head_ids) + # self.rel_ids.append(rel_ids) + self.lengths.append(self.data_dic[key]['length']) + + self.mask.append(self.data_dic[key]['attention_mask']) + self.target_head.append(self.data_dic[key]['target_idx'][0]) + self.target_tail.append(self.data_dic[key]['target_idx'][1]) + + # mask = get_mask_from_index(torch.Tensor([int(self.data_dic[key]['length'])]), opt.maxlen).squeeze() + # self.mask.append(mask) + + token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen) + # token_type_ids +=self.target_mask[key[0]] + self.token_type_ids.append(token_type_ids) + self.target_mask_ids.append(self.target_mask[key[0]]) + + self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']]) + + # print(self.frame_tabel[self.fe_data[key]['frame_ID']]) + + if self.data_dic[key]['length'] <= opt.maxlen: + sent_length = self.data_dic[key]['length'] + else: + sent_length = opt.maxlen + self.sent_length.append(sent_length) + + lu_name = self.data_dic[key]['lu'] + self.lu_mask.append(self.lu_list[lu_name]['lu_mask']) + self.fe_mask.append(self.lu_list[lu_name]['fe_mask']) + + fe_head = self.data_dic[key]['span_start'] + fe_tail = self.data_dic[key]['span_end'] + + + + while len(fe_head) < opt.fe_padding_num: + fe_head.append(min(sent_length-1, opt.maxlen-1)) + + while len(fe_tail) < opt.fe_padding_num: + fe_tail.append(min(sent_length-1,opt.maxlen-1)) + + self.fe_head.append(fe_head[0:opt.fe_padding_num]) + self.fe_tail.append(fe_tail[0:opt.fe_padding_num]) + + fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']] + + self.fe_cnt.append(min(len(fe_type), opt.fe_padding_num)) + self.fe_cnt_with_padding.append(min(len(fe_type)+1, opt.fe_padding_num)) + + while len(fe_type) < opt.fe_padding_num: + fe_type.append(self.fe_num) + # fe_coretype.append('0') + + self.fe_type.append(fe_type[0:opt.fe_padding_num]) + + def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, target_id, span_maxlen): + indices = [] + for index in range(len(fe_head_list)): + if fe_tail_list[index] - fe_head_list[index] >= span_maxlen: + indices.append(index) + elif fe_tail_list[index] < fe_head_list[index]: + indices.append(index) + + + elif (fe_type_list[index], target_id) not in self.fe_name_to_label: + indices.append(index) + + else: + for i in range(index): + if i not in indices: + if fe_head_list[index] >= fe_head_list[i] and fe_head_list[index] <= fe_tail_list[i]: + indices.append(index) + break + + elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]: + indices.append(index) + break + elif fe_tail_list[index] <= fe_head_list[i] and fe_tail_list[index] >= fe_tail_list[i]: + indices.append(index) + break + else: + continue + + fe_head_list_filter = [i for j, i in enumerate(fe_head_list) if j not in indices] + fe_tail_list_filter = [i for j, i in enumerate(fe_tail_list) if j not in indices] + fe_type_list_filter = [i for j, i in enumerate(fe_type_list) if j not in indices] + self.data_dic[key]['span_start'] = fe_head_list_filter + self.data_dic[key]['span_end'] = fe_tail_list_filter + self.data_dic[key]['span_type'] = fe_type_list_filter + + return len(indices) + + def build_target_mask(self,key,maxlen): + self.target_mask.setdefault(key[0], [0]*maxlen) + + target_head = self.data_dic[key]['target_idx'][0] + target_tail = self.data_dic[key]['target_idx'][1] + self.target_mask[key[0]][target_head] = 1 + self.target_mask[key[0]][target_tail] = 1 + + + + + +def build_token_type_ids(target_head, target_tail, maxlen): + token_type_ids = [0]*maxlen + token_type_ids[target_head] = 1 + token_type_ids[target_tail] = 1 + # token_type_ids[target_head:target_tail+1] = [1]*(target_tail+1-target_head) + + return token_type_ids + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + print(dataset.error_span) diff --git a/fn-model_bert/model_syntax36_with_bert.py b/fn-model_bert/model_syntax36_with_bert.py new file mode 100644 index 0000000..fc8ed07 --- /dev/null +++ b/fn-model_bert/model_syntax36_with_bert.py @@ -0,0 +1,530 @@ +import numpy as np +from typing import List,Tuple +import os +import json + +import torch.nn as nn +import torch +import torch.nn.functional as F + +from utils import batched_index_select,get_mask_from_index,generate_perm_inv +from pytorch_pretrained_bert import BertTokenizer,BertModel,BertConfig,BertAdam +from config_bert import PLMConfig,get_opt + +class Mlp(nn.Module): + def __init__(self, input_size, output_size): + super(Mlp, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, input_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + nn.Linear(input_size, output_size), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class Relu_Linear(nn.Module): + def __init__(self, input_size, output_size): + super(Relu_Linear, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, output_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class CnnNet(nn.Module): + def __init__(self, kernel_size, seq_length, input_size, output_size): + super(CnnNet, self).__init__() + self.seq_length = seq_length + self.output_size = output_size + self.kernel_size = kernel_size + + self.relu = nn.ReLU() + self.conv = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=self.kernel_size + , padding=1) + self.mp = nn.MaxPool1d(kernel_size=self.seq_length) + + def forward(self, input_emb): + input_emb = input_emb.permute(0, 2, 1) + x = self.conv(input_emb) + output = self.mp(x).squeeze() + + return output + + +class PointerNet(nn.Module): + def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): + super(PointerNet, self).__init__() + + assert attention_type in ('affine', 'dot_prod') + if attention_type == 'affine': + self.src_encoding_linear = Mlp(src_encoding_size, query_vec_size) + + self.src_linear = Mlp(src_encoding_size,src_encoding_size) + self.activate = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(0.5) + + self.fc =nn.Linear(src_encoding_size*2,src_encoding_size, bias=True) + + self.attention_type = attention_type + + def forward(self, src_encodings, src_token_mask,query_vec,head_vec=None): + + # (batch_size, 1, src_sent_len, query_vec_size) + if self.attention_type == 'affine': + src_encod = self.src_encoding_linear(src_encodings).unsqueeze(1) + head_weights = self.src_linear(src_encodings).unsqueeze(1) + + # (batch_size, tgt_action_num, query_vec_size, 1) + if head_vec is not None: + src_encod = torch.cat([src_encod,head_weights],dim = -1) + q = torch.cat([head_vec, query_vec], dim=-1).permute(1, 0, 2).unsqueeze(3) + + + else: + q = query_vec.permute(1, 0, 2).unsqueeze(3) + + weights = torch.matmul(src_encod, q).squeeze(3) + ptr_weights = weights.permute(1, 0, 2) + + # if head_vec is not None: + # src_weights = torch.matmul(head_weights, q_h).squeeze(3) + # src_weights = src_weights.permute(1, 0, 2) + # ptr_weights = weights+src_weights + # + # else: + # ptr_weights = weights + + ptr_weights_masked = ptr_weights.clone().detach() + if src_token_mask is not None: + # (tgt_action_num, batch_size, src_sent_len) + src_token_mask=1-src_token_mask.byte() + src_token_mask = src_token_mask.unsqueeze(0).expand_as(ptr_weights) + # ptr_weights.data.masked_fill_(src_token_mask, -float('inf')) + ptr_weights_masked.data.masked_fill_(src_token_mask, -float('inf')) + + # ptr_weights =self.activate(ptr_weights) + + return ptr_weights,ptr_weights_masked + + +class Encoder(nn.Module): + def __init__(self, opt, config, bert_frozen=False): + super(Encoder, self).__init__() + self.opt =opt + self.hidden_size = opt.rnn_hidden_size + self.emb_size = opt.encoder_emb_size + self.rnn_input_size = self.emb_size+opt.token_type_emb_size + # self.word_number = config.word_number + # self.lemma_number = config.lemma_number + self.maxlen = opt.maxlen + + bert_config = BertConfig.from_json_file(PLMConfig.CONFIG_PATH) + self.bert = BertModel.from_pretrained(PLMConfig.MODEL_PATH) + + if bert_frozen: + print('Bert grad is false') + for param in self.bert.parameters(): + param.requires_grad = False + + self.bert_hidden_size = bert_config.hidden_size + # self.bert_dropout = nn.Dropout(bert_config.hidden_dropout_prob) + + + self.dropout = 0.2 + # self.word_embedding = word_embedding + # self.lemma_embedding = lemma_embedding + # self.pos_embedding = nn.Embedding(config.pos_number, opt.pos_emb_size) + # self.rel_embedding = nn.Embedding(config.rel_number,opt.rel_emb_size) + self.token_type_embedding = nn.Embedding(3, opt.token_type_emb_size) + self.cell_name = opt.cell_name + + # self.embedded_linear = nn.Linear(self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size+opt.sent_emb_size, + # self.rnn_input_size) + # self.syntax_embedded_linear = nn.Linear(self.emb_size*2+opt.rel_emb_size+opt.token_type_emb_size, + # self.rnn_input_size) + # self.output_combine_linear = nn.Linear(4*self.hidden_size, 2*self.hidden_size) + + self.target_linear = nn.Linear(2*self.hidden_size, 2*self.hidden_size) + + self.relu_linear = Relu_Linear(4*self.hidden_size+self.rnn_input_size, opt.decoder_emb_size) + + if self.cell_name == 'gru': + self.rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + elif self.cell_name == 'lstm': + self.rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + else: + print('cell_name error') + + def forward(self, word_input: torch.Tensor, lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None,target_mask_ids=None): + + # word_embedded = self.word_embedding(word_input) + # lemma_embedded = self.lemma_embedding(lemma_input) + # pos_embedded = self.pos_embedding(pos_input) + # head_embedded = self.word_embedding(head_input) + # rel_embedded = self.rel_embedding(rel_input) + # type_ids =torch.add(token_type_ids, target_mask_ids) + token_type_embedded = self.token_type_embedding(token_type_ids) + #print(token_type_embedded) + # print(token_type_ids.size()) + # print(target_mask_ids.size()) + # print(token_type_embedded.size()) + hidden_state, cls = self.bert(word_input, token_type_ids, \ + attention_mask=attention_mask,\ + output_all_encoded_layers=False) + + embedded = torch.cat([hidden_state.squeeze(),token_type_embedded], dim=-1) + # sent_embedded = self.cnn(embedded) + + # sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) + #embedded = torch.cat([embedded,sent_embedded], dim=-1) + #embedded = self.embedded_linear(embedded) + + # syntax embedding + # syntax_embedded = torch.cat([word_embedded,head_embedded,rel_embedded,token_type_ids],dim=-1) + # syntax_embedded = self.syntax_embedded_linear(syntax_embedded) + + lengths=lengths.squeeze() + # sorted before pack + l = lengths.cpu().numpy() + perm_idx = np.argsort(-l) + perm_idx_inv = generate_perm_inv(perm_idx) + + embedded = embedded[perm_idx] + # syntax_embedded = syntax_embedded[perm_idx] + + if lengths is not None: + rnn_input = nn.utils.rnn.pack_padded_sequence(embedded, lengths=lengths[perm_idx], + batch_first=True) + # syntax_rnn_input = nn.utils.rnn.pack_padded_sequence(syntax_embedded, lengths=lengths[perm_idx], + # batch_first=True) + output, hidden = self.rnn(rnn_input) + #syntax_output, syntax_hidden = self.rnn(syntax_rnn_input) + + if lengths is not None: + output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=self.maxlen, batch_first=True) + # syntax_output, _ = nn.utils.rnn.pad_packed_sequence(syntax_output, total_length=self.maxlen, batch_first=True) + + + # print(output.size()) + # print(hidden.size()) + + output = output[perm_idx_inv] + # syntax_output = syntax_output[perm_idx_inv] + + if self.cell_name == 'gru': + hidden = hidden[:, perm_idx_inv] + hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(hidden, 1, dim=0)) + + # syntax_hidden = syntax_hidden[:, perm_idx_inv] + # syntax_hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(syntax_hidden, 1, dim=0)) + # hidden = (hidden + syntax_hidden) / 2 + + elif self.cell_name == 'lstm': + hn0 = hidden[0][:, perm_idx_inv] + hn1 = hidden[1][:, perm_idx_inv] + # sy_hn0 = syntax_hidden[0][:, perm_idx_inv] + # sy_hn1 = syntax_hidden[1][:, perm_idx_inv] + hn = tuple([hn0,hn1]) + hidden = tuple(map(lambda state: sum(torch.split(state, 1, dim=0))/(2*self.opt.num_layers), hn)) + + + target_state_head = batched_index_select(target=output, indices=frame_idx[0]) + target_state_tail = batched_index_select(target=output, indices=frame_idx[1]) + target_state = (target_state_head + target_state_tail) / 2 + target_state = self.target_linear(target_state) + + target_emb_head = batched_index_select(target=embedded, indices=frame_idx[0]) + target_emb_tail = batched_index_select(target=embedded, indices=frame_idx[1]) + target_emb = (target_emb_head + target_emb_tail) / 2 + + attentional_target_state = type_attention(attention_mask=target_mask_ids, hidden_state=output, + target_state=target_state) + + + target_state =torch.cat([target_state.squeeze(),attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) + target = self.relu_linear(target_state) + + # print(output.size()) + return output, hidden, target + + +class Decoder(nn.Module): + def __init__(self, opt, embedding_frozen=False): + super(Decoder, self).__init__() + + # rnn _init_ + self.opt = opt + self.cell_name = opt.cell_name + self.emb_size = opt.decoder_emb_size + self.hidden_size = opt.decoder_hidden_size + self.encoder_hidden_size = opt.rnn_hidden_size + + # decoder _init_ + self.decodelen = opt.fe_padding_num+1 + self.frame_embedding = nn.Embedding(opt.frame_number+1, self.emb_size) + self.frame_fc_layer =Mlp(self.emb_size, opt.frame_number+1) + self.role_embedding = nn.Embedding(opt.role_number+1, self.emb_size) + self.role_feature_layer = nn.Linear(2*self.emb_size, self.emb_size) + self.role_fc_layer = nn.Linear(self.hidden_size+self.emb_size, opt.role_number+1) + + self.head_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + self.tail_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + + self.span_fc_layer = Mlp(4 * self.encoder_hidden_size + self.emb_size, self.emb_size) + + self.next_input_fc_layer = Mlp(self.hidden_size+self.emb_size, self.emb_size) + + if embedding_frozen is True: + for param in self.frame_embedding.parameters(): + param.requires_grad = False + for param in self.role_embedding.parameters(): + param.requires_grad = False + + if self.cell_name == 'gru': + self.frame_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + if self.cell_name == 'lstm': + self.frame_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + + # pointer _init_ + self.ent_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.head_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.tail_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + + def forward(self, encoder_output: torch.Tensor, encoder_state: torch.Tensor, target_state: torch.Tensor, + attention_mask: torch.Tensor, fe_mask=None, lu_mask=None): + pred_frame_list = [] + pred_head_list = [] + pred_tail_list = [] + pred_role_list = [] + + pred_frame_action = [] + + pred_head_action = [] + pred_tail_action = [] + pred_role_action = [] + + frame_decoder_state = encoder_state + role_decoder_state = encoder_state + + input = target_state + # print(input.size()) + + span_mask = attention_mask.clone() + for t in range(self.decodelen): + # frame pred + output, frame_decoder_state = self.decode_step(self.frame_rnn, input=input, + decoder_state=frame_decoder_state) + + pred_frame_weight = self.frame_fc_layer(target_state.squeeze()) + pred_frame_weight_masked = pred_frame_weight.clone().detach() + + if lu_mask is not None: + LU_mask = 1-lu_mask + pred_frame_weight_masked.data.masked_fill_(LU_mask.byte(), -float('inf')) + + pred_frame_indices = torch.argmax(pred_frame_weight_masked.squeeze(), dim=-1).squeeze() + + pred_frame_list.append(pred_frame_weight) + pred_frame_action.append(pred_frame_indices) + + frame_emb = self.frame_embedding(pred_frame_indices) + + head_input = self.head_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + tail_input = self.tail_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + + head_pointer_weight, head_pointer_weight_masked = self.head_pointer(src_encodings=encoder_output, + src_token_mask=span_mask, + query_vec=head_input.view(1, self.opt.batch_size, -1)) + + head_indices = torch.argmax(head_pointer_weight_masked.squeeze(), dim=-1).squeeze() + head_target = batched_index_select(target=encoder_output, indices=head_indices.squeeze()) + head_mask = head_mask_update(span_mask, head_indices=head_indices, max_len=self.opt.maxlen) + + tail_pointer_weight, tail_pointer_weight_masked = self.tail_pointer(src_encodings=encoder_output, + src_token_mask=head_mask, + query_vec=tail_input.view(1, self.opt.batch_size,-1), + head_vec=head_target.view(1,self.opt.batch_size,-1)) + + tail_indices = torch.argmax(tail_pointer_weight_masked.squeeze(), dim=-1).squeeze() + # tail_target = batched_index_select(target=bert_hidden_state,indices=tail_indices.squeeze()) + + span_mask = span_mask_update(attention_mask=span_mask, head_indices=head_indices, + tail_indices=tail_indices, max_len=self.opt.maxlen) + + pred_head_list.append(head_pointer_weight) + pred_tail_list.append(tail_pointer_weight) + pred_head_action.append(head_indices) + pred_tail_action.append(tail_indices) + + # role pred + + # print(ent_target.size()) + # print(head_target.size()) + # print(tail_target.size()) + # print(bert_hidden_state.size()) + # print(tail_pointer_weight.size()) + # print(tail_indices.size()) + # print(output.size()) + + # next step + # head_target = batched_index_select(target=bert_hidden_state, indices=head_indices.squeeze()) + + tail_target = batched_index_select(target=encoder_output, indices=tail_indices.squeeze()) + + # head_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=head_target, + # window_size=0, max_len=self.opt.maxlen) + # + # tail_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=tail_target, + # window_size=0, max_len=self.opt.maxlen) + + span_input = self.span_fc_layer(torch.cat([head_target+tail_target, head_target-tail_target, frame_emb], dim=-1)).unsqueeze(1) + + output,role_decoder_state = self.decode_step(self.role_rnn, input=span_input, + decoder_state=role_decoder_state) + role_target = self.role_fc_layer(torch.cat([span_input,output],dim=-1)) + role_target_masked = role_target.squeeze().clone().detach() + if fe_mask is not None : + FE_mask = 1-fe_mask + # print(FE_mask.size()) + role_target_masked.data.masked_fill_(FE_mask.byte(), -float('inf')) + + role_indices = torch.argmax(role_target_masked.squeeze(), dim=-1).squeeze() + role_emb = self.role_embedding(role_indices) + + pred_role_list.append(role_target) + pred_role_action.append(role_indices) + + # next step + next_input =torch.cat([output, role_emb.unsqueeze(1)], dim=-1) + input = self.next_input_fc_layer(next_input) + + + #return + return pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action,\ + pred_head_action, pred_tail_action, pred_role_action + + def decode_step(self, rnn_cell: nn.modules, input: torch.Tensor, decoder_state: torch.Tensor): + + output, state = rnn_cell(input.view(-1, 1, self.emb_size), decoder_state) + + return output, state + + +class Model(nn.Module): + def __init__(self, opt, config): + super(Model, self).__init__() + # self.word_vectors = config.word_vectors + # self.lemma_vectors = config.lemma_vectors + # self.word_embedding = nn.Embedding(config.word_number+1, opt.encoder_emb_size) + # self.lemma_embedding = nn.Embedding(config.lemma_number+1, opt.encoder_emb_size) + # + # if load_emb: + # self.load_pretrain_emb() + + self.encoder = Encoder(opt, config) + self.decoder = Decoder(opt) + + # def load_pretrain_emb(self): + # self.word_embedding.weight.data.copy_(torch.from_numpy(self.word_vectors)) + # self.lemma_embedding.weight.data.copy_(torch.from_numpy(self.lemma_vectors)) + + def forward(self, word_ids, lengths, frame_idx, fe_mask=None, lu_mask=None, + frame_len=None, token_type_ids=None, attention_mask=None,target_mask_ids=None): + encoder_output, encoder_state, target_state = self.encoder(word_input=word_ids, lengths=lengths, + frame_idx=frame_idx, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + target_mask_ids=target_mask_ids) + + pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + pred_head_action, pred_tail_action, pred_role_action = self.decoder(encoder_output=encoder_output, + encoder_state=encoder_state, + target_state=target_state, + attention_mask=attention_mask, + fe_mask=fe_mask, lu_mask=lu_mask) + + # return pred_frame_list, pred_ent_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + # pred_ent_action, pred_head_action, pred_tail_action, pred_role_action + + return { + 'pred_frame_list' : pred_frame_list, + 'pred_head_list' : pred_head_list, + 'pred_tail_list' : pred_tail_list, + 'pred_role_list' : pred_role_list, + 'pred_frame_action' : pred_frame_action, + 'pred_head_action' : pred_head_action, + 'pred_tail_action' : pred_tail_action, + 'pred_role_action' : pred_role_action + } + + +def head_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, max_len): + indices=head_indices + indices_mask=1-get_mask_from_index(indices, max_len) + mask = torch.mul(attention_mask, indices_mask.long()) + + return mask + + +def span_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, tail_indices: torch.Tensor, max_len): + tail = tail_indices + 1 + head_indices_mask = get_mask_from_index(head_indices, max_len) + tail_indices_mask = get_mask_from_index(tail, max_len) + span_indices_mask = tail_indices_mask - head_indices_mask + span_indices_mask = 1 - span_indices_mask + mask = torch.mul(attention_mask, span_indices_mask.long()) + + return mask + + +def local_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, frame_idx, + target_state: torch.Tensor, window_size: int, max_len): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + head = frame_idx[0]-window_size + tail = frame_idx[1]+window_size + mask = span_mask_update(attention_mask=attention_mask, head_indices=head.squeeze(), + tail_indices=tail.squeeze(), max_len=max_len) + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + +def type_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, + target_state: torch.Tensor): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + + mask = 1-attention_mask + + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + + diff --git a/fn-model_bert/train_syntax36_bert.py b/fn-model_bert/train_syntax36_bert.py new file mode 100644 index 0000000..96908c5 --- /dev/null +++ b/fn-model_bert/train_syntax36_bert.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import os +import multiprocessing as mp + +from dataset_bert import FrameNetDataset +from torch.utils.data import DataLoader +from utils import get_mask_from_index,seed_everything +from evaluate10 import Eval +from config_bert import get_opt +from data_process_bert import DataConfig +from model_syntax36_with_bert import Model + +from pytorch_pretrained_bert import BertTokenizer,BertModel,BertConfig,BertAdam + + + +def evaluate(opt, model, dataset, best_metrics=None, show_case=False): + model.eval() + print('begin eval') + evaler = Eval(opt) + with torch.no_grad(): + test_dl = DataLoader( + dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + + for batch in test_dl: + word_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids, lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + + evaler.metrics(batch_size=opt.batch_size, fe_cnt=fe_cnt, gold_fe_type=fe_type, gold_fe_head=fe_head, \ + gold_fe_tail=fe_tail, gold_frame_type=target_type, + pred_fe_type=return_dic['pred_role_action'], + pred_fe_head=return_dic['pred_head_action'], + pred_fe_tail=return_dic['pred_tail_action'], + pred_frame_type=return_dic['pred_frame_action'], + fe_coretype=dataset.fe_coretype_table,sent_length=sent_length, + lu_mask=lu_mask) + + if show_case: + print('gold_fe_label = ', fe_type) + print('pred_fe_label = ', return_dic['pred_role_action']) + print('gold_head_label = ', fe_head) + print('pred_head_label = ', return_dic['pred_head_action']) + print('gold_tail_label = ', fe_tail) + print('pred_tail_label = ', return_dic['pred_tail_action']) + + metrics = evaler.calculate() + + + if best_metrics: + + if metrics[-1] > best_metrics: + best_metrics = metrics[-1] + + torch.save(model.state_dict(), opt.save_model_path) + + return best_metrics + + + + + + +if __name__ == '__main__': + + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + # os.environ['CUDA_LAUNCH_BLOCKING'] = 1 + + # bertconfig = BertConfig() + # print(bertconfig.CONFIG_PATH) + mp.set_start_method('spawn') + + opt = get_opt() + config = DataConfig(opt) + + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + + seed_everything(1116) + + epochs = opt.epochs + model = Model(opt, config) + model.to(device) + + clip_grad = 1.0 + # warmup率,前多少比例的步,lr逐渐升高 + warmup_proportion = 0.2 + # 学习率 + learning_rate = opt.lr + + best_metrics = 0.0 + + pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) + train_dataset = FrameNetDataset(opt, config, config.train_instance_dic, device) + dev_dataset = FrameNetDataset(opt, config, config.dev_instance_dic, device) + test_dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + + num_train_steps = int((len(train_dataset) / opt.batch_size) * opt.epochs) + # prepare optimizer + param_optimizer = list(model.named_parameters()) + # bias和LayerNorm的参数不需要权重衰减,不然可能LayerNorm效果会不对 + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}] + + optimizer = BertAdam(optimizer_grouped_parameters, + lr=learning_rate, + warmup=warmup_proportion, + t_total=num_train_steps, + max_grad_norm=clip_grad) + + frame_criterion =nn.CrossEntropyLoss() + head_criterion = nn.CrossEntropyLoss() + tail_criterion = nn.CrossEntropyLoss() + fe_type_criterion = nn.CrossEntropyLoss() + + if os.path.exists(opt.save_model_path) is True: + model.load_state_dict(torch.load(opt.save_model_path)) + + best_metrics = -1 + + if opt.mode == 'train': + for epoch in range(1, epochs): + # scheduler.step() + train_dl = DataLoader( + train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + model.train() + print('==========epochs= ' + str(epoch)) + step = 0 + sum_loss = 0 + # best_metrics = -1 + cnt = 0 + for batch in train_dl: + optimizer.zero_grad() + loss = 0 + word_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids + ,lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + # print(return_dic) + + frame_loss = 0 + head_loss = 0 + tail_loss = 0 + type_loss = 0 + + for batch_index in range(opt.batch_size): + pred_frame_first = return_dic['pred_frame_list'][fe_cnt[batch_index]][batch_index].unsqueeze(0) + pred_frame_last = return_dic['pred_frame_list'][0][batch_index].unsqueeze(0) + + pred_frame_label = pred_frame_last + + gold_frame_label = target_type[batch_index] + # print(gold_frame_label.size()) + # print(pred_frame_label.size()) + # print(fe_head) + frame_loss += frame_criterion(pred_frame_label, gold_frame_label) + + + for fe_index in range(opt.fe_padding_num): + + # print(fe_cnt[batch_index]) + + + pred_type_label = return_dic['pred_role_list'][fe_index].squeeze() + pred_type_label = pred_type_label[batch_index].unsqueeze(0) + + gold_type_label = fe_type[batch_index][fe_index].unsqueeze(0) + type_loss += fe_type_criterion(pred_type_label, gold_type_label) + + + if fe_index >= fe_cnt[batch_index]: + break + + pred_head_label = return_dic['pred_head_list'][fe_index].squeeze() + pred_head_label = pred_head_label[batch_index].unsqueeze(0) + + gold_head_label = fe_head[batch_index][fe_index].unsqueeze(0) + # print(gold_head_label.size()) + # print(pred_head_label.size()) + head_loss += head_criterion(pred_head_label, gold_head_label) + + pred_tail_label = return_dic['pred_tail_list'][fe_index].squeeze() + pred_tail_label = pred_tail_label[batch_index].unsqueeze(0) + + gold_tail_label = fe_tail[batch_index][fe_index].unsqueeze(0) + tail_loss += tail_criterion(pred_tail_label, gold_tail_label) + + + # print(fe_cnt[batch_index]) + # head_loss /= int(fe_cnt[batch_index]) + # tail_loss /= int(fe_cnt[batch_index]) + # type_loss /= int(fe_cnt[batch_index]+1) + # + # head_loss_total+=head_loss + # tail_loss_total+=tail_loss + # type_loss_total+=type_loss + + loss = (0.1 * frame_loss + 0.3 * type_loss + 0.3 * head_loss + 0.3 * tail_loss) / (opt.batch_size) + # loss = (0.3 * head_loss + 0.3 * tail_loss) / (opt.b0.3 * atch_size) + loss.backward() + optimizer.step() + # loss+=frame_loss() + step += 1 + if step % 20 == 0: + print(" | batch loss: %.6f step = %d" % (loss.item(), step)) + # print('gold_frame_label = ',target_type) + # print('pred_frame_label = ',return_dic['pred_frame_action']) + + for index in range(len(target_type)): + if target_type[index] == return_dic['pred_frame_action'][0][index]: + cnt += 1 + # print('gold_fe_label = ', fe_type) + # print('pred_fe_label = ', return_dic['pred_role_action']) + # print('gold_head_label = ', fe_head) + # print('pred_head_label = ', return_dic['pred_head_action']) + # print('gold_tail_label = ', fe_tail) + # print('pred_tail_label = ', return_dic['pred_tail_action']) + sum_loss += loss.item() + + print('| epoch %d avg loss = %.6f' % (epoch, sum_loss / step)) + print('| epoch %d prec = %.6f' % (epoch, cnt / (opt.batch_size * step))) + + best_metrics=evaluate(opt,model,dev_dataset,best_metrics) + + + else: + evaluate(opt,model,test_dataset,show_case=True) + + + + diff --git a/graph_construct.py b/graph_construct.py new file mode 100644 index 0000000..08bf409 --- /dev/null +++ b/graph_construct.py @@ -0,0 +1,51 @@ +import torch +import dgl + +from data_process4 import get_fe_tabel, get_frame_tabel,load_data_pd +from relation_extraction import parserelationFiles + + +def graph_construct(fr_path='', fr_file='frame.csv', fe_path='', fe_file='FE.csv'): + frame_id_to_label, frame_name_to_label, frame_name_to_id = get_frame_tabel(path=fr_path, file=fr_file) + fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type = get_fe_tabel(path=fe_path, file=fe_file) + fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid= parserelationFiles() + + data_dic = {} + + # frame to frame + for item in zip(fr_sup, fr_sub, fr_label): + #print(item[2]) + data_dic.setdefault(('frame', rel_name[item[2]], 'frame'), [[], []]) + data_dic[('frame', rel_name[item[2]], 'frame')][0].append(frame_id_to_label[item[0]]) + data_dic[('frame', rel_name[item[2]], 'frame')][1].append(frame_id_to_label[item[1]]) + + # fe to fe + data_dic.setdefault(('fe', 'fe_to_fe', 'fe'), [[], []]) + for item in zip(fe_sup, fe_sub, fe_label): + data_dic[('fe', 'fe_to_fe', 'fe')][0].append(fe_id_to_label[item[0]]) + data_dic[('fe', 'fe_to_fe', 'fe')][1].append(fe_id_to_label[item[1]]) + data_dic[('fe', 'fe_to_fe', 'fe')][0].append(fe_id_to_label[item[1]]) + data_dic[('fe', 'fe_to_fe', 'fe')][1].append(fe_id_to_label[item[0]]) + + # frame to fe + fe_dt = load_data_pd(dataset_path=fe_path, file=fe_file) + data_dic[('frame', 'fr_to_fe', 'fe')] = [[], []] + data_dic[('fe', 'fe_to_fr', 'frame')] = [[], []] + for idx in range(len(fe_dt['FrameID'])): + data_dic[('frame', 'fr_to_fe', 'fe')][0].append(frame_id_to_label[fe_dt['FrameID'][idx]]) + data_dic[('frame', 'fr_to_fe', 'fe')][1].append(fe_id_to_label[fe_dt['ID'][idx]]) + + data_dic[('fe', 'fe_to_fr', 'frame')][0].append(fe_id_to_label[fe_dt['ID'][idx]]) + data_dic[('fe', 'fe_to_fr', 'frame')][1].append(frame_id_to_label[fe_dt['FrameID'][idx]]) + + for key in data_dic.keys(): + data_dic[key] = (torch.Tensor(data_dic[key][0]).long(), torch. Tensor(data_dic[key][1]).long()) + + g = dgl.heterograph(data_dic) + + return g + +if __name__ == '__main__': + + g = graph_construct() + print(g) diff --git a/model_syntax2.py b/model_syntax36_final.py similarity index 89% rename from model_syntax2.py rename to model_syntax36_final.py index c0bd5f5..ee5110d 100644 --- a/model_syntax2.py +++ b/model_syntax36_final.py @@ -122,14 +122,13 @@ def __init__(self, opt, config, word_embedding:nn.modules.sparse.Embedding, self.opt =opt self.hidden_size = opt.rnn_hidden_size self.emb_size = opt.encoder_emb_size - self.rnn_input_size = self.emb_size*3+opt.pos_emb_size+opt.token_type_emb_size+\ - opt.sent_emb_size+opt.rel_emb_size + self.rnn_input_size = self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size self.word_number = config.word_number self.lemma_number = config.lemma_number self.maxlen = opt.maxlen self.cnn = CnnNet(opt.kernel_size, seq_length=opt.maxlen, - input_size=self.emb_size*3+opt.pos_emb_size+opt.token_type_emb_size+opt.rel_emb_size, + input_size=self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size, output_size=opt.sent_emb_size) self.dropout = 0.2 @@ -165,20 +164,24 @@ def __init__(self, opt, config, word_embedding:nn.modules.sparse.Embedding, def forward(self, word_input: torch.Tensor, lemma_input: torch.Tensor, pos_input: torch.Tensor, head_input:torch.Tensor, rel_input:torch.Tensor, - lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None): + lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None,target_mask_ids=None): word_embedded = self.word_embedding(word_input) lemma_embedded = self.lemma_embedding(lemma_input) pos_embedded = self.pos_embedding(pos_input) - head_embedded = self.word_embedding(head_input) - rel_embedded = self.rel_embedding(rel_input) - token_type_ids = self.token_type_embedding(token_type_ids) - - embedded = torch.cat([word_embedded, lemma_embedded, pos_embedded,head_embedded,rel_embedded,token_type_ids], dim=-1) - sent_embedded = self.cnn(embedded) - - sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) - embedded = torch.cat([embedded,sent_embedded], dim=-1) + # head_embedded = self.word_embedding(head_input) + # rel_embedded = self.rel_embedding(rel_input) + # type_ids =torch.add(token_type_ids, target_mask_ids) + token_type_embedded = self.token_type_embedding(token_type_ids) + #print(token_type_embedded) + # print(token_type_ids.size()) + # print(target_mask_ids.size()) + # print(token_type_embedded.size()) + embedded = torch.cat([word_embedded, lemma_embedded, pos_embedded,token_type_embedded], dim=-1) + # sent_embedded = self.cnn(embedded) + + # sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) + #embedded = torch.cat([embedded,sent_embedded], dim=-1) #embedded = self.embedded_linear(embedded) # syntax embedding @@ -239,13 +242,11 @@ def forward(self, word_input: torch.Tensor, lemma_input: torch.Tensor, pos_input target_emb_tail = batched_index_select(target=embedded, indices=frame_idx[1]) target_emb = (target_emb_head + target_emb_tail) / 2 - attentional_target_state = local_attention(attention_mask=attention_mask, hidden_state=output, - frame_idx=frame_idx,target_state=target_state, - window_size=self.opt.window_size, - max_len=self.opt.maxlen) + attentional_target_state = type_attention(attention_mask=target_mask_ids, hidden_state=output, + target_state=target_state) - target_state =torch.cat([target_state.squeeze(), attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) + target_state =torch.cat([target_state.squeeze(),attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) target = self.relu_linear(target_state) # print(output.size()) @@ -391,7 +392,7 @@ def forward(self, encoder_output: torch.Tensor, encoder_state: torch.Tensor, tar output,role_decoder_state = self.decode_step(self.role_rnn, input=span_input, decoder_state=role_decoder_state) - role_target = self.role_fc_layer(torch.cat([output,span_input],dim=-1)) + role_target = self.role_fc_layer(torch.cat([span_input,output],dim=-1)) role_target_masked = role_target.squeeze().clone().detach() if fe_mask is not None : FE_mask = 1-fe_mask @@ -439,13 +440,14 @@ def load_pretrain_emb(self): self.lemma_embedding.weight.data.copy_(torch.from_numpy(self.lemma_vectors)) def forward(self, word_ids, lemma_ids, pos_ids, head_ids, rel_ids, lengths, frame_idx, fe_mask=None, lu_mask=None, - frame_len=None, token_type_ids=None, attention_mask=None): + frame_len=None, token_type_ids=None, attention_mask=None,target_mask_ids=None): encoder_output, encoder_state, target_state = self.encoder(word_input=word_ids, lemma_input=lemma_ids, pos_input=pos_ids, head_input=head_ids, rel_input=rel_ids, lengths=lengths, frame_idx=frame_idx, token_type_ids=token_type_ids, - attention_mask=attention_mask) + attention_mask=attention_mask, + target_mask_ids=target_mask_ids) pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ pred_head_action, pred_tail_action, pred_role_action = self.decoder(encoder_output=encoder_output, @@ -503,5 +505,18 @@ def local_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, fr return attentional_hidden_state +def type_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, + target_state: torch.Tensor): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + + mask = 1-attention_mask + + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state diff --git a/relation_extraction.py b/relation_extraction.py new file mode 100644 index 0000000..fb19821 --- /dev/null +++ b/relation_extraction.py @@ -0,0 +1,56 @@ +try: + import xml.etree.cElementTree as ET +except ImportError: + import xml.etree.ElementTree as ET +import sys +import os +import pandas as pd + + +def parserelationFiles(c=0, relation_path='', filename='frRelation.xml', tagPrefix="{http://framenet.icsi.berkeley.edu}"): + tree = ET.ElementTree(file=relation_path + filename) + root = tree.getroot() + + fr_sub = [] + fr_sup = [] + fe_sub = [] + fe_sup = [] + fe_sub_to_frid = [] + fe_sup_to_frid = [] + fr_label =[] + fe_label =[] + rel_idx = -1 + rel_name = [] + + for child in root.iter(): + if child.tag == tagPrefix + 'frameRelationType': + rel_idx += 1 + rel_name.append(child.attrib.get('name')) + + for frame_child in child.iter(): + if frame_child.tag == tagPrefix + 'frameRelation': + fr_sub.append(int(frame_child.attrib.get('subID'))) + fr_sup.append(int(frame_child.attrib.get('supID'))) + fr_label.append(rel_idx) + + for fe_child in frame_child.iter(): + if fe_child.tag == tagPrefix + 'FERelation': + fe_sub.append(int(fe_child.attrib.get('subID'))) + fe_sup.append(int(fe_child.attrib.get('supID'))) + fe_sub_to_frid.append(int(frame_child.attrib.get('subID'))) + fe_sup_to_frid.append(int(frame_child.attrib.get('supID'))) + fe_label.append(rel_idx) + + return fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid + + +# fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid = parserelationFiles() +# print(fr_sub) +# print(fr_sup) +# print(fr_label) +# print(fe_sub) +# print(fe_sup) +# print(fe_label) +# print(fe_sub_to_frid) +# print(fe_sup_to_frid) +# print(rel_name) \ No newline at end of file diff --git a/rgcn.py b/rgcn.py new file mode 100644 index 0000000..9252339 --- /dev/null +++ b/rgcn.py @@ -0,0 +1,114 @@ +import dgl +import dgl.nn as dglnn +import numpy as np +import torch + +import torch.nn as nn +import torch.nn.functional as F +from graph_construct import graph_construct + +class RGCN(nn.Module): + def __init__(self, in_feats, hid_feats, num_frame, num_fe, rel_names): + super().__init__() + self.conv1 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(in_feats, hid_feats) + for rel in rel_names}, aggregate='mean') + # self.conv2 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(hid_feats, out_feats) + # for rel in rel_names}, aggregate='mean') + + self.fr_linear = nn.Linear(num_frame, in_feats) + self.fe_linear = nn.Linear(num_fe, in_feats) + + self.fr_fc = nn.Linear(hid_feats, num_frame) + self.fe_fc = nn.Linear(hid_feats, num_fe) + + def forward(self, graph, inputs): + + inputs['frame'] = self.fr_linear(inputs['frame']) + inputs['fe'] = self.fe_linear(inputs['fe']) + h = self.conv1(graph, inputs) + h = {k: F.relu(v) for k, v in h.items()} + h = self.conv1(graph, h) + h['frame'] = h['frame'] + inputs['frame'] + h['fe'] = h['fe'] + inputs['fe'] + fr_h = self.fr_fc(h['frame']) + fe_h = self.fe_fc(h['fe']) + + return fr_h, fe_h, h + +if __name__ == '__main__': + + frame_graph = graph_construct() + # 特征初始化 + num_frame = 1019 + num_fe = 9634 + n_features= 200 + save_model_path = './pretrain_model' + rel_names = ['Causative_of', 'Inchoative_of', 'Inheritance', 'Perspective_on', 'Precedes', + 'ReFraming_Mapping', 'See_also', 'Subframe', 'Using','fe_to_fe', 'fe_to_fr', 'fr_to_fe'] + + + frame_graph.nodes['frame'].data['feature'] = torch.eye(num_frame) + frame_graph.nodes['fe'].data['feature'] = torch.eye(num_fe) + frame_graph.nodes['frame'].data['label'] = torch.arange(0, num_frame, 1) + frame_graph.nodes['fe'].data['label'] = torch.arange(0, num_fe, 1) + + model = RGCN(in_feats=n_features, hid_feats=200, + num_frame=num_frame, num_fe=num_fe, rel_names=rel_names) + + fr_feats = frame_graph.nodes['frame'].data['feature'] + fe_feats = frame_graph.nodes['fe'].data['feature'] + fr_labels = frame_graph.nodes['frame'].data['label'] + fe_labels = frame_graph.nodes['fe'].data['label'] + + #h_dict = model(frame_graph, {'frame': fr_feats, 'fe': fe_feats}) + + opt = torch.optim.Adam(model.parameters()) + + best_fr_train_acc = 0 + best_fe_train_acc = 0 + loss_list = [] + train_fr_score_list = [] + train_fe_score_list = [] + + for epoch in range(200): + model.train() + # 输入图和节点特征 + fr_logits, fe_logits, hidden = model(frame_graph, {'frame': fr_feats, 'fe': fe_feats}) + # 计算损失 + loss = F.cross_entropy(fr_logits, fr_labels) + F.cross_entropy(fe_logits, fe_labels) + # 预测frame + fr_pred = fr_logits.argmax(1) + # 计算准确率 + fr_train_acc = (fr_pred == fr_labels).float().mean() + if best_fr_train_acc < fr_train_acc: + best_fr_train_acc = fr_train_acc + train_fr_score_list.append(fr_train_acc) + + # 预测fe + fe_pred = fe_logits.argmax(1) + # 计算准确率 + fe_train_acc = (fe_pred == fe_labels).float().mean() + if best_fe_train_acc < fe_train_acc: + best_fe_train_acc = fe_train_acc + train_fe_score_list.append(fe_train_acc) + + # 反向优化 + opt.zero_grad() + loss.backward() + opt.step() + + loss_list.append(loss.item()) + # 输出训练结果 + print('Loss %.4f, Train fr Acc %.4f (Best %.4f) Train fe Acc %.4f (Best %.4f)' % ( + loss.item(), + fr_train_acc.item(), + best_fr_train_acc, + fe_train_acc.item(), + best_fe_train_acc + )) + #print(frame_graph.nodes['frame'].data['feature']) + + torch.save(model.state_dict(), save_model_path) + print(hidden) + torch.save(hidden['frame'], "./frTensor4.pt") + torch.save(hidden['fe'], "./feTensor4.pt") \ No newline at end of file diff --git a/train_syntax.py b/train_syntax36_final.py similarity index 90% rename from train_syntax.py rename to train_syntax36_final.py index ed30f85..2a2a5d9 100644 --- a/train_syntax.py +++ b/train_syntax36_final.py @@ -5,13 +5,13 @@ import os import multiprocessing as mp -from dataset2 import FrameNetDataset +from dataset5 import FrameNetDataset from torch.utils.data import DataLoader from utils import get_mask_from_index,seed_everything -from evaluate2 import Eval +from evaluate8 import Eval from config import get_opt -from data_process2 import DataConfig -from model_syntax2 import Model +from data_process4 import DataConfig +from model_syntax36_final import Model def evaluate(opt, model, dataset, best_metrics=None, show_case=False): @@ -29,22 +29,21 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): for batch in test_dl: word_ids, lemma_ids, pos_ids,head_ids,rel_ids, lengths, attention_mask, target_head, target_tail, \ target_type, fe_head, fe_tail, fe_type, fe_cnt, \ - fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids = batch + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids,head_ids=head_ids, rel_ids=rel_ids,lengths=lengths, frame_idx=(target_head, target_tail), token_type_ids=token_type_ids, - attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask) - + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) evaler.metrics(batch_size=opt.batch_size, fe_cnt=fe_cnt, gold_fe_type=fe_type, gold_fe_head=fe_head, \ gold_fe_tail=fe_tail, gold_frame_type=target_type, pred_fe_type=return_dic['pred_role_action'], pred_fe_head=return_dic['pred_head_action'], pred_fe_tail=return_dic['pred_tail_action'], - pred_frame_type=return_dic['pred_frame_action']) - + pred_frame_type=return_dic['pred_frame_action'], + fe_coretype=dataset.fe_coretype_table,sent_length=sent_length) if show_case: print('gold_fe_label = ', fe_type) @@ -95,6 +94,8 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.6) + frame_criterion =nn.CrossEntropyLoss() head_criterion = nn.CrossEntropyLoss() tail_criterion = nn.CrossEntropyLoss() @@ -103,13 +104,16 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): if os.path.exists(opt.save_model_path) is True: model.load_state_dict(torch.load(opt.save_model_path)) - #pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) + pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) train_dataset = FrameNetDataset(opt, config, config.train_instance_dic, device) dev_dataset = FrameNetDataset(opt, config, config.dev_instance_dic, device) test_dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + + best_metrics = -1 if opt.mode == 'train': for epoch in range(1, epochs): + scheduler.step() train_dl = DataLoader( train_dataset, batch_size=opt.batch_size, @@ -120,20 +124,20 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): print('==========epochs= ' + str(epoch)) step = 0 sum_loss = 0 - best_metrics = -1 + # best_metrics = -1 cnt = 0 for batch in train_dl: optimizer.zero_grad() loss = 0 word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, attention_mask, target_head, target_tail, \ target_type, fe_head, fe_tail, fe_type, fe_cnt, \ - fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids = batch + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids, head_ids=head_ids, rel_ids=rel_ids ,lengths=lengths, frame_idx=(target_head, target_tail), token_type_ids=token_type_ids, - attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask) + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) # print(return_dic) frame_loss = 0 @@ -156,17 +160,19 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): for fe_index in range(opt.fe_padding_num): + # print(fe_cnt[batch_index]) + + pred_type_label = return_dic['pred_role_list'][fe_index].squeeze() pred_type_label = pred_type_label[batch_index].unsqueeze(0) gold_type_label = fe_type[batch_index][fe_index].unsqueeze(0) type_loss += fe_type_criterion(pred_type_label, gold_type_label) + if fe_index >= fe_cnt[batch_index]: break - # print(fe_cnt[batch_index]) - pred_head_label = return_dic['pred_head_list'][fe_index].squeeze() pred_head_label = pred_head_label[batch_index].unsqueeze(0) @@ -182,7 +188,6 @@ def evaluate(opt, model, dataset, best_metrics=None, show_case=False): tail_loss += tail_criterion(pred_tail_label, gold_tail_label) - # print(fe_cnt[batch_index]) # head_loss /= int(fe_cnt[batch_index]) # tail_loss /= int(fe_cnt[batch_index])