-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
90 lines (78 loc) · 2.49 KB
/
test.py
File metadata and controls
90 lines (78 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#-*-coding:utf-8-*-
#Author = horton
import sys
from locale import atof
import neural
class bpTest(object):
"""docstring for bpTest"""
def __init__(self, trainfile,testfile):
super(bpTest, self).__init__()
self.trainfile = trainfile
self.testfile = testfile
def _getExamplars(self):
traf = open(self.trainfile)
examplars = []
while(traf):
line = traf.readline()
if line!="":
tokens = line.strip().split(":")
print tokens
if 2!= len(tokens):
print "The format of the trainfile is wrong!--1"
sys.exit()
elif 8!= len(tokens[0].split(",")):
print "The format of the trainfile is wrong!--2"
sys.exit()
elif 1!= len(tokens[1].split(",")):
print "The format of the trainfile is wrong!--3"
sys.exit()
inputs = []
for i in range(8):
inputs.insert(i,atof(tokens[0].split(",")[i]))
outputs = []
outputs.insert(0,atof(tokens[1][0]))
examplar = (inputs,outputs)
examplars.append(examplar)
else:
break
print examplars
return examplars
def _getInputs(self):
tesf = open(self.testfile)
test = []
while(tesf):
line = tesf.readline()
if line!="":
tokens = line.strip().split(",")
if 8!= len(tokens):
print "The format of the testfile is wrong!"
sys.exit()
inputs = []
for i in range(8):
inputs.insert(i,atof(tokens[i]))
test.append(inputs)
else:
break
print test
return test
def main():
trainfile="./test/data/train-1"
testfile = "./test/data/test-1"
bpt = bpTest(trainfile, testfile)
inputs = bpt._getInputs()
examplars = bpt._getExamplars()
bpNet=neural.BackPropNet()
bpNet.addinput(8)
bpNet.addhidden(3)
bpNet.addouput(1)
bpNet.learn(examplars,1000)
results = bpNet.run(inputs)
for i, e in enumerate(inputs):
if results[i][0]>=0.5:
r=1
else:
r=0
print e,r
neural.xmlneural.XMLBPNSaver().save(bpNet,"./test/model.xml")
if __name__=="__main__":
main()