Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions mh/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,19 +342,75 @@ def get_friends(self, capcom_id, first_index=None, count=None):
return friends[begin:end]


class MySQLDatabase(TempDatabase):
"""Hybrid MySQL/TempDatabase."""
class DatabaseError(Exception):
"""Database exception class."""
pass


class SafeMySQLConnection(object):
"""Proxy object to safely reconnect to MySQL database.

TODO: If the logic needs to be duplicated, we can move it into a
dedicated annotation/metaclass.
"""

def __init__(self, attempts=3, cooldown=60.0):
"""Reconnection attempts before waiting a cooldown time."""
self.__attempts = attempts
self.__cooldown = cooldown
self.__connection = None
self.__time = None
self.__restart()

def __restart(self):
"""Reload the config and restart the connection."""
if self.__connection:
try:
self.__connection.close()
except Exception:
pass
finally:
self.__connection = None

def __init__(self):
self.parent = super(MySQLDatabase, self)
self.parent.__init__()
from mysql import connector
from mysql.connector.constants import ClientFlag
config = get_mysql_config("MYSQL")
config['charset'] = 'utf8'
config['autocommit'] = True
config['client_flags'] = [ClientFlag.SSL] if config['ssl_ca'] else None
self.connection = connector.connect(**config)
self.__connection = connector.connect(**config)

def cursor(self):
"""Override the cursor method."""
try:
c = self.__connection.cursor()
self.__time = None
return c
except Exception:
now = time.time()
if self.__time and (now - self.__time) < self.__cooldown:
raise DatabaseError("Connection lost, reconnection cooldown")

self.__time = now
for _ in range(self.__attempts):
try:
self.__restart()
if self.__connection is not None:
c = self.__connection.cursor()
self.__time = None
return c
except Exception:
pass
raise DatabaseError("Connection lost, reconnection attempts failed")


class MySQLDatabase(TempDatabase):
"""Hybrid MySQL/TempDatabase."""

def __init__(self):
self.parent = super(MySQLDatabase, self)
self.parent.__init__()
self.connection = SafeMySQLConnection()
self.create_database()
self.populate_database()

Expand Down