From e322a0ee30f5740155df17236cd7f2b58dfefd89 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:15:07 +0800 Subject: [PATCH 1/9] Add files via upload --- config.py | 13 +- data_process4.py | 377 +++++++++++++++++++++++++++++ dataset5.py | 271 +++++++++++++++++++++ model_syntax36_final.py | 522 ++++++++++++++++++++++++++++++++++++++++ train_syntax36_final.py | 233 ++++++++++++++++++ 5 files changed, 1411 insertions(+), 5 deletions(-) create mode 100644 data_process4.py create mode 100644 dataset5.py create mode 100644 model_syntax36_final.py create mode 100644 train_syntax36_final.py diff --git a/config.py b/config.py index 74381c7..f97205b 100644 --- a/config.py +++ b/config.py @@ -17,6 +17,7 @@ def get_opt(): # 数据集位置 parser.add_argument('--data_path', type=str, default='parsed-v1.5/') parser.add_argument('--emb_file_path', type=str, default='./glove.6B/glove.6B.200d.txt') + parser.add_argument('--exemplar_instance_path', type=str, default='./exemplar_instance_dic.npy') parser.add_argument('--train_instance_path', type=str, default='./train_instance_dic.npy') parser.add_argument('--dev_instance_path', type=str, default='./dev_instance_dic.npy') parser.add_argument('--test_instance_path', type=str, default='./test_instance_dic.npy') @@ -25,21 +26,23 @@ def get_opt(): parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') parser.add_argument('--model_name', type=str, default='train_model') parser.add_argument('--pretrain_model', type=str, default='') - parser.add_argument('--save_model_path', type=str, default='./models_ft/model_syntax.bin') + parser.add_argument('--save_model_path', type=str, default='./models_ft/wo_jointarg.bin') + # 训练相关 - parser.add_argument('--lr', type=float, default='0.0001') + parser.add_argument('--lr', type=float, default='0.00006') parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--batch_size', type=int, default=2) - parser.add_argument('--epochs', type=int, default=300) + parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--save_model_freq', type=int, default=1) # 保存模型间隔,以epoch为单位 parser.add_argument('--cuda', type=str, default="cuda:0") - parser.add_argument('--mode', type=str, default="train") + parser.add_argument('--mode', type=str, default="test") # 模型的一些settings parser.add_argument('--maxlen',type=int, default=256) parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--cell_name', type=str, default='lstm') + #change parser.add_argument('--rnn_hidden_size',type=int, default=256) parser.add_argument('--rnn_emb_size', type=int, default=400) parser.add_argument('--encoder_emb_size',type=int, default=200) @@ -56,7 +59,7 @@ def get_opt(): parser.add_argument('--role_number',type=int, default=9634) parser.add_argument('--fe_padding_num',type=int, default=5) parser.add_argument('--window_size', type=int, default=3) - parser.add_argument('--num_layers', type=int, default=4) + parser.add_argument('--num_layers', type=int, default=2) return parser.parse_args() diff --git a/data_process4.py b/data_process4.py new file mode 100644 index 0000000..dab1328 --- /dev/null +++ b/data_process4.py @@ -0,0 +1,377 @@ +import numpy as np +import pandas as pd +import torch + +from config import get_opt +from utils import get_mask_from_index +from nltk.parse.stanford import StanfordDependencyParser + +def load_data(path): + f = open(path, 'r', encoding='utf-8') + lines = f.readlines() + return lines + + +def instance_process(lines, maxlen): + instance_dic = {} + + parser = StanfordDependencyParser(r".\stanford-parser-full-2015-12-09\stanford-parser.jar", + r".\stanford-parser-full-2015-12-09\stanford-parser-3.6.0-models.jar" + ) + cnt = 0 + find = False + word_list_total = [] + for line in lines: + if line[0:3] == '# i': + word_list = [] + lemma_list = [] + pos_list = [] + target_idx = [-1, -1] + span_start = [] + span_end = [] + span_type = [] + length = 0 + + elif line[0:3] == '# e': + instance_dic.setdefault((sent_id, target_type, cnt), {}) + instance_dic[(sent_id, target_type, cnt)]['dep_list'] = dep_parsing(word_list, maxlen, parser) + instance_dic[(sent_id, target_type, cnt)]['word_list'] = padding_sentence(word_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['lemma_list'] = padding_sentence(lemma_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['pos_list'] = padding_sentence(pos_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['sent_id'] = sent_id + + word_list_total.append(word_list) + # add 'eos' + instance_dic[(sent_id, target_type, cnt)]['length'] = int(length)+1 + # instance_dic[(sent_id, target_type, cnt)]['attention_mask'] = get_mask_from_index(sequence_lengths=torch.Tensor([int(length)+1]), max_length=maxlen).squeeze() + + instance_dic[(sent_id, target_type, cnt)]['target_type'] = target_type + instance_dic[(sent_id, target_type, cnt)]['lu'] = lu + instance_dic[(sent_id, target_type, cnt)]['target_idx'] = target_idx + + instance_dic[(sent_id, target_type, cnt)]['span_start'] = span_start + instance_dic[(sent_id, target_type, cnt)]['span_end'] = span_end + instance_dic[(sent_id, target_type, cnt)]['span_type'] = span_type + print(cnt) + cnt += 1 + elif line == '\n': + continue + + else: + data_list = line.split('\t') + word_list.append(data_list[1]) + lemma_list.append(data_list[3]) + pos_list.append(data_list[5]) + sent_id = data_list[6] + length = data_list[0] + + if data_list[12] != '_' and data_list[13] != '_': + lu = data_list[12] + + target_type = data_list[13] + if target_idx == [-1, -1]: + target_idx = [int(data_list[0])-1, int(data_list[0])-1] + else: + target_idx[1] =int(data_list[0]) - 1 + + if data_list[14] != '_': + + fe = data_list[14].split('-') + + if fe[0] == 'B' and find is False: + span_start.append(int(data_list[0]) - 1) + find = True + + elif fe[0] == 'O': + span_end.append(int(data_list[0]) - 1) + span_type.append(fe[-1].replace('\n', '')) + find = False + + elif fe[0] == 'S': + span_start.append(int(data_list[0]) - 1) + span_end.append(int(data_list[0]) - 1) + span_type.append(fe[-1].replace('\n', '')) + + return instance_dic + +def dep_parsing(word_list: list,maxlen: list,parser): + res = list(parser.parse(word_list)) + sent = res[0].to_conll(4).split('\n')[:-1] + #['the', 'DT', '4', 'det'] + line = [line.split('\t') for line in sent] + head_list = [] + rel_list = [] + + distance = 0 + + #alignment + for index in range(len(word_list)-1): + #end stopwords + if index-distance >= len(line): + head_list.append('#') + rel_list.append('#') + distance+=1 + + elif word_list[index]!=line[index-distance][0]: + head_list.append('#') + rel_list.append('#') + distance+=1 + else: + rel_list.append(line[index-distance][3]) + + if line[index-distance][3] != 'root': + head_list.append(word_list[int(line[index-distance][2]) - 1]) + else: + head_list.append(word_list[index]) + + head_list.append('eos') + rel_list.append('eos') + + while len(head_list) < maxlen: + head_list.append('0') + rel_list.append('0') + + return (head_list,rel_list) + + +def padding_sentence(sentence: list,maxlen: int): + sentence.append('eos') + while len(sentence) < maxlen: + sentence.append('0') + + return sentence + +class DataConfig: + def __init__(self,opt): + exemplar_lines = load_data('fn1.5/conll/exemplar') + train_lines = load_data('fn1.5/conll/train') + dev_lines = load_data('fn1.5/conll/dev') + test_lines = load_data('fn1.5/conll/test') + + + self.emb_file_path = opt.emb_file_path + self.maxlen = opt.maxlen + + if opt.load_instance_dic: + self.exemplar_instance_dic = np.load(opt.exemplar_instance_path, allow_pickle=True).item() + self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item() + self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item() + self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item() + + else: + print('begin parsing') + self.exemplar_instance_dic = instance_process(lines=exemplar_lines,maxlen=self.maxlen) + print('exemplar_instance_dic finish') + self.train_instance_dic = instance_process(lines=train_lines,maxlen=self.maxlen) + np.save('train_instance_dic', self.train_instance_dic) + print('train_instance_dic finish') + self.dev_instance_dic = instance_process(lines=dev_lines,maxlen=self.maxlen) + np.save('dev_instance_dic', self.dev_instance_dic) + print('dev_instance_dic finish') + self.test_instance_dic = instance_process(lines=test_lines,maxlen=self.maxlen) + np.save('test_instance_dic', self.test_instance_dic) + print('test_instance_dic finish') + + + self.word_index = {} + self.lemma_index = {} + self.pos_index = {} + self.rel_index = {} + + self.word_number = 0 + self.lemma_number = 0 + self.pos_number = 0 + self.rel_number = 0 + + self.build_word_index(self.exemplar_instance_dic) + self.build_word_index(self.train_instance_dic) + self.build_word_index(self.dev_instance_dic) + self.build_word_index(self.test_instance_dic) + + # add # for parsing sign + self.word_index['#']=self.word_number + self.word_number+=1 + + self.emb_index = self.build_emb_index(self.emb_file_path) + + self.word_vectors = self.get_embedding_weight(self.emb_index, self.word_index, self.word_number) + self.lemma_vectors = self.get_embedding_weight(self.emb_index, self.lemma_index, self.lemma_number) + + def build_word_index(self, dic): + for key in dic.keys(): + word_list =dic[key]['word_list'] + lemma_list = dic[key]['lemma_list'] + pos_list = dic[key]['pos_list'] + rel_list = dic[key]['dep_list'][1] + + # print(row) + for word in word_list: + if word not in self.word_index.keys(): + self.word_index[word]=self.word_number + self.word_number += 1 + + for lemma in lemma_list: + if lemma not in self.lemma_index.keys(): + self.lemma_index[lemma]=self.lemma_number + self.lemma_number += 1 + + for pos in pos_list: + if pos not in self.pos_index.keys(): + self.pos_index[pos] = self.pos_number + self.pos_number += 1 + + for rel in rel_list: + if rel not in self.rel_index.keys(): + self.rel_index[rel] = self.rel_number + self.rel_number += 1 + + def build_emb_index(self, file_path): + data = open(file_path, 'r', encoding='utf-8') + emb_index = {} + for items in data: + item = items.split() + word = item[0] + weight = np.asarray(item[1:], dtype='float32') + emb_index[word] = weight + + return emb_index + + def get_embedding_weight(self,embed_dict, words_dict, words_count, dim=200): + + exact_count = 0 + fuzzy_count = 0 + oov_count = 0 + print("loading pre_train embedding by avg for out of vocabulary.") + embeddings = np.zeros((int(words_count) + 1, int(dim))) + inword_list = {} + for word in words_dict: + if word in embed_dict: + embeddings[words_dict[word]] = embed_dict[word] + inword_list[words_dict[word]] = 1 + # 准确匹配 + exact_count += 1 + elif word.lower() in embed_dict: + embeddings[words_dict[word]] = embed_dict[word.lower()] + inword_list[words_dict[word]] = 1 + # 模糊匹配 + fuzzy_count += 1 + else: + # 未登录词 + oov_count += 1 + # print(word) + # 对已经找到的词向量平均化 + sum_col = np.sum(embeddings, axis=0) / len(inword_list) # avg + sum_col /= np.std(sum_col) + for i in range(words_count): + if i not in inword_list: + embeddings[i] = sum_col + + embeddings[int(words_count)] = [0] * dim + final_embed = np.array(embeddings) + # print('exact_count: ',exact_count) + # print('fuzzy_count: ', fuzzy_count) + # print('oov_count: ', oov_count) + return final_embed + + +def load_data_pd(dataset_path,file): + # df=csv.reader(open(dataset_path+file,encoding='utf-8')) + # df = json.load(open(file_path,encoding='utf-8')) + df=pd.read_csv(dataset_path+file, header=0, encoding='utf-8') + return df + + +def get_frame_tabel(path, file): + data = load_data_pd(path, file) + + frame_id_to_label = {} + frame_name_to_label = {} + frame_name_to_id = {} + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in frame_id_to_label: + frame_id_to_label[data['ID'][idx]] = data_index + frame_name_to_label[data['Name'][idx]] = data_index + frame_name_to_id[data['Name'][idx]] = data['ID'][idx] + + data_index += 1 + + return frame_id_to_label, frame_name_to_label, frame_name_to_id + + +def get_fe_tabel(path, file): + data = load_data_pd(path, file) + + fe_id_to_label = {} + fe_name_to_label = {} + fe_name_to_id = {} + fe_id_to_type = {} + + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in fe_id_to_label: + fe_id_to_label[data['ID'][idx]] = data_index + fe_name_to_label[(data['Name'][idx], data['FrameID'][idx])] = data_index + fe_name_to_id[(data['Name'][idx], data['FrameID'][idx])] = data['ID'][idx] + fe_id_to_type[data['ID'][idx]] = data['CoreType'][idx] + + data_index += 1 + + return fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type + + +def get_fe_list(path, fe_num, fe_table, file='FE.csv'): + fe_dt = load_data_pd(path, file) + fe_mask_list = {} + + print('begin get fe list') + for idx in range(len(fe_dt['FrameID'])): + fe_mask_list.setdefault(fe_dt['FrameID'][idx], [0]*(fe_num+1)) + # fe_mask_list[fe_dt['FrameID'][idx]].setdefault('fe_mask', [0]*(fe_num+1)) + fe_mask_list[fe_dt['FrameID'][idx]][fe_table[fe_dt['ID'][idx]]] = 1 + + # for key in fe_list.keys(): + # fe_list[key]['fe_mask'][fe_num] = 1 + + return fe_mask_list + + +def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU.csv'): + lu_dt = load_data_pd(path, file) + lu_list = {} + lu_id_to_name = {} + lu_name_to_id = {} + #lu_name_to_felist = {} + + for idx in range(len(lu_dt['ID'])): + lu_name = lu_dt['Name'][idx] + lu_list.setdefault(lu_name, {}) + + lu_list[lu_name].setdefault('fe_mask', [0]*(fe_num+1)) + lu_list[lu_name]['fe_mask'] = list(map(lambda x: x[0]+x[1], zip(lu_list[lu_name]['fe_mask'], + fe_mask_list[lu_dt['FrameID'][idx]]))) + + lu_list[lu_name].setdefault('lu_mask', [0]*(lu_num+1)) + lu_list[lu_name]['lu_mask'][frame_id_to_label[lu_dt['FrameID'][idx]]] = 1 + + lu_id_to_name[lu_dt['ID'][idx]] = lu_name + lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx] + + for key in lu_list.keys(): + # lu_list[key]['lu_mask'][lu_num] = 1 + lu_list[key]['fe_mask'][fe_num] = 1 + + return lu_list, lu_id_to_name, lu_name_to_id + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + print(config.word_vectors) + print(config.lemma_number) + print(config.word_number) + print(config.pos_number) + print(config.dep_number) + + + diff --git a/dataset5.py b/dataset5.py new file mode 100644 index 0000000..8de0da3 --- /dev/null +++ b/dataset5.py @@ -0,0 +1,271 @@ + +import torch +import numpy as np + +from torch.utils.data import Dataset +from config import get_opt +from data_process4 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig +from utils import get_mask_from_index + +class FrameNetDataset(Dataset): + + def __init__(self, opt, config, data_dic, device): + super(FrameNetDataset, self).__init__() + print('begin load data') + self.data_dic = data_dic + self.fe_id_to_label, self.fe_name_to_label, self.fe_name_to_id, self.fe_id_to_type = get_fe_tabel('parsed-v1.5/', 'FE.csv') + self.frame_id_to_label, self.frame_name_to_label, self.frame_name_to_id = get_frame_tabel('parsed-v1.5/', 'frame.csv') + + self.word_index = config.word_index + self.lemma_index = config.lemma_index + self.pos_index = config.pos_index + self.rel_index = config.rel_index + + self.fe_num = len(self.fe_id_to_label) + self.frame_num = len(self.frame_id_to_label) + self.batch_size = opt.batch_size + print(self.fe_num) + print(self.frame_num) + self.dataset_len = len(self.data_dic) + + self.fe_mask_list = get_fe_list('parsed-v1.5/', self.fe_num, self.fe_id_to_label) + self.lu_list, self.lu_id_to_name,\ + self.lu_name_to_id = get_lu_list('parsed-v1.5/', + self.frame_num, self.fe_num, + self.frame_id_to_label, + self.fe_mask_list) + + self.word_ids = [] + self.lemma_ids = [] + self.pos_ids = [] + self.head_ids = [] + self.rel_ids = [] + + self.lengths = [] + self.mask = [] + self.target_head = [] + self.target_tail = [] + self.target_type = [] + self.fe_head = [] + self.fe_tail = [] + self.fe_type = [] + self.fe_coretype = [] + self.sent_length = [] + self.fe_cnt = [] + self.fe_cnt_with_padding =[] + self.fe_mask = [] + self.lu_mask = [] + self.token_type_ids = [] + self.target_mask_ids = [] + + self.device = device + self.oov_frame = 0 + self.long_span = 0 + self.error_span = 0 + self.fe_coretype_table = {} + self.target_mask = {} + + for idx in self.fe_id_to_type.keys(): + if self.fe_id_to_type[idx] == 'Core': + self.fe_coretype_table[self.fe_id_to_label[idx]] = 1 + else: + self.fe_coretype_table[self.fe_id_to_label[idx]] = 0 + + + + for key in self.data_dic.keys(): + self.build_target_mask(key,opt.maxlen) + + + for key in self.data_dic.keys(): + self.pre_process(key, opt) + + self.pad_dic_cnt = self.dataset_len % opt.batch_size + + + for idx,key in enumerate(self.data_dic.keys()): + if idx >= self.pad_dic_cnt: + break + self.pre_process(key, opt,filter=False) + + self.dataset_len+=self.pad_dic_cnt + + print('load data finish') + print('oov frame = ', self.oov_frame) + print('long_span = ', self.long_span) + print('dataset_len = ', self.dataset_len) + + def __len__(self): + self.dataset_len = int(self.dataset_len / self.batch_size) * self.batch_size + return self.dataset_len + + def __getitem__(self, item): + word_ids = torch.Tensor(self.word_ids[item]).long().to(self.device) + lemma_ids = torch.Tensor(self.lemma_ids[item]).long().to(self.device) + pos_ids = torch.Tensor(self.pos_ids[item]).long().to(self.device) + head_ids = torch.Tensor(self.head_ids[item]).long().to(self.device) + rel_ids = torch.Tensor(self.rel_ids[item]).long().to(self.device) + lengths = torch.Tensor([self.lengths[item]]).long().to(self.device) + mask = self.mask[item].long().to(self.device) + target_head = torch.Tensor([self.target_head[item]]).long().to(self.device) + target_tail = torch.Tensor([self.target_tail[item]]).long().to(self.device) + target_type = torch.Tensor([self.target_type[item]]).long().to(self.device) + fe_head = torch.Tensor(self.fe_head[item]).long().to(self.device) + fe_tail = torch.Tensor(self.fe_tail[item]).long().to(self.device) + fe_type = torch.Tensor(self.fe_type[item]).long().to(self.device) + fe_cnt = torch.Tensor([self.fe_cnt[item]]).long().to(self.device) + fe_cnt_with_padding = torch.Tensor([self.fe_cnt_with_padding[item]]).long().to(self.device) + fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device) + lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device) + token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device) + sent_length = torch.Tensor([self.sent_length[item]]).long().to(self.device) + target_mask_ids = torch.Tensor(self.target_mask_ids[item]).long().to(self.device) + # print(fe_cnt) + + + return (word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, mask, target_head, target_tail, target_type, + fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding, + fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids) + + def pre_process(self, key, opt,filter=True): + if self.data_dic[key]['target_type'] not in self.frame_name_to_label: + self.oov_frame += 1 + self.dataset_len -= 1 + return + + target_id = self.frame_name_to_id[self.data_dic[key]['target_type']] + if filter: + self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], + self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20) + + word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']] + lemma_ids = [self.lemma_index[lemma] for lemma in self.data_dic[key]['lemma_list']] + pos_ids = [self.pos_index[pos] for pos in self.data_dic[key]['pos_list']] + head_ids = [self.word_index[head] for head in self.data_dic[key]['dep_list'][0]] + rel_ids = [self.rel_index[rel] for rel in self.data_dic[key]['dep_list'][1]] + + self.word_ids.append(word_ids) + self.lemma_ids.append(lemma_ids) + self.pos_ids.append(pos_ids) + self.head_ids.append(head_ids) + self.rel_ids.append(rel_ids) + self.lengths.append(self.data_dic[key]['length']) + + # self.mask.append(self.data_dic[key]['attention_mask']) + self.target_head.append(self.data_dic[key]['target_idx'][0]) + self.target_tail.append(self.data_dic[key]['target_idx'][1]) + + mask = get_mask_from_index(torch.Tensor([int(self.data_dic[key]['length'])]), opt.maxlen).squeeze() + self.mask.append(mask) + + token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen) + # token_type_ids +=self.target_mask[key[0]] + self.token_type_ids.append(token_type_ids) + self.target_mask_ids.append(self.target_mask[key[0]]) + + self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']]) + + # print(self.frame_tabel[self.fe_data[key]['frame_ID']]) + + if self.data_dic[key]['length'] <= opt.maxlen: + sent_length = self.data_dic[key]['length'] + else: + sent_length = opt.maxlen + self.sent_length.append(sent_length) + + lu_name = self.data_dic[key]['lu'] + self.lu_mask.append(self.lu_list[lu_name]['lu_mask']) + self.fe_mask.append(self.lu_list[lu_name]['fe_mask']) + + fe_head = self.data_dic[key]['span_start'] + fe_tail = self.data_dic[key]['span_end'] + + + + while len(fe_head) < opt.fe_padding_num: + fe_head.append(min(sent_length-1, opt.maxlen-1)) + + while len(fe_tail) < opt.fe_padding_num: + fe_tail.append(min(sent_length-1,opt.maxlen-1)) + + self.fe_head.append(fe_head[0:opt.fe_padding_num]) + self.fe_tail.append(fe_tail[0:opt.fe_padding_num]) + + fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']] + + self.fe_cnt.append(min(len(fe_type), opt.fe_padding_num)) + self.fe_cnt_with_padding.append(min(len(fe_type)+1, opt.fe_padding_num)) + + while len(fe_type) < opt.fe_padding_num: + fe_type.append(self.fe_num) + # fe_coretype.append('0') + + self.fe_type.append(fe_type[0:opt.fe_padding_num]) + + def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, target_id, span_maxlen): + indices = [] + for index in range(len(fe_head_list)): + if fe_tail_list[index] - fe_head_list[index] >= span_maxlen: + indices.append(index) + elif fe_tail_list[index] < fe_head_list[index]: + indices.append(index) + + + elif (fe_type_list[index], target_id) not in self.fe_name_to_label: + indices.append(index) + + else: + for i in range(index): + if i not in indices: + if fe_head_list[index] >= fe_head_list[i] and fe_head_list[index] <= fe_tail_list[i]: + indices.append(index) + break + + elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]: + indices.append(index) + break + elif fe_tail_list[index] <= fe_head_list[i] and fe_tail_list[index] >= fe_tail_list[i]: + indices.append(index) + break + else: + continue + + fe_head_list_filter = [i for j, i in enumerate(fe_head_list) if j not in indices] + fe_tail_list_filter = [i for j, i in enumerate(fe_tail_list) if j not in indices] + fe_type_list_filter = [i for j, i in enumerate(fe_type_list) if j not in indices] + self.data_dic[key]['span_start'] = fe_head_list_filter + self.data_dic[key]['span_end'] = fe_tail_list_filter + self.data_dic[key]['span_type'] = fe_type_list_filter + + return len(indices) + + def build_target_mask(self,key,maxlen): + self.target_mask.setdefault(key[0], [0]*maxlen) + + target_head = self.data_dic[key]['target_idx'][0] + target_tail = self.data_dic[key]['target_idx'][1] + self.target_mask[key[0]][target_head] = 1 + self.target_mask[key[0]][target_tail] = 1 + + + + + +def build_token_type_ids(target_head, target_tail, maxlen): + token_type_ids = [0]*maxlen + token_type_ids[target_head] = 1 + token_type_ids[target_tail] = 1 + # token_type_ids[target_head:target_tail+1] = [1]*(target_tail+1-target_head) + + return token_type_ids + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + print(dataset.error_span) diff --git a/model_syntax36_final.py b/model_syntax36_final.py new file mode 100644 index 0000000..ee5110d --- /dev/null +++ b/model_syntax36_final.py @@ -0,0 +1,522 @@ +import numpy as np +from typing import List,Tuple +import os +import json + +import torch.nn as nn +import torch +import torch.nn.functional as F + +from utils import batched_index_select,get_mask_from_index,generate_perm_inv + + +class Mlp(nn.Module): + def __init__(self, input_size, output_size): + super(Mlp, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, input_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + nn.Linear(input_size, output_size), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class Relu_Linear(nn.Module): + def __init__(self, input_size, output_size): + super(Relu_Linear, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, output_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class CnnNet(nn.Module): + def __init__(self, kernel_size, seq_length, input_size, output_size): + super(CnnNet, self).__init__() + self.seq_length = seq_length + self.output_size = output_size + self.kernel_size = kernel_size + + self.relu = nn.ReLU() + self.conv = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=self.kernel_size + , padding=1) + self.mp = nn.MaxPool1d(kernel_size=self.seq_length) + + def forward(self, input_emb): + input_emb = input_emb.permute(0, 2, 1) + x = self.conv(input_emb) + output = self.mp(x).squeeze() + + return output + + +class PointerNet(nn.Module): + def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): + super(PointerNet, self).__init__() + + assert attention_type in ('affine', 'dot_prod') + if attention_type == 'affine': + self.src_encoding_linear = Mlp(src_encoding_size, query_vec_size) + + self.src_linear = Mlp(src_encoding_size,src_encoding_size) + self.activate = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(0.5) + + self.fc =nn.Linear(src_encoding_size*2,src_encoding_size, bias=True) + + self.attention_type = attention_type + + def forward(self, src_encodings, src_token_mask,query_vec,head_vec=None): + + # (batch_size, 1, src_sent_len, query_vec_size) + if self.attention_type == 'affine': + src_encod = self.src_encoding_linear(src_encodings).unsqueeze(1) + head_weights = self.src_linear(src_encodings).unsqueeze(1) + + # (batch_size, tgt_action_num, query_vec_size, 1) + if head_vec is not None: + src_encod = torch.cat([src_encod,head_weights],dim = -1) + q = torch.cat([head_vec, query_vec], dim=-1).permute(1, 0, 2).unsqueeze(3) + + + else: + q = query_vec.permute(1, 0, 2).unsqueeze(3) + + weights = torch.matmul(src_encod, q).squeeze(3) + ptr_weights = weights.permute(1, 0, 2) + + # if head_vec is not None: + # src_weights = torch.matmul(head_weights, q_h).squeeze(3) + # src_weights = src_weights.permute(1, 0, 2) + # ptr_weights = weights+src_weights + # + # else: + # ptr_weights = weights + + ptr_weights_masked = ptr_weights.clone().detach() + if src_token_mask is not None: + # (tgt_action_num, batch_size, src_sent_len) + src_token_mask=1-src_token_mask.byte() + src_token_mask = src_token_mask.unsqueeze(0).expand_as(ptr_weights) + # ptr_weights.data.masked_fill_(src_token_mask, -float('inf')) + ptr_weights_masked.data.masked_fill_(src_token_mask, -float('inf')) + + # ptr_weights =self.activate(ptr_weights) + + return ptr_weights,ptr_weights_masked + + +class Encoder(nn.Module): + def __init__(self, opt, config, word_embedding:nn.modules.sparse.Embedding, + lemma_embedding: nn.modules.sparse.Embedding): + super(Encoder, self).__init__() + self.opt =opt + self.hidden_size = opt.rnn_hidden_size + self.emb_size = opt.encoder_emb_size + self.rnn_input_size = self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size + self.word_number = config.word_number + self.lemma_number = config.lemma_number + self.maxlen = opt.maxlen + + self.cnn = CnnNet(opt.kernel_size, seq_length=opt.maxlen, + input_size=self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size, + output_size=opt.sent_emb_size) + + self.dropout = 0.2 + self.word_embedding = word_embedding + self.lemma_embedding = lemma_embedding + self.pos_embedding = nn.Embedding(config.pos_number, opt.pos_emb_size) + self.rel_embedding = nn.Embedding(config.rel_number,opt.rel_emb_size) + self.token_type_embedding = nn.Embedding(2, opt.token_type_emb_size) + self.cell_name = opt.cell_name + + # self.embedded_linear = nn.Linear(self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size+opt.sent_emb_size, + # self.rnn_input_size) + self.syntax_embedded_linear = nn.Linear(self.emb_size*2+opt.rel_emb_size+opt.token_type_emb_size, + self.rnn_input_size) + # self.output_combine_linear = nn.Linear(4*self.hidden_size, 2*self.hidden_size) + + self.target_linear = nn.Linear(2*self.hidden_size, 2*self.hidden_size) + + self.relu_linear = Relu_Linear(4*self.hidden_size+self.rnn_input_size, opt.decoder_emb_size) + + if self.cell_name == 'gru': + self.rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + elif self.cell_name == 'lstm': + self.rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + else: + print('cell_name error') + + def forward(self, word_input: torch.Tensor, lemma_input: torch.Tensor, pos_input: torch.Tensor, + head_input:torch.Tensor, rel_input:torch.Tensor, + lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None,target_mask_ids=None): + + word_embedded = self.word_embedding(word_input) + lemma_embedded = self.lemma_embedding(lemma_input) + pos_embedded = self.pos_embedding(pos_input) + # head_embedded = self.word_embedding(head_input) + # rel_embedded = self.rel_embedding(rel_input) + # type_ids =torch.add(token_type_ids, target_mask_ids) + token_type_embedded = self.token_type_embedding(token_type_ids) + #print(token_type_embedded) + # print(token_type_ids.size()) + # print(target_mask_ids.size()) + # print(token_type_embedded.size()) + embedded = torch.cat([word_embedded, lemma_embedded, pos_embedded,token_type_embedded], dim=-1) + # sent_embedded = self.cnn(embedded) + + # sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) + #embedded = torch.cat([embedded,sent_embedded], dim=-1) + #embedded = self.embedded_linear(embedded) + + # syntax embedding + # syntax_embedded = torch.cat([word_embedded,head_embedded,rel_embedded,token_type_ids],dim=-1) + # syntax_embedded = self.syntax_embedded_linear(syntax_embedded) + + lengths=lengths.squeeze() + # sorted before pack + l = lengths.cpu().numpy() + perm_idx = np.argsort(-l) + perm_idx_inv = generate_perm_inv(perm_idx) + + embedded = embedded[perm_idx] + # syntax_embedded = syntax_embedded[perm_idx] + + if lengths is not None: + rnn_input = nn.utils.rnn.pack_padded_sequence(embedded, lengths=lengths[perm_idx], + batch_first=True) + # syntax_rnn_input = nn.utils.rnn.pack_padded_sequence(syntax_embedded, lengths=lengths[perm_idx], + # batch_first=True) + output, hidden = self.rnn(rnn_input) + #syntax_output, syntax_hidden = self.rnn(syntax_rnn_input) + + if lengths is not None: + output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=self.maxlen, batch_first=True) + # syntax_output, _ = nn.utils.rnn.pad_packed_sequence(syntax_output, total_length=self.maxlen, batch_first=True) + + + # print(output.size()) + # print(hidden.size()) + + output = output[perm_idx_inv] + # syntax_output = syntax_output[perm_idx_inv] + + if self.cell_name == 'gru': + hidden = hidden[:, perm_idx_inv] + hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(hidden, 1, dim=0)) + + # syntax_hidden = syntax_hidden[:, perm_idx_inv] + # syntax_hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(syntax_hidden, 1, dim=0)) + # hidden = (hidden + syntax_hidden) / 2 + + elif self.cell_name == 'lstm': + hn0 = hidden[0][:, perm_idx_inv] + hn1 = hidden[1][:, perm_idx_inv] + # sy_hn0 = syntax_hidden[0][:, perm_idx_inv] + # sy_hn1 = syntax_hidden[1][:, perm_idx_inv] + hn = tuple([hn0,hn1]) + hidden = tuple(map(lambda state: sum(torch.split(state, 1, dim=0))/(2*self.opt.num_layers), hn)) + + + target_state_head = batched_index_select(target=output, indices=frame_idx[0]) + target_state_tail = batched_index_select(target=output, indices=frame_idx[1]) + target_state = (target_state_head + target_state_tail) / 2 + target_state = self.target_linear(target_state) + + target_emb_head = batched_index_select(target=embedded, indices=frame_idx[0]) + target_emb_tail = batched_index_select(target=embedded, indices=frame_idx[1]) + target_emb = (target_emb_head + target_emb_tail) / 2 + + attentional_target_state = type_attention(attention_mask=target_mask_ids, hidden_state=output, + target_state=target_state) + + + target_state =torch.cat([target_state.squeeze(),attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) + target = self.relu_linear(target_state) + + # print(output.size()) + return output, hidden, target + + +class Decoder(nn.Module): + def __init__(self, opt, embedding_frozen=False): + super(Decoder, self).__init__() + + # rnn _init_ + self.opt = opt + self.cell_name = opt.cell_name + self.emb_size = opt.decoder_emb_size + self.hidden_size = opt.decoder_hidden_size + self.encoder_hidden_size = opt.rnn_hidden_size + + # decoder _init_ + self.decodelen = opt.fe_padding_num+1 + self.frame_embedding = nn.Embedding(opt.frame_number+1, self.emb_size) + self.frame_fc_layer =Mlp(self.emb_size, opt.frame_number+1) + self.role_embedding = nn.Embedding(opt.role_number+1, self.emb_size) + self.role_feature_layer = nn.Linear(2*self.emb_size, self.emb_size) + self.role_fc_layer = nn.Linear(self.hidden_size+self.emb_size, opt.role_number+1) + + self.head_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + self.tail_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + + self.span_fc_layer = Mlp(4 * self.encoder_hidden_size + self.emb_size, self.emb_size) + + self.next_input_fc_layer = Mlp(self.hidden_size+self.emb_size, self.emb_size) + + if embedding_frozen is True: + for param in self.frame_embedding.parameters(): + param.requires_grad = False + for param in self.role_embedding.parameters(): + param.requires_grad = False + + if self.cell_name == 'gru': + self.frame_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + if self.cell_name == 'lstm': + self.frame_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + + # pointer _init_ + self.ent_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.head_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.tail_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + + def forward(self, encoder_output: torch.Tensor, encoder_state: torch.Tensor, target_state: torch.Tensor, + attention_mask: torch.Tensor, fe_mask=None, lu_mask=None): + pred_frame_list = [] + pred_head_list = [] + pred_tail_list = [] + pred_role_list = [] + + pred_frame_action = [] + + pred_head_action = [] + pred_tail_action = [] + pred_role_action = [] + + frame_decoder_state = encoder_state + role_decoder_state = encoder_state + + input = target_state + # print(input.size()) + + span_mask = attention_mask.clone() + for t in range(self.decodelen): + # frame pred + output, frame_decoder_state = self.decode_step(self.frame_rnn, input=input, + decoder_state=frame_decoder_state) + + pred_frame_weight = self.frame_fc_layer(target_state.squeeze()) + pred_frame_weight_masked = pred_frame_weight.clone().detach() + + if lu_mask is not None: + LU_mask = 1-lu_mask + pred_frame_weight_masked.data.masked_fill_(LU_mask.byte(), -float('inf')) + + pred_frame_indices = torch.argmax(pred_frame_weight_masked.squeeze(), dim=-1).squeeze() + + pred_frame_list.append(pred_frame_weight) + pred_frame_action.append(pred_frame_indices) + + frame_emb = self.frame_embedding(pred_frame_indices) + + head_input = self.head_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + tail_input = self.tail_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + + head_pointer_weight, head_pointer_weight_masked = self.head_pointer(src_encodings=encoder_output, + src_token_mask=span_mask, + query_vec=head_input.view(1, self.opt.batch_size, -1)) + + head_indices = torch.argmax(head_pointer_weight_masked.squeeze(), dim=-1).squeeze() + head_target = batched_index_select(target=encoder_output, indices=head_indices.squeeze()) + head_mask = head_mask_update(span_mask, head_indices=head_indices, max_len=self.opt.maxlen) + + tail_pointer_weight, tail_pointer_weight_masked = self.tail_pointer(src_encodings=encoder_output, + src_token_mask=head_mask, + query_vec=tail_input.view(1, self.opt.batch_size,-1), + head_vec=head_target.view(1,self.opt.batch_size,-1)) + + tail_indices = torch.argmax(tail_pointer_weight_masked.squeeze(), dim=-1).squeeze() + # tail_target = batched_index_select(target=bert_hidden_state,indices=tail_indices.squeeze()) + + span_mask = span_mask_update(attention_mask=span_mask, head_indices=head_indices, + tail_indices=tail_indices, max_len=self.opt.maxlen) + + pred_head_list.append(head_pointer_weight) + pred_tail_list.append(tail_pointer_weight) + pred_head_action.append(head_indices) + pred_tail_action.append(tail_indices) + + # role pred + + # print(ent_target.size()) + # print(head_target.size()) + # print(tail_target.size()) + # print(bert_hidden_state.size()) + # print(tail_pointer_weight.size()) + # print(tail_indices.size()) + # print(output.size()) + + # next step + # head_target = batched_index_select(target=bert_hidden_state, indices=head_indices.squeeze()) + + tail_target = batched_index_select(target=encoder_output, indices=tail_indices.squeeze()) + + # head_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=head_target, + # window_size=0, max_len=self.opt.maxlen) + # + # tail_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=tail_target, + # window_size=0, max_len=self.opt.maxlen) + + span_input = self.span_fc_layer(torch.cat([head_target+tail_target, head_target-tail_target, frame_emb], dim=-1)).unsqueeze(1) + + output,role_decoder_state = self.decode_step(self.role_rnn, input=span_input, + decoder_state=role_decoder_state) + role_target = self.role_fc_layer(torch.cat([span_input,output],dim=-1)) + role_target_masked = role_target.squeeze().clone().detach() + if fe_mask is not None : + FE_mask = 1-fe_mask + # print(FE_mask.size()) + role_target_masked.data.masked_fill_(FE_mask.byte(), -float('inf')) + + role_indices = torch.argmax(role_target_masked.squeeze(), dim=-1).squeeze() + role_emb = self.role_embedding(role_indices) + + pred_role_list.append(role_target) + pred_role_action.append(role_indices) + + # next step + next_input =torch.cat([output, role_emb.unsqueeze(1)], dim=-1) + input = self.next_input_fc_layer(next_input) + + + #return + return pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action,\ + pred_head_action, pred_tail_action, pred_role_action + + def decode_step(self, rnn_cell: nn.modules, input: torch.Tensor, decoder_state: torch.Tensor): + + output, state = rnn_cell(input.view(-1, 1, self.emb_size), decoder_state) + + return output, state + + +class Model(nn.Module): + def __init__(self, opt, config, load_emb=True): + super(Model, self).__init__() + self.word_vectors = config.word_vectors + self.lemma_vectors = config.lemma_vectors + self.word_embedding = nn.Embedding(config.word_number+1, opt.encoder_emb_size) + self.lemma_embedding = nn.Embedding(config.lemma_number+1, opt.encoder_emb_size) + + if load_emb: + self.load_pretrain_emb() + + self.encoder = Encoder(opt, config, self.word_embedding, self.lemma_embedding) + self.decoder = Decoder(opt) + + def load_pretrain_emb(self): + self.word_embedding.weight.data.copy_(torch.from_numpy(self.word_vectors)) + self.lemma_embedding.weight.data.copy_(torch.from_numpy(self.lemma_vectors)) + + def forward(self, word_ids, lemma_ids, pos_ids, head_ids, rel_ids, lengths, frame_idx, fe_mask=None, lu_mask=None, + frame_len=None, token_type_ids=None, attention_mask=None,target_mask_ids=None): + encoder_output, encoder_state, target_state = self.encoder(word_input=word_ids, lemma_input=lemma_ids, + pos_input=pos_ids, head_input=head_ids, + rel_input=rel_ids, lengths=lengths, + frame_idx=frame_idx, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + target_mask_ids=target_mask_ids) + + pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + pred_head_action, pred_tail_action, pred_role_action = self.decoder(encoder_output=encoder_output, + encoder_state=encoder_state, + target_state=target_state, + attention_mask=attention_mask, + fe_mask=fe_mask, lu_mask=lu_mask) + + # return pred_frame_list, pred_ent_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + # pred_ent_action, pred_head_action, pred_tail_action, pred_role_action + + return { + 'pred_frame_list' : pred_frame_list, + 'pred_head_list' : pred_head_list, + 'pred_tail_list' : pred_tail_list, + 'pred_role_list' : pred_role_list, + 'pred_frame_action' : pred_frame_action, + 'pred_head_action' : pred_head_action, + 'pred_tail_action' : pred_tail_action, + 'pred_role_action' : pred_role_action + } + + +def head_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, max_len): + indices=head_indices + indices_mask=1-get_mask_from_index(indices, max_len) + mask = torch.mul(attention_mask, indices_mask.long()) + + return mask + + +def span_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, tail_indices: torch.Tensor, max_len): + tail = tail_indices + 1 + head_indices_mask = get_mask_from_index(head_indices, max_len) + tail_indices_mask = get_mask_from_index(tail, max_len) + span_indices_mask = tail_indices_mask - head_indices_mask + span_indices_mask = 1 - span_indices_mask + mask = torch.mul(attention_mask, span_indices_mask.long()) + + return mask + + +def local_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, frame_idx, + target_state: torch.Tensor, window_size: int, max_len): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + head = frame_idx[0]-window_size + tail = frame_idx[1]+window_size + mask = span_mask_update(attention_mask=attention_mask, head_indices=head.squeeze(), + tail_indices=tail.squeeze(), max_len=max_len) + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + +def type_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, + target_state: torch.Tensor): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + + mask = 1-attention_mask + + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + + diff --git a/train_syntax36_final.py b/train_syntax36_final.py new file mode 100644 index 0000000..2a2a5d9 --- /dev/null +++ b/train_syntax36_final.py @@ -0,0 +1,233 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import os +import multiprocessing as mp + +from dataset5 import FrameNetDataset +from torch.utils.data import DataLoader +from utils import get_mask_from_index,seed_everything +from evaluate8 import Eval +from config import get_opt +from data_process4 import DataConfig +from model_syntax36_final import Model + + +def evaluate(opt, model, dataset, best_metrics=None, show_case=False): + model.eval() + print('begin eval') + evaler = Eval(opt) + with torch.no_grad(): + test_dl = DataLoader( + dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + + for batch in test_dl: + word_ids, lemma_ids, pos_ids,head_ids,rel_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids,head_ids=head_ids, + rel_ids=rel_ids,lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + + evaler.metrics(batch_size=opt.batch_size, fe_cnt=fe_cnt, gold_fe_type=fe_type, gold_fe_head=fe_head, \ + gold_fe_tail=fe_tail, gold_frame_type=target_type, + pred_fe_type=return_dic['pred_role_action'], + pred_fe_head=return_dic['pred_head_action'], + pred_fe_tail=return_dic['pred_tail_action'], + pred_frame_type=return_dic['pred_frame_action'], + fe_coretype=dataset.fe_coretype_table,sent_length=sent_length) + + if show_case: + print('gold_fe_label = ', fe_type) + print('pred_fe_label = ', return_dic['pred_role_action']) + print('gold_head_label = ', fe_head) + print('pred_head_label = ', return_dic['pred_head_action']) + print('gold_tail_label = ', fe_tail) + print('pred_tail_label = ', return_dic['pred_tail_action']) + + metrics = evaler.calculate() + + + if best_metrics: + + if metrics[-1] > best_metrics: + best_metrics = metrics[-1] + + torch.save(model.state_dict(), opt.save_model_path) + + return best_metrics + + + + + + +if __name__ == '__main__': + + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + # os.environ['CUDA_LAUNCH_BLOCKING'] = 1 + + # bertconfig = BertConfig() + # print(bertconfig.CONFIG_PATH) + mp.set_start_method('spawn') + + opt = get_opt() + config = DataConfig(opt) + + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + + seed_everything(1116) + + epochs = opt.epochs + model = Model(opt, config) + model.to(device) + + optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.6) + + frame_criterion =nn.CrossEntropyLoss() + head_criterion = nn.CrossEntropyLoss() + tail_criterion = nn.CrossEntropyLoss() + fe_type_criterion = nn.CrossEntropyLoss() + + if os.path.exists(opt.save_model_path) is True: + model.load_state_dict(torch.load(opt.save_model_path)) + + pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) + train_dataset = FrameNetDataset(opt, config, config.train_instance_dic, device) + dev_dataset = FrameNetDataset(opt, config, config.dev_instance_dic, device) + test_dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + + best_metrics = -1 + + if opt.mode == 'train': + for epoch in range(1, epochs): + scheduler.step() + train_dl = DataLoader( + train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + model.train() + print('==========epochs= ' + str(epoch)) + step = 0 + sum_loss = 0 + # best_metrics = -1 + cnt = 0 + for batch in train_dl: + optimizer.zero_grad() + loss = 0 + word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids, head_ids=head_ids, rel_ids=rel_ids + ,lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + # print(return_dic) + + frame_loss = 0 + head_loss = 0 + tail_loss = 0 + type_loss = 0 + + for batch_index in range(opt.batch_size): + pred_frame_first = return_dic['pred_frame_list'][fe_cnt[batch_index]][batch_index].unsqueeze(0) + pred_frame_last = return_dic['pred_frame_list'][0][batch_index].unsqueeze(0) + + pred_frame_label = pred_frame_last + + gold_frame_label = target_type[batch_index] + # print(gold_frame_label.size()) + # print(pred_frame_label.size()) + # print(fe_head) + frame_loss += frame_criterion(pred_frame_label, gold_frame_label) + + + for fe_index in range(opt.fe_padding_num): + + # print(fe_cnt[batch_index]) + + + pred_type_label = return_dic['pred_role_list'][fe_index].squeeze() + pred_type_label = pred_type_label[batch_index].unsqueeze(0) + + gold_type_label = fe_type[batch_index][fe_index].unsqueeze(0) + type_loss += fe_type_criterion(pred_type_label, gold_type_label) + + + if fe_index >= fe_cnt[batch_index]: + break + + pred_head_label = return_dic['pred_head_list'][fe_index].squeeze() + pred_head_label = pred_head_label[batch_index].unsqueeze(0) + + gold_head_label = fe_head[batch_index][fe_index].unsqueeze(0) + # print(gold_head_label.size()) + # print(pred_head_label.size()) + head_loss += head_criterion(pred_head_label, gold_head_label) + + pred_tail_label = return_dic['pred_tail_list'][fe_index].squeeze() + pred_tail_label = pred_tail_label[batch_index].unsqueeze(0) + + gold_tail_label = fe_tail[batch_index][fe_index].unsqueeze(0) + tail_loss += tail_criterion(pred_tail_label, gold_tail_label) + + + # print(fe_cnt[batch_index]) + # head_loss /= int(fe_cnt[batch_index]) + # tail_loss /= int(fe_cnt[batch_index]) + # type_loss /= int(fe_cnt[batch_index]+1) + # + # head_loss_total+=head_loss + # tail_loss_total+=tail_loss + # type_loss_total+=type_loss + + loss = (0.1 * frame_loss + 0.3 * type_loss + 0.3 * head_loss + 0.3 * tail_loss) / (opt.batch_size) + # loss = (0.3 * head_loss + 0.3 * tail_loss) / (opt.b0.3 * atch_size) + loss.backward() + optimizer.step() + # loss+=frame_loss() + step += 1 + if step % 20 == 0: + print(" | batch loss: %.6f step = %d" % (loss.item(), step)) + # print('gold_frame_label = ',target_type) + # print('pred_frame_label = ',return_dic['pred_frame_action']) + + for index in range(len(target_type)): + if target_type[index] == return_dic['pred_frame_action'][0][index]: + cnt += 1 + # print('gold_fe_label = ', fe_type) + # print('pred_fe_label = ', return_dic['pred_role_action']) + # print('gold_head_label = ', fe_head) + # print('pred_head_label = ', return_dic['pred_head_action']) + # print('gold_tail_label = ', fe_tail) + # print('pred_tail_label = ', return_dic['pred_tail_action']) + sum_loss += loss.item() + + print('| epoch %d avg loss = %.6f' % (epoch, sum_loss / step)) + print('| epoch %d prec = %.6f' % (epoch, cnt / (opt.batch_size * step))) + + best_metrics=evaluate(opt,model,dev_dataset,best_metrics) + + + else: + evaluate(opt,model,test_dataset,show_case=True) + + + + From f480b41ea2ae9a435e23bbb804edf2ee4e001ba7 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:16:57 +0800 Subject: [PATCH 2/9] Delete train_syntax.py --- train_syntax.py | 228 ------------------------------------------------ 1 file changed, 228 deletions(-) delete mode 100644 train_syntax.py diff --git a/train_syntax.py b/train_syntax.py deleted file mode 100644 index ed30f85..0000000 --- a/train_syntax.py +++ /dev/null @@ -1,228 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import os -import multiprocessing as mp - -from dataset2 import FrameNetDataset -from torch.utils.data import DataLoader -from utils import get_mask_from_index,seed_everything -from evaluate2 import Eval -from config import get_opt -from data_process2 import DataConfig -from model_syntax2 import Model - - -def evaluate(opt, model, dataset, best_metrics=None, show_case=False): - model.eval() - print('begin eval') - evaler = Eval(opt) - with torch.no_grad(): - test_dl = DataLoader( - dataset, - batch_size=opt.batch_size, - shuffle=True, - num_workers=0 - ) - - for batch in test_dl: - word_ids, lemma_ids, pos_ids,head_ids,rel_ids, lengths, attention_mask, target_head, target_tail, \ - target_type, fe_head, fe_tail, fe_type, fe_cnt, \ - fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids = batch - - return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids,head_ids=head_ids, - rel_ids=rel_ids,lengths=lengths, - frame_idx=(target_head, target_tail), - token_type_ids=token_type_ids, - attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask) - - - evaler.metrics(batch_size=opt.batch_size, fe_cnt=fe_cnt, gold_fe_type=fe_type, gold_fe_head=fe_head, \ - gold_fe_tail=fe_tail, gold_frame_type=target_type, - pred_fe_type=return_dic['pred_role_action'], - pred_fe_head=return_dic['pred_head_action'], - pred_fe_tail=return_dic['pred_tail_action'], - pred_frame_type=return_dic['pred_frame_action']) - - - if show_case: - print('gold_fe_label = ', fe_type) - print('pred_fe_label = ', return_dic['pred_role_action']) - print('gold_head_label = ', fe_head) - print('pred_head_label = ', return_dic['pred_head_action']) - print('gold_tail_label = ', fe_tail) - print('pred_tail_label = ', return_dic['pred_tail_action']) - - metrics = evaler.calculate() - - - if best_metrics: - - if metrics[-1] > best_metrics: - best_metrics = metrics[-1] - - torch.save(model.state_dict(), opt.save_model_path) - - return best_metrics - - - - - - -if __name__ == '__main__': - - os.environ['CUDA_LAUNCH_BLOCKING'] = "1" - # os.environ['CUDA_LAUNCH_BLOCKING'] = 1 - - # bertconfig = BertConfig() - # print(bertconfig.CONFIG_PATH) - mp.set_start_method('spawn') - - opt = get_opt() - config = DataConfig(opt) - - if torch.cuda.is_available(): - device = torch.device(opt.cuda) - else: - device = torch.device('cpu') - - seed_everything(1116) - - epochs = opt.epochs - model = Model(opt, config) - model.to(device) - - optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) - frame_criterion =nn.CrossEntropyLoss() - head_criterion = nn.CrossEntropyLoss() - tail_criterion = nn.CrossEntropyLoss() - fe_type_criterion = nn.CrossEntropyLoss() - - if os.path.exists(opt.save_model_path) is True: - model.load_state_dict(torch.load(opt.save_model_path)) - - #pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) - train_dataset = FrameNetDataset(opt, config, config.train_instance_dic, device) - dev_dataset = FrameNetDataset(opt, config, config.dev_instance_dic, device) - test_dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) - - if opt.mode == 'train': - for epoch in range(1, epochs): - train_dl = DataLoader( - train_dataset, - batch_size=opt.batch_size, - shuffle=True, - num_workers=0 - ) - model.train() - print('==========epochs= ' + str(epoch)) - step = 0 - sum_loss = 0 - best_metrics = -1 - cnt = 0 - for batch in train_dl: - optimizer.zero_grad() - loss = 0 - word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, attention_mask, target_head, target_tail, \ - target_type, fe_head, fe_tail, fe_type, fe_cnt, \ - fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids = batch - - return_dic = model(word_ids=word_ids, lemma_ids=lemma_ids, pos_ids=pos_ids, head_ids=head_ids, rel_ids=rel_ids - ,lengths=lengths, - frame_idx=(target_head, target_tail), - token_type_ids=token_type_ids, - attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask) - # print(return_dic) - - frame_loss = 0 - head_loss = 0 - tail_loss = 0 - type_loss = 0 - - for batch_index in range(opt.batch_size): - pred_frame_first = return_dic['pred_frame_list'][fe_cnt[batch_index]][batch_index].unsqueeze(0) - pred_frame_last = return_dic['pred_frame_list'][0][batch_index].unsqueeze(0) - - pred_frame_label = pred_frame_last - - gold_frame_label = target_type[batch_index] - # print(gold_frame_label.size()) - # print(pred_frame_label.size()) - # print(fe_head) - frame_loss += frame_criterion(pred_frame_label, gold_frame_label) - - - for fe_index in range(opt.fe_padding_num): - - pred_type_label = return_dic['pred_role_list'][fe_index].squeeze() - pred_type_label = pred_type_label[batch_index].unsqueeze(0) - - gold_type_label = fe_type[batch_index][fe_index].unsqueeze(0) - type_loss += fe_type_criterion(pred_type_label, gold_type_label) - - if fe_index >= fe_cnt[batch_index]: - break - - # print(fe_cnt[batch_index]) - - pred_head_label = return_dic['pred_head_list'][fe_index].squeeze() - pred_head_label = pred_head_label[batch_index].unsqueeze(0) - - gold_head_label = fe_head[batch_index][fe_index].unsqueeze(0) - # print(gold_head_label.size()) - # print(pred_head_label.size()) - head_loss += head_criterion(pred_head_label, gold_head_label) - - pred_tail_label = return_dic['pred_tail_list'][fe_index].squeeze() - pred_tail_label = pred_tail_label[batch_index].unsqueeze(0) - - gold_tail_label = fe_tail[batch_index][fe_index].unsqueeze(0) - tail_loss += tail_criterion(pred_tail_label, gold_tail_label) - - - - # print(fe_cnt[batch_index]) - # head_loss /= int(fe_cnt[batch_index]) - # tail_loss /= int(fe_cnt[batch_index]) - # type_loss /= int(fe_cnt[batch_index]+1) - # - # head_loss_total+=head_loss - # tail_loss_total+=tail_loss - # type_loss_total+=type_loss - - loss = (0.1 * frame_loss + 0.3 * type_loss + 0.3 * head_loss + 0.3 * tail_loss) / (opt.batch_size) - # loss = (0.3 * head_loss + 0.3 * tail_loss) / (opt.b0.3 * atch_size) - loss.backward() - optimizer.step() - # loss+=frame_loss() - step += 1 - if step % 20 == 0: - print(" | batch loss: %.6f step = %d" % (loss.item(), step)) - # print('gold_frame_label = ',target_type) - # print('pred_frame_label = ',return_dic['pred_frame_action']) - - for index in range(len(target_type)): - if target_type[index] == return_dic['pred_frame_action'][0][index]: - cnt += 1 - # print('gold_fe_label = ', fe_type) - # print('pred_fe_label = ', return_dic['pred_role_action']) - # print('gold_head_label = ', fe_head) - # print('pred_head_label = ', return_dic['pred_head_action']) - # print('gold_tail_label = ', fe_tail) - # print('pred_tail_label = ', return_dic['pred_tail_action']) - sum_loss += loss.item() - - print('| epoch %d avg loss = %.6f' % (epoch, sum_loss / step)) - print('| epoch %d prec = %.6f' % (epoch, cnt / (opt.batch_size * step))) - - best_metrics=evaluate(opt,model,dev_dataset,best_metrics) - - - else: - evaluate(opt,model,test_dataset,show_case=True) - - - - From c37d955ffccba210d7a8212533e2f74005731b97 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:17:11 +0800 Subject: [PATCH 3/9] Delete model_syntax2.py --- model_syntax2.py | 507 ----------------------------------------------- 1 file changed, 507 deletions(-) delete mode 100644 model_syntax2.py diff --git a/model_syntax2.py b/model_syntax2.py deleted file mode 100644 index c0bd5f5..0000000 --- a/model_syntax2.py +++ /dev/null @@ -1,507 +0,0 @@ -import numpy as np -from typing import List,Tuple -import os -import json - -import torch.nn as nn -import torch -import torch.nn.functional as F - -from utils import batched_index_select,get_mask_from_index,generate_perm_inv - - -class Mlp(nn.Module): - def __init__(self, input_size, output_size): - super(Mlp, self).__init__() - self.linear = nn.Sequential( - nn.Linear(input_size, input_size), - nn.Dropout(0.4), - nn.ReLU(inplace=True), - nn.Linear(input_size, output_size), - ) - - def forward(self, x): - out = self.linear(x) - return out - - -class Relu_Linear(nn.Module): - def __init__(self, input_size, output_size): - super(Relu_Linear, self).__init__() - self.linear = nn.Sequential( - nn.Linear(input_size, output_size), - nn.Dropout(0.4), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - out = self.linear(x) - return out - - -class CnnNet(nn.Module): - def __init__(self, kernel_size, seq_length, input_size, output_size): - super(CnnNet, self).__init__() - self.seq_length = seq_length - self.output_size = output_size - self.kernel_size = kernel_size - - self.relu = nn.ReLU() - self.conv = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=self.kernel_size - , padding=1) - self.mp = nn.MaxPool1d(kernel_size=self.seq_length) - - def forward(self, input_emb): - input_emb = input_emb.permute(0, 2, 1) - x = self.conv(input_emb) - output = self.mp(x).squeeze() - - return output - - -class PointerNet(nn.Module): - def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): - super(PointerNet, self).__init__() - - assert attention_type in ('affine', 'dot_prod') - if attention_type == 'affine': - self.src_encoding_linear = Mlp(src_encoding_size, query_vec_size) - - self.src_linear = Mlp(src_encoding_size,src_encoding_size) - self.activate = nn.ReLU(inplace=True) - self.dropout = nn.Dropout(0.5) - - self.fc =nn.Linear(src_encoding_size*2,src_encoding_size, bias=True) - - self.attention_type = attention_type - - def forward(self, src_encodings, src_token_mask,query_vec,head_vec=None): - - # (batch_size, 1, src_sent_len, query_vec_size) - if self.attention_type == 'affine': - src_encod = self.src_encoding_linear(src_encodings).unsqueeze(1) - head_weights = self.src_linear(src_encodings).unsqueeze(1) - - # (batch_size, tgt_action_num, query_vec_size, 1) - if head_vec is not None: - src_encod = torch.cat([src_encod,head_weights],dim = -1) - q = torch.cat([head_vec, query_vec], dim=-1).permute(1, 0, 2).unsqueeze(3) - - - else: - q = query_vec.permute(1, 0, 2).unsqueeze(3) - - weights = torch.matmul(src_encod, q).squeeze(3) - ptr_weights = weights.permute(1, 0, 2) - - # if head_vec is not None: - # src_weights = torch.matmul(head_weights, q_h).squeeze(3) - # src_weights = src_weights.permute(1, 0, 2) - # ptr_weights = weights+src_weights - # - # else: - # ptr_weights = weights - - ptr_weights_masked = ptr_weights.clone().detach() - if src_token_mask is not None: - # (tgt_action_num, batch_size, src_sent_len) - src_token_mask=1-src_token_mask.byte() - src_token_mask = src_token_mask.unsqueeze(0).expand_as(ptr_weights) - # ptr_weights.data.masked_fill_(src_token_mask, -float('inf')) - ptr_weights_masked.data.masked_fill_(src_token_mask, -float('inf')) - - # ptr_weights =self.activate(ptr_weights) - - return ptr_weights,ptr_weights_masked - - -class Encoder(nn.Module): - def __init__(self, opt, config, word_embedding:nn.modules.sparse.Embedding, - lemma_embedding: nn.modules.sparse.Embedding): - super(Encoder, self).__init__() - self.opt =opt - self.hidden_size = opt.rnn_hidden_size - self.emb_size = opt.encoder_emb_size - self.rnn_input_size = self.emb_size*3+opt.pos_emb_size+opt.token_type_emb_size+\ - opt.sent_emb_size+opt.rel_emb_size - self.word_number = config.word_number - self.lemma_number = config.lemma_number - self.maxlen = opt.maxlen - - self.cnn = CnnNet(opt.kernel_size, seq_length=opt.maxlen, - input_size=self.emb_size*3+opt.pos_emb_size+opt.token_type_emb_size+opt.rel_emb_size, - output_size=opt.sent_emb_size) - - self.dropout = 0.2 - self.word_embedding = word_embedding - self.lemma_embedding = lemma_embedding - self.pos_embedding = nn.Embedding(config.pos_number, opt.pos_emb_size) - self.rel_embedding = nn.Embedding(config.rel_number,opt.rel_emb_size) - self.token_type_embedding = nn.Embedding(2, opt.token_type_emb_size) - self.cell_name = opt.cell_name - - # self.embedded_linear = nn.Linear(self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size+opt.sent_emb_size, - # self.rnn_input_size) - self.syntax_embedded_linear = nn.Linear(self.emb_size*2+opt.rel_emb_size+opt.token_type_emb_size, - self.rnn_input_size) - # self.output_combine_linear = nn.Linear(4*self.hidden_size, 2*self.hidden_size) - - self.target_linear = nn.Linear(2*self.hidden_size, 2*self.hidden_size) - - self.relu_linear = Relu_Linear(4*self.hidden_size+self.rnn_input_size, opt.decoder_emb_size) - - if self.cell_name == 'gru': - self.rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, - dropout=self.dropout,bidirectional=True, batch_first=True) - # self.syntax_rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, - # dropout=self.dropout,bidirectional=True, batch_first=True) - elif self.cell_name == 'lstm': - self.rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, - dropout=self.dropout,bidirectional=True, batch_first=True) - # self.syntax_rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, - # dropout=self.dropout,bidirectional=True, batch_first=True) - else: - print('cell_name error') - - def forward(self, word_input: torch.Tensor, lemma_input: torch.Tensor, pos_input: torch.Tensor, - head_input:torch.Tensor, rel_input:torch.Tensor, - lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None): - - word_embedded = self.word_embedding(word_input) - lemma_embedded = self.lemma_embedding(lemma_input) - pos_embedded = self.pos_embedding(pos_input) - head_embedded = self.word_embedding(head_input) - rel_embedded = self.rel_embedding(rel_input) - token_type_ids = self.token_type_embedding(token_type_ids) - - embedded = torch.cat([word_embedded, lemma_embedded, pos_embedded,head_embedded,rel_embedded,token_type_ids], dim=-1) - sent_embedded = self.cnn(embedded) - - sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) - embedded = torch.cat([embedded,sent_embedded], dim=-1) - #embedded = self.embedded_linear(embedded) - - # syntax embedding - # syntax_embedded = torch.cat([word_embedded,head_embedded,rel_embedded,token_type_ids],dim=-1) - # syntax_embedded = self.syntax_embedded_linear(syntax_embedded) - - lengths=lengths.squeeze() - # sorted before pack - l = lengths.cpu().numpy() - perm_idx = np.argsort(-l) - perm_idx_inv = generate_perm_inv(perm_idx) - - embedded = embedded[perm_idx] - # syntax_embedded = syntax_embedded[perm_idx] - - if lengths is not None: - rnn_input = nn.utils.rnn.pack_padded_sequence(embedded, lengths=lengths[perm_idx], - batch_first=True) - # syntax_rnn_input = nn.utils.rnn.pack_padded_sequence(syntax_embedded, lengths=lengths[perm_idx], - # batch_first=True) - output, hidden = self.rnn(rnn_input) - #syntax_output, syntax_hidden = self.rnn(syntax_rnn_input) - - if lengths is not None: - output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=self.maxlen, batch_first=True) - # syntax_output, _ = nn.utils.rnn.pad_packed_sequence(syntax_output, total_length=self.maxlen, batch_first=True) - - - # print(output.size()) - # print(hidden.size()) - - output = output[perm_idx_inv] - # syntax_output = syntax_output[perm_idx_inv] - - if self.cell_name == 'gru': - hidden = hidden[:, perm_idx_inv] - hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(hidden, 1, dim=0)) - - # syntax_hidden = syntax_hidden[:, perm_idx_inv] - # syntax_hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(syntax_hidden, 1, dim=0)) - # hidden = (hidden + syntax_hidden) / 2 - - elif self.cell_name == 'lstm': - hn0 = hidden[0][:, perm_idx_inv] - hn1 = hidden[1][:, perm_idx_inv] - # sy_hn0 = syntax_hidden[0][:, perm_idx_inv] - # sy_hn1 = syntax_hidden[1][:, perm_idx_inv] - hn = tuple([hn0,hn1]) - hidden = tuple(map(lambda state: sum(torch.split(state, 1, dim=0))/(2*self.opt.num_layers), hn)) - - - target_state_head = batched_index_select(target=output, indices=frame_idx[0]) - target_state_tail = batched_index_select(target=output, indices=frame_idx[1]) - target_state = (target_state_head + target_state_tail) / 2 - target_state = self.target_linear(target_state) - - target_emb_head = batched_index_select(target=embedded, indices=frame_idx[0]) - target_emb_tail = batched_index_select(target=embedded, indices=frame_idx[1]) - target_emb = (target_emb_head + target_emb_tail) / 2 - - attentional_target_state = local_attention(attention_mask=attention_mask, hidden_state=output, - frame_idx=frame_idx,target_state=target_state, - window_size=self.opt.window_size, - max_len=self.opt.maxlen) - - - target_state =torch.cat([target_state.squeeze(), attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) - target = self.relu_linear(target_state) - - # print(output.size()) - return output, hidden, target - - -class Decoder(nn.Module): - def __init__(self, opt, embedding_frozen=False): - super(Decoder, self).__init__() - - # rnn _init_ - self.opt = opt - self.cell_name = opt.cell_name - self.emb_size = opt.decoder_emb_size - self.hidden_size = opt.decoder_hidden_size - self.encoder_hidden_size = opt.rnn_hidden_size - - # decoder _init_ - self.decodelen = opt.fe_padding_num+1 - self.frame_embedding = nn.Embedding(opt.frame_number+1, self.emb_size) - self.frame_fc_layer =Mlp(self.emb_size, opt.frame_number+1) - self.role_embedding = nn.Embedding(opt.role_number+1, self.emb_size) - self.role_feature_layer = nn.Linear(2*self.emb_size, self.emb_size) - self.role_fc_layer = nn.Linear(self.hidden_size+self.emb_size, opt.role_number+1) - - self.head_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) - self.tail_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) - - self.span_fc_layer = Mlp(4 * self.encoder_hidden_size + self.emb_size, self.emb_size) - - self.next_input_fc_layer = Mlp(self.hidden_size+self.emb_size, self.emb_size) - - if embedding_frozen is True: - for param in self.frame_embedding.parameters(): - param.requires_grad = False - for param in self.role_embedding.parameters(): - param.requires_grad = False - - if self.cell_name == 'gru': - self.frame_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) - self.ent_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) - self.role_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) - if self.cell_name == 'lstm': - self.frame_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) - self.ent_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) - self.role_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) - - # pointer _init_ - self.ent_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) - self.head_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) - self.tail_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) - - def forward(self, encoder_output: torch.Tensor, encoder_state: torch.Tensor, target_state: torch.Tensor, - attention_mask: torch.Tensor, fe_mask=None, lu_mask=None): - pred_frame_list = [] - pred_head_list = [] - pred_tail_list = [] - pred_role_list = [] - - pred_frame_action = [] - - pred_head_action = [] - pred_tail_action = [] - pred_role_action = [] - - frame_decoder_state = encoder_state - role_decoder_state = encoder_state - - input = target_state - # print(input.size()) - - span_mask = attention_mask.clone() - for t in range(self.decodelen): - # frame pred - output, frame_decoder_state = self.decode_step(self.frame_rnn, input=input, - decoder_state=frame_decoder_state) - - pred_frame_weight = self.frame_fc_layer(target_state.squeeze()) - pred_frame_weight_masked = pred_frame_weight.clone().detach() - - if lu_mask is not None: - LU_mask = 1-lu_mask - pred_frame_weight_masked.data.masked_fill_(LU_mask.byte(), -float('inf')) - - pred_frame_indices = torch.argmax(pred_frame_weight_masked.squeeze(), dim=-1).squeeze() - - pred_frame_list.append(pred_frame_weight) - pred_frame_action.append(pred_frame_indices) - - frame_emb = self.frame_embedding(pred_frame_indices) - - head_input = self.head_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) - tail_input = self.tail_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) - - head_pointer_weight, head_pointer_weight_masked = self.head_pointer(src_encodings=encoder_output, - src_token_mask=span_mask, - query_vec=head_input.view(1, self.opt.batch_size, -1)) - - head_indices = torch.argmax(head_pointer_weight_masked.squeeze(), dim=-1).squeeze() - head_target = batched_index_select(target=encoder_output, indices=head_indices.squeeze()) - head_mask = head_mask_update(span_mask, head_indices=head_indices, max_len=self.opt.maxlen) - - tail_pointer_weight, tail_pointer_weight_masked = self.tail_pointer(src_encodings=encoder_output, - src_token_mask=head_mask, - query_vec=tail_input.view(1, self.opt.batch_size,-1), - head_vec=head_target.view(1,self.opt.batch_size,-1)) - - tail_indices = torch.argmax(tail_pointer_weight_masked.squeeze(), dim=-1).squeeze() - # tail_target = batched_index_select(target=bert_hidden_state,indices=tail_indices.squeeze()) - - span_mask = span_mask_update(attention_mask=span_mask, head_indices=head_indices, - tail_indices=tail_indices, max_len=self.opt.maxlen) - - pred_head_list.append(head_pointer_weight) - pred_tail_list.append(tail_pointer_weight) - pred_head_action.append(head_indices) - pred_tail_action.append(tail_indices) - - # role pred - - # print(ent_target.size()) - # print(head_target.size()) - # print(tail_target.size()) - # print(bert_hidden_state.size()) - # print(tail_pointer_weight.size()) - # print(tail_indices.size()) - # print(output.size()) - - # next step - # head_target = batched_index_select(target=bert_hidden_state, indices=head_indices.squeeze()) - - tail_target = batched_index_select(target=encoder_output, indices=tail_indices.squeeze()) - - # head_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, - # frame_idx=(head_indices, tail_indices), target_state=head_target, - # window_size=0, max_len=self.opt.maxlen) - # - # tail_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, - # frame_idx=(head_indices, tail_indices), target_state=tail_target, - # window_size=0, max_len=self.opt.maxlen) - - span_input = self.span_fc_layer(torch.cat([head_target+tail_target, head_target-tail_target, frame_emb], dim=-1)).unsqueeze(1) - - output,role_decoder_state = self.decode_step(self.role_rnn, input=span_input, - decoder_state=role_decoder_state) - role_target = self.role_fc_layer(torch.cat([output,span_input],dim=-1)) - role_target_masked = role_target.squeeze().clone().detach() - if fe_mask is not None : - FE_mask = 1-fe_mask - # print(FE_mask.size()) - role_target_masked.data.masked_fill_(FE_mask.byte(), -float('inf')) - - role_indices = torch.argmax(role_target_masked.squeeze(), dim=-1).squeeze() - role_emb = self.role_embedding(role_indices) - - pred_role_list.append(role_target) - pred_role_action.append(role_indices) - - # next step - next_input =torch.cat([output, role_emb.unsqueeze(1)], dim=-1) - input = self.next_input_fc_layer(next_input) - - - #return - return pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action,\ - pred_head_action, pred_tail_action, pred_role_action - - def decode_step(self, rnn_cell: nn.modules, input: torch.Tensor, decoder_state: torch.Tensor): - - output, state = rnn_cell(input.view(-1, 1, self.emb_size), decoder_state) - - return output, state - - -class Model(nn.Module): - def __init__(self, opt, config, load_emb=True): - super(Model, self).__init__() - self.word_vectors = config.word_vectors - self.lemma_vectors = config.lemma_vectors - self.word_embedding = nn.Embedding(config.word_number+1, opt.encoder_emb_size) - self.lemma_embedding = nn.Embedding(config.lemma_number+1, opt.encoder_emb_size) - - if load_emb: - self.load_pretrain_emb() - - self.encoder = Encoder(opt, config, self.word_embedding, self.lemma_embedding) - self.decoder = Decoder(opt) - - def load_pretrain_emb(self): - self.word_embedding.weight.data.copy_(torch.from_numpy(self.word_vectors)) - self.lemma_embedding.weight.data.copy_(torch.from_numpy(self.lemma_vectors)) - - def forward(self, word_ids, lemma_ids, pos_ids, head_ids, rel_ids, lengths, frame_idx, fe_mask=None, lu_mask=None, - frame_len=None, token_type_ids=None, attention_mask=None): - encoder_output, encoder_state, target_state = self.encoder(word_input=word_ids, lemma_input=lemma_ids, - pos_input=pos_ids, head_input=head_ids, - rel_input=rel_ids, lengths=lengths, - frame_idx=frame_idx, - token_type_ids=token_type_ids, - attention_mask=attention_mask) - - pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ - pred_head_action, pred_tail_action, pred_role_action = self.decoder(encoder_output=encoder_output, - encoder_state=encoder_state, - target_state=target_state, - attention_mask=attention_mask, - fe_mask=fe_mask, lu_mask=lu_mask) - - # return pred_frame_list, pred_ent_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ - # pred_ent_action, pred_head_action, pred_tail_action, pred_role_action - - return { - 'pred_frame_list' : pred_frame_list, - 'pred_head_list' : pred_head_list, - 'pred_tail_list' : pred_tail_list, - 'pred_role_list' : pred_role_list, - 'pred_frame_action' : pred_frame_action, - 'pred_head_action' : pred_head_action, - 'pred_tail_action' : pred_tail_action, - 'pred_role_action' : pred_role_action - } - - -def head_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, max_len): - indices=head_indices - indices_mask=1-get_mask_from_index(indices, max_len) - mask = torch.mul(attention_mask, indices_mask.long()) - - return mask - - -def span_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, tail_indices: torch.Tensor, max_len): - tail = tail_indices + 1 - head_indices_mask = get_mask_from_index(head_indices, max_len) - tail_indices_mask = get_mask_from_index(tail, max_len) - span_indices_mask = tail_indices_mask - head_indices_mask - span_indices_mask = 1 - span_indices_mask - mask = torch.mul(attention_mask, span_indices_mask.long()) - - return mask - - -def local_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, frame_idx, - target_state: torch.Tensor, window_size: int, max_len): - - q = target_state.squeeze().unsqueeze(2) - context_att = torch.bmm(hidden_state, q).squeeze() - head = frame_idx[0]-window_size - tail = frame_idx[1]+window_size - mask = span_mask_update(attention_mask=attention_mask, head_indices=head.squeeze(), - tail_indices=tail.squeeze(), max_len=max_len) - context_att = context_att.masked_fill_(mask.byte(), -float('inf')) - context_att = F.softmax(context_att, dim=-1) - attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() - - return attentional_hidden_state - - - From 73fa9d6779c69723c0ed9dad312e488e8f5ba435 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:17:21 +0800 Subject: [PATCH 4/9] Delete dataset2.py --- dataset2.py | 229 ---------------------------------------------------- 1 file changed, 229 deletions(-) delete mode 100644 dataset2.py diff --git a/dataset2.py b/dataset2.py deleted file mode 100644 index 2f29943..0000000 --- a/dataset2.py +++ /dev/null @@ -1,229 +0,0 @@ - -import torch -import numpy as np - -from torch.utils.data import Dataset -from config import get_opt -from data_process2 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig -from utils import get_mask_from_index - -class FrameNetDataset(Dataset): - - def __init__(self, opt, config, data_dic, device): - super(FrameNetDataset, self).__init__() - print('begin load data') - self.data_dic = data_dic - self.fe_id_to_label, self.fe_name_to_label, self.fe_name_to_id, self.fe_id_to_type = get_fe_tabel('parsed-v1.5/', 'FE.csv') - self.frame_id_to_label, self.frame_name_to_label, self.frame_name_to_id = get_frame_tabel('parsed-v1.5/', 'frame.csv') - - self.word_index = config.word_index - self.lemma_index = config.lemma_index - self.pos_index = config.pos_index - self.rel_index = config.rel_index - - self.fe_num = len(self.fe_id_to_label) - self.frame_num = len(self.frame_id_to_label) - self.batch_size = opt.batch_size - print(self.fe_num) - print(self.frame_num) - self.dataset_len = len(self.data_dic) - - self.fe_mask_list = get_fe_list('parsed-v1.5/', self.fe_num, self.fe_id_to_label) - self.lu_list, self.lu_id_to_name,\ - self.lu_name_to_id = get_lu_list('parsed-v1.5/', - self.frame_num, self.fe_num, - self.frame_id_to_label, - self.fe_mask_list) - - self.word_ids = [] - self.lemma_ids = [] - self.pos_ids = [] - self.head_ids = [] - self.rel_ids = [] - - self.lengths = [] - self.mask = [] - self.target_head = [] - self.target_tail = [] - self.target_type = [] - self.fe_head = [] - self.fe_tail = [] - self.fe_type = [] - self.fe_coretype = [] - self.sent_length = [] - self.fe_cnt = [] - self.fe_cnt_with_padding =[] - self.fe_mask = [] - self.lu_mask = [] - self.token_type_ids = [] - - self.device = device - self.oov_frame = 0 - self.long_span = 0 - self.error_span = 0 - self.fe_coretype_table = {} - - for idx in self.fe_id_to_type.keys(): - if self.fe_id_to_type[idx] == 'Core': - self.fe_coretype_table[idx] = 1 - else: - self.fe_coretype_table[idx] = 0 - - for key in self.data_dic.keys(): - self.pre_process(key, opt) - print('load data finish') - print('oov frame = ', self.oov_frame) - print('long_span = ', self.long_span) - print('dataset_len = ', self.dataset_len) - - def __len__(self): - self.dataset_len = int(self.dataset_len / self.batch_size) * self.batch_size - return self.dataset_len - - def __getitem__(self, item): - word_ids = torch.Tensor(self.word_ids[item]).long().to(self.device) - lemma_ids = torch.Tensor(self.lemma_ids[item]).long().to(self.device) - pos_ids = torch.Tensor(self.pos_ids[item]).long().to(self.device) - head_ids = torch.Tensor(self.head_ids[item]).long().to(self.device) - rel_ids = torch.Tensor(self.rel_ids[item]).long().to(self.device) - lengths = torch.Tensor([self.lengths[item]]).long().to(self.device) - mask = self.mask[item].long().to(self.device) - target_head = torch.Tensor([self.target_head[item]]).long().to(self.device) - target_tail = torch.Tensor([self.target_tail[item]]).long().to(self.device) - target_type = torch.Tensor([self.target_type[item]]).long().to(self.device) - fe_head = torch.Tensor(self.fe_head[item]).long().to(self.device) - fe_tail = torch.Tensor(self.fe_tail[item]).long().to(self.device) - fe_type = torch.Tensor(self.fe_type[item]).long().to(self.device) - fe_cnt = torch.Tensor([self.fe_cnt[item]]).long().to(self.device) - fe_cnt_with_padding = torch.Tensor([self.fe_cnt_with_padding[item]]).long().to(self.device) - fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device) - lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device) - token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device) - - return (word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, mask, target_head, target_tail, target_type, - fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding, - fe_mask, lu_mask, token_type_ids) - - def pre_process(self, key, opt): - if self.data_dic[key]['target_type'] not in self.frame_name_to_label: - self.oov_frame += 1 - self.dataset_len -= 1 - return - - target_id = self.frame_name_to_id[self.data_dic[key]['target_type']] - self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], - self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20) - - word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']] - lemma_ids = [self.lemma_index[lemma] for lemma in self.data_dic[key]['lemma_list']] - pos_ids = [self.pos_index[pos] for pos in self.data_dic[key]['pos_list']] - head_ids = [self.word_index[head] for head in self.data_dic[key]['dep_list'][0]] - rel_ids = [self.rel_index[rel] for rel in self.data_dic[key]['dep_list'][1]] - - self.word_ids.append(word_ids) - self.lemma_ids.append(lemma_ids) - self.pos_ids.append(pos_ids) - self.head_ids.append(head_ids) - self.rel_ids.append(rel_ids) - self.lengths.append(self.data_dic[key]['length']) - - # self.mask.append(self.data_dic[key]['attention_mask']) - self.target_head.append(self.data_dic[key]['target_idx'][0]) - self.target_tail.append(self.data_dic[key]['target_idx'][1]) - - mask = get_mask_from_index(torch.Tensor([int(self.data_dic[key]['length'])]), opt.maxlen).squeeze() - self.mask.append(mask) - - token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen) - self.token_type_ids.append(token_type_ids) - - self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']]) - - # print(self.frame_tabel[self.fe_data[key]['frame_ID']]) - - if self.data_dic[key]['length'] <= opt.maxlen: - sent_length = self.data_dic[key]['length'] - else: - sent_length = opt.maxlen - self.sent_length.append(sent_length) - - lu_name = self.data_dic[key]['lu'] - self.lu_mask.append(self.lu_list[lu_name]['lu_mask']) - self.fe_mask.append(self.lu_list[lu_name]['fe_mask']) - - fe_head = self.data_dic[key]['span_start'] - fe_tail = self.data_dic[key]['span_end'] - - self.fe_cnt.append(min(len(fe_head), opt.fe_padding_num)) - self.fe_cnt_with_padding.append(min(len(fe_head)+1, opt.fe_padding_num)) - - while len(fe_head) < opt.fe_padding_num: - fe_head.append(min(sent_length-1, opt.maxlen-1)) - - while len(fe_tail) < opt.fe_padding_num: - fe_tail.append(min(sent_length-1,opt.maxlen-1)) - - self.fe_head.append(fe_head[0:opt.fe_padding_num]) - self.fe_tail.append(fe_tail[0:opt.fe_padding_num]) - - fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']] - - while len(fe_type) < opt.fe_padding_num: - fe_type.append(self.fe_num) - # fe_coretype.append('0') - - self.fe_type.append(fe_type[0:opt.fe_padding_num]) - - def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, target_id, span_maxlen): - indices = [] - for index in range(len(fe_head_list)): - if fe_tail_list[index] - fe_head_list[index] >= span_maxlen: - indices.append(index) - elif fe_tail_list[index] < fe_head_list[index]: - indices.append(index) - continue - - elif (fe_type_list[index], target_id) not in self.fe_name_to_label: - indices.append(index) - - else: - for i in range(index): - if i not in indices: - if fe_head_list[index] >= fe_head_list[i] and fe_head_list[index] <= fe_tail_list[i]: - indices.append(index) - break - - elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]: - indices.append(index) - break - else: - continue - - fe_head_list_filter = [i for j, i in enumerate(fe_head_list) if j not in indices] - fe_tail_list_filter = [i for j, i in enumerate(fe_tail_list) if j not in indices] - fe_type_list_filter = [i for j, i in enumerate(fe_type_list) if j not in indices] - self.data_dic[key]['span_start'] = fe_head_list_filter - self.data_dic[key]['span_end'] = fe_tail_list_filter - self.data_dic[key]['span_type'] = fe_type_list_filter - - return len(indices) - - -def build_token_type_ids(target_head, target_tail, maxlen): - token_type_ids = [0]*maxlen - token_type_ids[target_head] = 1 - token_type_ids[target_tail] = 1 - # token_type_ids[target_head:target_tail+1] = [1]*(target_tail+1-target_head) - - return token_type_ids - - -if __name__ == '__main__': - opt = get_opt() - config = DataConfig(opt) - if torch.cuda.is_available(): - device = torch.device(opt.cuda) - else: - device = torch.device('cpu') - dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) - print(dataset.error_span) From 409721fad94205f82ca34b6786c9dc47d9e1d7b1 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:17:32 +0800 Subject: [PATCH 5/9] Delete data_process2.py --- data_process2.py | 376 ----------------------------------------------- 1 file changed, 376 deletions(-) delete mode 100644 data_process2.py diff --git a/data_process2.py b/data_process2.py deleted file mode 100644 index 912d4e5..0000000 --- a/data_process2.py +++ /dev/null @@ -1,376 +0,0 @@ -import numpy as np -import pandas as pd -import torch - -from config import get_opt -from utils import get_mask_from_index -from nltk.parse.stanford import StanfordDependencyParser - -def load_data(path): - f = open(path, 'r', encoding='utf-8') - lines = f.readlines() - return lines - - -def instance_process(lines, maxlen): - instance_dic = {} - - parser = StanfordDependencyParser(r".\stanford-parser-full-2015-12-09\stanford-parser.jar", - r".\stanford-parser-full-2015-12-09\stanford-parser-3.6.0-models.jar" - ) - cnt = 0 - find = False - word_list_total = [] - for line in lines: - if line[0:3] == '# i': - word_list = [] - lemma_list = [] - pos_list = [] - target_idx = [-1, -1] - span_start = [] - span_end = [] - span_type = [] - length = 0 - - elif line[0:3] == '# e': - instance_dic.setdefault((sent_id, target_type, cnt), {}) - instance_dic[(sent_id, target_type, cnt)]['dep_list'] = dep_parsing(word_list, maxlen, parser) - instance_dic[(sent_id, target_type, cnt)]['word_list'] = padding_sentence(word_list, maxlen) - instance_dic[(sent_id, target_type, cnt)]['lemma_list'] = padding_sentence(lemma_list, maxlen) - instance_dic[(sent_id, target_type, cnt)]['pos_list'] = padding_sentence(pos_list, maxlen) - instance_dic[(sent_id, target_type, cnt)]['sent_id'] = sent_id - - word_list_total.append(word_list) - # add 'eos' - instance_dic[(sent_id, target_type, cnt)]['length'] = int(length)+1 - # instance_dic[(sent_id, target_type, cnt)]['attention_mask'] = get_mask_from_index(sequence_lengths=torch.Tensor([int(length)+1]), max_length=maxlen).squeeze() - - instance_dic[(sent_id, target_type, cnt)]['target_type'] = target_type - instance_dic[(sent_id, target_type, cnt)]['lu'] = lu - instance_dic[(sent_id, target_type, cnt)]['target_idx'] = target_idx - - instance_dic[(sent_id, target_type, cnt)]['span_start'] = span_start - instance_dic[(sent_id, target_type, cnt)]['span_end'] = span_end - instance_dic[(sent_id, target_type, cnt)]['span_type'] = span_type - print(cnt) - cnt += 1 - elif line == '\n': - continue - - else: - data_list = line.split('\t') - word_list.append(data_list[1]) - lemma_list.append(data_list[3]) - pos_list.append(data_list[5]) - sent_id = data_list[6] - length = data_list[0] - - if data_list[12] != '_' and data_list[13] != '_': - lu = data_list[12] - - target_type = data_list[13] - if target_idx == [-1, -1]: - target_idx = [int(data_list[0])-1, int(data_list[0])-1] - else: - target_idx[1] =int(data_list[0]) - 1 - - if data_list[14] != '_': - - fe = data_list[14].split('-') - - if fe[0] == 'B' and find is False: - span_start.append(int(data_list[0]) - 1) - find = True - - elif fe[0] == 'O': - span_end.append(int(data_list[0]) - 1) - span_type.append(fe[-1].replace('\n', '')) - find = False - - elif fe[0] == 'S': - span_start.append(int(data_list[0]) - 1) - span_end.append(int(data_list[0]) - 1) - span_type.append(fe[-1].replace('\n', '')) - - return instance_dic - -def dep_parsing(word_list: list,maxlen: list,parser): - res = list(parser.parse(word_list)) - sent = res[0].to_conll(4).split('\n')[:-1] - #['the', 'DT', '4', 'det'] - line = [line.split('\t') for line in sent] - head_list = [] - rel_list = [] - - distance = 0 - - #alignment - for index in range(len(word_list)-1): - #end stopwords - if index-distance >= len(line): - head_list.append('#') - rel_list.append('#') - distance+=1 - - elif word_list[index]!=line[index-distance][0]: - head_list.append('#') - rel_list.append('#') - distance+=1 - else: - rel_list.append(line[index-distance][3]) - - if line[index-distance][3] != 'root': - head_list.append(word_list[int(line[index-distance][2]) - 1]) - else: - head_list.append(word_list[index]) - - head_list.append('eos') - rel_list.append('eos') - - while len(head_list) < maxlen: - head_list.append('0') - rel_list.append('0') - - return (head_list,rel_list) - - -def padding_sentence(sentence: list,maxlen: int): - sentence.append('eos') - while len(sentence) < maxlen: - sentence.append('0') - - return sentence - -class DataConfig: - def __init__(self,opt): - exemplar_lines = load_data('fn1.5/conll/exemplar') - train_lines = load_data('fn1.5/conll/train') - dev_lines = load_data('fn1.5/conll/dev') - test_lines = load_data('fn1.5/conll/test') - - - self.emb_file_path = opt.emb_file_path - self.maxlen = opt.maxlen - - if opt.load_instance_dic: - self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item() - self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item() - self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item() - - else: - print('begin parsing') - self.exemplar_instance_dic = instance_process(lines=exemplar_lines,maxlen=self.maxlen) - print('exemplar_instance_dic finish') - self.train_instance_dic = instance_process(lines=train_lines,maxlen=self.maxlen) - np.save('train_instance_dic', self.train_instance_dic) - print('train_instance_dic finish') - self.dev_instance_dic = instance_process(lines=dev_lines,maxlen=self.maxlen) - np.save('dev_instance_dic', self.dev_instance_dic) - print('dev_instance_dic finish') - self.test_instance_dic = instance_process(lines=test_lines,maxlen=self.maxlen) - np.save('test_instance_dic', self.test_instance_dic) - print('test_instance_dic finish') - - - self.word_index = {} - self.lemma_index = {} - self.pos_index = {} - self.rel_index = {} - - self.word_number = 0 - self.lemma_number = 0 - self.pos_number = 0 - self.rel_number = 0 - - #self.build_word_index(self.exemplar_instance_dic) - self.build_word_index(self.train_instance_dic) - self.build_word_index(self.dev_instance_dic) - self.build_word_index(self.test_instance_dic) - - # add # for parsing sign - self.word_index['#']=self.word_number - self.word_number+=1 - - self.emb_index = self.build_emb_index(self.emb_file_path) - - self.word_vectors = self.get_embedding_weight(self.emb_index, self.word_index, self.word_number) - self.lemma_vectors = self.get_embedding_weight(self.emb_index, self.lemma_index, self.lemma_number) - - def build_word_index(self, dic): - for key in dic.keys(): - word_list =dic[key]['word_list'] - lemma_list = dic[key]['lemma_list'] - pos_list = dic[key]['pos_list'] - rel_list = dic[key]['dep_list'][1] - - # print(row) - for word in word_list: - if word not in self.word_index.keys(): - self.word_index[word]=self.word_number - self.word_number += 1 - - for lemma in lemma_list: - if lemma not in self.lemma_index.keys(): - self.lemma_index[lemma]=self.lemma_number - self.lemma_number += 1 - - for pos in pos_list: - if pos not in self.pos_index.keys(): - self.pos_index[pos] = self.pos_number - self.pos_number += 1 - - for rel in rel_list: - if rel not in self.rel_index.keys(): - self.rel_index[rel] = self.rel_number - self.rel_number += 1 - - def build_emb_index(self, file_path): - data = open(file_path, 'r', encoding='utf-8') - emb_index = {} - for items in data: - item = items.split() - word = item[0] - weight = np.asarray(item[1:], dtype='float32') - emb_index[word] = weight - - return emb_index - - def get_embedding_weight(self,embed_dict, words_dict, words_count, dim=200): - - exact_count = 0 - fuzzy_count = 0 - oov_count = 0 - print("loading pre_train embedding by avg for out of vocabulary.") - embeddings = np.zeros((int(words_count) + 1, int(dim))) - inword_list = {} - for word in words_dict: - if word in embed_dict: - embeddings[words_dict[word]] = embed_dict[word] - inword_list[words_dict[word]] = 1 - # 准确匹配 - exact_count += 1 - elif word.lower() in embed_dict: - embeddings[words_dict[word]] = embed_dict[word.lower()] - inword_list[words_dict[word]] = 1 - # 模糊匹配 - fuzzy_count += 1 - else: - # 未登录词 - oov_count += 1 - # print(word) - # 对已经找到的词向量平均化 - sum_col = np.sum(embeddings, axis=0) / len(inword_list) # avg - sum_col /= np.std(sum_col) - for i in range(words_count): - if i not in inword_list: - embeddings[i] = sum_col - - embeddings[int(words_count)] = [0] * dim - final_embed = np.array(embeddings) - # print('exact_count: ',exact_count) - # print('fuzzy_count: ', fuzzy_count) - # print('oov_count: ', oov_count) - return final_embed - - -def load_data_pd(dataset_path,file): - # df=csv.reader(open(dataset_path+file,encoding='utf-8')) - # df = json.load(open(file_path,encoding='utf-8')) - df=pd.read_csv(dataset_path+file, header=0, encoding='utf-8') - return df - - -def get_frame_tabel(path, file): - data = load_data_pd(path, file) - - frame_id_to_label = {} - frame_name_to_label = {} - frame_name_to_id = {} - data_index = 0 - for idx in range(len(data['ID'])): - if data['ID'][idx] not in frame_id_to_label: - frame_id_to_label[data['ID'][idx]] = data_index - frame_name_to_label[data['Name'][idx]] = data_index - frame_name_to_id[data['Name'][idx]] = data['ID'][idx] - - data_index += 1 - - return frame_id_to_label, frame_name_to_label, frame_name_to_id - - -def get_fe_tabel(path, file): - data = load_data_pd(path, file) - - fe_id_to_label = {} - fe_name_to_label = {} - fe_name_to_id = {} - fe_id_to_type = {} - - data_index = 0 - for idx in range(len(data['ID'])): - if data['ID'][idx] not in fe_id_to_label: - fe_id_to_label[data['ID'][idx]] = data_index - fe_name_to_label[(data['Name'][idx], data['FrameID'][idx])] = data_index - fe_name_to_id[(data['Name'][idx], data['FrameID'][idx])] = data['ID'][idx] - fe_id_to_type[data['ID'][idx]] = data['CoreType'][idx] - - data_index += 1 - - return fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type - - -def get_fe_list(path, fe_num, fe_table, file='FE.csv'): - fe_dt = load_data_pd(path, file) - fe_mask_list = {} - - print('begin get fe list') - for idx in range(len(fe_dt['FrameID'])): - fe_mask_list.setdefault(fe_dt['FrameID'][idx], [0]*(fe_num+1)) - # fe_mask_list[fe_dt['FrameID'][idx]].setdefault('fe_mask', [0]*(fe_num+1)) - fe_mask_list[fe_dt['FrameID'][idx]][fe_table[fe_dt['ID'][idx]]] = 1 - - # for key in fe_list.keys(): - # fe_list[key]['fe_mask'][fe_num] = 1 - - return fe_mask_list - - -def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU.csv'): - lu_dt = load_data_pd(path, file) - lu_list = {} - lu_id_to_name = {} - lu_name_to_id = {} - #lu_name_to_felist = {} - - for idx in range(len(lu_dt['ID'])): - lu_name = lu_dt['Name'][idx] - lu_list.setdefault(lu_name, {}) - - lu_list[lu_name].setdefault('fe_mask', [0]*(fe_num+1)) - lu_list[lu_name]['fe_mask'] = list(map(lambda x: x[0]+x[1], zip(lu_list[lu_name]['fe_mask'], - fe_mask_list[lu_dt['FrameID'][idx]]))) - - lu_list[lu_name].setdefault('lu_mask', [0]*(lu_num+1)) - lu_list[lu_name]['lu_mask'][frame_id_to_label[lu_dt['FrameID'][idx]]] = 1 - - lu_id_to_name[lu_dt['ID'][idx]] = lu_name - lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx] - - for key in lu_list.keys(): - #lu_list[key]['lu_mask'][lu_num] = 1 - lu_list[key]['fe_mask'][fe_num] = 1 - - return lu_list, lu_id_to_name, lu_name_to_id - - -if __name__ == '__main__': - opt = get_opt() - config = DataConfig(opt) - print(config.word_vectors) - print(config.lemma_number) - print(config.word_number) - print(config.pos_number) - print(config.dep_number) - - - From f7ef7d695981160737e44864366db630eab5fc22 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 17 Feb 2021 22:17:41 +0800 Subject: [PATCH 6/9] Delete evaluate2.py --- evaluate2.py | 83 ---------------------------------------------------- 1 file changed, 83 deletions(-) delete mode 100644 evaluate2.py diff --git a/evaluate2.py b/evaluate2.py deleted file mode 100644 index 69ec0da..0000000 --- a/evaluate2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch - -class Eval(): - def __init__(self,opt): - self.fe_TP = 0 - self.fe_TP_FP = 0 - self.fe_TP_FN = 0 - - self.frame_cnt = 0 - self.frame_acc = 0 - - self.opt=opt - def metrics(self,batch_size:int,fe_cnt:torch.Tensor,\ - gold_fe_type:torch.Tensor,gold_fe_head:torch.Tensor,\ - gold_fe_tail:torch.Tensor,gold_frame_type:torch.Tensor,\ - pred_fe_type:torch.Tensor,pred_fe_head:torch.Tensor,\ - pred_fe_tail:torch.Tensor,pred_frame_type:torch.Tensor, - ): - - self.frame_cnt+=batch_size - for batch_index in range(batch_size): - #caculate frame acc - # print(gold_frame_type[batch_index]) - # print(pred_frame_type[batch_index]) - if gold_frame_type[batch_index] == pred_frame_type[0][batch_index]: - self.frame_acc += 1 - - #update tp_fn - self.fe_TP_FN+=int(fe_cnt[batch_index]) - # preprocess error - for idx in range(fe_cnt[batch_index]): - if gold_fe_head[batch_index][idx] > gold_fe_tail[batch_index][idx]: - self.fe_TP_FN -= 1 - - - #gold_fe_list = gold_fe_type.cpu().numpy().tolist() - gold_tail_list =gold_fe_tail.cpu().numpy().tolist() - #update fe_tp and fe_TP_FP - for fe_index in range(self.opt.fe_padding_num): - if pred_fe_type[fe_index][batch_index] == self.opt.role_number: - break - - - #update fe_tp - if pred_fe_tail[fe_index][batch_index] in gold_tail_list[batch_index]: - idx = gold_tail_list[batch_index].index(pred_fe_tail[fe_index][batch_index]) - - - if pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ - pred_fe_type[fe_index][batch_index] == gold_fe_type[batch_index][idx]: - self.fe_TP+=1 - - #update fe_tp_fp - self.fe_TP_FP+=1 - - def calculate(self): - frame_acc = self.frame_acc / self.frame_cnt - fe_prec = self.fe_TP / (self.fe_TP_FP+0.000001) - fe_recall = float(self.fe_TP / self.fe_TP_FN) - fe_f1 = 2*fe_prec*fe_recall/(fe_prec+fe_recall+0.0000001) - - full_TP = self.frame_acc+self.fe_TP - full_TP_FP = self.frame_cnt + self.fe_TP_FP - full_TP_FN = self.frame_cnt +self.fe_TP_FN - - full_prec = float(full_TP / full_TP_FP) - full_recall =float(full_TP / full_TP_FN) - full_f1 = 2 * full_prec * full_recall / (full_prec + full_recall+0.000001) - - print(" frame acc: %.6f " % frame_acc) - print(" fe_prec: %.6f " % fe_prec) - print(" fe_recall: %.6f " % fe_recall) - print(" fe_f1: %.6f " % fe_f1) - print('================full struction=============') - print(" full_prec: %.6f " % full_prec) - print(" full_recall: %.6f " % full_recall) - print(" full_f1: %.6f " % full_f1) - - return (frame_acc,fe_prec,fe_recall,fe_f1,full_prec,full_recall,full_f1) - - - - From 3574e13fb394894b320d8e859b0b1ba4a1302462 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Wed, 14 Apr 2021 20:35:30 +0800 Subject: [PATCH 7/9] Add files via upload --- evaluate7.py | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 evaluate7.py diff --git a/evaluate7.py b/evaluate7.py new file mode 100644 index 0000000..402b1a8 --- /dev/null +++ b/evaluate7.py @@ -0,0 +1,162 @@ +import torch +import numpy as np + +class Eval(): + def __init__(self,opt,Dataset): + self.fe_TP = 0 + self.fe_TP_FP = 0 + self.fe_TP_FN = 0 + + self.frame_cnt = 0 + self.frame_acc = 0 + + self.core_cnt = 0 + self.nocore_cnt = 0 + + self.declay = 0.8 + self.opt=opt + + self.fe_id_to_label =Dataset.fe_id_to_label + self.fe_label_to_id = {} + for key, val in self.fe_id_to_label.items(): + self.fe_label_to_id[val] = key + + self.frame_id_to_label =Dataset.frame_id_to_label + self.frame_label_to_id = {} + for key, val in self.frame_id_to_label.items(): + self.frame_label_to_id[val] = key + + path = 'frame-fe-dist-path/' + self.fe_dis_matrix = np.load(path + 'fe_dis_matrix.npy', allow_pickle=True) + self.frame_dis_matrix = np.load(path + 'frame_dis_matrix.npy', allow_pickle=True) + self.fe_path_matrix = np.load(path + 'fe_path_matrix.npy', allow_pickle=True) + self.frame_path_matrix = np.load(path + 'frame_path_matrix.npy', allow_pickle=True) + self.fe_hash_index = np.load(path + 'fe_hash_idx.npy', allow_pickle=True).item() + self.frame_hash_index = np.load(path + 'frame_hash_idx.npy', allow_pickle=True).item() + + def metrics(self,batch_size:int,fe_cnt:torch.Tensor,\ + gold_fe_type:torch.Tensor,gold_fe_head:torch.Tensor,\ + gold_fe_tail:torch.Tensor,gold_frame_type:torch.Tensor,\ + pred_fe_type:torch.Tensor,pred_fe_head:torch.Tensor,\ + pred_fe_tail:torch.Tensor,pred_frame_type:torch.Tensor, + fe_coretype,sent_length:torch.Tensor): + + self.frame_cnt+=batch_size + for batch_index in range(batch_size): + #caculate frame acc + # print(gold_frame_type[batch_index]) + # print(pred_frame_type[batch_index]) + if self.frame_label_to_id[int(gold_frame_type[batch_index])] in self.frame_hash_index.keys() and \ + self.frame_label_to_id[int(pred_frame_type[0][batch_index])] in self.frame_hash_index.keys(): + gold_frame_idx = self.frame_hash_index[self.frame_label_to_id[int(gold_frame_type[batch_index])]] + pred_frame_idx = self.frame_hash_index[self.frame_label_to_id[int(pred_frame_type[0][batch_index])]] + KeyError = False + else: + KeyError = True + + if gold_frame_type[batch_index] == pred_frame_type[0][batch_index]: + self.frame_acc += 1 + + elif KeyError is False: + if self.frame_dis_matrix[gold_frame_idx][pred_frame_idx]!=-1 and self.frame_dis_matrix[gold_frame_idx][pred_frame_idx] <=9: + self.frame_acc += 0.8**self.frame_dis_matrix[gold_frame_idx][pred_frame_idx] + # update tp_fn + # self.fe_TP_FN+=int(fe_cnt[batch_index]) + for fe_index in range(fe_cnt[batch_index]): + if fe_coretype[int(gold_fe_type[batch_index][fe_index])] == 1: + self.fe_TP_FN+=1 + self.core_cnt+=1 + else: + self.fe_TP_FN+=0.5 + self.nocore_cnt+=1 + + # preprocess error + + #gold_fe_list = gold_fe_type.cpu().numpy().tolist() + gold_tail_list =gold_fe_tail.cpu().numpy().tolist() + #update fe_tp and fe_TP_FP + for fe_index in range(self.opt.fe_padding_num): + # if pred_fe_tail[fe_index][batch_index] == sent_length[batch_index]-1 and \ + # pred_fe_head[fe_index][batch_index] == sent_length[batch_index] - 1: + if pred_fe_type[fe_index][batch_index] == self.opt.role_number: + break + + + #update fe_tp + if pred_fe_tail[fe_index][batch_index] in gold_tail_list[batch_index]: + idx = gold_tail_list[batch_index].index(pred_fe_tail[fe_index][batch_index]) + + + + if pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ + pred_fe_type[fe_index][batch_index] == gold_fe_type[batch_index][idx] : + + if fe_coretype[int(gold_fe_type[batch_index][idx])] == 1: + self.fe_TP+= 1 + else: + self.fe_TP+= 0.5 + + + elif gold_fe_type[batch_index][idx]==self.opt.role_number: + pass + + elif KeyError is True: + pass + + elif pred_fe_head[fe_index][batch_index]==gold_fe_head[batch_index][idx] and \ + self.fe_label_to_id[int(gold_fe_type[batch_index][idx])] in self.fe_hash_index.keys() \ + and self.fe_label_to_id[int(pred_fe_type[fe_index][batch_index])] in self.fe_hash_index.keys(): + + gold_fe_idx = self.fe_hash_index[self.fe_label_to_id[int(gold_fe_type[batch_index][idx])]] + pred_fe_idx = self.fe_hash_index[self.fe_label_to_id[int(pred_fe_type[fe_index][batch_index])]] + + if self.fe_dis_matrix[gold_fe_idx][pred_fe_idx] != -1 and self.fe_dis_matrix[gold_fe_idx][pred_fe_idx] < 9 : + rate = float(self.fe_path_matrix[gold_fe_idx][pred_fe_idx]/(self.frame_path_matrix[gold_frame_idx][pred_frame_idx]+0.000001)) + if rate > 1: + rate = 1 + if fe_coretype[int(gold_fe_type[batch_index][idx])] == 1: + self.fe_TP += 1*rate*(0.8**self.fe_dis_matrix[gold_fe_idx][pred_fe_idx]) + else: + self.fe_TP += 0.5*rate*(0.8**self.fe_dis_matrix[gold_fe_idx][pred_fe_idx]) + + #update fe_tp_fp + if fe_coretype[int(pred_fe_type[fe_index][batch_index])] == 1: + self.fe_TP_FP += 1 + else: + self.fe_TP_FP += 0.5 + + + def calculate(self): + frame_acc = self.frame_acc / self.frame_cnt + fe_prec = self.fe_TP / (self.fe_TP_FP+0.000001) + fe_recall = float(self.fe_TP / self.fe_TP_FN) + fe_f1 = 2*fe_prec*fe_recall/(fe_prec+fe_recall+0.0000001) + + full_TP = self.frame_acc+self.fe_TP + full_TP_FP = self.frame_cnt + self.fe_TP_FP + full_TP_FN = self.frame_cnt +self.fe_TP_FN + + full_prec = float(full_TP / full_TP_FP) + full_recall =float(full_TP / full_TP_FN) + full_f1 = 2 * full_prec * full_recall / (full_prec + full_recall+0.000001) + + print(" frame acc: %.6f " % frame_acc) + print(" fe_prec: %.6f " % fe_prec) + print(" fe_recall: %.6f " % fe_recall) + print(" fe_f1: %.6f " % fe_f1) + print('================full struction=============') + print(" full_prec: %.6f " % full_prec) + print(" full_recall: %.6f " % full_recall) + print(" full_f1: %.6f " % full_f1) + print('================detail=============') + print(" fe_TP: %.6f " % self.fe_TP) + print(" fe_TP_FN: %.6f " % self.fe_TP_FN) + print(" fe_TP_FP: %.6f " % self.fe_TP_FP) + print(" core_cnt: %.6f " % self.core_cnt) + print(" nocore_cnt: %.6f " % self.nocore_cnt) + + return (frame_acc,fe_prec,fe_recall,fe_f1,full_prec,full_recall,full_f1) + + + + From 077dfbcb103be084012d24b13e31809d6c5abceb Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Thu, 15 Apr 2021 15:18:43 +0800 Subject: [PATCH 8/9] Add files via upload GCN code --- graph_construct.py | 51 ++++++++++++++++++ relation_extraction.py | 56 ++++++++++++++++++++ rgcn.py | 114 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 221 insertions(+) create mode 100644 graph_construct.py create mode 100644 relation_extraction.py create mode 100644 rgcn.py diff --git a/graph_construct.py b/graph_construct.py new file mode 100644 index 0000000..08bf409 --- /dev/null +++ b/graph_construct.py @@ -0,0 +1,51 @@ +import torch +import dgl + +from data_process4 import get_fe_tabel, get_frame_tabel,load_data_pd +from relation_extraction import parserelationFiles + + +def graph_construct(fr_path='', fr_file='frame.csv', fe_path='', fe_file='FE.csv'): + frame_id_to_label, frame_name_to_label, frame_name_to_id = get_frame_tabel(path=fr_path, file=fr_file) + fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type = get_fe_tabel(path=fe_path, file=fe_file) + fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid= parserelationFiles() + + data_dic = {} + + # frame to frame + for item in zip(fr_sup, fr_sub, fr_label): + #print(item[2]) + data_dic.setdefault(('frame', rel_name[item[2]], 'frame'), [[], []]) + data_dic[('frame', rel_name[item[2]], 'frame')][0].append(frame_id_to_label[item[0]]) + data_dic[('frame', rel_name[item[2]], 'frame')][1].append(frame_id_to_label[item[1]]) + + # fe to fe + data_dic.setdefault(('fe', 'fe_to_fe', 'fe'), [[], []]) + for item in zip(fe_sup, fe_sub, fe_label): + data_dic[('fe', 'fe_to_fe', 'fe')][0].append(fe_id_to_label[item[0]]) + data_dic[('fe', 'fe_to_fe', 'fe')][1].append(fe_id_to_label[item[1]]) + data_dic[('fe', 'fe_to_fe', 'fe')][0].append(fe_id_to_label[item[1]]) + data_dic[('fe', 'fe_to_fe', 'fe')][1].append(fe_id_to_label[item[0]]) + + # frame to fe + fe_dt = load_data_pd(dataset_path=fe_path, file=fe_file) + data_dic[('frame', 'fr_to_fe', 'fe')] = [[], []] + data_dic[('fe', 'fe_to_fr', 'frame')] = [[], []] + for idx in range(len(fe_dt['FrameID'])): + data_dic[('frame', 'fr_to_fe', 'fe')][0].append(frame_id_to_label[fe_dt['FrameID'][idx]]) + data_dic[('frame', 'fr_to_fe', 'fe')][1].append(fe_id_to_label[fe_dt['ID'][idx]]) + + data_dic[('fe', 'fe_to_fr', 'frame')][0].append(fe_id_to_label[fe_dt['ID'][idx]]) + data_dic[('fe', 'fe_to_fr', 'frame')][1].append(frame_id_to_label[fe_dt['FrameID'][idx]]) + + for key in data_dic.keys(): + data_dic[key] = (torch.Tensor(data_dic[key][0]).long(), torch. Tensor(data_dic[key][1]).long()) + + g = dgl.heterograph(data_dic) + + return g + +if __name__ == '__main__': + + g = graph_construct() + print(g) diff --git a/relation_extraction.py b/relation_extraction.py new file mode 100644 index 0000000..fb19821 --- /dev/null +++ b/relation_extraction.py @@ -0,0 +1,56 @@ +try: + import xml.etree.cElementTree as ET +except ImportError: + import xml.etree.ElementTree as ET +import sys +import os +import pandas as pd + + +def parserelationFiles(c=0, relation_path='', filename='frRelation.xml', tagPrefix="{http://framenet.icsi.berkeley.edu}"): + tree = ET.ElementTree(file=relation_path + filename) + root = tree.getroot() + + fr_sub = [] + fr_sup = [] + fe_sub = [] + fe_sup = [] + fe_sub_to_frid = [] + fe_sup_to_frid = [] + fr_label =[] + fe_label =[] + rel_idx = -1 + rel_name = [] + + for child in root.iter(): + if child.tag == tagPrefix + 'frameRelationType': + rel_idx += 1 + rel_name.append(child.attrib.get('name')) + + for frame_child in child.iter(): + if frame_child.tag == tagPrefix + 'frameRelation': + fr_sub.append(int(frame_child.attrib.get('subID'))) + fr_sup.append(int(frame_child.attrib.get('supID'))) + fr_label.append(rel_idx) + + for fe_child in frame_child.iter(): + if fe_child.tag == tagPrefix + 'FERelation': + fe_sub.append(int(fe_child.attrib.get('subID'))) + fe_sup.append(int(fe_child.attrib.get('supID'))) + fe_sub_to_frid.append(int(frame_child.attrib.get('subID'))) + fe_sup_to_frid.append(int(frame_child.attrib.get('supID'))) + fe_label.append(rel_idx) + + return fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid + + +# fr_sub, fr_sup, fe_sub, fe_sup, fr_label, fe_label, rel_name, fe_sub_to_frid, fe_sup_to_frid = parserelationFiles() +# print(fr_sub) +# print(fr_sup) +# print(fr_label) +# print(fe_sub) +# print(fe_sup) +# print(fe_label) +# print(fe_sub_to_frid) +# print(fe_sup_to_frid) +# print(rel_name) \ No newline at end of file diff --git a/rgcn.py b/rgcn.py new file mode 100644 index 0000000..9252339 --- /dev/null +++ b/rgcn.py @@ -0,0 +1,114 @@ +import dgl +import dgl.nn as dglnn +import numpy as np +import torch + +import torch.nn as nn +import torch.nn.functional as F +from graph_construct import graph_construct + +class RGCN(nn.Module): + def __init__(self, in_feats, hid_feats, num_frame, num_fe, rel_names): + super().__init__() + self.conv1 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(in_feats, hid_feats) + for rel in rel_names}, aggregate='mean') + # self.conv2 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(hid_feats, out_feats) + # for rel in rel_names}, aggregate='mean') + + self.fr_linear = nn.Linear(num_frame, in_feats) + self.fe_linear = nn.Linear(num_fe, in_feats) + + self.fr_fc = nn.Linear(hid_feats, num_frame) + self.fe_fc = nn.Linear(hid_feats, num_fe) + + def forward(self, graph, inputs): + + inputs['frame'] = self.fr_linear(inputs['frame']) + inputs['fe'] = self.fe_linear(inputs['fe']) + h = self.conv1(graph, inputs) + h = {k: F.relu(v) for k, v in h.items()} + h = self.conv1(graph, h) + h['frame'] = h['frame'] + inputs['frame'] + h['fe'] = h['fe'] + inputs['fe'] + fr_h = self.fr_fc(h['frame']) + fe_h = self.fe_fc(h['fe']) + + return fr_h, fe_h, h + +if __name__ == '__main__': + + frame_graph = graph_construct() + # 特征初始化 + num_frame = 1019 + num_fe = 9634 + n_features= 200 + save_model_path = './pretrain_model' + rel_names = ['Causative_of', 'Inchoative_of', 'Inheritance', 'Perspective_on', 'Precedes', + 'ReFraming_Mapping', 'See_also', 'Subframe', 'Using','fe_to_fe', 'fe_to_fr', 'fr_to_fe'] + + + frame_graph.nodes['frame'].data['feature'] = torch.eye(num_frame) + frame_graph.nodes['fe'].data['feature'] = torch.eye(num_fe) + frame_graph.nodes['frame'].data['label'] = torch.arange(0, num_frame, 1) + frame_graph.nodes['fe'].data['label'] = torch.arange(0, num_fe, 1) + + model = RGCN(in_feats=n_features, hid_feats=200, + num_frame=num_frame, num_fe=num_fe, rel_names=rel_names) + + fr_feats = frame_graph.nodes['frame'].data['feature'] + fe_feats = frame_graph.nodes['fe'].data['feature'] + fr_labels = frame_graph.nodes['frame'].data['label'] + fe_labels = frame_graph.nodes['fe'].data['label'] + + #h_dict = model(frame_graph, {'frame': fr_feats, 'fe': fe_feats}) + + opt = torch.optim.Adam(model.parameters()) + + best_fr_train_acc = 0 + best_fe_train_acc = 0 + loss_list = [] + train_fr_score_list = [] + train_fe_score_list = [] + + for epoch in range(200): + model.train() + # 输入图和节点特征 + fr_logits, fe_logits, hidden = model(frame_graph, {'frame': fr_feats, 'fe': fe_feats}) + # 计算损失 + loss = F.cross_entropy(fr_logits, fr_labels) + F.cross_entropy(fe_logits, fe_labels) + # 预测frame + fr_pred = fr_logits.argmax(1) + # 计算准确率 + fr_train_acc = (fr_pred == fr_labels).float().mean() + if best_fr_train_acc < fr_train_acc: + best_fr_train_acc = fr_train_acc + train_fr_score_list.append(fr_train_acc) + + # 预测fe + fe_pred = fe_logits.argmax(1) + # 计算准确率 + fe_train_acc = (fe_pred == fe_labels).float().mean() + if best_fe_train_acc < fe_train_acc: + best_fe_train_acc = fe_train_acc + train_fe_score_list.append(fe_train_acc) + + # 反向优化 + opt.zero_grad() + loss.backward() + opt.step() + + loss_list.append(loss.item()) + # 输出训练结果 + print('Loss %.4f, Train fr Acc %.4f (Best %.4f) Train fe Acc %.4f (Best %.4f)' % ( + loss.item(), + fr_train_acc.item(), + best_fr_train_acc, + fe_train_acc.item(), + best_fe_train_acc + )) + #print(frame_graph.nodes['frame'].data['feature']) + + torch.save(model.state_dict(), save_model_path) + print(hidden) + torch.save(hidden['frame'], "./frTensor4.pt") + torch.save(hidden['fe'], "./feTensor4.pt") \ No newline at end of file From ef5ca78b5cf194aa005e3d257ee7acc8d06e45e4 Mon Sep 17 00:00:00 2001 From: Dogblack <527742942@qq.com> Date: Sat, 29 May 2021 17:47:15 +0800 Subject: [PATCH 9/9] Add files via upload --- fn-model_bert/config_bert.py | 73 +++ fn-model_bert/data_process_bert.py | 430 ++++++++++++++++++ fn-model_bert/dataset_bert.py | 271 +++++++++++ fn-model_bert/model_syntax36_with_bert.py | 530 ++++++++++++++++++++++ fn-model_bert/train_syntax36_bert.py | 256 +++++++++++ 5 files changed, 1560 insertions(+) create mode 100644 fn-model_bert/config_bert.py create mode 100644 fn-model_bert/data_process_bert.py create mode 100644 fn-model_bert/dataset_bert.py create mode 100644 fn-model_bert/model_syntax36_with_bert.py create mode 100644 fn-model_bert/train_syntax36_bert.py diff --git a/fn-model_bert/config_bert.py b/fn-model_bert/config_bert.py new file mode 100644 index 0000000..54c4c6a --- /dev/null +++ b/fn-model_bert/config_bert.py @@ -0,0 +1,73 @@ + +import argparse +import os + +data_dir ='./' + + +class PLMConfig: + MODEL_PATH = 'uncased_L-12_H-768_A-12' + VOCAB_PATH = f'{MODEL_PATH}/vocab.txt' + CONFIG_PATH = f'{MODEL_PATH}/bert_config.json' + + +def get_opt(): + parser = argparse.ArgumentParser() + + # 数据集位置 + parser.add_argument('--data_path', type=str, default='parsed-v1.5/') + parser.add_argument('--emb_file_path', type=str, default='./glove.6B/glove.6B.200d.txt') + parser.add_argument('--exemplar_instance_path', type=str, default='./exemplar_instance_dic_bert.npy') + parser.add_argument('--train_instance_path', type=str, default='./train_instance_dic_bert.npy') + parser.add_argument('--dev_instance_path', type=str, default='./dev_instance_dic_bert.npy') + parser.add_argument('--test_instance_path', type=str, default='./test_instance_dic_bert.npy') + # 保存模型和加载模型相关 + parser.add_argument('--load_instance_dic', type=bool, default=True) + parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') + parser.add_argument('--model_name', type=str, default='train_model') + parser.add_argument('--pretrain_model', type=str, default='') + parser.add_argument('--save_model_path', type=str, default='./models_ft/final_bert_3.bin') + + # 训练相关 + parser.add_argument('--lr', type=float, default='6e-6') + parser.add_argument('--weight_decay', type=float, default=0.0001) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--save_model_freq', type=int, default=1) # 保存模型间隔,以epoch为单位 + parser.add_argument('--cuda', type=str, default="cuda:0") + parser.add_argument('--mode', type=str, default="test") + + # 模型的一些settings + parser.add_argument('--maxlen',type=int, default=256) + parser.add_argument('--dropout', type=float, default=0.5) + parser.add_argument('--cell_name', type=str, default='lstm') + #change + parser.add_argument('--rnn_hidden_size',type=int, default=256) + parser.add_argument('--rnn_emb_size', type=int, default=400) + parser.add_argument('--encoder_emb_size',type=int, default=768) + parser.add_argument('--sent_emb_size', type=int, default=200) + parser.add_argument('--pos_emb_size',type=int, default=64) + parser.add_argument('--rel_emb_size',type=int,default=100) + parser.add_argument('--token_type_emb_size',type=int, default=32) + parser.add_argument('--decoder_emb_size',type=int, default=200) + parser.add_argument('--decoder_hidden_size',type=int, default=256) + parser.add_argument('--target_maxlen',type=int, default=5) + parser.add_argument('--decoderlen',type=int,default=4) + parser.add_argument('--kernel_size', type=int, default=3) + parser.add_argument('--frame_number',type=int, default=1019) + parser.add_argument('--role_number',type=int, default=9634) + parser.add_argument('--fe_padding_num',type=int, default=5) + parser.add_argument('--window_size', type=int, default=3) + parser.add_argument('--num_layers', type=int, default=2) + + return parser.parse_args() + + + + + + + +if __name__ == '__main__': + bertconfig = BertConfig() + print(bertconfig.CONFIG_PATH) diff --git a/fn-model_bert/data_process_bert.py b/fn-model_bert/data_process_bert.py new file mode 100644 index 0000000..282bc76 --- /dev/null +++ b/fn-model_bert/data_process_bert.py @@ -0,0 +1,430 @@ +import numpy as np +import os +import pandas as pd +import csv +import numpy +import torch +import json +import re + +from config_bert import get_opt,PLMConfig +from pytorch_pretrained_bert import BertTokenizer +from bert_slot_tokenizer import SlotConverter +from nltk import word_tokenize, pos_tag +from nltk.corpus import wordnet +from nltk.stem import WordNetLemmatizer + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +bert_tokenizer=BertTokenizer.from_pretrained(PLMConfig.MODEL_PATH) + + + +def load_data(path): + f = open(path, 'r', encoding='utf-8') + lines = f.readlines() + return lines + + +def instance_process(lines, maxlen): + instance_dic = {} + + cnt = 0 + find = False + word_list_total = [] + for line in lines: + if line[0:3] == '# i': + word_list = [] + lemma_list = [] + pos_list = [] + target_idx = [-1, -1] + span_start = [] + span_start_token = [] + span_end = [] + span_end_token = [] + span_type = [] + length = 0 + sent ='[CLS] ' + + elif line[0:3] == '# e': + instance_dic.setdefault((sent_id, target_type, cnt), {}) + #instance_dic[(sent_id, target_type, cnt)]['dep_list'] = dep_parsing(word_list, maxlen, parser) + word_list = bert_tokenizer.tokenize(sent) + + instance_dic[(sent_id, target_type, cnt)]['length'] = min(len(word_list),maxlen) + + attention_mask = [0]*maxlen + attention_mask[0:min(len(word_list),maxlen)]=[1] * min(len(word_list),maxlen) + + instance_dic[(sent_id, target_type, cnt)]['attention_mask']=attention_mask + + word_list_padded=padding_sentence(word_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['word_list'] = word_list_padded + instance_dic[(sent_id, target_type, cnt)]['tokenized_ids'] = bert_tokenizer.convert_tokens_to_ids(word_list_padded) + # instance_dic[(sent_id, target_type, cnt)]['lemma_list'] = padding_sentence(lemma_list, maxlen) + # instance_dic[(sent_id, target_type, cnt)]['pos_list'] = padding_sentence(pos_list, maxlen) + instance_dic[(sent_id, target_type, cnt)]['sent_id'] = sent_id + + word_list_total.append(word_list) + # add 'eos' + # instance_dic[(sent_id, target_type, cnt)]['attention_mask'] = get_mask_from_index(sequence_lengths=torch.Tensor([int(length)+1]), max_length=maxlen).squeeze() + + instance_dic[(sent_id, target_type, cnt)]['target_type'] = target_type + instance_dic[(sent_id, target_type, cnt)]['lu'] = lu + + target_start_list=bert_tokenizer.tokenize(target_start_token) + target_end_list = bert_tokenizer.tokenize(target_end_token) + # print(target_start_list) + # print(target_end_list) + # print(target_start_list[0]) + # print(target_end_list[-1]) + # print(word_list) + target_start=word_list.index(target_start_list[0],int(target_idx[0])) + target_end=word_list.index(target_end_list[-1],int(target_idx[1])) + instance_dic[(sent_id, target_type, cnt)]['target_idx'] = (target_start,target_end) + + instance_dic[(sent_id, target_type, cnt)]['span_start']=[] + instance_dic[(sent_id, target_type, cnt)]['span_end']=[] + for i in range(len(span_start)): + span_start_list = bert_tokenizer.tokenize(span_start_token[i]) + span_end_list = bert_tokenizer.tokenize(span_end_token[i]) + span_start_idx = word_list.index(span_start_list[0], span_start[i]) + span_end_idx = word_list.index(span_end_list[-1], span_end[i]) + + + instance_dic[(sent_id, target_type, cnt)]['span_start'].append(span_start_idx) + instance_dic[(sent_id, target_type, cnt)]['span_end'].append(span_end_idx) + instance_dic[(sent_id, target_type, cnt)]['span_type'] = span_type + + print(cnt) + cnt += 1 + elif line == '\n': + sent += '[SEP]' + continue + + else: + data_list = line.split('\t') + word_list.append(data_list[1]) + # lemma_list.append(data_list[3]) + # pos_list.append(data_list[5]) + sent_id = data_list[6] + # length = data_list[0] + sent+=data_list[1]+' ' + + if data_list[12] != '_' and data_list[13] != '_': + lu = data_list[12] + + target_type = data_list[13] + if target_idx == [-1, -1]: + target_idx = [int(data_list[0])-1, int(data_list[0])-1] + target_start_token = data_list[1] + target_end_token = data_list[1] + else: + target_idx[1] =int(data_list[0]) - 1 + target_end_token = data_list[1] + + if data_list[14] != '_': + + fe = data_list[14].split('-') + + if fe[0] == 'B' and find is False: + span_start.append(int(data_list[0]) - 1) + span_start_token.append(data_list[1]) + find = True + + elif fe[0] == 'O': + span_end.append(int(data_list[0]) - 1) + span_end_token.append(data_list[1]) + span_type.append(fe[-1].replace('\n', '')) + find = False + + elif fe[0] == 'S': + span_start.append(int(data_list[0]) - 1) + span_start_token.append(data_list[1]) + span_end.append(int(data_list[0]) - 1) + span_end_token.append(data_list[1]) + span_type.append(fe[-1].replace('\n', '')) + + return instance_dic + +# def dep_parsing(word_list: list,maxlen: list,parser): +# res = list(parser.parse(word_list)) +# sent = res[0].to_conll(4).split('\n')[:-1] +# #['the', 'DT', '4', 'det'] +# line = [line.split('\t') for line in sent] +# head_list = [] +# rel_list = [] +# +# distance = 0 +# +# #alignment +# for index in range(len(word_list)-1): +# #end stopwords +# if index-distance >= len(line): +# head_list.append('#') +# rel_list.append('#') +# distance+=1 +# +# elif word_list[index]!=line[index-distance][0]: +# head_list.append('#') +# rel_list.append('#') +# distance+=1 +# else: +# rel_list.append(line[index-distance][3]) +# +# if line[index-distance][3] != 'root': +# head_list.append(word_list[int(line[index-distance][2]) - 1]) +# else: +# head_list.append(word_list[index]) +# +# head_list.append('eos') +# rel_list.append('eos') +# +# while len(head_list) < maxlen: +# head_list.append('0') +# rel_list.append('0') +# +# return (head_list,rel_list) + + +def padding_sentence(sentence: list,maxlen: int): + + while len(sentence) < maxlen: + sentence.append('0') + + return sentence[0:maxlen] + +class DataConfig: + def __init__(self,opt): + exemplar_lines = load_data('fn1.5/conll/exemplar') + train_lines = load_data('fn1.5/conll/train') + dev_lines = load_data('fn1.5/conll/dev') + test_lines = load_data('fn1.5/conll/test') + + + self.emb_file_path = opt.emb_file_path + self.maxlen = opt.maxlen + + if opt.load_instance_dic: + self.exemplar_instance_dic = np.load(opt.exemplar_instance_path, allow_pickle=True).item() + self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item() + self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item() + self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item() + + else: + print('begin parsing') + self.exemplar_instance_dic = instance_process(lines=exemplar_lines,maxlen=self.maxlen) + np.save('exemplar_instance_dic_bert', self.exemplar_instance_dic) + print('exemplar_instance_dic_bert finish') + + self.train_instance_dic = instance_process(lines=train_lines,maxlen=self.maxlen) + np.save('train_instance_dic_bert', self.train_instance_dic) + print('train_instance_dic finish') + + self.dev_instance_dic = instance_process(lines=dev_lines,maxlen=self.maxlen) + np.save('dev_instance_dic_bert', self.dev_instance_dic) + print('dev_instance_dic finish') + + self.test_instance_dic = instance_process(lines=test_lines,maxlen=self.maxlen) + np.save('test_instance_dic_bert', self.test_instance_dic) + print('test_instance_dic finish') + + + # self.word_index = {} + # self.lemma_index = {} + # self.pos_index = {} + # self.rel_index = {} + # + # self.word_number = 0 + # self.lemma_number = 0 + # self.pos_number = 0 + # self.rel_number = 0 + # + # self.build_word_index(self.exemplar_instance_dic) + # self.build_word_index(self.train_instance_dic) + # self.build_word_index(self.dev_instance_dic) + # self.build_word_index(self.test_instance_dic) + # + # # add # for parsing sign + # self.word_index['#']=self.word_number + # self.word_number+=1 + # + # self.emb_index = self.build_emb_index(self.emb_file_path) + # + # self.word_vectors = self.get_embedding_weight(self.emb_index, self.word_index, self.word_number) + # self.lemma_vectors = self.get_embedding_weight(self.emb_index, self.lemma_index, self.lemma_number) + + def build_word_index(self, dic): + for key in dic.keys(): + word_list =dic[key]['word_list'] + lemma_list = dic[key]['lemma_list'] + pos_list = dic[key]['pos_list'] + rel_list = dic[key]['dep_list'][1] + + # print(row) + for word in word_list: + if word not in self.word_index.keys(): + self.word_index[word]=self.word_number + self.word_number += 1 + + for lemma in lemma_list: + if lemma not in self.lemma_index.keys(): + self.lemma_index[lemma]=self.lemma_number + self.lemma_number += 1 + + for pos in pos_list: + if pos not in self.pos_index.keys(): + self.pos_index[pos] = self.pos_number + self.pos_number += 1 + + for rel in rel_list: + if rel not in self.rel_index.keys(): + self.rel_index[rel] = self.rel_number + self.rel_number += 1 + + def build_emb_index(self, file_path): + data = open(file_path, 'r', encoding='utf-8') + emb_index = {} + for items in data: + item = items.split() + word = item[0] + weight = np.asarray(item[1:], dtype='float32') + emb_index[word] = weight + + return emb_index + + def get_embedding_weight(self,embed_dict, words_dict, words_count, dim=200): + + exact_count = 0 + fuzzy_count = 0 + oov_count = 0 + print("loading pre_train embedding by avg for out of vocabulary.") + embeddings = np.zeros((int(words_count) + 1, int(dim))) + inword_list = {} + for word in words_dict: + if word in embed_dict: + embeddings[words_dict[word]] = embed_dict[word] + inword_list[words_dict[word]] = 1 + # 准确匹配 + exact_count += 1 + elif word.lower() in embed_dict: + embeddings[words_dict[word]] = embed_dict[word.lower()] + inword_list[words_dict[word]] = 1 + # 模糊匹配 + fuzzy_count += 1 + else: + # 未登录词 + oov_count += 1 + # print(word) + # 对已经找到的词向量平均化 + sum_col = np.sum(embeddings, axis=0) / len(inword_list) # avg + sum_col /= np.std(sum_col) + for i in range(words_count): + if i not in inword_list: + embeddings[i] = sum_col + + embeddings[int(words_count)] = [0] * dim + final_embed = np.array(embeddings) + # print('exact_count: ',exact_count) + # print('fuzzy_count: ', fuzzy_count) + # print('oov_count: ', oov_count) + return final_embed + + +def load_data_pd(dataset_path,file): + # df=csv.reader(open(dataset_path+file,encoding='utf-8')) + # df = json.load(open(file_path,encoding='utf-8')) + df=pd.read_csv(dataset_path+file, header=0, encoding='utf-8') + return df + + +def get_frame_tabel(path, file): + data = load_data_pd(path, file) + + frame_id_to_label = {} + frame_name_to_label = {} + frame_name_to_id = {} + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in frame_id_to_label: + frame_id_to_label[data['ID'][idx]] = data_index + frame_name_to_label[data['Name'][idx]] = data_index + frame_name_to_id[data['Name'][idx]] = data['ID'][idx] + + data_index += 1 + + return frame_id_to_label, frame_name_to_label, frame_name_to_id + + +def get_fe_tabel(path, file): + data = load_data_pd(path, file) + + fe_id_to_label = {} + fe_name_to_label = {} + fe_name_to_id = {} + fe_id_to_type = {} + + data_index = 0 + for idx in range(len(data['ID'])): + if data['ID'][idx] not in fe_id_to_label: + fe_id_to_label[data['ID'][idx]] = data_index + fe_name_to_label[(data['Name'][idx], data['FrameID'][idx])] = data_index + fe_name_to_id[(data['Name'][idx], data['FrameID'][idx])] = data['ID'][idx] + fe_id_to_type[data['ID'][idx]] = data['CoreType'][idx] + + data_index += 1 + + return fe_id_to_label, fe_name_to_label, fe_name_to_id, fe_id_to_type + + +def get_fe_list(path, fe_num, fe_table, file='FE.csv'): + fe_dt = load_data_pd(path, file) + fe_mask_list = {} + + print('begin get fe list') + for idx in range(len(fe_dt['FrameID'])): + fe_mask_list.setdefault(fe_dt['FrameID'][idx], [0]*(fe_num+1)) + # fe_mask_list[fe_dt['FrameID'][idx]].setdefault('fe_mask', [0]*(fe_num+1)) + fe_mask_list[fe_dt['FrameID'][idx]][fe_table[fe_dt['ID'][idx]]] = 1 + + # for key in fe_list.keys(): + # fe_list[key]['fe_mask'][fe_num] = 1 + + return fe_mask_list + + +def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU.csv'): + lu_dt = load_data_pd(path, file) + lu_list = {} + lu_id_to_name = {} + lu_name_to_id = {} + #lu_name_to_felist = {} + + for idx in range(len(lu_dt['ID'])): + lu_name = lu_dt['Name'][idx] + lu_list.setdefault(lu_name, {}) + + lu_list[lu_name].setdefault('fe_mask', [0]*(fe_num+1)) + lu_list[lu_name]['fe_mask'] = list(map(lambda x: x[0]+x[1], zip(lu_list[lu_name]['fe_mask'], + fe_mask_list[lu_dt['FrameID'][idx]]))) + + lu_list[lu_name].setdefault('lu_mask', [0]*(lu_num+1)) + lu_list[lu_name]['lu_mask'][frame_id_to_label[lu_dt['FrameID'][idx]]] = 1 + + lu_id_to_name[lu_dt['ID'][idx]] = lu_name + lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx] + + for key in lu_list.keys(): + # lu_list[key]['lu_mask'][lu_num] = 1 + lu_list[key]['fe_mask'][fe_num] = 1 + + return lu_list, lu_id_to_name, lu_name_to_id + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + + + + diff --git a/fn-model_bert/dataset_bert.py b/fn-model_bert/dataset_bert.py new file mode 100644 index 0000000..2c4c297 --- /dev/null +++ b/fn-model_bert/dataset_bert.py @@ -0,0 +1,271 @@ + +import torch +import numpy as np + +from torch.utils.data import Dataset +from config_bert import get_opt +from data_process_bert import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig +from utils import get_mask_from_index + +class FrameNetDataset(Dataset): + + def __init__(self, opt, config, data_dic, device): + super(FrameNetDataset, self).__init__() + print('begin load data') + self.data_dic = data_dic + self.fe_id_to_label, self.fe_name_to_label, self.fe_name_to_id, self.fe_id_to_type = get_fe_tabel('parsed-v1.5/', 'FE.csv') + self.frame_id_to_label, self.frame_name_to_label, self.frame_name_to_id = get_frame_tabel('parsed-v1.5/', 'frame.csv') + + # self.word_index = config.word_index + # self.lemma_index = config.lemma_index + # self.pos_index = config.pos_index + # self.rel_index = config.rel_index + + self.fe_num = len(self.fe_id_to_label) + self.frame_num = len(self.frame_id_to_label) + self.batch_size = opt.batch_size + print(self.fe_num) + print(self.frame_num) + self.dataset_len = len(self.data_dic) + + self.fe_mask_list = get_fe_list('parsed-v1.5/', self.fe_num, self.fe_id_to_label) + self.lu_list, self.lu_id_to_name,\ + self.lu_name_to_id = get_lu_list('parsed-v1.5/', + self.frame_num, self.fe_num, + self.frame_id_to_label, + self.fe_mask_list) + + self.word_ids = [] + # self.lemma_ids = [] + # self.pos_ids = [] + # self.head_ids = [] + # self.rel_ids = [] + + self.lengths = [] + self.mask = [] + self.target_head = [] + self.target_tail = [] + self.target_type = [] + self.fe_head = [] + self.fe_tail = [] + self.fe_type = [] + self.fe_coretype = [] + self.sent_length = [] + self.fe_cnt = [] + self.fe_cnt_with_padding =[] + self.fe_mask = [] + self.lu_mask = [] + self.token_type_ids = [] + self.target_mask_ids = [] + + self.device = device + self.oov_frame = 0 + self.long_span = 0 + self.error_span = 0 + self.fe_coretype_table = {} + self.target_mask = {} + + for idx in self.fe_id_to_type.keys(): + if self.fe_id_to_type[idx] == 'Core': + self.fe_coretype_table[self.fe_id_to_label[idx]] = 1 + else: + self.fe_coretype_table[self.fe_id_to_label[idx]] = 0 + + + + for key in self.data_dic.keys(): + self.build_target_mask(key,opt.maxlen) + + + for key in self.data_dic.keys(): + self.pre_process(key, opt) + + self.pad_dic_cnt = self.dataset_len % opt.batch_size + + + for idx,key in enumerate(self.data_dic.keys()): + if idx >= self.pad_dic_cnt: + break + self.pre_process(key, opt,filter=False) + + self.dataset_len+=self.pad_dic_cnt + + print('load data finish') + print('oov frame = ', self.oov_frame) + print('long_span = ', self.long_span) + print('dataset_len = ', self.dataset_len) + + def __len__(self): + self.dataset_len = int(self.dataset_len / self.batch_size) * self.batch_size + return self.dataset_len + + def __getitem__(self, item): + word_ids = torch.Tensor(self.word_ids[item]).long().to(self.device) + # lemma_ids = torch.Tensor(self.lemma_ids[item]).long().to(self.device) + # pos_ids = torch.Tensor(self.pos_ids[item]).long().to(self.device) + # head_ids = torch.Tensor(self.head_ids[item]).long().to(self.device) + # rel_ids = torch.Tensor(self.rel_ids[item]).long().to(self.device) + lengths = torch.Tensor([self.lengths[item]]).long().to(self.device) + mask = torch.Tensor(self.mask[item]).long().to(self.device) + target_head = torch.Tensor([self.target_head[item]]).long().to(self.device) + target_tail = torch.Tensor([self.target_tail[item]]).long().to(self.device) + target_type = torch.Tensor([self.target_type[item]]).long().to(self.device) + fe_head = torch.Tensor(self.fe_head[item]).long().to(self.device) + fe_tail = torch.Tensor(self.fe_tail[item]).long().to(self.device) + fe_type = torch.Tensor(self.fe_type[item]).long().to(self.device) + fe_cnt = torch.Tensor([self.fe_cnt[item]]).long().to(self.device) + fe_cnt_with_padding = torch.Tensor([self.fe_cnt_with_padding[item]]).long().to(self.device) + fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device) + lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device) + token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device) + sent_length = torch.Tensor([self.sent_length[item]]).long().to(self.device) + target_mask_ids = torch.Tensor(self.target_mask_ids[item]).long().to(self.device) + # print(fe_cnt) + + + return (word_ids, lengths, mask, target_head, target_tail, target_type, + fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding, + fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids) + + def pre_process(self, key, opt,filter=True): + if self.data_dic[key]['target_type'] not in self.frame_name_to_label: + self.oov_frame += 1 + self.dataset_len -= 1 + return + + target_id = self.frame_name_to_id[self.data_dic[key]['target_type']] + if filter: + self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'], + self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20) + + # word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']] + # lemma_ids = [self.lemma_index[lemma] for lemma in self.data_dic[key]['lemma_list']] + # pos_ids = [self.pos_index[pos] for pos in self.data_dic[key]['pos_list']] + # head_ids = [self.word_index[head] for head in self.data_dic[key]['dep_list'][0]] + # rel_ids = [self.rel_index[rel] for rel in self.data_dic[key]['dep_list'][1]] + + self.word_ids.append(self.data_dic[key]['tokenized_ids']) + # self.lemma_ids.append(lemma_ids) + # self.pos_ids.append(pos_ids) + # self.head_ids.append(head_ids) + # self.rel_ids.append(rel_ids) + self.lengths.append(self.data_dic[key]['length']) + + self.mask.append(self.data_dic[key]['attention_mask']) + self.target_head.append(self.data_dic[key]['target_idx'][0]) + self.target_tail.append(self.data_dic[key]['target_idx'][1]) + + # mask = get_mask_from_index(torch.Tensor([int(self.data_dic[key]['length'])]), opt.maxlen).squeeze() + # self.mask.append(mask) + + token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen) + # token_type_ids +=self.target_mask[key[0]] + self.token_type_ids.append(token_type_ids) + self.target_mask_ids.append(self.target_mask[key[0]]) + + self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']]) + + # print(self.frame_tabel[self.fe_data[key]['frame_ID']]) + + if self.data_dic[key]['length'] <= opt.maxlen: + sent_length = self.data_dic[key]['length'] + else: + sent_length = opt.maxlen + self.sent_length.append(sent_length) + + lu_name = self.data_dic[key]['lu'] + self.lu_mask.append(self.lu_list[lu_name]['lu_mask']) + self.fe_mask.append(self.lu_list[lu_name]['fe_mask']) + + fe_head = self.data_dic[key]['span_start'] + fe_tail = self.data_dic[key]['span_end'] + + + + while len(fe_head) < opt.fe_padding_num: + fe_head.append(min(sent_length-1, opt.maxlen-1)) + + while len(fe_tail) < opt.fe_padding_num: + fe_tail.append(min(sent_length-1,opt.maxlen-1)) + + self.fe_head.append(fe_head[0:opt.fe_padding_num]) + self.fe_tail.append(fe_tail[0:opt.fe_padding_num]) + + fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']] + + self.fe_cnt.append(min(len(fe_type), opt.fe_padding_num)) + self.fe_cnt_with_padding.append(min(len(fe_type)+1, opt.fe_padding_num)) + + while len(fe_type) < opt.fe_padding_num: + fe_type.append(self.fe_num) + # fe_coretype.append('0') + + self.fe_type.append(fe_type[0:opt.fe_padding_num]) + + def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, target_id, span_maxlen): + indices = [] + for index in range(len(fe_head_list)): + if fe_tail_list[index] - fe_head_list[index] >= span_maxlen: + indices.append(index) + elif fe_tail_list[index] < fe_head_list[index]: + indices.append(index) + + + elif (fe_type_list[index], target_id) not in self.fe_name_to_label: + indices.append(index) + + else: + for i in range(index): + if i not in indices: + if fe_head_list[index] >= fe_head_list[i] and fe_head_list[index] <= fe_tail_list[i]: + indices.append(index) + break + + elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]: + indices.append(index) + break + elif fe_tail_list[index] <= fe_head_list[i] and fe_tail_list[index] >= fe_tail_list[i]: + indices.append(index) + break + else: + continue + + fe_head_list_filter = [i for j, i in enumerate(fe_head_list) if j not in indices] + fe_tail_list_filter = [i for j, i in enumerate(fe_tail_list) if j not in indices] + fe_type_list_filter = [i for j, i in enumerate(fe_type_list) if j not in indices] + self.data_dic[key]['span_start'] = fe_head_list_filter + self.data_dic[key]['span_end'] = fe_tail_list_filter + self.data_dic[key]['span_type'] = fe_type_list_filter + + return len(indices) + + def build_target_mask(self,key,maxlen): + self.target_mask.setdefault(key[0], [0]*maxlen) + + target_head = self.data_dic[key]['target_idx'][0] + target_tail = self.data_dic[key]['target_idx'][1] + self.target_mask[key[0]][target_head] = 1 + self.target_mask[key[0]][target_tail] = 1 + + + + + +def build_token_type_ids(target_head, target_tail, maxlen): + token_type_ids = [0]*maxlen + token_type_ids[target_head] = 1 + token_type_ids[target_tail] = 1 + # token_type_ids[target_head:target_tail+1] = [1]*(target_tail+1-target_head) + + return token_type_ids + + +if __name__ == '__main__': + opt = get_opt() + config = DataConfig(opt) + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + print(dataset.error_span) diff --git a/fn-model_bert/model_syntax36_with_bert.py b/fn-model_bert/model_syntax36_with_bert.py new file mode 100644 index 0000000..fc8ed07 --- /dev/null +++ b/fn-model_bert/model_syntax36_with_bert.py @@ -0,0 +1,530 @@ +import numpy as np +from typing import List,Tuple +import os +import json + +import torch.nn as nn +import torch +import torch.nn.functional as F + +from utils import batched_index_select,get_mask_from_index,generate_perm_inv +from pytorch_pretrained_bert import BertTokenizer,BertModel,BertConfig,BertAdam +from config_bert import PLMConfig,get_opt + +class Mlp(nn.Module): + def __init__(self, input_size, output_size): + super(Mlp, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, input_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + nn.Linear(input_size, output_size), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class Relu_Linear(nn.Module): + def __init__(self, input_size, output_size): + super(Relu_Linear, self).__init__() + self.linear = nn.Sequential( + nn.Linear(input_size, output_size), + nn.Dropout(0.4), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + out = self.linear(x) + return out + + +class CnnNet(nn.Module): + def __init__(self, kernel_size, seq_length, input_size, output_size): + super(CnnNet, self).__init__() + self.seq_length = seq_length + self.output_size = output_size + self.kernel_size = kernel_size + + self.relu = nn.ReLU() + self.conv = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=self.kernel_size + , padding=1) + self.mp = nn.MaxPool1d(kernel_size=self.seq_length) + + def forward(self, input_emb): + input_emb = input_emb.permute(0, 2, 1) + x = self.conv(input_emb) + output = self.mp(x).squeeze() + + return output + + +class PointerNet(nn.Module): + def __init__(self, query_vec_size, src_encoding_size, attention_type='affine'): + super(PointerNet, self).__init__() + + assert attention_type in ('affine', 'dot_prod') + if attention_type == 'affine': + self.src_encoding_linear = Mlp(src_encoding_size, query_vec_size) + + self.src_linear = Mlp(src_encoding_size,src_encoding_size) + self.activate = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(0.5) + + self.fc =nn.Linear(src_encoding_size*2,src_encoding_size, bias=True) + + self.attention_type = attention_type + + def forward(self, src_encodings, src_token_mask,query_vec,head_vec=None): + + # (batch_size, 1, src_sent_len, query_vec_size) + if self.attention_type == 'affine': + src_encod = self.src_encoding_linear(src_encodings).unsqueeze(1) + head_weights = self.src_linear(src_encodings).unsqueeze(1) + + # (batch_size, tgt_action_num, query_vec_size, 1) + if head_vec is not None: + src_encod = torch.cat([src_encod,head_weights],dim = -1) + q = torch.cat([head_vec, query_vec], dim=-1).permute(1, 0, 2).unsqueeze(3) + + + else: + q = query_vec.permute(1, 0, 2).unsqueeze(3) + + weights = torch.matmul(src_encod, q).squeeze(3) + ptr_weights = weights.permute(1, 0, 2) + + # if head_vec is not None: + # src_weights = torch.matmul(head_weights, q_h).squeeze(3) + # src_weights = src_weights.permute(1, 0, 2) + # ptr_weights = weights+src_weights + # + # else: + # ptr_weights = weights + + ptr_weights_masked = ptr_weights.clone().detach() + if src_token_mask is not None: + # (tgt_action_num, batch_size, src_sent_len) + src_token_mask=1-src_token_mask.byte() + src_token_mask = src_token_mask.unsqueeze(0).expand_as(ptr_weights) + # ptr_weights.data.masked_fill_(src_token_mask, -float('inf')) + ptr_weights_masked.data.masked_fill_(src_token_mask, -float('inf')) + + # ptr_weights =self.activate(ptr_weights) + + return ptr_weights,ptr_weights_masked + + +class Encoder(nn.Module): + def __init__(self, opt, config, bert_frozen=False): + super(Encoder, self).__init__() + self.opt =opt + self.hidden_size = opt.rnn_hidden_size + self.emb_size = opt.encoder_emb_size + self.rnn_input_size = self.emb_size+opt.token_type_emb_size + # self.word_number = config.word_number + # self.lemma_number = config.lemma_number + self.maxlen = opt.maxlen + + bert_config = BertConfig.from_json_file(PLMConfig.CONFIG_PATH) + self.bert = BertModel.from_pretrained(PLMConfig.MODEL_PATH) + + if bert_frozen: + print('Bert grad is false') + for param in self.bert.parameters(): + param.requires_grad = False + + self.bert_hidden_size = bert_config.hidden_size + # self.bert_dropout = nn.Dropout(bert_config.hidden_dropout_prob) + + + self.dropout = 0.2 + # self.word_embedding = word_embedding + # self.lemma_embedding = lemma_embedding + # self.pos_embedding = nn.Embedding(config.pos_number, opt.pos_emb_size) + # self.rel_embedding = nn.Embedding(config.rel_number,opt.rel_emb_size) + self.token_type_embedding = nn.Embedding(3, opt.token_type_emb_size) + self.cell_name = opt.cell_name + + # self.embedded_linear = nn.Linear(self.emb_size*2+opt.pos_emb_size+opt.token_type_emb_size+opt.sent_emb_size, + # self.rnn_input_size) + # self.syntax_embedded_linear = nn.Linear(self.emb_size*2+opt.rel_emb_size+opt.token_type_emb_size, + # self.rnn_input_size) + # self.output_combine_linear = nn.Linear(4*self.hidden_size, 2*self.hidden_size) + + self.target_linear = nn.Linear(2*self.hidden_size, 2*self.hidden_size) + + self.relu_linear = Relu_Linear(4*self.hidden_size+self.rnn_input_size, opt.decoder_emb_size) + + if self.cell_name == 'gru': + self.rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.GRU(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + elif self.cell_name == 'lstm': + self.rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + dropout=self.dropout,bidirectional=True, batch_first=True) + # self.syntax_rnn = nn.LSTM(self.rnn_input_size, self.hidden_size,num_layers=self.opt.num_layers, + # dropout=self.dropout,bidirectional=True, batch_first=True) + else: + print('cell_name error') + + def forward(self, word_input: torch.Tensor, lengths:torch.Tensor, frame_idx, token_type_ids=None, attention_mask=None,target_mask_ids=None): + + # word_embedded = self.word_embedding(word_input) + # lemma_embedded = self.lemma_embedding(lemma_input) + # pos_embedded = self.pos_embedding(pos_input) + # head_embedded = self.word_embedding(head_input) + # rel_embedded = self.rel_embedding(rel_input) + # type_ids =torch.add(token_type_ids, target_mask_ids) + token_type_embedded = self.token_type_embedding(token_type_ids) + #print(token_type_embedded) + # print(token_type_ids.size()) + # print(target_mask_ids.size()) + # print(token_type_embedded.size()) + hidden_state, cls = self.bert(word_input, token_type_ids, \ + attention_mask=attention_mask,\ + output_all_encoded_layers=False) + + embedded = torch.cat([hidden_state.squeeze(),token_type_embedded], dim=-1) + # sent_embedded = self.cnn(embedded) + + # sent_embedded = sent_embedded.expand([self.opt.maxlen, self.opt.batch_size, self.opt.sent_emb_size]).permute(1, 0, 2) + #embedded = torch.cat([embedded,sent_embedded], dim=-1) + #embedded = self.embedded_linear(embedded) + + # syntax embedding + # syntax_embedded = torch.cat([word_embedded,head_embedded,rel_embedded,token_type_ids],dim=-1) + # syntax_embedded = self.syntax_embedded_linear(syntax_embedded) + + lengths=lengths.squeeze() + # sorted before pack + l = lengths.cpu().numpy() + perm_idx = np.argsort(-l) + perm_idx_inv = generate_perm_inv(perm_idx) + + embedded = embedded[perm_idx] + # syntax_embedded = syntax_embedded[perm_idx] + + if lengths is not None: + rnn_input = nn.utils.rnn.pack_padded_sequence(embedded, lengths=lengths[perm_idx], + batch_first=True) + # syntax_rnn_input = nn.utils.rnn.pack_padded_sequence(syntax_embedded, lengths=lengths[perm_idx], + # batch_first=True) + output, hidden = self.rnn(rnn_input) + #syntax_output, syntax_hidden = self.rnn(syntax_rnn_input) + + if lengths is not None: + output, _ = nn.utils.rnn.pad_packed_sequence(output, total_length=self.maxlen, batch_first=True) + # syntax_output, _ = nn.utils.rnn.pad_packed_sequence(syntax_output, total_length=self.maxlen, batch_first=True) + + + # print(output.size()) + # print(hidden.size()) + + output = output[perm_idx_inv] + # syntax_output = syntax_output[perm_idx_inv] + + if self.cell_name == 'gru': + hidden = hidden[:, perm_idx_inv] + hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(hidden, 1, dim=0)) + + # syntax_hidden = syntax_hidden[:, perm_idx_inv] + # syntax_hidden = (lambda a: sum(a)/(2*self.opt.num_layers))(torch.split(syntax_hidden, 1, dim=0)) + # hidden = (hidden + syntax_hidden) / 2 + + elif self.cell_name == 'lstm': + hn0 = hidden[0][:, perm_idx_inv] + hn1 = hidden[1][:, perm_idx_inv] + # sy_hn0 = syntax_hidden[0][:, perm_idx_inv] + # sy_hn1 = syntax_hidden[1][:, perm_idx_inv] + hn = tuple([hn0,hn1]) + hidden = tuple(map(lambda state: sum(torch.split(state, 1, dim=0))/(2*self.opt.num_layers), hn)) + + + target_state_head = batched_index_select(target=output, indices=frame_idx[0]) + target_state_tail = batched_index_select(target=output, indices=frame_idx[1]) + target_state = (target_state_head + target_state_tail) / 2 + target_state = self.target_linear(target_state) + + target_emb_head = batched_index_select(target=embedded, indices=frame_idx[0]) + target_emb_tail = batched_index_select(target=embedded, indices=frame_idx[1]) + target_emb = (target_emb_head + target_emb_tail) / 2 + + attentional_target_state = type_attention(attention_mask=target_mask_ids, hidden_state=output, + target_state=target_state) + + + target_state =torch.cat([target_state.squeeze(),attentional_target_state.squeeze(), target_emb.squeeze()], dim=-1) + target = self.relu_linear(target_state) + + # print(output.size()) + return output, hidden, target + + +class Decoder(nn.Module): + def __init__(self, opt, embedding_frozen=False): + super(Decoder, self).__init__() + + # rnn _init_ + self.opt = opt + self.cell_name = opt.cell_name + self.emb_size = opt.decoder_emb_size + self.hidden_size = opt.decoder_hidden_size + self.encoder_hidden_size = opt.rnn_hidden_size + + # decoder _init_ + self.decodelen = opt.fe_padding_num+1 + self.frame_embedding = nn.Embedding(opt.frame_number+1, self.emb_size) + self.frame_fc_layer =Mlp(self.emb_size, opt.frame_number+1) + self.role_embedding = nn.Embedding(opt.role_number+1, self.emb_size) + self.role_feature_layer = nn.Linear(2*self.emb_size, self.emb_size) + self.role_fc_layer = nn.Linear(self.hidden_size+self.emb_size, opt.role_number+1) + + self.head_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + self.tail_fc_layer = Mlp(self.hidden_size+self.emb_size, self.hidden_size) + + self.span_fc_layer = Mlp(4 * self.encoder_hidden_size + self.emb_size, self.emb_size) + + self.next_input_fc_layer = Mlp(self.hidden_size+self.emb_size, self.emb_size) + + if embedding_frozen is True: + for param in self.frame_embedding.parameters(): + param.requires_grad = False + for param in self.role_embedding.parameters(): + param.requires_grad = False + + if self.cell_name == 'gru': + self.frame_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.GRU(self.emb_size, self.hidden_size, batch_first=True) + if self.cell_name == 'lstm': + self.frame_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.ent_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + self.role_rnn = nn.LSTM(self.emb_size, self.hidden_size, batch_first=True) + + # pointer _init_ + self.ent_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.head_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + self.tail_pointer = PointerNet(query_vec_size=self.hidden_size, src_encoding_size=2*self.encoder_hidden_size) + + def forward(self, encoder_output: torch.Tensor, encoder_state: torch.Tensor, target_state: torch.Tensor, + attention_mask: torch.Tensor, fe_mask=None, lu_mask=None): + pred_frame_list = [] + pred_head_list = [] + pred_tail_list = [] + pred_role_list = [] + + pred_frame_action = [] + + pred_head_action = [] + pred_tail_action = [] + pred_role_action = [] + + frame_decoder_state = encoder_state + role_decoder_state = encoder_state + + input = target_state + # print(input.size()) + + span_mask = attention_mask.clone() + for t in range(self.decodelen): + # frame pred + output, frame_decoder_state = self.decode_step(self.frame_rnn, input=input, + decoder_state=frame_decoder_state) + + pred_frame_weight = self.frame_fc_layer(target_state.squeeze()) + pred_frame_weight_masked = pred_frame_weight.clone().detach() + + if lu_mask is not None: + LU_mask = 1-lu_mask + pred_frame_weight_masked.data.masked_fill_(LU_mask.byte(), -float('inf')) + + pred_frame_indices = torch.argmax(pred_frame_weight_masked.squeeze(), dim=-1).squeeze() + + pred_frame_list.append(pred_frame_weight) + pred_frame_action.append(pred_frame_indices) + + frame_emb = self.frame_embedding(pred_frame_indices) + + head_input = self.head_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + tail_input = self.tail_fc_layer(torch.cat([output.squeeze(), frame_emb], dim=-1)) + + head_pointer_weight, head_pointer_weight_masked = self.head_pointer(src_encodings=encoder_output, + src_token_mask=span_mask, + query_vec=head_input.view(1, self.opt.batch_size, -1)) + + head_indices = torch.argmax(head_pointer_weight_masked.squeeze(), dim=-1).squeeze() + head_target = batched_index_select(target=encoder_output, indices=head_indices.squeeze()) + head_mask = head_mask_update(span_mask, head_indices=head_indices, max_len=self.opt.maxlen) + + tail_pointer_weight, tail_pointer_weight_masked = self.tail_pointer(src_encodings=encoder_output, + src_token_mask=head_mask, + query_vec=tail_input.view(1, self.opt.batch_size,-1), + head_vec=head_target.view(1,self.opt.batch_size,-1)) + + tail_indices = torch.argmax(tail_pointer_weight_masked.squeeze(), dim=-1).squeeze() + # tail_target = batched_index_select(target=bert_hidden_state,indices=tail_indices.squeeze()) + + span_mask = span_mask_update(attention_mask=span_mask, head_indices=head_indices, + tail_indices=tail_indices, max_len=self.opt.maxlen) + + pred_head_list.append(head_pointer_weight) + pred_tail_list.append(tail_pointer_weight) + pred_head_action.append(head_indices) + pred_tail_action.append(tail_indices) + + # role pred + + # print(ent_target.size()) + # print(head_target.size()) + # print(tail_target.size()) + # print(bert_hidden_state.size()) + # print(tail_pointer_weight.size()) + # print(tail_indices.size()) + # print(output.size()) + + # next step + # head_target = batched_index_select(target=bert_hidden_state, indices=head_indices.squeeze()) + + tail_target = batched_index_select(target=encoder_output, indices=tail_indices.squeeze()) + + # head_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=head_target, + # window_size=0, max_len=self.opt.maxlen) + # + # tail_context =local_attention(attention_mask=attention_mask, hidden_state=encoder_output, + # frame_idx=(head_indices, tail_indices), target_state=tail_target, + # window_size=0, max_len=self.opt.maxlen) + + span_input = self.span_fc_layer(torch.cat([head_target+tail_target, head_target-tail_target, frame_emb], dim=-1)).unsqueeze(1) + + output,role_decoder_state = self.decode_step(self.role_rnn, input=span_input, + decoder_state=role_decoder_state) + role_target = self.role_fc_layer(torch.cat([span_input,output],dim=-1)) + role_target_masked = role_target.squeeze().clone().detach() + if fe_mask is not None : + FE_mask = 1-fe_mask + # print(FE_mask.size()) + role_target_masked.data.masked_fill_(FE_mask.byte(), -float('inf')) + + role_indices = torch.argmax(role_target_masked.squeeze(), dim=-1).squeeze() + role_emb = self.role_embedding(role_indices) + + pred_role_list.append(role_target) + pred_role_action.append(role_indices) + + # next step + next_input =torch.cat([output, role_emb.unsqueeze(1)], dim=-1) + input = self.next_input_fc_layer(next_input) + + + #return + return pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action,\ + pred_head_action, pred_tail_action, pred_role_action + + def decode_step(self, rnn_cell: nn.modules, input: torch.Tensor, decoder_state: torch.Tensor): + + output, state = rnn_cell(input.view(-1, 1, self.emb_size), decoder_state) + + return output, state + + +class Model(nn.Module): + def __init__(self, opt, config): + super(Model, self).__init__() + # self.word_vectors = config.word_vectors + # self.lemma_vectors = config.lemma_vectors + # self.word_embedding = nn.Embedding(config.word_number+1, opt.encoder_emb_size) + # self.lemma_embedding = nn.Embedding(config.lemma_number+1, opt.encoder_emb_size) + # + # if load_emb: + # self.load_pretrain_emb() + + self.encoder = Encoder(opt, config) + self.decoder = Decoder(opt) + + # def load_pretrain_emb(self): + # self.word_embedding.weight.data.copy_(torch.from_numpy(self.word_vectors)) + # self.lemma_embedding.weight.data.copy_(torch.from_numpy(self.lemma_vectors)) + + def forward(self, word_ids, lengths, frame_idx, fe_mask=None, lu_mask=None, + frame_len=None, token_type_ids=None, attention_mask=None,target_mask_ids=None): + encoder_output, encoder_state, target_state = self.encoder(word_input=word_ids, lengths=lengths, + frame_idx=frame_idx, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + target_mask_ids=target_mask_ids) + + pred_frame_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + pred_head_action, pred_tail_action, pred_role_action = self.decoder(encoder_output=encoder_output, + encoder_state=encoder_state, + target_state=target_state, + attention_mask=attention_mask, + fe_mask=fe_mask, lu_mask=lu_mask) + + # return pred_frame_list, pred_ent_list, pred_head_list, pred_tail_list, pred_role_list, pred_frame_action, \ + # pred_ent_action, pred_head_action, pred_tail_action, pred_role_action + + return { + 'pred_frame_list' : pred_frame_list, + 'pred_head_list' : pred_head_list, + 'pred_tail_list' : pred_tail_list, + 'pred_role_list' : pred_role_list, + 'pred_frame_action' : pred_frame_action, + 'pred_head_action' : pred_head_action, + 'pred_tail_action' : pred_tail_action, + 'pred_role_action' : pred_role_action + } + + +def head_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, max_len): + indices=head_indices + indices_mask=1-get_mask_from_index(indices, max_len) + mask = torch.mul(attention_mask, indices_mask.long()) + + return mask + + +def span_mask_update(attention_mask: torch.Tensor, head_indices: torch.Tensor, tail_indices: torch.Tensor, max_len): + tail = tail_indices + 1 + head_indices_mask = get_mask_from_index(head_indices, max_len) + tail_indices_mask = get_mask_from_index(tail, max_len) + span_indices_mask = tail_indices_mask - head_indices_mask + span_indices_mask = 1 - span_indices_mask + mask = torch.mul(attention_mask, span_indices_mask.long()) + + return mask + + +def local_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, frame_idx, + target_state: torch.Tensor, window_size: int, max_len): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + head = frame_idx[0]-window_size + tail = frame_idx[1]+window_size + mask = span_mask_update(attention_mask=attention_mask, head_indices=head.squeeze(), + tail_indices=tail.squeeze(), max_len=max_len) + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + +def type_attention(attention_mask: torch.Tensor, hidden_state: torch.Tensor, + target_state: torch.Tensor): + + q = target_state.squeeze().unsqueeze(2) + context_att = torch.bmm(hidden_state, q).squeeze() + + mask = 1-attention_mask + + context_att = context_att.masked_fill_(mask.byte(), -float('inf')) + context_att = F.softmax(context_att, dim=-1) + attentional_hidden_state = torch.bmm(hidden_state.permute(0, 2, 1), context_att.unsqueeze(2)).squeeze() + + return attentional_hidden_state + + diff --git a/fn-model_bert/train_syntax36_bert.py b/fn-model_bert/train_syntax36_bert.py new file mode 100644 index 0000000..96908c5 --- /dev/null +++ b/fn-model_bert/train_syntax36_bert.py @@ -0,0 +1,256 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import os +import multiprocessing as mp + +from dataset_bert import FrameNetDataset +from torch.utils.data import DataLoader +from utils import get_mask_from_index,seed_everything +from evaluate10 import Eval +from config_bert import get_opt +from data_process_bert import DataConfig +from model_syntax36_with_bert import Model + +from pytorch_pretrained_bert import BertTokenizer,BertModel,BertConfig,BertAdam + + + +def evaluate(opt, model, dataset, best_metrics=None, show_case=False): + model.eval() + print('begin eval') + evaler = Eval(opt) + with torch.no_grad(): + test_dl = DataLoader( + dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + + for batch in test_dl: + word_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids, lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + + evaler.metrics(batch_size=opt.batch_size, fe_cnt=fe_cnt, gold_fe_type=fe_type, gold_fe_head=fe_head, \ + gold_fe_tail=fe_tail, gold_frame_type=target_type, + pred_fe_type=return_dic['pred_role_action'], + pred_fe_head=return_dic['pred_head_action'], + pred_fe_tail=return_dic['pred_tail_action'], + pred_frame_type=return_dic['pred_frame_action'], + fe_coretype=dataset.fe_coretype_table,sent_length=sent_length, + lu_mask=lu_mask) + + if show_case: + print('gold_fe_label = ', fe_type) + print('pred_fe_label = ', return_dic['pred_role_action']) + print('gold_head_label = ', fe_head) + print('pred_head_label = ', return_dic['pred_head_action']) + print('gold_tail_label = ', fe_tail) + print('pred_tail_label = ', return_dic['pred_tail_action']) + + metrics = evaler.calculate() + + + if best_metrics: + + if metrics[-1] > best_metrics: + best_metrics = metrics[-1] + + torch.save(model.state_dict(), opt.save_model_path) + + return best_metrics + + + + + + +if __name__ == '__main__': + + os.environ['CUDA_LAUNCH_BLOCKING'] = "1" + # os.environ['CUDA_LAUNCH_BLOCKING'] = 1 + + # bertconfig = BertConfig() + # print(bertconfig.CONFIG_PATH) + mp.set_start_method('spawn') + + opt = get_opt() + config = DataConfig(opt) + + if torch.cuda.is_available(): + device = torch.device(opt.cuda) + else: + device = torch.device('cpu') + + seed_everything(1116) + + epochs = opt.epochs + model = Model(opt, config) + model.to(device) + + clip_grad = 1.0 + # warmup率,前多少比例的步,lr逐渐升高 + warmup_proportion = 0.2 + # 学习率 + learning_rate = opt.lr + + best_metrics = 0.0 + + pretrain_dataset = FrameNetDataset(opt, config, config.exemplar_instance_dic, device) + train_dataset = FrameNetDataset(opt, config, config.train_instance_dic, device) + dev_dataset = FrameNetDataset(opt, config, config.dev_instance_dic, device) + test_dataset = FrameNetDataset(opt, config, config.test_instance_dic, device) + + num_train_steps = int((len(train_dataset) / opt.batch_size) * opt.epochs) + # prepare optimizer + param_optimizer = list(model.named_parameters()) + # bias和LayerNorm的参数不需要权重衰减,不然可能LayerNorm效果会不对 + no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, + {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0}] + + optimizer = BertAdam(optimizer_grouped_parameters, + lr=learning_rate, + warmup=warmup_proportion, + t_total=num_train_steps, + max_grad_norm=clip_grad) + + frame_criterion =nn.CrossEntropyLoss() + head_criterion = nn.CrossEntropyLoss() + tail_criterion = nn.CrossEntropyLoss() + fe_type_criterion = nn.CrossEntropyLoss() + + if os.path.exists(opt.save_model_path) is True: + model.load_state_dict(torch.load(opt.save_model_path)) + + best_metrics = -1 + + if opt.mode == 'train': + for epoch in range(1, epochs): + # scheduler.step() + train_dl = DataLoader( + train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=0 + ) + model.train() + print('==========epochs= ' + str(epoch)) + step = 0 + sum_loss = 0 + # best_metrics = -1 + cnt = 0 + for batch in train_dl: + optimizer.zero_grad() + loss = 0 + word_ids, lengths, attention_mask, target_head, target_tail, \ + target_type, fe_head, fe_tail, fe_type, fe_cnt, \ + fe_cnt_with_padding, fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids = batch + + return_dic = model(word_ids=word_ids + ,lengths=lengths, + frame_idx=(target_head, target_tail), + token_type_ids=token_type_ids, + attention_mask=attention_mask, fe_mask=fe_mask, lu_mask=lu_mask,target_mask_ids=target_mask_ids) + # print(return_dic) + + frame_loss = 0 + head_loss = 0 + tail_loss = 0 + type_loss = 0 + + for batch_index in range(opt.batch_size): + pred_frame_first = return_dic['pred_frame_list'][fe_cnt[batch_index]][batch_index].unsqueeze(0) + pred_frame_last = return_dic['pred_frame_list'][0][batch_index].unsqueeze(0) + + pred_frame_label = pred_frame_last + + gold_frame_label = target_type[batch_index] + # print(gold_frame_label.size()) + # print(pred_frame_label.size()) + # print(fe_head) + frame_loss += frame_criterion(pred_frame_label, gold_frame_label) + + + for fe_index in range(opt.fe_padding_num): + + # print(fe_cnt[batch_index]) + + + pred_type_label = return_dic['pred_role_list'][fe_index].squeeze() + pred_type_label = pred_type_label[batch_index].unsqueeze(0) + + gold_type_label = fe_type[batch_index][fe_index].unsqueeze(0) + type_loss += fe_type_criterion(pred_type_label, gold_type_label) + + + if fe_index >= fe_cnt[batch_index]: + break + + pred_head_label = return_dic['pred_head_list'][fe_index].squeeze() + pred_head_label = pred_head_label[batch_index].unsqueeze(0) + + gold_head_label = fe_head[batch_index][fe_index].unsqueeze(0) + # print(gold_head_label.size()) + # print(pred_head_label.size()) + head_loss += head_criterion(pred_head_label, gold_head_label) + + pred_tail_label = return_dic['pred_tail_list'][fe_index].squeeze() + pred_tail_label = pred_tail_label[batch_index].unsqueeze(0) + + gold_tail_label = fe_tail[batch_index][fe_index].unsqueeze(0) + tail_loss += tail_criterion(pred_tail_label, gold_tail_label) + + + # print(fe_cnt[batch_index]) + # head_loss /= int(fe_cnt[batch_index]) + # tail_loss /= int(fe_cnt[batch_index]) + # type_loss /= int(fe_cnt[batch_index]+1) + # + # head_loss_total+=head_loss + # tail_loss_total+=tail_loss + # type_loss_total+=type_loss + + loss = (0.1 * frame_loss + 0.3 * type_loss + 0.3 * head_loss + 0.3 * tail_loss) / (opt.batch_size) + # loss = (0.3 * head_loss + 0.3 * tail_loss) / (opt.b0.3 * atch_size) + loss.backward() + optimizer.step() + # loss+=frame_loss() + step += 1 + if step % 20 == 0: + print(" | batch loss: %.6f step = %d" % (loss.item(), step)) + # print('gold_frame_label = ',target_type) + # print('pred_frame_label = ',return_dic['pred_frame_action']) + + for index in range(len(target_type)): + if target_type[index] == return_dic['pred_frame_action'][0][index]: + cnt += 1 + # print('gold_fe_label = ', fe_type) + # print('pred_fe_label = ', return_dic['pred_role_action']) + # print('gold_head_label = ', fe_head) + # print('pred_head_label = ', return_dic['pred_head_action']) + # print('gold_tail_label = ', fe_tail) + # print('pred_tail_label = ', return_dic['pred_tail_action']) + sum_loss += loss.item() + + print('| epoch %d avg loss = %.6f' % (epoch, sum_loss / step)) + print('| epoch %d prec = %.6f' % (epoch, cnt / (opt.batch_size * step))) + + best_metrics=evaluate(opt,model,dev_dataset,best_metrics) + + + else: + evaluate(opt,model,test_dataset,show_case=True) + + + +