-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
154 lines (112 loc) · 4.36 KB
/
utils.py
File metadata and controls
154 lines (112 loc) · 4.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import math
from pathlib import Path
from typing import List, Dict, Optional, Tuple
import pickle
from gensim.models import KeyedVectors
import numpy as np
import torch
from torch import Tensor
PAD_IDX = 0
UNK_IDX = 1
unique_tags = ["PAD", "O", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-DATE", "I-DATE"]
tag2idx = {tag: idx for idx, tag in enumerate(unique_tags)}
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
def predict_sentence(model, sentence, w2i, i2t, device="cpu"):
encoded_sent = encode_sent(sentence, w2i)
score, tags = model([encoded_sent])
tags = decode_tags(tags[0], i2t)
return score.item(), tags
def load_pickle(filename):
with open(filename, "rb") as f:
data = pickle.load(f)
return data
def load_data(filename: str):
"""
Load the data from the given filename.
"""
with Path(filename).open('r', encoding="utf-8") as f:
data = f.read().split('\n')
return data
def strip_sents_and_tags(sents: List, tags: List):
tmp_train_sents, tmp_train_tags = [], []
for sent, tag in zip(sents, tags):
if sent.strip():
tmp_train_sents.append(sent.strip())
tmp_train_tags.append(tag.strip())
return tmp_train_sents, tmp_train_tags
def encode_tags(sent_tags: List, tag2idx: Dict):
"""
Replace the tags (O, B-LOC etc.) with the corresponding idx from the
tag2idx dictionary
"""
encoded_tags = [tag2idx[token_tag] for token_tag in sent_tags]
return encoded_tags
def decode_tags(tags: List, idx2tag: Dict):
"""
Decode the tags by replacing the tag indices with the original tags.
"""
decoded_tags = [idx2tag[tag_idx] for tag_idx in tags] # if tag_idx != PAD_IDX
return decoded_tags
def load_wv(filename: str, limit: Optional[int]=None) -> KeyedVectors:
"""
Load the fastText pretrained word embeddings from given filename.
"""
embeddings = KeyedVectors.load_word2vec_format(filename,
binary=False,
limit=limit,
unicode_errors='ignore')
return embeddings
def encode_sent(sent_tokens: List, word2idx: Dict) -> List:
"""
Replace the tokens with the corresponding index from `word2idx`
dictionary.
"""
encoded_sent = [word2idx.get(token, UNK_IDX) for token in sent_tokens]
return encoded_sent
def decode_sent(sent: List, idx2word: Dict) -> List:
"""
Decode the sentence to the original form by replacing token indices
with the words.
"""
decoded_sent = [idx2word[token_idx] for token_idx in sent if token_idx != PAD_IDX]
return decoded_sent
def pad_sequences(sequences: List[List], pad_idx: Optional[int]=0) -> List[List]:
"""
Pad the sequences to the maximum length sequence.
"""
max_len = max([len(seq) for seq in sequences])
padded_sequence = []
for seq in sequences:
seq_len = len(seq)
pad_len = max_len - seq_len
padded_seq = seq + [pad_idx] * pad_len
padded_sequence.append(padded_seq)
return padded_sequence
def to_tensor(sents: List[List], device: str="cpu") -> Tensor:
"""
Pad the sentences and convert them to the torch tensor.
"""
padded_sents = pad_sequences(sents)
sent_tensor = torch.tensor(padded_sents, dtype=torch.long, device=device)
return sent_tensor # (batch_size, max_seq_len)
def generate_sent_masks(sents: Tensor, lengths: Tensor) -> Tensor:
"""
Generate the padding masking for given sents from lenghts.
Assumes lengths are sorted by descending order (batch_iter provides this).
"""
max_len = lengths[0]
bs = sents.shape[0]
mask = torch.arange(max_len).expand(bs, max_len) < lengths.unsqueeze(1)
return mask.byte()
def batch_iter(data: List[List], batch_size: int, shuffle: bool=False) -> Tuple[List, List]:
batch_num = math.ceil(len(data) / batch_size)
index_array = list(range(len(data)))
if shuffle:
np.random.shuffle(index_array)
for i in range(batch_num):
indices = index_array[i * batch_size: (i+1) * batch_size]
examples = [data[idx] for idx in indices]
examples = sorted(examples, key=lambda e: len(e[0]), reverse=True)
sents = [e[0] for e in examples]
tags = [e[1] for e in examples]
yield sents, tags