Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_opt():
# 数据集位置
parser.add_argument('--data_path', type=str, default='parsed-v1.5/')
parser.add_argument('--emb_file_path', type=str, default='./glove.6B/glove.6B.200d.txt')
parser.add_argument('--exemplar_instance_path', type=str, default='./exemplar_instance_dic.npy')
parser.add_argument('--train_instance_path', type=str, default='./train_instance_dic.npy')
parser.add_argument('--dev_instance_path', type=str, default='./dev_instance_dic.npy')
parser.add_argument('--test_instance_path', type=str, default='./test_instance_dic.npy')
Expand All @@ -25,21 +26,23 @@ def get_opt():
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
parser.add_argument('--model_name', type=str, default='train_model')
parser.add_argument('--pretrain_model', type=str, default='')
parser.add_argument('--save_model_path', type=str, default='./models_ft/model_syntax.bin')
parser.add_argument('--save_model_path', type=str, default='./models_ft/wo_jointarg.bin')


# 训练相关
parser.add_argument('--lr', type=float, default='0.0001')
parser.add_argument('--lr', type=float, default='0.00006')
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--save_model_freq', type=int, default=1) # 保存模型间隔,以epoch为单位
parser.add_argument('--cuda', type=str, default="cuda:0")
parser.add_argument('--mode', type=str, default="train")
parser.add_argument('--mode', type=str, default="test")

# 模型的一些settings
parser.add_argument('--maxlen',type=int, default=256)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--cell_name', type=str, default='lstm')
#change
parser.add_argument('--rnn_hidden_size',type=int, default=256)
parser.add_argument('--rnn_emb_size', type=int, default=400)
parser.add_argument('--encoder_emb_size',type=int, default=200)
Expand All @@ -56,7 +59,7 @@ def get_opt():
parser.add_argument('--role_number',type=int, default=9634)
parser.add_argument('--fe_padding_num',type=int, default=5)
parser.add_argument('--window_size', type=int, default=3)
parser.add_argument('--num_layers', type=int, default=4)
parser.add_argument('--num_layers', type=int, default=2)

return parser.parse_args()

Expand Down
5 changes: 3 additions & 2 deletions data_process2.py → data_process4.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(self,opt):
self.maxlen = opt.maxlen

if opt.load_instance_dic:
self.exemplar_instance_dic = np.load(opt.exemplar_instance_path, allow_pickle=True).item()
self.train_instance_dic = np.load(opt.train_instance_path, allow_pickle=True).item()
self.dev_instance_dic = np.load(opt.dev_instance_path, allow_pickle=True).item()
self.test_instance_dic = np.load(opt.test_instance_path, allow_pickle=True).item()
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(self,opt):
self.pos_number = 0
self.rel_number = 0

#self.build_word_index(self.exemplar_instance_dic)
self.build_word_index(self.exemplar_instance_dic)
self.build_word_index(self.train_instance_dic)
self.build_word_index(self.dev_instance_dic)
self.build_word_index(self.test_instance_dic)
Expand Down Expand Up @@ -357,7 +358,7 @@ def get_lu_list(path, lu_num, fe_num, frame_id_to_label, fe_mask_list, file='LU.
lu_name_to_id[(lu_name, lu_dt['FrameID'][idx])] = lu_dt['ID'][idx]

for key in lu_list.keys():
#lu_list[key]['lu_mask'][lu_num] = 1
# lu_list[key]['lu_mask'][lu_num] = 1
lu_list[key]['fe_mask'][fe_num] = 1

return lu_list, lu_id_to_name, lu_name_to_id
Expand Down
60 changes: 51 additions & 9 deletions dataset2.py → dataset5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from torch.utils.data import Dataset
from config import get_opt
from data_process2 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig
from data_process4 import get_frame_tabel, get_fe_tabel, get_fe_list, get_lu_list,DataConfig
from utils import get_mask_from_index

class FrameNetDataset(Dataset):
Expand Down Expand Up @@ -56,21 +56,40 @@ def __init__(self, opt, config, data_dic, device):
self.fe_mask = []
self.lu_mask = []
self.token_type_ids = []
self.target_mask_ids = []

self.device = device
self.oov_frame = 0
self.long_span = 0
self.error_span = 0
self.fe_coretype_table = {}
self.target_mask = {}

for idx in self.fe_id_to_type.keys():
if self.fe_id_to_type[idx] == 'Core':
self.fe_coretype_table[idx] = 1
self.fe_coretype_table[self.fe_id_to_label[idx]] = 1
else:
self.fe_coretype_table[idx] = 0
self.fe_coretype_table[self.fe_id_to_label[idx]] = 0



for key in self.data_dic.keys():
self.build_target_mask(key,opt.maxlen)


for key in self.data_dic.keys():
self.pre_process(key, opt)

self.pad_dic_cnt = self.dataset_len % opt.batch_size


for idx,key in enumerate(self.data_dic.keys()):
if idx >= self.pad_dic_cnt:
break
self.pre_process(key, opt,filter=False)

self.dataset_len+=self.pad_dic_cnt

print('load data finish')
print('oov frame = ', self.oov_frame)
print('long_span = ', self.long_span)
Expand Down Expand Up @@ -99,19 +118,24 @@ def __getitem__(self, item):
fe_mask = torch.Tensor(self.fe_mask[item]).long().to(self.device)
lu_mask = torch.Tensor(self.lu_mask[item]).long().to(self.device)
token_type_ids = torch.Tensor(self.token_type_ids[item]).long().to(self.device)
sent_length = torch.Tensor([self.sent_length[item]]).long().to(self.device)
target_mask_ids = torch.Tensor(self.target_mask_ids[item]).long().to(self.device)
# print(fe_cnt)


return (word_ids, lemma_ids, pos_ids,head_ids, rel_ids, lengths, mask, target_head, target_tail, target_type,
fe_head, fe_tail, fe_type, fe_cnt, fe_cnt_with_padding,
fe_mask, lu_mask, token_type_ids)
fe_mask, lu_mask, token_type_ids,sent_length,target_mask_ids)

