From 4c328d45ba71f0ef698ba55a59850fdf9de6a13a Mon Sep 17 00:00:00 2001 From: Kevin Barkevich Date: Sat, 14 Jun 2025 23:22:29 -0400 Subject: [PATCH 1/5] Move FMP state to its own class Other changes: - city.is_empty method added - assert_fields when searching cities fixed - workaround for cyclic import (wip) --- fmp_server.py | 13 +- mh/database.py | 508 ++------------------------------------- mh/pat.py | 75 ++++-- mh/pat_item.py | 55 ++++- mh/session.py | 294 ++++++++++++++++++----- mh/state.py | 280 ++++++++++++++++++++++ mh/state_models.py | 585 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 1224 insertions(+), 586 deletions(-) create mode 100644 mh/state.py create mode 100644 mh/state_models.py diff --git a/fmp_server.py b/fmp_server.py index c179f05..7f1e1eb 100644 --- a/fmp_server.py +++ b/fmp_server.py @@ -9,18 +9,29 @@ import mh.pat_item as pati from mh.constants import PatID4 from mh.pat import PatRequestHandler, PatServer +from mh.session import FMPSession +from mh.state import State from other.utils import hexdump, server_base, server_main, to_str class FmpServer(PatServer): """Basic FMP server class.""" # TODO: Backport close cache logic - pass + def __init__(self, *args, **kwargs): + PatServer.__init__(self, *args, **kwargs) + self.fmp_state = State() + # TODO: Backport the cache server registration logic instead + import mh.database as db + db.get_instance().servers = self.fmp_state.servers class FmpRequestHandler(PatRequestHandler): """Basic FMP server request handler class.""" + def on_init(self): + PatRequestHandler.on_init(self) + self.session = FMPSession(self.session) + def recvAnsConnection(self, packet_id, data, seq): """AnsConnection packet.""" connection_data = pati.ConnectionData.unpack(data) diff --git a/mh/database.py b/mh/database.py index ad7c70b..9c17bfe 100644 --- a/mh/database.py +++ b/mh/database.py @@ -7,9 +7,8 @@ import inspect import random import sqlite3 -import time from other import utils -from threading import RLock, local as thread_local +from threading import local as thread_local CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -23,359 +22,11 @@ BLANK_CAPCOM_ID = "******" -RESERVE_DC_TIMEOUT = 40.0 - def new_random_str(length=6): return "".join(random.choice(CHARSET) for _ in range(length)) -class ServerType(object): - OPEN = 1 - ROOKIE = 2 - EXPERT = 3 - RECRUITING = 4 - - -class LayerState(object): - EMPTY = 1 - FULL = 2 - JOINABLE = 3 - - -class Lockable(object): - def __init__(self): - self._lock = RLock() - - def lock(self): - return self - - def __enter__(self): - # Returns True if lock was acquired, False otherwise - return self._lock.acquire() - - def __exit__(self, *args): - self._lock.release() - - -class Players(Lockable): - def __init__(self, capacity): - assert capacity > 0, "Collection capacity can't be zero" - - self.slots = [None for _ in range(capacity)] - self.used = 0 - super(Players, self).__init__() - - def get_used_count(self): - return self.used - - def get_capacity(self): - return len(self.slots) - - def add(self, item): - with self.lock(): - if self.used >= len(self.slots): - return -1 - - item_index = self.index(item) - if item_index != -1: - return item_index - - for i, v in enumerate(self.slots): - if v is not None: - continue - - self.slots[i] = item - self.used += 1 - return i - - return -1 - - def remove(self, item): - assert item is not None, "Item != None" - - with self.lock(): - if self.used < 1: - return False - - if isinstance(item, int): - if item >= self.get_capacity(): - return False - - self.slots[item] = None - self.used -= 1 - return True - - for i, v in enumerate(self.slots): - if v != item: - continue - - self.slots[i] = None - self.used -= 1 - return True - - return False - - def index(self, item): - assert item is not None, "Item != None" - - for i, v in enumerate(self.slots): - if v == item: - return i - - return -1 - - def clear(self): - with self.lock(): - for i in range(self.get_capacity()): - self.slots[i] = None - - def find_first(self, **kwargs): - if self.used < 1: - return None - - for p in self.slots: - if p is None: - continue - - for k, v in kwargs.items(): - if getattr(p, k) != v: - break - else: - return p - - return None - - def find_by_capcom_id(self, capcom_id): - return self.find_first(capcom_id=capcom_id) - - def __len__(self): - return self.used - - def __iter__(self): - if self.used < 1: - return - - for i, v in enumerate(self.slots): - if v is None: - continue - - yield i, v - - -class Circle(Lockable): - def __init__(self, parent): - # type: (City) -> None - self.parent = parent - self.leader = None - self.players = Players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - super(Circle, self).__init__() - - def get_population(self): - return len(self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def is_full(self): - return self.get_population() == self.get_capacity() - - def is_empty(self): - return self.leader is None - - def is_joinable(self): - return not self.departed and not self.is_full() - - def has_password(self): - return self.password is not None - - def reset_players(self, capacity): - with self.lock(): - self.players = Players(capacity) - - def reset(self): - with self.lock(): - self.leader = None - self.reset_players(4) - self.departed = False - self.quest_id = 0 - self.embarked = False - self.password = None - self.remarks = None - - self.unk_byte_0x0e = 0 - - -class City(Lockable): - LAYER_DEPTH = 3 - - def __init__(self, name, parent): - # type: (str, Gate) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.players = Players(4) - self.optional_fields = [] - self.leader = None - self.reserved = None - self.circles = [ - # One circle per player - Circle(self) for _ in range(self.get_capacity()) - ] - super(City, self).__init__() - - def get_population(self): - return len(self.players) - - def in_quest_players(self): - return sum(p.is_in_quest() for _, p in self.players) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - if self.reserved: - return LayerState.FULL - - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - def get_pathname(self): - pathname = self.name - it = self.parent - while it is not None: - pathname = it.name + "\t" + pathname - it = it.parent - return pathname - - def get_first_empty_circle(self): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.is_empty(): - return circle, index - return None, None - - def get_circle_for(self, leader_session): - with self.lock(): - for index, circle in enumerate(self.circles): - if circle.leader == leader_session: - return circle, index - return None, None - - def clear_circles(self): - with self.lock(): - for circle in self.circles: - circle.reset() - - def reserve(self, reserve): - with self.lock(): - if reserve: - self.reserved = time.time() - else: - self.reserved = None - - def reset(self): - with self.lock(): - self.state = LayerState.EMPTY - self.players.clear() - self.optional_fields = [] - self.leader = None - self.reserved = None - self.clear_circles() - - -class Gate(object): - LAYER_DEPTH = 2 - - def __init__(self, name, parent, city_count=40, player_capacity=100): - # type: (str, Server, int, int) -> None - self.name = name - self.parent = parent - self.state = LayerState.EMPTY - self.cities = [ - City("City{}".format(i), self) - for i in range(1, city_count+1) - ] - self.players = Players(player_capacity) - self.optional_fields = [] - - def get_population(self): - return len(self.players) + sum(( - city.get_population() - for city in self.cities - )) - - def get_capacity(self): - return self.players.get_capacity() - - def get_state(self): - size = self.get_population() - if size == 0: - return LayerState.EMPTY - elif size < self.get_capacity(): - return LayerState.JOINABLE - else: - return LayerState.FULL - - -class Server(object): - LAYER_DEPTH = 1 - - def __init__(self, name, server_type, gate_count=40, capacity=2000, - addr=None, port=None): - self.name = name - self.parent = None - self.server_type = server_type - self.addr = addr # public IP address - self.port = port - self.gates = [ - Gate("City Gate{}".format(i), self) - for i in range(1, gate_count+1) - ] - self.players = Players(capacity) - - def get_population(self): - return len(self.players) + sum(( - gate.get_population() for gate in self.gates - )) - - def get_capacity(self): - return self.players.get_capacity() - - -def new_servers(): - servers = [] - servers.extend([ - Server("Valor{}".format(i), ServerType.OPEN) - for i in range(1, 5) - ]) - servers.extend([ - Server("Beginners{}".format(i), ServerType.ROOKIE) - for i in range(1, 3) - ]) - servers.extend([ - Server("Veterans{}".format(i), ServerType.EXPERT) - for i in range(1, 3) - ]) - servers.extend([ - Server("Greed{}".format(i), ServerType.RECRUITING) - for i in range(1, 5) - ]) - return servers - - class TempDatabase(object): """A temporary database. @@ -403,7 +54,6 @@ def __init__(self): # Capcom ID => List of Capcom IDs # TODO: May need stable index, see Players class } - self.servers = new_servers() def get_support_code(self, session): """Get the online support code or create one.""" @@ -500,93 +150,8 @@ def get_users(self, session, first_index, count): ]) return capcom_ids - def join_server(self, session, index): - if session.local_info["server_id"] is not None: - self.leave_server(session, session.local_info["server_id"]) - server = self.get_server(index) - server.players.add(session) - session.local_info["server_id"] = index - session.local_info["server_name"] = server.name - return server - - def leave_server(self, session, index): - self.get_server(index).players.remove(session) - session.local_info["server_id"] = None - session.local_info["server_name"] = None - - def get_server_time(self): - pass - - def get_game_time(self): - pass - - def get_servers(self): - return self.servers - - def get_server(self, index): - assert 0 < index <= len(self.servers), "Invalid server index" - return self.servers[index - 1] - - def get_gates(self, server_id): - return self.get_server(server_id).gates - - def get_gate(self, server_id, index): - gates = self.get_gates(server_id) - assert 0 < index <= len(gates), "Invalid gate index" - return gates[index - 1] - - def join_gate(self, session, server_id, index): - gate = self.get_gate(server_id, index) - gate.parent.players.remove(session) - gate.players.add(session) - session.local_info["gate_id"] = index - session.local_info["gate_name"] = gate.name - return gate - - def leave_gate(self, session): - gate = self.get_gate(session.local_info["server_id"], - session.local_info["gate_id"]) - gate.parent.players.add(session) - gate.players.remove(session) - session.local_info["gate_id"] = None - session.local_info["gate_name"] = None - - def get_cities(self, server_id, gate_id): - return self.get_gate(server_id, gate_id).cities - - def get_city(self, server_id, gate_id, index): - cities = self.get_cities(server_id, gate_id) - assert 0 < index <= len(cities), "Invalid city index" - return cities[index - 1] - - def reserve_city(self, server_id, gate_id, index, reserve): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - reserved_time = city.reserved - if reserve and reserved_time and \ - time.time()-reserved_time < RESERVE_DC_TIMEOUT: - return False - city.reserve(reserve) - return True - - def get_all_users(self, server_id, gate_id, city_id): - """Search for users in layers and its children. - - Let's assume wildcard search isn't possible for servers and gates. - A wildcard search happens when the id is zero. - """ - assert 0 < server_id, "Invalid server index" - assert 0 < gate_id, "Invalid gate index" - gate = self.get_gate(server_id, gate_id) - users = list(gate.players) - cities = [ - self.get_city(server_id, gate_id, city_id) - ] if city_id else self.get_cities(server_id, gate_id) - for city in cities: - users.extend(list(city.players)) - return users - - def find_users(self, capcom_id="", hunter_name=""): + def find_users(self, capcom_id="", hunter_name=b""): + # type: (str, bytes) -> list["Session"] assert capcom_id or hunter_name, "Search can't be empty" users = [] for user_id, user_info in self.capcom_ids.items(): @@ -606,59 +171,6 @@ def get_user_name(self, capcom_id): return "" return self.capcom_ids[capcom_id]["name"] - def create_city(self, session, server_id, gate_id, index, - settings, optional_fields): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.optional_fields = optional_fields - city.leader = session - city.reserved = None - return city - - def join_city(self, session, server_id, gate_id, index): - city = self.get_city(server_id, gate_id, index) - with city.lock(): - city.parent.players.remove(session) - city.players.add(session) - session.local_info["city_name"] = city.name - session.local_info["city_id"] = index - return city - - def leave_city(self, session): - city = self.get_city(session.local_info["server_id"], - session.local_info["gate_id"], - session.local_info["city_id"]) - with city.lock(): - city.parent.players.add(session) - city.players.remove(session) - if not city.get_population(): - city.reset() - session.local_info["city_id"] = None - session.local_info["city_name"] = None - - def layer_detail_search(self, server_type, fields): - cities = [] - - def match_city(city, fields): - with city.lock(): - return all(( - field in city.optional_fields - for field in fields - )) - - for server in self.servers: - if server.server_type != server_type: - continue - for gate in server.gates: - if not gate.get_population(): - continue - cities.extend([ - city - for city in gate.cities - if match_city(city, fields) - ]) - return cities - def add_friend_request(self, sender_id, recipient_id): # Friend invite can be sent to arbitrary Capcom ID if any(cid not in self.capcom_ids @@ -1148,14 +660,23 @@ def __init__(self, *args, **kwargs): CURRENT_DB = \ MySQLDatabase() \ if utils.is_mysql_enabled("MYSQL") \ - else TempSQLiteDatabase() + else TempSQLiteDatabase() # type: TempDatabase def get_instance(): + # type: () -> TempDatabase + """Return a database instance like module singleton. + + TODO: + - Rename the TempDatabase interface to DatabaseInterface or similar + - Proper type hint, though any implementation must implement all the + interface methods. + """ return CURRENT_DB def implementation_check(cls, base=TempDatabase): + # type: (type, type) -> bool """Debug implementation check allowing to see methods dependencies. Example: @@ -1163,6 +684,7 @@ def implementation_check(cls, base=TempDatabase): >>> implementation_check(TempSQLiteDatabase) """ def is_method(obj): + # type: (object) -> bool """Python3's missing unbound methods workaround.""" if inspect.ismethod(obj): return True @@ -1179,7 +701,7 @@ def is_method(obj): print("class {}:".format(cls.__name__)) - for name, obj in inspect.getmembers(cls, is_method): + for name, _ in inspect.getmembers(cls, is_method): if name in missing_methods: missing_methods.remove(name) for c in mros: diff --git a/mh/pat.py b/mh/pat.py index 6bf7b47..5d1e111 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -19,8 +19,8 @@ CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, \ MAINTENANCE, UNPATCHED, \ PatID4, get_pat_binary_from_version -from mh.session import Session -import mh.database as db +from mh.session import Session, FMPSession +from mh.state_models import Server, Gate, City, Circle, Players # noqa: F401 try: from typing import Literal @@ -81,7 +81,7 @@ def get_pat_handler(self, session): # TODO: Backport broadcasting refactoring if needed def broadcast(self, players, packet_id, data, seq, to_exclude=None): - # type: (db.Players, int, bytes, int, Session|None) -> None + # type: (Players, int, bytes, int, Session|None) -> None handlers = [] with players.lock(): for _, player in players: @@ -102,7 +102,7 @@ def layer_broadcast(self, session, packet_id, data, seq, def circle_broadcast(self, circle, packet_id, data, seq, session=None): - # type: (db.Circle, int, bytes, int, Session|None) -> None + # type: (Circle, int, bytes, int, Session|None) -> None with circle.lock(): self.broadcast(circle.players, packet_id, data, seq, session) @@ -302,6 +302,7 @@ def recvAnsConnection(self, packet_id, data, seq): has_ban = "online_support_code" in settings and \ pat_ticket in BANNED_ONLINE_SUPPORT_CODES + # Used by OPN server, not sure we should rely on session at this point if has_ban or len(self.session.get_servers()) == 0: self.sendNtcLogin(2, settings, seq) else: @@ -932,6 +933,8 @@ def recvReqFmpListVersion(self, packet_id, data, seq): JP: FMPリストバージョン確認 TR: FMP list version check + NB: This packet is also used by the LMP server. + TODO: - Find why there are 2 versions of FMP packets. - Find why most of the 2 versions are ignored. @@ -967,6 +970,8 @@ def recvReqFmpListHead(self, packet_id, data, seq): ID: 61310100 / 63110100 JP: FMPリスト数送信 / FMPリスト数要求 TR: Send FMP list count / FMP list count request + + NB: This packet is also used by the LMP server. """ # TODO: Might be worth investigating these parameters as # they might be useful when using multiple FMP servers. @@ -987,6 +992,8 @@ def sendAnsFmpListHead(self, seq): ID: 61310200 JP: FMPリスト数応答 TR: FMP list count response + + NB: This packet is also used by the LMP server. """ unused = 0 count = len(self.session.get_servers()) @@ -1013,6 +1020,8 @@ def recvReqFmpListData(self, packet_id, data, seq): ID: 61320100 / 63120100 JP: FMPリスト送信 / FMPリスト要求 TR: Send FMP list / FMP list response + + NB: This packet is also used by the LMP server. """ first_index, count = struct.unpack_from(">II", data) if packet_id == PatID4.ReqFmpListData: @@ -1058,6 +1067,8 @@ def recvReqFmpListFoot(self, packet_id, data, seq): ID: 61330100 / 63130100 JP: FMPリスト送信終了 / FMPリスト終了送信 TR: FMP list end of transmission / FMP list transmission end + + NB: This packet is also used by the LMP server. """ if packet_id == PatID4.ReqFmpListFoot: self.sendAnsFmpListFoot(seq) @@ -1132,21 +1143,26 @@ def recvReqFmpInfo(self, packet_id, data, seq): TR: FMP data request TODO: Do not hardcode the data and find the meaning of all fields. + + NB: This packet is also used by the LMP server. """ index, = struct.unpack_from(">I", data) fields = pati.unpack_bytes(data, 4) - server = self.session.join_server(index) + # FIXME: Doesn't seem to make sense here, + # e.g. on LMP server, as "FmpInfo" packet + # server = self.session.join_server(index) config = get_config("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] fmp_data = pati.FmpData() - fmp_data.server_address = pati.String(server.addr or fmp_addr) - fmp_data.server_port = pati.Word(server.port or fmp_port) + fmp_data.server_address = pati.String(fmp_addr) + fmp_data.server_port = pati.Word(fmp_port) fmp_data.assert_fields(fields) # TODO: Backport central logic - if packet_id == PatID4.ReqFmpInfo: + if packet_id == PatID4.ReqFmpInfo: # LMP version self.sendAnsFmpInfo(fmp_data, fields, seq) - elif packet_id == PatID4.ReqFmpInfo2: + elif packet_id == PatID4.ReqFmpInfo2: # FMP version + self.session.join_server(index) self.sendAnsFmpInfo2(fmp_data, fields, seq) # Preserve session in database, due to server selection @@ -1416,8 +1432,9 @@ def recvReqLayerStart(self, packet_id, data, seq): JP: レイヤ開始要求 TR: Layer start request """ - unk1 = pati.unpack_bytes(data) - unk2 = pati.unpack_bytes(data, len(unk1) + 1) + with pati.Unpacker(data) as unpacker: + unk1 = unpacker.bytes() + unk2 = unpacker.bytes() self.sendAnsLayerStart(unk1, unk2, seq) def sendAnsLayerStart(self, unk1, unk2, seq): @@ -2402,7 +2419,17 @@ def sendAnsLayerDetailSearchData(self, offset, count, seq): for i, city in enumerate(cities): with city.lock(): layer_data = pati.LayerData.create_from(i, city) - layer_data.assert_fields(self.search_info["layer_fields"]) + layer_fields = self.search_info["layer_fields"] + filtered_fields = layer_data.filter_fields(layer_fields) # noqa: F841 + """ + # During testing / TODO: Investigate + filtered_fields = [ + (5, 'index', Word(0)) + (17, 'positionInterval', Long(500)) + (18, 'unk_byte_0x12', Byte(1)) + ] + """ + layer_data.assert_fields(layer_fields) data += layer_data.pack() data += pati.pack_optional_fields(city.optional_fields) with city.players.lock(): @@ -2737,7 +2764,7 @@ def sendNtcCircleHost(self, circle, new_leader, new_leader_index, seq): def notify_city_info_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2749,7 +2776,7 @@ def notify_city_info_set(self, path): def notify_city_number_set(self, path): # type: (pati.LayerPath) -> None city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) gate = city.parent layer_data = pati.LayerData.create_from(path.city_id, city, path) @@ -2757,17 +2784,16 @@ def notify_city_number_set(self, path): self.server.broadcast(gate.players, PatID4.NtcLayerUserNum, number_set, 0, self.session) - @staticmethod - def get_layer(path): - # type: (pati.LayerPath) -> db.Server | db.Gate | db.City | None - database = db.get_instance() + def get_layer(self, path): + # type: (pati.LayerPath) -> Server | Gate | City | None + fmp_state = self.session.FMP() if path.city_id > 0: - return database.get_city(path.server_id, path.gate_id, - path.city_id) + return fmp_state.get_city(path.server_id, path.gate_id, + path.city_id) elif path.gate_id > 0: - return database.get_gate(path.server_id, path.gate_id) + return fmp_state.get_gate(path.server_id, path.gate_id) elif path.server_id > 0: - return database.get_server(path.server_id) + return fmp_state.get_server(path.server_id) return None def notify_layer_departure(self, end): @@ -2793,7 +2819,7 @@ def notify_layer_departure(self, end): if path.city_id > 0: city = self.get_layer(path) - assert isinstance(city, db.City) + assert isinstance(city, City) self.notify_city_number_set(path) if city.leader is None: @@ -2822,7 +2848,8 @@ def on_exception(self, e): self.send_error("{}: {}".format(type(e).__name__, str(e))) def on_finish(self): - self.notify_layer_departure(True) + if isinstance(self.session, FMPSession): + self.notify_layer_departure(True) # TODO: Backport session_layer_end self.session.disconnect() self.session.delete() diff --git a/mh/pat_item.py b/mh/pat_item.py index c1676d4..7d93455 100644 --- a/mh/pat_item.py +++ b/mh/pat_item.py @@ -7,10 +7,11 @@ import struct from collections import OrderedDict + from mh.constants import pad +from mh.state_models import Server, Gate, City from other.utils import to_bytearray, get_config, get_external_ip, \ GenericUnpacker -from mh.database import Server, Gate, City class ItemType: @@ -354,6 +355,8 @@ def unpack_bytes(data, offset=0): class PatData(OrderedDict): + # TODO: type hint the whole module and add OrderedDict generic stub + # see https://mypy.readthedocs.io/en/stable/runtime_troubles.html """Pat structure holding items.""" FIELDS = ( (1, "field_0x01"), @@ -363,9 +366,11 @@ class PatData(OrderedDict): ) def __len__(self): + # type: () -> int return len(self.pack()) def __repr__(self): + # type: () -> str items = [ (index, value) for index, value in self.items() @@ -380,6 +385,7 @@ def __repr__(self): ) def __getattr__(self, name): + # type: (str) -> Item for field_id, field_name in self.FIELDS: if name == field_name: if field_id not in self: @@ -388,6 +394,7 @@ def __getattr__(self, name): raise AttributeError("Unknown field: {}".format(name)) def __setattr__(self, name, value): + # type: (str, Item) -> None for field_id, field_name in self.FIELDS: if name == field_name: if not isinstance(value, Item): @@ -399,6 +406,7 @@ def __setattr__(self, name, value): raise AttributeError("Cannot set unknown field: {}".format(name)) def __delattr__(self, name): + # type: (str) -> None for field_id, field_name in self.FIELDS: if name == field_name: del self[field_id] @@ -406,13 +414,19 @@ def __delattr__(self, name): return OrderedDict.__delattr__(self, name) def __setitem__(self, key, value): + # type: (int, Item) -> None if not isinstance(key, int) or not (0 <= key <= 255): raise IndexError("index must be a valid numeric value") elif not isinstance(value, Item): raise ValueError("{!r} not a valid PAT item".format(value)) return OrderedDict.__setitem__(self, key, value) + def __getitem__(self, key): + # type: (int) -> Item + return OrderedDict.__getitem__(self, key) + def __contains__(self, key): + # type: (int | str | object) -> bool if isinstance(key, str): for field_id, field_name in self.FIELDS: if field_name == key: @@ -426,13 +440,19 @@ def __contains__(self, key): return OrderedDict.__contains__(self, key) + def items(self): + # type: () -> list[tuple[int, Item]] + return OrderedDict.items(self) + def field_name(self, index): + # type: (int) -> str for field_id, field_name in self.FIELDS: if index == field_id: return field_name return "field_0x{:02x}".format(index) def pack(self): + # type: () -> bytes """Pack PAT items.""" items = [ (index, value) @@ -445,6 +465,7 @@ def pack(self): ) def pack_fields(self, fields): + # type: (set[int]) -> bytes """Pack PAT items specified fields.""" items = [ (index, value) @@ -458,6 +479,7 @@ def pack_fields(self, fields): @classmethod def unpack(cls, data, offset=0): + # type: (type[PatData], bytes, int) -> PatData obj = cls() field_count, = struct.unpack_from(">B", data, offset) offset += 1 @@ -489,11 +511,36 @@ def unpack(cls, data, offset=0): return obj def assert_fields(self, fields): - items = set(self.keys()) + # type: (set[int]) -> None + items = set(self.keys()) # type: set[int] fields = set(fields) message = "Fields mismatch: {}\n -> Expected: {}".format(items, fields) assert items == fields, message + def filter_fields(self, fields): + # type: (set[int]) -> list[tuple[int, str, Item]] + """Filter PatData excess of information. + + Places using this method should be investigated to make sure there is + no oversight (or more reverse engineering needed) regarding the + data structure and packet used. + """ + items = set(self.keys()) # type: set[int] + fields = set(fields) + missing_fields = fields - items + message = "Can't filter missing fields: {}\n".format(", ".join( + self.field_name(field_id) for field_id in missing_fields + )) + assert not missing_fields, message + new_fields = items - fields + filtered_fields = [] # type: list[tuple[int, str, Item]] + for field_id in new_fields: + filtered_fields.append( + (field_id, self.field_name(field_id), self[field_id]) + ) + del self[field_id] + return filtered_fields + class DummyData(PatData): FIELDS = tuple() @@ -588,6 +635,10 @@ class UserSearchInfo(PatData): ) +# FIXME: Both LayerPath and LayerData introduce a strong dependency on +# mh.state_models, which should be avoided for this module. + + class LayerPath(object): STRUCT = struct.Struct(">IIHHH") diff --git a/mh/session.py b/mh/session.py index 698c342..aaed0c4 100644 --- a/mh/session.py +++ b/mh/session.py @@ -9,6 +9,28 @@ from other.utils import to_bytearray, to_str +try: + from typing import Any, TypedDict, TYPE_CHECKING # noqa: F401 + + if TYPE_CHECKING: + from mh.state import State # noqa: F401 + + SessionLocalInfo = TypedDict( + "SessionLocalInfo", { + "server_id": None | int, + "server_name": None | str, # maybe bytes to support accents? + "gate_id": None | int, + "gate_name": None | str, # ditto regarding special characters + "city_id": None | int, + "city_name": None | str, # ditto regarding special characters + "city_size": int, + "city_capacity": int, + "circle_id": None | int, + } + ) +except (ImportError, TypeError): + pass + DB = db.get_instance() @@ -25,6 +47,9 @@ class SessionState: class Session(object): """Server session class. + The goal of this class is to help writing the PAT packet logic by + handling complex interactions between servers/database behind the scene. + TODO: - Finish the implementation """ @@ -40,7 +65,7 @@ def __init__(self, connection_handler): "city_size": 0, "city_capacity": 0, "circle_id": None, - } + } # type: SessionLocalInfo self.connection = connection_handler self.online_support_code = None self.request_reconnection = False @@ -52,9 +77,82 @@ def __init__(self, connection_handler): self.state = SessionState.UNKNOWN self.binary_setting = b"" self.search_payload = None - # TODO: Backport the server_id and serialisation logic + # TODO: Backport the server_id logic self.hunter_info = pati.HunterSettings() + def serialize(self): + # type: () -> dict[str, Any] + return { + "pat_ticket": self.pat_ticket, + # TODO: local_info should be safe to serialize on its own as dict + "local_info_server_id": self.local_info["server_id"], + "local_info_server_name": self.local_info["server_name"], + "local_info_gate_id": self.local_info["gate_id"], + "local_info_gate_name": self.local_info["gate_name"], + "local_info_city_id": self.local_info["city_id"], + "local_info_city_name": self.local_info["city_name"], + "local_info_city_size": self.local_info["city_size"], + "local_info_city_capacity": self.local_info["city_capacity"], + "local_info_circle_id": self.local_info["circle_id"], + "online_support_code": self.online_support_code, + "capcom_id": self.capcom_id, + "hunter_name": self.hunter_name, + "hunter_stats": self.hunter_stats, + "layer": self.layer, + "state": self.state, + "binary_setting": self.binary_setting, + "hunter_info": to_str(self.hunter_info.pack()) + } + + @staticmethod + def deserialize(obj): + # type: (dict[str, Any]) -> Session + session = Session(None) + session.pat_ticket = \ + str(obj["pat_ticket"]) \ + if obj["pat_ticket"] else obj["pat_ticket"] + session.local_info["server_id"] = \ + int(obj["local_info_server_id"]) \ + if obj["local_info_server_id"] else obj["local_info_server_id"] + session.local_info["server_name"] = \ + str(obj["local_info_server_name"]) \ + if obj["local_info_server_name"] else obj["local_info_server_name"] + session.local_info["gate_id"] = \ + int(obj["local_info_gate_id"]) \ + if obj["local_info_gate_id"] else obj["local_info_gate_id"] + session.local_info["gate_name"] = \ + str(obj["local_info_gate_name"]) \ + if obj["local_info_gate_name"] else obj["local_info_gate_name"] + session.local_info["city_id"] = \ + int(obj["local_info_city_id"]) \ + if obj["local_info_city_id"] else obj["local_info_city_id"] + session.local_info["city_name"] = \ + str(obj["local_info_city_name"]) \ + if obj["local_info_city_name"] else obj["local_info_city_name"] + session.local_info["city_size"] = \ + int(obj["local_info_city_size"]) \ + if obj["local_info_city_size"] else obj["local_info_city_size"] + session.local_info["city_capacity"] = \ + int(obj["local_info_city_capacity"]) \ + if obj["local_info_city_capacity"] \ + else obj["local_info_city_capacity"] + session.local_info["circle_id"] = \ + int(obj["local_info_circle_id"]) \ + if obj["local_info_circle_id"] else obj["local_info_circle_id"] + session.online_support_code = \ + str(obj["online_support_code"]) \ + if obj["online_support_code"] else obj["online_support_code"] + session.capcom_id = str(obj["capcom_id"]) + session.hunter_name = str(obj["hunter_name"]) + session.hunter_stats = obj["hunter_stats"] + session.layer = int(obj["layer"]) + session.state = int(obj["state"]) + session.binary_setting = obj["binary_setting"] + h_settings = bytearray(obj["hunter_info"], encoding='ISO-8859-1') + session.hunter_info = pati.HunterSettings().unpack(h_settings, + len(h_settings)) + return session + def get(self, connection_data): """Return the session associated with the connection data, if any.""" if hasattr(connection_data, "pat_ticket"): @@ -68,6 +166,7 @@ def get(self, connection_data): ) session = DB.get_session(self.pat_ticket) or self if session != self: + # TODO: Use a less error-prone check assert session.connection is None, "Session is already in use" session.connection = self.connection self.connection = None @@ -118,25 +217,111 @@ def get_users(self, first_index, count): def use_user(self, index, name): DB.use_user(self, index, name) + def find_user_by_capcom_id(self, capcom_id): + sessions = DB.find_users(capcom_id=capcom_id) + if sessions: + return sessions[0] + return None + + def find_users(self, capcom_id, hunter_name, first_index, count): + users = DB.find_users(capcom_id, hunter_name) + start = first_index - 1 + return users[start:start+count] + + def get_user_name(self, capcom_id): + return DB.get_user_name(capcom_id) + + def add_friend_request(self, capcom_id): + return DB.add_friend_request(self.capcom_id, capcom_id) + + def accept_friend(self, capcom_id, accepted=True): + return DB.accept_friend(self.capcom_id, capcom_id, accepted) + + def delete_friend(self, capcom_id): + return DB.delete_friend(self.capcom_id, capcom_id) + + def get_friends(self, first_index=None, count=None): + return DB.get_friends(self.capcom_id, first_index, count) + # TODO: server_index and recall logic def get_servers(self): - return DB.get_servers() + """LMP servers can request the server list""" + return DB.servers # FIXME: See FmpServer constructor + + +class FMPSession(Session): + """Specialized session wrapper for FMP server. + + Most methods should be unrelated to DB. + This prevent other servers to trigger FMP related features. + """ + def __init__(self, session): + # type: (Session) -> None + self._session = session + self._fmp_state = session.connection.server.fmp_state # type: State + + def __getattr__(self, name): + # type: (str) -> Any + """Called when the default attribute access fails.""" + return getattr(self._session, name) + + def __setattr__(self, name, value): + # type: (str, Any) -> None + if name.startswith("_"): + return object.__setattr__(self, name, value) + return setattr(self._session, name, value) + + def __eq__(self, other): + # type: (Session) -> bool + """Avoid comparison issues when wrapping around the session. + + For instance, list removal is done by equality not identity, + in other words, FMPSession instances can get Session instances to be + removed from a list. + """ + if isinstance(other, FMPSession): + other = other._session + elif not isinstance(other, Session): + return NotImplemented + return self._session == other + + def __ne__(self, other): + # type: (Session) -> bool + x = self.__eq__(other) + if x is NotImplemented: + return NotImplemented + return not x + + def get(self, connection_data): + """Wrap existing session if needed""" + session = Session.get(self, connection_data) + if not isinstance(session, FMPSession): + session = FMPSession(session) + return session + + def FMP(self): + # type: () -> State + """Wrapper that can be repurposed later for cache/error handling.""" + return self._fmp_state + + def get_servers(self): + return self.FMP().get_servers() def get_server(self): assert self.local_info['server_id'] is not None - return DB.get_server(self.local_info['server_id']) + return self.FMP().get_server(self.local_info['server_id']) def get_gate(self): assert self.local_info['gate_id'] is not None - return DB.get_gate(self.local_info['server_id'], - self.local_info['gate_id']) + return self.FMP().get_gate(self.local_info['server_id'], + self.local_info['gate_id']) def get_city(self): assert self.local_info['city_id'] is not None - return DB.get_city(self.local_info['server_id'], - self.local_info['gate_id'], - self.local_info['city_id']) + return self.FMP().get_city(self.local_info['server_id'], + self.local_info['gate_id'], + self.local_info['city_id']) def get_circle(self): assert self.local_info['circle_id'] is not None @@ -197,10 +382,10 @@ def layer_detail_search(self, detailed_fields): (field_id, value) for field_id, field_type, value in detailed_fields ] # Convert detailed to simple optional fields - return DB.layer_detail_search(server_type, fields) + return self.FMP().layer_detail_search(server_type, fields) def join_server(self, server_id): - return DB.join_server(self, server_id) + return self.FMP().join_server(self, server_id) def get_layer_children(self): if self.layer == 0: @@ -218,74 +403,63 @@ def get_layer_sibling(self): def find_users_by_layer(self, server_id, gate_id, city_id, first_index, count, recursive=False): - if recursive: - players = DB.get_all_users(server_id, gate_id, city_id) - else: - layer = \ - DB.get_city(server_id, gate_id, city_id) if city_id else \ - DB.get_gate(server_id, gate_id) if gate_id else \ - DB.get_server(server_id) - players = list(layer.players) start = first_index - 1 + if recursive: + players = self.FMP().get_all_users(server_id, gate_id, city_id) + return players[start:start+count] + + layer = \ + self.FMP().get_city(server_id, gate_id, city_id) if city_id else \ + self.FMP().get_gate(server_id, gate_id) if gate_id else \ + self.FMP().get_server(server_id) + players = list(layer.players) return players[start:start+count] - def find_user_by_capcom_id(self, capcom_id): - sessions = DB.find_users(capcom_id=capcom_id) - if sessions: - return sessions[0] - return None - - def find_users(self, capcom_id, hunter_name, first_index, count): - users = DB.find_users(capcom_id, hunter_name) - start = first_index - 1 - return users[start:start+count] - - def get_user_name(self, capcom_id): - return DB.get_user_name(capcom_id) - def leave_server(self): - DB.leave_server(self, self.local_info["server_id"]) + self.FMP().leave_server(self, self.local_info["server_id"]) def get_gates(self): - return DB.get_gates(self.local_info["server_id"]) + return self.FMP().get_gates(self.local_info["server_id"]) def join_gate(self, gate_id): - DB.join_gate(self, self.local_info["server_id"], gate_id) + self.FMP().join_gate(self, self.local_info["server_id"], gate_id) self.state = SessionState.GATE def leave_gate(self): - DB.leave_gate(self) + self.FMP().leave_gate(self) self.state = SessionState.LOG_IN def get_cities(self): - return DB.get_cities(self.local_info["server_id"], - self.local_info["gate_id"]) + return self.FMP().get_cities(self.local_info["server_id"], + self.local_info["gate_id"]) def is_city_empty(self, city_id): - return DB.get_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id).get_state() == db.LayerState.EMPTY + return self.FMP().get_city( + self.local_info["server_id"], + self.local_info["gate_id"], + city_id + ).is_empty() def reserve_city(self, city_id, reserve): - return DB.reserve_city(self.local_info["server_id"], - self.local_info["gate_id"], - city_id, reserve) + return self.FMP().reserve_city(self.local_info["server_id"], + self.local_info["gate_id"], + city_id, reserve) def create_city(self, city_id, settings, optional_fields): - return DB.create_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id, settings, optional_fields) + return self.FMP().create_city( + self, self.local_info["server_id"], self.local_info["gate_id"], + city_id, settings, optional_fields + ) def join_city(self, city_id): - DB.join_city(self, - self.local_info["server_id"], - self.local_info["gate_id"], - city_id) + self.FMP().join_city(self, + self.local_info["server_id"], + self.local_info["gate_id"], + city_id) self.state = SessionState.CITY def leave_city(self): - DB.leave_city(self) + self.FMP().leave_city(self) self.state = SessionState.GATE def try_transfer_city_leadership(self): @@ -386,15 +560,3 @@ def get_optional_fields(self): (1, (weapon_type << 24) | location), (2, hunter_rank << 16) ] - - def add_friend_request(self, capcom_id): - return DB.add_friend_request(self.capcom_id, capcom_id) - - def accept_friend(self, capcom_id, accepted=True): - return DB.accept_friend(self.capcom_id, capcom_id, accepted) - - def delete_friend(self, capcom_id): - return DB.delete_friend(self.capcom_id, capcom_id) - - def get_friends(self, first_index=None, count=None): - return DB.get_friends(self.capcom_id, first_index, count) diff --git a/mh/state.py b/mh/state.py new file mode 100644 index 0000000..78238f7 --- /dev/null +++ b/mh/state.py @@ -0,0 +1,280 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2023-2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""FMP server state module. + +Each FMP server needs to know the in-game resources it manages: + - Servers + - Gates + - Cities + - Sessions (i.e. active server connections) + +TODO/FIXME: +Our current design isn't very scalable. +""" + + +from mh.database import get_instance as get_db +from mh.state_models import Server, ServerType + +try: + from typing import TypedDict, TYPE_CHECKING + if TYPE_CHECKING: + from mh.session import Session + from mh.state_models import Gate, City # noqa: F401 + + CapcomIDsInfo = TypedDict("CapcomIDsInfo", { + "name": bytes, + "session": None | Session + }) +except (ImportError, TypeError): + pass + + +def new_servers(): + # type: () -> list[Server] + # TODO: This logic was removed upstream to use harcoded values... + servers = [] # type: list[Server] + servers.extend([ + Server("Valor{}".format(i), ServerType.OPEN) + for i in range(1, 5) + ]) + servers.extend([ + Server("Beginners{}".format(i), ServerType.ROOKIE) + for i in range(1, 3) + ]) + servers.extend([ + Server("Veterans{}".format(i), ServerType.EXPERT) + for i in range(1, 3) + ]) + servers.extend([ + Server("Greed{}".format(i), ServerType.RECRUITING) + for i in range(1, 5) + ]) + return servers + + +class State(object): + """FMP server state class. + + Holds the required information not to bother too frequently: + - the database server + - the central server + + Ideally, this should allow the server to be resilient and operate even + with limited connectivity to the database and central server. + + TODO: (backport) + - Backport cache/server_id logic + - Like the following methods if needed: + setup_server + register_pat_ticket + get_session + disconnect_session + delete_session (+ cache logic) + fetch_id + get_servers_version + update_players + update_capcom_id + session_ready + set_session_ready + close_cache + + FIXME: (backport) + - These methods should belong to db (imho) + new_pat_ticket -> db.generate_pat_ticket + fill self.session + use_capcom_id + use_user + get_users (i.e. LMP server) + """ + def __init__(self): + self.servers = new_servers() + # TODO: Backport Sessions/Capcom ID handling (currently unused) + self.sessions = { + # PAT Ticket => Owner's session + } # type: dict[str, Session] + self.capcom_ids = { + # Capcom ID => Owner's name and session + # NB: Not to be mistaken with database.capcom_ids! + } # type: dict[str, CapcomIDsInfo] + + def join_server(self, session, index): + # type: (Session, int) -> Server + if session.local_info["server_id"] is not None: + self.leave_server(session, session.local_info["server_id"]) + # TODO: Backport cache/joining another external server + # It might imply moving from (array) index to (dict) server_id + server = self.get_server(index) + server.players.add(session) + session.local_info["server_id"] = index + session.local_info["server_name"] = server.name + return server + + def leave_server(self, session, index): + # type: (Session, int) -> None + self.get_server(index).players.remove(session) + session.local_info["server_id"] = None + session.local_info["server_name"] = None + + def get_server_time(self): + # TODO: Use it at some point or remove it + pass + + def get_game_time(self): + # TODO: Use it at some point or remove it + pass + + def get_servers(self): + # type: () -> list[Server] + # TODO: Backport cache code + return self.servers + + def get_server(self, index): + # type: (int) -> Server + # TODO: Backport cache code + assert 0 < index <= len(self.servers), "Invalid server index" + return self.servers[index - 1] + + def get_gates(self, server_id): + # type: (int) -> list[Gate] + return self.get_server(server_id).gates + + def get_gate(self, server_id, index): + # type: (int, int) -> Gate + gates = self.get_gates(server_id) + assert 0 < index <= len(gates), "Invalid gate index" + return gates[index - 1] + + def join_gate(self, session, server_id, index): + # type: (Session, int, int) -> Gate + gate = self.get_gate(server_id, index) + gate.parent.players.remove(session) + gate.players.add(session) + session.local_info["gate_id"] = index + session.local_info["gate_name"] = gate.name + return gate + + def leave_gate(self, session): + # type: (Session) -> None + gate = self.get_gate(session.local_info["server_id"], + session.local_info["gate_id"]) + gate.parent.players.add(session) + gate.players.remove(session) + session.local_info["gate_id"] = None + session.local_info["gate_name"] = None + + def get_cities(self, server_id, gate_id): + # type: (int, int) -> list[City] + return self.get_gate(server_id, gate_id).cities + + def get_city(self, server_id, gate_id, index): + # type: (int, int, int) -> City + cities = self.get_cities(server_id, gate_id) + assert 0 < index <= len(cities), "Invalid city index" + return cities[index - 1] + + def reserve_city(self, server_id, gate_id, index, reserve): + # type: (int, int, int, bool) -> bool + city = self.get_city(server_id, gate_id, index) + with city.lock(): + if reserve and city.is_reserved(): + return False + city.reserve(reserve) + return True + + def get_all_users(self, server_id, gate_id, city_id): + # type: (int, int, None | int) -> list[tuple[int, Session]] + """Search for users in layers and its children. + + Let's assume wildcard search isn't possible for servers and gates. + A wildcard search happens when the id is zero. + """ + assert 0 < server_id, "Invalid server index" + assert 0 < gate_id, "Invalid gate index" + gate = self.get_gate(server_id, gate_id) + users = list(gate.players) + cities = [ + self.get_city(server_id, gate_id, city_id) + ] if city_id else self.get_cities(server_id, gate_id) + for city in cities: + users.extend(list(city.players)) + return users + + def find_users(self, capcom_id="", hunter_name=b""): + # type: (str, bytes) -> list[Session] + assert capcom_id or hunter_name, "Search can't be empty" + users = [] # type: list[Session] + for user_id, user_info in self.capcom_ids.items(): + session = user_info["session"] + if not session: + continue + if capcom_id and capcom_id not in user_id: + continue + if hunter_name and \ + hunter_name.lower() not in user_info["name"].lower(): + continue + users.append(session) + # TODO: Backport cache/central code + # FIXME: We need to rely on DB meanwhile... + assert len(users) == 0, "State doesn't save IDs/Session yet" + users.extend(get_db().find_users( + capcom_id=capcom_id, + hunter_name=hunter_name + )) + return users + + def create_city(self, session, server_id, gate_id, index, + settings, optional_fields): + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.optional_fields = optional_fields + city.leader = session + return city + + def join_city(self, session, server_id, gate_id, index): + # type: (Session, int, int, int) -> City + city = self.get_city(server_id, gate_id, index) + with city.lock(): + city.parent.players.remove(session) + city.players.add(session) + session.local_info["city_name"] = city.name + session.local_info["city_id"] = index + return city + + def leave_city(self, session): + # type: (Session) -> None + city = self.get_city(session.local_info["server_id"], + session.local_info["gate_id"], + session.local_info["city_id"]) + with city.lock(): + city.parent.players.add(session) + city.players.remove(session) + if not city.get_population(): + city.clear_circles() + session.local_info["city_id"] = None + session.local_info["city_name"] = None + + def layer_detail_search(self, server_type, fields): + # TODO: Better document and test this + cities = [] + + def match_city(city, fields): + with city.lock(): + return all(( + field in city.optional_fields + for field in fields + )) + + for server in self.get_servers(): + if server.server_type != server_type: + continue + for gate in server.gates: + if not gate.get_population(): + continue + cities.extend([ + city + for city in gate.cities + if match_city(city, fields) + ]) + return cities diff --git a/mh/state_models.py b/mh/state_models.py new file mode 100644 index 0000000..9e5769e --- /dev/null +++ b/mh/state_models.py @@ -0,0 +1,585 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""FMP server state models module. + +This module contains ORM-like models to be (de)serialized. +""" + +import time + +from threading import RLock + +try: + from typing import Any, Literal, TYPE_CHECKING # noqa: F401 + + ServerTypeLiteral = Literal[1, 2, 3, 4] + LayerStateLiteral = Literal[0, 1, 2] + + if TYPE_CHECKING: + # FIXME: Session.deserialize introduces cyclic import + from mh.session import Session +except (ImportError, TypeError): + pass + + +RESERVE_DC_TIMEOUT = 40.0 + + +class ServerType: + OPEN = 1 + ROOKIE = 2 + EXPERT = 3 + RECRUITING = 4 + + +class LayerState: + JOINABLE = 0 + EMPTY = 1 + FULL = 2 + + +class Lockable(object): + def __init__(self): + self._lock = RLock() + + def lock(self): + return self + + def __enter__(self): + # Returns True if lock was acquired, False otherwise + return self._lock.acquire() + + def __exit__(self, *args): + # type: (Any) -> None + self._lock.release() + + +class Players(Lockable): + """Helper class to help retain player IDs.""" + def __init__(self, capacity): + # type: (int) -> None + assert capacity > 0, "Collection capacity can't be zero" + + self.slots = [ + None for _ in range(capacity) + ] # type: list[None | Session] + self.used = 0 + super(Players, self).__init__() + + def get_used_count(self): + # type: () -> int + return self.used + + def get_capacity(self): + # type: () -> int + return len(self.slots) + + def add(self, item): + # type: (Session) -> int + with self.lock(): + if self.used >= len(self.slots): + return -1 + + item_index = self.index(item) + if item_index != -1: + return item_index + + for i, v in enumerate(self.slots): + if v is not None: + continue + + self.slots[i] = item + self.used += 1 + return i + + return -1 + + def remove(self, item): + # type: (Session | int) -> bool + assert item is not None, "Item != None" + + with self.lock(): + if self.used < 1: + return False + + if isinstance(item, int): + if item >= self.get_capacity(): + return False + + self.slots[item] = None + self.used -= 1 + return True + + for i, v in enumerate(self.slots): + if v != item: + continue + + self.slots[i] = None + self.used -= 1 + return True + + return False + + def index(self, item): + # type: (Session) -> int + assert item is not None, "Item != None" + + for i, v in enumerate(self.slots): + if v == item: + return i + + return -1 + + def clear(self): + # type: () -> None + with self.lock(): + for i in range(self.get_capacity()): + self.slots[i] = None + + def find_first(self, **kwargs): + # type: (Any) -> None | Session + if self.used < 1: + return None + + for p in self.slots: + if p is None: + continue + + for k, v in kwargs.items(): + if getattr(p, k) != v: + break + else: + return p + + return None + + def find_by_capcom_id(self, capcom_id): + # type: (str) -> None | Session + return self.find_first(capcom_id=capcom_id) + + def __len__(self): + return self.used + + def __iter__(self): + if self.used < 1: + return + + for i, v in enumerate(self.slots): + if v is None: + continue + + yield i, v + + def serialize(self): + # type: () -> dict[str, Any] + if not self.used: + return {"capacity": len(self.slots)} + return { + "slots": [ + (p.serialize() + if p is not None + else None) + for p in self.slots + ], + "used": self.used + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], None) -> Players + if "used" not in obj.keys(): + return Players(obj["capacity"]) + + from mh.session import Session # avoid cyclic import + + players = Players(len(obj["slots"])) + players.slots = [ + (Session.deserialize(p) + if p is not None + else None) + for p in obj["slots"] + ] + players.used = obj["used"] + return players + + +class Circle(Lockable): + def __init__(self, parent): + # type: (City) -> None + self.parent = parent + self.leader = None # type: None | Session + self.players = Players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False # FIXME: Seems never used + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + super(Circle, self).__init__() + + def get_population(self): + # type: () -> int + return len(self.players) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def is_full(self): + # type: () -> bool + return self.get_population() == self.get_capacity() + + def is_empty(self): + # type: () -> bool + return self.leader is None + + def is_joinable(self): + # type: () -> bool + return not self.departed and not self.is_full() + + def has_password(self): + # type: () -> bool + return self.password is not None + + def reset_players(self, capacity): + # type: (int) -> None + with self.lock(): + self.players = Players(capacity) + + def reset(self): + # type: () -> None + with self.lock(): + self.leader = None + self.reset_players(4) + self.departed = False + self.quest_id = 0 + self.embarked = False + self.password = None + self.remarks = None + + self.unk_byte_0x0e = 0 + + def serialize(self): + # type: () -> dict[str, Any] + players = self.players.serialize() + if "used" not in players.keys(): + return {} + return { + "parent": None, + "leader": + self.leader.serialize() + if self.leader is not None + else None, + "players": players, + "departed": self.departed, + "quest_id": self.quest_id, + "embarked": self.embarked, + "password": self.password, + "remarks": self.remarks, + "unk_byte_0x0e": self.unk_byte_0x0e + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], City) -> Circle + circle = Circle(parent) + if not obj.keys(): + return circle + + from mh.session import Session # avoid cyclic import + + circle.leader = \ + Session.deserialize(obj["leader"]) \ + if obj["leader"] is not None \ + else None + # TODO: Players class doesn't have "parent" member variable + circle.players = Players.deserialize(obj["players"], circle) + circle.departed = obj["departed"] + circle.quest_id = obj["quest_id"] + circle.embarked = obj["embarked"] + circle.password = obj["password"] + circle.remarks = obj["remarks"] + circle.unk_byte_0x0e = obj["unk_byte_0x0e"] + return circle + + +class City(Lockable): + LAYER_DEPTH = 3 + + def __init__(self, name, parent): + # type: (str, Gate) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.players = Players(4) + self.optional_fields = [] + self.leader = None + self.reserved = None + self.circles = [ + # One circle per player + Circle(self) for _ in range(self.get_capacity()) + # NB: Might not work on the Japanese version, IIRC the limit is 5 + ] + super(City, self).__init__() + + def get_population(self): + # type: () -> int + return len(self.players) + + def in_quest_players(self): + # type: () -> int + return sum(p.is_in_quest() for _, p in self.players) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_state(self): + # type: () -> LayerStateLiteral + # FIXME: The following part was removed v + if self.reserved: + return LayerState.FULL + # TODO: ^ Backport it and make sure that wasn't a mistake + + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def is_empty(self): + # type: () -> bool + return self.get_state() == LayerState.EMPTY + + def get_pathname(self): + # type: () -> str + pathname = self.name # type: str + it = self.parent + # FIXME: Some tools don't like type-checking this loop at all + while it is not None: # type: ignore + pathname = it.name + "\t" + pathname + it = it.parent # type: Gate | Server | None + return pathname + + def get_first_empty_circle(self): + # type: () -> tuple[None, None] | tuple[Circle, int] + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.is_empty(): + return circle, index + return None, None + + def get_circle_for(self, leader_session): + # type: (Session) -> tuple[None, None] | tuple[Circle, int] + with self.lock(): + for index, circle in enumerate(self.circles): + if circle.leader == leader_session: + return circle, index + return None, None + + def clear_circles(self): + # type: () -> None + with self.lock(): + for circle in self.circles: + circle.reset() + + def reserve(self, reserve): + # type: (bool) -> None + with self.lock(): + if reserve: + self.reserved = time.time() + else: + self.reserved = None + + def is_reserved(self): + # type: () -> bool + reserved_time = self.reserved # type: None | float + if reserved_time: + return time.time()-reserved_time < RESERVE_DC_TIMEOUT + return False + + def get_all_players(self): + # type: () -> list[Session] + with self.players.lock(): + return [p for _, p in self.players] + + def serialize(self): + # type: () -> dict[str, Any] + players = self.players.serialize() + if "used" not in players.keys(): + return {"name": self.name} + return { + "name": self.name, + "parent": None, # TODO: Why (not) serializing it? + "state": self.state, + "players": players, + "optional_fields": self.optional_fields, + "leader": + self.leader.serialize() + if self.leader is not None + else None, + "reserved": self.reserved, + "circles": [c.serialize() for c in self.circles] + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], Gate) -> City + if len(obj.keys()) < 2: + return City(obj["name"], None) + city = City( + str(obj["name"]) if obj["name"] is not None else obj["name"], + parent + ) + city.state = obj["state"] + city.players = Players.deserialize(obj["players"], parent) + city.optional_fields = obj["optional_fields"] + city.leader = \ + Session.deserialize(obj["leader"]) \ + if obj["leader"] is not None \ + else None + city.reserved = obj["reserved"] + city.circles = [Circle.deserialize(c, city) for c in obj["circles"]] + return city + + +class Gate(object): + LAYER_DEPTH = 2 + + def __init__(self, name, parent, city_count=40, player_capacity=100): + # type: (str, Server, int, int) -> None + self.name = name + self.parent = parent + self.state = LayerState.EMPTY + self.cities = [ + City("City{}".format(i), self) + for i in range(1, city_count+1) + ] + self.players = Players(player_capacity) + self.optional_fields = [] + + def get_population(self): + # type: () -> int + return len(self.players) + sum(( + city.get_population() + for city in self.cities + )) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_state(self): + # type: () -> LayerStateLiteral + size = self.get_population() + if size == 0: + return LayerState.EMPTY + elif size < self.get_capacity(): + return LayerState.JOINABLE + else: + return LayerState.FULL + + def get_all_players(self): + # type: () -> list[Session] + # TODO: Backport its use, e.g. in cache + players = [p for _, p in self.players] + for city in self.cities: + players += city.get_all_players() + return players + + def serialize(self): + # type: () -> dict[str, Any] + return { + "name": self.name, + "parent": None, + "state": self.state, + "cities": [c.serialize() for c in self.cities], + "players": self.players.serialize(), + "optional_fields": self.optional_fields + } + + @staticmethod + def deserialize(obj, parent): + # type: (dict[str, Any], Server) -> Gate + gate = Gate( + str(obj["name"]) if obj["name"] is not None else obj["name"], + parent + ) + gate.state = obj["state"] + gate.cities = [City.deserialize(c, gate) for c in obj["cities"]] + gate.players = Players.deserialize(obj["players"], gate) + gate.optional_fields = obj["optional_fields"] + return gate + + +class Server(object): + LAYER_DEPTH = 1 + + def __init__(self, name, server_type, gate_count=40, capacity=2000, + addr=None, port=None): + # type: (str, ServerTypeLiteral, int, int, None | str, None | int) -> None # noqa: E501 + # TODO: Backport the constructor change if needed + self.name = name + self.parent = None + self.server_type = server_type + # Public IP address + self.addr = addr # type: None | str + self.port = port # type: None | int + self.gates = [ + Gate("City Gate{}".format(i), self) + for i in range(1, gate_count+1) + ] + self.players = Players(capacity) + + def get_population(self): + # type: () -> int + return len(self.players) + sum(( + gate.get_population() for gate in self.gates + )) + + def get_capacity(self): + # type: () -> int + return self.players.get_capacity() + + def get_all_players(self): + # type: () -> list[Session] + # TODO: Backport its use, e.g. in cache + players = [p for _, p in self.players] + for gate in self.gates: + players = players + gate.get_all_players() + return players + + def serialize(self): + # type: () -> dict[str, Any] + return { + "name": self.name, + "parent": self.parent, + "server_type": self.server_type, + "addr": self.addr, + "port": self.port, + "gates": [g.serialize() for g in self.gates], + "players": self.players.serialize() + } + + @staticmethod + def deserialize(obj): + # type: (dict[str, Any]) -> Server + server = Server( + str(obj["name"]) if obj["name"] is not None + else obj["name"], + int(obj["server_type"]) if obj["server_type"] + else obj["server_type"], + addr=str(obj["addr"]) if obj["addr"] is not None + else obj["addr"], + port=int(obj["port"]) if obj["port"] is not None + else obj["port"] + ) + server.parent = obj["parent"] + server.gates = [Gate.deserialize(g, server) for g in obj["gates"]] + server.players = Players.deserialize(obj["players"], server) + return server From 7368bf3181eca6ace42a66531e69cce3c0de0f81 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Mon, 4 Aug 2025 02:21:53 +0400 Subject: [PATCH 2/5] other: Add a config module with some helper classes Fix MaxThread option being ignored. Remove unused legacy_ssl parameter. --- mh/database.py | 9 +- mh/pat.py | 7 +- mh/pat_item.py | 6 +- other/config.py | 233 ++++++++++++++++++++++++++++++++++++++++++++++++ other/utils.py | 133 ++------------------------- 5 files changed, 255 insertions(+), 133 deletions(-) create mode 100644 other/config.py diff --git a/mh/database.py b/mh/database.py index 9c17bfe..9b6afb7 100644 --- a/mh/database.py +++ b/mh/database.py @@ -7,9 +7,12 @@ import inspect import random import sqlite3 -from other import utils + from threading import local as thread_local +from other import utils +from other.config import MySQLConfig + CHARSET = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -415,7 +418,7 @@ def __init__(self): self.parent.__init__() from mysql import connector self.connection = connector.connect( - **utils.get_mysql_config("MYSQL") + **MySQLConfig().connect_kwargs() ) self.create_database() self.populate_database() @@ -659,7 +662,7 @@ def __init__(self, *args, **kwargs): CURRENT_DB = \ MySQLDatabase() \ - if utils.is_mysql_enabled("MYSQL") \ + if MySQLConfig().is_enabled() \ else TempSQLiteDatabase() # type: TempDatabase diff --git a/mh/pat.py b/mh/pat.py index 5d1e111..a4db9d7 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -8,7 +8,8 @@ import traceback from datetime import timedelta -from other.utils import Logger, get_config, get_external_ip, hexdump, to_str +from other.config import ServerConfig +from other.utils import Logger, get_external_ip, hexdump, to_str from mh.quest_utils import QuestLoader import mh.pat_item as pati @@ -646,7 +647,7 @@ def recvReqLmpConnect(self, packet_id, data, seq): TODO: I don't think it's related to LMP protocol. """ - config = get_config("LMP") + config = ServerConfig("LMP") self.sendAnsLmpConnect(get_external_ip(config), config["Port"], seq) def sendAnsLmpConnect(self, address, port, seq): @@ -1151,7 +1152,7 @@ def recvReqFmpInfo(self, packet_id, data, seq): # FIXME: Doesn't seem to make sense here, # e.g. on LMP server, as "FmpInfo" packet # server = self.session.join_server(index) - config = get_config("FMP") + config = ServerConfig("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] fmp_data = pati.FmpData() diff --git a/mh/pat_item.py b/mh/pat_item.py index 7d93455..1910a10 100644 --- a/mh/pat_item.py +++ b/mh/pat_item.py @@ -10,8 +10,8 @@ from mh.constants import pad from mh.state_models import Server, Gate, City -from other.utils import to_bytearray, get_config, get_external_ip, \ - GenericUnpacker +from other.config import ServerConfig +from other.utils import to_bytearray, get_external_ip, GenericUnpacker class ItemType: @@ -1027,7 +1027,7 @@ def pack(self): def get_fmp_servers(session, first_index, count): assert first_index > 0, "Invalid list index" - config = get_config("FMP") + config = ServerConfig("FMP") fmp_addr = get_external_ip(config) fmp_port = config["Port"] diff --git a/other/config.py b/other/config.py new file mode 100644 index 0000000..9ce5bd0 --- /dev/null +++ b/other/config.py @@ -0,0 +1,233 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2021-2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Config module.""" + +from collections import OrderedDict + +try: + from collections.abc import Callable # noqa: F401 + from typing import TYPE_CHECKING, Any + IS_PYTHON2 = False +except ImportError: + TYPE_CHECKING = False # type: ignore + IS_PYTHON2 = True # type: ignore + +if TYPE_CHECKING: + OrderedDictT = OrderedDict[str, Any] + from configparser import RawConfigParser + from argparse import ArgumentParser # noqa: F401 +else: # workaround to avoid some type hinting issues + OrderedDictT = OrderedDict + if IS_PYTHON2: + from ConfigParser import RawConfigParser + else: + from configparser import RawConfigParser + + +CONFIG_FILE = "config.ini" + + +class ConfigLoader(RawConfigParser): + """Generic INI config loader class.""" + def __init__(self, config_path): + # type: (str) -> None + RawConfigParser.__init__(self, allow_no_value=True) + # Override this member to avoid options being lowercased + self.optionxform = lambda optionstr: str(optionstr) + self.read(config_path) + self._config_path = config_path + + def reload(self): + # type: () -> None + self.read(self._config_path) + + +class ConfigSection(OrderedDictT): + """ConfigSection helper class. + + This class loads all options from a config section into a dict. + """ + BOOL = tuple() # type: tuple[str, ...] + INT = tuple() # type: tuple[str, ...] + FLOAT = tuple() # type: tuple[str, ...] + STR = tuple() # type: tuple[str, ...] + # Special cases + SP = {} # type: dict[str, Callable[[ConfigLoader, str], str]] + + def __init__(self, section_name, config_path=CONFIG_FILE): + # type: (str, str) -> None + super(ConfigSection, self).__init__() + self._config = ConfigLoader(config_path) + self._section_name = section_name + self._config_path = config_path + self._load_section() + self._validate() + + def _load_section(self): + # type: () -> None + for option in self._config.options(self._section_name): + self[option] = ( + self._config.getboolean(self._section_name, option) + if option in self.BOOL + else + self._config.getint(self._section_name, option) + if option in self.INT + else + self._config.getfloat(self._section_name, option) + if option in self.FLOAT + else + self.SP[option](self._config, self._section_name) + if option in self.SP + else + self._config.get(self._section_name, option) + ) + + def _validate(self): + # type: () -> None + missing_options = [] # type: list[str] + for typed_options in (self.BOOL, self.INT, self.FLOAT, self.STR): + missing_options.extend( + option + for option in typed_options + if option not in self + ) + missing_options.extend( + option + for option in self.SP + if option not in self + ) + message = '{}: section "{}" has missing option(s): {}'.format( + self._config_path, self._section_name, ", ".join(missing_options) + ) + assert not missing_options, message + + def reload(self): + # type: () -> None + """Reload the config.""" + self._config.reload() + self._load_section() + self._validate() + + +class BaseServerConfig(ConfigSection): + INT = ("Port",) + BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow") + STR = ("IP", "ExternalIP", "Name", "LogFilename") + SP = { + "SSLCert": + lambda cfg, section: cfg.get(section, "SSLCert") + or cfg.get("SSL", "DefaultCert"), + "SSLKey": + lambda cfg, section: cfg.get(section, "SSLKey") + or cfg.get("SSL", "DefaultKey") + } + + +class ServerConfig(BaseServerConfig): + """OPN/LMP/FMP/RFP server config.""" + INT = BaseServerConfig.INT + ("MaxThread",) + + +class MySQLConfig(ConfigSection): + BOOL = ("enabled",) + STR = ("user", "password", "host", "database", + "ssl_ca", "ssl_cert", "ssl_key") + + def __init__(self, section_name="MYSQL", config_path=CONFIG_FILE): + # type: (str, str) -> None + super(MySQLConfig, self).__init__(section_name, config_path) + + def is_enabled(self): + # type: () -> bool + return self["enabled"] # type: ignore + + def connect_kwargs(self): + # type: () -> dict[str, Any] + from mysql.connector.constants import ClientFlag + kwargs = { + k: v for k, v in self.items() + if k not in ("enabled",) + } # type: dict[str, Any] + kwargs.update({ + "charset": "utf8", + "autocommit": True, + "client_flags": [ClientFlag.SSL] if kwargs["ssl_ca"] else None, + "ssl_ca": kwargs["ssl_ca"] or None, + "ssl_cert": kwargs["ssl_cert"] or None, + "ssl_key": kwargs["ssl_key"] or None + }) + return kwargs + + +# TODO: Backport latest_patch and central config code + + +def argparse_from_config(config): + # type: (BaseServerConfig) -> ArgumentParser + """Argument parser from config.""" + import argparse + + def typebool(s): + # type: (bool | str) -> bool + if isinstance(s, bool): + return s + s = s.lower() + if s in ("on", "yes", "y", "true", "t", "1"): + return True + elif s in ("off", "no", "n", "false", "f", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-i", "--interactive", action="store_true", + dest="interactive", + help="create an interactive shell") + parser.add_argument("-d", "--debug_mode", action="store_true", + dest="debug_mode", + help="enable debug mode, disabling timeouts and \ + lower logging verbosity level") + parser.add_argument("-a", "--address", action="store", type=str, + default=config["IP"], dest="address", + help="set server address") + parser.add_argument("-p", "--port", action="store", type=int, + default=config["Port"], dest="port", + help="set server port") + parser.add_argument("-n", "--name", action="store", type=str, + default=config["Name"], dest="name", + help="set server name") + parser.add_argument("-s", "--use-ssl", action="store", type=typebool, + default=config["UseSSL"], dest="use_ssl", + help="use SSL protocol") + parser.add_argument("-c", "--ssl-cert", action="store", type=str, + default=config["SSLCert"], dest="ssl_cert", + help="set server SSL certificate") + parser.add_argument("-k", "--ssl-key", action="store", type=str, + default=config["SSLKey"], dest="ssl_key", + help="set server SSL private key") + parser.add_argument("-l", "--log-filename", action="store", type=str, + default=config["LogFilename"], dest="log_filename", + help="set server log filename") + parser.add_argument("--log-to-file", action="store", type=typebool, + default=config["LogToFile"], dest="log_to_file", + help="log output to file") + parser.add_argument("--log-to-console", action="store", type=typebool, + default=config["LogToConsole"], dest="log_to_console", + help="log output to console") + parser.add_argument("--log-to-window", action="store", type=typebool, + default=config["LogToWindow"], dest="log_to_window", + help="log output to a new window") + parser.add_argument("--dry-run", action="store_true", + dest="dry_run", + help="dry run to test the server") + parser.add_argument("-t", "--no-timeout", action="store_true", + dest="no_timeout", + help="disable player timeouts") + if "MaxThread" in config: + parser.add_argument("--max-thread", action="store", type=int, + default=config["MaxThread"], dest="max_thread", + help="log output to a new window") + return parser diff --git a/other/utils.py b/other/utils.py index b55bb61..ed1b3d0 100644 --- a/other/utils.py +++ b/other/utils.py @@ -13,19 +13,18 @@ from collections import namedtuple from functools import partial from logging.handlers import TimedRotatingFileHandler +from other.config import ServerConfig, argparse_from_config from other.debug import register_debug_signal, dry_run try: # Python 2 basestring # str, unicode - import ConfigParser except NameError: # Python 3 basestring = str - import configparser as ConfigParser from typing import Any # noqa: F401 -CONFIG_FILE = "config.ini" + LOG_FOLDER = "logs" @@ -216,59 +215,6 @@ def create_logger(name, level=logging.DEBUG, log_to_file="", return logger -def get_config(name, config_file=CONFIG_FILE): - """Get server config.""" - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - return { - "IP": config.get(name, "IP"), - "ExternalIP": config.get(name, "ExternalIP"), - "Port": config.getint(name, "Port"), - "Name": config.get(name, "Name"), - "MaxThread": config.getint(name, "MaxThread"), - "UseSSL": config.getboolean(name, "UseSSL"), - "SSLCert": - config.get(name, "SSLCert") or - config.get("SSL", "DefaultCert"), - "SSLKey": - config.get(name, "SSLKey") or - config.get("SSL", "DefaultKey"), - "LogFilename": config.get(name, "LogFilename"), - "LogToConsole": config.getboolean(name, "LogToConsole"), - "LogToFile": config.getboolean(name, "LogToFile"), - "LogToWindow": config.getboolean(name, "LogToWindow"), - } - - -def get_mysql_config(name, config_file=CONFIG_FILE): - """Get MySQL config.""" - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - ssl_ca = config.get(name, "ssl_ca") or None - from mysql.connector.constants import ClientFlag - return { - "charset": "utf8", - "autocommit": True, - "user": config.get(name, "User"), - "password": config.get(name, "Password"), - "host": config.get(name, "Host"), - "database": config.get(name, "database"), - "client_flags": [ClientFlag.SSL] if ssl_ca else None, - "ssl_ca": ssl_ca, - "ssl_cert": config.get(name, "ssl_cert") or None, - "ssl_key": config.get(name, "ssl_key") or None - } - - -def is_mysql_enabled(name, config_file=CONFIG_FILE): - config = ConfigParser.RawConfigParser(allow_no_value=True) - config.read(config_file) - return config.getboolean(name, "Enabled") - - -# TODO: Backport latest_patch and central config code - - def get_default_ip(): # type: () -> str """Get the default IP address""" @@ -295,69 +241,6 @@ def get_external_ip(config): return config["ExternalIP"] or get_ip(config["IP"]) -def argparse_from_config(config): - """Argument parser from config.""" - import argparse - - def typebool(s): - if isinstance(s, bool): - return s - s = s.lower() - if s in ("on", "yes", "y", "true", "t", "1"): - return True - elif s in ("off", "no", "n", "false", "f", "0"): - return False - else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("-i", "--interactive", action="store_true", - dest="interactive", - help="create an interactive shell") - parser.add_argument("-d", "--debug_mode", action="store_true", - dest="debug_mode", - help="enable debug mode, disabling timeouts and \ - lower logging verbosity level") - parser.add_argument("-a", "--address", action="store", type=str, - default=config["IP"], dest="address", - help="set server address") - parser.add_argument("-p", "--port", action="store", type=int, - default=config["Port"], dest="port", - help="set server port") - parser.add_argument("-n", "--name", action="store", type=str, - default=config["Name"], dest="name", - help="set server name") - parser.add_argument("-s", "--use-ssl", action="store", type=typebool, - default=config["UseSSL"], dest="use_ssl", - help="use SSL protocol") - parser.add_argument("-c", "--ssl-cert", action="store", type=str, - default=config["SSLCert"], dest="ssl_cert", - help="set server SSL certificate") - parser.add_argument("-k", "--ssl-key", action="store", type=str, - default=config["SSLKey"], dest="ssl_key", - help="set server SSL private key") - parser.add_argument("-l", "--log-filename", action="store", type=str, - default=config["LogFilename"], dest="log_filename", - help="set server log filename") - parser.add_argument("--log-to-file", action="store", type=typebool, - default=config["LogToFile"], dest="log_to_file", - help="log output to file") - parser.add_argument("--log-to-console", action="store", type=typebool, - default=config["LogToConsole"], dest="log_to_console", - help="log output to console") - parser.add_argument("--log-to-window", action="store", type=typebool, - default=config["LogToWindow"], dest="log_to_window", - help="log output to a new window") - parser.add_argument("--dry-run", action="store_true", - dest="dry_run", - help="dry run to test the server") - parser.add_argument("-t", "--no-timeout", action="store_true", - dest="no_timeout", - help="disable player timeouts") - return parser - - def wii_ssl_wrap_socket(sock, ssl_cert, ssl_key): """SSL wrapper for network sockets aiming Wii compatibility. @@ -398,7 +281,7 @@ def create_server(server_class, server_handler, address="0.0.0.0", port=8200, name="Server", max_thread=0, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", log_to_file=True, log_filename="server.log", - log_to_console=True, log_to_window=False, legacy_ssl=False, + log_to_console=True, log_to_window=False, debug_mode=False, no_timeout=False): """Create a server, its logger and the SSL context if needed.""" logger = create_logger( @@ -409,9 +292,11 @@ def create_server(server_class, server_handler, if not use_ssl: ssl_cert = None ssl_key = None - return server_class((address, port), server_handler, max_thread, logger, - debug_mode, ssl_cert=ssl_cert, ssl_key=ssl_key, - no_timeout=no_timeout) + return server_class( + (address, port), server_handler, + max_thread_count=max_thread, logger=logger, debug_mode=debug_mode, + ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout + ) server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) @@ -422,7 +307,7 @@ def create_server_from_base(name, server_class, server_handler, args=None): If args is None, sys.argv is used (see ArgumentParser.parser_args). """ - config = get_config(name) + config = ServerConfig(name) # TODO: Backport central config code if needed parser = argparse_from_config(config) args = parser.parse_args(args) From aeea6eb881e37ef927b3eb06ed7d6610aec18fd1 Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sat, 1 Nov 2025 20:19:08 +0400 Subject: [PATCH 3/5] other: Add a comment regarding TIME_STATE constant --- mh/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mh/constants.py b/mh/constants.py index b3efc77..1cda8f1 100644 --- a/mh/constants.py +++ b/mh/constants.py @@ -280,7 +280,7 @@ def make_binary_trading_post(): FMP_VERSION = 1 # TODO: Backport central and NATNEG constants -TIME_STATE = 0 +TIME_STATE = 0 # FIXME: Unused since quest rotation added, see time_utils IS_JAP = False From c461fae7668ed1f2618c0e7795afbc3d5300d09a Mon Sep 17 00:00:00 2001 From: Sepalani Date: Sat, 1 Nov 2025 20:54:20 +0400 Subject: [PATCH 4/5] other: Add config field to enable/disable servers --- config.ini | 4 ++++ master_server.py | 7 ++++--- mh/pat.py | 2 +- mh/server.py | 2 +- other/config.py | 33 ++++++++++++++++++++++++--------- other/python.py | 10 ++++++++++ other/utils.py | 20 +++++++++++++++----- 7 files changed, 59 insertions(+), 19 deletions(-) create mode 100644 other/python.py diff --git a/config.ini b/config.ini index 5d8edc5..e8ffdff 100644 --- a/config.ini +++ b/config.ini @@ -13,6 +13,7 @@ ssl_cert = ssl_key = [OPN] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8200 @@ -27,6 +28,7 @@ LogToFile = ON LogToWindow = ON [LMP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8201 @@ -41,6 +43,7 @@ LogToFile = ON LogToWindow = ON [FMP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8202 @@ -55,6 +58,7 @@ LogToFile = ON LogToWindow = ON [RFP] +Enabled = ON IP = 0.0.0.0 ExternalIP = Port = 8203 diff --git a/master_server.py b/master_server.py index dbe650d..47c187b 100644 --- a/master_server.py +++ b/master_server.py @@ -23,8 +23,9 @@ def create_servers(server_args): has_ui = False for module in (OPN, LMP, FMP, RFP): server, args = create_server_from_base(*module.BASE, args=server_args) - has_ui = has_ui or args.log_to_window - servers.append(server) + if server: + has_ui = has_ui or args.log_to_window + servers.append(server) return servers, has_ui @@ -73,7 +74,7 @@ def interactive_mode(local=locals()): if args.interactive: t.join() except KeyboardInterrupt: - print("Interrupt key was pressed, closing server...") + print("Interrupt key was pressed, closing servers...") except Exception: print('Unexpected exception caught...') traceback.print_exc() diff --git a/mh/pat.py b/mh/pat.py index a4db9d7..9df0753 100644 --- a/mh/pat.py +++ b/mh/pat.py @@ -42,7 +42,7 @@ class PatServer(server.BasicPatServer, Logger): def __init__(self, address, handler_class, max_thread_count=0, logger=None, debug_mode=False, ssl_cert=None, ssl_key=None, - no_timeout=False): + no_timeout=False, **kwargs): server.BasicPatServer.__init__( self, address, handler_class, max_thread_count, ssl_cert=ssl_cert, ssl_key=ssl_key diff --git a/mh/server.py b/mh/server.py index 8ba056d..0e68c7c 100644 --- a/mh/server.py +++ b/mh/server.py @@ -374,4 +374,4 @@ def close(self): self.selector = None self.worker_threads = [] self.__shutdown_request = False - self.info('Server Closed') + self.info('Server closed') diff --git a/other/config.py b/other/config.py index 9ce5bd0..61d9ded 100644 --- a/other/config.py +++ b/other/config.py @@ -6,21 +6,17 @@ from collections import OrderedDict -try: - from collections.abc import Callable # noqa: F401 - from typing import TYPE_CHECKING, Any - IS_PYTHON2 = False -except ImportError: - TYPE_CHECKING = False # type: ignore - IS_PYTHON2 = True # type: ignore +from other.python import PYTHON_VERSION, TYPE_CHECKING if TYPE_CHECKING: + from collections.abc import Callable # noqa: F401 + from typing import Any OrderedDictT = OrderedDict[str, Any] from configparser import RawConfigParser from argparse import ArgumentParser # noqa: F401 else: # workaround to avoid some type hinting issues OrderedDictT = OrderedDict - if IS_PYTHON2: + if PYTHON_VERSION == 2: from ConfigParser import RawConfigParser else: from configparser import RawConfigParser @@ -112,8 +108,9 @@ def reload(self): class BaseServerConfig(ConfigSection): + SERVER_NAMES = tuple() # type: tuple[str, ...] INT = ("Port",) - BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow") + BOOL = ("UseSSL", "LogToConsole", "LogToFile", "LogToWindow", "Enabled") STR = ("IP", "ExternalIP", "Name", "LogFilename") SP = { "SSLCert": @@ -124,9 +121,18 @@ class BaseServerConfig(ConfigSection): or cfg.get("SSL", "DefaultKey") } + def to_argument_parser(self): + # type: () -> ArgumentParser + """Create a generic ArgumentParser from the server config. + + It can be used in a main function to parse command-line arguments.""" + # TODO: Move the code here when the refactoring is completed + return argparse_from_config(self) + class ServerConfig(BaseServerConfig): """OPN/LMP/FMP/RFP server config.""" + SERVER_NAMES = ("OPN", "LMP", "FMP", "RFP") INT = BaseServerConfig.INT + ("MaxThread",) @@ -164,6 +170,15 @@ def connect_kwargs(self): # TODO: Backport latest_patch and central config code +def config_from_name(name): + # type: (str) -> ServerConfig | CentralConfig + """Return the server config based on its name.""" + if name in ServerConfig.SERVER_NAMES: + return ServerConfig(name) + else: + raise NotImplementedError() + + def argparse_from_config(config): # type: (BaseServerConfig) -> ArgumentParser """Argument parser from config.""" diff --git a/other/python.py b/other/python.py new file mode 100644 index 0000000..5974a14 --- /dev/null +++ b/other/python.py @@ -0,0 +1,10 @@ +#! /usr/bin/env python +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (C) 2025 MH3SP Server Project +# SPDX-License-Identifier: AGPL-3.0-or-later +"""Python 2/3 helpers.""" + +import sys + +PYTHON_VERSION = sys.version_info.major +TYPE_CHECKING = False diff --git a/other/utils.py b/other/utils.py index ed1b3d0..495c75b 100644 --- a/other/utils.py +++ b/other/utils.py @@ -13,7 +13,7 @@ from collections import namedtuple from functools import partial from logging.handlers import TimedRotatingFileHandler -from other.config import ServerConfig, argparse_from_config +from other.config import config_from_name from other.debug import register_debug_signal, dry_run try: @@ -32,34 +32,40 @@ class Logger(object): """Generic logging class.""" def set_logger(self, logger): + # type: (Logger) -> None """Set logger.""" self.logger = logger def debug(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a debug message.""" if not hasattr(self, "logger"): return return self.logger.debug(msg, *args, **kwargs) def info(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a message.""" if not hasattr(self, "logger"): return return self.logger.info(msg, *args, **kwargs) def warning(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a warning message.""" if not hasattr(self, "logger"): return return self.logger.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log an error message.""" if not hasattr(self, "logger"): return return self.logger.error(msg, *args, **kwargs) def critical(self, msg, *args, **kwargs): + # type: (str, *Any, **Any) -> None """Log a critical message.""" if not hasattr(self, "logger"): return @@ -282,7 +288,7 @@ def create_server(server_class, server_handler, use_ssl=True, ssl_cert="server.crt", ssl_key="server.key", log_to_file=True, log_filename="server.log", log_to_console=True, log_to_window=False, - debug_mode=False, no_timeout=False): + debug_mode=False, no_timeout=False, **kwargs): """Create a server, its logger and the SSL context if needed.""" logger = create_logger( name, level=logging.DEBUG if debug_mode else logging.INFO, @@ -295,7 +301,8 @@ def create_server(server_class, server_handler, return server_class( (address, port), server_handler, max_thread_count=max_thread, logger=logger, debug_mode=debug_mode, - ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout + ssl_cert=ssl_cert, ssl_key=ssl_key, no_timeout=no_timeout, + **kwargs ) @@ -307,9 +314,11 @@ def create_server_from_base(name, server_class, server_handler, args=None): If args is None, sys.argv is used (see ArgumentParser.parser_args). """ - config = ServerConfig(name) + config = config_from_name(name) + if not config["Enabled"]: + return None, args # TODO: Backport central config code if needed - parser = argparse_from_config(config) + parser = config.to_argument_parser() args = parser.parse_args(args) kwargs = { k: v for k, v in vars(args).items() @@ -324,6 +333,7 @@ def server_main(name, server_class, server_handler): server, args = create_server_from_base(name, server_class, server_handler) + assert server, "Server disabled by the config file" try: import threading From a45b0a25833ea6e4d2b7dcc725617a4cc585954d Mon Sep 17 00:00:00 2001 From: Sepalani Date: Wed, 5 Nov 2025 17:55:44 +0400 Subject: [PATCH 5/5] Fix Thread.join() edge cases Add some more type hints --- master_server.py | 48 ++++++++++++++++++++++++++++--------------- mh/server.py | 16 ++++++--------- other/debug.py | 49 +++++++++++++++++++++++++++++--------------- other/ui.py | 29 +++++++++++++++++++------- other/utils.py | 53 +++++++++++++++++++++++++++++++++++------------- 5 files changed, 132 insertions(+), 63 deletions(-) diff --git a/master_server.py b/master_server.py index 47c187b..eac27d9 100644 --- a/master_server.py +++ b/master_server.py @@ -14,22 +14,33 @@ import rfp_server as RFP from other.debug import register_debug_signal, dry_run +from other.python import TYPE_CHECKING from other.utils import create_server_from_base +if TYPE_CHECKING: + from argparse import Namespace # noqa: F401 + from collections.abc import Sequence # noqa: F401 + from typing import Any # noqa: F401 + + from mh.server import BasicPatServer # noqa: F401 + def create_servers(server_args): + # type: (Sequence[str]) -> tuple[list[BasicPatServer], bool] """Create servers and check if it has ui.""" - servers = [] + servers = [] # type: list[BasicPatServer] has_ui = False for module in (OPN, LMP, FMP, RFP): - server, args = create_server_from_base(*module.BASE, args=server_args) - if server: + server, args = create_server_from_base(*module.BASE, + cmd_args=server_args) # type: ignore[misc] # noqa: E501 + if server and args: has_ui = has_ui or args.log_to_window servers.append(server) return servers, has_ui def main(args): + # type: (Namespace) -> None """Master server main function.""" register_debug_signal() @@ -42,37 +53,40 @@ def main(args): for server in servers ] # TODO: Backport cache's logic (i.e. new thread, maintain_connection) - for thread in threads: - thread.start() def interactive_mode(local=locals()): + # type: (dict[str, Any]) -> None """Run an interactive python interpreter in another thread.""" import code code.interact(local=local) + repl_thread = threading.Thread(target=interactive_mode) + if has_ui: from other.ui import update as ui_update - ui_update() + else: + def ui_update(): + pass try: + ui_update() + for server_thread in threads: + server_thread.start() + if args.interactive: - t = threading.Thread(target=interactive_mode) - t.start() + repl_thread.start() if args.dry_run: dry_run() while threads: - for thread in threads: - if has_ui: - ui_update() - if not thread.is_alive(): - threads.remove(thread) + for server_thread in threads: + ui_update() + if not server_thread.is_alive(): + threads.remove(server_thread) break - thread.join(0.1) + server_thread.join(0.1) - if args.interactive: - t.join() except KeyboardInterrupt: print("Interrupt key was pressed, closing servers...") except Exception: @@ -82,6 +96,8 @@ def interactive_mode(local=locals()): finally: for server in servers: server.close() + if args.interactive and repl_thread.is_alive(): + repl_thread.join() if __name__ == "__main__": diff --git a/mh/server.py b/mh/server.py index 0e68c7c..dda20ec 100644 --- a/mh/server.py +++ b/mh/server.py @@ -12,14 +12,13 @@ import traceback from mh.time_utils import Timer +from other.python import PYTHON_VERSION, TYPE_CHECKING from other.utils import wii_ssl_wrap_socket -try: - # Python 3 +if TYPE_CHECKING or PYTHON_VERSION == 3: import queue import selectors -except ImportError: - # Python 2 +else: import Queue as queue import externals.selectors2 as selectors @@ -183,11 +182,7 @@ def fileno(self): return self.socket.fileno() def initialize_workers(self): - """Initialize workers queues/threads. - - This needs to be deferred, otherwise the close method might try to - join threads that aren't started yet when an error occurs early. - """ + """Initialize workers queues/threads.""" for n in range(self.max_threads): thread_queue = queue.Queue() thread = threading.Thread( @@ -368,7 +363,8 @@ def close(self): q.put((None, None, None), block=True) for t in self.worker_threads: - t.join() + if t.is_alive(): + t.join() self.worker_queues = [] self.selector = None diff --git a/other/debug.py b/other/debug.py index ad404b2..0574f7a 100644 --- a/other/debug.py +++ b/other/debug.py @@ -17,17 +17,23 @@ import traceback import signal -try: - # Python 2 - import ConfigParser -except ImportError: - # Python 3 - import configparser as ConfigParser +from other.python import PYTHON_VERSION, TYPE_CHECKING + +if TYPE_CHECKING or PYTHON_VERSION == 3: + from configparser import RawConfigParser +else: + from ConfigParser import RawConfigParser + +if TYPE_CHECKING: + from collections.abc import Callable # noqa: F401 + from types import FrameType # noqa: F401 + from typing import Any # noqa: F401 DEBUG_INI_PATH = "debug.ini" def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): + # type: (int, FrameType, str, str, **Any) -> None """Handler for debugpy on Visual Studio and VS Code. References: @@ -35,7 +41,7 @@ def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): https://code.visualstudio.com/docs/python/debugging https://learn.microsoft.com/visualstudio/python/debugging-python-in-visual-studio """ - import debugpy + import debugpy # type: ignore s = (addr, int(port)) # config's items are str debugpy.listen(s) print("Waiting for client on {}:{}\n".format(*s)) @@ -43,6 +49,7 @@ def debugpy_handler(sig, frame, addr="127.0.0.1", port="5678", **kwargs): def trepan_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Handler for trepan2/trepan3k. References: @@ -50,22 +57,24 @@ def trepan_handler(sig, frame, **kwargs): https://github.com/rocky/python3-trepan/ https://python2-trepan.readthedocs.io/en/latest/entry-exit.html """ - from trepan.api import debug + from trepan.api import debug # type: ignore debug() def pudb_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Handler for pudb on Linux and Cygwin. References: https://github.com/inducer/pudb https://documen.tician.de/pudb/ """ - import pudb + import pudb # type: ignore pudb.set_trace() def breakpoint_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """PDB/breakpoint handler. References: @@ -80,6 +89,7 @@ def breakpoint_handler(sig, frame, **kwargs): def code_interact_handler(sig, frame, **kwargs): + # type: (int, FrameType, **Any) -> None """Python interpreter handler. References: @@ -99,25 +109,29 @@ def code_interact_handler(sig, frame, **kwargs): "PUDB": pudb_handler, "BREAKPOINT": breakpoint_handler, "CODE": code_interact_handler -} +} # type: dict[str, Callable[..., Any]] def load_config(path=DEBUG_INI_PATH): - config = ConfigParser.RawConfigParser() + # type: (str) -> RawConfigParser + config = RawConfigParser() config.read(path) return config def load_handler_config(name, config): + # type: (str, RawConfigParser) -> dict[str, str] | None if config.has_section(name) and config.getboolean(name, "Enabled"): return { k.lower(): v for k, v in config.items(name) if k.lower() != "enabled" } + return None def debug_signal_handler(sig, frame): + # type: (int, FrameType) -> None """Default debug signal handler. Might raise EINTR/IOError when occuring during some syscalls on Python 2. @@ -148,6 +162,7 @@ def debug_signal_handler(sig, frame): def register_debug_signal(fn=debug_signal_handler): + # type: (Callable[..., Any]) -> None """Register a debug handler on SIGBREAK (Windows) or SIGUSR1 (Linux). On Windows, press CTRL+Pause/Break to trigger. @@ -156,17 +171,19 @@ def register_debug_signal(fn=debug_signal_handler): Will raise ValueError exception if not called from the main thread. """ if hasattr(signal, "SIGBREAK"): - signal.signal(signal.SIGBREAK, fn) + signal.signal(signal.SIGBREAK, fn) # type: ignore else: - signal.signal(signal.SIGUSR1, fn) + signal.signal(signal.SIGUSR1, fn) # type: ignore def dry_run(delay=10.): + # type: (float) -> None """Dry run test.""" from threading import Timer - try: - from thread import interrupt_main - except ImportError: + + if TYPE_CHECKING or PYTHON_VERSION == 3: from _thread import interrupt_main + else: + from thread import interrupt_main Timer(delay, interrupt_main).start() diff --git a/other/ui.py b/other/ui.py index 7addf11..ef11f76 100644 --- a/other/ui.py +++ b/other/ui.py @@ -5,19 +5,28 @@ """UI helper module.""" import logging -try: - # Python 3.x + +from other.python import PYTHON_VERSION, TYPE_CHECKING + +if TYPE_CHECKING or PYTHON_VERSION == 3: import tkinter as tk import tkinter.scrolledtext as ScrolledText from queue import Queue -except ImportError: - # Python 2.x +elif PYTHON_VERSION == 2: import Tkinter as tk import ScrolledText from Queue import Queue -WINDOWS = [] -EMITTERS = Queue() +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any + QueueT = Queue[Callable[[], Any]] +else: + QueueT = Queue + + +WINDOWS = [] # type: list[LoggerTk] +EMITTERS = QueueT() class LoggingHandler(logging.Handler): @@ -29,10 +38,12 @@ class LoggingHandler(logging.Handler): """ def __init__(self, text): + # type: (ScrolledText.ScrolledText) -> None logging.Handler.__init__(self) self.text = text def emit(self, record): + # type: (logging.LogRecord) -> None msg = self.format(record) def append(): @@ -40,7 +51,8 @@ def append(): self.text.configure(state='normal') self.text.insert(tk.END, msg + '\n') self.text.configure(state='disabled') - self.text.yview(tk.END) # Autoscroll to the bottom + # Autoscroll to the bottom + self.text.yview(tk.END) # type: ignore # Won't work on Python3.x # self.text.after(0, append) @@ -51,6 +63,7 @@ class LoggerTk(tk.Tk): """Create a logging window.""" def __init__(self, *args, **kwargs): + # type: (*Any, **Any) -> None tk.Tk.__init__(self, *args, **kwargs) text = ScrolledText.ScrolledText(self, state='disabled') @@ -63,10 +76,12 @@ def __init__(self, *args, **kwargs): self.handler = LoggingHandler(text) def get_handler(self): + # type: () -> LoggingHandler """Return the window's logging.Handler instance.""" return self.handler def set_logger(self, logger): + # type: (logging.Logger) -> None """Add the window's logging.Handler to the logger.""" logger.addHandler(self.handler) diff --git a/other/utils.py b/other/utils.py index 495c75b..7f5f220 100644 --- a/other/utils.py +++ b/other/utils.py @@ -15,14 +15,17 @@ from logging.handlers import TimedRotatingFileHandler from other.config import config_from_name from other.debug import register_debug_signal, dry_run +from other.python import PYTHON_VERSION, TYPE_CHECKING -try: - # Python 2 - basestring # str, unicode -except NameError: - # Python 3 - basestring = str - from typing import Any # noqa: F401 +if TYPE_CHECKING or PYTHON_VERSION == 3: + basestring = str # Python 2: str, unicode +if TYPE_CHECKING: + from argparse import Namespace # noqa: F401 + from collections.abc import Sequence # noqa: F401 + from ssl import SSLSocket # noqa: F401 + from typing import Any, NamedTuple # noqa: F401 + + from mh.pat import PatServer, PatRequestHandler LOG_FOLDER = "logs" @@ -137,6 +140,7 @@ def handler(self, name, unpack_function, pack_function, def to_bytearray(data): + # type: (Any) -> bytearray """Python2/3 bytearray helper.""" if isinstance(data, basestring): return bytearray((ord(c) % 256 for c in data)) @@ -147,10 +151,12 @@ def to_bytearray(data): def to_bytes(data): + # type: (Any) -> bytes return bytes(to_bytearray(data)) def to_str(data): + # type: (Any) -> str """Python2/3 str helper.""" if isinstance(data, str): return data @@ -158,20 +164,24 @@ def to_str(data): def pad(s, size, p=b'\0'): + # type: (bytes, int, bytes) -> bytearray data = bytearray(s + p * max(0, size-len(s))) data[-1] = 0 return data def hexdump(data): + # type: (bytes | bytearray) -> str """Get data hexdump.""" data = bytearray(data) line_format = "{line:08x} | {hex:47} | {ascii}" def hex_helper(b): + # type: (int) -> str return "{:02x}".format(b) def ascii_helper(b): + # type: (int) -> str return chr(b) if 0x20 <= b < 0x7F else '.' return "\n".join( @@ -186,6 +196,7 @@ def ascii_helper(b): def create_logger(name, level=logging.DEBUG, log_to_file="", log_to_console=False, log_to_window=False): + # type: (str, int, str, bool, bool) -> logging.Logger """Create a logger.""" logger = logging.getLogger(name) logger.setLevel(level) @@ -248,6 +259,7 @@ def get_external_ip(config): def wii_ssl_wrap_socket(sock, ssl_cert, ssl_key): + # type: (socket.socket, str, str) -> SSLSocket """SSL wrapper for network sockets aiming Wii compatibility. References: @@ -289,6 +301,7 @@ def create_server(server_class, server_handler, log_to_file=True, log_filename="server.log", log_to_console=True, log_to_window=False, debug_mode=False, no_timeout=False, **kwargs): + # type: (type[PatServer], type[PatRequestHandler], str, int, str, int, bool, str | None, str | None, bool, str, bool, bool, bool, bool, **Any) -> PatServer # noqa: E501 """Create a server, its logger and the SSL context if needed.""" logger = create_logger( name, level=logging.DEBUG if debug_mode else logging.INFO, @@ -306,20 +319,29 @@ def create_server(server_class, server_handler, ) -server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) +if TYPE_CHECKING: + # Python 2 doesn't support the class syntax + server_base = NamedTuple("server_base", [ + ("name", str), + ("cls", type[PatServer]), + ("handler", type[PatRequestHandler]) + ]) +else: + server_base = namedtuple("ServerBase", ["name", "cls", "handler"]) -def create_server_from_base(name, server_class, server_handler, args=None): +def create_server_from_base(name, server_class, server_handler, cmd_args=None): + # type: (str, type[PatServer], type[PatRequestHandler], Sequence[str] | None) -> tuple[PatServer, Namespace] | tuple[None, None] # noqa: E501 """Create a server based on its config parameters and supplied args. If args is None, sys.argv is used (see ArgumentParser.parser_args). """ config = config_from_name(name) if not config["Enabled"]: - return None, args + return None, None # TODO: Backport central config code if needed parser = config.to_argument_parser() - args = parser.parse_args(args) + args = parser.parse_args(cmd_args) kwargs = { k: v for k, v in vars(args).items() if k not in ("interactive", "dry_run") @@ -328,12 +350,13 @@ def create_server_from_base(name, server_class, server_handler, args=None): def server_main(name, server_class, server_handler): + # type: (str, type[PatServer], type[PatRequestHandler]) -> None """Create a server main based on its config parameters.""" register_debug_signal() server, args = create_server_from_base(name, server_class, server_handler) - assert server, "Server disabled by the config file" + assert server and args, "Server disabled by the config file" try: import threading @@ -350,11 +373,13 @@ def server_main(name, server_class, server_handler): if args.log_to_window: from other.ui import update as ui_update + else: + def ui_update(): + pass while thread.is_alive(): thread.join(0.1) # Timeout allows main thread to handle signals - if args.log_to_window: - ui_update() + ui_update() except KeyboardInterrupt: server.info("Interrupt key was pressed, closing server...") except Exception: