-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest.py
More file actions
93 lines (74 loc) · 2.28 KB
/
test.py
File metadata and controls
93 lines (74 loc) · 2.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from text4gcn.models import Builder as bd
from text4gcn.models import Layer as layer
from text4gcn.models import GNN
from text4gcn.preprocess import TextPipeline
from text4gcn.builder import *
from text4gcn.datasets import data
print(f"\n{'='*60} START OF TEST\n")
path = "examples"
print(data.list())
data.R8(path=path)
# data.R52(path=path)
# data.AG_NEWS(path=path)
print("OK")
# #print(help(layer))
# ======================= TextPipeline
pipe = TextPipeline(
dataset_name="R8",
rare_count=5,
dataset_path=path,
language="english")
pipe.execute()
# =======================
# # ======================= FrequencyAdjacency
# freq = FrequencyAdjacency(
# dataset_name="R8",
# dataset_path=path
# )
# freq.build()
# # ======================= CosineSimilarityAdjacency
# freq = CosineSimilarityAdjacency(
# dataset_name="R8",
# dataset_path=path
# )
# freq.build()
# ======================= EmbeddingAdjacency
freq = EmbeddingAdjacency(
dataset_name="R8",
dataset_path=path,
num_epochs=20,
embedding_dimension=300,
training_regime=1
)
freq.build()
# # ======================= DependencyParsingAdjacency
# freq = DependencyParsingAdjacency(
# dataset_name="R8",
# dataset_path=path,
# core_nlp_path="C:/bin/CoreNLP/stanford-corenlp-full-2018-10-05"
# )
# # freq.build()
# # ======================= ConstituencyParsingAdjacency
# # freq = ConstituencyParsingAdjacency()
# # ======================= LiwcAdjacency
# freq = LiwcAdjacency(
# dataset_name="R8",
# dataset_path=path,
# liwc_path="ztst/LIWC2007_English100131.dic"
# )
# freq.build()
gnn = GNN(
dataset="R8", # Dataset to train
path=path, # Dataset path
log_dir="examples/log", # Log path
layer=layer.GCN, # Layer Type
epoches=200, # Number of traing epoches
dropout=0.5, # Dropout rate
val_ratio=0.1, # Train data used to validation
early_stopping=10, # Stop early technique
lr=00.2, # Initial learing rate
nhid=200, # Dimensions of hidden layers
builder=bd.Embedding # Type of Filtered Text Graph
)
gnn.fit()
print(f"\n{'='*60} END OF TEST\n")