Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions xor.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
#learn XOR with a nerual network with saving of the learned paramaters
# train a neural network to do XOR
#
# original code at https://github.com/EricSchles/neuralnet/blob/master/xor.py
# by Eric Schles

import pybrain
from pybrain.datasets import *
from pybrain.tools.shortcuts import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import pickle

if __name__ == "__main__":
ds = SupervisedDataSet(2, 1)
ds.addSample( (0,0) , (0,))
ds.addSample( (0,1) , (1,))
ds.addSample( (1,0) , (1,))
ds.addSample( (1,1) , (0,))
# =============
# the data we want to train with:
# every sample contains the two inputs, and the one expected output
# this represents the XOR operation: only if the two inputs are different should the output be true=1
print "setting up training data"
ds = SupervisedDataSet(2, 1)
ds.addSample((0,0), (0))
ds.addSample((0,1), (1))
ds.addSample((1,0), (1))
ds.addSample((1,1), (0))

net = buildNetwork(2, 4, 1, bias=True)
# =============
# build up a neural network
# the arguments are: 2=number of input nodes, 4=number of hidden nodes, 1=number of output nodes
print "building the network"
net = buildNetwork(2,4,1,bias=True)

# try:
# f = open('_learned', 'r')
# net = pickle.load(f)
# f.close()
# except:
trainer = BackpropTrainer(net, learningrate = 0.01, momentum = 0.99)
trainer.trainOnDataset(ds, 3000)
trainer.testOnData()
# f = open('_learned', 'w')
# pickle.dump(net, f)
# f.close()

# =============
# do the training with the data:
print "training the network"
trainer = BackpropTrainer(net, learningrate = 0.01, momentum = 0.99)
trainer.trainOnDataset(ds, 3000)
trainer.testOnData()

print net.activate((1,1))
print net.activate((1,0))
print net.activate((0,1))
print net.activate((0,0))
# =============
# now we use the trainig data to test the network:
# we loop through all the samples:
print "checking the network"
for inp, target in ds:
out = net.activate(inp)
print "Input %s: Output of Network is %f. Expected output is %f." % ( repr( inp ), out, target )
print "done testing."