diff --git a/xor.py b/xor.py index 6ceea8b..a706b27 100644 --- a/xor.py +++ b/xor.py @@ -1,4 +1,7 @@ -#learn XOR with a nerual network with saving of the learned paramaters +# train a neural network to do XOR +# +# original code at https://github.com/EricSchles/neuralnet/blob/master/xor.py +# by Eric Schles import pybrain from pybrain.datasets import * @@ -6,29 +9,35 @@ from pybrain.supervised.trainers import BackpropTrainer import pickle -if __name__ == "__main__": - ds = SupervisedDataSet(2, 1) - ds.addSample( (0,0) , (0,)) - ds.addSample( (0,1) , (1,)) - ds.addSample( (1,0) , (1,)) - ds.addSample( (1,1) , (0,)) +# ============= +# the data we want to train with: +# every sample contains the two inputs, and the one expected output +# this represents the XOR operation: only if the two inputs are different should the output be true=1 +print "setting up training data" +ds = SupervisedDataSet(2, 1) +ds.addSample((0,0), (0)) +ds.addSample((0,1), (1)) +ds.addSample((1,0), (1)) +ds.addSample((1,1), (0)) - net = buildNetwork(2, 4, 1, bias=True) +# ============= +# build up a neural network +# the arguments are: 2=number of input nodes, 4=number of hidden nodes, 1=number of output nodes +print "building the network" +net = buildNetwork(2,4,1,bias=True) - # try: - # f = open('_learned', 'r') - # net = pickle.load(f) - # f.close() - # except: - trainer = BackpropTrainer(net, learningrate = 0.01, momentum = 0.99) - trainer.trainOnDataset(ds, 3000) - trainer.testOnData() - # f = open('_learned', 'w') - # pickle.dump(net, f) - # f.close() - +# ============= +# do the training with the data: +print "training the network" +trainer = BackpropTrainer(net, learningrate = 0.01, momentum = 0.99) +trainer.trainOnDataset(ds, 3000) +trainer.testOnData() - print net.activate((1,1)) - print net.activate((1,0)) - print net.activate((0,1)) - print net.activate((0,0)) +# ============= +# now we use the trainig data to test the network: +# we loop through all the samples: +print "checking the network" +for inp, target in ds: + out = net.activate(inp) + print "Input %s: Output of Network is %f. Expected output is %f." % ( repr( inp ), out, target ) +print "done testing."