-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparsetree.py
More file actions
69 lines (52 loc) · 1.66 KB
/
parsetree.py
File metadata and controls
69 lines (52 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import pickle
import torch
import numpy as np
from sklearn.decomposition import PCA
from entailmenttree import EntailmentTree
def parse_trees(file_path):
"""
Parses JSON encoded dateset stored at filepath
and returns a list of trees.
"""
trees_json = []
with open(file_path, 'r') as file:
for line in file:
json_data = json.loads(line)
trees_json.append(json_data)
trees = [EntailmentTree(tree_json) for tree_json in trees_json]
return trees
def reduce_embeddings(trees, d_new):
"""
Given a list of trees,
reduce the dimension of all embeddings
from d to d' via PCA.
:param trees: the list of trees
:param d_new: reduced dimensionality
:return trees: the updated list of trees
"""
# Get N x d tensor of all embeddings
embeddings = []
for tree in trees:
for id, embedding in tree.id_to_embedding.items():
embeddings.append(embedding.unsqueeze(0))
embeddings = torch.cat(embeddings, dim=0)
print(f'embeddings.shape = {embeddings.shape}')
# Fit PCA
N, d = embeddings.shape
pca = PCA(n_components=d_new)
pca.fit(embeddings.numpy()) # Fit PCA on the data
# Update embeddings
for tree in trees:
for id, embedding in tree.id_to_embedding.items():
e = embedding.unsqueeze(0).numpy()
tree.id_to_embedding[id] = torch.tensor(pca.transform(e)).squeeze()
return trees
if __name__ == "__main__":
dataset = "dev"
original_dataset_fp = f'data/task_1/{dataset}.jsonl'
processed_dataset_fp = f'data/processed/{dataset}.pkl'
trees = parse_trees(original_dataset_fp)
trees = reduce_embeddings(trees, 32)
with open(processed_dataset_fp, 'wb') as file:
pickle.dump(trees, file)