-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathserver.py
More file actions
427 lines (374 loc) · 15.2 KB
/
server.py
File metadata and controls
427 lines (374 loc) · 15.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
# server
import socket
import threading
import sqlite3 as sql
import pydb # database initialization are available via this file
# setup global variables
PORT = 8414 # PORT for server address
MSGLEN = 256
DBNAME = "stockDB"
MAXCLIENTS = 10
serverAddress = ("localhost", PORT) # change server address string if desired
status = True
clientSockets = []
clientThreads = []
clientAddresses = []
usersOnline = ["" for x in range(0, 10)]
# init db and setup cursor
newDb = pydb.initDB(DBNAME)
if newDb != None:
print("Connected to database: " + DBNAME + "\n")
else:
status = False
print("Database failed to connect - Program exiting...")
# create server socket
serverSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# test connection with serverAddress - error if it doesn't work and exit program
try:
serverSocket.bind(serverAddress)
serverSocket.listen(5)
print(
f"Server started on {serverAddress[0]} and is listening on port {serverAddress[1]} for clients\n")
except Exception as e:
status = False
print(f"Server failed to start on {serverAddress[0]}:{serverAddress[1]} \nException is" + str(
e) + "\nProgram exiting...")
# DATABASE FUNCTIONS
# returns user tuple with userID
def getUserInfo(cur, userID):
cur.execute("SELECT * FROM Users WHERE ID = " + str(userID))
user = cur.fetchall()
if len(user) == 0:
return (None)
else:
return (user)
# sends input string to client of max length MSGLEN
def sendMsg(msg, cSocket):
if cSocket in clientSockets:
totalSent = 0
while len(msg) < MSGLEN:
msg = msg + " "
while totalSent < MSGLEN:
try:
sent = cSocket.send(msg[totalSent:].encode("utf-8"))
except Exception as e:
print(e)
threadClose(cSocket)
break
if sent == 0 and msg != "":
break
elif msg == "":
break
totalSent += sent
print("msg '" + msg.strip() + "' sent with total bytes: " + str(totalSent))
# recieves string from client
def recieveMsg(cSocket):
if cSocket in clientSockets:
chunks = []
bytesRecieved = 0
while bytesRecieved < MSGLEN:
try:
chunk = cSocket.recv(min(MSGLEN - bytesRecieved, 2048))
except Exception as e:
print(e)
threadClose(cSocket)
break
if chunk.decode("utf-8") == "":
threadClose(cSocket)
break
chunks.append(chunk)
bytesRecieved += len(chunk)
returnStr = ""
for c in chunks:
returnStr = returnStr + c.decode("utf-8")
print("msg '" + returnStr.strip() +
"' recieved with total bytes: " + str(bytesRecieved))
return returnStr.strip()
# returns user balance of user with userID
def balance(cur, userID):
UI = getUserInfo(cur, userID)
if UI != None:
return UI[0][5]
else:
return None
def who():
x = 0
who = ""
while x < len(usersOnline):
if (usersOnline[x] != ""):
who += usersOnline[x] + f"-{clientAddresses[x][0]}:{clientAddresses[x][1]} \n"
x += 1
return who
# deposit Funds in account
def deposit(cur: sql.Cursor, deposit_amount, user_id):
# Check if user exists in database
query = "SELECT * FROM Users WHERE ID = ?"
cur.execute(query, (user_id,))
user = cur.fetchone()
if user is None:
return (f"400 User with ID {user_id} does not exist")
else:
# Update the user's account balance
new_balance = balance(cur, user_id) + float(deposit_amount)
cur.execute("UPDATE Users SET usd_balance = ? WHERE id = ?",
(new_balance, user_id,))
cur.connection.commit()
return (f"Deposited {deposit_amount} into the account of user {user_id}")
# buy stock funciton
def buy_stock(cur: sql.Cursor, ticker, quantity, stock_price, user_id):
# Check if user has sufficient funds
userBal = balance(cur, user_id)
totalCost = float(quantity) * float(stock_price)
if userBal < totalCost:
return "failure"
# Update stock quantity
query = "SELECT * FROM Stocks WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (ticker, user_id))
stock = cur.fetchone()
if not (stock is None):
print("in")
query = "UPDATE Stocks SET stock_balance = stock_balance + ? WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (quantity, ticker, user_id))
else:
query = "INSERT INTO Stocks (stock_symbol, stock_name, stock_balance, user_id) VALUES (?, ?, ?, ?)"
cur.execute(query, (ticker, ticker, quantity, user_id))
# Update user balance
query = "UPDATE Users SET usd_balance = usd_balance - ? WHERE ID = ?"
cur.execute(query, (totalCost, user_id))
cur.connection.commit()
return "success"
# sell stock function
def sell_stock(cur: sql.Cursor, ticker, quantity, stock_price, user_id):
query = "SELECT * FROM Stocks WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (ticker, user_id))
stock = cur.fetchone()
if stock is None: # check if user has any of said stock
return "notExist"
stock_balance = stock[3]
if stock_balance < float(quantity): # check if user has enough of the stock
return "lessQuantity"
# Update stock quantity
query = "UPDATE Stocks SET stock_balance = stock_balance - ? WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (quantity, ticker, user_id))
# Check stock quanity, delete if 0
threshold = 1e-5
query = "SELECT stock_balance FROM Stocks WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (ticker, user_id))
dbQuantity = cur.fetchone()[0]
if dbQuantity < threshold: # delete stock if user has sold all of it
query = "DELETE FROM Stocks WHERE stock_symbol = ? AND user_id = ?"
cur.execute(query, (ticker, user_id))
# Update user balance
query = "UPDATE Users SET usd_balance = usd_balance + ? WHERE ID = ?"
cur.execute(query, (float(quantity) * float(stock_price), user_id))
cur.connection.commit()
return "success"
# lookup function from PA2
def lookup_stock(cur: sql.Cursor, stock_name, user_id):
# Find all stocks that match the given name
cur.execute("SELECT * FROM Stocks WHERE stock_symbol = ? AND user_id = ?",
(stock_name, user_id,))
stocks = cur.fetchall()
if len(stocks) == 0:
# If no stocks match the given name, return a 404 response
return f"404 ERROR Your search for {stock_name} did not match any records."
else:
# Otherwise, return a 200 OK response and the matching stock(s)
response = f"200 OK "
for stock in stocks:
response = response + f"[{stock[0]},{stock[1]},{stock[2]},{round(stock[3], 2)},{stock[4]}] "
return response
# returns a list of all stock tuples with user_id
def list_stocks(cur, clientIndex, clientUID):
if usersOnline[clientIndex] == "root":
query = "SELECT * FROM Stocks"
cur.execute(query,)
stocks = cur.fetchall()
stocksList = []
for stock in stocks:
stockEntry = []
for item in stock:
stockEntry.append(item)
stocksList.append(stockEntry)
query = "SELECT user_name FROM Users WHERE ID = ?"
for i,stock in enumerate(stocksList):
cur.execute(query, (stock[4],))
stocksList[i][4] = cur.fetchone()[0]
return stocksList
else:
query = "SELECT * FROM Stocks WHERE user_id = ?"
cur.execute(query, (clientUID,))
stocks = cur.fetchall()
return stocks
# server shutdown command
def shutdown():
global status
status = False
cSockCopy = clientSockets[:]
for i, sock in enumerate(cSockCopy):
if threadClose(sock):
pass
print("Server shutting down...")
serverSocket.close()
def login(cur, userID, password):
# check if userID exists
query = "SELECT * FROM Users WHERE user_name = ?"
cur.execute(query, (userID,))
user = cur.fetchone()
if user == None:
return "notExist"
# check password
if user[4] == password:
return "success " + str(user[0]) + " " + str(user[3])
else:
return "passWrong " + str(user[3])
def threadClose(clientSocket, clientThread=None):
sockID = None
threadID = None
userIndex = None
for i, sock in enumerate(clientSockets):
if sock.getpeername()[0] == clientSocket.getpeername()[0] and sock.getpeername()[1] == clientSocket.getpeername()[1]:
sockID = i
threadID = i
userIndex = i
if sockID != None and threadID != None:
if clientThread == None:
clientThread = threading.current_thread()
global status
if status == False:
sendMsg("shutdown", clientSocket)
clientSockets.pop(sockID)
clientThreads.pop(threadID)
try:
user = usersOnline[userIndex]
usersOnline.pop(userIndex)
except Exception as e:
pass
addr = clientSocket.getpeername()
print(f"Client with address {addr[0]}:{addr[1]} disconnected")
return True
return False
def isSocketClose(sock: socket.socket):
try:
# this will try to read bytes without blocking and also without removing them from buffer (peek only)
sock.setblocking(False)
data = sock.recv(2048, socket.MSG_PEEK)
if data.decode("utf-8") == "":
return True
except BlockingIOError:
sock.setblocking(True)
return False # socket is open and reading from it would block
except ConnectionResetError:
return True # socket was closed for some other reason
except Exception as e:
sock.setblocking(True)
return False
sock.setblocking(True)
return False
# thread function
def threadLoop(clientSocket: socket.socket, clientIndex):
clientUID = None
clientUserName = ""
db = pydb.getDB()
cur = db.cursor()
# main loop
while status and not isSocketClose(clientSocket):
msg = recieveMsg(clientSocket)
if msg == None:
break
if msg.lower() == "shutdown".lower(): # shutdown command
if usersOnline[clientIndex] == "root":
sendMsg("200 OK Shutdown Initiated", clientSocket)
shutdown()
else:
print("ERROR 400 User is not root, cannot issue shutdown command")
sendMsg("ERROR 400 User Not Root", clientSocket)
elif msg.lower() == "quit".lower(): # quit command - listen for next client afterwards
if threadClose(clientSocket) == True:
pass
elif msg.lower()[0:5] == "login".lower():
params = msg[6:].split()
loginTry = login(cur, params[0], params[1])
if loginTry[0:7] == "success":
clientUID = loginTry[8:9]
clientUserName = loginTry[10:]
usersOnline.insert(clientIndex, clientUserName)
if clientIndex != 9:
usersOnline.pop(clientIndex+1)
clientAddresses.insert(clientIndex, clientSocket.getpeername())
sendMsg( f"200 OK {clientUID} Successfully logged in user {clientUserName} with userID {clientUID}", clientSocket)
elif loginTry == "notExist":
sendMsg("403 ERROR User does not exist", clientSocket)
elif loginTry[0:9] == "passWrong":
sendMsg(f"403 Error Password incorrect for user {loginTry[10:]}", clientSocket)
elif msg.lower() == "logout".lower():
clientUID = None
clientUserName = ""
usersOnline[clientIndex] = ""
clientAddresses[clientIndex] = None
sendMsg("200 OK", clientSocket)
elif msg.lower()[0:7] == "balance".lower():
userBalance = balance(cur, clientUID)
sendMsg("200 OK " + str(round(userBalance, 2)), clientSocket)
elif msg.lower()[0:7] == "deposit".lower():
params = msg[8:].split()
dep = deposit(cur, params[0], params[1])
sendMsg("200 OK " + str(dep), clientSocket)
elif msg.lower()[0:6] == "lookup".lower():
params = msg[7:].split()
lup = lookup_stock(cur, params[0].upper(), params[1])
sendMsg(lup, clientSocket)
# list user's stocks, user id is 1 by default
elif msg.lower()[0:4] == "list".lower():
stocks = list_stocks(cur, clientIndex, clientUID)
sendString = ""
for stock in stocks:
sendString += f"[{stock[0]},{stock[1]},{stock[2]},{round(stock[3], 2)},{stock[4]}] "
sendMsg("200 OK " + sendString, clientSocket)
elif msg.lower()[0:3] == "buy".lower(): # buy function
params = msg[4:].split()
success = buy_stock(
cur, params[0].upper(), params[1], params[2], params[3])
newBal = balance(cur, params[3])
if success == "success":
sendMsg(
"200 OK " + f"Successfully bought {params[1]} of {params[0].upper()}. New balance of user {params[3]}: {round(newBal, 2)}", clientSocket)
else:
sendMsg(
"400 ERROR " + "Unable to buy stock. User balance insuffecient", clientSocket)
elif msg.lower()[0:4] == "sell".lower(): # sell function
params = msg[5:].split()
success = sell_stock(
cur, params[0].upper(), params[1], params[2], params[3])
newBal = balance(cur, params[3])
if success == "success":
sendMsg(
"200 OK " + f"Successfully sold {params[1]} of stock {params[0].upper()}. New balance of user {params[3]}: {round(newBal, 2)}", clientSocket)
elif success == "lessQuantity":
sendMsg(
"400 ERROR " + f"Unable to sell stock. User holds insuffecient amount of stock {params[0].upper()}", clientSocket)
elif success == "notExist":
sendMsg(
"401 ERROR " + f"Unable to sell stock. Stock entry doesn't exist", clientSocket)
elif msg.lower()[0:3] == "who".lower():
if usersOnline[clientIndex] == "root":
ret = who()
sendMsg("200 OK " + f"{ret}", clientSocket)
else:
print("ERROR 400 User is not root, cannot issue who command")
sendMsg("ERROR 400 User Not Root", clientSocket)
elif not isSocketClose(clientSocket):
sendMsg("400 ERROR Invalid Command", clientSocket)
# main server loop - accept connections from clients
while status:
if len(clientSockets) < MAXCLIENTS:
try:
(clientSocket, clientAddr) = serverSocket.accept()
except Exception as e:
break
clientSockets.append(clientSocket)
clientIndex = len(clientSockets) - 1
cThread = threading.Thread(target=threadLoop, args=(clientSocket, clientIndex))
clientThreads.append(cThread)
cThread.start()
print(f"\nConnection accepted from {clientAddr[0]}:{clientAddr[1]}\nStarting new thread for client\n")