-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDemo.py
More file actions
208 lines (184 loc) · 7.33 KB
/
Demo.py
File metadata and controls
208 lines (184 loc) · 7.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#!/usr/bin/env python
"""
!!! Not certified fit for any purpose, use at your own risk !!!
Copyright (c) Rex Sutton 2016.
Demonstrate machine learning for binary classification,
using a Gaussian Process, on USPS hand written digits, 3's vs 5's task.
"""
import warnings
warnings.filterwarnings("ignore")
import argparse
import numpy as np
from Tools import log_string
from Tools import print_log
import Usps
import LaplaceBinaryGpClassifier as gpc
GOLD_LOGSIGMAF = 2.35
GOLD_LOGL = 2.85
GOLD_LOG_MARGINAL_LIKELIHOOD = -99.0
__pos_label__ = 3
__neg_label__ = 5
__max_patterns__ = 1000
def benchmark(path):
"""Benchmark the solution of Rasmussen and Williams, published for the USPS
data 3 versus 5 classifier using the squared exponential kernel.
Args:
path (str): The path to save the classifier to.
"""
# load data from disk
print_log("starting...")
training_patterns, training_classifications \
= Usps.load_training_data(pos_label=__pos_label__,
neg_label=__neg_label__,
max_patterns=__max_patterns__)
print_log("loaded data...")
# observation_vector is D rows, n cols
training_patterns = np.transpose(training_patterns)
# use gold params
params = np.empty([2], gpc.__precision__)
params[0] = GOLD_LOGSIGMAF
params[1] = GOLD_LOGL
# compute the derived data
kernel = gpc.SquaredExponentialKernel()
kernel.set_params(params)
data = gpc.DerivedData(training_patterns, training_classifications, kernel)
print_log("calculated derived data...")
# build the classifier
pred = gpc.Classifier(data)
print_log("initialized classifier...")
# save the classifier
pred.save(path)
print_log("saved classifier...")
# run test
test(path)
def train(path):
""" Train the classifier, saving the results to disk.
Args:
path (str): The path to save the model to.
"""
# load patterns
training_patterns, training_classifications \
= Usps.load_training_data(pos_label=__pos_label__,
neg_label=__neg_label__,
max_patterns=__max_patterns__)
# observation_vector is D rows, n cols
training_patterns = np.transpose(training_patterns)
# train the model
print log_string(), "training..."
kernel = gpc.SquaredExponentialKernel()
params = gpc.Classifier.train(kernel, training_patterns, training_classifications)
# print results
print log_string(), "optimal log_sigma_f:", params[0]
print log_string(), "optimal log_l:", params[1]
# save the predictor
kernel.set_params(params)
data = gpc.DerivedData(training_patterns, training_classifications, kernel)
pred = gpc.Classifier(data)
pred.save(path)
print_log("saved classifier...")
def peek(path, pattern_idx):
""" Classify a pattern from the USPS test patterns using the classifier saved to disk.
Args:
path (str): The path to save the classifier to.
pattern_idx (int): The index of the selected pattern.
"""
# load patterns
test_patterns, dummy \
= Usps.load_test_data(pos_label=__pos_label__,
neg_label=__neg_label__,
max_patterns=__max_patterns__)
# print the predicted classification
pred = gpc.Classifier.load(path)
prob = pred.predict(test_patterns[pattern_idx])
if gpc.Classifier.threshold(prob) > 0.0:
print "*** machine predicted digit:", __pos_label__, "with probability:", prob
else:
print "*** machine predicted digit:", __neg_label__, "with probability:", 1.0 - prob
# call Usps to peek at the pattern
Usps.peek(pattern_idx,
pos_label=__pos_label__,
neg_label=__neg_label__,
max_patterns=__max_patterns__)
def show(prompt, path, indices):
""" Show the user randomly selected patterns from the subset of indices.
Args:
prompt (str): The indices of the in-correctly classified patterns.
path (str): The path to load the classifier from.
indices (int): The indices of the patterns
"""
while True:
if raw_input(prompt) != "y":
break
index = np.random.choice(indices)
peek(path, index)
def test(path):
""" Test the performance of the classifier saved to disc.
Args:
path (str): The path to load the classifier from.
"""
# load patterns
test_patterns, test_classifications \
= Usps.load_test_data(pos_label=__pos_label__,
neg_label=__neg_label__,
max_patterns=__max_patterns__)
# load classifier
pred = gpc.Classifier.load(path)
# print parameters
print "*** using log_sigma_f:", pred.data.kernel.__kernel__.log_sigma_f
print "*** using log_l:", pred.data.kernel.__kernel__.log_l
# print log marginal likelihood and derivatives
derivatives = pred.log_marginal_likelihood_deriv()
print "*** log marginal likelihood:", pred.log_marginal_likelihood()
print "*** derivative log_sigma_f:", derivatives[0]
print "*** derivative log_l:", derivatives[1]
# print information and errors
predicted_probabilities \
= [pred.predict(pattern) for pattern in test_patterns]
predicted_classifications \
= [pred.threshold(probability) for probability in predicted_probabilities]
results \
= np.subtract(predicted_classifications, test_classifications) # an error if opposite signs
num_errors = np.count_nonzero(results)
information = -1.0 * np.average(np.log2(predicted_probabilities))
print "*** average Information (bytes):", information
# print performance summary
num_test_patterns = len(test_classifications)
num_correct_classifications = num_test_patterns - num_errors
percent = 100.0 * float(num_correct_classifications) / float(num_test_patterns)
print "*** correctly classified:", num_correct_classifications,\
"of", num_test_patterns, "digits"
print "*** correctly classified:", percent, "%"
show("*** View a randomly selected, correctly classified digit (y)? ",
path,
list(set(range(0, num_test_patterns)) - set(list(np.nonzero(results)[0]))))
show("*** View a randomly selected, __in-correctly__ classified digit (y)? ",
path,
list(np.nonzero(results)[0]))
def main():
""" The main entry point function.
"""
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--command",
help="The command to invoke.",
choices=["bench", "test", "peek", "train"],
default="peek")
parser.add_argument("-p", "--path",
help="The path to load/save the classifier from/to.",
default="classifier")
parser.add_argument("-i", "--idx",
help="The index of the pattern to peek at.",
type=int,
default=0)
args = parser.parse_args()
if args.command == "bench":
benchmark(args.path)
elif args.command == "test":
test(args.path)
elif args.command == "train":
train(args.path)
elif args.command == "peek":
peek(args.path, args.idx)
else:
parser.print_help()
if __name__ == "__main__":
main()