diff --git a/config.ini b/config.ini index 2f3733a..908d5a7 100644 --- a/config.ini +++ b/config.ini @@ -2,6 +2,25 @@ DefaultCert = cert/server.crt DefaultKey = cert/server.key +[CENTRAL] +CentralIP = 0.0.0.0 +CentralCrossconnectPort = 8300 +CrossconnectSSL = ON + +[SERVER1] +Name = Valor1 +ServerType = 1 +Capacity = 400 +IP = 0.0.0.0 +Port = 8204 + +[SERVER2] +Name = Greed1 +ServerType = 4 +Capacity = 400 +IP = 0.0.0.0 +Port = 8205 + [OPN] IP = 0.0.0.0 Port = 8200 diff --git a/fmp_server.py b/fmp_server.py index c7efd3b..7db87c0 100644 --- a/fmp_server.py +++ b/fmp_server.py @@ -6,6 +6,7 @@ import struct +from mh.state import Players, get_instance import mh.pat_item as pati from mh.constants import PatID4 from mh.pat import PatRequestHandler, PatServer @@ -14,7 +15,9 @@ class FmpServer(PatServer): """Basic FMP server class.""" - pass + def close(self): + PatServer.close(self) + get_instance().close_cache() class FmpRequestHandler(PatRequestHandler): @@ -24,7 +27,16 @@ def recvAnsConnection(self, packet_id, data, seq): """AnsConnection packet.""" connection_data = pati.ConnectionData.unpack(data) self.server.debug("Connection: {!r}".format(connection_data)) - self.sendNtcLogin(3, connection_data, seq) + loaded_session = self.session.session_ready(connection_data) + if loaded_session: + if get_instance().server_id != 0: + self.session.set_session_ready(connection_data, False) + loaded_session.connection = self + self.session = loaded_session + get_instance().register_pat_ticket(self.session) + self.sendNtcLogin(3, connection_data, seq) + else: + self.session.set_session_ready(connection_data, (self, connection_data, seq)) def sendAnsLayerDown(self, layer_id, layer_set, seq): """AnsLayerDown packet. diff --git a/master_server.py b/master_server.py index fcd8c7f..f709371 100644 --- a/master_server.py +++ b/master_server.py @@ -15,15 +15,18 @@ from other.debug import register_debug_signal from other.utils import create_server_from_base +from other.cache import Cache -def create_servers(silent=False, debug_mode=False): +def create_servers(server_id, silent=False, debug_mode=False, no_timeout=False): """Create servers and check if it has ui.""" servers = [] has_ui = False - for module in (OPN, LMP, FMP, RFP): + for module in ((OPN, LMP, FMP, RFP) if server_id==0 else (FMP,)): server, has_window = create_server_from_base(*module.BASE, + server_id=server_id, silent=silent, - debug_mode=debug_mode) + debug_mode=debug_mode, + no_timeout=no_timeout) has_ui = has_ui or has_window servers.append(server) return servers, has_ui @@ -31,10 +34,10 @@ def create_servers(silent=False, debug_mode=False): def main(args): """Master server main function.""" - register_debug_signal() - - servers, has_ui = create_servers(silent=args.silent, - debug_mode=args.debug_mode) + servers, has_ui = create_servers(server_id=args.server_id, + silent=args.silent, + debug_mode=args.debug_mode, + no_timeout=args.no_timeout) threads = [ threading.Thread( target=server.serve_forever, @@ -42,6 +45,10 @@ def main(args): ) for server in servers ] + cache = Cache(server_id=args.server_id, debug_mode=args.debug_mode, + log_to_file=True, log_to_console=not args.silent, + log_to_window=False) + threads.append(threading.Thread(target=cache.maintain_connection)) for thread in threads: thread.start() @@ -92,7 +99,15 @@ def interactive_mode(local=locals()): help="silent console logs") parser.add_argument("-d", "--debug_mode", action="store_true", dest="debug_mode", - help="enable debug mode, disabling timeouts and \ - lower logging verbosity level") + help="enable debug mode, \ + raising logging verbosity level") + parser.add_argument("-t", "--no_timeout", action="store_true", + dest="no_timeout", + help="disable player timeouts") + parser.add_argument("-S", "--server_id", type=int, default=0, + dest="server_id", + help="specifies the server id used to pull info \ + from the config file (0 for central)") + args = parser.parse_args() main(args) diff --git a/mh/constants.py b/mh/constants.py index 4040c4a..ba09fd9 100644 --- a/mh/constants.py +++ b/mh/constants.py @@ -307,10 +307,16 @@ def slot(item, qty): b"Roadmap for a list of features we are working on.", b"
Welcome to Loc Lac!" ]) +MAINTENANCE = b"
".join([ + b"
Monster Hunter 3 (Tri) Server Project", + b"
MH3SP is currently down for maintenance.", + b"
Please check back later!" +]) CHARGE = b"""
MH3 Server Project - No charge.""" # VULGARITY_INFO = b"""MH3 Server Project - Vulgarity info (low).""" VULGARITY_INFO = b"" -FMP_VERSION = 1 +FMP_CENTRAL_VERSION = 1 +FMP_VERSION = 2 TIME_STATE = 0 IS_JAP = False diff --git a/mh/database.py b/mh/database.py index cbe65b7..4ebf359 100644 --- a/mh/database.py +++ b/mh/database.py @@ -22,359 +22,10 @@ BLANK_CAPCOM_ID = "******" -RESERVE_DC_TIMEOUT = 40.0 - - def new_random_str(length=6): return "".join(random.choice(CHARSET) for _ in range(length)) -class ServerType(object): - OPEN = 1 - ROOKIE = 2 - EXPERT = 3 - RECRUITING = 4 - - -class LayerState(object): - EMPTY = 1 - FULL = 2 - JOINABLE = 3 - - -class Lockable(object): - def __init__(self): - self._lock = RLock() - - def lock(self): - return self - - def __enter__(self): - # Returns True if lock was acquired, False otherwise - return self._lock.acquire() - - def __exit__(self, *args): - self._lock.release() - - -class Players(Lockable): - def __init__(self, capacity): - assert capacity > 0, "Collection capacity can't be zero" - - self.slots = [None for _ in range(capacity)] - self.used = 0 - super(Players, self).__init__() - - def get_used_count(self): - return self.used - - def get_capacity(self): - return len(self.slots) - - def add(self, item): - with self.lock(): - if self.used >= len(self.slots): - return -1 - - item_index = self.index(item) - if item_index != -1: - return item_index - - for i, v in enumerate(self.slots): - if v is not None: - continue - - self.slots[i] = item - self.used += 1 - return i - - return -1 - - def remove(self, item): - assert item is not None, "Item != None" - - with self.lock(): - if self.used < 1: - return False - - if isinstance(item, int): - if item >= self.get_capacity(): - return False - - self.slots[item] = None - self.used -= 1 - return True - - for i, v in enumerate(self.slots): - if v != item: - continue - - self.slots[i] = None - self.used -= 1 - return True - - return False - - def index(self, item): - assert item is not None, "Item != None" - - for i, v in enumerate(self.slots): - if v == item: - return i - - return -1 - - def clear(self): - with self.lock(): - for i in range(self.get_capacity()): - self.slots[i] = None - - def find_first(self, **kwargs): - if self.used < 1: - return None - - for p in self.slots: - if p is None: - continue - - for k, v in kwargs.items(): - if getattr(p, k) != v: - break - else: - return p - - return None - - def find_by_capcom_id(self, capcom_id): - return self.find_first(capcom_id=capcom_id) - - def __len__(self): - return self.used - - def __iter__(self): - if self.used < 1: - return - - for i, v in enumerate(self.slots): - if v is None: - continue - - yield i, v - - -class Circle(Lockable): - def __init__(self, parent): - # type: (City) -> None - self.parent = parent - self.leader = None - self.players = Players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - super(Circle, self).__init__() - - def get_population(self): - return len(self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def is_full(self): - return self.get_population() == self.get_capacity() - - def is_empty(self): - return self.leader is None - - def is_joinable(self): - return not self.departed and not self.is_full() - - def has_password(self): - return self.password is not None - - def reset_players(self, capacity): - with self.lock(): - self.players = Players(capacity) - - def reset(self): - with self.lock(): - self.leader = None - self.reset_players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - - -class City(Lockable): - LAYER_DEPTH = 3 - - def __init__(self, name, parent): - # type: (str, Gate) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.players = Players(4) - self.optional_fields = [] - self.leader = None - self.reserved = None - self.circles = [ - # One circle per player - Circle(self) for _ in range(self.get_capacity()) - ] - super(City, self).__init__() - - def get_population(self): - return len(self.players) - - def in_quest_players(self): - return sum(p.is_in_quest() for _, p in self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - if self.reserved: - return LayerState.FULL - - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - def get_pathname(self): - pathname = self.name - it = self.parent - while it is not None: - pathname = it.name + "\t" + pathname - it = it.parent - return pathname - - def get_first_empty_circle(self): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.is_empty(): - return circle, index - return None, None - - def get_circle_for(self, leader_session): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.leader == leader_session: - return circle, index - return None, None - - def clear_circles(self): - with self.lock(): - for circle in self.circles: - circle.reset() - - def reserve(self, reserve): - with self.lock(): - if reserve: - self.reserved = time.time() - else: - self.reserved = None - - def reset(self): - with self.lock(): - self.state = LayerState.EMPTY - self.players.clear() - self.optional_fields = [] - self.leader = None - self.reserved = None - self.clear_circles() - - -class Gate(object): - LAYER_DEPTH = 2 - - def __init__(self, name, parent, city_count=40, player_capacity=100): - # type: (str, Server, int, int) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.cities = [ - City("City{}".format(i), self) - for i in range(1, city_count+1) - ] - self.players = Players(player_capacity) - self.optional_fields = [] - - def get_population(self): - return len(self.players) + sum(( - city.get_population() - for city in self.cities - )) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - -class Server(object): - LAYER_DEPTH = 1 - - def __init__(self, name, server_type, gate_count=40, capacity=2000, - addr=None, port=None): - self.name = name - self.parent = None - self.server_type = server_type - self.addr = addr - self.port = port - self.gates = [ - Gate("City Gate{}".format(i), self) - for i in range(1, gate_count+1) - ] - self.players = Players(capacity) - - def get_population(self): - return len(self.players) + sum(( - gate.get_population() for gate in self.gates - )) - - def get_capacity(self): - return self.players.get_capacity() - - -def new_servers(): - servers = [] - servers.extend([ - Server("Valor{}".format(i), ServerType.OPEN) - for i in range(1, 5) - ]) - servers.extend([ - Server("Beginners{}".format(i), ServerType.ROOKIE) - for i in range(1, 3) - ]) - servers.extend([ - Server("Veterans{}".format(i), ServerType.EXPERT) - for i in range(1, 3) - ]) - servers.extend([ - Server("Greed{}".format(i), ServerType.RECRUITING) - for i in range(1, 5) - ]) - return servers - - class TempDatabase(object): """A temporary database. @@ -389,11 +40,8 @@ def __init__(self): self.consoles = { # Online support code => Capcom IDs } - self.sessions = { - # PAT Ticket => Owner's session - } self.capcom_ids = { - # Capcom ID => Owner's name and session + # Capcom ID => Hunter name } self.friend_requests = { # Capcom ID => List of friend requests from Capcom IDs @@ -402,7 +50,6 @@ def __init__(self): # Capcom ID => List of Capcom IDs # TODO: May need stable index, see Players class } - self.servers = new_servers() def get_support_code(self, session): """Get the online support code or create one.""" @@ -419,244 +66,19 @@ def get_support_code(self, session): self.consoles[support_code] = [BLANK_CAPCOM_ID] * 6 return support_code - def new_pat_ticket(self, session): - """Generates a new PAT ticket for the session.""" - while True: - session.pat_ticket = new_random_str(11) - if session.pat_ticket not in self.sessions: - break - self.sessions[session.pat_ticket] = session - return session.pat_ticket - - def use_capcom_id(self, session, capcom_id, name=None): - """Attach the session to the Capcom ID.""" - assert capcom_id in self.capcom_ids, "Capcom ID doesn't exist" - - not_in_use = self.capcom_ids[capcom_id]["session"] is None - assert not_in_use, "Capcom ID is already in use" - - name = name or self.capcom_ids[capcom_id]["name"] - self.capcom_ids[capcom_id] = {"name": name, "session": session} + def get_capcom_ids(self, online_support_code): + """Get the Capcom IDs associated with an online support code.""" + return self.consoles[online_support_code] - # TODO: Check if stable index is required - if capcom_id not in self.friend_lists: - self.friend_lists[capcom_id] = [] - if capcom_id not in self.friend_requests: - self.friend_requests[capcom_id] = [] + def assign_capcom_id(self, online_support_code, index, capcom_id): + """Assign a Capcom ID to an online support code.""" + self.consoles[online_support_code][index] = capcom_id - return name + def assign_name(self, capcom_id, name): + self.capcom_ids[capcom_id] = name - def use_user(self, session, index, name): - """Use User from the slot or create one if empty""" - assert 1 <= index <= 6, "Invalid Capcom ID slot" - index -= 1 - users = self.consoles[session.online_support_code] - while users[index] == BLANK_CAPCOM_ID: - capcom_id = new_random_str(6) - if capcom_id not in self.capcom_ids: - self.capcom_ids[capcom_id] = {"name": name, "session": None} - users[index] = capcom_id - break - else: - capcom_id = users[index] - name = self.use_capcom_id(session, capcom_id, name) - session.capcom_id = capcom_id - session.hunter_name = name - - def get_session(self, pat_ticket): - """Returns existing PAT session or None.""" - session = self.sessions.get(pat_ticket) - if session and session.capcom_id: - self.use_capcom_id(session, session.capcom_id, session.hunter_name) - return session - - def disconnect_session(self, session): - """Detach the session from its Capcom ID.""" - if not session.capcom_id: - # Capcom ID isn't chosen yet with OPN/LMP servers - return - self.capcom_ids[session.capcom_id]["session"] = None - - def delete_session(self, session): - """Delete the session from the database.""" - self.disconnect_session(session) - pat_ticket = session.pat_ticket - if pat_ticket in self.sessions: - del self.sessions[pat_ticket] - - def get_users(self, session, first_index, count): - """Returns Capcom IDs tied to the session.""" - users = self.consoles[session.online_support_code] - capcom_ids = [ - (i, (capcom_id, self.capcom_ids.get(capcom_id, {}))) - for i, capcom_id in enumerate(users[:count], first_index) - ] - size = len(capcom_ids) - if size < count: - capcom_ids.extend([ - (index, (BLANK_CAPCOM_ID, {})) - for index in range(first_index+size, first_index+count) - ]) - return capcom_ids - - def join_server(self, session, index): - if session.local_info["server_id"] is not None: - self.leave_server(session, session.local_info["server_id"]) - server = self.get_server(index) - server.players.add(session) - session.local_info["server_id"] = index - session.local_info["server_name"] = server.name - return server - - def leave_server(self, session, index): - self.get_server(index).players.remove(session) - session.local_info["server_id"] = None - session.local_info["server_name"] = None - - def get_server_time(self): - pass - - def get_game_time(self): - pass - - def get_servers(self): - return self.servers - - def get_server(self, index): - assert 0 < index <= len(self.servers), "Invalid server index" - return self.servers[index - 1] - - def get_gates(self, server_id): - return self.get_server(server_id).gates - - def get_gate(self, server_id, index): - gates = self.get_gates(server_id) - assert 0 < index <= len(gates), "Invalid gate index" - return gates[index - 1] - - def join_gate(self, session, server_id, index): - gate = self.get_gate(server_id, index) - gate.parent.players.remove(session) - gate.players.add(session) - session.local_info["gate_id"] = index - session.local_info["gate_name"] = gate.name - return gate - - def leave_gate(self, session): - gate = self.get_gate(session.local_info["server_id"], - session.local_info["gate_id"]) - gate.parent.players.add(session) - gate.players.remove(session) - session.local_info["gate_id"] = None - session.local_info["gate_name"] = None - - def get_cities(self, server_id, gate_id): - return self.get_gate(server_id, gate_id).cities - - def get_city(self, server_id, gate_id, index): - cities = self.get_cities(server_id, gate_id) - assert 0 < index <= len(cities), "Invalid city index" - return cities[index - 1] - - def reserve_city(self, server_id, gate_id, index, reserve): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - reserved_time = city.reserved - if reserve and reserved_time and \ - time.time()-reserved_time < RESERVE_DC_TIMEOUT: - return False - city.reserve(reserve) - return True - - def get_all_users(self, server_id, gate_id, city_id): - """Search for users in layers and its children. - - Let's assume wildcard search isn't possible for servers and gates. - A wildcard search happens when the id is zero. - """ - assert 0 < server_id, "Invalid server index" - assert 0 < gate_id, "Invalid gate index" - gate = self.get_gate(server_id, gate_id) - users = list(gate.players) - cities = [ - self.get_city(server_id, gate_id, city_id) - ] if city_id else self.get_cities(server_id, gate_id) - for city in cities: - users.extend(list(city.players)) - return users - - def find_users(self, capcom_id="", hunter_name=""): - assert capcom_id or hunter_name, "Search can't be empty" - users = [] - for user_id, user_info in self.capcom_ids.items(): - session = user_info["session"] - if not session: - continue - if capcom_id and capcom_id not in user_id: - continue - if hunter_name and \ - hunter_name.lower() not in user_info["name"].lower(): - continue - users.append(session) - return users - - def get_user_name(self, capcom_id): - if capcom_id not in self.capcom_ids: - return "" - return self.capcom_ids[capcom_id]["name"] - - def create_city(self, session, server_id, gate_id, index, - settings, optional_fields): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.optional_fields = optional_fields - city.leader = session - city.reserved = None - return city - - def join_city(self, session, server_id, gate_id, index): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.parent.players.remove(session) - city.players.add(session) - session.local_info["city_name"] = city.name - session.local_info["city_id"] = index - return city - - def leave_city(self, session): - city = self.get_city(session.local_info["server_id"], - session.local_info["gate_id"], - session.local_info["city_id"]) - with city.lock(): - city.parent.players.add(session) - city.players.remove(session) - if not city.get_population(): - city.reset() - session.local_info["city_id"] = None - session.local_info["city_name"] = None - - def layer_detail_search(self, server_type, fields): - cities = [] - - def match_city(city, fields): - with city.lock(): - return all(( - field in city.optional_fields - for field in fields - )) - - for server in self.servers: - if server.server_type != server_type: - continue - for gate in server.gates: - if not gate.get_population(): - continue - cities.extend([ - city - for city in gate.cities - if match_city(city, fields) - ]) - return cities + def get_name(self, capcom_id): + return self.capcom_ids.get(capcom_id, "") def add_friend_request(self, sender_id, recipient_id): # Friend invite can be sent to arbitrary Capcom ID @@ -692,7 +114,7 @@ def get_friends(self, capcom_id, first_index=None, count=None): begin = 0 if first_index is None else (first_index - 1) end = count if count is None else (begin + count) return [ - (k, self.capcom_ids[k]["name"]) + (k, self.capcom_ids[k]) for k in self.friend_lists[capcom_id] if k in self.capcom_ids # Skip unknown Capcom IDs ][begin:end] @@ -852,15 +274,45 @@ def force_update(self): (capcom_id, friend_id) ) + def assign_capcom_id(self, online_support_code, index, capcom_id): + """Assign a Capcom ID to an online support code.""" + self.consoles[online_support_code][index] = capcom_id + with self.connection as cursor: + cursor.execute( + "INSERT OR REPLACE INTO consoles VALUES (?,?,?,?)", + (online_support_code, index, capcom_id, "????") + ) + + def get_capcom_ids(self, online_support_code): + """Get list of associated Capcom IDs from an online support code.""" + with self.connection as cursor: + rows = cursor.execute("SELECT slot_index, capcom_id FROM consoles WHERE support_code = '{}'".format(online_support_code)) + ids = [] + for index, id in rows: + while len(ids) < index: + ids.append(BLANK_CAPCOM_ID) + ids.append(id) + while len(ids) < 6: + ids.append(BLANK_CAPCOM_ID) + return ids + + def get_name(self, capcom_id): + """Get the hunter name associated with a valid Capcom ID.""" + with self.connection as cursor: + rows = cursor.execute("SELECT name FROM consoles WHERE capcom_id = '{}'".format(capcom_id)) + names = [name for name, in rows] + if len(names): + return names[0] + return "" + def use_user(self, session, index, name): - result = self.parent.use_user(session, index, name) + """Insert the current hunter's info into a selected Capcom ID slot.""" with self.connection as cursor: cursor.execute( "INSERT OR REPLACE INTO consoles VALUES (?,?,?,?)", (session.online_support_code, index, session.capcom_id, session.hunter_name) ) - return result def accept_friend(self, capcom_id, friend_id, accepted): if accepted: @@ -884,6 +336,16 @@ def delete_friend(self, capcom_id, friend_id): ) return self.parent.delete_friend(capcom_id, friend_id) + def get_friends(self, capcom_id, first_index=None, count=None): + begin = 0 if first_index is None else (first_index - 1) + end = count if count is None else (begin + count) + with self.connection as cursor: + rows = cursor.execute("SELECT friend_id, name FROM friend_lists INNER JOIN consoles ON friend_lists.friend_id = consoles.capcom_id WHERE friend_lists.capcom_id = '{}'".format(capcom_id)) + friends = [] + for friend_id, name in rows: + friends.append((friend_id, name)) + return friends[begin:end] + class DebugDatabase(TempSQLiteDatabase): """For testing purpose.""" diff --git a/mh/pat.py b/mh/pat.py index 6b6770a..cf9cc35 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -15,9 +15,10 @@ import mh.time_utils as time_utils from mh.constants import \ LAYER_CHAT_COLORS, TERMS_VERSION, TERMS, SUBTERMS, ANNOUNCE, \ - CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, PatID4 + CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, PatID4, \ + FMP_CENTRAL_VERSION from mh.session import Session -import mh.database as db +import mh.state as state try: from typing import Literal, List, Union, Optional # noqa: F401 @@ -34,7 +35,7 @@ class PatServer(server.BasicPatServer, Logger): """Generic PAT server class.""" def __init__(self, address, handler_class, max_thread_count=0, - logger=None, debug_mode=False): + logger=None, debug_mode=False, no_timeout=False): server.BasicPatServer.__init__(self, address, handler_class, max_thread_count) Logger.__init__(self) @@ -43,6 +44,7 @@ def __init__(self, address, handler_class, max_thread_count=0, self.info("Running on {} port {}".format(*address)) self.debug_con = [] self.debug_mode = debug_mode + self.no_timeout = no_timeout def add_to_debug(self, con): """Add connection to the debug connection list.""" @@ -59,6 +61,9 @@ def get_debug(self): def debug_enabled(self): return self.debug_mode + def no_timeout_enabled(self): + return self.no_timeout + def get_pat_handler(self, session): """Return pat handler from session""" for handler in self.debug_con: @@ -268,14 +273,14 @@ def recvAnsConnection(self, packet_id, data, seq): """ settings = pati.ConnectionData.unpack(data) self.server.debug("Connection: {!r}".format(settings)) - pat_ticket = b"" - if "pat_ticket" in settings: - _, pat_ticket = pati.unpack_any(settings.pat_ticket) - elif "online_support_code" in settings: - _, pat_ticket = pati.unpack_any(settings.online_support_code) - self.server.info("Client {} Ticket `{}`".format(self.client_address, - pat_ticket)) - self.sendNtcLogin(5, settings, seq) + pat_ticket = settings.pat_ticket if "pat_ticket" in settings else \ + settings.online_support_code if "online_support_code" in settings \ + else "" + self.server.info("Client {} Ticket `{}`".format(self.client_address, pat_ticket)) + if len(self.session.get_servers()) == 0: + self.sendNtcLogin(2, settings, seq) + else: + self.sendNtcLogin(5, settings, seq) def sendNtcLogin(self, server_status, connection_data, seq): """NtcLogin packet. @@ -290,6 +295,27 @@ def sendNtcLogin(self, server_status, connection_data, seq): self.session = self.session.get(connection_data) self.send_packet(PatID4.NtcLogin, data, seq) + def recvReqMaintenance(self, packet_id, data, seq): + """sendAnsMaintenance packet. + + ID: 62200100 + JP: メンテナンス情報要求 + TR: Maintenance information request + """ + self.sendAnsMaintenance(MAINTENANCE, seq) + + def sendAnsMaintenance(self, maintenance, seq): + """sendAnsMaintenance packet. + + ID: 62200200 + JP: メンテナンス情報通知 + TR: Maintenance information notification + + The server replies with the maintenance information text. + """ + data = pati.lp2_string(maintenance) + self.send_packet(PatID4.AnsMaintenance, data, seq) + def recvReqAuthenticationToken(self, packet_id, data, seq): """ReqAuthenticationToken packet. @@ -721,7 +747,7 @@ def sendAnsTicket(self, seq): TR: PAT ticket response """ pat_ticket = self.session.new_pat_ticket() - data = struct.pack(">H", len(pat_ticket)) + pat_ticket + data = struct.pack(">H", len(pat_ticket)) + pat_ticket.encode('ascii') self.send_packet(PatID4.AnsTicket, data, seq) def recvReqUserListHead(self, packet_id, data, seq): @@ -883,7 +909,7 @@ def sendAnsFmpListVersion(self, seq): JP: FMPリストバージョン確認応答 TR: FMP list version acknowledgment """ - data = struct.pack(">I", FMP_VERSION) + data = struct.pack(">I", FMP_CENTRAL_VERSION) self.send_packet(PatID4.AnsFmpListVersion, data, seq) def sendAnsFmpListVersion2(self, seq): @@ -893,7 +919,7 @@ def sendAnsFmpListVersion2(self, seq): JP: FMPリストバージョン確認応答 TR: FMP list version acknowledgment """ - data = struct.pack(">I", FMP_VERSION) + data = struct.pack(">I", self.session.get_fmp_version()) self.send_packet(PatID4.AnsFmpListVersion2, data, seq) def recvReqFmpListHead(self, packet_id, data, seq): @@ -903,13 +929,9 @@ def recvReqFmpListHead(self, packet_id, data, seq): JP: FMPリスト数送信 / FMPリスト数要求 TR: Send FMP list count / FMP list count request """ - # TODO: Might be worth investigating these parameters as - # they might be useful when using multiple FMP servers. - version, first_index, count = struct.unpack_from( - ">III", data - ) # noqa: F841 - # TODO: Unpack it using pati.Unpacker - header = pati.unpack_bytes(data, 12) # noqa: F841 + version, first_index, count = struct.unpack_from(">III", data) + self.session.preserve_server_ids(first_index, count) + header = pati.unpack_bytes(data, 12) if packet_id == PatID4.ReqFmpListHead: self.sendAnsFmpListHead(seq) elif packet_id == PatID4.ReqFmpListHead2: @@ -967,7 +989,7 @@ def sendAnsFmpListData(self, first_index, count, seq): """ unused = 0 data = struct.pack(">II", unused, count) - data += pati.get_fmp_servers(self.session, first_index, count) + data += pati.get_fmp_central_servers(self.session, first_index, count) self.send_packet(PatID4.AnsFmpListData, data, seq) def sendAnsFmpListData2(self, first_index, count, seq): @@ -1067,20 +1089,31 @@ def recvReqFmpInfo(self, packet_id, data, seq): """ index, = struct.unpack_from(">I", data) fields = pati.unpack_bytes(data, 4) - server = self.session.join_server(index) - config = get_config("FMP") - fmp_addr = get_ip(config["IP"]) - fmp_port = config["Port"] fmp_data = pati.FmpData() - fmp_data.server_address = pati.String(server.addr or fmp_addr) - fmp_data.server_port = pati.Word(server.port or fmp_port) - fmp_data.assert_fields(fields) + if packet_id == PatID4.ReqFmpInfo: + config = get_config("FMP") + central_fmp_addr = get_ip(config["IP"]) + central_fmp_port = config["Port"] + fmp_data.server_address = pati.String(central_fmp_addr) + fmp_data.server_port = pati.Word(central_fmp_port) + fmp_data.assert_fields(fields) self.sendAnsFmpInfo(fmp_data, fields, seq) elif packet_id == PatID4.ReqFmpInfo2: + if not self.session.server_index_exists(index): + self.sendAnsAlert(PatID4.AnsFmpInfo2, + "
\ + Server is offline.", + seq) + return + server_id = self.session.recall_server_id(index) + server = self.session.join_server(server_id) + fmp_data.server_address = pati.String(server.addr) + fmp_data.server_port = pati.Word(server.port) + fmp_data.assert_fields(fields) self.sendAnsFmpInfo2(fmp_data, fields, seq) - # Preserve session in database, due to server selection + # Preserve session in state, due to server selection self.session.request_reconnection = True def sendAnsFmpInfo(self, fmp_data, fields, seq): @@ -1328,10 +1361,10 @@ def sendAnsUserSearchInfo(self, capcom_id, search_info, seq): # Specifically when a client is deserializing data from the packets # `NtcLayerBinary` and `NtcLayerBinary2` # TODO: Proper field value and name - user_info.info_mine_0x0f = pati.Long(int(hash(user.capcom_id)) - & 0xffffffff) - user_info.info_mine_0x10 = pati.Long(int(hash(user.capcom_id[::-1])) - & 0xffffffff) + user_info.info_mine_0x0f = pati.Long(int(hash(user.capcom_id)) & + 0xffffffff) + user_info.info_mine_0x10 = pati.Long(int(hash(user.capcom_id[::-1])) & + 0xffffffff) data = user_info.pack() # TODO: Figure out the optional fields @@ -2658,7 +2691,7 @@ def sendNtcCircleHost(self, circle, new_leader, new_leader_index, seq): def notify_city_info_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, state.City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2670,7 +2703,7 @@ def notify_city_info_set(self, path): def notify_city_number_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, state.City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2680,14 +2713,14 @@ def notify_city_number_set(self, path): @staticmethod def get_layer(path): - # type: (pati.LayerPath) -> Optional[db.Server | db.Gate | db.City] - database = db.get_instance() + # type: (pati.LayerPath) -> Optional[state.Server | state.Gate | state.City] + curr_state = state.get_instance() if path.city_id > 0: - return database.get_city(path.server_id, path.gate_id, path.city_id) + return curr_state.get_city(path.server_id, path.gate_id, path.city_id) elif path.gate_id > 0: - return database.get_gate(path.server_id, path.gate_id) + return curr_state.get_gate(path.server_id, path.gate_id) elif path.server_id > 0: - return database.get_server(path.server_id) + return curr_state.get_server(path.server_id) return None def notify_layer_departure(self, end): @@ -2711,7 +2744,7 @@ def notify_layer_departure(self, end): if path.city_id > 0: city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, state.City) self.notify_city_number_set(path) if city.leader is None: @@ -2796,7 +2829,7 @@ def on_tick(self): # Send a ping with 30 seconds interval if self.ping_timer.elapsed() >= 30: - if not self.server.debug_enabled() and not self.line_check: + if not self.server.no_timeout_enabled() and not self.line_check: raise Exception("Client timed out.") self.line_check = False self.sendReqLineCheck() diff --git a/mh/pat_item.py b/mh/pat_item.py index aa8e321..ec2e075 100644 --- a/mh/pat_item.py +++ b/mh/pat_item.py @@ -9,7 +9,7 @@ from collections import OrderedDict from mh.constants import pad from other.utils import to_bytearray, get_config, get_ip, GenericUnpacker -from mh.database import Server, Gate, City +from mh.state import Server, Gate, City class ItemType: Custom = 0 @@ -893,7 +893,6 @@ def pack_from(layer_data): return data.pack() - def patdata_extender(unpacker): """PatData classes must be defined above this function.""" for name, value in globals().items(): @@ -964,14 +963,11 @@ def get_fmp_servers(session, first_index, count): fmp_port = config["Port"] data = b"" - start = first_index - 1 - end = start + count - servers = session.get_servers()[start:end] + + servers = session.recall_servers(first_index, count) for i, server in enumerate(servers, first_index): fmp_data = FmpData() fmp_data.index = Long(i) # The server might be full, if zero - server.addr = server.addr or fmp_addr - server.port = server.port or fmp_port fmp_data.server_address = String(server.addr) fmp_data.server_port = Word(server.port) # Might produce invalid reads if too high @@ -987,6 +983,32 @@ def get_fmp_servers(session, first_index, count): return data +def get_fmp_central_servers(session, first_index, count): + assert first_index > 0, "Invalid list index" + + config = get_config("FMP") + fmp_addr = get_ip(config["IP"]) + fmp_port = config["Port"] + + start = first_index - 1 + end = start + count + data = b"" + for i in range(start, end): + fmp_data = FmpData() + fmp_data.index = Long(count+1) # Index must not match any "real" server. + fmp_data.server_name = String("central") + fmp_data.server_address = String(fmp_addr) + fmp_data.server_port = Word(fmp_port) + fmp_data.server_type = LongLong(1) + fmp_data.player_count = Long(1) + fmp_data.player_capacity = Long(1) + fmp_data.unk_string_0x0b = String("X") + fmp_data.unk_long_0x0c = Long(0x12345678) + data += fmp_data.pack() + + return data + + def get_layer_children(session, first_index, count, sibling=False): assert first_index > 0, "Invalid list index" diff --git a/mh/session.py b/mh/session.py index d89e966..1abd1ec 100644 --- a/mh/session.py +++ b/mh/session.py @@ -4,12 +4,17 @@ # SPDX-License-Identifier: AGPL-3.0-or-later """Monster Hunter session module.""" +import struct +import time + import mh.database as db +from mh.state import get_instance, Players, LayerState import mh.pat_item as pati from other.utils import to_bytearray, to_str DB = db.get_instance() +STATE = get_instance() class SessionState: @@ -52,19 +57,86 @@ def __init__(self, connection_handler): self.state = SessionState.UNKNOWN self.binary_setting = b"" self.search_payload = None + self.loaded_server_ids = {} self.hunter_info = pati.HunterSettings() + def serialize(self): + pdict = { + "pat_ticket": self.pat_ticket, + "local_info_server_id": self.local_info["server_id"], + "local_info_server_name": self.local_info["server_name"], + "local_info_gate_id": self.local_info["gate_id"], + "local_info_gate_name": self.local_info["gate_name"], + "local_info_city_id": self.local_info["city_id"], + "local_info_city_name": self.local_info["city_name"], + "local_info_city_size": self.local_info["city_size"], + "local_info_city_capacity": self.local_info["city_capacity"], + "local_info_circle_id": self.local_info["circle_id"], + "online_support_code": self.online_support_code, + "capcom_id": self.capcom_id, + "hunter_name": self.hunter_name, + "hunter_stats": self.hunter_stats, + "layer": self.layer, + "state": self.state, + "binary_setting": self.binary_setting, + "hunter_info": self.hunter_info.pack().decode( + encoding='ISO-8859-1' + ) + } + return pdict + + @staticmethod + def deserialize(pdict): + session = Session(None) + session.pat_ticket = str(pdict["pat_ticket"])\ + if pdict["pat_ticket"] else pdict["pat_ticket"] + session.local_info["server_id"] = int(pdict["local_info_server_id"])\ + if pdict["local_info_server_id"] else pdict["local_info_server_id"] + session.local_info["server_name"] =\ + str(pdict["local_info_server_name"])\ + if pdict["local_info_server_name"]\ + else pdict["local_info_server_name"] + session.local_info["gate_id"] = int(pdict["local_info_gate_id"])\ + if pdict["local_info_gate_id"] else pdict["local_info_gate_id"] + session.local_info["gate_name"] = str(pdict["local_info_gate_name"])\ + if pdict["local_info_gate_name"] else pdict["local_info_gate_name"] + session.local_info["city_id"] = int(pdict["local_info_city_id"])\ + if pdict["local_info_city_id"] else pdict["local_info_city_id"] + session.local_info["city_name"] = str(pdict["local_info_city_name"])\ + if pdict["local_info_city_name"] else pdict["local_info_city_name"] + session.local_info["city_size"] = int(pdict["local_info_city_size"])\ + if pdict["local_info_city_size"] else pdict["local_info_city_size"] + session.local_info["city_capacity"] =\ + int(pdict["local_info_city_capacity"])\ + if pdict["local_info_city_capacity"]\ + else pdict["local_info_city_capacity"] + session.local_info["circle_id"] = int(pdict["local_info_circle_id"])\ + if pdict["local_info_circle_id"] else pdict["local_info_circle_id"] + session.online_support_code = str(pdict["online_support_code"])\ + if pdict["online_support_code"] else pdict["online_support_code"] + session.capcom_id = str(pdict["capcom_id"]) + session.hunter_name = str(pdict["hunter_name"]) + session.hunter_stats = pdict["hunter_stats"] + session.layer = int(pdict["layer"]) + session.state = int(pdict["state"]) + session.binary_setting = pdict["binary_setting"] + h_settings = bytearray(pdict["hunter_info"], encoding='ISO-8859-1') + session.hunter_info = pati.HunterSettings().unpack(h_settings, + len(h_settings)) + return session + def get(self, connection_data): """Return the session associated with the connection data, if any.""" - if hasattr(connection_data, "pat_ticket"): + has_pat_ticket = hasattr(connection_data, "pat_ticket") + if has_pat_ticket: self.pat_ticket = to_str( pati.unpack_binary(connection_data.pat_ticket) ) + session = STATE.get_session(self.pat_ticket) or self if hasattr(connection_data, "online_support_code"): self.online_support_code = to_str( pati.unpack_string(connection_data.online_support_code) ) - session = DB.get_session(self.pat_ticket) or self if session != self: assert session.connection is None, "Session is already in use" session.connection = self.connection @@ -78,6 +150,20 @@ def get(self, connection_data): "online_support_code" in connection_data) return session + def session_ready(self, connection_data): + if hasattr(connection_data, "pat_ticket"): + return STATE.session_ready(to_str( + pati.unpack_binary(connection_data.pat_ticket) + )) + else: + return STATE.session_ready(self.pat_ticket) + + def set_session_ready(self, connection_data, store_data): + STATE.set_session_ready( + to_str(pati.unpack_binary(connection_data.pat_ticket)), + store_data + ) + def get_support_code(self): """Return the online support code.""" return DB.get_support_code(self) @@ -88,7 +174,7 @@ def disconnect(self): It doesn't purge the session state nor its PAT ticket. """ self.connection = None - DB.disconnect_session(self) + STATE.disconnect_session(self) def delete(self): """Delete the current session. @@ -98,39 +184,76 @@ def delete(self): - We should probably create a SessionManager thread per server. """ if not self.request_reconnection: - DB.delete_session(self) + STATE.delete_session(self) def is_jap(self): """TODO: Heuristic using the connection data to detect region.""" pass def new_pat_ticket(self): - DB.new_pat_ticket(self) - return to_bytearray(self.pat_ticket) + # type: () -> str + STATE.new_pat_ticket(self) + return self.pat_ticket + + def get_fmp_version(self): + return STATE.get_servers_version() def get_users(self, first_index, count): - return DB.get_users(self, first_index, count) + return STATE.get_users(self, first_index, count) def use_user(self, index, name): - DB.use_user(self, index, name) + STATE.use_user(self, index, name) + + def server_index_exists(self, index): + try: + if index in self.loaded_server_ids: + self.recall_server(index) + return True + else: + return False + except AssertionError: + return False + + def preserve_server_ids(self, first_index, count): + server_ids, servers = STATE.get_servers(include_ids=True) + if first_index-1 + count > len(server_ids): + count = len(server_ids) - (first_index - 1) + assert first_index <= len(server_ids) + server_ids = server_ids[first_index-1:first_index-1+count] + self.loaded_server_ids = {} + for i, server_id in enumerate(server_ids): + self.loaded_server_ids[first_index+i] = server_id + + def recall_servers(self, first_index, count): + server_ids, servers = STATE.get_servers(include_ids=True) + servers = [] + for i in range(first_index, first_index+count): + servers.append(self.recall_server(i)) + return servers + + def recall_server(self, index): + return STATE.get_server(self.loaded_server_ids[index]) + + def recall_server_id(self, index): + return self.loaded_server_ids[index] def get_servers(self): - return DB.get_servers() + return STATE.get_servers() def get_server(self): assert self.local_info['server_id'] is not None - return DB.get_server(self.local_info['server_id']) + return STATE.get_server(self.local_info['server_id']) def get_gate(self): assert self.local_info['gate_id'] is not None - return DB.get_gate(self.local_info['server_id'], - self.local_info['gate_id']) + return STATE.get_gate(self.local_info['server_id'], + self.local_info['gate_id']) def get_city(self): assert self.local_info['city_id'] is not None - return DB.get_city(self.local_info['server_id'], - self.local_info['gate_id'], - self.local_info['city_id']) + return STATE.get_city(self.local_info['server_id'], + self.local_info['gate_id'], + self.local_info['city_id']) def get_circle(self): assert self.local_info['circle_id'] is not None @@ -191,10 +314,10 @@ def layer_detail_search(self, detailed_fields): (field_id, value) for field_id, field_type, value in detailed_fields ] # Convert detailed to simple optional fields - return DB.layer_detail_search(server_type, fields) + return STATE.layer_detail_search(server_type, fields) def join_server(self, server_id): - return DB.join_server(self, server_id) + return STATE.join_server(self, server_id) def get_layer_children(self): if self.layer == 0: @@ -213,73 +336,73 @@ def get_layer_sibling(self): def find_users_by_layer(self, server_id, gate_id, city_id, first_index, count, recursive=False): if recursive: - players = DB.get_all_users(server_id, gate_id, city_id) + players = STATE.get_all_users(server_id, gate_id, city_id) else: layer = \ - DB.get_city(server_id, gate_id, city_id) if city_id else \ - DB.get_gate(server_id, gate_id) if gate_id else \ - DB.get_server(server_id) + STATE.get_city(server_id, gate_id, city_id) if city_id else \ + STATE.get_gate(server_id, gate_id) if gate_id else \ + STATE.get_server(server_id) players = list(layer.players) start = first_index - 1 return players[start:start+count] def find_user_by_capcom_id(self, capcom_id): - sessions = DB.find_users(capcom_id=capcom_id) + sessions = STATE.find_users(capcom_id=capcom_id) if sessions: return sessions[0] return None def find_users(self, capcom_id, hunter_name, first_index, count): - users = DB.find_users(capcom_id, hunter_name) + users = STATE.find_users(capcom_id, hunter_name) start = first_index - 1 return users[start:start+count] def get_user_name(self, capcom_id): - return DB.get_user_name(capcom_id) + return DB.get_name(capcom_id) def leave_server(self): - DB.leave_server(self, self.local_info["server_id"]) + STATE.leave_server(self) def get_gates(self): - return DB.get_gates(self.local_info["server_id"]) + return STATE.get_gates(self.local_info["server_id"]) def join_gate(self, gate_id): - DB.join_gate(self, self.local_info["server_id"], gate_id) + STATE.join_gate(self, self.local_info["server_id"], gate_id) self.state = SessionState.GATE def leave_gate(self): - DB.leave_gate(self) + STATE.leave_gate(self) self.state = SessionState.LOG_IN def get_cities(self): - return DB.get_cities(self.local_info["server_id"], - self.local_info["gate_id"]) + return STATE.get_cities(self.local_info["server_id"], + self.local_info["gate_id"]) def is_city_empty(self, city_id): - return DB.get_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id).get_state() == db.LayerState.EMPTY + return STATE.get_city(self.local_info["server_id"], + self.local_info["gate_id"], + city_id).get_state() == LayerState.EMPTY def reserve_city(self, city_id, reserve): - return DB.reserve_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id, reserve) + return STATE.reserve_city(self.local_info["server_id"], + self.local_info["gate_id"], + city_id, reserve) def create_city(self, city_id, settings, optional_fields): - return DB.create_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id, settings, optional_fields) + return STATE.create_city(self, + self.local_info["server_id"], + self.local_info["gate_id"], + city_id, settings, optional_fields) def join_city(self, city_id): - DB.join_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id) + STATE.join_city(self, + self.local_info["server_id"], + self.local_info["gate_id"], + city_id) self.state = SessionState.CITY def leave_city(self): - DB.leave_city(self) + STATE.leave_city(self) self.state = SessionState.GATE def try_transfer_city_leadership(self): diff --git a/mh/state.py b/mh/state.py new file mode 100644 index 0000000..0b78372 --- /dev/null +++ b/mh/state.py @@ -0,0 +1,890 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2023 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Monster Hunter state module.""" + + +from mh import database +from math import floor +import time +from threading import RLock, Event + + +try: + # Python 3 + import selectors +except ImportError: + # Python 2 + import externals.selectors2 as selectors + + +RESERVE_DC_TIMEOUT = 40.0 + + +class ServerType(object): + OPEN = 1 + ROOKIE = 2 + EXPERT = 3 + RECRUITING = 4 + + +class LayerState(object): + EMPTY = 1 + FULL = 2 + JOINABLE = 3 + + +class Lockable(object): + def __init__(self): + self._lock = RLock() + + def lock(self): + return self + + def __enter__(self): + # Returns True if lock was acquired, False otherwise + return self._lock.acquire() + + def __exit__(self, *args): + self._lock.release() + + +class Players(Lockable): + def __init__(self, capacity): + assert capacity > 0, "Collection capacity can't be zero" + + self.slots = [None for _ in range(capacity)] + self.used = 0 + super(Players, self).__init__() + + def get_used_count(self): + return self.used + + def get_capacity(self): + return len(self.slots) + + def add(self, item): + with self.lock(): + if self.used >= len(self.slots): + return -1 + + item_index = self.index(item) + if item_index != -1: + return item_index + + for i, v in enumerate(self.slots): + if v is not None: + continue + + self.slots[i] = item + self.used += 1 + return i + + return -1 + + def remove(self, item): + assert item is not None, "Item != None" + + with self.lock(): + if self.used < 1: + return False + + if isinstance(item, int): + if item >= self.get_capacity(): + return False + + self.slots[item] = None + self.used -= 1 + return True + + for i, v in enumerate(self.slots): + if v != item: + continue + + self.slots[i] = None + self.used -= 1 + return True + + return False + + def index(self, item): + assert item is not None, "Item != None" + + for i, v in enumerate(self.slots): + if v == item: + return i + + return -1 + + def clear(self): + with self.lock(): + for i in range(self.get_capacity()): + self.slots[i] = None + + def find_first(self, **kwargs): + if self.used < 1: + return None + + for p in self.slots: + if p is None: + continue + + for k, v in kwargs.items(): + if getattr(p, k) != v: + break + else: + return p + + return None + + def find_by_capcom_id(self, capcom_id): + return self.find_first(capcom_id=capcom_id) + + def __len__(self): + return self.used + + def __iter__(self): + if self.used < 1: + raise StopIteration + + for i, v in enumerate(self.slots): + if v is None: + continue + + yield i, v + + def serialize(self): + if self.used == 0: + return {"capacity": len(self.slots)} + pdict = { + "slots": [(p.serialize() if p is not None else None) + for p in self.slots], + "used": self.used + } + return pdict + + @staticmethod + def deserialize(pdict, parent): + if "used" not in pdict.keys(): + return Players(pdict["capacity"]) + from mh.session import Session + players = Players(len(pdict["slots"])) + players.slots = [(Session.deserialize(p) if p is not None else None) + for p in pdict["slots"]] + players.used = pdict["used"] + return players + + +class Circle(Lockable): + def __init__(self, parent): + # type: (City) -> None + self.parent = parent + self.leader = None + self.players = Players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + super(Circle, self).__init__() + + def get_population(self): + return len(self.players) + + def get_capacity(self): + return self.players.get_capacity() + + def is_full(self): + return self.get_population() == self.get_capacity() + + def is_empty(self): + return self.leader is None + + def is_joinable(self): + return not self.departed and not self.is_full() + + def has_password(self): + return self.password is not None + + def reset_players(self, capacity): + with self.lock(): + self.players = Players(capacity) + + def reset(self): + with self.lock(): + self.leader = None + self.reset_players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + + def serialize(self): + serialized_players = self.players.serialize() + if "used" not in serialized_players.keys(): + return {} + cdict = { + "parent": None, + "leader": self.leader.serialize() if self.leader is not None + else None, + "players": serialized_players, + "departed": self.departed, + "quest_id": self.quest_id, + "embarked": self.embarked, + "password": self.password, + "remarks": self.remarks, + "unk_byte_0x0e": self.unk_byte_0x0e + } + return cdict + + @staticmethod + def deserialize(cdict, parent): + if not len(cdict.keys()): + return Circle(parent) + from mh.session import Session + circle = Circle(parent) + circle.leader = Session.deserialize(cdict["leader"])\ + if cdict["leader"] is not None else None + circle.players = Players.deserialize(cdict["players"], circle) + circle.departed = cdict["departed"] + circle.quest_id = cdict["quest_id"] + circle.embarked = cdict["embarked"] + circle.password = cdict["password"] + circle.remarks = cdict["remarks"] + circle.unk_byte_0x0e = cdict["unk_byte_0x0e"] + return circle + + +class City(Lockable): + LAYER_DEPTH = 3 + + def __init__(self, name, parent): + # type: (str, Gate) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.players = Players(4) + self.optional_fields = [] + self.leader = None + self.reserved = None + self.circles = [ + # One circle per player + Circle(self) for _ in range(self.get_capacity()) + ] + super(City, self).__init__() + + def get_population(self): + return len(self.players) + + def in_quest_players(self): + return sum(p.is_in_quest() for _, p in self.players) + + def get_capacity(self): + return self.players.get_capacity() + + def get_state(self): + if self.reserved: + return LayerState.FULL + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def get_pathname(self): + pathname = self.name + it = self.parent + while it is not None: + pathname = it.name + "\t" + pathname + it = it.parent + return pathname + + def get_first_empty_circle(self): + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.is_empty(): + return circle, index + return None, None + + def get_circle_for(self, leader_session): + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.leader == leader_session: + return circle, index + return None, None + + def clear_circles(self): + with self.lock(): + for circle in self.circles: + circle.reset() + + def reserve(self, reserve): + with self.lock(): + if reserve: + self.reserved = time.time() + else: + self.reserved = None + + def reset(self): + with self.lock(): + self.state = LayerState.EMPTY + self.players.clear() + self.optional_fields = [] + self.leader = None + self.reserved = None + self.clear_circles() + + def get_all_players(self): + with self.players.lock(): + return [p for _, p in self.players] + + def serialize(self): + serialized_players = self.players.serialize() + if "used" not in serialized_players.keys(): + return {"name": self.name} + cdict = { + "name": self.name, + "parent": None, + "state": self.state, + "players": serialized_players, + "optional_fields": self.optional_fields, + "leader": self.leader.serialize() if self.leader is not None + else None, + "reserved": self.reserved, + "circles": [c.serialize() for c in self.circles] + } + return cdict + + @staticmethod + def deserialize(cdict, parent): + if len(cdict.keys()) < 2: + return City(cdict["name"], None) + from mh.session import Session + city = City(str(cdict["name"]) if cdict["name"] is not None + else cdict["name"], cdict["parent"]) + city.parent = parent + city.state = cdict["state"] + city.players = Players.deserialize(cdict["players"], parent) + city.optional_fields = cdict["optional_fields"] + city.leader = Session.deserialize(cdict["leader"])\ + if cdict["leader"] is not None else None + city.reserved = cdict["reserved"] + city.circles = [Circle.deserialize(c, city) for c in cdict["circles"]] + return city + + +class Gate(object): + LAYER_DEPTH = 2 + + def __init__(self, name, parent, city_count=40, player_capacity=100): + # type: (str, Server, int, int) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.cities = [ + City("City{}".format(i), self) + for i in range(1, city_count+1) + ] + self.players = Players(player_capacity) + self.optional_fields = [] + + def get_population(self): + return len(self.players) + sum(( + city.get_population() + for city in self.cities + )) + + def get_capacity(self): + return self.players.get_capacity() + + def get_state(self): + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def get_all_players(self): + players = [p for _, p in self.players] + for city in self.cities: + players = players + city.get_all_players() + return players + + def serialize(self): + gdict = { + "name": self.name, + "parent": None, + "state": self.state, + "cities": [c.serialize() for c in self.cities], + "players": self.players.serialize(), + "optional_fields": self.optional_fields + } + return gdict + + @staticmethod + def deserialize(gdict, parent): + gate = Gate(str(gdict["name"]) if gdict["name"] is not None + else gdict["name"], parent) + gate.state = gdict["state"] + gate.cities = [City.deserialize(c, gate) for c in gdict["cities"]] + gate.players = Players.deserialize(gdict["players"], gate) + gate.optional_fields = gdict["optional_fields"] + return gate + + +class Server(object): + LAYER_DEPTH = 1 + + def __init__(self, name, server_type, capacity=2000, + addr=None, port=None): + self.name = name + self.parent = None + self.server_type = server_type + self.addr = addr + self.port = port + gate_count = int(floor(capacity / 100)) + remainder = capacity % 100 + self.gates = [ + Gate("City Gate {}".format(i), self) + for i in range(1, gate_count+1) + ] + if remainder: + self.gates.append(Gate( + "City Gate {}".format(len(self.gates)+1), + self, player_capacity=remainder + )) + self.players = Players(capacity) + + def get_population(self): + return len(self.players) + sum(( + gate.get_population() for gate in self.gates + )) + + def get_capacity(self): + return self.players.get_capacity() + + def get_all_players(self): + players = [p for _, p in self.players] + for gate in self.gates: + players = players + gate.get_all_players() + return players + + def serialize(self): + sdict = { + "name": self.name, + "parent": self.parent, + "server_type": self.server_type, + "addr": self.addr, + "port": self.port, + "gates": [g.serialize() for g in self.gates], + "players": self.players.serialize() + } + return sdict + + @staticmethod + def deserialize(sdict): + server = Server(str(sdict["name"]) if sdict["name"] is not None + else sdict["name"], + int(sdict["server_type"]) if sdict["server_type"] + else sdict["server_type"], + addr=str(sdict["addr"]) if sdict["addr"] is not None + else sdict["addr"], + port=int(sdict["port"]) if sdict["port"] is not None + else sdict["port"]) + server.parent = sdict["parent"] + server.gates = [Gate.deserialize(g, server) for g in sdict["gates"]] + server.players = Players.deserialize(sdict["players"], server) + return server + + +def new_servers(): + servers = [] + servers.extend([ + Server("Valor{}".format(i), ServerType.OPEN) + for i in range(1, 5) + ]) + servers.extend([ + Server("Beginners{}".format(i), ServerType.ROOKIE) + for i in range(1, 3) + ]) + servers.extend([ + Server("Veterans{}".format(i), ServerType.EXPERT) + for i in range(1, 3) + ]) + servers.extend([ + Server("Greed{}".format(i), ServerType.RECRUITING) + for i in range(1, 5) + ]) + return servers + + +class State(object): + def __init__(self): + self.sessions = { + # PAT Ticket => Owner's session + } + self.capcom_ids = { + # Capcom ID => {Owner's name, Owner's session} + } + self.cache = None + self.server_id = None + self.server = None + self.initialized = Event() + + def setup_server(self, server_id, server_name, server_type, + capacity, server_addr, server_port): + self.server_id = server_id + if server_id != 0: + self.server = Server(server_name, server_type, capacity, + addr=server_addr, port=server_port) + else: + self.server = None + self.initialized.set() + + def new_pat_ticket(self, session): + """Generates a new PAT ticket for the session.""" + while True: + session.pat_ticket = database.new_random_str(11) + if session.pat_ticket not in self.sessions: + break + self.sessions[session.pat_ticket] = session + return session.pat_ticket + + def register_pat_ticket(self, session): + """Register a Session's PAT ticket from another server.""" + self.sessions[session.pat_ticket] = session + self.capcom_ids[session.capcom_id] = {"name": "", "session": None} + self.join_server(session, self.server_id) + + def use_capcom_id(self, session, capcom_id, name=None): + """Attach the session to the Capcom ID.""" + assert capcom_id in self.capcom_ids, "Capcom ID doesn't exist" + + not_in_use = self.capcom_ids[capcom_id]["session"] is None + assert not_in_use, "Capcom ID is already in use. Try again in 60 seconds." + + name = name or self.capcom_ids[capcom_id]["name"] + self.capcom_ids[capcom_id] = {"name": name, "session": session} + + db = database.get_instance() + db.assign_name(capcom_id, name) + + # TODO: Check if stable index is required + if capcom_id not in db.friend_lists: + db.friend_lists[capcom_id] = [] + if capcom_id not in db.friend_requests: + db.friend_requests[capcom_id] = [] + + return name + + def use_user(self, session, index, name): + """Use User from the slot or create one if empty""" + assert 1 <= index <= 6, "Invalid Capcom ID slot" + index -= 1 + users = database.get_instance().get_capcom_ids( + session.online_support_code + ) + while users[index] == "******": + capcom_id = database.new_random_str(6) + if capcom_id not in self.capcom_ids and \ + not database.get_instance().get_name(capcom_id): + self.capcom_ids[capcom_id] = {"name": name, "session": None} + database.get_instance().assign_capcom_id( + session.online_support_code, index, capcom_id + ) + break + else: + capcom_id = users[index] + name = self.use_capcom_id(session, capcom_id, name) + session.capcom_id = capcom_id + session.hunter_name = name + database.get_instance().use_user(session, index, name) + + def get_session(self, pat_ticket): + """Returns existing PAT session or None.""" + session = self.sessions.get(pat_ticket) + if session and session.capcom_id: + try: + self.use_capcom_id( + session, session.capcom_id, session.hunter_name + ) + except AssertionError as e: + return None + return session + + def disconnect_session(self, session): + """Detach the session from its Capcom ID.""" + if not session.capcom_id: + # Capcom ID isn't chosen yet with OPN/LMP servers + return + self.capcom_ids[session.capcom_id]["session"] = None + + def delete_session(self, session): + """Delete the session from the database.""" + self.disconnect_session(session) + pat_ticket = session.pat_ticket + if pat_ticket in self.sessions: + del self.sessions[pat_ticket] + self.cache.notify_session_deletion(session.capcom_id) + + def fetch_id(self, capcom_id): + if capcom_id not in self.capcom_ids: + self.capcom_ids[capcom_id] = { + "name": database.get_instance().get_name(capcom_id), + "session": None + } + return self.capcom_ids[capcom_id] + + def get_users(self, session, first_index, count): + """Returns Capcom IDs tied to the session.""" + users = database.get_instance().get_capcom_ids( + session.online_support_code + ) + capcom_ids = [ + (i, (capcom_id, self.capcom_ids.get( + capcom_id, self.fetch_id(capcom_id) + ))) + for i, capcom_id in enumerate(users[:count], first_index) + ] + size = len(capcom_ids) + if size < count: + capcom_ids.extend([ + (index, ("******", {})) + for index in range(first_index+size, first_index+count) + ]) + return capcom_ids + + def join_server(self, session, server_id): + server = self.get_server(server_id) + if server_id != self.server_id: + # Joining another server + if self.cache: + self.cache.send_session_info(server_id, session) + if session.local_info["server_id"] is not None: + self.leave_server(session) + else: + # Connecting to this server + server.players.add(session) + session.local_info["server_id"] = server_id + session.local_info["server_name"] = server.name + return server + + def leave_server(self, session): + self.server.players.remove(session) + session.local_info["server_id"] = None + session.local_info["server_name"] = None + + def get_server_time(self): + pass + + def get_game_time(self): + pass + + def get_servers_version(self): + return self.cache.servers_version + + def get_servers(self, include_ids=False): + if not self.cache: + return [] + server_ids, servers = self.cache.get_server_list(include_ids=True) + below, above = [], [] + below_ids, above_ids = [], [] + for server_id, server in zip(server_ids, servers): + if server_id < self.server_id: + below.append(server) + below_ids.append(server_id) + elif server_id > self.server_id: + above.append(server) + above_ids.append(server_id) + if self.server_id != 0: + below.append(self.server) + below_ids.append(self.server_id) + below.extend(above) + below_ids.extend(above_ids) + if include_ids: + return below_ids, below + return below + + def get_server(self, server_id): + if server_id == self.server_id: + return self.server + return self.cache.get_server(server_id) + + def get_gates(self, server_id): + return self.get_server(server_id).gates + + def get_gate(self, server_id, index): + gates = self.get_gates(server_id) + assert 0 < index <= len(gates), "Invalid gate index" + return gates[index - 1] + + def join_gate(self, session, server_id, index): + gate = self.get_gate(server_id, index) + gate.parent.players.remove(session) + gate.players.add(session) + session.local_info["gate_id"] = index + session.local_info["gate_name"] = gate.name + return gate + + def leave_gate(self, session): + gate = self.get_gate(session.local_info["server_id"], + session.local_info["gate_id"]) + gate.parent.players.add(session) + gate.players.remove(session) + session.local_info["gate_id"] = None + session.local_info["gate_name"] = None + + def get_cities(self, server_id, gate_id): + return self.get_gate(server_id, gate_id).cities + + def get_city(self, server_id, gate_id, index): + cities = self.get_cities(server_id, gate_id) + assert 0 < index <= len(cities), "Invalid city index" + return cities[index - 1] + + def reserve_city(self, server_id, gate_id, index, reserve): + city = self.get_city(server_id, gate_id, index) + with city.lock(): + reserved_time = city.reserved + if reserve and reserved_time and \ + time.time()-reserved_time < RESERVE_DC_TIMEOUT: + return False + city.reserve(reserve) + return True + + def get_all_users(self, server_id, gate_id, city_id): + """Search for users in layers and its children. + + Let's assume wildcard search isn't possible for servers and gates. + A wildcard search happens when the id is zero. + """ + assert 0 < server_id, "Invalid server index" + assert 0 < gate_id, "Invalid gate index" + gate = self.get_gate(server_id, gate_id) + users = list(gate.players) + cities = [ + self.get_city(server_id, gate_id, city_id) + ] if city_id else self.get_cities(server_id, gate_id) + for city in cities: + users.extend(list(city.players)) + return users + + def find_users(self, capcom_id="", hunter_name=""): + assert capcom_id or hunter_name, "Search can't be empty" + users = [] + for user_id, user_info in self.capcom_ids.items(): + session = user_info["session"] + if not session: + continue + if capcom_id and capcom_id not in user_id: + continue + if hunter_name and \ + hunter_name.lower() not in user_info["name"].lower(): + continue + users.append(session) + for user_info in self.cache.get_remote_players_list(): + if not user_info: + continue + if capcom_id and capcom_id not in user_info.capcom_id: + continue + if hunter_name and \ + hunter_name.lower() not in user_info.hunter_name.lower(): + continue + users.append(user_info) + return users + + def create_city(self, session, server_id, gate_id, index, + settings, optional_fields): + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.optional_fields = optional_fields + city.leader = session + city.reserved = None + return city + + def join_city(self, session, server_id, gate_id, index): + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.parent.players.remove(session) + city.players.add(session) + session.local_info["city_name"] = city.name + session.local_info["city_id"] = index + return city + + def leave_city(self, session): + city = self.get_city(session.local_info["server_id"], + session.local_info["gate_id"], + session.local_info["city_id"]) + with city.lock(): + city.parent.players.add(session) + city.players.remove(session) + if not city.get_population(): + city.reset() + session.local_info["city_id"] = None + session.local_info["city_name"] = None + + def layer_detail_search(self, server_type, fields): + cities = [] + + def match_city(city, fields): + with city.lock(): + return all(( + field in city.optional_fields + for field in fields + )) + + for server in self.get_servers(): + if server.server_type != server_type: + continue + for gate in server.gates: + if not gate.get_population(): + continue + cities.extend([ + city + for city in gate.cities + if match_city(city, fields) + ]) + return cities + + def update_players(self): + # Central server method for clearing unused Capcom IDs + for capcom_id, player_info in self.capcom_ids.items(): + if player_info["session"] is not None and \ + player_info["session"].local_info["server_id"] and \ + capcom_id not in self.cache.players: + self.capcom_ids[capcom_id]["session"] = None + elif capcom_id in self.cache.players: + self.capcom_ids[capcom_id]["session"] =\ + self.cache.players[capcom_id] + + def update_capcom_id(self, session): + # Central server method for keeping track of in-use Capcom IDs + self.capcom_ids[session.capcom_id]["session"] = session + + def session_ready(self, pat_ticket): + if self.server_id == 0: + return True + return self.cache.session_ready(pat_ticket) + + def set_session_ready(self, pat_ticket, store_data): + self.cache.set_session_ready(pat_ticket, store_data) + + def close_cache(self): + self.cache.close() + + +CURRENT_STATE = State() + + +def get_instance(): + return CURRENT_STATE diff --git a/other/cache.py b/other/cache.py new file mode 100644 index 0000000..04c2882 --- /dev/null +++ b/other/cache.py @@ -0,0 +1,818 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +"""Monster Hunter cache module. + + Monster Hunter 3 Server Project + Copyright (C) 2023 Ze SpyRo + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +""" + +from mh.time_utils import Timer +from mh.session import Session +from mh.state import get_instance, Server +from other.utils import Logger, create_logger, get_remote_config, \ + get_central_config, get_config, get_ip + +from threading import Lock, Event +import socket +import struct +import logging +import json + +try: + # Python 3 + import selectors + from typing import Tuple, Union, Optional, Dict, List, Callable, TYPE_CHECKING + if TYPE_CHECKING: + from fmp_server import FmpRequestHandler + from mh.pat_item import ConnectionData +except ImportError: + # Python 2 + import externals.selectors2 as selectors + + +class PacketTypes(object): + Ping = 0x0000 + FriendlyHello = 0x0001 + ReqConnectionInfo = 0x0002 + SendConnectionInfo = 0x0003 + ReqServerRefresh = 0x0004 + SessionInfo = 0x0005 + ServerIDList = 0x0006 + ServerShutdown = 0x0007 + SessionDisconnect = 0x0008 + + +class CentralConnectionHandler(object): + def __init__(self, sck, client_address, cache): + # type: (socket.socket, Tuple[str, int], Cache) -> None + self.id = -1 # type: int + self.socket = sck + self.client_address = client_address + self.cache = cache + self.rfile = self.socket.makefile('rb', -1) + self.wfile = self.socket.makefile('wb', 0) + + self.rw = Lock() + self.finished = False # type: bool + + self.handler_functions = { + PacketTypes.Ping: self.RecvPing, + PacketTypes.FriendlyHello: self.RecvFriendlyHello, + PacketTypes.SendConnectionInfo: self.RecvConnectionInfo, + PacketTypes.ReqConnectionInfo: self.RecvReqConnectionInfo, + PacketTypes.ReqServerRefresh: self.RecvReqServerRefresh, + PacketTypes.SessionInfo: self.RecvSessionInfo, + PacketTypes.ServerShutdown: self.RecvServerShutdown, + PacketTypes.SessionDisconnect: self.RecvSessionDisconnect + } # type: Dict[int, Callable[[int, bytes], None]] + + def fileno(self): + # type: () -> int + return self.socket.fileno() + + def on_recv(self): + # type: () -> Optional[Tuple[int, int, bytes]] + header = self.rfile.read(10) + if not len(header) or len(header) < 10: + return None + + return self.recv_packet(header) + + def recv_packet(self, header): + # type: (bytes) -> Tuple[int, int, bytes] + size, packet_id, server_id = struct.unpack(">IIH", header) + data = self.rfile.read(size) + return server_id, packet_id, data + + def send_packet(self, packet_id=0, data=b""): + # type: (int, bytes) -> None + self.wfile.write(self.pack_data( + data, packet_id + )) + + def pack_data(self, data, packet_id): + # type: (bytes, int) -> bytes + return struct.pack(">II", len(data), packet_id) + data + + def is_finished(self): + # type: () -> bool + return self.finished + + def on_exception(self, e): + # type: (Exception) -> None + self.finish() + + def direct_to_handler(self, packet): + # type: (Tuple[int, int, bytes]) -> None + server_id, packet_type, data = packet + self.handler_functions[packet_type](server_id, data) + + def RecvPing(self, server_id, data): + # type: (int, bytes) -> None + pass + + def RecvFriendlyHello(self, server_id, data): + # type: (int, bytes) -> None + self.cache.debug("Recieved a friendly hello from {}!".format( + server_id + )) + self.cache.register_handler(server_id, self) + self.ReqConnectionInfo() + + def ReqConnectionInfo(self): + # type: () -> None + self.cache.debug("Requesting connection info.") + self.send_packet(PacketTypes.ReqConnectionInfo, b"") + + def RecvConnectionInfo(self, server_id, data): + # type: (int, bytes) -> None + self.cache.debug("Recieved connection info sized {} from {}".format( + len(data), server_id + )) + server = Server.deserialize(json.loads(data.decode('utf-8'))) + self.cache.servers[server_id] = server + self.cache.update_players() + + def RecvReqConnectionInfo(self, server_id, data): + # type: (int, bytes) -> None + requested_server_id, = struct.unpack(">H", data) + self.cache.debug("Recieved request for data of Server {}.".format( + requested_server_id + )) + if server_id in self.cache.servers: + data = json.dumps(self.cache.servers[server_id].serialize()).encode('utf-8') + self.SendConnectionInfo(data) + + def SendConnectionInfo(self, data): + # type: (bytes) -> None + self.cache.debug("Sending updated connection info.") + self.send_packet(PacketTypes.SendConnectionInfo, data) + + def RecvReqServerRefresh(self, server_id, data): + # type: (int, bytes) -> None + self.cache.debug("Recieved server refresh request from \ + Server {}.".format( + server_id + )) + self.SendServerIDList() + for _server_id in self.cache.servers: + data = struct.pack(">H", _server_id) + data += json.dumps(self.cache.servers[_server_id].serialize()).encode('utf-8') + self.SendConnectionInfo(data) + + def SendServerIDList(self): + # type: () -> None + self.cache.debug("Sending updated Server ID list.") + data = struct.pack(">H", self.cache.servers_version) + data += struct.pack(">H", len(self.cache.servers)) + for _server_id in self.cache.servers: + data += struct.pack(">H", _server_id) + self.send_packet(PacketTypes.ServerIDList, data) + + def RecvSessionInfo(self, server_id, data): + # type: (int, bytes) -> None + dest_server_id, = struct.unpack(">H", data[:2]) + self.cache.debug("Recieved session data from Server {} \ + bound for Server {}.".format( + server_id, dest_server_id + )) + self.cache.update_player_record( + Session.deserialize(json.loads(data[2:].decode('utf-8'))) + ) + self.cache.get_handler(dest_server_id).SendSessionInfo(data[2:]) + + def SendSessionInfo(self, ser_session): + # type: (bytes) -> None + self.cache.debug("Dispatching session info to remote Server.") + self.send_packet(PacketTypes.SessionInfo, ser_session) + + def RecvServerShutdown(self, server_id, data): + # type: (int, bytes) -> None + raise Exception("Server shutting down.") + + def RecvSessionDisconnect(self, server_id, data): + # type: (int, bytes) -> None + length, = struct.unpack(">H", data[:2]) + capcom_id = str(data[2:2+length].decode('utf-8')) + if capcom_id in self.cache.players: + del self.cache.players[capcom_id] + get_instance().update_players() + + def finish(self): + # type: () -> None + if self.finished: + return + + self.finished = True + + try: + self.wfile.close() + except Exception: + pass + + try: + self.rfile.close() + except Exception: + pass + + try: + self.socket.close() + except Exception: + pass + + +class RemoteConnectionHandler(object): + def __init__(self, sck, client_address, cache): + # type: (socket.socket, Tuple[str, int], Cache) -> None + self.id = 0 + self.socket = sck + self.client_address = client_address + self.cache = cache + self.rfile = self.socket.makefile('rb', -1) + self.wfile = self.socket.makefile('wb', 0) + + self.rw = Lock() + self.finished = False # type: bool + + self.handler_functions = { + PacketTypes.ReqConnectionInfo: self.ReqConnectionInfo, + PacketTypes.SendConnectionInfo: self.RecvConnectionInfo, + PacketTypes.SessionInfo: self.RecvSessionInfo, + PacketTypes.ServerIDList: self.RecvServerIDList, + } # type: Dict[int, Callable[[bytes], None]] + + def fileno(self): + # type: () -> int + return self.socket.fileno() + + def on_recv(self): + # type: () -> Optional[Tuple[int, bytes]] + header = self.rfile.read(8) + if not len(header) or len(header) < 8: + return None + + return self.recv_packet(header) + + def recv_packet(self, header): + # type: (bytes) -> Tuple[int, bytes] + size, packet_id = struct.unpack(">II", header) + data = self.rfile.read(size) + return packet_id, data + + def is_finished(self): + # type: () -> bool + return self.finished + + def on_exception(self, e): + # type: (Exception) -> None + self.finish() + + def send_packet(self, packet_id=0, data=b""): + # type: (int, bytes) -> None + self.wfile.write(self.pack_data( + data, packet_id + )) + + def pack_data(self, data, packet_id): + # type: (bytes, int) -> bytes + return struct.pack(">IIH", len(data), packet_id, + self.cache.server_id) + data + + def direct_to_handler(self, packet): + # type: (Tuple[int, bytes]) -> None + packet_type, data = packet + self.handler_functions[packet_type](data) + + def SendFriendlyHello(self, data=b""): + # type: (bytes) -> None + self.cache.debug("Sending a friendly hello!") + self.send_packet(PacketTypes.FriendlyHello, data) + + def ReqConnectionInfo(self, data): + # type: (bytes) -> None + self.cache.debug("Recieved request for update connection info from Central.") + server = get_instance().server + assert server != None + + data = json.dumps(server.serialize()).encode('utf-8') + self.SendConnectionInfo(data) + + def SendConnectionInfo(self, data): + # type: (bytes) -> None + self.cache.debug("Sending connection info to Central.") + self.send_packet(PacketTypes.SendConnectionInfo, data) + + def SendReqServerRefresh(self): + # type: () -> None + self.cache.debug("Requesting refreshed server info from central.") + self.send_packet(PacketTypes.ReqServerRefresh, b"") + + def SendReqConnectionInfo(self, server_id): + # type: (int) -> None + self.cache.debug("Requesting info for Server {}".format( + server_id + )) + self.send_packet(PacketTypes.ReqConnectionInfo, + struct.pack(">H", server_id)) + + def RecvServerIDList(self, data): + # type: (bytes) -> None + self.cache.debug("Recieved updated Server ID list from Central.") + servers_version, count = struct.unpack(">HH", data[:4]) + self.cache.update_servers_version(servers_version) + updated_server_ids = [] + for i in range(count): + server_id = struct.unpack(">H", data[2*(i+2):2*(i+2)+2]) + updated_server_ids.append(server_id) + for server_id in self.cache.servers.keys(): + if server_id not in updated_server_ids: + self.cache.prune_server(server_id) + + def RecvConnectionInfo(self, data): + # type: (bytes) -> None + try: + server_id, = struct.unpack(">H", data[:2]) + server = Server.deserialize(json.loads(data[2:].decode('utf-8'))) + except Exception as e: + self.cache.error(e) + return + self.cache.debug("Obtained updated server info for Server {}".format( + server_id + )) + self.cache.servers[server_id] = server + + def SendSessionInfo(self, server_id, ser_session): + # type: (int, bytes) -> None + self.cache.debug("Sending Session info to Server {}".format( + server_id + )) + data = struct.pack(">H", server_id) + data += ser_session + self.send_packet(PacketTypes.SessionInfo, data) + + def RecvSessionInfo(self, data): + # type: (bytes) -> None + self.cache.debug("Recieved new Session info!") + self.cache.new_session(Session.deserialize(json.loads(data.decode('utf-8')))) + + def SendSessionDisconnect(self, capcom_id): + # type: (bytes) -> None + data = struct.pack(">H", len(capcom_id)) + data += capcom_id.encode('utf-8') + self.send_packet(PacketTypes.SessionDisconnect, data) + + def finish(self): + # type: () -> None + if self.finished: + return + + self.finished = True + + try: + self.wfile.close() + except Exception: + pass + + try: + self.rfile.close() + except Exception: + pass + + try: + self.socket.close() + except Exception: + pass + + +class Cache(Logger): + def __init__(self, server_id, debug_mode=False, log_to_file=False, + log_to_console=False, log_to_window=False, + refresh_period=30, ssl_location='cert/crossserverCA/'): + # type: (int, bool, bool, bool, bool, int, str) -> None + Logger.__init__(self) + self.servers_version = 1 # type: int + self.servers = { + # To be populated by remote connection + } # type: Dict[int, Server] + + self.outbound_sessions = [ + # (destination_server_id, session) + ] # type: List[Tuple[int, bytes]] + + self.players = { + # capcom_id -> connectionless sessions from other servers + } # type: Dict[str, Session] + self.ready_sessions = { + # pat_ticket -> True or connection_data + } # type: Dict[str, Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]]] + + log_level = logging.DEBUG if debug_mode else logging.INFO + log_file = "cache.log" if log_to_file else "" + + self.set_logger(create_logger("Cache", log_level, log_file, log_to_console, log_to_window)) + + self.is_central_server = server_id == 0 + if not self.is_central_server: + remote_config = get_remote_config("SERVER{}".format(server_id)) + get_instance().setup_server(server_id, + remote_config["Name"], + int(remote_config["ServerType"]), + int(remote_config["Capacity"]), + get_ip(remote_config["IP"]), + int(remote_config["Port"])) + else: + config = get_config("FMP") + get_instance().setup_server( + server_id, "", 0, 1, '0.0.0.0', config["Port"] + ) + self.shut_down = False # type: bool + self.shut_down_event = Event() + self.refresh_period = refresh_period + self.handlers = {} # type: Dict[int, CentralConnectionHandler] + self.pending_connections = {} # type: Dict[CentralConnectionHandler, Timer] + self.remote_server = None # type: Optional[RemoteConnectionHandler] + self.server_id = server_id + self.central_config = get_central_config() + self.central_connection = (self.central_config["CentralIP"], + self.central_config["CentralCrossconnectPort"]) + self.sel = selectors.DefaultSelector() + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.ssl_location = ssl_location + if self.central_config["CrossconnectSSL"]: + self.create_ssl_wrapper() + + def create_ssl_wrapper(self): + # type: () -> None + import ssl + context = ssl.SSLContext(ssl.PROTOCOL_TLS) + if self.is_central_server: + context.load_verify_locations( + cafile="{}ca.crt".format(self.ssl_location) + ) + context.load_cert_chain("{}MH3SP.crt".format(self.ssl_location), + "{}MH3SP.key".format(self.ssl_location)) + else: + context.load_cert_chain( + "{}client{}.crt".format(self.ssl_location, self.server_id), + "{}client{}.key".format(self.ssl_location, self.server_id) + ) + self.socket = context.wrap_socket( + self.socket, server_side=self.is_central_server + ) + + def update_player_record(self, session): + # type: (Session) -> None + get_instance().update_capcom_id(session) + + def update_players(self): + # type: () -> None + + players = [] + for server_id, server in self.servers.items(): + if server_id != self.server_id: + players = players + server.get_all_players() + new_players = {} + for p in players: + new_players[p.capcom_id] = p + self.players = new_players + if self.is_central_server: + get_instance().update_players() + + def get_remote_players_list(self): + # type: () -> List[Session] + players = [] + for server_id, server in self.servers.items(): + if server_id != self.server_id: + players = players + server.get_all_players() + return players + + def update_servers_version(self, servers_version): + # type: (int) -> None + self.servers_version = servers_version + + def get_server_list(self, include_ids=False): + # type: (bool) -> Union[List[Server], Tuple[List[int], List[Server]]] + if include_ids: + return list(self.servers.keys()), list(self.servers.values()) + return list(self.servers.values()) + + def get_server(self, server_id): + # type: (int) -> Server + assert server_id in self.servers + return self.servers[server_id] + + def send_session_info(self, server_id, session): + # type: (int, Session) -> None + self.outbound_sessions.append( + (server_id, json.dumps(session.serialize()).encode('utf-8')) + ) + + def new_session(self, session): + # type: (Session) -> None + pat_ticket = session.pat_ticket + assert pat_ticket is not None + + ready_data = self.session_ready(pat_ticket) + if ready_data: + self.set_session_ready(pat_ticket, False) + get_instance().register_pat_ticket(session) + self.send_login_packet(*ready_data) # type: ignore + else: + self.set_session_ready(pat_ticket, session) + + def session_ready(self, pat_ticket): + # type: (str) -> Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]] + return self.ready_sessions.get(pat_ticket, False) + + def set_session_ready(self, pat_ticket, store_data): + # type: (str, Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]]) -> None + self.ready_sessions[pat_ticket] = store_data + + def notify_session_deletion(self, capcom_id): + # type: str -> None + if not self.is_central_server: + self.remote_server.SendSessionDisconnect(capcom_id) + + def send_login_packet(self, player_handler, connection_data, seq): + # type: (FmpRequestHandler, ConnectionData, int) -> None + player_handler.sendNtcLogin(3, connection_data, seq) + + def get_handler(self, server_id): + # type: (int) -> CentralConnectionHandler + return self.handlers[server_id] + + def register_handler(self, server_id, handler): + # type: (int, CentralConnectionHandler) -> None + handler.id = server_id + del self.pending_connections[handler] + self.handlers[server_id] = handler + + def prune_server(self, server_id): + # type: (int) -> None + for player in self.servers[server_id].get_all_players(): + if player.capcom_id in self.players: + del self.players[player.capcom_id] + if server_id != 0: + del self.servers[server_id] + if self.is_central_server: + self.update_servers_version(self.servers_version + 1) + + def maintain_connection(self): + # type: () -> None + state = get_instance() + state.initialized.wait() + state.cache = self + + refresh_timer = Timer() + if self.is_central_server: + # CENTRAL SERVER CONNECTION TO REMOTE + self.serve_cache(refresh_timer) + else: + # REMOTE SERVER CONNECTION TO CENTRAL + self.remote_connection_to_central_server(refresh_timer) + + def remote_connection_to_central_server(self, refresh_timer): + # type: (Timer) -> None + central_host = "localhost" if self.central_connection[0] == "0.0.0.0" else self.central_connection[0] + central_addr = (central_host, self.central_connection[1]) + + while not self.shut_down: + # Connect to the central server if needed + if self.remote_server is None or self.remote_server.is_finished(): + self.info("Connecting to central server at {}:{}...".format(central_addr[0], central_addr[1])) + + if self.remote_server is not None: + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.central_config["CrossconnectSSL"]: + self.create_ssl_wrapper() + self.remote_server = None + + connect_timer = Timer() + try: + self.socket.settimeout(60.0) + self.socket.connect(central_addr) + self.socket.settimeout(None) + except socket.error as sck_error: + self.error("Failed! {}.".format(sck_error)) + if connect_timer.elapsed() < 60.0: + remaining = 60.0 - connect_timer.elapsed() + self.debug("Retrying in {:.2f}s".format(remaining)) + self.shut_down_event.wait(remaining) + continue + + self.remote_server = RemoteConnectionHandler(self.socket, central_addr, self) + self.sel.register(self.remote_server, selectors.EVENT_READ | selectors.EVENT_WRITE) + + try: + # Listen for incoming packets + assert self.remote_server is not None + events = self.sel.select(timeout=1) + for _, event in events: + if bool(event & selectors.EVENT_WRITE): + self.remote_server.SendFriendlyHello() + self.sel.modify(self.remote_server, selectors.EVENT_READ) + elif bool(event & selectors.EVENT_READ): + packet = self.remote_server.on_recv() + if packet is None: + continue + self.remote_server.direct_to_handler(packet) + # Request updated server information + if refresh_timer.elapsed() >= self.refresh_period: + assert self.remote_server is not None + try: + self.remote_server.SendReqServerRefresh() + finally: + refresh_timer.restart() + # Pass on an outbound session + if len(self.outbound_sessions) > 0: + self.debug("Outbound Session dispatching to Central Server.") + assert self.remote_server is not None + self.remote_server.SendSessionInfo(*self.outbound_sessions.pop(0)) + except Exception as exc: + assert self.remote_server is not None + self.remote_server.on_exception(exc) + finally: + assert self.remote_server is not None + if not self.shut_down and self.remote_server.is_finished(): + self.sel.unregister(self.remote_server) + + + def serve_cache(self, refresh_timer): + # type: (Timer) -> None + try: + self.socket.bind(self.central_connection) + self.socket.listen(0) + except socket.error as sck_error: + self.error('Failed to bind server to {}:{}. {}'.format(self.central_connection[0], self.central_connection[1], sck_error)) + return + self.info("Listening for remote servers on {}:{}".format(self.central_connection[0], self.central_connection[1])) + self.sel.register(self.socket, selectors.EVENT_READ) + + while not self.shut_down: + events = self.sel.select(timeout=1) + # Respond to incoming packets + for key, event in events: + connection = key.fileobj + if connection == self.socket: + # Accept a new connection + new_client = None + try: + new_client = self.socket.accept() + except Exception as exc: + self.error("Failed to accept incoming connection. {}".format( + exc)) + continue + + client_socket, client_address = new_client + self.info("Remote Server connected from {}".format( + client_address + )) + handler = CentralConnectionHandler(client_socket, + client_address, + self) + self.pending_connections[handler] = Timer() + self.sel.register(handler, selectors.EVENT_READ) + self.update_servers_version(self.servers_version + 1) + else: + assert event == selectors.EVENT_READ + assert isinstance(connection, CentralConnectionHandler) + try: + packet = connection.on_recv() + if packet is None: + if connection.is_finished(): + self.remove_handler(connection) + continue + connection.direct_to_handler(packet) + except Exception as exc: + connection.on_exception(exc) + if not connection.is_finished(): + continue + self.info( + "Connection to Remote Server {} lost: {}.".format( + connection.id, + exc + )) + self.remove_handler(connection) + if refresh_timer.elapsed() >= self.refresh_period: + for _, handler in self.handlers.items(): + try: + handler.ReqConnectionInfo() + except Exception as exc: + handler.on_exception(exc) + + if handler.is_finished(): + self.remove_handler(handler) + refresh_timer.restart() + # Pass on an outbound session + if len(self.outbound_sessions) > 0: + self.debug("Session outbound...") + outbound_session = self.outbound_sessions.pop(0) + self.debug("Dispatching Session to Server {}.".format( + outbound_session[0] + )) + + server_handler = self.get_handler(outbound_session[0]) + + try: + server_handler.SendSessionInfo(outbound_session[1]) + except Exception as exc: + server_handler.on_exception(exc) + + if server_handler.is_finished(): + self.remove_handler(server_handler) + # Prune dangling pending connection + for handler, timer in self.pending_connections.items(): + if timer.elapsed() < 30.0: + continue + self.warning('Prunning dangling remote server connection {}:{}'.format(handler.client_address[0], handler.client_address[1])) + try: + handler.finish() + except Exception: + pass + self.sel.unregister(handler) + del self.pending_connections[handler] + + def remove_handler(self, handler): + # type: (CentralConnectionHandler) -> None + if handler.id > -1: + try: + del self.handlers[handler.id] + except KeyError: + pass + else: + try: + del self.pending_connections[handler] + except KeyError: + pass + + try: + self.sel.unregister(handler) + except KeyError: + pass + + try: + handler.finish() + except Exception: + pass + + if handler.id > -1: + self.prune_server(handler.id) + + if self.is_central_server: + for handler in self.handlers.values(): + try: + handler.SendServerIDList() + except Exception: + self.error("Failed to send the Server ID list to a handler.") + + def close(self): + # type: () -> None + if self.shut_down: + return + + self.shut_down = True + self.shut_down_event.set() + + if self.is_central_server: + for _, handler in self.handlers.items(): + try: + handler.finish() + except Exception: + pass + + self.handlers.clear() + + for handler, _ in self.pending_connections.items(): + try: + handler.finish() + except Exception: + pass + self.pending_connections.clear() + elif self.remote_server is not None: + try: + self.remote_server.send_packet(PacketTypes.ServerShutdown, b"") + except Exception: + pass + + try: + self.remote_server.finish() + except Exception: + pass + + self.socket.close() + self.sel.close() + + self.info('Server Closed') diff --git a/other/utils.py b/other/utils.py index 1d2c68c..221895a 100644 --- a/other/utils.py +++ b/other/utils.py @@ -65,6 +65,17 @@ def critical(self, msg, *args, **kwargs): return self.logger.critical(msg, *args, **kwargs) +class SimpleNamespace (object): + def __init__ (self, **kwargs): + self.__dict__.update(kwargs) + def __repr__ (self): + keys = sorted(self.__dict__) + items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys) + return "{}({})".format(type(self).__name__, ", ".join(items)) + def __eq__ (self, other): + return self.__dict__ == other.__dict__ + + class GenericUnpacker(object): """Generic unpacker that maps unpack and pack functions. @@ -237,6 +248,28 @@ def get_config(name, config_file=CONFIG_FILE): } +def get_remote_config(name, config_file=CONFIG_FILE): + config = ConfigParser.RawConfigParser(allow_no_value=True) + config.read(config_file) + return { + "IP": config.get(name, "IP"), + "Port": config.getint(name, "Port"), + "Capacity": config.getint(name, "Capacity"), + "Name": config.get(name, "Name"), + "ServerType": config.get(name, "ServerType") + } + + +def get_central_config(config_file=CONFIG_FILE): + config = ConfigParser.RawConfigParser(allow_no_value=True) + config.read(CONFIG_FILE) + return { + "CentralIP": config.get("CENTRAL", "CentralIP"), + "CentralCrossconnectPort": config.getint("CENTRAL", "CentralCrossconnectPort"), + "CrossconnectSSL": config.getboolean("CENTRAL", "CrossconnectSSL") + } + + def get_default_ip(): """Get the default IP address""" s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -344,7 +377,9 @@ def create_server(server_class, server_handler, address="0.0.0.0", port=8200, name="Server", max_thread=0, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", log_to_file=True, log_filename="server.log", - log_to_console=True, log_to_window=False, debug_mode=False): + log_to_console=True, log_to_window=False, legacy_ssl=False, + debug_mode=False, no_timeout=False): + """Create a server, its logger and the SSL context if needed.""" logger = create_logger( name, level=logging.DEBUG if debug_mode else logging.INFO, @@ -352,7 +387,7 @@ def create_server(server_class, server_handler, log_to_console=log_to_console, log_to_window=log_to_window) server = server_class((address, port), server_handler, max_thread, logger, - debug_mode) + debug_mode, no_timeout) if use_ssl: server.socket = wii_ssl_wrap_socket(server.socket, ssl_cert, ssl_key) @@ -363,25 +398,27 @@ def create_server(server_class, server_handler, server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) -def create_server_from_base(name, server_class, server_handler, silent=False, - debug_mode=False): +def create_server_from_base(name, server_class, server_handler, server_id, + silent=False, debug_mode=False, no_timeout=False): """Create a server based on its config parameters.""" - config = get_config(name) + general_config = get_config(name) + server_config = None if not server_id else get_remote_config("SERVER{}".format(server_id)) return create_server( server_class, server_handler, - address=config["IP"], - port=config["Port"], - name=config["Name"], - max_thread=config["MaxThread"], - use_ssl=config["UseSSL"], - ssl_cert=config["SSLCert"], - ssl_key=config["SSLKey"], - log_to_file=config["LogToFile"], - log_filename=config["LogFilename"], - log_to_console=config["LogToConsole"] and not silent, - log_to_window=config["LogToWindow"], - debug_mode=debug_mode - ), config["LogToWindow"] + address=general_config["IP"], + port=general_config["Port"] if not server_id else server_config["Port"], + name=general_config["Name"] if not server_id else server_config["Name"], + max_thread=general_config["MaxThread"], + use_ssl=general_config["UseSSL"], + ssl_cert=general_config["SSLCert"], + ssl_key=general_config["SSLKey"], + log_to_file=general_config["LogToFile"], + log_filename=general_config["LogFilename"], + log_to_console=general_config["LogToConsole"] and not silent, + log_to_window=general_config["LogToWindow"], + debug_mode=debug_mode, + no_timeout=no_timeout + ), general_config["LogToWindow"] def server_main(name, server_class, server_handler):