-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_experiments.py
More file actions
116 lines (98 loc) · 3.42 KB
/
run_experiments.py
File metadata and controls
116 lines (98 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
kb_sizes = ['small', 'medium', 'large']
query_sizes = ['short', 'medium', 'long']
query_classes = ['sat', 'unsat']
def make_pred_sig(args, i):
return "query{0}({1})".format(i, args)
def make_tmpfile_name(kb_size, query_size, query_class):
tmpfile = "/tmp/{0}_{1}_{2}.pl".format(kb_size, query_size, query_class)
return tmpfile
def gen_tmpfile(kb_size, query_size, query_class):
from sh import cp
from experiments_pl import convertToQuery
import re
queryfile = "./{0}-{1}-queries.txt".format(query_class, query_size)
tmpfile = make_tmpfile_name(kb_size, query_size, query_class)
cp(
"{0}.pl".format(kb_size),
tmpfile
)
r = []
with open(tmpfile, "a") as tmp:
with open(queryfile) as queries:
for i, query in enumerate(queries.readlines()):
rule = convertToQuery(query.strip())
args = ",".join([chr(65+n) for n,_ in enumerate(re.finditer('tweet',rule))])
pred_sig = make_pred_sig(args, i)
tmp.write("{0} :- {1}.\n".format(pred_sig, rule))
r.append({
'args': args,
'i': i,
'kb_size': kb_size,
'query_size': query_size,
'query_class': query_class,
'orig_query': query.strip()
})
return r
def run_swi(kb_size, query_size, query_class, args, i, *aargs, **kwargs):
from sh import swipl
from StringIO import StringIO
import re
f = StringIO()
swipl(
'-s', make_tmpfile_name(kb_size, query_size, query_class),
'-g', 'time(findall([{0}],{1},ZZZ)).'.format(args, make_pred_sig(args, i)),
'-t', 'halt.',
'-G32g',
_err=f
)
output = f.getvalue()
matches = re.search(r'in ([0-9.]+) seconds', output)
return float(matches.group(1))
def time_validation(query):
from timeit import Timer
tv = Timer(stmt="solve('{0}')".format(query), setup="from validator import solve").timeit(1)
return tv
if __name__ == '__main__':
from argparse import ArgumentParser
from itertools import product
from collections import OrderedDict
import csv
parser = ArgumentParser()
parser.add_argument('outfile', help="The file where you want to store the output CSV")
parser.add_argument('-s', '--satisfiability', help="Select which satisfiabilities you want.",
action='append', choices=query_classes)
parser.add_argument('-a', '--arity', help="Select which arities you want.",
action='append', choices=query_sizes)
parser.add_argument('-k', '--kbsize', help="Select which knowledgebase sizes you want",
action='append', choices=kb_sizes+['tiny'])
args = parser.parse_args()
if args.satisfiability is None:
args.satisfiability = query_classes
if args.arity is None:
args.arity = query_sizes
if args.kbsize is None:
args.kbsize = kb_sizes
queue = []
for kb_size, query_size, query_class in product(args.kbsize, args.arity, args.satisfiability):
queries = gen_tmpfile(kb_size, query_size, query_class)
queue.extend(queries)
with open(args.outfile, "wb") as outfile:
out_fields = OrderedDict([
('kb_size', None),
('query_sat', None),
('query_size', None),
('time_to_validate', None),
('time_to_run', None)])
csvwriter = csv.DictWriter(outfile, delimiter=',', fieldnames=out_fields)
csvwriter.writeheader()
while len(queue) > 0:
current = queue.pop(0)
validation_time = time_validation(current['orig_query'])
run_time = run_swi(**current)
csvwriter.writerow({
'kb_size': current['kb_size'],
'query_sat': current['query_class'],
'query_size': current['query_size'],
'time_to_validate': validation_time,
'time_to_run': run_time
})