def pre_process(self, key, opt):
def pre_process(self, key, opt,filter=True):
if self.data_dic[key]['target_type'] not in self.frame_name_to_label:
self.oov_frame += 1
self.dataset_len -= 1
return

target_id = self.frame_name_to_id[self.data_dic[key]['target_type']]
self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'],
if filter:
self.long_span += self.remove_error_span(key, self.data_dic[key]['span_start'],
self.data_dic[key]['span_end'], self.data_dic[key]['span_type'], target_id, 20)

word_ids = [self.word_index[word] for word in self.data_dic[key]['word_list']]
Expand All @@ -135,7 +159,9 @@ def pre_process(self, key, opt):
self.mask.append(mask)

token_type_ids = build_token_type_ids(self.data_dic[key]['target_idx'][0], self.data_dic[key]['target_idx'][1], opt.maxlen)
# token_type_ids +=self.target_mask[key[0]]
self.token_type_ids.append(token_type_ids)
self.target_mask_ids.append(self.target_mask[key[0]])

self.target_type.append(self.frame_name_to_label[self.data_dic[key]['target_type']])

Expand All @@ -154,8 +180,7 @@ def pre_process(self, key, opt):
fe_head = self.data_dic[key]['span_start']
fe_tail = self.data_dic[key]['span_end']

self.fe_cnt.append(min(len(fe_head), opt.fe_padding_num))
self.fe_cnt_with_padding.append(min(len(fe_head)+1, opt.fe_padding_num))


while len(fe_head) < opt.fe_padding_num:
fe_head.append(min(sent_length-1, opt.maxlen-1))
Expand All @@ -168,6 +193,9 @@ def pre_process(self, key, opt):

fe_type = [self.fe_name_to_label[(item, target_id)] for item in self.data_dic[key]['span_type']]

self.fe_cnt.append(min(len(fe_type), opt.fe_padding_num))
self.fe_cnt_with_padding.append(min(len(fe_type)+1, opt.fe_padding_num))

while len(fe_type) < opt.fe_padding_num:
fe_type.append(self.fe_num)
# fe_coretype.append('0')
Expand All @@ -181,7 +209,7 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe
indices.append(index)
elif fe_tail_list[index] < fe_head_list[index]:
indices.append(index)
continue


elif (fe_type_list[index], target_id) not in self.fe_name_to_label:
indices.append(index)
Expand All @@ -196,6 +224,9 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe
elif fe_tail_list[index] >= fe_head_list[i] and fe_tail_list[index] <= fe_tail_list[i]:
indices.append(index)
break
elif fe_tail_list[index] <= fe_head_list[i] and fe_tail_list[index] >= fe_tail_list[i]:
indices.append(index)
break
else:
continue

Expand All @@ -208,6 +239,17 @@ def remove_error_span(self, key, fe_head_list, fe_tail_list, fe_type_list, targe

return len(indices)

def build_target_mask(self,key,maxlen):
self.target_mask.setdefault(key[0], [0]*maxlen)

target_head = self.data_dic[key]['target_idx'][0]
target_tail = self.data_dic[key]['target_idx'][1]
self.target_mask[key[0]][target_head] = 1
self.target_mask[key[0]][target_tail] = 1





def build_token_type_ids(target_head, target_tail, maxlen):
token_type_ids = [0]*maxlen
Expand Down
83 changes: 0 additions & 83 deletions evaluate2.py

This file was deleted.

Loading