diff --git a/mh/database.py b/mh/database.py index 058e163..9d200de 100644 --- a/mh/database.py +++ b/mh/database.py @@ -360,7 +360,11 @@ def __init__(self): self.servers = new_servers() def get_support_code(self, session): - """Get the online support code or create one.""" + """SESSION LOCKING + + Get the online support code or create one. + """ + session.lock() support_code = session.online_support_code if support_code is None: while True: @@ -368,6 +372,7 @@ def get_support_code(self, session): if support_code not in self.consoles: session.online_support_code = support_code break + session.unlock() # Create some default users if support_code not in self.consoles: @@ -378,12 +383,17 @@ def get_support_code(self, session): return support_code def new_pat_ticket(self, session): - """Generates a new PAT ticket for the session.""" + """SESSION LOCKING + + Generates a new PAT ticket for the session. + """ + session.lock() while True: session.pat_ticket = new_random_str(11) if session.pat_ticket not in self.sessions: break self.sessions[session.pat_ticket] = session + session.unlock() return session.pat_ticket def use_capcom_id(self, session, capcom_id, name=None): @@ -398,9 +408,13 @@ def use_capcom_id(self, session, capcom_id, name=None): return name def use_user(self, session, index, name): - """Use User from the slot or create one if empty""" + """SESSION LOCKING + + Use User from the slot or create one if empty. + """ assert 1 <= index <= 6, "Invalid Capcom ID slot" index -= 1 + session.lock() users = self.consoles[session.online_support_code] while users[index] == "******": capcom_id = new_random_str(6) @@ -413,6 +427,7 @@ def use_user(self, session, index, name): name = self.use_capcom_id(session, capcom_id, name) session.capcom_id = capcom_id session.hunter_name = name + session.unlock() def get_session(self, pat_ticket): """Returns existing PAT session or None.""" @@ -451,18 +466,24 @@ def get_users(self, session, first_index, count): return capcom_ids def join_server(self, session, index): + """SESSION LOCKING""" + session.lock() if session.local_info["server_id"] is not None: self.leave_server(session, session.local_info["server_id"]) server = self.get_server(index) server.players.add(session) session.local_info["server_id"] = index session.local_info["server_name"] = server.name + session.unlock() return server def leave_server(self, session, index): + """SESSION LOCKING""" + session.lock() self.get_server(index).players.remove(session) session.local_info["server_id"] = None session.local_info["server_name"] = None + session.unlock() def get_server_time(self): pass @@ -486,20 +507,26 @@ def get_gate(self, server_id, index): return gates[index - 1] def join_gate(self, session, server_id, index): + """SESSION LOCKING""" + session.lock() gate = self.get_gate(server_id, index) gate.parent.players.remove(session) gate.players.add(session) session.local_info["gate_id"] = index session.local_info["gate_name"] = gate.name + session.unlock() return gate def leave_gate(self, session): + """SESSION LOCKING""" + session.lock() gate = self.get_gate(session.local_info["server_id"], session.local_info["gate_id"]) gate.parent.players.add(session) gate.players.remove(session) session.local_info["gate_id"] = None session.local_info["gate_name"] = None + session.unlock() def get_cities(self, server_id, gate_id): return self.get_gate(server_id, gate_id).cities @@ -557,14 +584,19 @@ def create_city(self, session, server_id, gate_id, index, return city def join_city(self, session, server_id, gate_id, index): + """SESSION LOCKING""" + session.lock() city = self.get_city(server_id, gate_id, index) city.parent.players.remove(session) city.players.add(session) session.local_info["city_id"] = index session.local_info["city_name"] = city.name + session.unlock() return city def leave_city(self, session): + """SESSION LOCKING""" + session.lock() city = self.get_city(session.local_info["server_id"], session.local_info["gate_id"], session.local_info["city_id"]) @@ -574,6 +606,7 @@ def leave_city(self, session): city.clear_circles() session.local_info["city_id"] = None session.local_info["city_name"] = None + session.unlock() def layer_detail_search(self, server_type, fields): cities = [] diff --git a/mh/session.py b/mh/session.py index aaaa954..1806ee5 100644 --- a/mh/session.py +++ b/mh/session.py @@ -20,6 +20,7 @@ """ import struct +from threading import RLock import mh.database as db import mh.pat_item as pati @@ -70,9 +71,14 @@ def __init__(self, connection_handler): self.binary_setting = b"" self.search_payload = None self.hunter_info = pati.HunterSettings() + self._lock = RLock() def get(self, connection_data): - """Return the session associated with the connection data, if any.""" + """SESSION LOCKING + + Return the session associated with the connection data, if any. + """ + self.lock() if hasattr(connection_data, "pat_ticket"): self.pat_ticket = to_str( pati.unpack_binary(connection_data.pat_ticket) @@ -81,9 +87,10 @@ def get(self, connection_data): self.online_support_code = to_str( pati.unpack_string(connection_data.online_support_code) ) + session = DB.get_session(self.pat_ticket) or self if session != self: - assert session.connection is None, "Session is already in use" + assert session.connection is None, "Session is already in use"+str(self.unlock()) session.connection = self.connection self.connection = None @@ -93,6 +100,7 @@ def get(self, connection_data): session.request_reconnection = \ not ("pat_ticket" in connection_data or "online_support_code" in connection_data) + self.unlock() return session def get_support_code(self): @@ -100,23 +108,29 @@ def get_support_code(self): return DB.get_support_code(self) def disconnect(self): - """Disconnect the current session. + """SESSION LOCKING + Disconnect the current session. It doesn't purge the session state nor its PAT ticket. """ + self.lock() self.layer_end() self.connection = None DB.disconnect_session(self) + self.unlock() def delete(self): - """Delete the current session. + """SESSION LOCKING + Delete the current session. TODO: - Find a good place to purge old tickets. - We should probably create a SessionManager thread per server. """ + self.lock() if not self.request_reconnection: DB.delete_session(self) + self.unlock() def is_jap(self): """TODO: Heuristic using the connection data to detect region.""" @@ -136,30 +150,42 @@ def get_servers(self): return DB.get_servers() def get_server(self): - assert self.local_info['server_id'] is not None - return DB.get_server(self.local_info['server_id']) + server_id = self.local_info['server_id'] + assert server_id is not None + return DB.get_server(server_id) def get_gate(self): - assert self.local_info['gate_id'] is not None - return DB.get_gate(self.local_info['server_id'], - self.local_info['gate_id']) + server_id = self.local_info['server_id'] + gate_id = self.local_info['gate_id'] + assert server_id is not None and gate_id is not None + return DB.get_gate(server_id, + gate_id) def get_city(self): - assert self.local_info['city_id'] is not None - return DB.get_city(self.local_info['server_id'], - self.local_info['gate_id'], - self.local_info['city_id']) + server_id = self.local_info['server_id'] + gate_id = self.local_info['gate_id'] + city_id = self.local_info['city_id'] + assert server_id is not None and gate_id is not None and city_id is not None + return DB.get_city(server_id, + gate_id, + city_id) def get_circle(self): - assert self.local_info['circle_id'] is not None - return self.get_city().circles[self.local_info['circle_id']] + circle_id = self.local_info['circle_id'] + assert circle_id is not None + return self.get_city().circles[circle_id] def layer_start(self): + """SESSION LOCKING""" + self.lock() self.layer = 0 self.state = SessionState.LOG_IN + self.unlock() return pati.getDummyLayerData() def layer_end(self): + """SESSION LOCKING""" + self.lock() if self.layer > 1: # City path if self.local_info['circle_id'] is not None: @@ -174,15 +200,20 @@ def layer_end(self): self.leave_server() self.layer = 0 self.state = SessionState.UNKNOWN + self.unlock() def layer_down(self, layer_id): + """SESSION LOCKING""" + self.lock() if self.layer == 0: self.join_gate(layer_id) elif self.layer == 1: self.join_city(layer_id) else: + self.unlock() assert False, "Can't go down a layer" self.layer += 1 + self.unlock() def layer_create(self, layer_id, settings, optional_fields): if self.layer == 1: @@ -193,13 +224,17 @@ def layer_create(self, layer_id, settings, optional_fields): self.layer_down(layer_id) def layer_up(self): + """SESSION LOCKING""" + self.lock() if self.layer == 1: self.leave_gate() elif self.layer == 2: self.leave_city() else: + self.unlock() assert False, "Can't go up a layer" self.layer -= 1 + self.unlock() def layer_detail_search(self, detailed_fields): server_type = self.get_server().server_type @@ -213,16 +248,18 @@ def join_server(self, server_id): return DB.join_server(self, server_id) def get_layer_children(self): - if self.layer == 0: + layer = self.layer + if layer == 0: return self.get_gates() - elif self.layer == 1: + elif layer == 1: return self.get_cities() assert False, "Unsupported layer to get children" def get_layer_sibling(self): - if self.layer == 1: + layer = self.layer + if layer == 1: return self.get_gates() - elif self.layer == 2: + elif layer == 2: return self.get_cities() assert False, "Unsupported layer to get sibling" @@ -251,12 +288,18 @@ def get_gates(self): return DB.get_gates(self.local_info["server_id"]) def join_gate(self, gate_id): + """SESSION LOCKING""" + self.lock() DB.join_gate(self, self.local_info["server_id"], gate_id) self.state = SessionState.GATE + self.unlock() def leave_gate(self): + """SESSION LOCKING""" + self.lock() DB.leave_gate(self) self.state = SessionState.LOG_IN + self.unlock() def get_cities(self): return DB.get_cities(self.local_info["server_id"], @@ -275,15 +318,21 @@ def create_city(self, city_id, settings, optional_fields): city_id, settings, optional_fields) def join_city(self, city_id): + """SESSION LOCKING""" + self.lock() DB.join_city(self, self.local_info["server_id"], self.local_info["gate_id"], city_id) self.state = SessionState.CITY + self.unlock() def leave_city(self): + """SESSION LOCKING""" + self.lock() DB.leave_city(self) self.state = SessionState.GATE + self.unlock() def try_transfer_city_leadership(self): if self.local_info['city_id'] is None: @@ -317,9 +366,12 @@ def try_transfer_circle_leadership(self): return None, None def join_circle(self, circle_id): + """SESSION LOCKING""" # TODO: Move this to the database + self.lock() self.local_info['circle_id'] = circle_id self.state = SessionState.CIRCLE + self.unlock() def set_circle_standby(self, val): assert self.state == SessionState.CIRCLE or \ @@ -339,10 +391,13 @@ def set_in_quest(self): self.state = SessionState.QUEST def leave_circle(self): + """SESSION LOCKING""" # TODO: Move this to the database + self.lock() circle = self.get_circle() self.local_info['circle_id'] = None self.state = SessionState.CITY + self.unlock() if circle.leader == self: circle.reset() @@ -350,13 +405,14 @@ def leave_circle(self): circle.players.remove(self) def get_layer_players(self): - if self.layer == 0: + layer = self.layer + if layer == 0: server = self.get_server() return server.players - elif self.layer == 1: + elif layer == 1: gate = self.get_gate() return gate.players - elif self.layer == 2: + elif layer == 2: city = self.get_city() return city.players else: @@ -380,3 +436,9 @@ def get_optional_fields(self): (1, (weapon_type << 24) | location), (2, hunter_rank << 16) ] + + def lock(self): + self._lock.acquire() + + def unlock(self): + self._lock.release()