-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_server2.py
More file actions
116 lines (100 loc) · 3.32 KB
/
test_server2.py
File metadata and controls
116 lines (100 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from collections import defaultdict
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import time
from cleaner import get_data
K = 5
f = open('/home/christopher/data_bin/train_test_1car_randomforestdecisions.txt')
train, test = get_data(f)
df = train
print 'Fitting model'
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import BernoulliNB
from sklearn.svm import SVC
from sklearn import cross_validation
from sklearn import metrics
x_cols = ['curCar starting_speed', 'starting_distance', 'discretizedBrakes', 'discretizedSteering']
y_col = 'collisionOrOffroad'
train_x = df[x_cols]
train_y = df[y_col]
rfc = RandomForestClassifier(n_estimators=100)
rfc.fit(train_x, train_y)
#
# Hello World server in Python
# Binds REP socket to tcp://*:5555
# Expects b"Hello" from client, replies with b"World"
#
import time
import zmq
import numpy as np
print 'Establishing server'
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:5555")
print 'Waiting for requests'
x = np.zeros((1,len(x_cols)))
# behaviors = [(0,0), (0,1), (1,0), (1,1)]
brakesBehaviors = sorted(set(df['discretizedBrakes']))
steeringBehaviors = sorted(set(df['discretizedSteering']))
behaviors = list(itertools.product(brakesBehaviors, steeringBehaviors))
probs = np.zeros(len(behaviors))
request_size = len(x_cols) - len(behaviors[0])
x2 = np.array([(0, 0, beh[0], beh[1]) for beh in behaviors])
while True:
# Wait for next request from client
message = socket.recv()
t0 = time.time()
print("Received request: \"%s\"" % message)
# Do some 'work'
split = message.split()
# print split
# if len(split) != request_size:
# print 'ERROR, incorrect request size'
# min_proba_behavior = behaviors[np.argmin(probs)]
# string_behavior = [str(num) for num in min_proba_behavior]
# result = ' '.join(string_behavior) + '\n'
# # Send reply back to client
# socket.send(result)
# continue
# request = [float(num) for num in split]
# x[0, :len(request)] = request
try:
for i,num in enumerate(split):
x2[:, i] = float(num)
except ValueError as e:
print 'ERROR, incorrect request size'
min_proba_behavior = behaviors[np.argmin(probs)]
string_behavior = [str(num) for num in min_proba_behavior]
result = ' '.join(string_behavior) + '\n'
# Send reply back to client
socket.send(result)
continue
if np.any(np.isnan(x)):
print 'ERROR, NaN'
min_proba_behavior = behaviors[np.argmin(probs)]
string_behavior = [str(num) for num in min_proba_behavior]
result = ' '.join(string_behavior) + '\n'
# Send reply back to client
socket.send(result)
continue
# print x2
# x2 = [(float(split[0]), float(split[1]), beh[0], beh[1]) for beh in behaviors]
probs = rfc.predict_proba(x2)
min_proba_behavior = behaviors[np.argmin(probs[:,1])]
# for i, behavior in enumerate(behaviors):
# x[0, len(split):] = behavior
# prob = rfc.predict_proba(x)
# probs[i] = prob[0,1]
# import pdb; pdb.set_trace()
# # prob_collision = prob[0,1]
# # probs[i] = prob_collision
print probs[:,1]
# min_proba_behavior = behaviors[np.argmin(probs)]
string_behavior = [str(num) for num in min_proba_behavior]
result = ' '.join(string_behavior) + '\n'
print result
# Send reply back to client
socket.send(result)
print "Processed in %fms" % ((time.time() - t0)*1000)