diff --git a/config.ini b/config.ini
index 2f3733a..908d5a7 100644
--- a/config.ini
+++ b/config.ini
@@ -2,6 +2,25 @@
DefaultCert = cert/server.crt
DefaultKey = cert/server.key
+[CENTRAL]
+CentralIP = 0.0.0.0
+CentralCrossconnectPort = 8300
+CrossconnectSSL = ON
+
+[SERVER1]
+Name = Valor1
+ServerType = 1
+Capacity = 400
+IP = 0.0.0.0
+Port = 8204
+
+[SERVER2]
+Name = Greed1
+ServerType = 4
+Capacity = 400
+IP = 0.0.0.0
+Port = 8205
+
[OPN]
IP = 0.0.0.0
Port = 8200
diff --git a/fmp_server.py b/fmp_server.py
index c7efd3b..7db87c0 100644
--- a/fmp_server.py
+++ b/fmp_server.py
@@ -6,6 +6,7 @@
import struct
+from mh.state import Players, get_instance
import mh.pat_item as pati
from mh.constants import PatID4
from mh.pat import PatRequestHandler, PatServer
@@ -14,7 +15,9 @@
class FmpServer(PatServer):
"""Basic FMP server class."""
- pass
+ def close(self):
+ PatServer.close(self)
+ get_instance().close_cache()
class FmpRequestHandler(PatRequestHandler):
@@ -24,7 +27,16 @@ def recvAnsConnection(self, packet_id, data, seq):
"""AnsConnection packet."""
connection_data = pati.ConnectionData.unpack(data)
self.server.debug("Connection: {!r}".format(connection_data))
- self.sendNtcLogin(3, connection_data, seq)
+ loaded_session = self.session.session_ready(connection_data)
+ if loaded_session:
+ if get_instance().server_id != 0:
+ self.session.set_session_ready(connection_data, False)
+ loaded_session.connection = self
+ self.session = loaded_session
+ get_instance().register_pat_ticket(self.session)
+ self.sendNtcLogin(3, connection_data, seq)
+ else:
+ self.session.set_session_ready(connection_data, (self, connection_data, seq))
def sendAnsLayerDown(self, layer_id, layer_set, seq):
"""AnsLayerDown packet.
diff --git a/master_server.py b/master_server.py
index fcd8c7f..f709371 100644
--- a/master_server.py
+++ b/master_server.py
@@ -15,15 +15,18 @@
from other.debug import register_debug_signal
from other.utils import create_server_from_base
+from other.cache import Cache
-def create_servers(silent=False, debug_mode=False):
+def create_servers(server_id, silent=False, debug_mode=False, no_timeout=False):
"""Create servers and check if it has ui."""
servers = []
has_ui = False
- for module in (OPN, LMP, FMP, RFP):
+ for module in ((OPN, LMP, FMP, RFP) if server_id==0 else (FMP,)):
server, has_window = create_server_from_base(*module.BASE,
+ server_id=server_id,
silent=silent,
- debug_mode=debug_mode)
+ debug_mode=debug_mode,
+ no_timeout=no_timeout)
has_ui = has_ui or has_window
servers.append(server)
return servers, has_ui
@@ -31,10 +34,10 @@ def create_servers(silent=False, debug_mode=False):
def main(args):
"""Master server main function."""
- register_debug_signal()
-
- servers, has_ui = create_servers(silent=args.silent,
- debug_mode=args.debug_mode)
+ servers, has_ui = create_servers(server_id=args.server_id,
+ silent=args.silent,
+ debug_mode=args.debug_mode,
+ no_timeout=args.no_timeout)
threads = [
threading.Thread(
target=server.serve_forever,
@@ -42,6 +45,10 @@ def main(args):
)
for server in servers
]
+ cache = Cache(server_id=args.server_id, debug_mode=args.debug_mode,
+ log_to_file=True, log_to_console=not args.silent,
+ log_to_window=False)
+ threads.append(threading.Thread(target=cache.maintain_connection))
for thread in threads:
thread.start()
@@ -92,7 +99,15 @@ def interactive_mode(local=locals()):
help="silent console logs")
parser.add_argument("-d", "--debug_mode", action="store_true",
dest="debug_mode",
- help="enable debug mode, disabling timeouts and \
- lower logging verbosity level")
+ help="enable debug mode, \
+ raising logging verbosity level")
+ parser.add_argument("-t", "--no_timeout", action="store_true",
+ dest="no_timeout",
+ help="disable player timeouts")
+ parser.add_argument("-S", "--server_id", type=int, default=0,
+ dest="server_id",
+ help="specifies the server id used to pull info \
+ from the config file (0 for central)")
+
args = parser.parse_args()
main(args)
diff --git a/mh/constants.py b/mh/constants.py
index 4040c4a..ba09fd9 100644
--- a/mh/constants.py
+++ b/mh/constants.py
@@ -307,10 +307,16 @@ def slot(item, qty):
b"Roadmap for a list of features we are working on.",
b"
Welcome to Loc Lac!"
])
+MAINTENANCE = b"
".join([
+ b"
Monster Hunter 3 (Tri) Server Project",
+ b"
MH3SP is currently down for maintenance.",
+ b"
Please check back later!"
+])
CHARGE = b"""MH3 Server Project - No charge."""
# VULGARITY_INFO = b"""MH3 Server Project - Vulgarity info (low)."""
VULGARITY_INFO = b""
-FMP_VERSION = 1
+FMP_CENTRAL_VERSION = 1
+FMP_VERSION = 2
TIME_STATE = 0
IS_JAP = False
diff --git a/mh/database.py b/mh/database.py
index cbe65b7..4ebf359 100644
--- a/mh/database.py
+++ b/mh/database.py
@@ -22,359 +22,10 @@
BLANK_CAPCOM_ID = "******"
-RESERVE_DC_TIMEOUT = 40.0
-
-
def new_random_str(length=6):
return "".join(random.choice(CHARSET) for _ in range(length))
-class ServerType(object):
- OPEN = 1
- ROOKIE = 2
- EXPERT = 3
- RECRUITING = 4
-
-
-class LayerState(object):
- EMPTY = 1
- FULL = 2
- JOINABLE = 3
-
-
-class Lockable(object):
- def __init__(self):
- self._lock = RLock()
-
- def lock(self):
- return self
-
- def __enter__(self):
- # Returns True if lock was acquired, False otherwise
- return self._lock.acquire()
-
- def __exit__(self, *args):
- self._lock.release()
-
-
-class Players(Lockable):
- def __init__(self, capacity):
- assert capacity > 0, "Collection capacity can't be zero"
-
- self.slots = [None for _ in range(capacity)]
- self.used = 0
- super(Players, self).__init__()
-
- def get_used_count(self):
- return self.used
-
- def get_capacity(self):
- return len(self.slots)
-
- def add(self, item):
- with self.lock():
- if self.used >= len(self.slots):
- return -1
-
- item_index = self.index(item)
- if item_index != -1:
- return item_index
-
- for i, v in enumerate(self.slots):
- if v is not None:
- continue
-
- self.slots[i] = item
- self.used += 1
- return i
-
- return -1
-
- def remove(self, item):
- assert item is not None, "Item != None"
-
- with self.lock():
- if self.used < 1:
- return False
-
- if isinstance(item, int):
- if item >= self.get_capacity():
- return False
-
- self.slots[item] = None
- self.used -= 1
- return True
-
- for i, v in enumerate(self.slots):
- if v != item:
- continue
-
- self.slots[i] = None
- self.used -= 1
- return True
-
- return False
-
- def index(self, item):
- assert item is not None, "Item != None"
-
- for i, v in enumerate(self.slots):
- if v == item:
- return i
-
- return -1
-
- def clear(self):
- with self.lock():
- for i in range(self.get_capacity()):
- self.slots[i] = None
-
- def find_first(self, **kwargs):
- if self.used < 1:
- return None
-
- for p in self.slots:
- if p is None:
- continue
-
- for k, v in kwargs.items():
- if getattr(p, k) != v:
- break
- else:
- return p
-
- return None
-
- def find_by_capcom_id(self, capcom_id):
- return self.find_first(capcom_id=capcom_id)
-
- def __len__(self):
- return self.used
-
- def __iter__(self):
- if self.used < 1:
- return
-
- for i, v in enumerate(self.slots):
- if v is None:
- continue
-
- yield i, v
-
-
-class Circle(Lockable):
- def __init__(self, parent):
- # type: (City) -> None
- self.parent = parent
- self.leader = None
- self.players = Players(4)
- self.departed = False
- self.quest_id = 0
- self.embarked = False
- self.password = None
- self.remarks = None
-
- self.unk_byte_0x0e = 0
- super(Circle, self).__init__()
-
- def get_population(self):
- return len(self.players)
-
- def get_capacity(self):
- return self.players.get_capacity()
-
- def is_full(self):
- return self.get_population() == self.get_capacity()
-
- def is_empty(self):
- return self.leader is None
-
- def is_joinable(self):
- return not self.departed and not self.is_full()
-
- def has_password(self):
- return self.password is not None
-
- def reset_players(self, capacity):
- with self.lock():
- self.players = Players(capacity)
-
- def reset(self):
- with self.lock():
- self.leader = None
- self.reset_players(4)
- self.departed = False
- self.quest_id = 0
- self.embarked = False
- self.password = None
- self.remarks = None
-
- self.unk_byte_0x0e = 0
-
-
-class City(Lockable):
- LAYER_DEPTH = 3
-
- def __init__(self, name, parent):
- # type: (str, Gate) -> None
- self.name = name
- self.parent = parent
- self.state = LayerState.EMPTY
- self.players = Players(4)
- self.optional_fields = []
- self.leader = None
- self.reserved = None
- self.circles = [
- # One circle per player
- Circle(self) for _ in range(self.get_capacity())
- ]
- super(City, self).__init__()
-
- def get_population(self):
- return len(self.players)
-
- def in_quest_players(self):
- return sum(p.is_in_quest() for _, p in self.players)
-
- def get_capacity(self):
- return self.players.get_capacity()
-
- def get_state(self):
- if self.reserved:
- return LayerState.FULL
-
- size = self.get_population()
- if size == 0:
- return LayerState.EMPTY
- elif size < self.get_capacity():
- return LayerState.JOINABLE
- else:
- return LayerState.FULL
-
- def get_pathname(self):
- pathname = self.name
- it = self.parent
- while it is not None:
- pathname = it.name + "\t" + pathname
- it = it.parent
- return pathname
-
- def get_first_empty_circle(self):
- with self.lock():
- for index, circle in enumerate(self.circles):
- if circle.is_empty():
- return circle, index
- return None, None
-
- def get_circle_for(self, leader_session):
- with self.lock():
- for index, circle in enumerate(self.circles):
- if circle.leader == leader_session:
- return circle, index
- return None, None
-
- def clear_circles(self):
- with self.lock():
- for circle in self.circles:
- circle.reset()
-
- def reserve(self, reserve):
- with self.lock():
- if reserve:
- self.reserved = time.time()
- else:
- self.reserved = None
-
- def reset(self):
- with self.lock():
- self.state = LayerState.EMPTY
- self.players.clear()
- self.optional_fields = []
- self.leader = None
- self.reserved = None
- self.clear_circles()
-
-
-class Gate(object):
- LAYER_DEPTH = 2
-
- def __init__(self, name, parent, city_count=40, player_capacity=100):
- # type: (str, Server, int, int) -> None
- self.name = name
- self.parent = parent
- self.state = LayerState.EMPTY
- self.cities = [
- City("City{}".format(i), self)
- for i in range(1, city_count+1)
- ]
- self.players = Players(player_capacity)
- self.optional_fields = []
-
- def get_population(self):
- return len(self.players) + sum((
- city.get_population()
- for city in self.cities
- ))
-
- def get_capacity(self):
- return self.players.get_capacity()
-
- def get_state(self):
- size = self.get_population()
- if size == 0:
- return LayerState.EMPTY
- elif size < self.get_capacity():
- return LayerState.JOINABLE
- else:
- return LayerState.FULL
-
-
-class Server(object):
- LAYER_DEPTH = 1
-
- def __init__(self, name, server_type, gate_count=40, capacity=2000,
- addr=None, port=None):
- self.name = name
- self.parent = None
- self.server_type = server_type
- self.addr = addr
- self.port = port
- self.gates = [
- Gate("City Gate{}".format(i), self)
- for i in range(1, gate_count+1)
- ]
- self.players = Players(capacity)
-
- def get_population(self):
- return len(self.players) + sum((
- gate.get_population() for gate in self.gates
- ))
-
- def get_capacity(self):
- return self.players.get_capacity()
-
-
-def new_servers():
- servers = []
- servers.extend([
- Server("Valor{}".format(i), ServerType.OPEN)
- for i in range(1, 5)
- ])
- servers.extend([
- Server("Beginners{}".format(i), ServerType.ROOKIE)
- for i in range(1, 3)
- ])
- servers.extend([
- Server("Veterans{}".format(i), ServerType.EXPERT)
- for i in range(1, 3)
- ])
- servers.extend([
- Server("Greed{}".format(i), ServerType.RECRUITING)
- for i in range(1, 5)
- ])
- return servers
-
-
class TempDatabase(object):
"""A temporary database.
@@ -389,11 +40,8 @@ def __init__(self):
self.consoles = {
# Online support code => Capcom IDs
}
- self.sessions = {
- # PAT Ticket => Owner's session
- }
self.capcom_ids = {
- # Capcom ID => Owner's name and session
+ # Capcom ID => Hunter name
}
self.friend_requests = {
# Capcom ID => List of friend requests from Capcom IDs
@@ -402,7 +50,6 @@ def __init__(self):
# Capcom ID => List of Capcom IDs
# TODO: May need stable index, see Players class
}
- self.servers = new_servers()
def get_support_code(self, session):
"""Get the online support code or create one."""
@@ -419,244 +66,19 @@ def get_support_code(self, session):
self.consoles[support_code] = [BLANK_CAPCOM_ID] * 6
return support_code
- def new_pat_ticket(self, session):
- """Generates a new PAT ticket for the session."""
- while True:
- session.pat_ticket = new_random_str(11)
- if session.pat_ticket not in self.sessions:
- break
- self.sessions[session.pat_ticket] = session
- return session.pat_ticket
-
- def use_capcom_id(self, session, capcom_id, name=None):
- """Attach the session to the Capcom ID."""
- assert capcom_id in self.capcom_ids, "Capcom ID doesn't exist"
-
- not_in_use = self.capcom_ids[capcom_id]["session"] is None
- assert not_in_use, "Capcom ID is already in use"
-
- name = name or self.capcom_ids[capcom_id]["name"]
- self.capcom_ids[capcom_id] = {"name": name, "session": session}
+ def get_capcom_ids(self, online_support_code):
+ """Get the Capcom IDs associated with an online support code."""
+ return self.consoles[online_support_code]
- # TODO: Check if stable index is required
- if capcom_id not in self.friend_lists:
- self.friend_lists[capcom_id] = []
- if capcom_id not in self.friend_requests:
- self.friend_requests[capcom_id] = []
+ def assign_capcom_id(self, online_support_code, index, capcom_id):
+ """Assign a Capcom ID to an online support code."""
+ self.consoles[online_support_code][index] = capcom_id
- return name
+ def assign_name(self, capcom_id, name):
+ self.capcom_ids[capcom_id] = name
- def use_user(self, session, index, name):
- """Use User from the slot or create one if empty"""
- assert 1 <= index <= 6, "Invalid Capcom ID slot"
- index -= 1
- users = self.consoles[session.online_support_code]
- while users[index] == BLANK_CAPCOM_ID:
- capcom_id = new_random_str(6)
- if capcom_id not in self.capcom_ids:
- self.capcom_ids[capcom_id] = {"name": name, "session": None}
- users[index] = capcom_id
- break
- else:
- capcom_id = users[index]
- name = self.use_capcom_id(session, capcom_id, name)
- session.capcom_id = capcom_id
- session.hunter_name = name
-
- def get_session(self, pat_ticket):
- """Returns existing PAT session or None."""
- session = self.sessions.get(pat_ticket)
- if session and session.capcom_id:
- self.use_capcom_id(session, session.capcom_id, session.hunter_name)
- return session
-
- def disconnect_session(self, session):
- """Detach the session from its Capcom ID."""
- if not session.capcom_id:
- # Capcom ID isn't chosen yet with OPN/LMP servers
- return
- self.capcom_ids[session.capcom_id]["session"] = None
-
- def delete_session(self, session):
- """Delete the session from the database."""
- self.disconnect_session(session)
- pat_ticket = session.pat_ticket
- if pat_ticket in self.sessions:
- del self.sessions[pat_ticket]
-
- def get_users(self, session, first_index, count):
- """Returns Capcom IDs tied to the session."""
- users = self.consoles[session.online_support_code]
- capcom_ids = [
- (i, (capcom_id, self.capcom_ids.get(capcom_id, {})))
- for i, capcom_id in enumerate(users[:count], first_index)
- ]
- size = len(capcom_ids)
- if size < count:
- capcom_ids.extend([
- (index, (BLANK_CAPCOM_ID, {}))
- for index in range(first_index+size, first_index+count)
- ])
- return capcom_ids
-
- def join_server(self, session, index):
- if session.local_info["server_id"] is not None:
- self.leave_server(session, session.local_info["server_id"])
- server = self.get_server(index)
- server.players.add(session)
- session.local_info["server_id"] = index
- session.local_info["server_name"] = server.name
- return server
-
- def leave_server(self, session, index):
- self.get_server(index).players.remove(session)
- session.local_info["server_id"] = None
- session.local_info["server_name"] = None
-
- def get_server_time(self):
- pass
-
- def get_game_time(self):
- pass
-
- def get_servers(self):
- return self.servers
-
- def get_server(self, index):
- assert 0 < index <= len(self.servers), "Invalid server index"
- return self.servers[index - 1]
-
- def get_gates(self, server_id):
- return self.get_server(server_id).gates
-
- def get_gate(self, server_id, index):
- gates = self.get_gates(server_id)
- assert 0 < index <= len(gates), "Invalid gate index"
- return gates[index - 1]
-
- def join_gate(self, session, server_id, index):
- gate = self.get_gate(server_id, index)
- gate.parent.players.remove(session)
- gate.players.add(session)
- session.local_info["gate_id"] = index
- session.local_info["gate_name"] = gate.name
- return gate
-
- def leave_gate(self, session):
- gate = self.get_gate(session.local_info["server_id"],
- session.local_info["gate_id"])
- gate.parent.players.add(session)
- gate.players.remove(session)
- session.local_info["gate_id"] = None
- session.local_info["gate_name"] = None
-
- def get_cities(self, server_id, gate_id):
- return self.get_gate(server_id, gate_id).cities
-
- def get_city(self, server_id, gate_id, index):
- cities = self.get_cities(server_id, gate_id)
- assert 0 < index <= len(cities), "Invalid city index"
- return cities[index - 1]
-
- def reserve_city(self, server_id, gate_id, index, reserve):
- city = self.get_city(server_id, gate_id, index)
- with city.lock():
- reserved_time = city.reserved
- if reserve and reserved_time and \
- time.time()-reserved_time < RESERVE_DC_TIMEOUT:
- return False
- city.reserve(reserve)
- return True
-
- def get_all_users(self, server_id, gate_id, city_id):
- """Search for users in layers and its children.
-
- Let's assume wildcard search isn't possible for servers and gates.
- A wildcard search happens when the id is zero.
- """
- assert 0 < server_id, "Invalid server index"
- assert 0 < gate_id, "Invalid gate index"
- gate = self.get_gate(server_id, gate_id)
- users = list(gate.players)
- cities = [
- self.get_city(server_id, gate_id, city_id)
- ] if city_id else self.get_cities(server_id, gate_id)
- for city in cities:
- users.extend(list(city.players))
- return users
-
- def find_users(self, capcom_id="", hunter_name=""):
- assert capcom_id or hunter_name, "Search can't be empty"
- users = []
- for user_id, user_info in self.capcom_ids.items():
- session = user_info["session"]
- if not session:
- continue
- if capcom_id and capcom_id not in user_id:
- continue
- if hunter_name and \
- hunter_name.lower() not in user_info["name"].lower():
- continue
- users.append(session)
- return users
-
- def get_user_name(self, capcom_id):
- if capcom_id not in self.capcom_ids:
- return ""
- return self.capcom_ids[capcom_id]["name"]
-
- def create_city(self, session, server_id, gate_id, index,
- settings, optional_fields):
- city = self.get_city(server_id, gate_id, index)
- with city.lock():
- city.optional_fields = optional_fields
- city.leader = session
- city.reserved = None
- return city
-
- def join_city(self, session, server_id, gate_id, index):
- city = self.get_city(server_id, gate_id, index)
- with city.lock():
- city.parent.players.remove(session)
- city.players.add(session)
- session.local_info["city_name"] = city.name
- session.local_info["city_id"] = index
- return city
-
- def leave_city(self, session):
- city = self.get_city(session.local_info["server_id"],
- session.local_info["gate_id"],
- session.local_info["city_id"])
- with city.lock():
- city.parent.players.add(session)
- city.players.remove(session)
- if not city.get_population():
- city.reset()
- session.local_info["city_id"] = None
- session.local_info["city_name"] = None
-
- def layer_detail_search(self, server_type, fields):
- cities = []
-
- def match_city(city, fields):
- with city.lock():
- return all((
- field in city.optional_fields
- for field in fields
- ))
-
- for server in self.servers:
- if server.server_type != server_type:
- continue
- for gate in server.gates:
- if not gate.get_population():
- continue
- cities.extend([
- city
- for city in gate.cities
- if match_city(city, fields)
- ])
- return cities
+ def get_name(self, capcom_id):
+ return self.capcom_ids.get(capcom_id, "")
def add_friend_request(self, sender_id, recipient_id):
# Friend invite can be sent to arbitrary Capcom ID
@@ -692,7 +114,7 @@ def get_friends(self, capcom_id, first_index=None, count=None):
begin = 0 if first_index is None else (first_index - 1)
end = count if count is None else (begin + count)
return [
- (k, self.capcom_ids[k]["name"])
+ (k, self.capcom_ids[k])
for k in self.friend_lists[capcom_id]
if k in self.capcom_ids # Skip unknown Capcom IDs
][begin:end]
@@ -852,15 +274,45 @@ def force_update(self):
(capcom_id, friend_id)
)
+ def assign_capcom_id(self, online_support_code, index, capcom_id):
+ """Assign a Capcom ID to an online support code."""
+ self.consoles[online_support_code][index] = capcom_id
+ with self.connection as cursor:
+ cursor.execute(
+ "INSERT OR REPLACE INTO consoles VALUES (?,?,?,?)",
+ (online_support_code, index, capcom_id, "????")
+ )
+
+ def get_capcom_ids(self, online_support_code):
+ """Get list of associated Capcom IDs from an online support code."""
+ with self.connection as cursor:
+ rows = cursor.execute("SELECT slot_index, capcom_id FROM consoles WHERE support_code = '{}'".format(online_support_code))
+ ids = []
+ for index, id in rows:
+ while len(ids) < index:
+ ids.append(BLANK_CAPCOM_ID)
+ ids.append(id)
+ while len(ids) < 6:
+ ids.append(BLANK_CAPCOM_ID)
+ return ids
+
+ def get_name(self, capcom_id):
+ """Get the hunter name associated with a valid Capcom ID."""
+ with self.connection as cursor:
+ rows = cursor.execute("SELECT name FROM consoles WHERE capcom_id = '{}'".format(capcom_id))
+ names = [name for name, in rows]
+ if len(names):
+ return names[0]
+ return ""
+
def use_user(self, session, index, name):
- result = self.parent.use_user(session, index, name)
+ """Insert the current hunter's info into a selected Capcom ID slot."""
with self.connection as cursor:
cursor.execute(
"INSERT OR REPLACE INTO consoles VALUES (?,?,?,?)",
(session.online_support_code, index,
session.capcom_id, session.hunter_name)
)
- return result
def accept_friend(self, capcom_id, friend_id, accepted):
if accepted:
@@ -884,6 +336,16 @@ def delete_friend(self, capcom_id, friend_id):
)
return self.parent.delete_friend(capcom_id, friend_id)
+ def get_friends(self, capcom_id, first_index=None, count=None):
+ begin = 0 if first_index is None else (first_index - 1)
+ end = count if count is None else (begin + count)
+ with self.connection as cursor:
+ rows = cursor.execute("SELECT friend_id, name FROM friend_lists INNER JOIN consoles ON friend_lists.friend_id = consoles.capcom_id WHERE friend_lists.capcom_id = '{}'".format(capcom_id))
+ friends = []
+ for friend_id, name in rows:
+ friends.append((friend_id, name))
+ return friends[begin:end]
+
class DebugDatabase(TempSQLiteDatabase):
"""For testing purpose."""
diff --git a/mh/pat.py b/mh/pat.py
index 6b6770a..cf9cc35 100644
--- a/mh/pat.py
+++ b/mh/pat.py
@@ -15,9 +15,10 @@
import mh.time_utils as time_utils
from mh.constants import \
LAYER_CHAT_COLORS, TERMS_VERSION, TERMS, SUBTERMS, ANNOUNCE, \
- CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, PatID4
+ CHARGE, VULGARITY_INFO, FMP_VERSION, PAT_BINARIES, PAT_NAMES, PatID4, \
+ FMP_CENTRAL_VERSION
from mh.session import Session
-import mh.database as db
+import mh.state as state
try:
from typing import Literal, List, Union, Optional # noqa: F401
@@ -34,7 +35,7 @@ class PatServer(server.BasicPatServer, Logger):
"""Generic PAT server class."""
def __init__(self, address, handler_class, max_thread_count=0,
- logger=None, debug_mode=False):
+ logger=None, debug_mode=False, no_timeout=False):
server.BasicPatServer.__init__(self, address, handler_class,
max_thread_count)
Logger.__init__(self)
@@ -43,6 +44,7 @@ def __init__(self, address, handler_class, max_thread_count=0,
self.info("Running on {} port {}".format(*address))
self.debug_con = []
self.debug_mode = debug_mode
+ self.no_timeout = no_timeout
def add_to_debug(self, con):
"""Add connection to the debug connection list."""
@@ -59,6 +61,9 @@ def get_debug(self):
def debug_enabled(self):
return self.debug_mode
+ def no_timeout_enabled(self):
+ return self.no_timeout
+
def get_pat_handler(self, session):
"""Return pat handler from session"""
for handler in self.debug_con:
@@ -268,14 +273,14 @@ def recvAnsConnection(self, packet_id, data, seq):
"""
settings = pati.ConnectionData.unpack(data)
self.server.debug("Connection: {!r}".format(settings))
- pat_ticket = b""
- if "pat_ticket" in settings:
- _, pat_ticket = pati.unpack_any(settings.pat_ticket)
- elif "online_support_code" in settings:
- _, pat_ticket = pati.unpack_any(settings.online_support_code)
- self.server.info("Client {} Ticket `{}`".format(self.client_address,
- pat_ticket))
- self.sendNtcLogin(5, settings, seq)
+ pat_ticket = settings.pat_ticket if "pat_ticket" in settings else \
+ settings.online_support_code if "online_support_code" in settings \
+ else ""
+ self.server.info("Client {} Ticket `{}`".format(self.client_address, pat_ticket))
+ if len(self.session.get_servers()) == 0:
+ self.sendNtcLogin(2, settings, seq)
+ else:
+ self.sendNtcLogin(5, settings, seq)
def sendNtcLogin(self, server_status, connection_data, seq):
"""NtcLogin packet.
@@ -290,6 +295,27 @@ def sendNtcLogin(self, server_status, connection_data, seq):
self.session = self.session.get(connection_data)
self.send_packet(PatID4.NtcLogin, data, seq)
+ def recvReqMaintenance(self, packet_id, data, seq):
+ """sendAnsMaintenance packet.
+
+ ID: 62200100
+ JP: メンテナンス情報要求
+ TR: Maintenance information request
+ """
+ self.sendAnsMaintenance(MAINTENANCE, seq)
+
+ def sendAnsMaintenance(self, maintenance, seq):
+ """sendAnsMaintenance packet.
+
+ ID: 62200200
+ JP: メンテナンス情報通知
+ TR: Maintenance information notification
+
+ The server replies with the maintenance information text.
+ """
+ data = pati.lp2_string(maintenance)
+ self.send_packet(PatID4.AnsMaintenance, data, seq)
+
def recvReqAuthenticationToken(self, packet_id, data, seq):
"""ReqAuthenticationToken packet.
@@ -721,7 +747,7 @@ def sendAnsTicket(self, seq):
TR: PAT ticket response
"""
pat_ticket = self.session.new_pat_ticket()
- data = struct.pack(">H", len(pat_ticket)) + pat_ticket
+ data = struct.pack(">H", len(pat_ticket)) + pat_ticket.encode('ascii')
self.send_packet(PatID4.AnsTicket, data, seq)
def recvReqUserListHead(self, packet_id, data, seq):
@@ -883,7 +909,7 @@ def sendAnsFmpListVersion(self, seq):
JP: FMPリストバージョン確認応答
TR: FMP list version acknowledgment
"""
- data = struct.pack(">I", FMP_VERSION)
+ data = struct.pack(">I", FMP_CENTRAL_VERSION)
self.send_packet(PatID4.AnsFmpListVersion, data, seq)
def sendAnsFmpListVersion2(self, seq):
@@ -893,7 +919,7 @@ def sendAnsFmpListVersion2(self, seq):
JP: FMPリストバージョン確認応答
TR: FMP list version acknowledgment
"""
- data = struct.pack(">I", FMP_VERSION)
+ data = struct.pack(">I", self.session.get_fmp_version())
self.send_packet(PatID4.AnsFmpListVersion2, data, seq)
def recvReqFmpListHead(self, packet_id, data, seq):
@@ -903,13 +929,9 @@ def recvReqFmpListHead(self, packet_id, data, seq):
JP: FMPリスト数送信 / FMPリスト数要求
TR: Send FMP list count / FMP list count request
"""
- # TODO: Might be worth investigating these parameters as
- # they might be useful when using multiple FMP servers.
- version, first_index, count = struct.unpack_from(
- ">III", data
- ) # noqa: F841
- # TODO: Unpack it using pati.Unpacker
- header = pati.unpack_bytes(data, 12) # noqa: F841
+ version, first_index, count = struct.unpack_from(">III", data)
+ self.session.preserve_server_ids(first_index, count)
+ header = pati.unpack_bytes(data, 12)
if packet_id == PatID4.ReqFmpListHead:
self.sendAnsFmpListHead(seq)
elif packet_id == PatID4.ReqFmpListHead2:
@@ -967,7 +989,7 @@ def sendAnsFmpListData(self, first_index, count, seq):
"""
unused = 0
data = struct.pack(">II", unused, count)
- data += pati.get_fmp_servers(self.session, first_index, count)
+ data += pati.get_fmp_central_servers(self.session, first_index, count)
self.send_packet(PatID4.AnsFmpListData, data, seq)
def sendAnsFmpListData2(self, first_index, count, seq):
@@ -1067,20 +1089,31 @@ def recvReqFmpInfo(self, packet_id, data, seq):
"""
index, = struct.unpack_from(">I", data)
fields = pati.unpack_bytes(data, 4)
- server = self.session.join_server(index)
- config = get_config("FMP")
- fmp_addr = get_ip(config["IP"])
- fmp_port = config["Port"]
fmp_data = pati.FmpData()
- fmp_data.server_address = pati.String(server.addr or fmp_addr)
- fmp_data.server_port = pati.Word(server.port or fmp_port)
- fmp_data.assert_fields(fields)
+
if packet_id == PatID4.ReqFmpInfo:
+ config = get_config("FMP")
+ central_fmp_addr = get_ip(config["IP"])
+ central_fmp_port = config["Port"]
+ fmp_data.server_address = pati.String(central_fmp_addr)
+ fmp_data.server_port = pati.Word(central_fmp_port)
+ fmp_data.assert_fields(fields)
self.sendAnsFmpInfo(fmp_data, fields, seq)
elif packet_id == PatID4.ReqFmpInfo2:
+ if not self.session.server_index_exists(index):
+ self.sendAnsAlert(PatID4.AnsFmpInfo2,
+ "\
+ Server is offline.",
+ seq)
+ return
+ server_id = self.session.recall_server_id(index)
+ server = self.session.join_server(server_id)
+ fmp_data.server_address = pati.String(server.addr)
+ fmp_data.server_port = pati.Word(server.port)
+ fmp_data.assert_fields(fields)
self.sendAnsFmpInfo2(fmp_data, fields, seq)
- # Preserve session in database, due to server selection
+ # Preserve session in state, due to server selection
self.session.request_reconnection = True
def sendAnsFmpInfo(self, fmp_data, fields, seq):
@@ -1328,10 +1361,10 @@ def sendAnsUserSearchInfo(self, capcom_id, search_info, seq):
# Specifically when a client is deserializing data from the packets
# `NtcLayerBinary` and `NtcLayerBinary2`
# TODO: Proper field value and name
- user_info.info_mine_0x0f = pati.Long(int(hash(user.capcom_id))
- & 0xffffffff)
- user_info.info_mine_0x10 = pati.Long(int(hash(user.capcom_id[::-1]))
- & 0xffffffff)
+ user_info.info_mine_0x0f = pati.Long(int(hash(user.capcom_id)) &
+ 0xffffffff)
+ user_info.info_mine_0x10 = pati.Long(int(hash(user.capcom_id[::-1])) &
+ 0xffffffff)
data = user_info.pack()
# TODO: Figure out the optional fields
@@ -2658,7 +2691,7 @@ def sendNtcCircleHost(self, circle, new_leader, new_leader_index, seq):
def notify_city_info_set(self, path):
# type: (pati.LayerPath) -> None
city = self.get_layer(path)
- assert isinstance(city, db.City)
+ assert isinstance(city, state.City)
gate = city.parent
layer_data = pati.LayerData.create_from(path.city_id, city, path)
@@ -2670,7 +2703,7 @@ def notify_city_info_set(self, path):
def notify_city_number_set(self, path):
# type: (pati.LayerPath) -> None
city = self.get_layer(path)
- assert isinstance(city, db.City)
+ assert isinstance(city, state.City)
gate = city.parent
layer_data = pati.LayerData.create_from(path.city_id, city, path)
@@ -2680,14 +2713,14 @@ def notify_city_number_set(self, path):
@staticmethod
def get_layer(path):
- # type: (pati.LayerPath) -> Optional[db.Server | db.Gate | db.City]
- database = db.get_instance()
+ # type: (pati.LayerPath) -> Optional[state.Server | state.Gate | state.City]
+ curr_state = state.get_instance()
if path.city_id > 0:
- return database.get_city(path.server_id, path.gate_id, path.city_id)
+ return curr_state.get_city(path.server_id, path.gate_id, path.city_id)
elif path.gate_id > 0:
- return database.get_gate(path.server_id, path.gate_id)
+ return curr_state.get_gate(path.server_id, path.gate_id)
elif path.server_id > 0:
- return database.get_server(path.server_id)
+ return curr_state.get_server(path.server_id)
return None
def notify_layer_departure(self, end):
@@ -2711,7 +2744,7 @@ def notify_layer_departure(self, end):
if path.city_id > 0:
city = self.get_layer(path)
- assert isinstance(city, db.City)
+ assert isinstance(city, state.City)
self.notify_city_number_set(path)
if city.leader is None:
@@ -2796,7 +2829,7 @@ def on_tick(self):
# Send a ping with 30 seconds interval
if self.ping_timer.elapsed() >= 30:
- if not self.server.debug_enabled() and not self.line_check:
+ if not self.server.no_timeout_enabled() and not self.line_check:
raise Exception("Client timed out.")
self.line_check = False
self.sendReqLineCheck()
diff --git a/mh/pat_item.py b/mh/pat_item.py
index aa8e321..ec2e075 100644
--- a/mh/pat_item.py
+++ b/mh/pat_item.py
@@ -9,7 +9,7 @@
from collections import OrderedDict
from mh.constants import pad
from other.utils import to_bytearray, get_config, get_ip, GenericUnpacker
-from mh.database import Server, Gate, City
+from mh.state import Server, Gate, City
class ItemType:
Custom = 0
@@ -893,7 +893,6 @@ def pack_from(layer_data):
return data.pack()
-
def patdata_extender(unpacker):
"""PatData classes must be defined above this function."""
for name, value in globals().items():
@@ -964,14 +963,11 @@ def get_fmp_servers(session, first_index, count):
fmp_port = config["Port"]
data = b""
- start = first_index - 1
- end = start + count
- servers = session.get_servers()[start:end]
+
+ servers = session.recall_servers(first_index, count)
for i, server in enumerate(servers, first_index):
fmp_data = FmpData()
fmp_data.index = Long(i) # The server might be full, if zero
- server.addr = server.addr or fmp_addr
- server.port = server.port or fmp_port
fmp_data.server_address = String(server.addr)
fmp_data.server_port = Word(server.port)
# Might produce invalid reads if too high
@@ -987,6 +983,32 @@ def get_fmp_servers(session, first_index, count):
return data
+def get_fmp_central_servers(session, first_index, count):
+ assert first_index > 0, "Invalid list index"
+
+ config = get_config("FMP")
+ fmp_addr = get_ip(config["IP"])
+ fmp_port = config["Port"]
+
+ start = first_index - 1
+ end = start + count
+ data = b""
+ for i in range(start, end):
+ fmp_data = FmpData()
+ fmp_data.index = Long(count+1) # Index must not match any "real" server.
+ fmp_data.server_name = String("central")
+ fmp_data.server_address = String(fmp_addr)
+ fmp_data.server_port = Word(fmp_port)
+ fmp_data.server_type = LongLong(1)
+ fmp_data.player_count = Long(1)
+ fmp_data.player_capacity = Long(1)
+ fmp_data.unk_string_0x0b = String("X")
+ fmp_data.unk_long_0x0c = Long(0x12345678)
+ data += fmp_data.pack()
+
+ return data
+
+
def get_layer_children(session, first_index, count, sibling=False):
assert first_index > 0, "Invalid list index"
diff --git a/mh/session.py b/mh/session.py
index d89e966..1abd1ec 100644
--- a/mh/session.py
+++ b/mh/session.py
@@ -4,12 +4,17 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
"""Monster Hunter session module."""
+import struct
+import time
+
import mh.database as db
+from mh.state import get_instance, Players, LayerState
import mh.pat_item as pati
from other.utils import to_bytearray, to_str
DB = db.get_instance()
+STATE = get_instance()
class SessionState:
@@ -52,19 +57,86 @@ def __init__(self, connection_handler):
self.state = SessionState.UNKNOWN
self.binary_setting = b""
self.search_payload = None
+ self.loaded_server_ids = {}
self.hunter_info = pati.HunterSettings()
+ def serialize(self):
+ pdict = {
+ "pat_ticket": self.pat_ticket,
+ "local_info_server_id": self.local_info["server_id"],
+ "local_info_server_name": self.local_info["server_name"],
+ "local_info_gate_id": self.local_info["gate_id"],
+ "local_info_gate_name": self.local_info["gate_name"],
+ "local_info_city_id": self.local_info["city_id"],
+ "local_info_city_name": self.local_info["city_name"],
+ "local_info_city_size": self.local_info["city_size"],
+ "local_info_city_capacity": self.local_info["city_capacity"],
+ "local_info_circle_id": self.local_info["circle_id"],
+ "online_support_code": self.online_support_code,
+ "capcom_id": self.capcom_id,
+ "hunter_name": self.hunter_name,
+ "hunter_stats": self.hunter_stats,
+ "layer": self.layer,
+ "state": self.state,
+ "binary_setting": self.binary_setting,
+ "hunter_info": self.hunter_info.pack().decode(
+ encoding='ISO-8859-1'
+ )
+ }
+ return pdict
+
+ @staticmethod
+ def deserialize(pdict):
+ session = Session(None)
+ session.pat_ticket = str(pdict["pat_ticket"])\
+ if pdict["pat_ticket"] else pdict["pat_ticket"]
+ session.local_info["server_id"] = int(pdict["local_info_server_id"])\
+ if pdict["local_info_server_id"] else pdict["local_info_server_id"]
+ session.local_info["server_name"] =\
+ str(pdict["local_info_server_name"])\
+ if pdict["local_info_server_name"]\
+ else pdict["local_info_server_name"]
+ session.local_info["gate_id"] = int(pdict["local_info_gate_id"])\
+ if pdict["local_info_gate_id"] else pdict["local_info_gate_id"]
+ session.local_info["gate_name"] = str(pdict["local_info_gate_name"])\
+ if pdict["local_info_gate_name"] else pdict["local_info_gate_name"]
+ session.local_info["city_id"] = int(pdict["local_info_city_id"])\
+ if pdict["local_info_city_id"] else pdict["local_info_city_id"]
+ session.local_info["city_name"] = str(pdict["local_info_city_name"])\
+ if pdict["local_info_city_name"] else pdict["local_info_city_name"]
+ session.local_info["city_size"] = int(pdict["local_info_city_size"])\
+ if pdict["local_info_city_size"] else pdict["local_info_city_size"]
+ session.local_info["city_capacity"] =\
+ int(pdict["local_info_city_capacity"])\
+ if pdict["local_info_city_capacity"]\
+ else pdict["local_info_city_capacity"]
+ session.local_info["circle_id"] = int(pdict["local_info_circle_id"])\
+ if pdict["local_info_circle_id"] else pdict["local_info_circle_id"]
+ session.online_support_code = str(pdict["online_support_code"])\
+ if pdict["online_support_code"] else pdict["online_support_code"]
+ session.capcom_id = str(pdict["capcom_id"])
+ session.hunter_name = str(pdict["hunter_name"])
+ session.hunter_stats = pdict["hunter_stats"]
+ session.layer = int(pdict["layer"])
+ session.state = int(pdict["state"])
+ session.binary_setting = pdict["binary_setting"]
+ h_settings = bytearray(pdict["hunter_info"], encoding='ISO-8859-1')
+ session.hunter_info = pati.HunterSettings().unpack(h_settings,
+ len(h_settings))
+ return session
+
def get(self, connection_data):
"""Return the session associated with the connection data, if any."""
- if hasattr(connection_data, "pat_ticket"):
+ has_pat_ticket = hasattr(connection_data, "pat_ticket")
+ if has_pat_ticket:
self.pat_ticket = to_str(
pati.unpack_binary(connection_data.pat_ticket)
)
+ session = STATE.get_session(self.pat_ticket) or self
if hasattr(connection_data, "online_support_code"):
self.online_support_code = to_str(
pati.unpack_string(connection_data.online_support_code)
)
- session = DB.get_session(self.pat_ticket) or self
if session != self:
assert session.connection is None, "Session is already in use"
session.connection = self.connection
@@ -78,6 +150,20 @@ def get(self, connection_data):
"online_support_code" in connection_data)
return session
+ def session_ready(self, connection_data):
+ if hasattr(connection_data, "pat_ticket"):
+ return STATE.session_ready(to_str(
+ pati.unpack_binary(connection_data.pat_ticket)
+ ))
+ else:
+ return STATE.session_ready(self.pat_ticket)
+
+ def set_session_ready(self, connection_data, store_data):
+ STATE.set_session_ready(
+ to_str(pati.unpack_binary(connection_data.pat_ticket)),
+ store_data
+ )
+
def get_support_code(self):
"""Return the online support code."""
return DB.get_support_code(self)
@@ -88,7 +174,7 @@ def disconnect(self):
It doesn't purge the session state nor its PAT ticket.
"""
self.connection = None
- DB.disconnect_session(self)
+ STATE.disconnect_session(self)
def delete(self):
"""Delete the current session.
@@ -98,39 +184,76 @@ def delete(self):
- We should probably create a SessionManager thread per server.
"""
if not self.request_reconnection:
- DB.delete_session(self)
+ STATE.delete_session(self)
def is_jap(self):
"""TODO: Heuristic using the connection data to detect region."""
pass
def new_pat_ticket(self):
- DB.new_pat_ticket(self)
- return to_bytearray(self.pat_ticket)
+ # type: () -> str
+ STATE.new_pat_ticket(self)
+ return self.pat_ticket
+
+ def get_fmp_version(self):
+ return STATE.get_servers_version()
def get_users(self, first_index, count):
- return DB.get_users(self, first_index, count)
+ return STATE.get_users(self, first_index, count)
def use_user(self, index, name):
- DB.use_user(self, index, name)
+ STATE.use_user(self, index, name)
+
+ def server_index_exists(self, index):
+ try:
+ if index in self.loaded_server_ids:
+ self.recall_server(index)
+ return True
+ else:
+ return False
+ except AssertionError:
+ return False
+
+ def preserve_server_ids(self, first_index, count):
+ server_ids, servers = STATE.get_servers(include_ids=True)
+ if first_index-1 + count > len(server_ids):
+ count = len(server_ids) - (first_index - 1)
+ assert first_index <= len(server_ids)
+ server_ids = server_ids[first_index-1:first_index-1+count]
+ self.loaded_server_ids = {}
+ for i, server_id in enumerate(server_ids):
+ self.loaded_server_ids[first_index+i] = server_id
+
+ def recall_servers(self, first_index, count):
+ server_ids, servers = STATE.get_servers(include_ids=True)
+ servers = []
+ for i in range(first_index, first_index+count):
+ servers.append(self.recall_server(i))
+ return servers
+
+ def recall_server(self, index):
+ return STATE.get_server(self.loaded_server_ids[index])
+
+ def recall_server_id(self, index):
+ return self.loaded_server_ids[index]
def get_servers(self):
- return DB.get_servers()
+ return STATE.get_servers()
def get_server(self):
assert self.local_info['server_id'] is not None
- return DB.get_server(self.local_info['server_id'])
+ return STATE.get_server(self.local_info['server_id'])
def get_gate(self):
assert self.local_info['gate_id'] is not None
- return DB.get_gate(self.local_info['server_id'],
- self.local_info['gate_id'])
+ return STATE.get_gate(self.local_info['server_id'],
+ self.local_info['gate_id'])
def get_city(self):
assert self.local_info['city_id'] is not None
- return DB.get_city(self.local_info['server_id'],
- self.local_info['gate_id'],
- self.local_info['city_id'])
+ return STATE.get_city(self.local_info['server_id'],
+ self.local_info['gate_id'],
+ self.local_info['city_id'])
def get_circle(self):
assert self.local_info['circle_id'] is not None
@@ -191,10 +314,10 @@ def layer_detail_search(self, detailed_fields):
(field_id, value)
for field_id, field_type, value in detailed_fields
] # Convert detailed to simple optional fields
- return DB.layer_detail_search(server_type, fields)
+ return STATE.layer_detail_search(server_type, fields)
def join_server(self, server_id):
- return DB.join_server(self, server_id)
+ return STATE.join_server(self, server_id)
def get_layer_children(self):
if self.layer == 0:
@@ -213,73 +336,73 @@ def get_layer_sibling(self):
def find_users_by_layer(self, server_id, gate_id, city_id,
first_index, count, recursive=False):
if recursive:
- players = DB.get_all_users(server_id, gate_id, city_id)
+ players = STATE.get_all_users(server_id, gate_id, city_id)
else:
layer = \
- DB.get_city(server_id, gate_id, city_id) if city_id else \
- DB.get_gate(server_id, gate_id) if gate_id else \
- DB.get_server(server_id)
+ STATE.get_city(server_id, gate_id, city_id) if city_id else \
+ STATE.get_gate(server_id, gate_id) if gate_id else \
+ STATE.get_server(server_id)
players = list(layer.players)
start = first_index - 1
return players[start:start+count]
def find_user_by_capcom_id(self, capcom_id):
- sessions = DB.find_users(capcom_id=capcom_id)
+ sessions = STATE.find_users(capcom_id=capcom_id)
if sessions:
return sessions[0]
return None
def find_users(self, capcom_id, hunter_name, first_index, count):
- users = DB.find_users(capcom_id, hunter_name)
+ users = STATE.find_users(capcom_id, hunter_name)
start = first_index - 1
return users[start:start+count]
def get_user_name(self, capcom_id):
- return DB.get_user_name(capcom_id)
+ return DB.get_name(capcom_id)
def leave_server(self):
- DB.leave_server(self, self.local_info["server_id"])
+ STATE.leave_server(self)
def get_gates(self):
- return DB.get_gates(self.local_info["server_id"])
+ return STATE.get_gates(self.local_info["server_id"])
def join_gate(self, gate_id):
- DB.join_gate(self, self.local_info["server_id"], gate_id)
+ STATE.join_gate(self, self.local_info["server_id"], gate_id)
self.state = SessionState.GATE
def leave_gate(self):
- DB.leave_gate(self)
+ STATE.leave_gate(self)
self.state = SessionState.LOG_IN
def get_cities(self):
- return DB.get_cities(self.local_info["server_id"],
- self.local_info["gate_id"])
+ return STATE.get_cities(self.local_info["server_id"],
+ self.local_info["gate_id"])
def is_city_empty(self, city_id):
- return DB.get_city(self.local_info["server_id"],
- self.local_info["gate_id"],
- city_id).get_state() == db.LayerState.EMPTY
+ return STATE.get_city(self.local_info["server_id"],
+ self.local_info["gate_id"],
+ city_id).get_state() == LayerState.EMPTY
def reserve_city(self, city_id, reserve):
- return DB.reserve_city(self.local_info["server_id"],
- self.local_info["gate_id"],
- city_id, reserve)
+ return STATE.reserve_city(self.local_info["server_id"],
+ self.local_info["gate_id"],
+ city_id, reserve)
def create_city(self, city_id, settings, optional_fields):
- return DB.create_city(self,
- self.local_info["server_id"],
- self.local_info["gate_id"],
- city_id, settings, optional_fields)
+ return STATE.create_city(self,
+ self.local_info["server_id"],
+ self.local_info["gate_id"],
+ city_id, settings, optional_fields)
def join_city(self, city_id):
- DB.join_city(self,
- self.local_info["server_id"],
- self.local_info["gate_id"],
- city_id)
+ STATE.join_city(self,
+ self.local_info["server_id"],
+ self.local_info["gate_id"],
+ city_id)
self.state = SessionState.CITY
def leave_city(self):
- DB.leave_city(self)
+ STATE.leave_city(self)
self.state = SessionState.GATE
def try_transfer_city_leadership(self):
diff --git a/mh/state.py b/mh/state.py
new file mode 100644
index 0000000..0b78372
--- /dev/null
+++ b/mh/state.py
@@ -0,0 +1,890 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: Copyright (C) 2023 MH3SP Server Project
+# SPDX-License-Identifier: AGPL-3.0-or-later
+"""Monster Hunter state module."""
+
+
+from mh import database
+from math import floor
+import time
+from threading import RLock, Event
+
+
+try:
+ # Python 3
+ import selectors
+except ImportError:
+ # Python 2
+ import externals.selectors2 as selectors
+
+
+RESERVE_DC_TIMEOUT = 40.0
+
+
+class ServerType(object):
+ OPEN = 1
+ ROOKIE = 2
+ EXPERT = 3
+ RECRUITING = 4
+
+
+class LayerState(object):
+ EMPTY = 1
+ FULL = 2
+ JOINABLE = 3
+
+
+class Lockable(object):
+ def __init__(self):
+ self._lock = RLock()
+
+ def lock(self):
+ return self
+
+ def __enter__(self):
+ # Returns True if lock was acquired, False otherwise
+ return self._lock.acquire()
+
+ def __exit__(self, *args):
+ self._lock.release()
+
+
+class Players(Lockable):
+ def __init__(self, capacity):
+ assert capacity > 0, "Collection capacity can't be zero"
+
+ self.slots = [None for _ in range(capacity)]
+ self.used = 0
+ super(Players, self).__init__()
+
+ def get_used_count(self):
+ return self.used
+
+ def get_capacity(self):
+ return len(self.slots)
+
+ def add(self, item):
+ with self.lock():
+ if self.used >= len(self.slots):
+ return -1
+
+ item_index = self.index(item)
+ if item_index != -1:
+ return item_index
+
+ for i, v in enumerate(self.slots):
+ if v is not None:
+ continue
+
+ self.slots[i] = item
+ self.used += 1
+ return i
+
+ return -1
+
+ def remove(self, item):
+ assert item is not None, "Item != None"
+
+ with self.lock():
+ if self.used < 1:
+ return False
+
+ if isinstance(item, int):
+ if item >= self.get_capacity():
+ return False
+
+ self.slots[item] = None
+ self.used -= 1
+ return True
+
+ for i, v in enumerate(self.slots):
+ if v != item:
+ continue
+
+ self.slots[i] = None
+ self.used -= 1
+ return True
+
+ return False
+
+ def index(self, item):
+ assert item is not None, "Item != None"
+
+ for i, v in enumerate(self.slots):
+ if v == item:
+ return i
+
+ return -1
+
+ def clear(self):
+ with self.lock():
+ for i in range(self.get_capacity()):
+ self.slots[i] = None
+
+ def find_first(self, **kwargs):
+ if self.used < 1:
+ return None
+
+ for p in self.slots:
+ if p is None:
+ continue
+
+ for k, v in kwargs.items():
+ if getattr(p, k) != v:
+ break
+ else:
+ return p
+
+ return None
+
+ def find_by_capcom_id(self, capcom_id):
+ return self.find_first(capcom_id=capcom_id)
+
+ def __len__(self):
+ return self.used
+
+ def __iter__(self):
+ if self.used < 1:
+ raise StopIteration
+
+ for i, v in enumerate(self.slots):
+ if v is None:
+ continue
+
+ yield i, v
+
+ def serialize(self):
+ if self.used == 0:
+ return {"capacity": len(self.slots)}
+ pdict = {
+ "slots": [(p.serialize() if p is not None else None)
+ for p in self.slots],
+ "used": self.used
+ }
+ return pdict
+
+ @staticmethod
+ def deserialize(pdict, parent):
+ if "used" not in pdict.keys():
+ return Players(pdict["capacity"])
+ from mh.session import Session
+ players = Players(len(pdict["slots"]))
+ players.slots = [(Session.deserialize(p) if p is not None else None)
+ for p in pdict["slots"]]
+ players.used = pdict["used"]
+ return players
+
+
+class Circle(Lockable):
+ def __init__(self, parent):
+ # type: (City) -> None
+ self.parent = parent
+ self.leader = None
+ self.players = Players(4)
+ self.departed = False
+ self.quest_id = 0
+ self.embarked = False
+ self.password = None
+ self.remarks = None
+
+ self.unk_byte_0x0e = 0
+ super(Circle, self).__init__()
+
+ def get_population(self):
+ return len(self.players)
+
+ def get_capacity(self):
+ return self.players.get_capacity()
+
+ def is_full(self):
+ return self.get_population() == self.get_capacity()
+
+ def is_empty(self):
+ return self.leader is None
+
+ def is_joinable(self):
+ return not self.departed and not self.is_full()
+
+ def has_password(self):
+ return self.password is not None
+
+ def reset_players(self, capacity):
+ with self.lock():
+ self.players = Players(capacity)
+
+ def reset(self):
+ with self.lock():
+ self.leader = None
+ self.reset_players(4)
+ self.departed = False
+ self.quest_id = 0
+ self.embarked = False
+ self.password = None
+ self.remarks = None
+
+ self.unk_byte_0x0e = 0
+
+ def serialize(self):
+ serialized_players = self.players.serialize()
+ if "used" not in serialized_players.keys():
+ return {}
+ cdict = {
+ "parent": None,
+ "leader": self.leader.serialize() if self.leader is not None
+ else None,
+ "players": serialized_players,
+ "departed": self.departed,
+ "quest_id": self.quest_id,
+ "embarked": self.embarked,
+ "password": self.password,
+ "remarks": self.remarks,
+ "unk_byte_0x0e": self.unk_byte_0x0e
+ }
+ return cdict
+
+ @staticmethod
+ def deserialize(cdict, parent):
+ if not len(cdict.keys()):
+ return Circle(parent)
+ from mh.session import Session
+ circle = Circle(parent)
+ circle.leader = Session.deserialize(cdict["leader"])\
+ if cdict["leader"] is not None else None
+ circle.players = Players.deserialize(cdict["players"], circle)
+ circle.departed = cdict["departed"]
+ circle.quest_id = cdict["quest_id"]
+ circle.embarked = cdict["embarked"]
+ circle.password = cdict["password"]
+ circle.remarks = cdict["remarks"]
+ circle.unk_byte_0x0e = cdict["unk_byte_0x0e"]
+ return circle
+
+
+class City(Lockable):
+ LAYER_DEPTH = 3
+
+ def __init__(self, name, parent):
+ # type: (str, Gate) -> None
+ self.name = name
+ self.parent = parent
+ self.state = LayerState.EMPTY
+ self.players = Players(4)
+ self.optional_fields = []
+ self.leader = None
+ self.reserved = None
+ self.circles = [
+ # One circle per player
+ Circle(self) for _ in range(self.get_capacity())
+ ]
+ super(City, self).__init__()
+
+ def get_population(self):
+ return len(self.players)
+
+ def in_quest_players(self):
+ return sum(p.is_in_quest() for _, p in self.players)
+
+ def get_capacity(self):
+ return self.players.get_capacity()
+
+ def get_state(self):
+ if self.reserved:
+ return LayerState.FULL
+ size = self.get_population()
+ if size == 0:
+ return LayerState.EMPTY
+ elif size < self.get_capacity():
+ return LayerState.JOINABLE
+ else:
+ return LayerState.FULL
+
+ def get_pathname(self):
+ pathname = self.name
+ it = self.parent
+ while it is not None:
+ pathname = it.name + "\t" + pathname
+ it = it.parent
+ return pathname
+
+ def get_first_empty_circle(self):
+ with self.lock():
+ for index, circle in enumerate(self.circles):
+ if circle.is_empty():
+ return circle, index
+ return None, None
+
+ def get_circle_for(self, leader_session):
+ with self.lock():
+ for index, circle in enumerate(self.circles):
+ if circle.leader == leader_session:
+ return circle, index
+ return None, None
+
+ def clear_circles(self):
+ with self.lock():
+ for circle in self.circles:
+ circle.reset()
+
+ def reserve(self, reserve):
+ with self.lock():
+ if reserve:
+ self.reserved = time.time()
+ else:
+ self.reserved = None
+
+ def reset(self):
+ with self.lock():
+ self.state = LayerState.EMPTY
+ self.players.clear()
+ self.optional_fields = []
+ self.leader = None
+ self.reserved = None
+ self.clear_circles()
+
+ def get_all_players(self):
+ with self.players.lock():
+ return [p for _, p in self.players]
+
+ def serialize(self):
+ serialized_players = self.players.serialize()
+ if "used" not in serialized_players.keys():
+ return {"name": self.name}
+ cdict = {
+ "name": self.name,
+ "parent": None,
+ "state": self.state,
+ "players": serialized_players,
+ "optional_fields": self.optional_fields,
+ "leader": self.leader.serialize() if self.leader is not None
+ else None,
+ "reserved": self.reserved,
+ "circles": [c.serialize() for c in self.circles]
+ }
+ return cdict
+
+ @staticmethod
+ def deserialize(cdict, parent):
+ if len(cdict.keys()) < 2:
+ return City(cdict["name"], None)
+ from mh.session import Session
+ city = City(str(cdict["name"]) if cdict["name"] is not None
+ else cdict["name"], cdict["parent"])
+ city.parent = parent
+ city.state = cdict["state"]
+ city.players = Players.deserialize(cdict["players"], parent)
+ city.optional_fields = cdict["optional_fields"]
+ city.leader = Session.deserialize(cdict["leader"])\
+ if cdict["leader"] is not None else None
+ city.reserved = cdict["reserved"]
+ city.circles = [Circle.deserialize(c, city) for c in cdict["circles"]]
+ return city
+
+
+class Gate(object):
+ LAYER_DEPTH = 2
+
+ def __init__(self, name, parent, city_count=40, player_capacity=100):
+ # type: (str, Server, int, int) -> None
+ self.name = name
+ self.parent = parent
+ self.state = LayerState.EMPTY
+ self.cities = [
+ City("City{}".format(i), self)
+ for i in range(1, city_count+1)
+ ]
+ self.players = Players(player_capacity)
+ self.optional_fields = []
+
+ def get_population(self):
+ return len(self.players) + sum((
+ city.get_population()
+ for city in self.cities
+ ))
+
+ def get_capacity(self):
+ return self.players.get_capacity()
+
+ def get_state(self):
+ size = self.get_population()
+ if size == 0:
+ return LayerState.EMPTY
+ elif size < self.get_capacity():
+ return LayerState.JOINABLE
+ else:
+ return LayerState.FULL
+
+ def get_all_players(self):
+ players = [p for _, p in self.players]
+ for city in self.cities:
+ players = players + city.get_all_players()
+ return players
+
+ def serialize(self):
+ gdict = {
+ "name": self.name,
+ "parent": None,
+ "state": self.state,
+ "cities": [c.serialize() for c in self.cities],
+ "players": self.players.serialize(),
+ "optional_fields": self.optional_fields
+ }
+ return gdict
+
+ @staticmethod
+ def deserialize(gdict, parent):
+ gate = Gate(str(gdict["name"]) if gdict["name"] is not None
+ else gdict["name"], parent)
+ gate.state = gdict["state"]
+ gate.cities = [City.deserialize(c, gate) for c in gdict["cities"]]
+ gate.players = Players.deserialize(gdict["players"], gate)
+ gate.optional_fields = gdict["optional_fields"]
+ return gate
+
+
+class Server(object):
+ LAYER_DEPTH = 1
+
+ def __init__(self, name, server_type, capacity=2000,
+ addr=None, port=None):
+ self.name = name
+ self.parent = None
+ self.server_type = server_type
+ self.addr = addr
+ self.port = port
+ gate_count = int(floor(capacity / 100))
+ remainder = capacity % 100
+ self.gates = [
+ Gate("City Gate {}".format(i), self)
+ for i in range(1, gate_count+1)
+ ]
+ if remainder:
+ self.gates.append(Gate(
+ "City Gate {}".format(len(self.gates)+1),
+ self, player_capacity=remainder
+ ))
+ self.players = Players(capacity)
+
+ def get_population(self):
+ return len(self.players) + sum((
+ gate.get_population() for gate in self.gates
+ ))
+
+ def get_capacity(self):
+ return self.players.get_capacity()
+
+ def get_all_players(self):
+ players = [p for _, p in self.players]
+ for gate in self.gates:
+ players = players + gate.get_all_players()
+ return players
+
+ def serialize(self):
+ sdict = {
+ "name": self.name,
+ "parent": self.parent,
+ "server_type": self.server_type,
+ "addr": self.addr,
+ "port": self.port,
+ "gates": [g.serialize() for g in self.gates],
+ "players": self.players.serialize()
+ }
+ return sdict
+
+ @staticmethod
+ def deserialize(sdict):
+ server = Server(str(sdict["name"]) if sdict["name"] is not None
+ else sdict["name"],
+ int(sdict["server_type"]) if sdict["server_type"]
+ else sdict["server_type"],
+ addr=str(sdict["addr"]) if sdict["addr"] is not None
+ else sdict["addr"],
+ port=int(sdict["port"]) if sdict["port"] is not None
+ else sdict["port"])
+ server.parent = sdict["parent"]
+ server.gates = [Gate.deserialize(g, server) for g in sdict["gates"]]
+ server.players = Players.deserialize(sdict["players"], server)
+ return server
+
+
+def new_servers():
+ servers = []
+ servers.extend([
+ Server("Valor{}".format(i), ServerType.OPEN)
+ for i in range(1, 5)
+ ])
+ servers.extend([
+ Server("Beginners{}".format(i), ServerType.ROOKIE)
+ for i in range(1, 3)
+ ])
+ servers.extend([
+ Server("Veterans{}".format(i), ServerType.EXPERT)
+ for i in range(1, 3)
+ ])
+ servers.extend([
+ Server("Greed{}".format(i), ServerType.RECRUITING)
+ for i in range(1, 5)
+ ])
+ return servers
+
+
+class State(object):
+ def __init__(self):
+ self.sessions = {
+ # PAT Ticket => Owner's session
+ }
+ self.capcom_ids = {
+ # Capcom ID => {Owner's name, Owner's session}
+ }
+ self.cache = None
+ self.server_id = None
+ self.server = None
+ self.initialized = Event()
+
+ def setup_server(self, server_id, server_name, server_type,
+ capacity, server_addr, server_port):
+ self.server_id = server_id
+ if server_id != 0:
+ self.server = Server(server_name, server_type, capacity,
+ addr=server_addr, port=server_port)
+ else:
+ self.server = None
+ self.initialized.set()
+
+ def new_pat_ticket(self, session):
+ """Generates a new PAT ticket for the session."""
+ while True:
+ session.pat_ticket = database.new_random_str(11)
+ if session.pat_ticket not in self.sessions:
+ break
+ self.sessions[session.pat_ticket] = session
+ return session.pat_ticket
+
+ def register_pat_ticket(self, session):
+ """Register a Session's PAT ticket from another server."""
+ self.sessions[session.pat_ticket] = session
+ self.capcom_ids[session.capcom_id] = {"name": "", "session": None}
+ self.join_server(session, self.server_id)
+
+ def use_capcom_id(self, session, capcom_id, name=None):
+ """Attach the session to the Capcom ID."""
+ assert capcom_id in self.capcom_ids, "Capcom ID doesn't exist"
+
+ not_in_use = self.capcom_ids[capcom_id]["session"] is None
+ assert not_in_use, "Capcom ID is already in use. Try again in 60 seconds."
+
+ name = name or self.capcom_ids[capcom_id]["name"]
+ self.capcom_ids[capcom_id] = {"name": name, "session": session}
+
+ db = database.get_instance()
+ db.assign_name(capcom_id, name)
+
+ # TODO: Check if stable index is required
+ if capcom_id not in db.friend_lists:
+ db.friend_lists[capcom_id] = []
+ if capcom_id not in db.friend_requests:
+ db.friend_requests[capcom_id] = []
+
+ return name
+
+ def use_user(self, session, index, name):
+ """Use User from the slot or create one if empty"""
+ assert 1 <= index <= 6, "Invalid Capcom ID slot"
+ index -= 1
+ users = database.get_instance().get_capcom_ids(
+ session.online_support_code
+ )
+ while users[index] == "******":
+ capcom_id = database.new_random_str(6)
+ if capcom_id not in self.capcom_ids and \
+ not database.get_instance().get_name(capcom_id):
+ self.capcom_ids[capcom_id] = {"name": name, "session": None}
+ database.get_instance().assign_capcom_id(
+ session.online_support_code, index, capcom_id
+ )
+ break
+ else:
+ capcom_id = users[index]
+ name = self.use_capcom_id(session, capcom_id, name)
+ session.capcom_id = capcom_id
+ session.hunter_name = name
+ database.get_instance().use_user(session, index, name)
+
+ def get_session(self, pat_ticket):
+ """Returns existing PAT session or None."""
+ session = self.sessions.get(pat_ticket)
+ if session and session.capcom_id:
+ try:
+ self.use_capcom_id(
+ session, session.capcom_id, session.hunter_name
+ )
+ except AssertionError as e:
+ return None
+ return session
+
+ def disconnect_session(self, session):
+ """Detach the session from its Capcom ID."""
+ if not session.capcom_id:
+ # Capcom ID isn't chosen yet with OPN/LMP servers
+ return
+ self.capcom_ids[session.capcom_id]["session"] = None
+
+ def delete_session(self, session):
+ """Delete the session from the database."""
+ self.disconnect_session(session)
+ pat_ticket = session.pat_ticket
+ if pat_ticket in self.sessions:
+ del self.sessions[pat_ticket]
+ self.cache.notify_session_deletion(session.capcom_id)
+
+ def fetch_id(self, capcom_id):
+ if capcom_id not in self.capcom_ids:
+ self.capcom_ids[capcom_id] = {
+ "name": database.get_instance().get_name(capcom_id),
+ "session": None
+ }
+ return self.capcom_ids[capcom_id]
+
+ def get_users(self, session, first_index, count):
+ """Returns Capcom IDs tied to the session."""
+ users = database.get_instance().get_capcom_ids(
+ session.online_support_code
+ )
+ capcom_ids = [
+ (i, (capcom_id, self.capcom_ids.get(
+ capcom_id, self.fetch_id(capcom_id)
+ )))
+ for i, capcom_id in enumerate(users[:count], first_index)
+ ]
+ size = len(capcom_ids)
+ if size < count:
+ capcom_ids.extend([
+ (index, ("******", {}))
+ for index in range(first_index+size, first_index+count)
+ ])
+ return capcom_ids
+
+ def join_server(self, session, server_id):
+ server = self.get_server(server_id)
+ if server_id != self.server_id:
+ # Joining another server
+ if self.cache:
+ self.cache.send_session_info(server_id, session)
+ if session.local_info["server_id"] is not None:
+ self.leave_server(session)
+ else:
+ # Connecting to this server
+ server.players.add(session)
+ session.local_info["server_id"] = server_id
+ session.local_info["server_name"] = server.name
+ return server
+
+ def leave_server(self, session):
+ self.server.players.remove(session)
+ session.local_info["server_id"] = None
+ session.local_info["server_name"] = None
+
+ def get_server_time(self):
+ pass
+
+ def get_game_time(self):
+ pass
+
+ def get_servers_version(self):
+ return self.cache.servers_version
+
+ def get_servers(self, include_ids=False):
+ if not self.cache:
+ return []
+ server_ids, servers = self.cache.get_server_list(include_ids=True)
+ below, above = [], []
+ below_ids, above_ids = [], []
+ for server_id, server in zip(server_ids, servers):
+ if server_id < self.server_id:
+ below.append(server)
+ below_ids.append(server_id)
+ elif server_id > self.server_id:
+ above.append(server)
+ above_ids.append(server_id)
+ if self.server_id != 0:
+ below.append(self.server)
+ below_ids.append(self.server_id)
+ below.extend(above)
+ below_ids.extend(above_ids)
+ if include_ids:
+ return below_ids, below
+ return below
+
+ def get_server(self, server_id):
+ if server_id == self.server_id:
+ return self.server
+ return self.cache.get_server(server_id)
+
+ def get_gates(self, server_id):
+ return self.get_server(server_id).gates
+
+ def get_gate(self, server_id, index):
+ gates = self.get_gates(server_id)
+ assert 0 < index <= len(gates), "Invalid gate index"
+ return gates[index - 1]
+
+ def join_gate(self, session, server_id, index):
+ gate = self.get_gate(server_id, index)
+ gate.parent.players.remove(session)
+ gate.players.add(session)
+ session.local_info["gate_id"] = index
+ session.local_info["gate_name"] = gate.name
+ return gate
+
+ def leave_gate(self, session):
+ gate = self.get_gate(session.local_info["server_id"],
+ session.local_info["gate_id"])
+ gate.parent.players.add(session)
+ gate.players.remove(session)
+ session.local_info["gate_id"] = None
+ session.local_info["gate_name"] = None
+
+ def get_cities(self, server_id, gate_id):
+ return self.get_gate(server_id, gate_id).cities
+
+ def get_city(self, server_id, gate_id, index):
+ cities = self.get_cities(server_id, gate_id)
+ assert 0 < index <= len(cities), "Invalid city index"
+ return cities[index - 1]
+
+ def reserve_city(self, server_id, gate_id, index, reserve):
+ city = self.get_city(server_id, gate_id, index)
+ with city.lock():
+ reserved_time = city.reserved
+ if reserve and reserved_time and \
+ time.time()-reserved_time < RESERVE_DC_TIMEOUT:
+ return False
+ city.reserve(reserve)
+ return True
+
+ def get_all_users(self, server_id, gate_id, city_id):
+ """Search for users in layers and its children.
+
+ Let's assume wildcard search isn't possible for servers and gates.
+ A wildcard search happens when the id is zero.
+ """
+ assert 0 < server_id, "Invalid server index"
+ assert 0 < gate_id, "Invalid gate index"
+ gate = self.get_gate(server_id, gate_id)
+ users = list(gate.players)
+ cities = [
+ self.get_city(server_id, gate_id, city_id)
+ ] if city_id else self.get_cities(server_id, gate_id)
+ for city in cities:
+ users.extend(list(city.players))
+ return users
+
+ def find_users(self, capcom_id="", hunter_name=""):
+ assert capcom_id or hunter_name, "Search can't be empty"
+ users = []
+ for user_id, user_info in self.capcom_ids.items():
+ session = user_info["session"]
+ if not session:
+ continue
+ if capcom_id and capcom_id not in user_id:
+ continue
+ if hunter_name and \
+ hunter_name.lower() not in user_info["name"].lower():
+ continue
+ users.append(session)
+ for user_info in self.cache.get_remote_players_list():
+ if not user_info:
+ continue
+ if capcom_id and capcom_id not in user_info.capcom_id:
+ continue
+ if hunter_name and \
+ hunter_name.lower() not in user_info.hunter_name.lower():
+ continue
+ users.append(user_info)
+ return users
+
+ def create_city(self, session, server_id, gate_id, index,
+ settings, optional_fields):
+ city = self.get_city(server_id, gate_id, index)
+ with city.lock():
+ city.optional_fields = optional_fields
+ city.leader = session
+ city.reserved = None
+ return city
+
+ def join_city(self, session, server_id, gate_id, index):
+ city = self.get_city(server_id, gate_id, index)
+ with city.lock():
+ city.parent.players.remove(session)
+ city.players.add(session)
+ session.local_info["city_name"] = city.name
+ session.local_info["city_id"] = index
+ return city
+
+ def leave_city(self, session):
+ city = self.get_city(session.local_info["server_id"],
+ session.local_info["gate_id"],
+ session.local_info["city_id"])
+ with city.lock():
+ city.parent.players.add(session)
+ city.players.remove(session)
+ if not city.get_population():
+ city.reset()
+ session.local_info["city_id"] = None
+ session.local_info["city_name"] = None
+
+ def layer_detail_search(self, server_type, fields):
+ cities = []
+
+ def match_city(city, fields):
+ with city.lock():
+ return all((
+ field in city.optional_fields
+ for field in fields
+ ))
+
+ for server in self.get_servers():
+ if server.server_type != server_type:
+ continue
+ for gate in server.gates:
+ if not gate.get_population():
+ continue
+ cities.extend([
+ city
+ for city in gate.cities
+ if match_city(city, fields)
+ ])
+ return cities
+
+ def update_players(self):
+ # Central server method for clearing unused Capcom IDs
+ for capcom_id, player_info in self.capcom_ids.items():
+ if player_info["session"] is not None and \
+ player_info["session"].local_info["server_id"] and \
+ capcom_id not in self.cache.players:
+ self.capcom_ids[capcom_id]["session"] = None
+ elif capcom_id in self.cache.players:
+ self.capcom_ids[capcom_id]["session"] =\
+ self.cache.players[capcom_id]
+
+ def update_capcom_id(self, session):
+ # Central server method for keeping track of in-use Capcom IDs
+ self.capcom_ids[session.capcom_id]["session"] = session
+
+ def session_ready(self, pat_ticket):
+ if self.server_id == 0:
+ return True
+ return self.cache.session_ready(pat_ticket)
+
+ def set_session_ready(self, pat_ticket, store_data):
+ self.cache.set_session_ready(pat_ticket, store_data)
+
+ def close_cache(self):
+ self.cache.close()
+
+
+CURRENT_STATE = State()
+
+
+def get_instance():
+ return CURRENT_STATE
diff --git a/other/cache.py b/other/cache.py
new file mode 100644
index 0000000..04c2882
--- /dev/null
+++ b/other/cache.py
@@ -0,0 +1,818 @@
+#! /usr/bin/env python
+# -*- coding: utf-8 -*-
+"""Monster Hunter cache module.
+
+ Monster Hunter 3 Server Project
+ Copyright (C) 2023 Ze SpyRo
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+"""
+
+from mh.time_utils import Timer
+from mh.session import Session
+from mh.state import get_instance, Server
+from other.utils import Logger, create_logger, get_remote_config, \
+ get_central_config, get_config, get_ip
+
+from threading import Lock, Event
+import socket
+import struct
+import logging
+import json
+
+try:
+ # Python 3
+ import selectors
+ from typing import Tuple, Union, Optional, Dict, List, Callable, TYPE_CHECKING
+ if TYPE_CHECKING:
+ from fmp_server import FmpRequestHandler
+ from mh.pat_item import ConnectionData
+except ImportError:
+ # Python 2
+ import externals.selectors2 as selectors
+
+
+class PacketTypes(object):
+ Ping = 0x0000
+ FriendlyHello = 0x0001
+ ReqConnectionInfo = 0x0002
+ SendConnectionInfo = 0x0003
+ ReqServerRefresh = 0x0004
+ SessionInfo = 0x0005
+ ServerIDList = 0x0006
+ ServerShutdown = 0x0007
+ SessionDisconnect = 0x0008
+
+
+class CentralConnectionHandler(object):
+ def __init__(self, sck, client_address, cache):
+ # type: (socket.socket, Tuple[str, int], Cache) -> None
+ self.id = -1 # type: int
+ self.socket = sck
+ self.client_address = client_address
+ self.cache = cache
+ self.rfile = self.socket.makefile('rb', -1)
+ self.wfile = self.socket.makefile('wb', 0)
+
+ self.rw = Lock()
+ self.finished = False # type: bool
+
+ self.handler_functions = {
+ PacketTypes.Ping: self.RecvPing,
+ PacketTypes.FriendlyHello: self.RecvFriendlyHello,
+ PacketTypes.SendConnectionInfo: self.RecvConnectionInfo,
+ PacketTypes.ReqConnectionInfo: self.RecvReqConnectionInfo,
+ PacketTypes.ReqServerRefresh: self.RecvReqServerRefresh,
+ PacketTypes.SessionInfo: self.RecvSessionInfo,
+ PacketTypes.ServerShutdown: self.RecvServerShutdown,
+ PacketTypes.SessionDisconnect: self.RecvSessionDisconnect
+ } # type: Dict[int, Callable[[int, bytes], None]]
+
+ def fileno(self):
+ # type: () -> int
+ return self.socket.fileno()
+
+ def on_recv(self):
+ # type: () -> Optional[Tuple[int, int, bytes]]
+ header = self.rfile.read(10)
+ if not len(header) or len(header) < 10:
+ return None
+
+ return self.recv_packet(header)
+
+ def recv_packet(self, header):
+ # type: (bytes) -> Tuple[int, int, bytes]
+ size, packet_id, server_id = struct.unpack(">IIH", header)
+ data = self.rfile.read(size)
+ return server_id, packet_id, data
+
+ def send_packet(self, packet_id=0, data=b""):
+ # type: (int, bytes) -> None
+ self.wfile.write(self.pack_data(
+ data, packet_id
+ ))
+
+ def pack_data(self, data, packet_id):
+ # type: (bytes, int) -> bytes
+ return struct.pack(">II", len(data), packet_id) + data
+
+ def is_finished(self):
+ # type: () -> bool
+ return self.finished
+
+ def on_exception(self, e):
+ # type: (Exception) -> None
+ self.finish()
+
+ def direct_to_handler(self, packet):
+ # type: (Tuple[int, int, bytes]) -> None
+ server_id, packet_type, data = packet
+ self.handler_functions[packet_type](server_id, data)
+
+ def RecvPing(self, server_id, data):
+ # type: (int, bytes) -> None
+ pass
+
+ def RecvFriendlyHello(self, server_id, data):
+ # type: (int, bytes) -> None
+ self.cache.debug("Recieved a friendly hello from {}!".format(
+ server_id
+ ))
+ self.cache.register_handler(server_id, self)
+ self.ReqConnectionInfo()
+
+ def ReqConnectionInfo(self):
+ # type: () -> None
+ self.cache.debug("Requesting connection info.")
+ self.send_packet(PacketTypes.ReqConnectionInfo, b"")
+
+ def RecvConnectionInfo(self, server_id, data):
+ # type: (int, bytes) -> None
+ self.cache.debug("Recieved connection info sized {} from {}".format(
+ len(data), server_id
+ ))
+ server = Server.deserialize(json.loads(data.decode('utf-8')))
+ self.cache.servers[server_id] = server
+ self.cache.update_players()
+
+ def RecvReqConnectionInfo(self, server_id, data):
+ # type: (int, bytes) -> None
+ requested_server_id, = struct.unpack(">H", data)
+ self.cache.debug("Recieved request for data of Server {}.".format(
+ requested_server_id
+ ))
+ if server_id in self.cache.servers:
+ data = json.dumps(self.cache.servers[server_id].serialize()).encode('utf-8')
+ self.SendConnectionInfo(data)
+
+ def SendConnectionInfo(self, data):
+ # type: (bytes) -> None
+ self.cache.debug("Sending updated connection info.")
+ self.send_packet(PacketTypes.SendConnectionInfo, data)
+
+ def RecvReqServerRefresh(self, server_id, data):
+ # type: (int, bytes) -> None
+ self.cache.debug("Recieved server refresh request from \
+ Server {}.".format(
+ server_id
+ ))
+ self.SendServerIDList()
+ for _server_id in self.cache.servers:
+ data = struct.pack(">H", _server_id)
+ data += json.dumps(self.cache.servers[_server_id].serialize()).encode('utf-8')
+ self.SendConnectionInfo(data)
+
+ def SendServerIDList(self):
+ # type: () -> None
+ self.cache.debug("Sending updated Server ID list.")
+ data = struct.pack(">H", self.cache.servers_version)
+ data += struct.pack(">H", len(self.cache.servers))
+ for _server_id in self.cache.servers:
+ data += struct.pack(">H", _server_id)
+ self.send_packet(PacketTypes.ServerIDList, data)
+
+ def RecvSessionInfo(self, server_id, data):
+ # type: (int, bytes) -> None
+ dest_server_id, = struct.unpack(">H", data[:2])
+ self.cache.debug("Recieved session data from Server {} \
+ bound for Server {}.".format(
+ server_id, dest_server_id
+ ))
+ self.cache.update_player_record(
+ Session.deserialize(json.loads(data[2:].decode('utf-8')))
+ )
+ self.cache.get_handler(dest_server_id).SendSessionInfo(data[2:])
+
+ def SendSessionInfo(self, ser_session):
+ # type: (bytes) -> None
+ self.cache.debug("Dispatching session info to remote Server.")
+ self.send_packet(PacketTypes.SessionInfo, ser_session)
+
+ def RecvServerShutdown(self, server_id, data):
+ # type: (int, bytes) -> None
+ raise Exception("Server shutting down.")
+
+ def RecvSessionDisconnect(self, server_id, data):
+ # type: (int, bytes) -> None
+ length, = struct.unpack(">H", data[:2])
+ capcom_id = str(data[2:2+length].decode('utf-8'))
+ if capcom_id in self.cache.players:
+ del self.cache.players[capcom_id]
+ get_instance().update_players()
+
+ def finish(self):
+ # type: () -> None
+ if self.finished:
+ return
+
+ self.finished = True
+
+ try:
+ self.wfile.close()
+ except Exception:
+ pass
+
+ try:
+ self.rfile.close()
+ except Exception:
+ pass
+
+ try:
+ self.socket.close()
+ except Exception:
+ pass
+
+
+class RemoteConnectionHandler(object):
+ def __init__(self, sck, client_address, cache):
+ # type: (socket.socket, Tuple[str, int], Cache) -> None
+ self.id = 0
+ self.socket = sck
+ self.client_address = client_address
+ self.cache = cache
+ self.rfile = self.socket.makefile('rb', -1)
+ self.wfile = self.socket.makefile('wb', 0)
+
+ self.rw = Lock()
+ self.finished = False # type: bool
+
+ self.handler_functions = {
+ PacketTypes.ReqConnectionInfo: self.ReqConnectionInfo,
+ PacketTypes.SendConnectionInfo: self.RecvConnectionInfo,
+ PacketTypes.SessionInfo: self.RecvSessionInfo,
+ PacketTypes.ServerIDList: self.RecvServerIDList,
+ } # type: Dict[int, Callable[[bytes], None]]
+
+ def fileno(self):
+ # type: () -> int
+ return self.socket.fileno()
+
+ def on_recv(self):
+ # type: () -> Optional[Tuple[int, bytes]]
+ header = self.rfile.read(8)
+ if not len(header) or len(header) < 8:
+ return None
+
+ return self.recv_packet(header)
+
+ def recv_packet(self, header):
+ # type: (bytes) -> Tuple[int, bytes]
+ size, packet_id = struct.unpack(">II", header)
+ data = self.rfile.read(size)
+ return packet_id, data
+
+ def is_finished(self):
+ # type: () -> bool
+ return self.finished
+
+ def on_exception(self, e):
+ # type: (Exception) -> None
+ self.finish()
+
+ def send_packet(self, packet_id=0, data=b""):
+ # type: (int, bytes) -> None
+ self.wfile.write(self.pack_data(
+ data, packet_id
+ ))
+
+ def pack_data(self, data, packet_id):
+ # type: (bytes, int) -> bytes
+ return struct.pack(">IIH", len(data), packet_id,
+ self.cache.server_id) + data
+
+ def direct_to_handler(self, packet):
+ # type: (Tuple[int, bytes]) -> None
+ packet_type, data = packet
+ self.handler_functions[packet_type](data)
+
+ def SendFriendlyHello(self, data=b""):
+ # type: (bytes) -> None
+ self.cache.debug("Sending a friendly hello!")
+ self.send_packet(PacketTypes.FriendlyHello, data)
+
+ def ReqConnectionInfo(self, data):
+ # type: (bytes) -> None
+ self.cache.debug("Recieved request for update connection info from Central.")
+ server = get_instance().server
+ assert server != None
+
+ data = json.dumps(server.serialize()).encode('utf-8')
+ self.SendConnectionInfo(data)
+
+ def SendConnectionInfo(self, data):
+ # type: (bytes) -> None
+ self.cache.debug("Sending connection info to Central.")
+ self.send_packet(PacketTypes.SendConnectionInfo, data)
+
+ def SendReqServerRefresh(self):
+ # type: () -> None
+ self.cache.debug("Requesting refreshed server info from central.")
+ self.send_packet(PacketTypes.ReqServerRefresh, b"")
+
+ def SendReqConnectionInfo(self, server_id):
+ # type: (int) -> None
+ self.cache.debug("Requesting info for Server {}".format(
+ server_id
+ ))
+ self.send_packet(PacketTypes.ReqConnectionInfo,
+ struct.pack(">H", server_id))
+
+ def RecvServerIDList(self, data):
+ # type: (bytes) -> None
+ self.cache.debug("Recieved updated Server ID list from Central.")
+ servers_version, count = struct.unpack(">HH", data[:4])
+ self.cache.update_servers_version(servers_version)
+ updated_server_ids = []
+ for i in range(count):
+ server_id = struct.unpack(">H", data[2*(i+2):2*(i+2)+2])
+ updated_server_ids.append(server_id)
+ for server_id in self.cache.servers.keys():
+ if server_id not in updated_server_ids:
+ self.cache.prune_server(server_id)
+
+ def RecvConnectionInfo(self, data):
+ # type: (bytes) -> None
+ try:
+ server_id, = struct.unpack(">H", data[:2])
+ server = Server.deserialize(json.loads(data[2:].decode('utf-8')))
+ except Exception as e:
+ self.cache.error(e)
+ return
+ self.cache.debug("Obtained updated server info for Server {}".format(
+ server_id
+ ))
+ self.cache.servers[server_id] = server
+
+ def SendSessionInfo(self, server_id, ser_session):
+ # type: (int, bytes) -> None
+ self.cache.debug("Sending Session info to Server {}".format(
+ server_id
+ ))
+ data = struct.pack(">H", server_id)
+ data += ser_session
+ self.send_packet(PacketTypes.SessionInfo, data)
+
+ def RecvSessionInfo(self, data):
+ # type: (bytes) -> None
+ self.cache.debug("Recieved new Session info!")
+ self.cache.new_session(Session.deserialize(json.loads(data.decode('utf-8'))))
+
+ def SendSessionDisconnect(self, capcom_id):
+ # type: (bytes) -> None
+ data = struct.pack(">H", len(capcom_id))
+ data += capcom_id.encode('utf-8')
+ self.send_packet(PacketTypes.SessionDisconnect, data)
+
+ def finish(self):
+ # type: () -> None
+ if self.finished:
+ return
+
+ self.finished = True
+
+ try:
+ self.wfile.close()
+ except Exception:
+ pass
+
+ try:
+ self.rfile.close()
+ except Exception:
+ pass
+
+ try:
+ self.socket.close()
+ except Exception:
+ pass
+
+
+class Cache(Logger):
+ def __init__(self, server_id, debug_mode=False, log_to_file=False,
+ log_to_console=False, log_to_window=False,
+ refresh_period=30, ssl_location='cert/crossserverCA/'):
+ # type: (int, bool, bool, bool, bool, int, str) -> None
+ Logger.__init__(self)
+ self.servers_version = 1 # type: int
+ self.servers = {
+ # To be populated by remote connection
+ } # type: Dict[int, Server]
+
+ self.outbound_sessions = [
+ # (destination_server_id, session)
+ ] # type: List[Tuple[int, bytes]]
+
+ self.players = {
+ # capcom_id -> connectionless sessions from other servers
+ } # type: Dict[str, Session]
+ self.ready_sessions = {
+ # pat_ticket -> True or connection_data
+ } # type: Dict[str, Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]]]
+
+ log_level = logging.DEBUG if debug_mode else logging.INFO
+ log_file = "cache.log" if log_to_file else ""
+
+ self.set_logger(create_logger("Cache", log_level, log_file, log_to_console, log_to_window))
+
+ self.is_central_server = server_id == 0
+ if not self.is_central_server:
+ remote_config = get_remote_config("SERVER{}".format(server_id))
+ get_instance().setup_server(server_id,
+ remote_config["Name"],
+ int(remote_config["ServerType"]),
+ int(remote_config["Capacity"]),
+ get_ip(remote_config["IP"]),
+ int(remote_config["Port"]))
+ else:
+ config = get_config("FMP")
+ get_instance().setup_server(
+ server_id, "", 0, 1, '0.0.0.0', config["Port"]
+ )
+ self.shut_down = False # type: bool
+ self.shut_down_event = Event()
+ self.refresh_period = refresh_period
+ self.handlers = {} # type: Dict[int, CentralConnectionHandler]
+ self.pending_connections = {} # type: Dict[CentralConnectionHandler, Timer]
+ self.remote_server = None # type: Optional[RemoteConnectionHandler]
+ self.server_id = server_id
+ self.central_config = get_central_config()
+ self.central_connection = (self.central_config["CentralIP"],
+ self.central_config["CentralCrossconnectPort"])
+ self.sel = selectors.DefaultSelector()
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.ssl_location = ssl_location
+ if self.central_config["CrossconnectSSL"]:
+ self.create_ssl_wrapper()
+
+ def create_ssl_wrapper(self):
+ # type: () -> None
+ import ssl
+ context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+ if self.is_central_server:
+ context.load_verify_locations(
+ cafile="{}ca.crt".format(self.ssl_location)
+ )
+ context.load_cert_chain("{}MH3SP.crt".format(self.ssl_location),
+ "{}MH3SP.key".format(self.ssl_location))
+ else:
+ context.load_cert_chain(
+ "{}client{}.crt".format(self.ssl_location, self.server_id),
+ "{}client{}.key".format(self.ssl_location, self.server_id)
+ )
+ self.socket = context.wrap_socket(
+ self.socket, server_side=self.is_central_server
+ )
+
+ def update_player_record(self, session):
+ # type: (Session) -> None
+ get_instance().update_capcom_id(session)
+
+ def update_players(self):
+ # type: () -> None
+
+ players = []
+ for server_id, server in self.servers.items():
+ if server_id != self.server_id:
+ players = players + server.get_all_players()
+ new_players = {}
+ for p in players:
+ new_players[p.capcom_id] = p
+ self.players = new_players
+ if self.is_central_server:
+ get_instance().update_players()
+
+ def get_remote_players_list(self):
+ # type: () -> List[Session]
+ players = []
+ for server_id, server in self.servers.items():
+ if server_id != self.server_id:
+ players = players + server.get_all_players()
+ return players
+
+ def update_servers_version(self, servers_version):
+ # type: (int) -> None
+ self.servers_version = servers_version
+
+ def get_server_list(self, include_ids=False):
+ # type: (bool) -> Union[List[Server], Tuple[List[int], List[Server]]]
+ if include_ids:
+ return list(self.servers.keys()), list(self.servers.values())
+ return list(self.servers.values())
+
+ def get_server(self, server_id):
+ # type: (int) -> Server
+ assert server_id in self.servers
+ return self.servers[server_id]
+
+ def send_session_info(self, server_id, session):
+ # type: (int, Session) -> None
+ self.outbound_sessions.append(
+ (server_id, json.dumps(session.serialize()).encode('utf-8'))
+ )
+
+ def new_session(self, session):
+ # type: (Session) -> None
+ pat_ticket = session.pat_ticket
+ assert pat_ticket is not None
+
+ ready_data = self.session_ready(pat_ticket)
+ if ready_data:
+ self.set_session_ready(pat_ticket, False)
+ get_instance().register_pat_ticket(session)
+ self.send_login_packet(*ready_data) # type: ignore
+ else:
+ self.set_session_ready(pat_ticket, session)
+
+ def session_ready(self, pat_ticket):
+ # type: (str) -> Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]]
+ return self.ready_sessions.get(pat_ticket, False)
+
+ def set_session_ready(self, pat_ticket, store_data):
+ # type: (str, Union[bool, Tuple[FmpRequestHandler, ConnectionData, int]]) -> None
+ self.ready_sessions[pat_ticket] = store_data
+
+ def notify_session_deletion(self, capcom_id):
+ # type: str -> None
+ if not self.is_central_server:
+ self.remote_server.SendSessionDisconnect(capcom_id)
+
+ def send_login_packet(self, player_handler, connection_data, seq):
+ # type: (FmpRequestHandler, ConnectionData, int) -> None
+ player_handler.sendNtcLogin(3, connection_data, seq)
+
+ def get_handler(self, server_id):
+ # type: (int) -> CentralConnectionHandler
+ return self.handlers[server_id]
+
+ def register_handler(self, server_id, handler):
+ # type: (int, CentralConnectionHandler) -> None
+ handler.id = server_id
+ del self.pending_connections[handler]
+ self.handlers[server_id] = handler
+
+ def prune_server(self, server_id):
+ # type: (int) -> None
+ for player in self.servers[server_id].get_all_players():
+ if player.capcom_id in self.players:
+ del self.players[player.capcom_id]
+ if server_id != 0:
+ del self.servers[server_id]
+ if self.is_central_server:
+ self.update_servers_version(self.servers_version + 1)
+
+ def maintain_connection(self):
+ # type: () -> None
+ state = get_instance()
+ state.initialized.wait()
+ state.cache = self
+
+ refresh_timer = Timer()
+ if self.is_central_server:
+ # CENTRAL SERVER CONNECTION TO REMOTE
+ self.serve_cache(refresh_timer)
+ else:
+ # REMOTE SERVER CONNECTION TO CENTRAL
+ self.remote_connection_to_central_server(refresh_timer)
+
+ def remote_connection_to_central_server(self, refresh_timer):
+ # type: (Timer) -> None
+ central_host = "localhost" if self.central_connection[0] == "0.0.0.0" else self.central_connection[0]
+ central_addr = (central_host, self.central_connection[1])
+
+ while not self.shut_down:
+ # Connect to the central server if needed
+ if self.remote_server is None or self.remote_server.is_finished():
+ self.info("Connecting to central server at {}:{}...".format(central_addr[0], central_addr[1]))
+
+ if self.remote_server is not None:
+ self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if self.central_config["CrossconnectSSL"]:
+ self.create_ssl_wrapper()
+ self.remote_server = None
+
+ connect_timer = Timer()
+ try:
+ self.socket.settimeout(60.0)
+ self.socket.connect(central_addr)
+ self.socket.settimeout(None)
+ except socket.error as sck_error:
+ self.error("Failed! {}.".format(sck_error))
+ if connect_timer.elapsed() < 60.0:
+ remaining = 60.0 - connect_timer.elapsed()
+ self.debug("Retrying in {:.2f}s".format(remaining))
+ self.shut_down_event.wait(remaining)
+ continue
+
+ self.remote_server = RemoteConnectionHandler(self.socket, central_addr, self)
+ self.sel.register(self.remote_server, selectors.EVENT_READ | selectors.EVENT_WRITE)
+
+ try:
+ # Listen for incoming packets
+ assert self.remote_server is not None
+ events = self.sel.select(timeout=1)
+ for _, event in events:
+ if bool(event & selectors.EVENT_WRITE):
+ self.remote_server.SendFriendlyHello()
+ self.sel.modify(self.remote_server, selectors.EVENT_READ)
+ elif bool(event & selectors.EVENT_READ):
+ packet = self.remote_server.on_recv()
+ if packet is None:
+ continue
+ self.remote_server.direct_to_handler(packet)
+ # Request updated server information
+ if refresh_timer.elapsed() >= self.refresh_period:
+ assert self.remote_server is not None
+ try:
+ self.remote_server.SendReqServerRefresh()
+ finally:
+ refresh_timer.restart()
+ # Pass on an outbound session
+ if len(self.outbound_sessions) > 0:
+ self.debug("Outbound Session dispatching to Central Server.")
+ assert self.remote_server is not None
+ self.remote_server.SendSessionInfo(*self.outbound_sessions.pop(0))
+ except Exception as exc:
+ assert self.remote_server is not None
+ self.remote_server.on_exception(exc)
+ finally:
+ assert self.remote_server is not None
+ if not self.shut_down and self.remote_server.is_finished():
+ self.sel.unregister(self.remote_server)
+
+
+ def serve_cache(self, refresh_timer):
+ # type: (Timer) -> None
+ try:
+ self.socket.bind(self.central_connection)
+ self.socket.listen(0)
+ except socket.error as sck_error:
+ self.error('Failed to bind server to {}:{}. {}'.format(self.central_connection[0], self.central_connection[1], sck_error))
+ return
+ self.info("Listening for remote servers on {}:{}".format(self.central_connection[0], self.central_connection[1]))
+ self.sel.register(self.socket, selectors.EVENT_READ)
+
+ while not self.shut_down:
+ events = self.sel.select(timeout=1)
+ # Respond to incoming packets
+ for key, event in events:
+ connection = key.fileobj
+ if connection == self.socket:
+ # Accept a new connection
+ new_client = None
+ try:
+ new_client = self.socket.accept()
+ except Exception as exc:
+ self.error("Failed to accept incoming connection. {}".format(
+ exc))
+ continue
+
+ client_socket, client_address = new_client
+ self.info("Remote Server connected from {}".format(
+ client_address
+ ))
+ handler = CentralConnectionHandler(client_socket,
+ client_address,
+ self)
+ self.pending_connections[handler] = Timer()
+ self.sel.register(handler, selectors.EVENT_READ)
+ self.update_servers_version(self.servers_version + 1)
+ else:
+ assert event == selectors.EVENT_READ
+ assert isinstance(connection, CentralConnectionHandler)
+ try:
+ packet = connection.on_recv()
+ if packet is None:
+ if connection.is_finished():
+ self.remove_handler(connection)
+ continue
+ connection.direct_to_handler(packet)
+ except Exception as exc:
+ connection.on_exception(exc)
+ if not connection.is_finished():
+ continue
+ self.info(
+ "Connection to Remote Server {} lost: {}.".format(
+ connection.id,
+ exc
+ ))
+ self.remove_handler(connection)
+ if refresh_timer.elapsed() >= self.refresh_period:
+ for _, handler in self.handlers.items():
+ try:
+ handler.ReqConnectionInfo()
+ except Exception as exc:
+ handler.on_exception(exc)
+
+ if handler.is_finished():
+ self.remove_handler(handler)
+ refresh_timer.restart()
+ # Pass on an outbound session
+ if len(self.outbound_sessions) > 0:
+ self.debug("Session outbound...")
+ outbound_session = self.outbound_sessions.pop(0)
+ self.debug("Dispatching Session to Server {}.".format(
+ outbound_session[0]
+ ))
+
+ server_handler = self.get_handler(outbound_session[0])
+
+ try:
+ server_handler.SendSessionInfo(outbound_session[1])
+ except Exception as exc:
+ server_handler.on_exception(exc)
+
+ if server_handler.is_finished():
+ self.remove_handler(server_handler)
+ # Prune dangling pending connection
+ for handler, timer in self.pending_connections.items():
+ if timer.elapsed() < 30.0:
+ continue
+ self.warning('Prunning dangling remote server connection {}:{}'.format(handler.client_address[0], handler.client_address[1]))
+ try:
+ handler.finish()
+ except Exception:
+ pass
+ self.sel.unregister(handler)
+ del self.pending_connections[handler]
+
+ def remove_handler(self, handler):
+ # type: (CentralConnectionHandler) -> None
+ if handler.id > -1:
+ try:
+ del self.handlers[handler.id]
+ except KeyError:
+ pass
+ else:
+ try:
+ del self.pending_connections[handler]
+ except KeyError:
+ pass
+
+ try:
+ self.sel.unregister(handler)
+ except KeyError:
+ pass
+
+ try:
+ handler.finish()
+ except Exception:
+ pass
+
+ if handler.id > -1:
+ self.prune_server(handler.id)
+
+ if self.is_central_server:
+ for handler in self.handlers.values():
+ try:
+ handler.SendServerIDList()
+ except Exception:
+ self.error("Failed to send the Server ID list to a handler.")
+
+ def close(self):
+ # type: () -> None
+ if self.shut_down:
+ return
+
+ self.shut_down = True
+ self.shut_down_event.set()
+
+ if self.is_central_server:
+ for _, handler in self.handlers.items():
+ try:
+ handler.finish()
+ except Exception:
+ pass
+
+ self.handlers.clear()
+
+ for handler, _ in self.pending_connections.items():
+ try:
+ handler.finish()
+ except Exception:
+ pass
+ self.pending_connections.clear()
+ elif self.remote_server is not None:
+ try:
+ self.remote_server.send_packet(PacketTypes.ServerShutdown, b"")
+ except Exception:
+ pass
+
+ try:
+ self.remote_server.finish()
+ except Exception:
+ pass
+
+ self.socket.close()
+ self.sel.close()
+
+ self.info('Server Closed')
diff --git a/other/utils.py b/other/utils.py
index 1d2c68c..221895a 100644
--- a/other/utils.py
+++ b/other/utils.py
@@ -65,6 +65,17 @@ def critical(self, msg, *args, **kwargs):
return self.logger.critical(msg, *args, **kwargs)
+class SimpleNamespace (object):
+ def __init__ (self, **kwargs):
+ self.__dict__.update(kwargs)
+ def __repr__ (self):
+ keys = sorted(self.__dict__)
+ items = ("{}={!r}".format(k, self.__dict__[k]) for k in keys)
+ return "{}({})".format(type(self).__name__, ", ".join(items))
+ def __eq__ (self, other):
+ return self.__dict__ == other.__dict__
+
+
class GenericUnpacker(object):
"""Generic unpacker that maps unpack and pack functions.
@@ -237,6 +248,28 @@ def get_config(name, config_file=CONFIG_FILE):
}
+def get_remote_config(name, config_file=CONFIG_FILE):
+ config = ConfigParser.RawConfigParser(allow_no_value=True)
+ config.read(config_file)
+ return {
+ "IP": config.get(name, "IP"),
+ "Port": config.getint(name, "Port"),
+ "Capacity": config.getint(name, "Capacity"),
+ "Name": config.get(name, "Name"),
+ "ServerType": config.get(name, "ServerType")
+ }
+
+
+def get_central_config(config_file=CONFIG_FILE):
+ config = ConfigParser.RawConfigParser(allow_no_value=True)
+ config.read(CONFIG_FILE)
+ return {
+ "CentralIP": config.get("CENTRAL", "CentralIP"),
+ "CentralCrossconnectPort": config.getint("CENTRAL", "CentralCrossconnectPort"),
+ "CrossconnectSSL": config.getboolean("CENTRAL", "CrossconnectSSL")
+ }
+
+
def get_default_ip():
"""Get the default IP address"""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -344,7 +377,9 @@ def create_server(server_class, server_handler,
address="0.0.0.0", port=8200, name="Server", max_thread=0,
use_ssl=True, ssl_cert="server.crt", ssl_key="server.key",
log_to_file=True, log_filename="server.log",
- log_to_console=True, log_to_window=False, debug_mode=False):
+ log_to_console=True, log_to_window=False, legacy_ssl=False,
+ debug_mode=False, no_timeout=False):
+
"""Create a server, its logger and the SSL context if needed."""
logger = create_logger(
name, level=logging.DEBUG if debug_mode else logging.INFO,
@@ -352,7 +387,7 @@ def create_server(server_class, server_handler,
log_to_console=log_to_console,
log_to_window=log_to_window)
server = server_class((address, port), server_handler, max_thread, logger,
- debug_mode)
+ debug_mode, no_timeout)
if use_ssl:
server.socket = wii_ssl_wrap_socket(server.socket, ssl_cert, ssl_key)
@@ -363,25 +398,27 @@ def create_server(server_class, server_handler,
server_base = namedtuple("ServerBase", ["name", "cls", "handler"])
-def create_server_from_base(name, server_class, server_handler, silent=False,
- debug_mode=False):
+def create_server_from_base(name, server_class, server_handler, server_id,
+ silent=False, debug_mode=False, no_timeout=False):
"""Create a server based on its config parameters."""
- config = get_config(name)
+ general_config = get_config(name)
+ server_config = None if not server_id else get_remote_config("SERVER{}".format(server_id))
return create_server(
server_class, server_handler,
- address=config["IP"],
- port=config["Port"],
- name=config["Name"],
- max_thread=config["MaxThread"],
- use_ssl=config["UseSSL"],
- ssl_cert=config["SSLCert"],
- ssl_key=config["SSLKey"],
- log_to_file=config["LogToFile"],
- log_filename=config["LogFilename"],
- log_to_console=config["LogToConsole"] and not silent,
- log_to_window=config["LogToWindow"],
- debug_mode=debug_mode
- ), config["LogToWindow"]
+ address=general_config["IP"],
+ port=general_config["Port"] if not server_id else server_config["Port"],
+ name=general_config["Name"] if not server_id else server_config["Name"],
+ max_thread=general_config["MaxThread"],
+ use_ssl=general_config["UseSSL"],
+ ssl_cert=general_config["SSLCert"],
+ ssl_key=general_config["SSLKey"],
+ log_to_file=general_config["LogToFile"],
+ log_filename=general_config["LogFilename"],
+ log_to_console=general_config["LogToConsole"] and not silent,
+ log_to_window=general_config["LogToWindow"],
+ debug_mode=debug_mode,
+ no_timeout=no_timeout
+ ), general_config["LogToWindow"]
def server_main(name, server_class, server_handler):