diff --git a/mh/database.py b/mh/database.py index 9507007..739fb93 100644 --- a/mh/database.py +++ b/mh/database.py @@ -342,19 +342,75 @@ def get_friends(self, capcom_id, first_index=None, count=None): return friends[begin:end] -class MySQLDatabase(TempDatabase): - """Hybrid MySQL/TempDatabase.""" +class DatabaseError(Exception): + """Database exception class.""" + pass + + +class SafeMySQLConnection(object): + """Proxy object to safely reconnect to MySQL database. + + TODO: If the logic needs to be duplicated, we can move it into a + dedicated annotation/metaclass. + """ + + def __init__(self, attempts=3, cooldown=60.0): + """Reconnection attempts before waiting a cooldown time.""" + self.__attempts = attempts + self.__cooldown = cooldown + self.__connection = None + self.__time = None + self.__restart() + + def __restart(self): + """Reload the config and restart the connection.""" + if self.__connection: + try: + self.__connection.close() + except Exception: + pass + finally: + self.__connection = None - def __init__(self): - self.parent = super(MySQLDatabase, self) - self.parent.__init__() from mysql import connector from mysql.connector.constants import ClientFlag config = get_mysql_config("MYSQL") config['charset'] = 'utf8' config['autocommit'] = True config['client_flags'] = [ClientFlag.SSL] if config['ssl_ca'] else None - self.connection = connector.connect(**config) + self.__connection = connector.connect(**config) + + def cursor(self): + """Override the cursor method.""" + try: + c = self.__connection.cursor() + self.__time = None + return c + except Exception: + now = time.time() + if self.__time and (now - self.__time) < self.__cooldown: + raise DatabaseError("Connection lost, reconnection cooldown") + + self.__time = now + for _ in range(self.__attempts): + try: + self.__restart() + if self.__connection is not None: + c = self.__connection.cursor() + self.__time = None + return c + except Exception: + pass + raise DatabaseError("Connection lost, reconnection attempts failed") + + +class MySQLDatabase(TempDatabase): + """Hybrid MySQL/TempDatabase.""" + + def __init__(self): + self.parent = super(MySQLDatabase, self) + self.parent.__init__() + self.connection = SafeMySQLConnection() self.create_database() self.populate_database()