From 4c328d45ba71f0ef698ba55a59850fdf9de6a13a Mon Sep 17 00:00:00 2001 From: Kevin Barkevich Date: Sat, 14 Jun 2025 23:22:29 -0400 Subject: [PATCH 1/7] Move FMP state to its own class Other changes: - city.is_empty method added - assert_fields when searching cities fixed - workaround for cyclic import (wip) --- fmp_server.py | 13 +- mh/database.py | 508 ++------------------------------------- mh/pat.py | 75 ++++-- mh/pat_item.py | 55 ++++- mh/session.py | 294 ++++++++++++++++++----- mh/state.py | 280 ++++++++++++++++++++++ mh/state_models.py | 585 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1224 insertions(+), 586 deletions(-) create mode 100644 mh/state.py create mode 100644 mh/state_models.py diff --git a/fmp_server.py b/fmp_server.py index c179f05..7f1e1eb 100644 --- a/fmp_server.py +++ b/fmp_server.py @@ -9,18 +9,29 @@ import mh.pat_item as pati from mh.constants import PatID4 from mh.pat import PatRequestHandler, PatServer +from mh.session import FMPSession +from mh.state import State from other.utils import hexdump, server_base, server_main, to_str class FmpServer(PatServer): """Basic FMP server class.""" # TODO: Backport close cache logic - pass + def __init__(self, *args, **kwargs): + PatServer.__init__(self, *args, **kwargs) + self.fmp_state = State() + # TODO: Backport the cache server registration logic instead + import mh.database as db + db.get_instance().servers = self.fmp_state.servers class FmpRequestHandler(PatRequestHandler): """Basic FMP server request handler class.""" + def on_init(self): + PatRequestHandler.on_init(self) + self.session = FMPSession(self.session) + def recvAnsConnection(self, packet_id, data, seq): """AnsConnection packet.""" connection_data = pati.ConnectionData.unpack(data) diff --git a/mh/database.py b/mh/database.py index ad7c70b..9c17bfe 100644 --- a/mh/database.py +++ b/mh/database.py @@ -7,9 +7,8 @@ import inspect import random import sqlite3 -import time from other import utils -from threading import RLock, local as thread_local +from threading import local as thread_local CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -23,359 +22,11 @@ BLANK_CAPCOM_ID = "******" -RESERVE_DC_TIMEOUT = 40.0 - def new_random_str(length=6): return "".join(random.choice(CHARSET) for _ in range(length)) -class ServerType(object): - OPEN = 1 - ROOKIE = 2 - EXPERT = 3 - RECRUITING = 4 - - -class LayerState(object): - EMPTY = 1 - FULL = 2 - JOINABLE = 3 - - -class Lockable(object): - def __init__(self): - self._lock = RLock() - - def lock(self): - return self - - def __enter__(self): - # Returns True if lock was acquired, False otherwise - return self._lock.acquire() - - def __exit__(self, *args): - self._lock.release() - - -class Players(Lockable): - def __init__(self, capacity): - assert capacity > 0, "Collection capacity can't be zero" - - self.slots = [None for _ in range(capacity)] - self.used = 0 - super(Players, self).__init__() - - def get_used_count(self): - return self.used - - def get_capacity(self): - return len(self.slots) - - def add(self, item): - with self.lock(): - if self.used >= len(self.slots): - return -1 - - item_index = self.index(item) - if item_index != -1: - return item_index - - for i, v in enumerate(self.slots): - if v is not None: - continue - - self.slots[i] = item - self.used += 1 - return i - - return -1 - - def remove(self, item): - assert item is not None, "Item != None" - - with self.lock(): - if self.used < 1: - return False - - if isinstance(item, int): - if item >= self.get_capacity(): - return False - - self.slots[item] = None - self.used -= 1 - return True - - for i, v in enumerate(self.slots): - if v != item: - continue - - self.slots[i] = None - self.used -= 1 - return True - - return False - - def index(self, item): - assert item is not None, "Item != None" - - for i, v in enumerate(self.slots): - if v == item: - return i - - return -1 - - def clear(self): - with self.lock(): - for i in range(self.get_capacity()): - self.slots[i] = None - - def find_first(self, **kwargs): - if self.used < 1: - return None - - for p in self.slots: - if p is None: - continue - - for k, v in kwargs.items(): - if getattr(p, k) != v: - break - else: - return p - - return None - - def find_by_capcom_id(self, capcom_id): - return self.find_first(capcom_id=capcom_id) - - def __len__(self): - return self.used - - def __iter__(self): - if self.used < 1: - return - - for i, v in enumerate(self.slots): - if v is None: - continue - - yield i, v - - -class Circle(Lockable): - def __init__(self, parent): - # type: (City) -> None - self.parent = parent - self.leader = None - self.players = Players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - super(Circle, self).__init__() - - def get_population(self): - return len(self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def is_full(self): - return self.get_population() == self.get_capacity() - - def is_empty(self): - return self.leader is None - - def is_joinable(self): - return not self.departed and not self.is_full() - - def has_password(self): - return self.password is not None - - def reset_players(self, capacity): - with self.lock(): - self.players = Players(capacity) - - def reset(self): - with self.lock(): - self.leader = None - self.reset_players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - - -class City(Lockable): - LAYER_DEPTH = 3 - - def __init__(self, name, parent): - # type: (str, Gate) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.players = Players(4) - self.optional_fields = [] - self.leader = None - self.reserved = None - self.circles = [ - # One circle per player - Circle(self) for _ in range(self.get_capacity()) - ] - super(City, self).__init__() - - def get_population(self): - return len(self.players) - - def in_quest_players(self): - return sum(p.is_in_quest() for _, p in self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - if self.reserved: - return LayerState.FULL - - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - def get_pathname(self): - pathname = self.name - it = self.parent - while it is not None: - pathname = it.name + "\t" + pathname - it = it.parent - return pathname - - def get_first_empty_circle(self): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.is_empty(): - return circle, index - return None, None - - def get_circle_for(self, leader_session): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.leader == leader_session: - return circle, index - return None, None - - def clear_circles(self): - with self.lock(): - for circle in self.circles: - circle.reset() - - def reserve(self, reserve): - with self.lock(): - if reserve: - self.reserved = time.time() - else: - self.reserved = None - - def reset(self): - with self.lock(): - self.state = LayerState.EMPTY - self.players.clear() - self.optional_fields = [] - self.leader = None - self.reserved = None - self.clear_circles() - - -class Gate(object): - LAYER_DEPTH = 2 - - def __init__(self, name, parent, city_count=40, player_capacity=100): - # type: (str, Server, int, int) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.cities = [ - City("City{}".format(i), self) - for i in range(1, city_count+1) - ] - self.players = Players(player_capacity) - self.optional_fields = [] - - def get_population(self): - return len(self.players) + sum(( - city.get_population() - for city in self.cities - )) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - -class Server(object): - LAYER_DEPTH = 1 - - def __init__(self, name, server_type, gate_count=40, capacity=2000, - addr=None, port=None): - self.name = name - self.parent = None - self.server_type = server_type - self.addr = addr # public IP address - self.port = port - self.gates = [ - Gate("City Gate{}".format(i), self) - for i in range(1, gate_count+1) - ] - self.players = Players(capacity) - - def get_population(self): - return len(self.players) + sum(( - gate.get_population() for gate in self.gates - )) - - def get_capacity(self): - return self.players.get_capacity() - - -def new_servers(): - servers = [] - servers.extend([ - Server("Valor{}".format(i), ServerType.OPEN) - for i in range(1, 5) - ]) - servers.extend([ - Server("Beginners{}".format(i), ServerType.ROOKIE) - for i in range(1, 3) - ]) - servers.extend([ - Server("Veterans{}".format(i), ServerType.EXPERT) - for i in range(1, 3) - ]) - servers.extend([ - Server("Greed{}".format(i), ServerType.RECRUITING) - for i in range(1, 5) - ]) - return servers - - class TempDatabase(object): """A temporary database. @@ -403,7 +54,6 @@ def __init__(self): # Capcom ID => List of Capcom IDs # TODO: May need stable index, see Players class } - self.servers = new_servers() def get_support_code(self, session): """Get the online support code or create one.""" @@ -500,93 +150,8 @@ def get_users(self, session, first_index, count): ]) return capcom_ids - def join_server(self, session, index): - if session.local_info["server_id"] is not None: - self.leave_server(session, session.local_info["server_id"]) - server = self.get_server(index) - server.players.add(session) - session.local_info["server_id"] = index - session.local_info["server_name"] = server.name - return server - - def leave_server(self, session, index): - self.get_server(index).players.remove(session) - session.local_info["server_id"] = None - session.local_info["server_name"] = None - - def get_server_time(self): - pass - - def get_game_time(self): - pass - - def get_servers(self): - return self.servers - - def get_server(self, index): - assert 0 < index <= len(self.servers), "Invalid server index" - return self.servers[index - 1] - - def get_gates(self, server_id): - return self.get_server(server_id).gates - - def get_gate(self, server_id, index): - gates = self.get_gates(server_id) - assert 0 < index <= len(gates), "Invalid gate index" - return gates[index - 1] - - def join_gate(self, session, server_id, index): - gate = self.get_gate(server_id, index) - gate.parent.players.remove(session) - gate.players.add(session) - session.local_info["gate_id"] = index - session.local_info["gate_name"] = gate.name - return gate - - def leave_gate(self, session): - gate = self.get_gate(session.local_info["server_id"], - session.local_info["gate_id"]) - gate.parent.players.add(session) - gate.players.remove(session) - session.local_info["gate_id"] = None - session.local_info["gate_name"] = None - - def get_cities(self, server_id, gate_id): - return self.get_gate(server_id, gate_id).cities - - def get_city(self, server_id, gate_id, index): - cities = self.get_cities(server_id, gate_id) - assert 0 < index <= len(cities), "Invalid city index" - return cities[index - 1] - - def reserve_city(self, server_id, gate_id, index, reserve): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - reserved_time = city.reserved - if reserve and reserved_time and \ - time.time()-reserved_time < RESERVE_DC_TIMEOUT: - return False - city.reserve(reserve) - return True - - def get_all_users(self, server_id, gate_id, city_id): - """Search for users in layers and its children. - - Let's assume wildcard search isn't possible for servers and gates. - A wildcard search happens when the id is zero. - """ - assert 0 < server_id, "Invalid server index" - assert 0 < gate_id, "Invalid gate index" - gate = self.get_gate(server_id, gate_id) - users = list(gate.players) - cities = [ - self.get_city(server_id, gate_id, city_id) - ] if city_id else self.get_cities(server_id, gate_id) - for city in cities: - users.extend(list(city.players)) - return users - - def find_users(self, capcom_id="", hunter_name=""): + def find_users(self, capcom_id="", hunter_name=b""): + # type: (str, bytes) -> list["Session"] assert capcom_id or hunter_name, "Search can't be empty" users = [] for user_id, user_info in self.capcom_ids.items(): @@ -606,59 +171,6 @@ def get_user_name(self, capcom_id): return "" return self.capcom_ids[capcom_id]["name"] - def create_city(self, session, server_id, gate_id, index, - settings, optional_fields): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.optional_fields = optional_fields - city.leader = session - city.reserved = None - return city - - def join_city(self, session, server_id, gate_id, index): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.parent.players.remove(session) - city.players.add(session) - session.local_info["city_name"] = city.name - session.local_info["city_id"] = index - return city - - def leave_city(self, session): - city = self.get_city(session.local_info["server_id"], - session.local_info["gate_id"], - session.local_info["city_id"]) - with city.lock(): - city.parent.players.add(session) - city.players.remove(session) - if not city.get_population(): - city.reset() - session.local_info["city_id"] = None - session.local_info["city_name"] = None - - def layer_detail_search(self, server_type, fields): - cities = [] - - def match_city(city, fields): - with city.lock(): - return all(( - field in city.optional_fields - for field in fields - )) - - for server in self.servers: - if server.server_type != server_type: - continue - for gate in server.gates: - if not gate.get_population(): - continue - cities.extend([ - city - for city in gate.cities - if match_city(city, fields) - ]) - return cities - def add_friend_request(self, sender_id, recipient_id): # Friend invite can be sent to arbitrary Capcom ID if any(cid not in self.capcom_ids @@ -1148,14 +660,23 @@ def __init__(self, *args, **kwargs): CURRENT_DB = \ MySQLDatabase() \ if utils.is_mysql_enabled("MYSQL") \ - else TempSQLiteDatabase() + else TempSQLiteDatabase() # type: TempDatabase def get_instance(): + # type: () -> TempDatabase + """Return a database instance like module singleton. + + TODO: + - Rename the TempDatabase interface to DatabaseInterface or similar + - Proper type hint, though any implementation must implement all the + interface methods. + """ return CURRENT_DB def implementation_check(cls, base=TempDatabase): + # type: (type, type) -> bool """Debug implementation check allowing to see methods dependencies. Example: @@ -1163,6 +684,7 @@ def implementation_check(cls, base=TempDatabase): >>> implementation_check(TempSQLiteDatabase) """ def is_method(obj): + # type: (object) -> bool """Python3's missing unbound methods workaround.""" if inspect.ismethod(obj): return True @@ -1179,7 +701,7 @@ def is_method(obj): print("class {}:".format(cls.__name__)) - for name, obj in inspect.getmembers(cls, is_method): + for name, _ in inspect.getmembers(cls, is_method): if name in missing_methods: missing_methods.remove(name) for c in mros: diff --git a/mh/pat.py b/mh/pat.py index 6bf7b47..5d1e111 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -19,8 +19,8 @@ CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, \ MAINTENANCE, UNPATCHED, \ PatID4, get_pat_binary_from_version -from mh.session import Session -import mh.database as db +from mh.session import Session, FMPSession +from mh.state_models import Server, Gate, City, Circle, Players # noqa: F401 try: from typing import Literal @@ -81,7 +81,7 @@ def get_pat_handler(self, session): # TODO: Backport broadcasting refactoring if needed def broadcast(self, players, packet_id, data, seq, to_exclude=None): - # type: (db.Players, int, bytes, int, Session|None) -> None + # type: (Players, int, bytes, int, Session|None) -> None handlers = [] with players.lock(): for _, player in players: @@ -102,7 +102,7 @@ def layer_broadcast(self, session, packet_id, data, seq, def circle_broadcast(self, circle, packet_id, data, seq, session=None): - # type: (db.Circle, int, bytes, int, Session|None) -> None + # type: (Circle, int, bytes, int, Session|None) -> None with circle.lock(): self.broadcast(circle.players, packet_id, data, seq, session) @@ -302,6 +302,7 @@ def recvAnsConnection(self, packet_id, data, seq): has_ban = "online_support_code" in settings and \ pat_ticket in BANNED_ONLINE_SUPPORT_CODES + # Used by OPN server, not sure we should rely on session at this point if has_ban or len(self.session.get_servers()) == 0: self.sendNtcLogin(2, settings, seq) else: @@ -932,6 +933,8 @@ def recvReqFmpListVersion(self, packet_id, data, seq): JP: FMPリストバージョン確認 TR: FMP list version check + NB: This packet is also used by the LMP server. + TODO: - Find why there are 2 versions of FMP packets. - Find why most of the 2 versions are ignored. @@ -967,6 +970,8 @@ def recvReqFmpListHead(self, packet_id, data, seq): ID: 61310100 / 63110100 JP: FMPリスト数送信 / FMPリスト数要求 TR: Send FMP list count / FMP list count request + + NB: This packet is also used by the LMP server. """ # TODO: Might be worth investigating these parameters as # they might be useful when using multiple FMP servers. @@ -987,6 +992,8 @@ def sendAnsFmpListHead(self, seq): ID: 61310200 JP: FMPリスト数応答 TR: FMP list count response + + NB: This packet is also used by the LMP server. """ unused = 0 count = len(self.session.get_servers()) @@ -1013,6 +1020,8 @@ def recvReqFmpListData(self, packet_id, data, seq): ID: 61320100 / 63120100 JP: FMPリスト送信 / FMPリスト要求 TR: Send FMP list / FMP list response + + NB: This packet is also used by the LMP server. """ first_index, count = struct.unpack_from(">II", data) if packet_id == PatID4.ReqFmpListData: @@ -1058,6 +1067,8 @@ def recvReqFmpListFoot(self, packet_id, data, seq): ID: 61330100 / 63130100 JP: FMPリスト送信終了 / FMPリスト終了送信 TR: FMP list end of transmission / FMP list transmission end + + NB: This packet is also used by the LMP server. """ if packet_id == PatID4.ReqFmpListFoot: self.sendAnsFmpListFoot(seq) @@ -1132,21 +1143,26 @@ def recvReqFmpInfo(self, packet_id, data, seq): TR: FMP data request TODO: Do not hardcode the data and find the meaning of all fields. + + NB: This packet is also used by the LMP server. """ index, = struct.unpack_from(">I", data) fields = pati.unpack_bytes(data, 4) - server = self.session.join_server(index) + # FIXME: Doesn't seem to make sense here, + # e.g. on LMP server, as "FmpInfo" packet + # server = self.session.join_server(index) config = get_config("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] fmp_data = pati.FmpData() - fmp_data.server_address = pati.String(server.addr or fmp_addr) - fmp_data.server_port = pati.Word(server.port or fmp_port) + fmp_data.server_address = pati.String(fmp_addr) + fmp_data.server_port = pati.Word(fmp_port) fmp_data.assert_fields(fields) # TODO: Backport central logic - if packet_id == PatID4.ReqFmpInfo: + if packet_id == PatID4.ReqFmpInfo: # LMP version self.sendAnsFmpInfo(fmp_data, fields, seq) - elif packet_id == PatID4.ReqFmpInfo2: + elif packet_id == PatID4.ReqFmpInfo2: # FMP version + self.session.join_server(index) self.sendAnsFmpInfo2(fmp_data, fields, seq) # Preserve session in database, due to server selection @@ -1416,8 +1432,9 @@ def recvReqLayerStart(self, packet_id, data, seq): JP: レイヤ開始要求 TR: Layer start request """ - unk1 = pati.unpack_bytes(data) - unk2 = pati.unpack_bytes(data, len(unk1) + 1) + with pati.Unpacker(data) as unpacker: + unk1 = unpacker.bytes() + unk2 = unpacker.bytes() self.sendAnsLayerStart(unk1, unk2, seq) def sendAnsLayerStart(self, unk1, unk2, seq): @@ -2402,7 +2419,17 @@ def sendAnsLayerDetailSearchData(self, offset, count, seq): for i, city in enumerate(cities): with city.lock(): layer_data = pati.LayerData.create_from(i, city) - layer_data.assert_fields(self.search_info["layer_fields"]) + layer_fields = self.search_info["layer_fields"] + filtered_fields = layer_data.filter_fields(layer_fields) # noqa: F841 + """ + # During testing / TODO: Investigate + filtered_fields = [ + (5, 'index', Word(0)) + (17, 'positionInterval', Long(500)) + (18, 'unk_byte_0x12', Byte(1)) + ] + """ + layer_data.assert_fields(layer_fields) data += layer_data.pack() data += pati.pack_optional_fields(city.optional_fields) with city.players.lock(): @@ -2737,7 +2764,7 @@ def sendNtcCircleHost(self, circle, new_leader, new_leader_index, seq): def notify_city_info_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2749,7 +2776,7 @@ def notify_city_info_set(self, path): def notify_city_number_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2757,17 +2784,16 @@ def notify_city_number_set(self, path): self.server.broadcast(gate.players, PatID4.NtcLayerUserNum, number_set, 0, self.session) - @staticmethod - def get_layer(path): - # type: (pati.LayerPath) -> db.Server | db.Gate | db.City | None - database = db.get_instance() + def get_layer(self, path): + # type: (pati.LayerPath) -> Server | Gate | City | None + fmp_state = self.session.FMP() if path.city_id > 0: - return database.get_city(path.server_id, path.gate_id, - path.city_id) + return fmp_state.get_city(path.server_id, path.gate_id, + path.city_id) elif path.gate_id > 0: - return database.get_gate(path.server_id, path.gate_id) + return fmp_state.get_gate(path.server_id, path.gate_id) elif path.server_id > 0: - return database.get_server(path.server_id) + return fmp_state.get_server(path.server_id) return None def notify_layer_departure(self, end): @@ -2793,7 +2819,7 @@ def notify_layer_departure(self, end): if path.city_id > 0: city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) self.notify_city_number_set(path) if city.leader is None: @@ -2822,7 +2848,8 @@ def on_exception(self, e): self.send_error("{}: {}".format(type(e).__name__, str(e))) def on_finish(self): - self.notify_layer_departure(True) + if isinstance(self.session, FMPSession): + self.notify_layer_departure(True) # TODO: Backport session_layer_end self.session.disconnect() self.session.delete() diff --git a/mh/pat_item.py b/mh/pat_item.py index c1676d4..7d93455 100644 --- a/mh/pat_item.py +++ b/mh/pat_item.py @@ -7,10 +7,11 @@ import struct from collections import OrderedDict + from mh.constants import pad +from mh.state_models import Server, Gate, City from other.utils import to_bytearray, get_config, get_external_ip, \ GenericUnpacker -from mh.database import Server, Gate, City class ItemType: @@ -354,6 +355,8 @@ def unpack_bytes(data, offset=0): class PatData(OrderedDict): + # TODO: type hint the whole module and add OrderedDict generic stub + # see https://mypy.readthedocs.io/en/stable/runtime_troubles.html """Pat structure holding items.""" FIELDS = ( (1, "field_0x01"), @@ -363,9 +366,11 @@ class PatData(OrderedDict): ) def __len__(self): + # type: () -> int return len(self.pack()) def __repr__(self): + # type: () -> str items = [ (index, value) for index, value in self.items() @@ -380,6 +385,7 @@ def __repr__(self): ) def __getattr__(self, name): + # type: (str) -> Item for field_id, field_name in self.FIELDS: if name == field_name: if field_id not in self: @@ -388,6 +394,7 @@ def __getattr__(self, name): raise AttributeError("Unknown field: {}".format(name)) def __setattr__(self, name, value): + # type: (str, Item) -> None for field_id, field_name in self.FIELDS: if name == field_name: if not isinstance(value, Item): @@ -399,6 +406,7 @@ def __setattr__(self, name, value): raise AttributeError("Cannot set unknown field: {}".format(name)) def __delattr__(self, name): + # type: (str) -> None for field_id, field_name in self.FIELDS: if name == field_name: del self[field_id] @@ -406,13 +414,19 @@ def __delattr__(self, name): return OrderedDict.__delattr__(self, name) def __setitem__(self, key, value): + # type: (int, Item) -> None if not isinstance(key, int) or not (0 <= key <= 255): raise IndexError("index must be a valid numeric value") elif not isinstance(value, Item): raise ValueError("{!r} not a valid PAT item".format(value)) return OrderedDict.__setitem__(self, key, value) + def __getitem__(self, key): + # type: (int) -> Item + return OrderedDict.__getitem__(self, key) + def __contains__(self, key): + # type: (int | str | object) -> bool if isinstance(key, str): for field_id, field_name in self.FIELDS: if field_name == key: @@ -426,13 +440,19 @@ def __contains__(self, key): return OrderedDict.__contains__(self, key) + def items(self): + # type: () -> list[tuple[int, Item]] + return OrderedDict.items(self) + def field_name(self, index): + # type: (int) -> str for field_id, field_name in self.FIELDS: if index == field_id: return field_name return "field_0x{:02x}".format(index) def pack(self): + # type: () -> bytes """Pack PAT items.""" items = [ (index, value) @@ -445,6 +465,7 @@ def pack(self): ) def pack_fields(self, fields): + # type: (set[int]) -> bytes """Pack PAT items specified fields.""" items = [ (index, value) @@ -458,6 +479,7 @@ def pack_fields(self, fields): @classmethod def unpack(cls, data, offset=0): + # type: (type[PatData], bytes, int) -> PatData obj = cls() field_count, = struct.unpack_from(">B", data, offset) offset += 1 @@ -489,11 +511,36 @@ def unpack(cls, data, offset=0): return obj def assert_fields(self, fields): - items = set(self.keys()) + # type: (set[int]) -> None + items = set(self.keys()) # type: set[int] fields = set(fields) message = "Fields mismatch: {}\n -> Expected: {}".format(items, fields) assert items == fields, message + def filter_fields(self, fields): + # type: (set[int]) -> list[tuple[int, str, Item]] + """Filter PatData excess of information. + + Places using this method should be investigated to make sure there is + no oversight (or more reverse engineering needed) regarding the + data structure and packet used. + """ + items = set(self.keys()) # type: set[int] + fields = set(fields) + missing_fields = fields - items + message = "Can't filter missing fields: {}\n".format(", ".join( + self.field_name(field_id) for field_id in missing_fields + )) + assert not missing_fields, message + new_fields = items - fields + filtered_fields = [] # type: list[tuple[int, str, Item]] + for field_id in new_fields: + filtered_fields.append( + (field_id, self.field_name(field_id), self[field_id]) + ) + del self[field_id] + return filtered_fields + class DummyData(PatData): FIELDS = tuple() @@ -588,6 +635,10 @@ class UserSearchInfo(PatData): ) +# FIXME: Both LayerPath and LayerData introduce a strong dependency on +# mh.state_models, which should be avoided for this module. + + class LayerPath(object): STRUCT = struct.Struct(">IIHHH") diff --git a/mh/session.py b/mh/session.py index 698c342..aaed0c4 100644 --- a/mh/session.py +++ b/mh/session.py @@ -9,6 +9,28 @@ from other.utils import to_bytearray, to_str +try: + from typing import Any, TypedDict, TYPE_CHECKING # noqa: F401 + + if TYPE_CHECKING: + from mh.state import State # noqa: F401 + + SessionLocalInfo = TypedDict( + "SessionLocalInfo", { + "server_id": None | int, + "server_name": None | str, # maybe bytes to support accents? + "gate_id": None | int, + "gate_name": None | str, # ditto regarding special characters + "city_id": None | int, + "city_name": None | str, # ditto regarding special characters + "city_size": int, + "city_capacity": int, + "circle_id": None | int, + } + ) +except (ImportError, TypeError): + pass + DB = db.get_instance() @@ -25,6 +47,9 @@ class SessionState: class Session(object): """Server session class. + The goal of this class is to help writing the PAT packet logic by + handling complex interactions between servers/database behind the scene. + TODO: - Finish the implementation """ @@ -40,7 +65,7 @@ def __init__(self, connection_handler): "city_size": 0, "city_capacity": 0, "circle_id": None, - } + } # type: SessionLocalInfo self.connection = connection_handler self.online_support_code = None self.request_reconnection = False @@ -52,9 +77,82 @@ def __init__(self, connection_handler): self.state = SessionState.UNKNOWN self.binary_setting = b"" self.search_payload = None - # TODO: Backport the server_id and serialisation logic + # TODO: Backport the server_id logic self.hunter_info = pati.HunterSettings() + def serialize(self): + # type: () -> dict[str, Any] + return { + "pat_ticket": self.pat_ticket, + # TODO: local_info should be safe to serialize on its own as dict + "local_info_server_id": self.local_info["server_id"], + "local_info_server_name": self.local_info["server_name"], + "local_info_gate_id": self.local_info["gate_id"], + "local_info_gate_name": self.local_info["gate_name"], + "local_info_city_id": self.local_info["city_id"], + "local_info_city_name": self.local_info["city_name"], + "local_info_city_size": self.local_info["city_size"], + "local_info_city_capacity": self.local_info["city_capacity"], + "local_info_circle_id": self.local_info["circle_id"], + "online_support_code": self.online_support_code, + "capcom_id": self.capcom_id, + "hunter_name": self.hunter_name, + "hunter_stats": self.hunter_stats, + "layer": self.layer, + "state": self.state, + "binary_setting": self.binary_setting, + "hunter_info": to_str(self.hunter_info.pack()) + } + + @staticmethod + def deserialize(obj): + # type: (dict[str, Any]) -> Session + session = Session(None) + session.pat_ticket = \ + str(obj["pat_ticket"]) \ + if obj["pat_ticket"] else obj["pat_ticket"] + session.local_info["server_id"] = \ + int(obj["local_info_server_id"]) \ + if obj["local_info_server_id"] else obj["local_info_server_id"] + session.local_info["server_name"] = \ + str(obj["local_info_server_name"]) \ + if obj["local_info_server_name"] else obj["local_info_server_name"] + session.local_info["gate_id"] = \ + int(obj["local_info_gate_id"]) \ + if obj["local_info_gate_id"] else obj["local_info_gate_id"] + session.local_info["gate_name"] = \ + str(obj["local_info_gate_name"]) \ + if obj["local_info_gate_name"] else obj["local_info_gate_name"] + session.local_info["city_id"] = \ + int(obj["local_info_city_id"]) \ + if obj["local_info_city_id"] else obj["local_info_city_id"] + session.local_info["city_name"] = \ + str(obj["local_info_city_name"]) \ + if obj["local_info_city_name"] else obj["local_info_city_name"] + session.local_info["city_size"] = \ + int(obj["local_info_city_size"]) \ + if obj["local_info_city_size"] else obj["local_info_city_size"] + session.local_info["city_capacity"] = \ + int(obj["local_info_city_capacity"]) \ + if obj["local_info_city_capacity"] \ + else obj["local_info_city_capacity"] + session.local_info["circle_id"] = \ + int(obj["local_info_circle_id"]) \ + if obj["local_info_circle_id"] else obj["local_info_circle_id"] + session.online_support_code = \ + str(obj["online_support_code"]) \ + if obj["online_support_code"] else obj["online_support_code"] + session.capcom_id = str(obj["capcom_id"]) + session.hunter_name = str(obj["hunter_name"]) + session.hunter_stats = obj["hunter_stats"] + session.layer = int(obj["layer"]) + session.state = int(obj["state"]) + session.binary_setting = obj["binary_setting"] + h_settings = bytearray(obj["hunter_info"], encoding='ISO-8859-1') + session.hunter_info = pati.HunterSettings().unpack(h_settings, + len(h_settings)) + return session + def get(self, connection_data): """Return the session associated with the connection data, if any.""" if hasattr(connection_data, "pat_ticket"): @@ -68,6 +166,7 @@ def get(self, connection_data): ) session = DB.get_session(self.pat_ticket) or self if session != self: + # TODO: Use a less error-prone check assert session.connection is None, "Session is already in use" session.connection = self.connection self.connection = None @@ -118,25 +217,111 @@ def get_users(self, first_index, count): def use_user(self, index, name): DB.use_user(self, index, name) + def find_user_by_capcom_id(self, capcom_id): + sessions = DB.find_users(capcom_id=capcom_id) + if sessions: + return sessions[0] + return None + + def find_users(self, capcom_id, hunter_name, first_index, count): + users = DB.find_users(capcom_id, hunter_name) + start = first_index - 1 + return users[start:start+count] + + def get_user_name(self, capcom_id): + return DB.get_user_name(capcom_id) + + def add_friend_request(self, capcom_id): + return DB.add_friend_request(self.capcom_id, capcom_id) + + def accept_friend(self, capcom_id, accepted=True): + return DB.accept_friend(self.capcom_id, capcom_id, accepted) + + def delete_friend(self, capcom_id): + return DB.delete_friend(self.capcom_id, capcom_id) + + def get_friends(self, first_index=None, count=None): + return DB.get_friends(self.capcom_id, first_index, count) + # TODO: server_index and recall logic def get_servers(self): - return DB.get_servers() + """LMP servers can request the server list""" + return DB.servers # FIXME: See FmpServer constructor + + +class FMPSession(Session): + """Specialized session wrapper for FMP server. + + Most methods should be unrelated to DB. + This prevent other servers to trigger FMP related features. + """ + def __init__(self, session): + # type: (Session) -> None + self._session = session + self._fmp_state = session.connection.server.fmp_state # type: State + + def __getattr__(self, name): + # type: (str) -> Any + """Called when the default attribute access fails.""" + return getattr(self._session, name) + + def __setattr__(self, name, value): + # type: (str, Any) -> None + if name.startswith("_"): + return object.__setattr__(self, name, value) + return setattr(self._session, name, value) + + def __eq__(self, other): + # type: (Session) -> bool + """Avoid comparison issues when wrapping around the session. + + For instance, list removal is done by equality not identity, + in other words, FMPSession instances can get Session instances to be + removed from a list. + """ + if isinstance(other, FMPSession): + other = other._session + elif not isinstance(other, Session): + return NotImplemented + return self._session == other + + def __ne__(self, other): + # type: (Session) -> bool + x = self.__eq__(other) + if x is NotImplemented: + return NotImplemented + return not x + + def get(self, connection_data): + """Wrap existing session if needed""" + session = Session.get(self, connection_data) + if not isinstance(session, FMPSession): + session = FMPSession(session) + return session + + def FMP(self): + # type: () -> State + """Wrapper that can be repurposed later for cache/error handling.""" + return self._fmp_state + + def get_servers(self): + return self.FMP().get_servers() def get_server(self): assert self.local_info['server_id'] is not None - return DB.get_server(self.local_info['server_id']) + return self.FMP().get_server(self.local_info['server_id']) def get_gate(self): assert self.local_info['gate_id'] is not None - return DB.get_gate(self.local_info['server_id'], - self.local_info['gate_id']) + return self.FMP().get_gate(self.local_info['server_id'], + self.local_info['gate_id']) def get_city(self): assert self.local_info['city_id'] is not None - return DB.get_city(self.local_info['server_id'], - self.local_info['gate_id'], - self.local_info['city_id']) + return self.FMP().get_city(self.local_info['server_id'], + self.local_info['gate_id'], + self.local_info['city_id']) def get_circle(self): assert self.local_info['circle_id'] is not None @@ -197,10 +382,10 @@ def layer_detail_search(self, detailed_fields): (field_id, value) for field_id, field_type, value in detailed_fields ] # Convert detailed to simple optional fields - return DB.layer_detail_search(server_type, fields) + return self.FMP().layer_detail_search(server_type, fields) def join_server(self, server_id): - return DB.join_server(self, server_id) + return self.FMP().join_server(self, server_id) def get_layer_children(self): if self.layer == 0: @@ -218,74 +403,63 @@ def get_layer_sibling(self): def find_users_by_layer(self, server_id, gate_id, city_id, first_index, count, recursive=False): - if recursive: - players = DB.get_all_users(server_id, gate_id, city_id) - else: - layer = \ - DB.get_city(server_id, gate_id, city_id) if city_id else \ - DB.get_gate(server_id, gate_id) if gate_id else \ - DB.get_server(server_id) - players = list(layer.players) start = first_index - 1 + if recursive: + players = self.FMP().get_all_users(server_id, gate_id, city_id) + return players[start:start+count] + + layer = \ + self.FMP().get_city(server_id, gate_id, city_id) if city_id else \ + self.FMP().get_gate(server_id, gate_id) if gate_id else \ + self.FMP().get_server(server_id) + players = list(layer.players) return players[start:start+count] - def find_user_by_capcom_id(self, capcom_id): - sessions = DB.find_users(capcom_id=capcom_id) - if sessions: - return sessions[0] - return None - - def find_users(self, capcom_id, hunter_name, first_index, count): - users = DB.find_users(capcom_id, hunter_name) - start = first_index - 1 - return users[start:start+count] - - def get_user_name(self, capcom_id): - return DB.get_user_name(capcom_id) - def leave_server(self): - DB.leave_server(self, self.local_info["server_id"]) + self.FMP().leave_server(self, self.local_info["server_id"]) def get_gates(self): - return DB.get_gates(self.local_info["server_id"]) + return self.FMP().get_gates(self.local_info["server_id"]) def join_gate(self, gate_id): - DB.join_gate(self, self.local_info["server_id"], gate_id) + self.FMP().join_gate(self, self.local_info["server_id"], gate_id) self.state = SessionState.GATE def leave_gate(self): - DB.leave_gate(self) + self.FMP().leave_gate(self) self.state = SessionState.LOG_IN def get_cities(self): - return DB.get_cities(self.local_info["server_id"], - self.local_info["gate_id"]) + return self.FMP().get_cities(self.local_info["server_id"], + self.local_info["gate_id"]) def is_city_empty(self, city_id): - return DB.get_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id).get_state() == db.LayerState.EMPTY + return self.FMP().get_city( + self.local_info["server_id"], + self.local_info["gate_id"], + city_id + ).is_empty() def reserve_city(self, city_id, reserve): - return DB.reserve_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id, reserve) + return self.FMP().reserve_city(self.local_info["server_id"], + self.local_info["gate_id"], + city_id, reserve) def create_city(self, city_id, settings, optional_fields): - return DB.create_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id, settings, optional_fields) + return self.FMP().create_city( + self, self.local_info["server_id"], self.local_info["gate_id"], + city_id, settings, optional_fields + ) def join_city(self, city_id): - DB.join_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id) + self.FMP().join_city(self, + self.local_info["server_id"], + self.local_info["gate_id"], + city_id) self.state = SessionState.CITY def leave_city(self): - DB.leave_city(self) + self.FMP().leave_city(self) self.state = SessionState.GATE def try_transfer_city_leadership(self): @@ -386,15 +560,3 @@ def get_optional_fields(self): (1, (weapon_type << 24) | location), (2, hunter_rank << 16) ] - - def add_friend_request(self, capcom_id): - return DB.add_friend_request(self.capcom_id, capcom_id) - - def accept_friend(self, capcom_id, accepted=True): - return DB.accept_friend(self.capcom_id, capcom_id, accepted) - - def delete_friend(self, capcom_id): - return DB.delete_friend(self.capcom_id, capcom_id) - - def get_friends(self, first_index=None, count=None): - return DB.get_friends(self.capcom_id, first_index, count) diff --git a/mh/state.py b/mh/state.py new file mode 100644 index 0000000..78238f7 --- /dev/null +++ b/mh/state.py @@ -0,0 +1,280 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2023-2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""FMP server state module. + +Each FMP server needs to know the in-game resources it manages: + - Servers + - Gates + - Cities + - Sessions (i.e. active server connections) + +TODO/FIXME: +Our current design isn't very scalable. +""" + + +from mh.database import get_instance as get_db +from mh.state_models import Server, ServerType + +try: + from typing import TypedDict, TYPE_CHECKING + if TYPE_CHECKING: + from mh.session import Session + from mh.state_models import Gate, City # noqa: F401 + + CapcomIDsInfo = TypedDict("CapcomIDsInfo", { + "name": bytes, + "session": None | Session + }) +except (ImportError, TypeError): + pass + + +def new_servers(): + # type: () -> list[Server] + # TODO: This logic was removed upstream to use harcoded values... + servers = [] # type: list[Server] + servers.extend([ + Server("Valor{}".format(i), ServerType.OPEN) + for i in range(1, 5) + ]) + servers.extend([ + Server("Beginners{}".format(i), ServerType.ROOKIE) + for i in range(1, 3) + ]) + servers.extend([ + Server("Veterans{}".format(i), ServerType.EXPERT) + for i in range(1, 3) + ]) + servers.extend([ + Server("Greed{}".format(i), ServerType.RECRUITING) + for i in range(1, 5) + ]) + return servers + + +class State(object): + """FMP server state class. + + Holds the required information not to bother too frequently: + - the database server + - the central server + + Ideally, this should allow the server to be resilient and operate even + with limited connectivity to the database and central server. + + TODO: (backport) + - Backport cache/server_id logic + - Like the following methods if needed: + setup_server + register_pat_ticket + get_session + disconnect_session + delete_session (+ cache logic) + fetch_id + get_servers_version + update_players + update_capcom_id + session_ready + set_session_ready + close_cache + + FIXME: (backport) + - These methods should belong to db (imho) + new_pat_ticket -> db.generate_pat_ticket + fill self.session + use_capcom_id + use_user + get_users (i.e. LMP server) + """ + def __init__(self): + self.servers = new_servers() + # TODO: Backport Sessions/Capcom ID handling (currently unused) + self.sessions = { + # PAT Ticket => Owner's session + } # type: dict[str, Session] + self.capcom_ids = { + # Capcom ID => Owner's name and session + # NB: Not to be mistaken with database.capcom_ids! + } # type: dict[str, CapcomIDsInfo] + + def join_server(self, session, index): + # type: (Session, int) -> Server + if session.local_info["server_id"] is not None: + self.leave_server(session, session.local_info["server_id"]) + # TODO: Backport cache/joining another external server + # It might imply moving from (array) index to (dict) server_id + server = self.get_server(index) + server.players.add(session) + session.local_info["server_id"] = index + session.local_info["server_name"] = server.name + return server + + def leave_server(self, session, index): + # type: (Session, int) -> None + self.get_server(index).players.remove(session) + session.local_info["server_id"] = None + session.local_info["server_name"] = None + + def get_server_time(self): + # TODO: Use it at some point or remove it + pass + + def get_game_time(self): + # TODO: Use it at some point or remove it + pass + + def get_servers(self): + # type: () -> list[Server] + # TODO: Backport cache code + return self.servers + + def get_server(self, index): + # type: (int) -> Server + # TODO: Backport cache code + assert 0 < index <= len(self.servers), "Invalid server index" + return self.servers[index - 1] + + def get_gates(self, server_id): + # type: (int) -> list[Gate] + return self.get_server(server_id).gates + + def get_gate(self, server_id, index): + # type: (int, int) -> Gate + gates = self.get_gates(server_id) + assert 0 < index <= len(gates), "Invalid gate index" + return gates[index - 1] + + def join_gate(self, session, server_id, index): + # type: (Session, int, int) -> Gate + gate = self.get_gate(server_id, index) + gate.parent.players.remove(session) + gate.players.add(session) + session.local_info["gate_id"] = index + session.local_info["gate_name"] = gate.name + return gate + + def leave_gate(self, session): + # type: (Session) -> None + gate = self.get_gate(session.local_info["server_id"], + session.local_info["gate_id"]) + gate.parent.players.add(session) + gate.players.remove(session) + session.local_info["gate_id"] = None + session.local_info["gate_name"] = None + + def get_cities(self, server_id, gate_id): + # type: (int, int) -> list[City] + return self.get_gate(server_id, gate_id).cities + + def get_city(self, server_id, gate_id, index): + # type: (int, int, int) -> City + cities = self.get_cities(server_id, gate_id) + assert 0 < index <= len(cities), "Invalid city index" + return cities[index - 1] + + def reserve_city(self, server_id, gate_id, index, reserve): + # type: (int, int, int, bool) -> bool + city = self.get_city(server_id, gate_id, index) + with city.lock(): + if reserve and city.is_reserved(): + return False + city.reserve(reserve) + return True + + def get_all_users(self, server_id, gate_id, city_id): + # type: (int, int, None | int) -> list[tuple[int, Session]] + """Search for users in layers and its children. + + Let's assume wildcard search isn't possible for servers and gates. + A wildcard search happens when the id is zero. + """ + assert 0 < server_id, "Invalid server index" + assert 0 < gate_id, "Invalid gate index" + gate = self.get_gate(server_id, gate_id) + users = list(gate.players) + cities = [ + self.get_city(server_id, gate_id, city_id) + ] if city_id else self.get_cities(server_id, gate_id) + for city in cities: + users.extend(list(city.players)) + return users + + def find_users(self, capcom_id="", hunter_name=b""): + # type: (str, bytes) -> list[Session] + assert capcom_id or hunter_name, "Search can't be empty" + users = [] # type: list[Session] + for user_id, user_info in self.capcom_ids.items(): + session = user_info["session"] + if not session: + continue + if capcom_id and capcom_id not in user_id: + continue + if hunter_name and \ + hunter_name.lower() not in user_info["name"].lower(): + continue + users.append(session) + # TODO: Backport cache/central code + # FIXME: We need to rely on DB meanwhile... + assert len(users) == 0, "State doesn't save IDs/Session yet" + users.extend(get_db().find_users( + capcom_id=capcom_id, + hunter_name=hunter_name + )) + return users + + def create_city(self, session, server_id, gate_id, index, + settings, optional_fields): + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.optional_fields = optional_fields + city.leader = session + return city + + def join_city(self, session, server_id, gate_id, index): + # type: (Session, int, int, int) -> City + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.parent.players.remove(session) + city.players.add(session) + session.local_info["city_name"] = city.name + session.local_info["city_id"] = index + return city + + def leave_city(self, session): + # type: (Session) -> None + city = self.get_city(session.local_info["server_id"], + session.local_info["gate_id"], + session.local_info["city_id"]) + with city.lock(): + city.parent.players.add(session) + city.players.remove(session) + if not city.get_population(): + city.clear_circles() + session.local_info["city_id"] = None + session.local_info["city_name"] = None + + def layer_detail_search(self, server_type, fields): + # TODO: Better document and test this + cities = [] + + def match_city(city, fields): + with city.lock(): + return all(( + field in city.optional_fields + for field in fields + )) + + for server in self.get_servers(): + if server.server_type != server_type: + continue + for gate in server.gates: + if not gate.get_population(): + continue + cities.extend([ + city + for city in gate.cities + if match_city(city, fields) + ]) + return cities diff --git a/mh/state_models.py b/mh/state_models.py new file mode 100644 index 0000000..9e5769e --- /dev/null +++ b/mh/state_models.py @@ -0,0 +1,585 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""FMP server state models module. + +This module contains ORM-like models to be (de)serialized. +""" + +import time + +from threading import RLock + +try: + from typing import Any, Literal, TYPE_CHECKING # noqa: F401 + + ServerTypeLiteral = Literal[1, 2, 3, 4] + LayerStateLiteral = Literal[0, 1, 2] + + if TYPE_CHECKING: + # FIXME: Session.deserialize introduces cyclic import + from mh.session import Session +except (ImportError, TypeError): + pass + + +RESERVE_DC_TIMEOUT = 40.0 + + +class ServerType: + OPEN = 1 + ROOKIE = 2 + EXPERT = 3 + RECRUITING = 4 + + +class LayerState: + JOINABLE = 0 + EMPTY = 1 + FULL = 2 + + +class Lockable(object): + def __init__(self): + self._lock = RLock() + + def lock(self): + return self + + def __enter__(self): + # Returns True if lock was acquired, False otherwise + return self._lock.acquire() + + def __exit__(self, *args): + # type: (Any) -> None + self._lock.release() + + +class Players(Lockable): + """Helper class to help retain player IDs.""" + def __init__(self, capacity): + # type: (int) -> None + assert capacity > 0, "Collection capacity can't be zero" + + self.slots = [ + None for _ in range(capacity) + ] # type: list[None | Session] + self.used = 0 + super(Players, self).__init__() + + def get_used_count(self): + # type: () -> int + return self.used + + def get_capacity(self): + # type: () -> int + return len(self.slots) + + def add(self, item): + # type: (Session) -> int + with self.lock(): + if self.used >= len(self.slots): + return -1 + + item_index = self.index(item) + if item_index != -1: + return item_index + + for i, v in enumerate(self.slots): + if v is not None: + continue + + self.slots[i] = item + self.used += 1 + return i + + return -1 + + def remove(self, item): + # type: (Session | int) -> bool + assert item is not None, "Item != None" + + with self.lock(): + if self.used < 1: + return False + + if isinstance(item, int): + if item >= self.get_capacity(): + return False + + self.slots[item] = None + self.used -= 1 + return True + + for i, v in enumerate(self.slots): + if v != item: + continue + + self.slots[i] = None + self.used -= 1 + return True + + return False + + def index(self, item): + # type: (Session) -> int + assert item is not None, "Item != None" + + for i, v in enumerate(self.slots): + if v == item: + return i + + return -1 + + def clear(self): + # type: () -> None + with self.lock(): + for i in range(self.get_capacity()): + self.slots[i] = None + + def find_first(self, **kwargs): + # type: (Any) -> None | Session + if self.used < 1: + return None + + for p in self.slots: + if p is None: + continue + + for k, v in kwargs.items(): + if getattr(p, k) != v: + break + else: + return p + + return None + + def find_by_capcom_id(self, capcom_id): + # type: (str) -> None | Session + return self.find_first(capcom_id=capcom_id) + + def __len__(self): + return self.used + + def __iter__(self): + if self.used < 1: + return + + for i, v in enumerate(self.slots): + if v is None: + continue + + yield i, v + + def serialize(self): + # type: () -> dict[str, Any] + if not self.used: + return {"capacity": len(self.slots)} + return { + "slots": [ + (p.serialize() + if p is not None + else None) + for p in self.slots + ], + "used": self.used + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], None) -> Players + if "used" not in obj.keys(): + return Players(obj["capacity"]) + + from mh.session import Session # avoid cyclic import + + players = Players(len(obj["slots"])) + players.slots = [ + (Session.deserialize(p) + if p is not None + else None) + for p in obj["slots"] + ] + players.used = obj["used"] + return players + + +class Circle(Lockable): + def __init__(self, parent): + # type: (City) -> None + self.parent = parent + self.leader = None # type: None | Session + self.players = Players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False # FIXME: Seems never used + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + super(Circle, self).__init__() + + def get_population(self): + # type: () -> int + return len(self.players) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def is_full(self): + # type: () -> bool + return self.get_population() == self.get_capacity() + + def is_empty(self): + # type: () -> bool + return self.leader is None + + def is_joinable(self): + # type: () -> bool + return not self.departed and not self.is_full() + + def has_password(self): + # type: () -> bool + return self.password is not None + + def reset_players(self, capacity): + # type: (int) -> None + with self.lock(): + self.players = Players(capacity) + + def reset(self): + # type: () -> None + with self.lock(): + self.leader = None + self.reset_players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + + def serialize(self): + # type: () -> dict[str, Any] + players = self.players.serialize() + if "used" not in players.keys(): + return {} + return { + "parent": None, + "leader": + self.leader.serialize() + if self.leader is not None + else None, + "players": players, + "departed": self.departed, + "quest_id": self.quest_id, + "embarked": self.embarked, + "password": self.password, + "remarks": self.remarks, + "unk_byte_0x0e": self.unk_byte_0x0e + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], City) -> Circle + circle = Circle(parent) + if not obj.keys(): + return circle + + from mh.session import Session # avoid cyclic import + + circle.leader = \ + Session.deserialize(obj["leader"]) \ + if obj["leader"] is not None \ + else None + # TODO: Players class doesn't have "parent" member variable + circle.players = Players.deserialize(obj["players"], circle) + circle.departed = obj["departed"] + circle.quest_id = obj["quest_id"] + circle.embarked = obj["embarked"] + circle.password = obj["password"] + circle.remarks = obj["remarks"] + circle.unk_byte_0x0e = obj["unk_byte_0x0e"] + return circle + + +class City(Lockable): + LAYER_DEPTH = 3 + + def __init__(self, name, parent): + # type: (str, Gate) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.players = Players(4) + self.optional_fields = [] + self.leader = None + self.reserved = None + self.circles = [ + # One circle per player + Circle(self) for _ in range(self.get_capacity()) + # NB: Might not work on the Japanese version, IIRC the limit is 5 + ] + super(City, self).__init__() + + def get_population(self): + # type: () -> int + return len(self.players) + + def in_quest_players(self): + # type: () -> int + return sum(p.is_in_quest() for _, p in self.players) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_state(self): + # type: () -> LayerStateLiteral + # FIXME: The following part was removed v + if self.reserved: + return LayerState.FULL + # TODO: ^ Backport it and make sure that wasn't a mistake + + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def is_empty(self): + # type: () -> bool + return self.get_state() == LayerState.EMPTY + + def get_pathname(self): + # type: () -> str + pathname = self.name # type: str + it = self.parent + # FIXME: Some tools don't like type-checking this loop at all + while it is not None: # type: ignore + pathname = it.name + "\t" + pathname + it = it.parent # type: Gate | Server | None + return pathname + + def get_first_empty_circle(self): + # type: () -> tuple[None, None] | tuple[Circle, int] + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.is_empty(): + return circle, index + return None, None + + def get_circle_for(self, leader_session): + # type: (Session) -> tuple[None, None] | tuple[Circle, int] + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.leader == leader_session: + return circle, index + return None, None + + def clear_circles(self): + # type: () -> None + with self.lock(): + for circle in self.circles: + circle.reset() + + def reserve(self, reserve): + # type: (bool) -> None + with self.lock(): + if reserve: + self.reserved = time.time() + else: + self.reserved = None + + def is_reserved(self): + # type: () -> bool + reserved_time = self.reserved # type: None | float + if reserved_time: + return time.time()-reserved_time < RESERVE_DC_TIMEOUT + return False + + def get_all_players(self): + # type: () -> list[Session] + with self.players.lock(): + return [p for _, p in self.players] + + def serialize(self): + # type: () -> dict[str, Any] + players = self.players.serialize() + if "used" not in players.keys(): + return {"name": self.name} + return { + "name": self.name, + "parent": None, # TODO: Why (not) serializing it? + "state": self.state, + "players": players, + "optional_fields": self.optional_fields, + "leader": + self.leader.serialize() + if self.leader is not None + else None, + "reserved": self.reserved, + "circles": [c.serialize() for c in self.circles] + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], Gate) -> City + if len(obj.keys()) < 2: + return City(obj["name"], None) + city = City( + str(obj["name"]) if obj["name"] is not None else obj["name"], + parent + ) + city.state = obj["state"] + city.players = Players.deserialize(obj["players"], parent) + city.optional_fields = obj["optional_fields"] + city.leader = \ + Session.deserialize(obj["leader"]) \ + if obj["leader"] is not None \ + else None + city.reserved = obj["reserved"] + city.circles = [Circle.deserialize(c, city) for c in obj["circles"]] + return city + + +class Gate(object): + LAYER_DEPTH = 2 + + def __init__(self, name, parent, city_count=40, player_capacity=100): + # type: (str, Server, int, int) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.cities = [ + City("City{}".format(i), self) + for i in range(1, city_count+1) + ] + self.players = Players(player_capacity) + self.optional_fields = [] + + def get_population(self): + # type: () -> int + return len(self.players) + sum(( + city.get_population() + for city in self.cities + )) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_state(self): + # type: () -> LayerStateLiteral + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def get_all_players(self): + # type: () -> list[Session] + # TODO: Backport its use, e.g. in cache + players = [p for _, p in self.players] + for city in self.cities: + players += city.get_all_players() + return players + + def serialize(self): + # type: () -> dict[str, Any] + return { + "name": self.name, + "parent": None, + "state": self.state, + "cities": [c.serialize() for c in self.cities], + "players": self.players.serialize(), + "optional_fields": self.optional_fields + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], Server) -> Gate + gate = Gate( + str(obj["name"]) if obj["name"] is not None else obj["name"], + parent + ) + gate.state = obj["state"] + gate.cities = [City.deserialize(c, gate) for c in obj["cities"]] + gate.players = Players.deserialize(obj["players"], gate) + gate.optional_fields = obj["optional_fields"] + return gate + + +class Server(object): + LAYER_DEPTH = 1 + + def __init__(self, name, server_type, gate_count=40, capacity=2000, + addr=None, port=None): + # type: (str, ServerTypeLiteral, int, int, None | str, None | int) -> None # noqa: E501 + # TODO: Backport the constructor change if needed + self.name = name + self.parent = None + self.server_type = server_type + # Public IP address + self.addr = addr # type: None | str + self.port = port # type: None | int + self.gates = [ + Gate("City Gate{}".format(i), self) + for i in range(1, gate_count+1) + ] + self.players = Players(capacity) + + def get_population(self): + # type: () -> int + return len(self.players) + sum(( + gate.get_population() for gate in self.gates + )) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_all_players(self): + # type: () -> list[Session] + # TODO: Backport its use, e.g. in cache + players = [p for _, p in self.players] + for gate in self.gates: + players = players + gate.get_all_players() + return players + + def serialize(self): + # type: () -> dict[str, Any] + return { + "name": self.name, + "parent": self.parent, + "server_type": self.server_type, + "addr": self.addr, + "port": self.port, + "gates": [g.serialize() for g in self.gates], + "players": self.players.serialize() + } + + @staticmethod + def deserialize(obj): + # type: (dict[str, Any]) -> Server + server = Server( + str(obj["name"]) if obj["name"] is not None + else obj["name"], + int(obj["server_type"]) if obj["server_type"] + else obj["server_type"], + addr=str(obj["addr"]) if obj["addr"] is not None + else obj["addr"], + port=int(obj["port"]) if obj["port"] is not None + else obj["port"] + ) + server.parent = obj["parent"] + server.gates = [Gate.deserialize(g, server) for g in obj["gates"]] + server.players = Players.deserialize(obj["players"], server) + return server From 7368bf3181eca6ace42a66531e69cce3c0de0f81 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Mon, 4 Aug 2025 02:21:53 +0400 Subject: [PATCH 2/7] other: Add a config module with some helper classes Fix MaxThread option being ignored. Remove unused legacy_ssl parameter. --- mh/database.py | 9 +- mh/pat.py | 7 +- mh/pat_item.py | 6 +- other/config.py | 233 ++++++++++++++++++++++++++++++++++++++++++++++++ other/utils.py | 133 ++------------------------- 5 files changed, 255 insertions(+), 133 deletions(-) create mode 100644 other/config.py diff --git a/mh/database.py b/mh/database.py index 9c17bfe..9b6afb7 100644 --- a/mh/database.py +++ b/mh/database.py @@ -7,9 +7,12 @@ import inspect import random import sqlite3 -from other import utils + from threading import local as thread_local +from other import utils +from other.config import MySQLConfig + CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -415,7 +418,7 @@ def __init__(self): self.parent.__init__() from mysql import connector self.connection = connector.connect( - **utils.get_mysql_config("MYSQL") + **MySQLConfig().connect_kwargs() ) self.create_database() self.populate_database() @@ -659,7 +662,7 @@ def __init__(self, *args, **kwargs): CURRENT_DB = \ MySQLDatabase() \ - if utils.is_mysql_enabled("MYSQL") \ + if MySQLConfig().is_enabled() \ else TempSQLiteDatabase() # type: TempDatabase diff --git a/mh/pat.py b/mh/pat.py index 5d1e111..a4db9d7 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -8,7 +8,8 @@ import traceback from datetime import timedelta -from other.utils import Logger, get_config, get_external_ip, hexdump, to_str +from other.config import ServerConfig +from other.utils import Logger, get_external_ip, hexdump, to_str from mh.quest_utils import QuestLoader import mh.pat_item as pati @@ -646,7 +647,7 @@ def recvReqLmpConnect(self, packet_id, data, seq): TODO: I don't think it's related to LMP protocol. """ - config = get_config("LMP") + config = ServerConfig("LMP") self.sendAnsLmpConnect(get_external_ip(config), config["Port"], seq) def sendAnsLmpConnect(self, address, port, seq): @@ -1151,7 +1152,7 @@ def recvReqFmpInfo(self, packet_id, data, seq): # FIXME: Doesn't seem to make sense here, # e.g. on LMP server, as "FmpInfo" packet # server = self.session.join_server(index) - config = get_config("FMP") + config = ServerConfig("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] fmp_data = pati.FmpData() diff --git a/mh/pat_item.py b/mh/pat_item.py index 7d93455..1910a10 100644 --- a/mh/pat_item.py +++ b/mh/pat_item.py @@ -10,8 +10,8 @@ from mh.constants import pad from mh.state_models import Server, Gate, City -from other.utils import to_bytearray, get_config, get_external_ip, \ - GenericUnpacker +from other.config import ServerConfig +from other.utils import to_bytearray, get_external_ip, GenericUnpacker class ItemType: @@ -1027,7 +1027,7 @@ def pack(self): def get_fmp_servers(session, first_index, count): assert first_index > 0, "Invalid list index" - config = get_config("FMP") + config = ServerConfig("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] diff --git a/other/config.py b/other/config.py new file mode 100644 index 0000000..9ce5bd0 --- /dev/null +++ b/other/config.py @@ -0,0 +1,233 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2021-2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Config module.""" + +from collections import OrderedDict + +try: + from collections.abc import Callable # noqa: F401 + from typing import TYPE_CHECKING, Any + IS_PYTHON2 = False +except ImportError: + TYPE_CHECKING = False # type: ignore + IS_PYTHON2 = True # type: ignore + +if TYPE_CHECKING: + OrderedDictT = OrderedDict[str, Any] + from configparser import RawConfigParser + from argparse import ArgumentParser # noqa: F401 +else: # workaround to avoid some type hinting issues + OrderedDictT = OrderedDict + if IS_PYTHON2: + from ConfigParser import RawConfigParser + else: + from configparser import RawConfigParser + + +CONFIG_FILE = "config.ini" + + +class ConfigLoader(RawConfigParser): + """Generic INI config loader class.""" + def __init__(self, config_path): + # type: (str) -> None + RawConfigParser.__init__(self, allow_no_value=True) + # Override this member to avoid options being lowercased + self.optionxform = lambda optionstr: str(optionstr) + self.read(config_path) + self._config_path = config_path + + def reload(self): + # type: () -> None + self.read(self._config_path) + + +class ConfigSection(OrderedDictT): + """ConfigSection helper class. + + This class loads all options from a config section into a dict. + """ + BOOL = tuple() # type: tuple[str, ...] + INT = tuple() # type: tuple[str, ...] + FLOAT = tuple() # type: tuple[str, ...] + STR = tuple() # type: tuple[str, ...] + # Special cases + SP = {} # type: dict[str, Callable[[ConfigLoader, str], str]] + + def __init__(self, section_name, config_path=CONFIG_FILE): + # type: (str, str) -> None + super(ConfigSection, self).__init__() + self._config = ConfigLoader(config_path) + self._section_name = section_name + self._config_path = config_path + self._load_section() + self._validate() + + def _load_section(self): + # type: () -> None + for option in self._config.options(self._section_name): + self[option] = ( + self._config.getboolean(self._section_name, option) + if option in self.BOOL + else + self._config.getint(self._section_name, option) + if option in self.INT + else + self._config.getfloat(self._section_name, option) + if option in self.FLOAT + else + self.SP[option](self._config, self._section_name) + if option in self.SP + else + self._config.get(self._section_name, option) + ) + + def _validate(self): + # type: () -> None + missing_options = [] # type: list[str] + for typed_options in (self.BOOL, self.INT, self.FLOAT, self.STR): + missing_options.extend( + option + for option in typed_options + if option not in self + ) + missing_options.extend( + option + for option in self.SP + if option not in self + ) + message = '{}: section "{}" has missing option(s): {}'.format( + self._config_path, self._section_name, ", ".join(missing_options) + ) + assert not missing_options, message + + def reload(self): + # type: () -> None + """Reload the config.""" + self._config.reload() + self._load_section() + self._validate() + + +class BaseServerConfig(ConfigSection): + INT = ("Port",) + BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow") + STR = ("IP", "ExternalIP", "Name", "LogFilename") + SP = { + "SSLCert": + lambda cfg, section: cfg.get(section, "SSLCert") + or cfg.get("SSL", "DefaultCert"), + "SSLKey": + lambda cfg, section: cfg.get(section, "SSLKey") + or cfg.get("SSL", "DefaultKey") + } + + +class ServerConfig(BaseServerConfig): + """OPN/LMP/FMP/RFP server config.""" + INT = BaseServerConfig.INT + ("MaxThread",) + + +class MySQLConfig(ConfigSection): + BOOL = ("enabled",) + STR = ("user", "password", "host", "database", + "ssl_ca", "ssl_cert", "ssl_key") + + def __init__(self, section_name="MYSQL", config_path=CONFIG_FILE): + # type: (str, str) -> None + super(MySQLConfig, self).__init__(section_name, config_path) + + def is_enabled(self): + # type: () -> bool + return self["enabled"] # type: ignore + + def connect_kwargs(self): + # type: () -> dict[str, Any] + from mysql.connector.constants import ClientFlag + kwargs = { + k: v for k, v in self.items() + if k not in ("enabled",) + } # type: dict[str, Any] + kwargs.update({ + "charset": "utf8", + "autocommit": True, + "client_flags": [ClientFlag.SSL] if kwargs["ssl_ca"] else None, + "ssl_ca": kwargs["ssl_ca"] or None, + "ssl_cert": kwargs["ssl_cert"] or None, + "ssl_key": kwargs["ssl_key"] or None + }) + return kwargs + + +# TODO: Backport latest_patch and central config code + + +def argparse_from_config(config): + # type: (BaseServerConfig) -> ArgumentParser + """Argument parser from config.""" + import argparse + + def typebool(s): + # type: (bool | str) -> bool + if isinstance(s, bool): + return s + s = s.lower() + if s in ("on", "yes", "y", "true", "t", "1"): + return True + elif s in ("off", "no", "n", "false", "f", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-i", "--interactive", action="store_true", + dest="interactive", + help="create an interactive shell") + parser.add_argument("-d", "--debug_mode", action="store_true", + dest="debug_mode", + help="enable debug mode, disabling timeouts and \ + lower logging verbosity level") + parser.add_argument("-a", "--address", action="store", type=str, + default=config["IP"], dest="address", + help="set server address") + parser.add_argument("-p", "--port", action="store", type=int, + default=config["Port"], dest="port", + help="set server port") + parser.add_argument("-n", "--name", action="store", type=str, + default=config["Name"], dest="name", + help="set server name") + parser.add_argument("-s", "--use-ssl", action="store", type=typebool, + default=config["UseSSL"], dest="use_ssl", + help="use SSL protocol") + parser.add_argument("-c", "--ssl-cert", action="store", type=str, + default=config["SSLCert"], dest="ssl_cert", + help="set server SSL certificate") + parser.add_argument("-k", "--ssl-key", action="store", type=str, + default=config["SSLKey"], dest="ssl_key", + help="set server SSL private key") + parser.add_argument("-l", "--log-filename", action="store", type=str, + default=config["LogFilename"], dest="log_filename", + help="set server log filename") + parser.add_argument("--log-to-file", action="store", type=typebool, + default=config["LogToFile"], dest="log_to_file", + help="log output to file") + parser.add_argument("--log-to-console", action="store", type=typebool, + default=config["LogToConsole"], dest="log_to_console", + help="log output to console") + parser.add_argument("--log-to-window", action="store", type=typebool, + default=config["LogToWindow"], dest="log_to_window", + help="log output to a new window") + parser.add_argument("--dry-run", action="store_true", + dest="dry_run", + help="dry run to test the server") + parser.add_argument("-t", "--no-timeout", action="store_true", + dest="no_timeout", + help="disable player timeouts") + if "MaxThread" in config: + parser.add_argument("--max-thread", action="store", type=int, + default=config["MaxThread"], dest="max_thread", + help="log output to a new window") + return parser diff --git a/other/utils.py b/other/utils.py index b55bb61..ed1b3d0 100644 --- a/other/utils.py +++ b/other/utils.py @@ -13,19 +13,18 @@ from collections import namedtuple from functools import partial from logging.handlers import TimedRotatingFileHandler +from other.config import ServerConfig, argparse_from_config from other.debug import register_debug_signal, dry_run try: # Python 2 basestring # str, unicode - import ConfigParser except NameError: # Python 3 basestring = str - import configparser as ConfigParser from typing import Any # noqa: F401 -CONFIG_FILE = "config.ini" + LOG_FOLDER = "logs" @@ -216,59 +215,6 @@ def create_logger(name, level=logging.DEBUG, log_to_file="", return logger -def get_config(name, config_file=CONFIG_FILE): - """Get server config.""" - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - return { - "IP": config.get(name, "IP"), - "ExternalIP": config.get(name, "ExternalIP"), - "Port": config.getint(name, "Port"), - "Name": config.get(name, "Name"), - "MaxThread": config.getint(name, "MaxThread"), - "UseSSL": config.getboolean(name, "UseSSL"), - "SSLCert": - config.get(name, "SSLCert") or - config.get("SSL", "DefaultCert"), - "SSLKey": - config.get(name, "SSLKey") or - config.get("SSL", "DefaultKey"), - "LogFilename": config.get(name, "LogFilename"), - "LogToConsole": config.getboolean(name, "LogToConsole"), - "LogToFile": config.getboolean(name, "LogToFile"), - "LogToWindow": config.getboolean(name, "LogToWindow"), - } - - -def get_mysql_config(name, config_file=CONFIG_FILE): - """Get MySQL config.""" - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - ssl_ca = config.get(name, "ssl_ca") or None - from mysql.connector.constants import ClientFlag - return { - "charset": "utf8", - "autocommit": True, - "user": config.get(name, "User"), - "password": config.get(name, "Password"), - "host": config.get(name, "Host"), - "database": config.get(name, "database"), - "client_flags": [ClientFlag.SSL] if ssl_ca else None, - "ssl_ca": ssl_ca, - "ssl_cert": config.get(name, "ssl_cert") or None, - "ssl_key": config.get(name, "ssl_key") or None - } - - -def is_mysql_enabled(name, config_file=CONFIG_FILE): - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - return config.getboolean(name, "Enabled") - - -# TODO: Backport latest_patch and central config code - - def get_default_ip(): # type: () -> str """Get the default IP address""" @@ -295,69 +241,6 @@ def get_external_ip(config): return config["ExternalIP"] or get_ip(config["IP"]) -def argparse_from_config(config): - """Argument parser from config.""" - import argparse - - def typebool(s): - if isinstance(s, bool): - return s - s = s.lower() - if s in ("on", "yes", "y", "true", "t", "1"): - return True - elif s in ("off", "no", "n", "false", "f", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("-i", "--interactive", action="store_true", - dest="interactive", - help="create an interactive shell") - parser.add_argument("-d", "--debug_mode", action="store_true", - dest="debug_mode", - help="enable debug mode, disabling timeouts and \ - lower logging verbosity level") - parser.add_argument("-a", "--address", action="store", type=str, - default=config["IP"], dest="address", - help="set server address") - parser.add_argument("-p", "--port", action="store", type=int, - default=config["Port"], dest="port", - help="set server port") - parser.add_argument("-n", "--name", action="store", type=str, - default=config["Name"], dest="name", - help="set server name") - parser.add_argument("-s", "--use-ssl", action="store", type=typebool, - default=config["UseSSL"], dest="use_ssl", - help="use SSL protocol") - parser.add_argument("-c", "--ssl-cert", action="store", type=str, - default=config["SSLCert"], dest="ssl_cert", - help="set server SSL certificate") - parser.add_argument("-k", "--ssl-key", action="store", type=str, - default=config["SSLKey"], dest="ssl_key", - help="set server SSL private key") - parser.add_argument("-l", "--log-filename", action="store", type=str, - default=config["LogFilename"], dest="log_filename", - help="set server log filename") - parser.add_argument("--log-to-file", action="store", type=typebool, - default=config["LogToFile"], dest="log_to_file", - help="log output to file") - parser.add_argument("--log-to-console", action="store", type=typebool, - default=config["LogToConsole"], dest="log_to_console", - help="log output to console") - parser.add_argument("--log-to-window", action="store", type=typebool, - default=config["LogToWindow"], dest="log_to_window", - help="log output to a new window") - parser.add_argument("--dry-run", action="store_true", - dest="dry_run", - help="dry run to test the server") - parser.add_argument("-t", "--no-timeout", action="store_true", - dest="no_timeout", - help="disable player timeouts") - return parser - - def wii_ssl_wrap_socket(sock, ssl_cert, ssl_key): """SSL wrapper for network sockets aiming Wii compatibility. @@ -398,7 +281,7 @@ def create_server(server_class, server_handler, address="0.0.0.0", port=8200, name="Server", max_thread=0, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", log_to_file=True, log_filename="server.log", - log_to_console=True, log_to_window=False, legacy_ssl=False, + log_to_console=True, log_to_window=False, debug_mode=False, no_timeout=False): """Create a server, its logger and the SSL context if needed.""" logger = create_logger( @@ -409,9 +292,11 @@ def create_server(server_class, server_handler, if not use_ssl: ssl_cert = None ssl_key = None - return server_class((address, port), server_handler, max_thread, logger, - debug_mode, ssl_cert=ssl_cert, ssl_key=ssl_key, - no_timeout=no_timeout) + return server_class( + (address, port), server_handler, + max_thread_count=max_thread, logger=logger, debug_mode=debug_mode, + ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout + ) server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) @@ -422,7 +307,7 @@ def create_server_from_base(name, server_class, server_handler, args=None): If args is None, sys.argv is used (see ArgumentParser.parser_args). """ - config = get_config(name) + config = ServerConfig(name) # TODO: Backport central config code if needed parser = argparse_from_config(config) args = parser.parse_args(args) From aeea6eb881e37ef927b3eb06ed7d6610aec18fd1 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sat, 1 Nov 2025 20:19:08 +0400 Subject: [PATCH 3/7] other: Add a comment regarding TIME_STATE constant --- mh/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mh/constants.py b/mh/constants.py index b3efc77..1cda8f1 100644 --- a/mh/constants.py +++ b/mh/constants.py @@ -280,7 +280,7 @@ def make_binary_trading_post(): FMP_VERSION = 1 # TODO: Backport central and NATNEG constants -TIME_STATE = 0 +TIME_STATE = 0 # FIXME: Unused since quest rotation added, see time_utils IS_JAP = False From c461fae7668ed1f2618c0e7795afbc3d5300d09a Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sat, 1 Nov 2025 20:54:20 +0400 Subject: [PATCH 4/7] other: Add config field to enable/disable servers --- config.ini | 4 ++++ master_server.py | 7 ++++--- mh/pat.py | 2 +- mh/server.py | 2 +- other/config.py | 33 ++++++++++++++++++++++++--------- other/python.py | 10 ++++++++++ other/utils.py | 20 +++++++++++++++----- 7 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 other/python.py diff --git a/config.ini b/config.ini index 5d8edc5..e8ffdff 100644 --- a/config.ini +++ b/config.ini @@ -13,6 +13,7 @@ ssl_cert = ssl_key = [OPN] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8200 @@ -27,6 +28,7 @@ LogToFile = ON LogToWindow = ON [LMP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8201 @@ -41,6 +43,7 @@ LogToFile = ON LogToWindow = ON [FMP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8202 @@ -55,6 +58,7 @@ LogToFile = ON LogToWindow = ON [RFP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8203 diff --git a/master_server.py b/master_server.py index dbe650d..47c187b 100644 --- a/master_server.py +++ b/master_server.py @@ -23,8 +23,9 @@ def create_servers(server_args): has_ui = False for module in (OPN, LMP, FMP, RFP): server, args = create_server_from_base(*module.BASE, args=server_args) - has_ui = has_ui or args.log_to_window - servers.append(server) + if server: + has_ui = has_ui or args.log_to_window + servers.append(server) return servers, has_ui @@ -73,7 +74,7 @@ def interactive_mode(local=locals()): if args.interactive: t.join() except KeyboardInterrupt: - print("Interrupt key was pressed, closing server...") + print("Interrupt key was pressed, closing servers...") except Exception: print('Unexpected exception caught...') traceback.print_exc() diff --git a/mh/pat.py b/mh/pat.py index a4db9d7..9df0753 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -42,7 +42,7 @@ class PatServer(server.BasicPatServer, Logger): def __init__(self, address, handler_class, max_thread_count=0, logger=None, debug_mode=False, ssl_cert=None, ssl_key=None, - no_timeout=False): + no_timeout=False, **kwargs): server.BasicPatServer.__init__( self, address, handler_class, max_thread_count, ssl_cert=ssl_cert, ssl_key=ssl_key diff --git a/mh/server.py b/mh/server.py index 8ba056d..0e68c7c 100644 --- a/mh/server.py +++ b/mh/server.py @@ -374,4 +374,4 @@ def close(self): self.selector = None self.worker_threads = [] self.__shutdown_request = False - self.info('Server Closed') + self.info('Server closed') diff --git a/other/config.py b/other/config.py index 9ce5bd0..61d9ded 100644 --- a/other/config.py +++ b/other/config.py @@ -6,21 +6,17 @@ from collections import OrderedDict -try: - from collections.abc import Callable # noqa: F401 - from typing import TYPE_CHECKING, Any - IS_PYTHON2 = False -except ImportError: - TYPE_CHECKING = False # type: ignore - IS_PYTHON2 = True # type: ignore +from other.python import PYTHON_VERSION, TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Callable # noqa: F401 + from typing import Any OrderedDictT = OrderedDict[str, Any] from configparser import RawConfigParser from argparse import ArgumentParser # noqa: F401 else: # workaround to avoid some type hinting issues OrderedDictT = OrderedDict - if IS_PYTHON2: + if PYTHON_VERSION == 2: from ConfigParser import RawConfigParser else: from configparser import RawConfigParser @@ -112,8 +108,9 @@ def reload(self): class BaseServerConfig(ConfigSection): + SERVER_NAMES = tuple() # type: tuple[str, ...] INT = ("Port",) - BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow") + BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow", "Enabled") STR = ("IP", "ExternalIP", "Name", "LogFilename") SP = { "SSLCert": @@ -124,9 +121,18 @@ class BaseServerConfig(ConfigSection): or cfg.get("SSL", "DefaultKey") } + def to_argument_parser(self): + # type: () -> ArgumentParser + """Create a generic ArgumentParser from the server config. + + It can be used in a main function to parse command-line arguments.""" + # TODO: Move the code here when the refactoring is completed + return argparse_from_config(self) + class ServerConfig(BaseServerConfig): """OPN/LMP/FMP/RFP server config.""" + SERVER_NAMES = ("OPN", "LMP", "FMP", "RFP") INT = BaseServerConfig.INT + ("MaxThread",) @@ -164,6 +170,15 @@ def connect_kwargs(self): # TODO: Backport latest_patch and central config code +def config_from_name(name): + # type: (str) -> ServerConfig | CentralConfig + """Return the server config based on its name.""" + if name in ServerConfig.SERVER_NAMES: + return ServerConfig(name) + else: + raise NotImplementedError() + + def argparse_from_config(config): # type: (BaseServerConfig) -> ArgumentParser """Argument parser from config.""" diff --git a/other/python.py b/other/python.py new file mode 100644 index 0000000..5974a14 --- /dev/null +++ b/other/python.py @@ -0,0 +1,10 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Python 2/3 helpers.""" + +import sys + +PYTHON_VERSION = sys.version_info.major +TYPE_CHECKING = False diff --git a/other/utils.py b/other/utils.py index ed1b3d0..495c75b 100644 --- a/other/utils.py +++ b/other/utils.py @@ -13,7 +13,7 @@ from collections import namedtuple from functools import partial from logging.handlers import TimedRotatingFileHandler -from other.config import ServerConfig, argparse_from_config +from other.config import config_from_name from other.debug import register_debug_signal, dry_run try: @@ -32,34 +32,40 @@ class Logger(object): """Generic logging class.""" def set_logger(self, logger): + # type: (Logger) -> None """Set logger.""" self.logger = logger def debug(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a debug message.""" if not hasattr(self, "logger"): return return self.logger.debug(msg, *args, **kwargs) def info(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a message.""" if not hasattr(self, "logger"): return return self.logger.info(msg, *args, **kwargs) def warning(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a warning message.""" if not hasattr(self, "logger"): return return self.logger.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log an error message.""" if not hasattr(self, "logger"): return return self.logger.error(msg, *args, **kwargs) def critical(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a critical message.""" if not hasattr(self, "logger"): return @@ -282,7 +288,7 @@ def create_server(server_class, server_handler, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", log_to_file=True, log_filename="server.log", log_to_console=True, log_to_window=False, - debug_mode=False, no_timeout=False): + debug_mode=False, no_timeout=False, **kwargs): """Create a server, its logger and the SSL context if needed.""" logger = create_logger( name, level=logging.DEBUG if debug_mode else logging.INFO, @@ -295,7 +301,8 @@ def create_server(server_class, server_handler, return server_class( (address, port), server_handler, max_thread_count=max_thread, logger=logger, debug_mode=debug_mode, - ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout + ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout, + **kwargs ) @@ -307,9 +314,11 @@ def create_server_from_base(name, server_class, server_handler, args=None): If args is None, sys.argv is used (see ArgumentParser.parser_args). """ - config = ServerConfig(name) + config = config_from_name(name) + if not config["Enabled"]: + return None, args # TODO: Backport central config code if needed - parser = argparse_from_config(config) + parser = config.to_argument_parser() args = parser.parse_args(args) kwargs = { k: v for k, v in vars(args).items() @@ -324,6 +333,7 @@ def server_main(name, server_class, server_handler): server, args = create_server_from_base(name, server_class, server_handler) + assert server, "Server disabled by the config file" try: import threading From a45b0a25833ea6e4d2b7dcc725617a4cc585954d Mon Sep 17 00:00:00 2001 From: Sepalani Date: Wed, 5 Nov 2025 17:55:44 +0400 Subject: [PATCH 5/7] Fix Thread.join() edge cases Add some more type hints --- master_server.py | 48 ++++++++++++++++++++++++++++--------------- mh/server.py | 16 ++++++--------- other/debug.py | 49 +++++++++++++++++++++++++++++--------------- other/ui.py | 29 +++++++++++++++++++------- other/utils.py | 53 +++++++++++++++++++++++++++++++++++------------- 5 files changed, 132 insertions(+), 63 deletions(-) diff --git a/master_server.py b/master_server.py index 47c187b..eac27d9 100644 --- a/master_server.py +++ b/master_server.py @@ -14,22 +14,33 @@ import rfp_server as RFP from other.debug import register_debug_signal, dry_run +from other.python import TYPE_CHECKING from other.utils import create_server_from_base +if TYPE_CHECKING: + from argparse import Namespace # noqa: F401 + from collections.abc import Sequence # noqa: F401 + from typing import Any # noqa: F401 + + from mh.server import BasicPatServer # noqa: F401 + def create_servers(server_args): + # type: (Sequence[str]) -> tuple[list[BasicPatServer], bool] """Create servers and check if it has ui.""" - servers = [] + servers = [] # type: list[BasicPatServer] has_ui = False for module in (OPN, LMP, FMP, RFP): - server, args = create_server_from_base(*module.BASE, args=server_args) - if server: + server, args = create_server_from_base(*module.BASE, + cmd_args=server_args) # type: ignore[misc] # noqa: E501 + if server and args: has_ui = has_ui or args.log_to_window servers.append(server) return servers, has_ui def main(args): + # type: (Namespace) -> None """Master server main function.""" register_debug_signal() @@ -42,37 +53,40 @@ def main(args): for server in servers ] # TODO: Backport cache's logic (i.e. new thread, maintain_connection) - for thread in threads: - thread.start() def interactive_mode(local=locals()): + # type: (dict[str, Any]) -> None """Run an interactive python interpreter in another thread.""" import code code.interact(local=local) + repl_thread = threading.Thread(target=interactive_mode) + if has_ui: from other.ui import update as ui_update - ui_update() + else: + def ui_update(): + pass try: + ui_update() + for server_thread in threads: + server_thread.start() + if args.interactive: - t = threading.Thread(target=interactive_mode) - t.start() + repl_thread.start() if args.dry_run: dry_run() while threads: - for thread in threads: - if has_ui: - ui_update() - if not thread.is_alive(): - threads.remove(thread) + for server_thread in threads: + ui_update() + if not server_thread.is_alive(): + threads.remove(server_thread) break - thread.join(0.1) + server_thread.join(0.1) - if args.interactive: - t.join() except KeyboardInterrupt: print("Interrupt key was pressed, closing servers...") except Exception: @@ -82,6 +96,8 @@ def interactive_mode(local=locals()): finally: for server in servers: server.close() + if args.interactive and repl_thread.is_alive(): + repl_thread.join() if __name__ == "__main__": diff --git a/mh/server.py b/mh/server.py index 0e68c7c..dda20ec 100644 --- a/mh/server.py +++ b/mh/server.py @@ -12,14 +12,13 @@ import traceback from mh.time_utils import Timer +from other.python import PYTHON_VERSION, TYPE_CHECKING from other.utils import wii_ssl_wrap_socket -try: - # Python 3 +if TYPE_CHECKING or PYTHON_VERSION == 3: import queue import selectors -except ImportError: - # Python 2 +else: import Queue as queue import externals.selectors2 as selectors @@ -183,11 +182,7 @@ def fileno(self): return self.socket.fileno() def initialize_workers(self): - """Initialize workers queues/threads. - - This needs to be deferred, otherwise the close method might try to - join threads that aren't started yet when an error occurs early. - """ + """Initialize workers queues/threads.""" for n in range(self.max_threads): thread_queue = queue.Queue() thread = threading.Thread( @@ -368,7 +363,8 @@ def close(self): q.put((None, None, None), block=True) for t in self.worker_threads: - t.join() + if t.is_alive(): + t.join() self.worker_queues = [] self.selector = None diff --git a/other/debug.py b/other/debug.py index ad404b2..0574f7a 100644 --- a/other/debug.py +++ b/other/debug.py @@ -17,17 +17,23 @@ import traceback import signal -try: - # Python 2 - import ConfigParser -except ImportError: - # Python 3 - import configparser as ConfigParser +from other.python import PYTHON_VERSION, TYPE_CHECKING + +if TYPE_CHECKING or PYTHON_VERSION == 3: + from configparser import RawConfigParser +else: + from ConfigParser import RawConfigParser + +if TYPE_CHECKING: + from collections.abc import Callable # noqa: F401 + from types import FrameType # noqa: F401 + from typing import Any # noqa: F401 DEBUG_INI_PATH = "debug.ini" def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): + # type: (int, FrameType, str, str, **Any) -> None """Handler for debugpy on Visual Studio and VS Code. References: @@ -35,7 +41,7 @@ def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): https://code.visualstudio.com/docs/python/debugging https://learn.microsoft.com/visualstudio/python/debugging-python-in-visual-studio """ - import debugpy + import debugpy # type: ignore s = (addr, int(port)) # config's items are str debugpy.listen(s) print("Waiting for client on {}:{}\n".format(*s)) @@ -43,6 +49,7 @@ def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): def trepan_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Handler for trepan2/trepan3k. References: @@ -50,22 +57,24 @@ def trepan_handler(sig, frame, **kwargs): https://github.com/rocky/python3-trepan/ https://python2-trepan.readthedocs.io/en/latest/entry-exit.html """ - from trepan.api import debug + from trepan.api import debug # type: ignore debug() def pudb_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Handler for pudb on Linux and Cygwin. References: https://github.com/inducer/pudb https://documen.tician.de/pudb/ """ - import pudb + import pudb # type: ignore pudb.set_trace() def breakpoint_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """PDB/breakpoint handler. References: @@ -80,6 +89,7 @@ def breakpoint_handler(sig, frame, **kwargs): def code_interact_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Python interpreter handler. References: @@ -99,25 +109,29 @@ def code_interact_handler(sig, frame, **kwargs): "PUDB": pudb_handler, "BREAKPOINT": breakpoint_handler, "CODE": code_interact_handler -} +} # type: dict[str, Callable[..., Any]] def load_config(path=DEBUG_INI_PATH): - config = ConfigParser.RawConfigParser() + # type: (str) -> RawConfigParser + config = RawConfigParser() config.read(path) return config def load_handler_config(name, config): + # type: (str, RawConfigParser) -> dict[str, str] | None if config.has_section(name) and config.getboolean(name, "Enabled"): return { k.lower(): v for k, v in config.items(name) if k.lower() != "enabled" } + return None def debug_signal_handler(sig, frame): + # type: (int, FrameType) -> None """Default debug signal handler. Might raise EINTR/IOError when occuring during some syscalls on Python 2. @@ -148,6 +162,7 @@ def debug_signal_handler(sig, frame): def register_debug_signal(fn=debug_signal_handler): + # type: (Callable[..., Any]) -> None """Register a debug handler on SIGBREAK (Windows) or SIGUSR1 (Linux). On Windows, press CTRL+Pause/Break to trigger. @@ -156,17 +171,19 @@ def register_debug_signal(fn=debug_signal_handler): Will raise ValueError exception if not called from the main thread. """ if hasattr(signal, "SIGBREAK"): - signal.signal(signal.SIGBREAK, fn) + signal.signal(signal.SIGBREAK, fn) # type: ignore else: - signal.signal(signal.SIGUSR1, fn) + signal.signal(signal.SIGUSR1, fn) # type: ignore def dry_run(delay=10.): + # type: (float) -> None """Dry run test.""" from threading import Timer - try: - from thread import interrupt_main - except ImportError: + + if TYPE_CHECKING or PYTHON_VERSION == 3: from _thread import interrupt_main + else: + from thread import interrupt_main Timer(delay, interrupt_main).start() diff --git a/other/ui.py b/other/ui.py index 7addf11..ef11f76 100644 --- a/other/ui.py +++ b/other/ui.py @@ -5,19 +5,28 @@ """UI helper module.""" import logging -try: - # Python 3.x + +from other.python import PYTHON_VERSION, TYPE_CHECKING + +if TYPE_CHECKING or PYTHON_VERSION == 3: import tkinter as tk import tkinter.scrolledtext as ScrolledText from queue import Queue -except ImportError: - # Python 2.x +elif PYTHON_VERSION == 2: import Tkinter as tk import ScrolledText from Queue import Queue -WINDOWS = [] -EMITTERS = Queue() +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any + QueueT = Queue[Callable[[], Any]] +else: + QueueT = Queue + + +WINDOWS = [] # type: list[LoggerTk] +EMITTERS = QueueT() class LoggingHandler(logging.Handler): @@ -29,10 +38,12 @@ class LoggingHandler(logging.Handler): """ def __init__(self, text): + # type: (ScrolledText.ScrolledText) -> None logging.Handler.__init__(self) self.text = text def emit(self, record): + # type: (logging.LogRecord) -> None msg = self.format(record) def append(): @@ -40,7 +51,8 @@ def append(): self.text.configure(state='normal') self.text.insert(tk.END, msg + '\n') self.text.configure(state='disabled') - self.text.yview(tk.END) # Autoscroll to the bottom + # Autoscroll to the bottom + self.text.yview(tk.END) # type: ignore # Won't work on Python3.x # self.text.after(0, append) @@ -51,6 +63,7 @@ class LoggerTk(tk.Tk): """Create a logging window.""" def __init__(self, *args, **kwargs): + # type: (*Any, **Any) -> None tk.Tk.__init__(self, *args, **kwargs) text = ScrolledText.ScrolledText(self, state='disabled') @@ -63,10 +76,12 @@ def __init__(self, *args, **kwargs): self.handler = LoggingHandler(text) def get_handler(self): + # type: () -> LoggingHandler """Return the window's logging.Handler instance.""" return self.handler def set_logger(self, logger): + # type: (logging.Logger) -> None """Add the window's logging.Handler to the logger.""" logger.addHandler(self.handler) diff --git a/other/utils.py b/other/utils.py index 495c75b..7f5f220 100644 --- a/other/utils.py +++ b/other/utils.py @@ -15,14 +15,17 @@ from logging.handlers import TimedRotatingFileHandler from other.config import config_from_name from other.debug import register_debug_signal, dry_run +from other.python import PYTHON_VERSION, TYPE_CHECKING -try: - # Python 2 - basestring # str, unicode -except NameError: - # Python 3 - basestring = str - from typing import Any # noqa: F401 +if TYPE_CHECKING or PYTHON_VERSION == 3: + basestring = str # Python 2: str, unicode +if TYPE_CHECKING: + from argparse import Namespace # noqa: F401 + from collections.abc import Sequence # noqa: F401 + from ssl import SSLSocket # noqa: F401 + from typing import Any, NamedTuple # noqa: F401 + + from mh.pat import PatServer, PatRequestHandler LOG_FOLDER = "logs" @@ -137,6 +140,7 @@ def handler(self, name, unpack_function, pack_function, def to_bytearray(data): + # type: (Any) -> bytearray """Python2/3 bytearray helper.""" if isinstance(data, basestring): return bytearray((ord(c) % 256 for c in data)) @@ -147,10 +151,12 @@ def to_bytearray(data): def to_bytes(data): + # type: (Any) -> bytes return bytes(to_bytearray(data)) def to_str(data): + # type: (Any) -> str """Python2/3 str helper.""" if isinstance(data, str): return data @@ -158,20 +164,24 @@ def to_str(data): def pad(s, size, p=b'\0'): + # type: (bytes, int, bytes) -> bytearray data = bytearray(s + p * max(0, size-len(s))) data[-1] = 0 return data def hexdump(data): + # type: (bytes | bytearray) -> str """Get data hexdump.""" data = bytearray(data) line_format = "{line:08x} | {hex:47} | {ascii}" def hex_helper(b): + # type: (int) -> str return "{:02x}".format(b) def ascii_helper(b): + # type: (int) -> str return chr(b) if 0x20 <= b < 0x7F else '.' return "\n".join( @@ -186,6 +196,7 @@ def ascii_helper(b): def create_logger(name, level=logging.DEBUG, log_to_file="", log_to_console=False, log_to_window=False): + # type: (str, int, str, bool, bool) -> logging.Logger """Create a logger.""" logger = logging.getLogger(name) logger.setLevel(level) @@ -248,6 +259,7 @@ def get_external_ip(config): def wii_ssl_wrap_socket(sock, ssl_cert, ssl_key): + # type: (socket.socket, str, str) -> SSLSocket """SSL wrapper for network sockets aiming Wii compatibility. References: @@ -289,6 +301,7 @@ def create_server(server_class, server_handler, log_to_file=True, log_filename="server.log", log_to_console=True, log_to_window=False, debug_mode=False, no_timeout=False, **kwargs): + # type: (type[PatServer], type[PatRequestHandler], str, int, str, int, bool, str | None, str | None, bool, str, bool, bool, bool, bool, **Any) -> PatServer # noqa: E501 """Create a server, its logger and the SSL context if needed.""" logger = create_logger( name, level=logging.DEBUG if debug_mode else logging.INFO, @@ -306,20 +319,29 @@ def create_server(server_class, server_handler, ) -server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) +if TYPE_CHECKING: + # Python 2 doesn't support the class syntax + server_base = NamedTuple("server_base", [ + ("name", str), + ("cls", type[PatServer]), + ("handler", type[PatRequestHandler]) + ]) +else: + server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) -def create_server_from_base(name, server_class, server_handler, args=None): +def create_server_from_base(name, server_class, server_handler, cmd_args=None): + # type: (str, type[PatServer], type[PatRequestHandler], Sequence[str] | None) -> tuple[PatServer, Namespace] | tuple[None, None] # noqa: E501 """Create a server based on its config parameters and supplied args. If args is None, sys.argv is used (see ArgumentParser.parser_args). """ config = config_from_name(name) if not config["Enabled"]: - return None, args + return None, None # TODO: Backport central config code if needed parser = config.to_argument_parser() - args = parser.parse_args(args) + args = parser.parse_args(cmd_args) kwargs = { k: v for k, v in vars(args).items() if k not in ("interactive", "dry_run") @@ -328,12 +350,13 @@ def create_server_from_base(name, server_class, server_handler, args=None): def server_main(name, server_class, server_handler): + # type: (str, type[PatServer], type[PatRequestHandler]) -> None """Create a server main based on its config parameters.""" register_debug_signal() server, args = create_server_from_base(name, server_class, server_handler) - assert server, "Server disabled by the config file" + assert server and args, "Server disabled by the config file" try: import threading @@ -350,11 +373,13 @@ def server_main(name, server_class, server_handler): if args.log_to_window: from other.ui import update as ui_update + else: + def ui_update(): + pass while thread.is_alive(): thread.join(0.1) # Timeout allows main thread to handle signals - if args.log_to_window: - ui_update() + ui_update() except KeyboardInterrupt: server.info("Interrupt key was pressed, closing server...") except Exception: From b5faee3dc591baa75e800a164ad1011474f8aab9 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sun, 9 Nov 2025 00:29:45 +0400 Subject: [PATCH 6/7] Refactor the network code and add a net_utils module Try to address several issues we have, as follows: - DRY principle - more documentation (docstrings, type-hints) - less `try..except` clutter and error silencing - more logging - unused variables removed - avoid RuntimeError while iterating+deleting items from a container - use SO_REUSEADDR to address some (not all) TIME_WAIT issues --- fmp_server.py | 4 +- master_server.py | 31 ++- mh/pat.py | 60 +++-- mh/server.py | 468 +++++++++++--------------------- other/net_utils.py | 659 +++++++++++++++++++++++++++++++++++++++++++++ other/utils.py | 43 +-- 6 files changed, 877 insertions(+), 388 deletions(-) create mode 100644 other/net_utils.py diff --git a/fmp_server.py b/fmp_server.py index 7f1e1eb..16bc328 100644 --- a/fmp_server.py +++ b/fmp_server.py @@ -28,8 +28,8 @@ def __init__(self, *args, **kwargs): class FmpRequestHandler(PatRequestHandler): """Basic FMP server request handler class.""" - def on_init(self): - PatRequestHandler.on_init(self) + def setup(self): + super(FmpRequestHandler, self).setup() self.session = FMPSession(self.session) def recvAnsConnection(self, packet_id, data, seq): diff --git a/master_server.py b/master_server.py index eac27d9..6634553 100644 --- a/master_server.py +++ b/master_server.py @@ -45,13 +45,13 @@ def main(args): register_debug_signal() servers, has_ui = create_servers(server_args=args.args) - threads = [ - threading.Thread( + threads_map = { + server: threading.Thread( target=server.serve_forever, name="{}.serve_forever".format(server.__class__.__name__) ) for server in servers - ] + } # TODO: Backport cache's logic (i.e. new thread, maintain_connection) def interactive_mode(local=locals()): @@ -66,12 +66,15 @@ def interactive_mode(local=locals()): from other.ui import update as ui_update else: def ui_update(): + # type: () -> None + """No-op.""" pass + # noinspection PyBroadException try: ui_update() - for server_thread in threads: - server_thread.start() + for thread in threads_map.values(): + thread.start() if args.interactive: repl_thread.start() @@ -79,13 +82,15 @@ def ui_update(): if args.dry_run: dry_run() - while threads: - for server_thread in threads: + while threads_map: + for server, thread in threads_map.items(): ui_update() - if not server_thread.is_alive(): - threads.remove(server_thread) + if not thread.is_alive(): + server.error("Server thread died: %s", thread.name) + del threads_map[server] + # TODO: Add restart option? break - server_thread.join(0.1) + thread.join(0.1) except KeyboardInterrupt: print("Interrupt key was pressed, closing servers...") @@ -95,7 +100,8 @@ def ui_update(): sys.exit(1) finally: for server in servers: - server.close() + server.shutdown() + server.server_close() if args.interactive and repl_thread.is_alive(): repl_thread.join() @@ -114,5 +120,4 @@ def ui_update(): # - no_timeout is currently available as a server argument as follows: parser.add_argument("args", nargs='*', help="arguments forwarded to all servers") - args = parser.parse_args() - main(args) + main(parser.parse_args()) diff --git a/mh/pat.py b/mh/pat.py index 9df0753..d0ad976 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -9,7 +9,7 @@ from datetime import timedelta from other.config import ServerConfig -from other.utils import Logger, get_external_ip, hexdump, to_str +from other.utils import get_external_ip, hexdump, to_str from mh.quest_utils import QuestLoader import mh.pat_item as pati @@ -37,17 +37,12 @@ g_binary_loader = QuestLoader("event/quest_rotation.json") -class PatServer(server.BasicPatServer, Logger): +class PatServer(server.BasicPatServer): """Generic PAT server class.""" - def __init__(self, address, handler_class, max_thread_count=0, - logger=None, debug_mode=False, ssl_cert=None, ssl_key=None, + def __init__(self, address, handler_class, logger=None, debug_mode=False, no_timeout=False, **kwargs): - server.BasicPatServer.__init__( - self, address, handler_class, max_thread_count, - ssl_cert=ssl_cert, ssl_key=ssl_key - ) - Logger.__init__(self) + server.BasicPatServer.__init__(self, address, handler_class, **kwargs) if logger: self.set_logger(logger) self.info("Running on {} port {}".format(*address)) @@ -122,7 +117,7 @@ class PatRequestHandler(server.BasicPatHandler): inaccurate. `unk` stands for `unknown`. """ - def on_init(self): + def setup(self): """Default PAT handler.""" self.server.info("Handle client from {}".format(self.client_address)) self.server.add_to_debug(self) @@ -134,6 +129,7 @@ def on_init(self): self.game_id = None self.natneg_url = b"natneg1.gs.nintendowifi.net" self.game_patched = True + return super(PatRequestHandler, self).setup() def try_send_packet(self, packet_id=0, data=b'', seq=0): """Send PAT packet and catch exceptions.""" @@ -215,7 +211,9 @@ def helper(line): # message is too long. except Exception: # Probably unreachable and was disconnected - self.server.warning("Failed to send a complete error message") + self.server.warning( + "Failed to send a complete error message to %s", self + ) finally: self.session.request_reconnection = False self.finish() @@ -2845,25 +2843,32 @@ def notify_circle_leave(self, circle_index, seq): def on_exception(self, e): # type: (Exception) -> None - self.server.error(traceback.format_exc()) + self.server.error( + "Exception occurred during processing of %s:\n%s", + self, traceback.format_exc().rstrip('\n') + ) self.send_error("{}: {}".format(type(e).__name__, str(e))) - def on_finish(self): - if isinstance(self.session, FMPSession): - self.notify_layer_departure(True) - # TODO: Backport session_layer_end - self.session.disconnect() - self.session.delete() - - self.server.del_from_debug(self) - self.server.info("Client finished!") - - def on_packet(self, data): - if not data: - self.finish() + def finish(self): + if self.is_finished(): return + try: + if isinstance(self.session, FMPSession): + self.notify_layer_departure(True) + # TODO: Backport session_layer_end + self.session.disconnect() + self.session.delete() + except Exception: + self.server.error( + "Failed to finish %s:\n%s", self, + traceback.format_exc().rstrip('\n') + ) + finally: + super(PatRequestHandler, self).finish() + self.server.del_from_debug(self) + self.server.info("%s finished!", self) - packet_id, data, seq = data + def on_packet(self, packet_id, data, seq): self.server.info( "RECV %s[ID=%08x; Seq=%04x]", PAT_NAMES.get(packet_id, "Packet"), @@ -2873,6 +2878,7 @@ def on_packet(self, data): self.dispatch(packet_id, data, seq) def send_packet(self, packet_id=0, data=b'', seq=0): + # type: (int, bytes, int) -> None super(PatRequestHandler, self).send_packet(packet_id, data, seq) self.server.info( "SEND %s[ID=%08x; Seq=%04x]", @@ -2895,7 +2901,7 @@ def dispatch(self, packet_id, data, seq): handler = getattr(self, name) return handler(packet_id, data, seq) - def on_tick(self): + def on_send(self): if not self.requested_connection: # TODO: Investigate why do we need to wait a certain amount of # seconds before sending the `ReqConnection` packet? diff --git a/mh/server.py b/mh/server.py index dda20ec..e6b56a9 100644 --- a/mh/server.py +++ b/mh/server.py @@ -6,14 +6,15 @@ import multiprocessing import random -import socket -import struct +import sys import threading import traceback from mh.time_utils import Timer +from other.net_utils import \ + PacketHandler, WiiSSLHandlerMixIn, SelectorsBaseServer from other.python import PYTHON_VERSION, TYPE_CHECKING -from other.utils import wii_ssl_wrap_socket +from other.utils import Logger if TYPE_CHECKING or PYTHON_VERSION == 3: import queue @@ -22,169 +23,120 @@ import Queue as queue import externals.selectors2 as selectors +if TYPE_CHECKING: + from typing import Any -class BasicPatHandler(object): - def __init__(self, socket, client_address, server): - # type: (socket.socket, tuple[str, int], BasicPatServer) -> None - self.socket = socket - self.client_address = client_address - self.server = server - self.finished = False - self.rw = threading.Lock() - self.setup() - - def fileno(self): - # type: () -> int - return self.socket.fileno() - - def setup(self): - self.rfile = self.socket.makefile('rb', -1) - self.wfile = self.socket.makefile('wb', 0) - - self.on_init() - - def on_init(self): - """Called after setup""" - pass - - def on_exception(self, e): - # type: (Exception) -> None - """Called when during recv/write an exception ocurred""" - pass - - def on_recv(self): - """Called when the socket have bytes to be readed - - ** This method would be called by the server thread - - """ - header = self.rfile.read(8) - if not len(header): - # The socket was closed by externally - return None - - if len(header) < 8: - # Invalid packet header - return None - - return self.recv_packet(header) - - def on_packet(self, data): - """ Called when there is a packet to be handled - - ** This method would be called from a worker thread (Not Thread Safe) - - """ - - def recv_packet(self, header): - """Receive PAT packet.""" - size, seq, packet_id = struct.unpack(">HHI", header) - data = self.rfile.read(size) - return packet_id, data, seq + # noinspection PyCompatibility + PacketQueue = queue.Queue[ + tuple["BasicPatHandler", Any, int] + | tuple[None, None, None] + ] +else: + PacketQueue = queue.Queue - def send_packet(self, packet_id=0, data=b'', seq=0): - """Send PAT packet.""" - self.wfile.write(struct.pack( - ">HHI", - len(data), seq, packet_id - )) - if data: - self.wfile.write(data) - def on_tick(self): - """Called every time the server tick +class BasicPatHandler(WiiSSLHandlerMixIn, PacketHandler): # type: ignore[misc] + """Dummy queueable PAT packet handler class.""" - ** Currently executed from the server thread + # Prevents indefinite recv from invalid headers + timeout = 2.0 + # noinspection PyAttributeOutsideInit + def setup(self): + # type: () -> None + """Prepare the handler and its default properties.""" + self.worker_index = 0 + return super(BasicPatHandler, self).setup() + + def handle_packet(self, seq, packet_id, data): # type: ignore[override] + # type: (int, int, bytes) -> None + """Add the received packet to the server's queue.""" + assert isinstance(self.server, BasicPatServer) + self.server.queue_work(self, (packet_id, data, seq), + selectors.EVENT_READ) + + def send_packet(self, packet_id, data, seq): # type: ignore[override] + # type: (int, bytes, int) -> None + """Change parameters order to match the generic one.""" + return super(BasicPatHandler, self).send_packet(seq, packet_id, data) + + def on_packet(self, packet_id, data, seq): + # type: (int, bytes, int) -> None + """Called when there is a packet to be handled + + This method would be called from a worker thread (Not Thread Safe) """ pass - def on_finish(self): - """Called before finish""" - pass - - def is_finished(self): - return self.finished - def finish(self): - """Called when the handler is being disposed""" +class BasicPatServer(SelectorsBaseServer, Logger): + """Basic PAT packet server with worker threads.""" - if self.finished: - return - - try: - self.on_finish() - except Exception: - self.server.error(traceback.format_exc()) - - self.finished = True - - try: - self.wfile.close() - except Exception: - pass + # noinspection PyAttributeOutsideInit + def server_activate(self): + # type: () -> None + """Set the server default properties.""" + self._random = random.SystemRandom() + self.write_watch = Timer() + self.write_timeout = 1 # Seconds + self.worker_threads = [] # type: list[threading.Thread] + self.worker_queues = [] # type: list[PacketQueue] + self.max_thread = \ + self.kwargs.get("max_thread") or multiprocessing.cpu_count() + return super(BasicPatServer, self).server_activate() + # noinspection PyBroadException + def _worker_target(self, work_queue): + # type: (PacketQueue) -> None + """Worker thread main loop to handle a PacketQueue.""" try: - self.rfile.close() - except Exception: - pass - - -class BasicPatServer(object): - - socket_queue_size = 5 - - address_family = socket.AF_INET + while not self.is_shut_down(): + try: + handler, packet, event = work_queue.get(block=True) + except queue.Empty: + continue - socket_type = socket.SOCK_STREAM + if self.is_shut_down() or handler is None or \ + packet is None: # extra check for type checkers + break # shutting down - def __init__(self, server_address, RequestHandlerClass, max_threads, - bind_and_activate=True, ssl_cert=None, ssl_key=None): - # type: (tuple[str, int], BasicPatHandler, int, bool, str|None, str|None) -> None # noqa: E501 - """Constructor. May be extended, do not override.""" - self.server_address = server_address - self.RequestHandlerClass = RequestHandlerClass - self.__is_shut_down = threading.Event() - self.__is_shut_down.set() - self.__shutdown_request = False - self.socket = socket.socket(self.address_family, self.socket_type) - self._random = random.SystemRandom() # type: random.SystemRandom - self.handlers = [] # type: list[BasicPatHandler] - self.worker_threads = [] # type: list[threading.Thread] - self.worker_queues = [] # type: list[queue.queue] - self.selector = selectors.DefaultSelector() - self.max_threads = max_threads or multiprocessing.cpu_count() - self.ssl_cert = ssl_cert - self.ssl_key = ssl_key - # TODO: Backport change required by central/cache if any - - if bind_and_activate: - try: - self.server_bind() - self.server_activate() - except Exception: - self.close() - raise - - def server_bind(self): - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() + if handler.is_finished(): + continue - def server_activate(self): - self.socket.listen(0) + assert event == selectors.EVENT_READ - def fileno(self): - """Return server socket file descriptor. + try: + try: + handler.on_packet(*packet) + except Exception as e: + handler.on_exception(e) + + if handler.is_finished(): + self.shutdown_request(handler) + except: # noqa: E722 + self.error( + "Worker failure with %s:\n%s", handler, + traceback.format_exc().rstrip('\n') + ) + raise + finally: + self.info("Worker(%s) exiting...", + threading.current_thread().name) - Interface required by selector. + def queue_work(self, handler, work_data, event): + # type: (BasicPatHandler, Any, int) -> None + """Add a packet to the handler's PacketQueue.""" + if handler.is_finished(): + return - """ - return self.socket.fileno() + thread_queue = self.worker_queues[handler.worker_index] + thread_queue.put((handler, work_data, event), block=True) def initialize_workers(self): + # type: () -> None """Initialize workers queues/threads.""" - for n in range(self.max_threads): - thread_queue = queue.Queue() + for n in range(self.max_thread): + thread_queue = PacketQueue() thread = threading.Thread( target=self._worker_target, args=(thread_queue,), @@ -194,180 +146,84 @@ def initialize_workers(self): self.worker_threads.append(thread) thread.start() - def serve_forever(self): - self.__is_shut_down.clear() - try: - self.initialize_workers() - - with self.selector as selector: - selector.register(self, selectors.EVENT_READ) - - write_watch = Timer() - write_timeout = 1 # Seconds - while not self.__shutdown_request: - ready = selector.select(write_timeout) - if self.__shutdown_request: - break - - try: - for (key, event) in ready: - selected = key.fileobj - if selected == self: - self.accept_new_connection() - else: - assert event == selectors.EVENT_READ - try: - packet = selected.on_recv() - if packet is None: - if selected.is_finished(): - self.remove_handler(selected) - continue - - self._queue_work(selected, packet, event) - except Exception as e: - selected.on_exception(e) - if selected.is_finished(): - self.remove_handler(selected) - except Exception: - self.error(traceback.format_exc()) - - if write_watch.elapsed() >= write_timeout: - try: - for handler in self.handlers: - try: - handler.on_tick() - except Exception as e: - handler.on_exception(e) - - if handler.is_finished(): - self.remove_handler(handler) - except Exception: - self.error(traceback.format_exc()) - finally: - write_watch.restart() - except Exception: - self.error(traceback.format_exc()) - finally: - self.__is_shut_down.set() + def finish_request(self, handler, client_address): # type: ignore[override] # noqa: E501 + # type: (BasicPatHandler, tuple[str, int]) -> None + """Finish the request handler construction.""" + handler.worker_index = self._random.randint(0, + len(self.worker_queues)-1) + return super(BasicPatServer, self).finish_request(handler, + client_address) + + def serve_forever(self, poll_interval=0.5): + # type: (float) -> None + """Start the worker threads and the server main loop.""" + self.initialize_workers() + return super(BasicPatServer, self).serve_forever(poll_interval) + + # noinspection PyBroadException + def service_actions(self): + # type: () -> None + """Called on each server loop. - def _worker_target(self, work_queue): - # type: (queue.Queue) -> None + Alternative to monitor write events which is CPU intensive. - while not self.__shutdown_request: + Reminder: MUST NOT RAISE EXCEPTIONS. + """ + if self.write_watch.elapsed() >= self.write_timeout: try: - handler, packet, event = work_queue.get(block=True) - except queue.Empty: - continue - - if self.__shutdown_request: - break - - if handler.is_finished(): - continue - - assert event == selectors.EVENT_READ + for handler in self.get_handlers(): + assert isinstance(handler, BasicPatHandler) - try: - try: - handler.on_packet(packet) - except Exception as e: - handler.on_exception(e) + try: + handler.on_send() + except Exception as e: + handler.on_exception(e) - if handler.is_finished(): - self.remove_handler(handler) + if handler.is_finished(): + self.shutdown_request(handler) except Exception: - self.error(traceback.format_exc()) + self.error("%s", traceback.format_exc().rstrip('\n')) + finally: + self.write_watch.restart() + return super(BasicPatServer, self).service_actions() - def accept_new_connection(self): + def server_close(self): # type: () -> None - + """Clean up the server and its worker threads.""" try: - client_socket, client_address = self.socket.accept() - except Exception as e: - self.error('Error accepting connection (1). {}'.format(e)) - return - - try: - # TODO: Find a cleaner way to process ill-formed packets. - # Currently, they get stuck on `packet = selected.on_recv()`, - # thus blocking the `serve_forever` method. - client_socket.settimeout(2.0) - - # TODO: Ensure this is the correct way to fix the server not - # accepting SSL connection anymore. - # - # See https://stackoverflow.com/a/68214507 - if self.ssl_cert and self.ssl_key: - client_socket = wii_ssl_wrap_socket( - client_socket, self.ssl_cert, self.ssl_key - ) - handler = self.RequestHandlerClass(client_socket, client_address, - self) - except Exception as e: - self.error('Error accepting connection (2). {}'.format(e)) - return - - handler.__worker_thread = \ - self._random.randint(0, len(self.worker_queues)-1) - - self.selector.register(handler, selectors.EVENT_READ) - self.handlers.append(handler) - - def _queue_work(self, handler, work_data, event): - # type: (BasicPatHandler, any, int) -> None - if handler.is_finished(): - return - - thread_queue = self.worker_queues[handler.__worker_thread] - thread_queue.put((handler, work_data, event), block=True) - - def remove_handler(self, handler): - # type: (BasicPatHandler) -> None - try: - self.handlers.remove(handler) - except Exception: - pass + super(BasicPatServer, self).server_close() + finally: + if not hasattr(self, "worker_queues"): + return # Server startup interrupted (bind error?) - try: - self.selector.unregister(handler) - except Exception: - pass + for q in self.worker_queues: + q.put((None, None, None), block=True) - try: - handler.finish() - except Exception: - pass + for t in self.worker_threads: + if t.is_alive(): + t.join() - try: - handler.socket.close() - except Exception: - pass + self.worker_queues = [] + self.worker_threads = [] - def close(self): - """Called to clean-up the server. + self.info('Server closed') - May be overridden. + def handle_error(self, handler=None, client_address=None): # type: ignore[override] # noqa: E501 + # type: (None | BasicPatHandler, None | tuple[str, int]) -> None + """Custom error handler to use server's logger. + MUST NOT RAISE EXCEPTIONS. """ - self.__shutdown_request = True - self.socket.close() - self.__is_shut_down.wait() - - for h in self.handlers: - try: - h.finish() - except Exception: - pass - - for q in self.worker_queues: - q.put((None, None, None), block=True) - - for t in self.worker_threads: - if t.is_alive(): - t.join() - - self.worker_queues = [] - self.selector = None - self.worker_threads = [] - self.__shutdown_request = False - self.info('Server closed') + active_exception = sys.exc_info()[1] + # Handle client related exceptions + if handler and active_exception is not None: + handler.on_exception(active_exception) + return + # Handle server related exceptions + message = "Exception occurred during processing of {}".format( + handler if handler else "accepting client" + ) + if handler: + message += " from {}".format(client_address) if client_address \ + else " shutdown" + self.error("%s\n%s", message, traceback.format_exc().rstrip('\n')) diff --git a/other/net_utils.py b/other/net_utils.py new file mode 100644 index 0000000..0c49cee --- /dev/null +++ b/other/net_utils.py @@ -0,0 +1,659 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Network utils module.""" + +import os +import socket +import ssl +import struct +import sys +import traceback + +from other.python import PYTHON_VERSION, TYPE_CHECKING +from threading import Event + +if PYTHON_VERSION == 3: + import selectors + from socketserver import StreamRequestHandler, TCPServer +elif not TYPE_CHECKING: + import externals.selectors2 as selectors + from SocketServer import \ + StreamRequestHandler as StreamRequestHandler_, \ + TCPServer as TCPServer_ + + # Fix Python2 issue with super() + # noinspection PyMissingOrEmptyDocstring + class TCPServer(object, TCPServer_): + pass + + # noinspection PyMissingOrEmptyDocstring + class StreamRequestHandler(object, StreamRequestHandler_): + pass +else: + import selectors + from socketserver import StreamRequestHandler, TCPServer + from typing import Any, Self, TypeAlias # noqa: F401 + IpAddress = str + Port = int + ClientAddress = tuple[IpAddress, Port] + ServerAddress = tuple[IpAddress, Port] + SelectorsEvents = list[tuple[selectors.SelectorKey, int]] + + +class SelectorsRequestHandler(StreamRequestHandler): + """Custom StreamRequestHandler with selectors support. + + It differs from the regular StreamRequestHandler as it needs to maintain + the socket connection. It will call the handler's `setup` method but won't + call its `handle` and `finish` methods to keep it alive. + + N.B.: In this context, `finish` refers to the RequestHandler class. So the + `finish` method should be used to clean up resources this class created. + The server class is still responsible for cleaning up the request/socket. + + Class variables that may be overridden: + - rbufsize + - wbufsize + - timeout + - disable_nagle_algorithm + """ + + # noinspection PyMissingConstructor + def __init__(self, request, client_address, server): + # type: (socket.socket, ClientAddress, SelectorsBaseServer) -> None + """See socketserver.py source code for more information.""" + self.request = request + self.client_address = client_address + self.server = server + self.__finished = False + self.setup() + + def __str__(self): + # type: () -> str + """Custom string representation.""" + return "{}({}:{})".format(self.__class__.__name__, + *self.client_address) + + def fileno(self): + # type: () -> int + """Return the associated socket file descriptor.""" + return self.connection.fileno() + + def handle(self): + # type: () -> None + """Handle a single incoming request.""" + raise NotImplementedError("must be implemented in subclass") + + def finish(self): + # type: () -> None + """Terminate the request handler. + + MUST BE CALLED if overridden, failing to do so will disrupt the + `is_finished` method. + + MUST NOT RAISE EXCEPTIONS. + """ + self.__finished = True + try: + super(SelectorsRequestHandler, self).finish() + except socket.error as e: + # Closing wfile can sometimes raise an exception. + self.on_exception(e) + + def is_finished(self): + # type: () -> bool + """Check whether the connection is active.""" + # return self.wfile.closed + # ^ Not sure the above is reliable if super().finish() throws. + return self.__finished + + def on_recv(self): + # type: () -> Any + """Selectors read event callback.""" + pass + + def on_send(self): + # type: () -> None + """Selectors write event callback.""" + pass + + def on_exception(self, e): + # type: (BaseException) -> None + """Error callback that the server should call on error. + + MUST NOT RAISE EXCEPTIONS. + """ + pass + + +class InvalidHeader(Exception): + """PacketHandler invalid header exception.""" + pass + + +class PacketHandler(SelectorsRequestHandler): + """SelectorsRequestHandler with packet specialization. + + The default packet format is as follows (RequestHeader + data): + - (uint16_t) packet_size + - (uint16_t) packet_sequence + - (uint32_t) packet_id + - (bytes) packet_data + """ + + HEADER_SIZE = 8 + PACKET_FORMAT = ">HHI" + + if TYPE_CHECKING: + PacketSeq = int # type: TypeAlias + PacketId = int # type: TypeAlias + RequestHeader = tuple[int, PacketSeq, PacketId] # size, sequence, id + RequestPacket = tuple[PacketSeq, PacketId, bytes] # sequence, id, data + # ^ Ideally, should be: tuple[*RequestHeader[1:], bytes] + + def handle(self): + # type: () -> None + """Handle a single incoming packet.""" + packet = self.on_recv() + if packet: + self.handle_packet(*packet) + else: + self.finish() + + def on_recv(self): + # type: () -> None | PacketHandler.RequestPacket + """When read is available, check for incoming packet header.""" + header = self.rfile.read(self.HEADER_SIZE) + if not header: + # The socket was closed + return None + + if len(header) < self.HEADER_SIZE: + # Invalid packet header + raise InvalidHeader(header) + + return self.recv_packet(header) + + def on_exception(self, e): + # type: (BaseException) -> None + """Close the connection on error. + + Example: + ``` + if isinstance(e, (socket.error, InvalidHeader)): + # Implement error handling here + ``` + + Reminder: MUST NOT RAISE EXCEPTIONS. + """ + if not self.is_finished(): + # TODO: On non-fatal exception shouldn't we ignore it instead? + # noinspection PyBroadException + try: + self.finish() + except Exception: + pass + + def recv_packet(self, data): + # type: (bytes) -> PacketHandler.RequestPacket + """Receive packet based on the provided header data.""" + header = struct.unpack(self.PACKET_FORMAT, + data) # type: PacketHandler.RequestHeader + packet_size = header[0] + data = self.rfile.read(packet_size) + return header[1:] + (data,) + + def send_packet(self, *args): + # type: (*Any) -> None + """Construct and send packet header and data.""" + params = args[:-1] + data = args[-1] + self.wfile.write(struct.pack(self.PACKET_FORMAT, len(data), *params)) + if data: + self.wfile.write(data) + + def handle_packet(self, *args): + # type: (*Any) -> None + """Called when a packet is received. + + It takes as arguments the unpacked RequestHeader (except its size) and + the received data. + """ + raise NotImplementedError("must be implemented in subclass") + + +class SelectorsBaseServer(TCPServer): + """Custom BaseServer with selectors support. + + Compared to the original BaseServer, this implementation also monitors + request handler class using selectors.select(). + + While SocketServer's MixIn might work, they shouldn't be used because + they handle each request individually preventing them to be monitored + continuously. + + Class variables that may be overridden: + - request_queue_size + - allow_reuse_address + - allow_reuse_port # Python 3 only + + Instance variables: + - server_address + - handler_class + - socket + - selector + - args + - kwargs + """ + + request_queue_size = 0 + if os.name != "nt": + # Windows SO_REUSEADDR behaves differently + # (e.g. allows 2 servers to bind+listen on the same address:port) + # TODO: Confirm this behaviour on old versions of Windows (XP ~ 7) + allow_reuse_address = True + + # noinspection PyMissingConstructor + def __init__(self, server_address, handler_class, *args, **kwargs): + # type: (ServerAddress, type[SelectorsRequestHandler], Any, Any) -> None # noqa: E501 + """Custom flexible constructors.""" + self.server_address = server_address + self.handler_class = handler_class # type: type[SelectorsRequestHandler] # noqa: E501 + self.socket = socket.socket(self.address_family, self.socket_type) + self.selector = selectors.DefaultSelector() + self.args = args + self.kwargs = kwargs + self.__is_shut_down = Event() + self.__shutdown_request = False + try: + self.server_bind() + self.server_activate() + except: # All errors # noqa: E722 + self.server_close() + raise + + def __enter__(self): + # type: () -> Self + return self + + def __exit__(self, *args): + # type: (Any) -> None + self.server_close() + + def __on_connection(self): + # type: () -> None + """Handle a new connection when the server socket is readable.""" + try: + request, client_address = self.get_request() + except socket.error: + self.handle_error() + return + + # noinspection PyBroadException + try: + if self.verify_request(request, client_address): + self.process_request(request, client_address) + else: + self.shutdown_request(request) + except Exception: + self.handle_error(request, client_address) + self.shutdown_request(request) + except: # Fatal error # noqa: E722 + self.shutdown_request(request) + raise + + def __on_read_event(self, handler): + # type: (SelectorsRequestHandler) -> None + """Handle request handler read events.""" + # noinspection PyBroadException + try: + handler.handle() + except Exception: + self.handle_error(handler, handler.client_address) + finally: + if handler.is_finished(): + self.shutdown_request(handler) + + def __on_write_event(self, handler): + # type: (SelectorsRequestHandler) -> None + """Handle request handler write events. + + N.B.: Not monitored by default. + """ + # noinspection PyBroadException + try: + handler.on_send() + except Exception: + self.handle_error(handler, handler.client_address) + + def __handle_events(self, events, n=None): + # type: (SelectorsEvents, None | int) -> None + """Handle at most `n` events and call service_actions.""" + try: + for key, mask in events[:n]: + obj = key.fileobj + if obj == self and mask & selectors.EVENT_READ: + self.__on_connection() + continue + assert isinstance(obj, SelectorsRequestHandler) + if mask & selectors.EVENT_READ: + self.__on_read_event(obj) + if mask & selectors.EVENT_WRITE: + self.__on_write_event(obj) + finally: + self.service_actions() + + def get_handlers(self): + # type: () -> list[SelectorsRequestHandler] + """Create a new list populated with monitored handlers.""" + handlers = [] # type: list[SelectorsRequestHandler] + mapping = self.selector.get_map() + if mapping: + # This copy prevents us from iterating over a view + # while removing items from it, which is unsafe and + # raises an exception with Python 3. + for key in mapping.values(): + if key.fileobj == self: + continue + assert isinstance(key.fileobj, SelectorsRequestHandler) + handlers.append(key.fileobj) + return handlers + + def __select(self, timeout=None): + # type: (None | float) -> SelectorsEvents + """Safer wrapper around select.""" + try: + return self.selector.select(timeout) + except socket.error as e: + if not hasattr(e, "winerror") or e.winerror != 10038: + raise + # Operation attempted on something that is not a socket + for handler in self.get_handlers(): + if handler.is_finished(): + self.shutdown_request(handler) + return self.selector.select(timeout) + + def server_activate(self): + # type: () -> None + """Called by the constructor to activate the server.""" + super(SelectorsBaseServer, self).server_activate() # calls listen() + self.selector.register(self, selectors.EVENT_READ) + + def server_close(self): + # type: () -> None + """Called to clean up the server.""" + try: + for handler in self.get_handlers(): + # noinspection PyBroadException + try: + self.shutdown_request(handler) + except Exception: + self.handle_error(handler) + except: # Fatal error # noqa: E722 + # If we can't gracefully shut down a client, + # let's exit early to avoid further issues and + # their performance cost. + raise + finally: + self.selector.close() + super(SelectorsBaseServer, self).server_close() + + def get_request(self): # type: ignore[override] + # type: () -> tuple[SelectorsRequestHandler, ClientAddress] + """Get a single TCP connection request. + + Compared to TCPServer's, a compatible handler is returned instead + of a socket instance. + """ + accepted_socket, address = super(SelectorsBaseServer, + self).get_request() + try: + handler = self.handler_class(accepted_socket, address, self) + except: # noqa: E722 + # Prevent clients from hanging if we fail to create the handler + accepted_socket.close() + raise + return handler, address + + def verify_request(self, request, client_address): # type: ignore[override] # noqa: E501 + # type: (SelectorsRequestHandler, ClientAddress) -> bool + """Verify the request before processing it.""" + return True + + def process_request(self, handler, client_address): # type: ignore[override] # noqa: E501 + # type: (SelectorsRequestHandler, ClientAddress) -> None + """Call finish_request. + + In Python's original implementation, this method was overridden + by ForkingMixIn and ThreadingMixIn. + """ + self.finish_request(handler, client_address) + + def finish_request(self, handler, client_address): # type: ignore[override] # noqa: E501 + # type: (SelectorsRequestHandler, ClientAddress) -> None + """Called when the request handler creation is finished.""" + # FIXME: Is selector thread-safe? + self.selector.register(handler, selectors.EVENT_READ) + + def shutdown_request(self, handler): # type: ignore[override] + # type: (SelectorsRequestHandler) -> None + """Shutdown the request gracefully and close its file descriptor. + + MUST NOT RAISE EXCEPTIONS. + """ + try: + # Flush the socket + handler.finish() + # SHUT_WR should make future reads empty + # SHUT_RD might raise exception on future reads + handler.request.shutdown(socket.SHUT_WR) + except socket.error: + self.handle_error(handler) + finally: + self.close_request(handler) + + def close_request(self, handler): # type: ignore[override] + # type: (SelectorsRequestHandler) -> None + """Close the request file descriptor. + + MUST NOT RAISE EXCEPTIONS. + """ + try: + # FIXME: Is selector thread-safe? + self.selector.unregister(handler) + except (KeyError, ValueError): + # Might happen if an invalid handler/fd is provided. + # Can sometimes occur when finish_request raises an exception. + pass + finally: + handler.request.close() + + def serve_forever(self, poll_interval=0.5): + # type: (float) -> None + """Handle one request at a time until shutdown. + + Polls for shutdown every poll_interval seconds. Ignores + self.timeout. If you need to do periodic tasks, do them in + another thread. + """ + self.__shutdown_request = False + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + events = self.__select(poll_interval) + if self.__shutdown_request: + break + self.__handle_events(events) + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + # type: () -> None + """Stops the serve_forever loop. + + Blocks until the loop has finished. This must be called while + serve_forever() is running in another thread, or it will + deadlock. + """ + self.__shutdown_request = True + self.__is_shut_down.wait() + + def is_shut_down(self): + # type: () -> bool + """Returns True if the server is shut down (gracefully or not).""" + return self.__is_shut_down.is_set() + + def service_actions(self): + # type: () -> None + """Automatically called after requests handling. + + Called by serve_forever/handle_request via __handle_events. + + May be overridden by a subclass / Mixin to implement any code that + needs to be run during the loop. + + MUST NOT RAISE EXCEPTIONS. + """ + pass + + def handle_request(self, n=1): + # type: (int) -> None + """Handle at most `n` requests, possibly blocking. + + Respects self.timeout. + """ + try: + from time import monotonic as time # Python 3 only + except ImportError: + from time import time + + # Support people who used socket.settimeout() to escape + # handle_request before self.timeout was available. + timeout = self.socket.gettimeout() + if timeout is None: + timeout = self.timeout + elif self.timeout is not None: + timeout = min(timeout, self.timeout) + deadline = None if timeout is None else (time() + timeout) + + # Wait until a request arrives or the timeout expires - the loop is + # necessary to accommodate early wakeups due to EINTR. + while True: + events = self.__select(timeout) + if events: + return self.__handle_events(events, n=n) + if deadline is not None and time() >= deadline: + return self.handle_timeout() + + def handle_error(self, handler=None, client_address=None): # type: ignore[override] # noqa: E501 + # type: (None | SelectorsRequestHandler, None | ClientAddress) -> None + """Generic error handler, may be overridden. + + Some special cases: + 1. client_address is None: client's shutdown error + 2. handler is also None: server's accept error + + MUST NOT RAISE EXCEPTIONS. + """ + active_exception = sys.exc_info()[1] + if handler and active_exception is not None: + handler.on_exception(active_exception) + message = "Exception occurred during processing of {}".format( + handler if handler else "accepting client" + ) + if handler: + message += " from {}".format(client_address) if client_address \ + else " shutdown" + separators = '-' * 40 + sys.stderr.write("{sep}\n{msg}\n{err}{sep}\n".format( + sep=separators, msg=message, err=traceback.format_exc() + )) + sys.stderr.flush() + + +class SSLHandlerMixIn(object): + """SSL MixIn class.""" + + def get_ssl_context(self): + # type: () -> ssl.SSLContext + """Get SSL context, may be overridden.""" + return ssl.SSLContext(ssl.PROTOCOL_TLS) + + def __ssl_wrap_socket(self, sock): # pyright: ignore + # type: (socket.socket) -> socket.socket | ssl.SSLSocket + """Wrap the socket if SSL is enabled.""" + if not self.ssl_cert or not self.ssl_key: + return sock + + context = self.get_ssl_context() + if self.ssl_ca: + context.load_verify_locations(cafile=self.ssl_ca) + + context.load_cert_chain(self.ssl_cert, self.ssl_key) + + # Some Python versions might not enforce the timeout properly + timeout = getattr(self, "timeout", None) # type: None | float + sock.settimeout(timeout) + ssl_sock = context.wrap_socket(sock, server_side=True, + do_handshake_on_connect=False) + ssl_sock.settimeout(timeout) + try: + ssl_sock.do_handshake() + except: # noqa: E722 + ssl_sock.close() + raise + return ssl_sock + + # noinspection PyAttributeOutsideInit + def setup(self): + # type: () -> None + """Load SSL options from server.kwargs if SSL is enabled.""" + assert isinstance(self, SelectorsRequestHandler) + if hasattr(self.server, "kwargs"): + assert isinstance(self.server, SelectorsBaseServer) + kwargs = dict(self.server.kwargs) + self.ssl_ca = kwargs.get("ssl_ca") # type: None | str + self.ssl_cert = kwargs.get("ssl_cert") # type: None | str + self.ssl_key = kwargs.get("ssl_key") # type: None | str + # noinspection PyUnresolvedReferences + self.request = self.__ssl_wrap_socket(self.request) + super(SSLHandlerMixIn, self).setup() + + +class WiiSSLHandlerMixIn(SSLHandlerMixIn): + """SSL wrapper for network sockets aiming Wii compatibility. + + References: + https://docs.python.org/2.7/library/ssl.html + https://docs.python.org/3/library/ssl.html + https://www.openssl.org/docs/man1.0.2/man1/ciphers.html + https://www.openssl.org/docs/man1.1.1/man1/ciphers.html + https://www.openssl.org/docs/man3.0/man1/openssl-ciphers.html + """ + def get_ssl_context(self): + # type: () -> ssl.SSLContext + """Get a compatible Wii SSL context.""" + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + if hasattr(ssl, "TLSVersion"): # Since Python 3.7 + # Required since Python 3.10 + context.minimum_version = ssl.TLSVersion.SSLv3 + wii_ciphers = ":".join([ + "AES128-SHA", "AES256-SHA", + # The following ones are often unavailable + "DES-CBC-SHA", "3DES-CBC-SHA", + "RC4-MD5", "RC4-SHA" + # NB: Python might enforce additional (unsupported) ciphers + # for security reasons + # TODO: Disable them in Dolphin to emulate the Wii accurately + ]) + + # Try to enforce legacy ciphers/weak cert chain (OpenSSL >= 1.1 only) + if ssl.OPENSSL_VERSION_INFO >= (1, 1): + wii_ciphers += ":@SECLEVEL=0" + + context.set_ciphers(wii_ciphers) + return context diff --git a/other/utils.py b/other/utils.py index 7f5f220..45228e7 100644 --- a/other/utils.py +++ b/other/utils.py @@ -22,7 +22,6 @@ if TYPE_CHECKING: from argparse import Namespace # noqa: F401 from collections.abc import Sequence # noqa: F401 - from ssl import SSLSocket # noqa: F401 from typing import Any, NamedTuple # noqa: F401 from mh.pat import PatServer, PatRequestHandler @@ -258,43 +257,6 @@ def get_external_ip(config): return config["ExternalIP"] or get_ip(config["IP"]) -def wii_ssl_wrap_socket(sock, ssl_cert, ssl_key): - # type: (socket.socket, str, str) -> SSLSocket - """SSL wrapper for network sockets aiming Wii compatibility. - - References: - https://docs.python.org/2.7/library/ssl.html - https://docs.python.org/3/library/ssl.html - https://www.openssl.org/docs/man1.0.2/man1/ciphers.html - https://www.openssl.org/docs/man1.1.1/man1/ciphers.html - https://www.openssl.org/docs/man3.0/man1/openssl-ciphers.html - """ - import ssl - - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - - if hasattr(ssl, "TLSVersion"): # Since Python 3.7 - # Required since Python 3.10 - context.minimum_version = ssl.TLSVersion.SSLv3 - wii_ciphers = ":".join([ - "AES128-SHA", "AES256-SHA", - # The following ones are often unavailable - "DES-CBC-SHA", "3DES-CBC-SHA", - "RC4-MD5", "RC4-SHA" - # NB: Python might enforce additional (unsupported) ciphers - # for security reasons - # TODO: Disable them in Dolphin to emulate the Wii accurately - ]) - - # Try to enforce legacy ciphers/weak cert chain (OpenSSL >= 1.1 only) - if ssl.OPENSSL_VERSION_INFO >= (1, 1): - wii_ciphers += ":@SECLEVEL=0" - - context.set_ciphers(wii_ciphers) - context.load_cert_chain(ssl_cert, ssl_key) - return context.wrap_socket(sock, server_side=True) - - def create_server(server_class, server_handler, address="0.0.0.0", port=8200, name="Server", max_thread=0, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", @@ -313,7 +275,7 @@ def create_server(server_class, server_handler, ssl_key = None return server_class( (address, port), server_handler, - max_thread_count=max_thread, logger=logger, debug_mode=debug_mode, + max_thread=max_thread, logger=logger, debug_mode=debug_mode, ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout, **kwargs ) @@ -387,4 +349,5 @@ def ui_update(): traceback.print_exc() sys.exit(1) finally: - server.close() + server.shutdown() + server.server_close() From 6436512159c9b8c9b75d1b1ee720bb888384ab4a Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sun, 1 Mar 2026 22:55:42 +0400 Subject: [PATCH 7/7] mh.database: Add SafeMySQLConnection helper --- mh/database.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/mh/database.py b/mh/database.py index 9b6afb7..4b87eb7 100644 --- a/mh/database.py +++ b/mh/database.py @@ -7,11 +7,23 @@ import inspect import random import sqlite3 +import time from threading import local as thread_local from other import utils from other.config import MySQLConfig +from other.python import TYPE_CHECKING + +try: + import mysql.connector +except ImportError: + pass + +if TYPE_CHECKING: + from mysql.connector.pooling import PooledMySQLConnection # noqa: F401 + from mysql.connector.abstracts import \ + MySQLConnectionAbstract, MySQLCursorAbstract # noqa: F401 CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -30,6 +42,11 @@ def new_random_str(length=6): return "".join(random.choice(CHARSET) for _ in range(length)) +class DatabaseError(Exception): + """Database exception class.""" + pass + + class TempDatabase(object): """A temporary database. @@ -401,6 +418,65 @@ def delete_friend(self, capcom_id, friend_id): return self.parent.delete_friend(capcom_id, friend_id) +class SafeMySQLConnection(object): + """Proxy object to safely reconnect to MySQL database. + + TODO: If the logic needs to be duplicated, we can move it into a + dedicated annotation/metaclass. + """ + + def __init__(self, attempts=3, cooldown=60.0): + # type: (int, float) -> None + """Reconnection attempts before waiting a cooldown time.""" + self.__attempts = attempts + self.__cooldown = cooldown + self.__connection = None # noqa: E501 # type: MySQLConnectionAbstract | PooledMySQLConnection | None + self.__time = None # type: float | None + self.__restart() + + def __restart(self): + # type: () -> None + """Reload the config and restart the connection.""" + if self.__connection: + # noinspection PyBroadException + try: + self.__connection.close() + except Exception: + pass + finally: + self.__connection = None + + self.__connection = mysql.connector.connect( + **MySQLConfig().connect_kwargs() + ) + + def cursor(self): + # type: () -> MySQLCursorAbstract + """Override the cursor method.""" + # noinspection PyBroadException + try: + c = self.__connection.cursor() # type: ignore + self.__time = None + return c + except Exception: + now = time.time() + if self.__time and (now - self.__time) < self.__cooldown: + raise DatabaseError("Connection lost, reconnection cooldown") + + self.__time = now + for _ in range(self.__attempts): + # noinspection PyBroadException + try: + self.__restart() + if self.__connection is not None: + c = self.__connection.cursor() + self.__time = None + return c + except Exception: + pass + raise DatabaseError("Connection lost, reconnection attempts failed") + + class MySQLDatabase(TempDatabase): """Hybrid MySQL/TempDatabase. @@ -416,10 +492,7 @@ class MySQLDatabase(TempDatabase): def __init__(self): self.parent = super(MySQLDatabase, self) self.parent.__init__() - from mysql import connector - self.connection = connector.connect( - **MySQLConfig().connect_kwargs() - ) + self.connection = SafeMySQLConnection() self.create_database() self.populate_database()