-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
49 lines (31 loc) · 1.31 KB
/
main.py
File metadata and controls
49 lines (31 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from network import *
from functions import *
from mnist import *
from cost_functions import *
def xor(a, b):
return [np.array([a, b]), np.array([int(a!=b), 0])]
def andGate(a, b):
return [np.array([a, b]), np.array([int(a and b)])]
def quadratic(a):
return [np.array([a]), np.array([a**2])]
f = xor
def checkFunction(model, input, expected):
print("Expected: " + str(expected) + " --- " + str(model.forward_pass(np.array(input))))
def get_best_network():
net = Network([784, 15, 15, 15, 10], sigmoid, derivSigmoid, CrossEntropyLoss())
net.load("mnist_network_two_prime")
return net
data = training_data()
# net = Network([784, 15, 15, 10], sigmoid, derivSigmoid, CrossEntropyLoss())
# net.debug_descent(data, 0.0008, 50, 1, True)
# net.save("mnist_cel")
# data = [[np.atleast_2d(d[0]).transpose(), np.atleast_2d(d[1]).transpose()] for d in data]
if(__name__ == "__main__"):
net = Network([784, 15, 15, 15, 10], sigmoid, derivSigmoid, CrossEntropyLoss())
# net.load("mnist_network_two_prime")
# net.debug_descent(data, 0.01, 50, 1, 100, True)
# net.save("mnist_network_two_alpha")
# print(data[10][1])
net.debug_descent(data, 0.005, 50, 1, 100, True)
net.save("mnist_test_batch")
# print(net.forward_pass(data[10][0]))