diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bcf6b49 --- /dev/null +++ b/.gitignore @@ -0,0 +1,121 @@ +# Data files +*.db +*.pickle +*.json + +# Conf files +*.conf + +# Development cruft +*.swo +*.swp +RCS +tags + +# ContextBot contents +contextBot/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# vscode config files +**/.vscode diff --git a/assets/guide-gifs/pattern-guide.gif b/assets/guide-gifs/pattern-guide.gif new file mode 100644 index 0000000..d2add88 Binary files /dev/null and b/assets/guide-gifs/pattern-guide.gif differ diff --git a/assets/sounds/slow-spring-board.wav b/assets/sounds/slow-spring-board.wav new file mode 100644 index 0000000..14f2779 Binary files /dev/null and b/assets/sounds/slow-spring-board.wav differ diff --git a/bot/BotData.py b/bot/BotData.py deleted file mode 100644 index 43877f7..0000000 --- a/bot/BotData.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -from datetime import datetime -import sqlite3 as sq -import json - -prop_table_info = [ - ("users", "users", ["userid"]), - ("guilds", "guilds", ["guildid"]), -] - - -class BotData: - def __init__(self, app="", data_file="data.db", version=0): - to_create = not os.path.exists(data_file) - - # Connect to database - self.conn = sq.connect(data_file, timeout=20) - - # Handle version checking - now = datetime.timestamp(datetime.utcnow()) - cursor = self.conn.cursor() - version_columns = "version INTEGER NOT NULL, time INTEGER NOT NULL" - if to_create: - # Create version table - cursor.execute('CREATE TABLE VersionHistory ({})'.format(version_columns)) - - # Insert current version into table - cursor.execute('INSERT INTO VersionHistory VALUES ({}, {})'.format(version, now)) - self.conn.commit() - else: - # Check if table exists - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='VersionHistory'") - version_exists = cursor.fetchone() - if not version_exists: - # Create version table - cursor.execute('CREATE TABLE VersionHistory ({})'.format(version_columns)) - - # Insert version 0 into table - cursor.execute('INSERT INTO VersionHistory VALUES (0, {})'.format(now)) - self.conn.commit() - - # Get last entry in version table, compare against desired version - cursor.execute("SELECT * FROM VersionHistory ORDER BY rowid DESC LIMIT 1") - current_version, _ = cursor.fetchone() - - if current_version != version: - # Complain - raise Exception( - ("Database version is {}, required version is {}. " - "Please migrate database.").format(current_version, version) - ) - - # Load property tables - for name, table_name, keys in prop_table_info: - manipulator = _propTableManipulator(table_name, keys, self.conn, app) - self.__setattr__(name, manipulator) - - def close(self): - self.conn.commit() - self.conn.close() - - -class _propTableManipulator: - def __init__(self, table, keys, conn, app): - self.table = table - self.keys = keys - self.conn = conn - self.app = app - - self.ensure_tables() - self.propmap = self.get_propmap() - - def ensure_tables(self): - cursor = self.conn.cursor() - keys = "{},".format(", ".join("{} INTEGER NOT NULL".format(key) for key in self.keys)) if self.keys else "" - key_list = "{},".format(", ".join(self.keys)) if self.keys else "" - columns = "{} property TEXT NOT NULL, value TEXT, PRIMARY KEY ({} property)".format(keys, key_list) - cursor.execute('CREATE TABLE IF NOT EXISTS {} ({})'.format(self.table, columns)) - cursor.execute('CREATE TABLE IF NOT EXISTS {}_props (property TEXT NOT NULL,\ - shared BOOLEAN NOT NULL,\ - PRIMARY KEY (property))'.format(self.table)) - self.conn.commit() - - def get_propmap(self): - cursor = self.conn.cursor() - cursor.execute('SELECT * from {}_props'.format(self.table)) - propmap = {} - for prop in cursor.fetchall(): - propmap[prop[0]] = prop[1] - return propmap - - def map_prop(self, prop): - return "{}_{}".format(self.app, prop) if (prop in self.propmap and not self.propmap[prop] and self.app) else prop - - def ensure_exists(self, *props, shared=True): - for prop in props: - if prop in self.propmap: - if self.propmap[prop] != shared: - cursor = self.conn.cursor() - cursor.execute('UPDATE {}_props SET shared = ? WHERE property = ?'.format(self.table), (shared, prop)) - self.propmap[prop] = shared - self.conn.commit() - else: - cursor = self.conn.cursor() - cursor.execute('INSERT INTO {}_props VALUES (?, ?)'.format(self.table), (prop, shared)) - self.propmap = self.get_propmap() - self.conn.commit() - - def get(self, *args, default=None): - if len(args) != len(self.keys) + 1: - raise Exception("Improper number of keys passed to get.") - prop = self.map_prop(args[-1]) - criteria = " AND ".join("{} = ?" for key in args) - - cursor = self.conn.cursor() - cursor.execute('SELECT value from {} where {}'.format(self.table, criteria).format(*self.keys, 'property'), tuple([*args[:-1], prop])) - value = cursor.fetchone() - return json.loads(value[0]) if (value and value[0]) else default - - def set(self, *args): - if len(args) != len(self.keys) + 2: - raise Exception("Improper number of keys passed to set.") - prop = self.map_prop(args[-2]) - value = json.dumps(args[-1]) - criteria = " AND ".join("{} = ?" for key in args[:-1]) - values = ", ".join("?" for key in args) - - cursor = self.conn.cursor() - cursor.execute('SELECT EXISTS(SELECT 1 from {} where {})'.format(self.table, criteria).format(*self.keys, 'property'), tuple([*args[:-2], prop])) - exists = cursor.fetchone() - - if not exists[0]: - cursor.execute('INSERT INTO {} VALUES ({})'.format(self.table, values), tuple([*args[:-2], prop, value])) - else: - cursor.execute('UPDATE {} SET value = ? WHERE {}'.format(self.table, criteria).format(*self.keys, 'property'), tuple([value, *args[:-2], prop])) - self.conn.commit() - - def find(self, prop, value, read=False): - if len(self.keys) > 1: - raise Exception("This method cannot currently be used when there are multiple keys") - prop = self.map_prop(prop) - if read: - value = json.dumps(value) - - cursor = self.conn.cursor() - cursor.execute('SELECT {} FROM {} WHERE property = ? AND value = ?'.format(self.keys[0], self.table), (prop, value)) - return [value[0] for value in cursor.fetchall()] - - def find_not_empty(self, prop): - if len(self.keys) > 1: - raise Exception("This method cannot currently be used when there are multiple keys") - prop = self.map_prop(prop) - - cursor = self.conn.cursor() - cursor.execute('SELECT {} FROM {} WHERE property = ? AND value IS NOT NULL AND value != \'\''.format(self.keys[0], self.table), (prop,)) - return [value[0] for value in cursor.fetchall()] diff --git a/bot/Timer/Timer.py b/bot/Timer/Timer.py deleted file mode 100644 index a3dce9e..0000000 --- a/bot/Timer/Timer.py +++ /dev/null @@ -1,686 +0,0 @@ -import asyncio -import datetime -import logging -import traceback -import discord -from enum import Enum - -from logger import log - - -class Timer(object): - clock_period = 600 - max_warning = 1 - - def __init__(self, name, role, channel, clock_channel=None, stages=None): - self.channel = channel - self.clock_channel = clock_channel - self.role = role - self.name = name - self._truename = name - - self.start_time = None # Session start time - self.current_stage_start = None # Time at which the current stage started - self.remaining = None # Amount of time until the next stage starts - self.state = TimerState.STOPPED # Current state of the timer - - self.stages = stages # List of stages in this timer - self.current_stage = 0 # Index of current stage - - self.subscribed = {} # Dict of subbed members, userid maps to (user, lastupdate, timesubbed) - - self.timer_messages = [] # List of sent message ids that this timer owns, e.g. for reaction handling - - self.last_clockupdate = 0 - - if stages: - self.setup(stages) - - def __contains__(self, userid): - """ - Containment interface acts as list of subscribers. - """ - return userid in self.subscribed - - def setup(self, stages): - """ - Setup the timer with a list of TimerStages. - """ - self.stop() - - self.stages = stages - self.current_stage = 0 - - now = self.now() - self.start_time = now - - self.remaining = stages[0].duration - self.current_stage_start = now - - # Return self for method chaining - return self - - async def update_clock_channel(self, force=False): - """ - Try to update the name of the status channel with the current status - """ - # Quit if there's no status channel set - if self.clock_channel is None: - return - - # Quit if we aren't due for a clock update yet - if not force and self.now() - self.last_clockupdate < self.clock_period: - return - - # Get the name and time strings - stage_name = self.stages[self.current_stage].name - - # Update the channel name, or quit silently if something goes wrong. - self.last_clockupdate = self.now() - try: - await self.clock_channel.edit(name="{} - {}".format(self.name, stage_name)) - self.last_clockupdate = self.now() - except Exception: - pass - - def pretty_remaining(self, show_seconds=False): - """ - Return a formatted version of the time remaining until the next stage. - """ - return self.parse_dur(self.remaining, show_seconds=show_seconds) - - def pretty_pinstatus(self): - """ - Return a formatted status string for use in the pinned status message. - """ - subbed_names = [m.member.name for m in self.subscribed.values()] - subbed_str = "```{}```".format(", ".join(subbed_names)) if subbed_names else "*No members*" - - if self.state in [TimerState.RUNNING, TimerState.PAUSED]: - # Collect the component strings and data - current_stage_name = self.stages[self.current_stage].name - remaining = self.pretty_remaining() - - # Create a list of lines for the stage string - longest_stage_len = max(len(stage.name) for stage in self.stages) - stage_format = "`{{prefix}}{{name:>{}}}:` {{dur}} min {{current}}".format(longest_stage_len) - - stage_str_lines = [ - stage_format.format( - prefix="->" if i == self.current_stage else "​ ", - name=stage.name, - dur=stage.duration, - current="(**{}**)".format(remaining) if i == self.current_stage else "" - ) for i, stage in enumerate(self.stages) - ] - # Create the stage string itself - stage_str = "\n".join(stage_str_lines) - - # Create the final formatted status string - status_str = ("**{name}**: {current_stage_name} {paused}\n" - "{stage_str}\n" - "{subbed_str}").format(name=self.name, - role=self.role.mention, - paused=" ***Paused***" if self.state == TimerState.PAUSED else "", - current_stage_name=current_stage_name, - stage_str=stage_str, - subbed_str=subbed_str) - elif self.state == TimerState.STOPPED: - status_str = "**{}**: *Timer not running.*\n{}".format(self.name, subbed_str) - return status_str - - def pretty_summary(self): - """ - Return a short summary status message. - """ - if self.stages: - stage_str = "/".join(("**{}**".format(stage.duration) if i == self.current_stage else str(stage.duration)) - for i, stage in enumerate(self.stages)) - else: - stage_str = "*Not set up.*" - - if self.state == TimerState.RUNNING: - status_str = "Stage `{}`, `{}` remaining\n".format(self.stages[self.current_stage].name, - self.pretty_remaining()) - elif self.state == TimerState.PAUSED: - status_str = "*Timer is paused.*\n" - elif self.state == TimerState.STOPPED: - status_str = "" - - if self.subscribed: - member_str = "Members: " + ", ".join(s.member.mention for s in self.subscribed.values()) - else: - member_str = "*No members.*" - - return "{} ({}): {}\n{}{}".format( - self.role.mention, - self.name, - stage_str, - status_str, - member_str - ) - - def oneline_summary(self): - """ - Return a one line summary status message - """ - if self.state == TimerState.RUNNING: - status = "Running" - elif self.state == TimerState.PAUSED: - status = "Paused" - elif self.state == TimerState.STOPPED: - status = "Stopped" - - if self.stages: - stage_str = "/".join(str(stage.duration) for i, stage in enumerate(self.stages)) - else: - stage_str = "not set up" - - return "{name} ({status} with {members} members, {setup}.)".format( - name=self.name, - status=status, - members=len(self.subscribed) if self.subscribed else 'no', - setup=stage_str - ) - - async def change_stage(self, stage_index, notify=True, inactivity_check=True, report_old=True): - """ - Advance the timer to the new stage. - """ - stage_index = stage_index % len(self.stages) - current_stage = self.stages[self.current_stage] - new_stage = self.stages[stage_index] - - self.current_stage = stage_index - self.current_stage_start = self.now() - self.remaining = self.stages[stage_index].duration * 60 - - # Update clocked times for all the subbed users and handle inactivity - needs_warning = [] - unsubs = [] - for subber in self.subscribed.values(): - subber.touch() - if inactivity_check: - if subber.warnings >= self.max_warning: - subber.warnings += 1 - unsubs.append(subber) - elif (self.now() - subber.last_seen) > current_stage.duration * 60: - subber.warnings += 1 - if subber.warnings >= self.max_warning: - needs_warning.append(subber) - - # Handle not having any subscribers - empty = (len(self.subscribed) == 0) - - # Handle notifications - if notify: - old_stage_str = "**{}** finished! ".format(current_stage.name) if report_old else "" - if needs_warning: - warning_str = ("{} you will be unsubscribed on the next stage " - "if you do not reply or react to this message.\n").format( - ", ".join(subber.member.mention for subber in needs_warning) - ) - else: - warning_str = "" - if unsubs: - unsub_str = "{} you have been unsubscribed due to inactivity!\n".format( - ", ".join(subber.member.mention for subber in unsubs) - ) - else: - unsub_str = "" - - main_line = "{}Starting **{}** ({} minutes). {}".format( - old_stage_str, - new_stage.name, - new_stage.duration, - new_stage.message - ) - - if not empty: - out_msg = await self.channel.send( - ("{}\n{}\n" - "Please reply or react to this message to register your existence.\n{}{}").format( - self.role.mention, - main_line, - warning_str, - unsub_str - ) - ) - try: - await out_msg.add_reaction("✅") - except Exception: - pass - - # Add the stage message to the owned message list - self.timer_messages.append(out_msg.id) - self.timer_messages = self.timer_messages[-5:] # Truncate - else: - """ - await self.channel.send( - ("{}\n " - "{}No subscribers, stopping group timer.").format( - self.role.mention, - old_stage_str - ) - ) - self.stop() - """ - pass - - # Notify the subscribers as desired - for subber in self.subscribed.values(): - try: - out_msg = None - if subber in unsubs and subber.notify >= NotifyLevel.FINAL: - await subber.member.send( - "You have been unsubscribed from group **{}** in {} due to inactivity!".format( - self.name, - self.channel.mention - ) - ) - elif subber in needs_warning and subber.notify >= NotifyLevel.WARNING: - out_msg = await subber.member.send( - ("**Warning** from group **{}** in {}!\n" - "Please respond or react to a timer message " - "to avoid being unsubscribed on the next stage.\n{}").format( - self.name, - self.channel.mention, - main_line - ) - ) - elif subber.notify >= NotifyLevel.ALL: - out_msg = await subber.member.send( - "Status update for group **{}** in {}!\n{}".format(self.name, - self.channel.mention, - main_line) - ) - except discord.Forbidden: - pass - except discord.HTTPException: - pass - - for subber in unsubs: - await subber.unsub() - - async def start(self): - """ - Start or restart the timer. - """ - await self.change_stage(0, report_old=False) - self.state = TimerState.RUNNING - for subber in self.subscribed.values(): - subber.touch() - subber.active = True - - asyncio.ensure_future(self.runloop()) - - def stop(self): - """ - Stop the timer, and ensure the subscriber clocked times are updated. - """ - for subber in self.subscribed.values(): - subber.touch() - subber.active = False - - self.state = TimerState.STOPPED - - async def runloop(self): - while self.state == TimerState.RUNNING: - self.remaining = int(60*self.stages[self.current_stage].duration - (self.now() - self.current_stage_start)) - if self.remaining <= 0: - try: - await self.change_stage(self.current_stage + 1) - asyncio.ensure_future(self.update_clock_channel(force=True)) - except Exception: - full_traceback = traceback.format_exc() - log("Exception encountered while changing stage.\n{}".format(full_traceback), - context="TIMER_RUNLOOP", - level=logging.ERROR) - - # Disable clock update since the channel update ratelimit is too slow - # asyncio.ensure_future(self.update_clock_channel()) - await asyncio.sleep(1) - - @staticmethod - def now(): - """ - Helper to get the current UTC timestamp as an integer. - """ - return int(datetime.datetime.timestamp(datetime.datetime.utcnow())) - - @staticmethod - def parse_dur(diff, show_seconds=False): - """ - Parse a duration given in seconds to a time string. - """ - diff = max(diff, 0) - if show_seconds: - diff = int(60 * round(diff / 60)) - hours = diff // 3600 - minutes = (diff % 3600) // 60 - return "{:02d}:{:02d}".format(hours, minutes) - else: - hours = diff // 3600 - minutes = (diff % 3600) // 60 - seconds = diff % 60 - return "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) - - def serialise(self): - """ - Serialise current timer status to a dictionary. - Does not serialise subscribers or fixed attributes such as channels. - """ - return { - 'roleid': self.role.id, - 'name': self.name, - 'start_time': self.start_time, - 'current_stage_start': self.current_stage_start, - 'remaining': self.remaining, - 'state': self.state.value, - 'stages': [stage.serialise() for stage in self.stages] if self.stages else None, - 'current_stage': self.current_stage, - 'messages': self.timer_messages, - } - - def update_from_data(self, data): - """ - Restore timer status from the provided status dict, as produced by `serialise`. - """ - self.name = data['name'] - self.start_time = data['start_time'] - self.current_stage_start = data['current_stage_start'] - self.remaining = data['remaining'] - self.state = TimerState(data['state']) - self.stages = [ - TimerStage.deserialise(stage_data) for stage_data in data['stages'] - ] if data['stages'] else None - self.current_stage = data.get('current_stage', 0) - self.timer_messages = data.get('messages', []) - - asyncio.ensure_future(self.runloop()) - return self - - -class TimerState(Enum): - """ - Enum representing the current running state of the timer. - STOPPED: The timer either hasn't been set up, or has been stopped externally. - RUNNING: The timer is running normally. - PAUSED: The timer has been paused by a user. - """ - STOPPED = 1 - RUNNING = 2 - PAUSED = 3 - - -class TimerStage(object): - """ - Small data class to encapsualate a "stage" of a timer. - - Parameters - ---------- - name: str - The human readable name of the stage. - duration: int - The number of minutes the stage lasts for. - message: str - An optional message to send when starting this stage. - focus: bool - Whether `focus` mode is set for this stage. - modifiers: Dict(str, bool) - An unspecified collection of stage modifiers, stored for external use. - """ - __slots__ = ('name', 'message', 'duration', 'focus', 'modifiers') - - def __init__(self, name, duration, message="", focus=False, **modifiers): - self.name = name - self.duration = duration - self.message = message - - self.focus = focus - - self.modifiers = modifiers - - def serialise(self): - """ - Serialise stage to a serialisable dictionary. - """ - return { - 'name': self.name, - 'duration': self.duration, - 'message': self.message, - 'focus': self.focus, - 'modifiers': self.modifiers - } - - @classmethod - def deserialise(cls, data_dict): - """ - Deserialise stage from a dictionary formatted like the output of `serialise. - """ - return cls( - data_dict['name'], - data_dict['duration'], - message=data_dict['message'], - focus=data_dict['focus'], - **data_dict['modifiers'] - ) - - -class TimerChannel(object): - """ - A data class representing a guild channel bound to (potentially) several timers. - - Parameters - ---------- - channel: discord.Channel - The bound discord guild channel - timers: List(Timer) - The timers bound to the channel - msg: discord.Message - A valid and current discord Message in the channel. - Holds the updating timer status messages. - """ - __slots__ = ('channel', 'timers', 'msg', 'old_desc') - - def __init__(self, channel): - self.channel = channel - - self.timers = [] - self.msg = None - - self.old_desc = "" - - async def update(self): - """ - Create or update the channel status message. - """ - messages = [timer.pretty_pinstatus() for timer in self.timers] - if messages: - desc = "\n\n".join(messages) - - # Don't resend the same message - if desc == self.old_desc: - return - self.old_desc = desc - - embed = discord.Embed( - title="Pomodoro Timer Status", - description=desc, - timestamp=datetime.datetime.utcnow() - ) - if self.msg is not None: - try: - await self.msg.edit(embed=embed) - except discord.NotFound: - self.msg = None - except discord.Forbidden: - pass - except Exception: - pass - - """ - if all(timer.state == TimerState.STOPPED for timer in self.timers): - # Unpin and unset message - try: - await self.msg.unpin() - except Exception: - pass - - self.msg = None - """ - elif any(timer.state != TimerState.STOPPED for timer in self.timers): - # Attempt to generate a new message - try: - # Send a new message - self.msg = await self.channel.send(embed=embed) - - # Pin the message - try: - await self.msg.pin() - except Exception: - pass - except discord.Forbidden: - try: - await self.channel.send( - "I require permission to send embeds in this channel! " - "Stopping all timers." - ) - except Exception: - # There's no point trying to handle this if we can't send anything, just quietly unload - pass - - for timer in self.timers: - timer.stop() - except discord.NotFound: - # The channel doesn't even exist anymore! Stop all timers so we don't try to post anymore. - # TODO: Handle garbage collection, cautiously because this might be an outage - for timer in self.timers: - timer.stop() - - -class NotifyLevel(Enum): - """ - Enum representing a subscriber's notification level. - NONE: Never send direct messages. - FINAL: Send a direct message when kicking for inactivity. - WARNING: Send direct messages for unsubscription warnings. - ALL: Send direct messages for all stage updates. - """ - NONE = 1 - FINAL = 2 - WARNING = 3 - ALL = 4 - - def __ge__(self, other): - if self.__class__ is other.__class__: - return self.value >= other.value - return NotImplemented - - def __gt__(self, other): - if self.__class__ is other.__class__: - return self.value > other.value - return NotImplemented - - def __le__(self, other): - if self.__class__ is other.__class__: - return self.value <= other.value - return NotImplemented - - def __lt__(self, other): - if self.__class__ is other.__class__: - return self.value < other.value - return NotImplemented - - -class TimerSubscriber(object): - __slots__ = ( - 'member', - 'timer', - 'interface', - 'notify', - 'client', - 'id', - 'time_joined', - 'last_updated', - 'clocked_time', - 'active', - 'last_seen', - 'warnings' - ) - - def __init__(self, member, timer, interface, notify=NotifyLevel.WARNING): - self.member = member - self.timer = timer - self.interface = interface - self.notify = notify - - self.client = interface.client - self.id = member.id - - now = Timer.now() - self.time_joined = now - - self.last_updated = now - self.clocked_time = 0 - self.active = (timer.state == TimerState.RUNNING) - - self.last_seen = now - self.warnings = 0 - - async def unsub(self): - return await self.interface.unsub(self.member.guild.id, self.id) - - def bump(self): - self.last_seen = Timer.now() - self.warnings = 0 - - def touch(self): - """ - Update the clocked time based on the active status. - """ - now = Timer.now() - self.clocked_time += (now - self.last_updated) if self.active else 0 - self.last_updated = now - - def session_data(self): - """ - Return session data in a format compatible with the registry. - """ - self.touch() - - return ( - self.id, - self.member.guild.id, - self.timer.role.id, - self.time_joined, - self.clocked_time - ) - - def serialise(self): - return { - 'id': self.id, - 'guildid': self.member.guild.id, - 'roleid': self.timer.role.id, - 'notify': self.notify.value, - 'time_joined': self.time_joined, - 'last_updated': self.last_updated, - 'clocked_time': self.clocked_time, - 'active': self.active, - 'last_seen': self.last_seen, - 'warnings': self.warnings - } - - @classmethod - def deserialise(cls, member, timer, interface, data): - self = cls(member, timer, interface) - - self.time_joined = data['time_joined'] - self.last_updated = data['last_updated'] - self.clocked_time = data['clocked_time'] - self.active = data['active'] - self.notify = NotifyLevel(data['notify']) - self.last_seen = data['last_seen'] - self.warnings = data['warnings'] - - return self diff --git a/bot/Timer/__init__.py b/bot/Timer/__init__.py index 939d6d6..e08c2b3 100644 --- a/bot/Timer/__init__.py +++ b/bot/Timer/__init__.py @@ -1,2 +1,240 @@ -from .interface import TimerInterface -from .Timer import Timer, TimerChannel, TimerSubscriber, TimerState, TimerStage, NotifyLevel +import os +import shutil +import logging +import asyncio +import pickle + +import discord +from cmdClient import Module +from cmdClient.Context import Context + +from meta import log +from data import tables + +from .lib import TimerState, NotifyLevel, InvalidPattern # noqa +from .core import Timer, TimerSubscriber, TimerChannel, Pattern # noqa +from . import activity_events +from . import timer_reactions +from . import voice_events +from . import guild_events + + +class TimerInterface(Module): + name = "TimerInterface" + save_dir = "data/timerstatus/" + save_fn = "timerstatus.pickle" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.guild_channels = {} # guildid -> {channelid -> TimerChannel} + + self.init_task(self.core_init) + self.launch_task(self.core_launch) + + self.runloop_task = None + self.saveloop_task = None + + def core_init(self, _client): + _client.interface = self + Context.timers = module + _client.add_after_event('message', activity_events.message_tracker) + _client.add_after_event('raw_reaction_add', activity_events.reaction_tracker) + + _client.add_after_event('raw_reaction_add', timer_reactions.joinleave_tracker) + _client.add_after_event('voice_state_update', voice_events.vc_update_handler) + + _client.add_after_event('guild_join', guild_events.on_guild_join) + _client.add_after_event('guild_remove', guild_events.on_guild_remove) + + async def core_launch(self, _client): + await self.load_timers() + self.restore_from_save() + self.runloop_task = asyncio.create_task(self._runloop()) + self.saveloop_task = asyncio.create_task(self._saveloop()) + + def shutdown(self): + if self.saveloop_task and not self.saveloop_task.done(): + self.saveloop_task.cancel() + if self.runloop_task and not self.runloop_task.done(): + self.runloop_task.cancel() + + async def _runloop(self): + while True: + channel_keys = [ + (guildid, channelid) + for guildid, channels in self.guild_channels.items() + for channelid, channel in channels.items() + if any(timer.state == TimerState.RUNNING for timer in channel.timers) + ] + channel_count = len(channel_keys) + if channel_count == 0: + await asyncio.sleep(30) + continue + + delay = max(0.1, 30/channel_count) + + channels = ( + self.guild_channels[gid][cid] + for gid, cid in channel_keys + if gid in self.guild_channels and cid in self.guild_channels[gid] + ) + for channel in channels: + try: + await channel.update_pin() + except Exception: + log("Exception encountered updating channel pin for {!r}".format(channel.channel), + context="TIMER_RUNLOOP", + level=logging.ERROR, + add_exc_info=True) + await asyncio.sleep(delay) + + async def _saveloop(self): + while True: + await asyncio.sleep(60) + self.update_save() + + def update_save(self, reason=None): + # TODO: Move save file location to config? For e.g. sharding + log("Writing session savefile.", context="TIMER_SAVE", level=logging.DEBUG) + save_data = { + guildid: [tchannel.serialise() for tchannel in tchannels.values()] + for guildid, tchannels in self.guild_channels.items() + } + path = os.path.join(self.save_dir, self.save_fn) + # Rotate + if os.path.exists(path): + os.rename(path, path + '.old') + + with open(path, 'wb') as f: + pickle.dump(save_data, f, pickle.HIGHEST_PROTOCOL) + + if reason: + shutil.copy2(path, path + '.' + reason) + + def restore_from_save(self): + log("------------------------Beginning session restore.", context="TIMER_RESTORE") + path = os.path.join(self.save_dir, self.save_fn) + if os.path.exists(path): + with open(path, 'rb') as f: + save_data = pickle.load(f) + + for guildid, data_channels in save_data.items(): + log("Restoring Guild (gid:{}).".format(guildid), context='TIMER_RESTORE') + tchannels = self.guild_channels.get(guildid, None) + if tchannels: + [tchannels[data['channelid']].restore_from(data) + for data in data_channels if data['channelid'] in tchannels] + log("------------------------Session restore complete.", context="TIMER_RESTORE") + + async def load_timers(self): + # Populate the pattern cache with the latest patterns + tables.patterns.fetch_rows_where(_extra="INNER JOIN 'current_timer_patterns' USING (patternid)") + + # Build and load all the timers, preserving the existing ones + timer_rows = tables.timers.fetch_rows_where() + timers = [Timer(row) for row in timer_rows] + timers = [timer for timer in timers if timer.load()] + + # Create the TimerChannels + guild_channels = {} + for timer in timers: + channels = guild_channels.get(timer.channel.guild.id, None) + if channels is None: + channels = guild_channels[timer.channel.guild.id] = {} + channel = channels.get(timer.channel.id, None) + if channel is None: + channel = channels[timer.channel.id] = TimerChannel(timer.channel) + channel.timers.append(timer) + + self.guild_channels = guild_channels + + def create_timer(self, role, channel, name, **kwargs): + guild = role.guild + new_timer = Timer.create(role.id, guild.id, name, channel.id, **kwargs) + if not new_timer.load(): + return None + + tchannels = self.guild_channels.get(guild.id, None) + if tchannels is None: + tchannels = self.guild_channels[guild.id] = {} + tchannel = tchannels.get(channel.id, None) + if tchannel is None: + tchannel = tchannels[channel.id] = TimerChannel(channel) + tchannel.timers.append(new_timer) + asyncio.create_task(tchannel.update_pin(force=True)) + + return new_timer + + async def obliterate_timer(self, timer): + # Remove the timer from its channel + channel = self.guild_channels[timer.channel.guild.id][timer.channel.id] + channel.timers.remove(timer) + if not channel.timers: + self.guild_channels[timer.channel.guild.id].pop(timer.channel.id) + + # Destroy the timer, unsubscribing members and deleting it from data + await timer.destroy() + + # Refresh the pinned message + await channel.update_pin(force=True) + + def move_timer(self, timer, new_channelid): + """ + Bind a timer to a new channelid. + """ + channels = self.guild_channels[timer.data.guildid] + old_channel = channels[timer.data.channelid] + old_channel.timers.remove(timer) + timer.data.channelid = new_channelid + timer.load() + + if new_channelid not in channels: + channels[new_channelid] = TimerChannel(timer.channel) + channels[new_channelid].timers.append(timer) + + def fetch_timer(self, roleid): + row = tables.timers.fetch(roleid) + if row: + channels = self.guild_channels.get(row.guildid, None) + if channels: + channel = channels.get(row.channelid, None) + if channel: + return next((timer for timer in channel.timers if timer.roleid == roleid), None) + + def get_timers_in(self, guildid, channelid=None): + timers = [] + channels = self.guild_channels.get(guildid, None) + if channels is not None: + if channelid is None: + timers = [timer for channel in channels.values() for timer in channel.timers] + elif channelid in channels: + timers = channels[channelid].timers + + return timers + + def get_subscriber(self, userid, guildid): + return next( + (timer.subscribers[userid] + for channel in self.guild_channels.get(guildid, {}).values() + for timer in channel.timers + if userid in timer.subscribers), + None + ) + + async def on_exception(self, ctx, exception): + if isinstance(exception, InvalidPattern): + await ctx.reply( + embed=discord.Embed( + description=( + "{}\n\n" + "See `{}help patterns` for more information about timer patterns." + ).format(exception.msg, ctx.best_prefix), + colour=discord.Colour.red() + ) + ) + else: + await super().on_exception(ctx, exception) + + +module = TimerInterface() diff --git a/bot/Timer/activity_events.py b/bot/Timer/activity_events.py new file mode 100644 index 0000000..6173a94 --- /dev/null +++ b/bot/Timer/activity_events.py @@ -0,0 +1,14 @@ +async def message_tracker(client, message): + if message.guild: + sub = client.interface.get_subscriber(message.author.id, message.guild.id) + if sub and sub.timer.channel == message.channel: + sub.touch() + if not sub.member: + sub.set_member(message.author) + + +async def reaction_tracker(client, payload): + if payload.guild_id: + sub = client.interface.get_subscriber(payload.user_id, payload.guild_id) + if sub and sub.timer.channel.id == payload.channel_id: + sub.touch() diff --git a/bot/Timer/core.py b/bot/Timer/core.py new file mode 100644 index 0000000..4f23b4e --- /dev/null +++ b/bot/Timer/core.py @@ -0,0 +1,1205 @@ +import json +import asyncio +import logging +import datetime +from collections import namedtuple + +import cachetools +import discord + +from meta import client, log +from data import tables +from settings import TimerSettings, GuildSettings + +from .voice_notify import play_alert +from .lib import join_emoji, leave_emoji, now, parse_dur, best_prefix, TimerState, NotifyLevel, InvalidPattern + + +# NamedTuple represeting a pattern stage +Stage = namedtuple('Stage', ('name', 'duration', 'message', 'focus')) + + +class Pattern: + _slots = ('row', 'stages') + + _cache = cachetools.LFUCache(1000) + _table = tables.patterns + + default_work_stage = "Work" + default_work_message = "Good luck!" + default_break_stage = "Break" + default_break_message = "Have a rest!" + + def __init__(self, row, stages=None): + self.row = row + self.stages = stages or [Stage(*stage) for stage in json.loads(row.stage_str)] + self._cache[(self.__class__, self.row.patternid)] = self + + def __iter__(self): + return iter(self.stages) + + def __str__(self): + return self.display() + + def display(self, brief=None, truncate=None): + brief = brief if brief is not None else self.row.short_repr + if brief: + if truncate and len(self.stages) > truncate: + return "/".join(str(stage.duration) for stage in self.stages[:truncate]) + '/...' + else: + return "/".join(str(stage.duration) for stage in self.stages) + else: + return ";\n".join( + "{0.name}, {0.duration}{1}, {0.message}".format(stage, '*' * stage.focus) + for stage in self.stages + ) + + @classmethod + def from_userstr(cls, string, timerid=None, userid=None, guildid=None): + """ + Parse a user-provided string into a `Pattern`, if possible. + Raises `InvalidPattern` for parsing errors. + Where possible, an existing `Pattern` will be returned, + otherwise a new `Pattern` will be created. + + Accepts kwargs to describe the parsing context. + """ + if not string: + raise InvalidPattern("No pattern provided!") + + pattern = None + + # First try presets + if userid: + row = tables.user_presets.select_one_where(userid=userid, preset_name=string) + if row: + pattern = cls.get(row['patternid']) + + if not pattern and guildid: + row = tables.guild_presets.select_one_where(guildid=guildid, preset_name=string) + if row: + pattern = cls.get(row['patternid']) + + # Then try string parsing + if not pattern: + stages = None + if ';' in string or ',' in string: + # Long form + # Accepts stages as 'name, length' or 'name, length, message' + short_repr = False + stage_blocks = string.strip(';').split(';') + stages = [] + for block in stage_blocks: + # Extract stage components + parts = block.split(',', maxsplit=2) + if len(parts) == 1: + raise InvalidPattern( + "`{}` is not of the form `name, length` or `name, length, message`.".format(block) + ) + elif len(parts) == 2: + name, dur = parts + message = None + else: + name, dur, message = parts + + # Parse duration + dur = dur.strip() + focus = dur.startswith('*') or dur.endswith('*') + if focus: + dur = dur.strip('* ') + + if not dur.isdigit(): + raise InvalidPattern( + "`{}` in `{}` couldn't be parsed as a duration.".format(dur, block.strip()) + ) + + # Build and add stage + stages.append(Stage(name.strip(), int(dur), (message or '').strip(), focus)) + elif '/' in string: + # Short form + # Only accepts numerical stages + short_repr = True + stage_blocks = string.strip('/').split('/') + stages = [] + + is_work = True # Whether the current stage is a work or break stage + default_focus = '*' not in string # Whether to use default focus flags + for block in stage_blocks: + # Parse duration + dur = block.strip() + focus = dur.startswith('*') or dur.endswith('*') + if focus: + dur = dur.strip('* ') + + if not dur.isdigit(): + raise InvalidPattern( + "`{}` couldn't be parsed as a duration.".format(dur) + ) + + # Build and add stage + if is_work: + stages.append(Stage( + cls.default_work_stage, + int(dur), + cls.default_work_message, + focus=True if default_focus else focus + )) + else: + stages.append(Stage( + cls.default_break_stage, + int(dur), + cls.default_break_message, + focus=False if default_focus else focus + )) + + is_work = not is_work + else: + raise InvalidPattern("Patterns must have more than one stage!") + + # Create the stage string + stage_str = json.dumps(stages) + + # Fetch or create the pattern row + row = cls._table.fetch_or_create( + short_repr=short_repr, + stage_str=stage_str + ) + + # Initialise and return the pattern + if row.patternid in cls._cache: + pattern = cls._cache[row.patternid] + else: + pattern = cls(row, stages=stages) + + return pattern + + @classmethod + @cachetools.cached(_cache) + def get(cls, patternid): + return cls(cls._table.fetch(patternid)) + + +class Timer: + __slots__ = ( + 'data', + 'settings', + 'state', + 'current_pattern', + 'stage_index', + 'stage_start', + '_loop_wait_task', + 'subscribers', + 'message_ids', + 'guild', + 'role', + 'channel', + 'voice_channel', + 'last_voice_update' + ) + + _table = tables.timers + + max_warnings = 1 + + def __init__(self, data): + self.data = data + self.settings = TimerSettings(data.roleid, timer=self) + + self.state: TimerState = TimerState.UNSET # State of the timer + self.current_pattern: Pattern = None # Current pattern set up + self.stage_index: int = None # Index of the current stage in the pattern + self.stage_start: int = None # Timestamp of the start of the stage + + self._loop_wait_task = None # Task used to trigger runloop read + + self.subscribers = {} # TimerSubscribers in the timer + self.message_ids = [] # Notification messages owned by the timer + + self.last_voice_update = 0 # Timestamp of last vc update + + # Discord objects, intialised in `Timer.load()` + self.guild: discord.Guild = None + self.role: discord.Role = None + self.channel: discord.TextChannel = None + self.voice_channel: discord.VoiceChannel = None + + def __getattr__(self, key): + # TODO: Dangerous due to potential property attribute errors + if key in self.data.table.columns: + return getattr(self.data, key) + else: + raise AttributeError(key) + + def __contains__(self, userid): + return userid in self.subscribers + + @property + def default_pattern(self) -> Pattern: + return Pattern.get(self.data.patternid) + + @property + def current_stage(self): + return self.current_pattern.stages[self.stage_index] + + @property + def remaining(self): + """ + The remaining time (in seconds) in the current stage. + """ + return int(60*self.current_stage.duration - (now() - self.stage_start)) + + @property + def pretty_remaining(self): + return parse_dur( + self.remaining, + show_seconds=True + ) if self.state == TimerState.RUNNING else '*Not Running*' + + @property + def pinstatus(self): + """ + Return a formatted status string for use in the pinned status message. + """ + return self.status_string() + + @property + def voice_channel_name(self): + return self.settings.vc_name.value.replace( + "{stage_name}", self.current_stage.name + ).replace( + "{remaining}", parse_dur( + int(60*self.current_stage.duration - (now() - self.stage_start)), + show_seconds=False + ) + ).replace( + "{name}", self.data.name + ).replace( + "{stage_dur}", parse_dur(self.current_stage.duration * 60, show_seconds=False) + ).replace( + "{sub_count}", str(len(self.subscribers)) + ).replace( + "{pattern}", (self.current_pattern or self.default_pattern).display(brief=True, truncate=6) + ) + + @property + def oneline_summary(self): + """ + Return a one line summary status message + """ + if self.state == TimerState.RUNNING: + status = "Running" + elif self.state == TimerState.PAUSED: + status = "Paused" + elif self.state in (TimerState.STOPPED, TimerState.UNSET): + status = "Stopped" + + return "{name} ({status} with {members} members, {setup}.)".format( + name=self.data.name, + status=status, + members=len(self.subscribers) if self.subscribers else 'no', + setup=(self.current_pattern or self.default_pattern).display(brief=True) + ) + + @property + def pretty_summary(self): + pattern = self.current_pattern or self.default_pattern + stage_str = "/".join( + "{1}{0}{1}".format(stage.duration, (i == self.stage_index) * '**') + for i, stage in enumerate(pattern.stages) + ) + + if self.state == TimerState.RUNNING: + status_str = "Stage `{}`, `{}` remaining\n".format(self.current_stage.name, self.pretty_remaining) + elif self.state == TimerState.PAUSED: + status_str = "*Timer is paused.*\n" + else: + status_str = '' + + if self.subscribers: + member_str = "Members: " + ", ".join("<@{}>".format(uid) for uid in self.subscribers) + else: + member_str = "*No members.*" + + return "{}{}: {}\n{}{}".format( + self.role.mention, + "({})".format(self.data.name) if self.data.name != self.role.name else '', + stage_str, + status_str, + member_str + ) + + def status_string(self, show_seconds=False): + subbed_names = [m.name for m in self.subscribers.values()] + subbed_str = "```{}```".format(", ".join(subbed_names)) if subbed_names else "*No members*" + + if self.state in (TimerState.RUNNING, TimerState.PAUSED, TimerState.STOPPED): + running = self.state in (TimerState.RUNNING, TimerState.PAUSED) + + # Collect the component strings and data + pretty_remaining = parse_dur( + int(60*self.current_stage.duration - (now() - self.stage_start)), + show_seconds=show_seconds + ) if running else '' + + # Create the stage string + longest_stage_len = max(len(stage.name) for stage in self.current_pattern.stages) + stage_format = "`{{prefix}}{{name:>{}}}:` {{dur}} min {{current}}".format(longest_stage_len) + + stage_str = '\n'.join( + stage_format.format( + prefix="->" if running and i == self.stage_index else "​ ", + name=stage.name, + dur=stage.duration, + current="(**{}**)".format(pretty_remaining) if running and i == self.stage_index else '' + ) for i, stage in enumerate(self.current_pattern.stages) + ) + + # Create the final formatted status string + status_str = ("**{name}**: {stage} {paused}\n" + "{stage_str}\n" + "{subbed_str}").format(name=self.data.name, + paused=" ***Paused***" if self.state == TimerState.PAUSED else "", + stage=self.current_stage.name if running else "*Timer not running.*", + stage_str=stage_str, + subbed_str=subbed_str) + else: + status_str = "**{}**: *Timer not set up.*\n{}".format(self.data.name, subbed_str) + + return status_str + + @classmethod + def create(cls, roleid, guildid, name, channelid, **kwargs): + log("Creating Timer with (roleid={!r}, guildid={!r}, name={!r}, channelid={!r})".format(roleid, + guildid, + name, + channelid), + context="rid:{}".format(roleid)) + + # Remove any existing timers under the same roleid + cls._table.delete_where(roleid=roleid) + + # Create new timer + data = cls._table.create_row(roleid=roleid, + guildid=guildid, + name=name, + channelid=channelid, + **kwargs) + + # Instantiate and return + return cls(data) + + async def destroy(self): + log("Destroying Timer with data {!r}".format(self.data), context="rid:{}".format(self.data.roleid)) + + # Stop the timer and unsubscribe all members + self.stop() + for subid in list(self.subscribers.keys()): + await self.unsubscribe(subid) + + # Remove the timer from data + self._table.delete_where(roleid=self.data.roleid) + + def load(self): + """ + Load discord objects from data. + + Returns + ------- + `True` if the timer successfully loaded. + `False` if the guild, channel, or role no longer exist. + """ + data = self.data = tables.timers.fetch(self.data.roleid) + + self.guild = client.get_guild(data.guildid) + if not self.guild: + log("Timer gone, guild (gid: {}) no longer exists.".format(data.guildid), + "tid:{}".format(data.roleid)) + return False + + self.role = self.guild.get_role(data.roleid) + if not self.role: + log("Timer gone, role no longer exists.", + "tid:{}".format(data.roleid)) + return False + + self.channel = self.guild.get_channel(data.channelid) + if not self.channel: + log("Timer gone, channel (cid: {}) no longer exists.".format(data.channelid), + "tid:{}".format(data.roleid)) + return False + + if data.voice_channelid: + self.voice_channel = self.guild.get_channel(data.voice_channelid) + else: + self.voice_channel = None + + return True + + async def post(self, *args, **kwargs): + """ + Safely send a message to the timer channel. + If an error occurs, in most cases ignore it. + As such, is not guaranteed to yield a `discord.Message`. + """ + # TODO: Reconsider if we want some form of cleanup here + try: + return await self.channel.send(*args, **kwargs) + except discord.Forbidden: + # We are not allowed to send to the timer channel + # Stop the timer + self.stop() + except discord.HTTPException: + # An unknown discord error occured + # Silently continue + pass + + async def setup(self, pattern=None, actor=None): + """ + Setup the timer with the given timer pattern. + If no pattern is given, uses the default pattern. + """ + pattern = pattern or self.default_pattern + + log("Setting up timer with pattern {!r}.".format(pattern.row), context="rid:{}".format(self.data.roleid)) + + # Ensure timer is stopped + self.stop() + + # Update runtime data for new pattern + self.current_pattern = pattern + self.stage_index = 0 + + tables.timer_pattern_history.insert( + timerid=self.data.roleid, + patternid=pattern.row.patternid, + modified_by=actor + ) + + async def start(self): + """ + Start the timer with the current pattern, or the default pattern. + """ + log("Starting timer.", context="rid:{}".format(self.data.roleid)) + if not self.current_pattern: + await self.setup() + + await self.change_stage(0, inactivity_check=False, finished_old=False) + self.state = TimerState.RUNNING + for subber in self.subscribers.values(): + subber.new_session() + + asyncio.create_task(self.runloop()) + + def stop(self): + """ + Stop the timer. + """ + if not self.state == TimerState.STOPPED: + log("Stopping timer.", context="rid:{}".format(self.data.roleid)) + # Trigger session save on all subscribers + for subber in self.subscribers.values(): + subber.close_session() + + # Change status to stopped + self.state = TimerState.STOPPED + + # Cancel loop wait task + if self._loop_wait_task and not self._loop_wait_task.done(): + self._loop_wait_task.cancel() + + def shift(self, amount=None): + """ + Shift the running timer forwards or backwards by the provided amount. + If `amount` is not given, aligns the start of the session to the nearest (UTC) hour. + + `amount` is the amount (in seconds) the stage start is shifted *forwards*. + This effectively adds `amount` to the stage duration, since it will change `amount` seconds later. + """ + if amount is None: + # Get the difference to the nearest hour + started = datetime.datetime.utcfromtimestamp(self.stage_start) + amount = started.minute * 60 + started.second + if amount > 1800: + amount = 3600 - amount + else: + amount = -1 * amount + + # Find the target stage and new stage start + remaining_amount = -1 * amount + i = self.stage_index + is_first = True + while True: + stage = self.current_pattern.stages[i] + stage_remaining = self.remaining if is_first else stage.duration * 60 + if remaining_amount >= stage_remaining: + is_first = False + remaining_amount -= stage_remaining + i = (i + 1) % len(self.current_pattern.stages) + else: + break + target_stage = i + if is_first: + new_stage_start = self.stage_start - remaining_amount + shifts = [(self.stage_index, -remaining_amount)] + else: + new_stage_start = now() - remaining_amount + shifts = [ + (self.stage_index, now() - self.stage_start), + (target_stage, -1 * remaining_amount) + ] + + # Apply shifts + for subber in self.subscribers.values(): + for shift in shifts: + subber.stage_shift(*shift) + + # Update timer + self.stage_index = target_stage + self.stage_start = new_stage_start + + # Cancel loop wait task to rerun runloop + if self._loop_wait_task and not self._loop_wait_task.done(): + self._loop_wait_task.cancel() + + async def change_stage(self, stage_index, post=True, inactivity_check=True, finished_old=True): + """ + Change the timer stage to the given index in the current pattern. + + Parameters + ---------- + stage_index: int + Index to move to in the current pattern. + Will be modded by the lenth of the pattern. + post: bool + Whether to post a stage change message in the linked text channel. + """ + log( + "Changing stage from {} to {}. (post={}, inactivity_check={}, finished_old={})".format( + self.stage_index, + stage_index, + post, + inactivity_check, + finished_old + ), context="rid:{}".format(self.data.roleid), level=logging.DEBUG + ) + + # If the stage change is triggered by finishing a stage, adjust current time to match + if finished_old: + _now = self.stage_start + self.current_stage.duration * 60 + if not -3600 < _now - now() < 3600: + # Don't voice notify if there is a significant real time difference + post = False + else: + _now = now() + + # Update stage info and save the current and new stages + old_stage = self.current_stage + old_index = self.stage_index + + self.stage_index = stage_index % len(self.current_pattern.stages) + self.stage_start = _now + new_stage = self.current_stage + + # Update the voice channel + asyncio.create_task(self.update_voice()) + + if len(self.subscribers) == 0: + # Skip notification and subscriber checks + # Handle empty reset, if enabled + if self.settings.auto_reset.value: + await self.setup() + return + + # Update subscriber sessions + if finished_old: + for sub in self.subscribers.values(): + sub.stage_finished(old_index) + + # Track subscriber inactivity + needs_warning = [] + unsubs = [] + if inactivity_check: + for sub in self.subscribers.values(): + if sub.warnings >= self.max_warnings: + sub.warnings += 1 + unsubs.append(sub) + elif (_now - sub.last_seen) > old_stage.duration * 60: + sub.warnings += 1 + if sub.warnings >= self.max_warnings: + needs_warning.append(sub) + + # Build message components + old_stage_str = "**{}** finished! ".format(old_stage.name) if finished_old else "" + warning_str = ( + "{} you will be unsubscribed on the next stage if you do not respond or react to this message.\n".format( + ', '.join('<@{}>'.format(sub.userid) for sub in needs_warning) + ) + ) if needs_warning else "" + unsub_str = ( + "{} you have been unsubscribed due to inactivity!\n".format( + ', '.join('<@{}>'.format(sub.userid) for sub in unsubs) + ) + ) if unsubs else "" + main_line = "{}Starting **{}** ({} minutes). {}".format( + old_stage_str, + new_stage.name, + new_stage.duration, + new_stage.message + ) + please_line = ( + "Please respond or react to this message to avoid being unsubscribed.\n" + ) if not self.settings.compact.value else "" + + # Post stage change message, if required + if post: + make_unmentionable = False + can_manage = self.guild.me.guild_permissions.manage_roles and self.guild.me.top_role > self.role + # Make role mentionable + if not self.role.mentionable and can_manage: + try: + await self.role.edit(mentionable=True, reason="Notifying for stage change.") + make_unmentionable = True + except discord.HTTPException: + pass + + # Send the message + out_msg = await self.post( + "{} {}\n{}{}{}".format( + self.role.mention, + main_line, + please_line, + warning_str, + unsub_str + ) + ) + if out_msg: + # Mark the message as being tracked + self.message_ids.append(out_msg.id) + self.message_ids = self.message_ids[-5:] # Truncate + + # Add the check reaction + try: + await out_msg.add_reaction(join_emoji) + await out_msg.add_reaction(leave_emoji) + except discord.HTTPException: + pass + + if make_unmentionable: + try: + await self.role.edit(mentionable=False, reason="Notifying finished.") + except discord.HTTPException: + pass + + # Do the voice alert, if required + if self.settings.voice_alert.value and self.voice_channel and finished_old and post: + asyncio.create_task(play_alert(self.voice_channel)) + + # Notify and unsubscribe as required + for sub in list(self.subscribers.values()): + try: + to_send = None + if sub in unsubs: + sub = await self.unsubscribe(sub.userid) + if sub.notify_level >= NotifyLevel.FINAL: + to_send = ( + "You have been unsubscribed from the group **{}** in {} due to inactivity!\n" + "You were subscribed for **{}**." + ).format(self.data.name, self.channel.mention, sub.pretty_clocked) + elif sub in needs_warning and sub.notify_level >= NotifyLevel.WARNING: + to_send = ( + "**Warning** from group **{}** in {}!\n" + "Please respond or react to a timer message " + "to avoid being unsubscribed on the next stage.\n{}".format( + self.data.name, + self.channel.mention, + main_line + ) + ) + elif sub.notify_level >= NotifyLevel.ALL: + to_send = "Status update for group **{}** in {}!\n{}".format(self.data.name, + self.channel.mention, + main_line) + + if to_send is not None: + await sub.send(to_send) + except discord.HTTPException: + pass + + async def subscribe(self, member, post=False): + """ + Subscribe a new member to the timer. + This may raise `discord.HTTPException`. + """ + log("Subscribing {!r}.".format(member), context="rid:{}".format(self.data.roleid)) + studyrole = GuildSettings(member.guild.id).studyrole.value + try: + if studyrole: + await member.add_roles(self.role, studyrole, reason="Applying study group role and global studyrole.") + else: + await member.add_roles(self.role, reason="Applying study group role.") + except discord.Forbidden: + desc = ( + "I don't have enough permissions to subscribe {} to {}!\n" + ).format(member.mention, self.role.mention) + if not self.guild.me.guild_permissions.manage_roles: + desc += "I require the `manage_roles` permission!" + elif not self.guild.me.top_role > self.role: + desc += ( + "My top role needs to be higher in the role list than the study group role {}." + ).format(self.role.mention) + elif studyrole and not self.guild.me.top_role > studyrole: + desc += ( + "My top role needs to be higher in the role list than the studyrole {}." + ).format(studyrole.mention) + + await self.post( + embed=discord.Embed( + description=desc, + colour=discord.Colour.red() + ) + ) + subscriber = TimerSubscriber(self, member.id, member=member) + if self.state == TimerState.RUNNING: + subscriber.new_session() + + self.subscribers[member.id] = subscriber + + if post: + # Send a welcome message + welcome = "Welcome to **{}**, {}!".format(self.data.name, member.mention) + welcome += ' ' if self.settings.compact.value else '\n' + + if self.state == TimerState.RUNNING: + welcome += "Currently on stage **{}** with **{}** remaining. {}".format( + self.current_stage.name, + self.pretty_remaining, + self.current_stage.message + ) + elif self.state in (TimerState.STOPPED, TimerState.UNSET): + welcome += ( + "The group timer is not running. Start it with `{0}start` " + "(or `{0}start ` to use a different timer pattern)." + ).format(best_prefix(member.guild.id)) + await self.post(welcome) + return subscriber + + async def unsubscribe(self, userid, post=False): + """ + Unsubscribe a member from the timer. + Raises `ValueError` if the user isn't subscribed. + Returns the old subscriber for session reporting. + """ + log("Unsubscribing (uid:{}).".format(userid), context="rid:{}".format(self.data.roleid)) + + if userid not in self.subscribers: + raise ValueError("Attempted to unsubscribe a non-existent user!") + subscriber = self.subscribers.pop(userid) + subscriber.close_session() + + studyrole = GuildSettings(self.guild.id).studyrole.value + try: + # Use a manual request to avoid requiring the member object + await client.http.remove_role(self.guild.id, userid, self.role.id, reason="Removing study group role.") + if studyrole: + await client.http.remove_role(self.guild.id, userid, studyrole.id, reason="Removing global studyrole.") + except discord.HTTPException: + pass + + if post: + await self.post( + "Goodbye <@{}>! You were subscribed for **{}**.".format( + userid, subscriber.pretty_clocked + ) + ) + + return subscriber + + async def update_voice(self): + """ + Update the name of the associated voice channel. + """ + if not self.voice_channel or self.voice_channel not in self.guild.channels: + # Return if there is no associated voice channel + return + if self.state != TimerState.RUNNING: + # Don't update if we aren't running + return + if now() - self.last_voice_update < 10 * 60: + # Return if the last update was less than 10 minutes ago (discord ratelimiting) + return + + name = self.voice_channel_name + + if name == self.voice_channel.name: + # Don't update if there are no changes + return + + log("Updating vc name to {}.".format(name), + context="rid:{}".format(self.data.roleid), + level=logging.DEBUG) + try: + self.last_voice_update = now() + await self.voice_channel.edit(name=name) + self.last_voice_update = now() + except discord.HTTPException: + # Nothing we can do + pass + + async def runloop(self): + """ + Central runloop. + Handles firing stage-changes and voice channel updates. + """ + while self.state == TimerState.RUNNING: + remaining = self.remaining + if remaining <= 0: + try: + await self.change_stage(self.stage_index + 1) + except Exception: + log("Exception encountered while changing stage.", + context="rid:{}".format(self.role.id), + level=logging.ERROR, + add_exc_info=True) + elif remaining > 600 and self.subscribers: + await self.update_voice() + + self._loop_wait_task = asyncio.create_task(asyncio.sleep(min(600, remaining))) + try: + await self._loop_wait_task + except asyncio.CancelledError: + pass + + def serialise(self): + return { + 'roleid': self.data.roleid, + 'state': self.state.value, + 'patternid': self.current_pattern.row.patternid if self.current_pattern else None, + 'stage_index': self.stage_index, + 'stage_start': self.stage_start, + 'message_ids': self.message_ids, + 'subscribers': [subber.serialise() for subber in self.subscribers.values()], + 'last_voice_update': self.last_voice_update + } + + def restore_from(self, data): + log("Restoring Timer (rid:{}).".format(data['roleid']), context='TIMER_RESTORE') + self.stage_index = data['stage_index'] + self.stage_start = data['stage_start'] + self.state = TimerState(data['state']) + self.current_pattern = Pattern.get(data['patternid'] if data['patternid'] is not None else self.patternid) + self.message_ids = data['message_ids'] + self.last_voice_update = data['last_voice_update'] + + self.subscribers = {} + for sub_data in data['subscribers']: + subber = TimerSubscriber(self, sub_data['userid'], name=sub_data['name']) + subber.restore_from(sub_data) + self.subscribers[sub_data['userid']] = subber + + asyncio.create_task(self.runloop()) + + +class TimerChannel: + """ + Represents a discord text channel holding one or more timers. + + Manages the pinned update message. + """ + __slots__ = ( + 'channel', + 'timers', + 'pinned_msg', + 'pinned_msg_id', + 'previous_desc', + 'failure_count' + ) + + def __init__(self, channel): + self.channel: discord.TextChannel = channel + + self.timers = [] + self.pinned_msg = None + self.pinned_msg_id = None + + self.previous_desc = '' + + self.failure_count = 0 + + async def update_pin(self, force=False): + if not force and self.failure_count > 5: + return + + if self.channel not in self.channel.guild.channels: + return + + if self.pinned_msg is None and self.pinned_msg_id is not None: + try: + self.pinned_msg = await self.channel.fetch_message(self.pinned_msg_id) + except discord.HTTPException: + self.pinned_msg_id = None + + desc = '\n\n'.join(timer.pinstatus for timer in self.timers) + if desc and desc != self.previous_desc: + self.previous_desc = desc + + # Build embed + embed = discord.Embed( + title="Pomodoro Timer Status", + description=desc, + timestamp=datetime.datetime.utcnow() + ) + embed.set_footer(text="Last Updated") + + if self.pinned_msg is not None: + try: + await self.pinned_msg.edit(embed=embed) + except discord.NotFound: + self.pinned_msg = None + except discord.HTTPException: + # An obscure permission error or discord dying? + self.failure_count += 1 + return + elif force or all(timer.state != TimerState.STOPPED for timer in self.timers): + # Attempt to generate a new pinned message + try: + self.pinned_msg = await self.channel.send(embed=embed) + except discord.Forbidden: + # We can't send embeds, or maybe any messages? + # First stop the timers, then try to report the error + self.failure_count = 100 + for timer in self.timers: + timer.stop() + + perms = self.channel.permissions_for(self.channel.guild.me) + if perms.send_messages and not perms.embed_links: + try: + await self.channel.send( + "I require the `embed links` permission in this channel! Timers stopped." + ) + except discord.HTTPException: + # Nothing we can do... + pass + return + + # Now attempt to pin the message + try: + await self.pinned_msg.pin() + except discord.Forbidden: + await self.channel.send( + "I don't have the `manage messages` permission required to pin the channel status message! " + "Please pin the message manually." + ) + except discord.HTTPException: + pass + + def serialise(self): + return { + 'channelid': self.channel.id, + 'pinned_msg_id': self.pinned_msg.id if self.pinned_msg else None, + 'timers': [timer.serialise() for timer in self.timers] + } + + def restore_from(self, data): + log("Restoring Timer Channel (cid:{}).".format(data['channelid']), context='TIMER_RESTORE') + self.pinned_msg_id = data['pinned_msg_id'] + + timers = {timer.data.roleid: timer for timer in self.timers} + for timer_data in data['timers']: + timer = timers.get(timer_data['roleid'], None) + if timer is not None: + timer.restore_from(timer_data) + + +class TimerSubscriber: + """ + Represents a member subscribed to a timer. + """ + __slots__ = ( + 'timer', + 'userid', + '_name', + 'member', + '_fetch_task', + 'subscribed_at', + 'last_seen', + 'warnings', + 'clocked_time', + 'session_started', + 'session', + ) + + def __init__(self, timer: Timer, userid, member=None, name=None): + self.timer = timer # Timer the member is subscribed to + self.userid = userid # Discord userid + self.member = member # Discord member object, if provided + + self._name = name # Backup name used when there is no member object + self._fetch_task = None # Potential asyncio.Task for fetching the member object + + self.last_seen = now() # Last seen, for activity tracking + self.warnings = 0 # Current number of warnings + self.clocked_time = 0 # Total clocked session time in this subscription (in seconds) + + self.session_started = None + self.session = None + + if self.member and self.member.name != self.user_data.name: + self.user_data.name = self.member.name + + @property + def name(self): + """ + Name of the member. + May be retrieved from `_name` if the member doesn't exist yet. + """ + return self.member.display_name if self.member else self._name or 'Unknown' + + @property + def user_data(self): + return tables.users.fetch_or_create(self.userid) + + @property + def notify_level(self): + raw = self.user_data.notify_level + return NotifyLevel(raw) if raw is not None else NotifyLevel.WARNING + + @property + def pretty_clocked(self): + return parse_dur(self.clocked_time, True) + + @property + def unsaved_time(self): + """ + Clocked time not yet saved in a session. + """ + return (now() - self.session_started) if self.session else 0 + + def touch(self): + """ + Update `last_seen`, and reset warning count. + """ + self.last_seen = now() + self.warnings = 0 + + async def _fetch_member(self): + try: + self.member = await self.timer.guild.fetch_member(self.userid) + if self.member.name != self.user_data.name: + self.user_data.name = self.member.name + except discord.HTTPException: + pass + + async def send(self, *args, **kwargs): + if self.member: + await self.member.send(*args, **kwargs) + else: + if self._fetch_task is None: + self._fetch_task = asyncio.create_task(self._fetch_member()) + if not self._fetch_task.done(): + try: + await self._fetch_task + except asyncio.CancelledError: + pass + await self.send(*args, **kwargs) + + def set_member(self, member): + """ + Set the member for this subscriber, if unset. + """ + if not self.member and member.id == self.userid: + self.member = member + if self.member.name != self.user_data.name: + self.user_data.name = self.member.name + + if self._fetch_task: + if self._fetch_task.done(): + self._fetch_task = None + else: + self._fetch_task.cancel() + + def new_session(self): + """ + Start a new session for this subscriber. + Requires the timer to be setup. + Typically called after subscription or timer start. + """ + # Close any existing session + self.close_session() + + # Initialise the new session + self.session_started = now() + self.session = [(0, 0) for stage in self.timer.current_pattern] + + # Apply the initial join shift + if self.timer.state == TimerState.RUNNING: + shift = self.timer.stage_start - self.session_started + self.session[self.timer.stage_index] = (0, shift) + + def close_session(self): + """ + Save and close the current session, if any. + This may occur upon unsubscribing or stopping/pausing the timer. + """ + if self.session: + _now = now() + + # Final shift + if self.timer.state == TimerState.RUNNING: + shift = _now - self.timer.stage_start + count, current_shift = self.session[self.timer.stage_index] + self.session[self.timer.stage_index] = (count, current_shift + shift) + + # Save session + duration = _now - self.session_started + focused_duration = sum( + t[0] * stage.duration * 60 + t[1] + for t, stage in zip(self.session, self.timer.current_pattern) + if stage.focus + ) + # Don't save if the session was under a minute + if duration > 60: + tables.sessions.insert( + guildid=self.timer.guild.id, + userid=self.userid, + roleid=self.timer.role.id, + start_time=self.session_started, + duration=duration, + focused_duration=focused_duration, + patternid=self.timer.current_pattern.row.patternid, + stages=json.dumps(self.session) + ) + + # Update clocked time + self.clocked_time += duration + + # Reset session state + self.session_started = None + self.session = None + + def stage_finished(self, stageid): + """ + Finish a stage, adding it to the running session + """ + count, shift = self.session[stageid] + self.session[stageid] = count + 1, shift + + def stage_shift(self, stageid, diff): + """ + Shift a stage (i.e. move the stage start forwards by `shift`, temporarily increasing the stage length). + """ + count, shift = self.session[stageid] + self.session[stageid] = count, shift + diff + + def serialise(self): + return { + 'userid': self.userid, + 'timerid': self.timer.role.id, + 'session_started': self.session_started, + 'session': self.session, + 'name': self.name + } + + def restore_from(self, data): + log("Restoring Subscriber (uid:{}).".format(data['userid']), context='TIMER_RESTORE') + self.session_started = data['session_started'] + self.session = data['session'] diff --git a/bot/Timer/guild_events.py b/bot/Timer/guild_events.py new file mode 100644 index 0000000..5ccc02c --- /dev/null +++ b/bot/Timer/guild_events.py @@ -0,0 +1,52 @@ +from data import tables + +from .core import Timer, TimerChannel + + +async def on_guild_join(client, guild): + """ + (Re)-load the guild timers when we join a guild. + """ + count = 0 + timer_rows = tables.timers.fetch_rows_where(guildid=guild.id) + if timer_rows: + timers = [Timer(row) for row in timer_rows] + timers = [timer for timer in timers if timer.load()] + + channels = client.interface.guild_channels[guild.id] = {} + for timer in timers: + channel = channels.get(timer.channel.id, None) + if channel is None: + channel = channels[timer.channel.id] = TimerChannel(timer.channel) + channel.timers.append(timer) + count += 1 + + client.log( + "Joined new guild \"{}\" (gid: {}) and loaded {} pre-existing timers.".format( + guild.name, + guild.id, + count + ) + ) + + +async def on_guild_remove(client, guild): + """ + Unsubscribe and unload the guild timers when we leave a guild. + """ + count = 0 + channels = client.interface.guild_channels.pop(guild.id, {}) + for channelid, tchannel in channels.items(): + for timer in tchannel.timers: + count += 1 + timer.stop() + for subber in timer.subscribers.values(): + subber.close_session() + + client.log( + "Left guild \"{}\" (gid: {}) and cleaned up {} timers.".format( + guild.name, + guild.id, + count + ) + ) diff --git a/bot/Timer/interface.py b/bot/Timer/interface.py deleted file mode 100644 index ba00790..0000000 --- a/bot/Timer/interface.py +++ /dev/null @@ -1,436 +0,0 @@ -import os -import traceback -import logging -import json -import asyncio - -import discord - -from cmdClient import Context - -from logger import log - -from .trackers import message_tracker, reaction_tracker -from .Timer import Timer, TimerChannel, TimerSubscriber, TimerStage, NotifyLevel, TimerState -from .registry import TimerRegistry -from .voice import sub_on_vcjoin - - -class TimerInterface(object): - save_interval = 120 - save_fp = "data/timerstatus.json" - - def __init__(self, client, db_filename): - self.client = client - self.registry = TimerRegistry(db_filename) - - self.guild_channels = {} - self.channels = {} - self.subscribers = {} - - self.last_save = 0 - - self.ready = False - - self.setup_client() - - def setup_client(self): - client = self.client - - # Bind the interface - client.interface = self - - # Ensure required config entry exists - client.config.guilds.ensure_exists("timers") - - # Load timers from database - client.add_after_event("ready", self.launch) - - # Track user activity in timer channels - client.add_after_event("message", message_tracker) - client.add_after_event("raw_reaction_add", reaction_tracker) - client.add_after_event("raw_reaction_add", self.reaction_sub) - - # Voice event handlers - client.add_after_event("voice_state_update", sub_on_vcjoin) - - async def launch(self, client): - if self.ready: - return - - self.load_timers() - await self.restore_save() - - self.ready = True - asyncio.ensure_future(self.updateloop()) - - async def updateloop(self): - while True: - channels = self.channels.values() - delay = max((0.1, 60/len(channels))) - - for tchan in channels: - asyncio.ensure_future(tchan.update()) - await asyncio.sleep(delay) - - if Timer.now() - self.last_save > self.save_interval: - self.update_save() - - def load_timers(self): - client = self.client - - # Get the guilds with timers - guilds = client.config.guilds.find_not_empty("timers") - - for guildid in guilds: - # List of TimerChannels in the guild - channels = [] - - # Fetch the actual guild, if possible - guild = client.get_guild(guildid) - if guild is None: - continue - - # Get the corresponding timers - raw_timers = client.config.guilds.get(guildid, "timers") - for name, roleid, channelid, clock_channelid in raw_timers: - # Get the objects corresponding to the ids - role = guild.get_role(roleid) - channel = guild.get_channel(channelid) - clock_channel = guild.get_channel(clock_channelid) if clock_channelid != 0 else None - - if role is None or channel is None: - # This timer doesn't exist - # TODO: Handle garbage collection - continue - - # Create the new timer - new_timer = Timer(name, role, channel, clock_channel) - - # Get the timer channel, or create it - tchan = self.channels.get(channelid, None) - if tchan is None: - tchan = TimerChannel(channel) - channels.append(tchan) - self.channels[channelid] = tchan - - # Bind the timer to the channel - tchan.timers.append(new_timer) - - # Assign the channels to the guild - self.guild_channels[guildid] = channels - - async def restore_save(self): - # Open save file if it exists - if not os.path.exists(self.save_fp): - return - - with open(self.save_fp) as f: - try: - savedata = json.load(f) - except Exception: - log("Caught the following exception loading the temporary savefile\n{}".format(traceback.format_exc()), - context="TIMER_RESTORE", - level=logging.ERROR) - os.rename(self.save_fp, self.save_fp + '_CORRUPTED') - return - - if savedata: - # Create a roleid: timer map - timers = {timer.role.id: timer for channel in self.channels.values() for timer in channel.timers} - - for timer in savedata['timers']: - if timer['roleid'] in timers: - timers[timer['roleid']].update_from_data(timer) - log("Restored timer {} (roleid {}) from save.".format(timer['name'], timer['roleid']), - context="TIMER_RESTORE") - - for sub_data in savedata['subscribers']: - if sub_data['roleid'] in timers: - timer = timers[sub_data['roleid']] - - guild = self.client.get_guild(sub_data['guildid']) - if guild is None: - continue - - try: - member = await guild.fetch_member(sub_data['id']) - except discord.Forbidden: - continue - except discord.NotFound: - continue - - if member is None: - continue - - subber = TimerSubscriber.deserialise(member, timer, self, sub_data) - self.subscribers[(member.guild.id, member.id)] = subber - timer.subscribed[member.id] = subber - - log("Restored subscriber {} (id {}) in timer {} (roleid {}) from save.".format(member.name, - member.id, - timer.name, - timer.role.id), - context="TIMER_RESTORE") - - for tchan_data in savedata['timer_channels']: - if tchan_data['id'] in self.channels: - tchan = self.channels[tchan_data['id']] - try: - tchan.msg = await tchan.channel.fetch_message(tchan_data['msgid']) - except discord.NotFound: - continue - except discord.Forbidden: - continue - - def update_save(self, save_name="autosave"): - # Generate save dict - timers = [timer for channel in self.channels.values() for timer in channel.timers] - timer_data = [timer.serialise() for timer in timers if timer.stages] - sub_data = [subber.serialise() for subber in self.subscribers.values()] - tchan_data = [{'id': tchan.channel.id, 'msgid': tchan.msg.id} for tchan in self.channels.values() if tchan.msg] - - data = { - 'timers': timer_data, - 'subscribers': sub_data, - 'timer_channels': tchan_data - } - data_str = json.dumps(data) - - # Backup the save file - if os.path.exists(self.save_fp): - os.replace(self.save_fp, "{}.{}.old".format(self.save_fp, save_name)) - - with open(self.save_fp, 'w') as f: - f.write(data_str) - - self.last_save = Timer.now() - - async def reaction_sub(self, client, payload): - """ - Subscribe a user to a timer if press the subscribe reaction. - """ - # Return if the emoji isn't the right one - if str(payload.emoji) != "✅": - return - - # Quit if the reaction is in DM or we can't see the guild - if payload.guild_id is None: - return - guild = client.get_guild(payload.guild_id) - if guild is None: - return - - # Return if the member is already subscribed - if (payload.guild_id, payload.user_id) in self.subscribers: - return - - # Quit if the user is the client - if payload.user_id == client.user.id: - return - - # Get the timers in the current channel - tchan = self.channels.get(payload.channel_id, None) - if tchan is None: - return - timers = tchan.timers - - # Get the timer who owns the message, if any - timer = next((timer for timer in timers if payload.message_id in timer.timer_messages), None) - if timer is None: - return - - # Get the reacting user - user = guild.get_member(payload.user_id) - if user is None: - log(("Recieved subscribe reaction from (uid: {}) " - "in (cid: {}) but could not find the user!").format(payload.user_id, payload.channel_id), - context="TIMER_INTERFACE", - level=logging.ERROR) - return - - # Finally, subscribe the user to the timer - ctx = Context(client, channel=timer.channel, guild=timer.channel.guild, author=user) - log("Reaction-subscribing user {} (uid: {}) to timer {} (rid: {})".format(user.name, - user.id, - timer.name, - timer.role.id), - context="TIMER_INTERFACE") - await self.sub(ctx, user, timer) - - # Send a welcome message - welcome = "Welcome to **{}**, {}!\n".format(timer.name, user.mention) - if timer.stages and timer.state == TimerState.RUNNING: - welcome += "Currently on stage **{}** with **{}** remaining. {}".format( - timer.stages[timer.current_stage].name, - timer.pretty_remaining(), - timer.stages[timer.current_stage].message - ) - elif timer.stages: - welcome += "Group timer is set up but not running." - - await ctx.ch.send(welcome) - - def create_timer(self, group_name, group_role, bound_channel, clock_channel=None): - """ - Create a new timer, attach it to a timer channel, and save it to storage. - """ - guild = group_role.guild - - # Create the new timer - new_timer = Timer(group_name, group_role, bound_channel, clock_channel) - - # Bind the timer to a timer channel, creating if required - tchan = self.channels.get(bound_channel.id, None) - if tchan is None: - # Create the timer channel - tchan = TimerChannel(bound_channel) - self.channels[bound_channel.id] = tchan - - # Add the timer channel to the guild list, creating if required - guild_channels = self.guild_channels.get(guild.id, None) - if guild_channels is None: - guild_channels = [] - self.guild_channels[guild.id] = guild_channels - guild_channels.append(tchan) - tchan.timers.append(new_timer) - - # Store the new timer in guild config - timers = self.client.config.guilds.get(guild.id, "timers") or [] - timers.append( - (group_name, - group_role.id, - bound_channel.id, - clock_channel.id if clock_channel else 0) - ) - self.client.config.guilds.set(guild.id, "timers", timers) - - return new_timer - - def destroy_timer(self, timer): - # Unsubscribe all members - for sub in timer.subscribed.values(): - asyncio.ensure_future(sub.unsub()) - - # Stop the timer - timer.stop() - - # Remove the timer from its channel - tchan = self.channels.get(timer.channel.id, None) - if tchan is not None: - tchan.timers.remove(timer) - # Cleanup if the channel has no remaining timers - if len(tchan.timers) == 0: - self.channels.pop(timer.channel.id) - - # Update the guild timer config - guild = timer.channel.guild - timers = self.client.config.guilds.get(guild.id, "timers") or [] - tup = next(tup for tup in timers if tup[0] == timer._truename and tup[1] == timer.role.id) - timers.remove(tup) - self.client.config.guilds.set(guild.id, "timers", timers) - - def get_timer_for(self, guildid, userid): - """ - Retrieve timer for the given member, or None. - """ - if (guildid, userid) in self.subscribers: - return self.subscribers[(guildid, userid)].timer - else: - return None - - def get_subs_for(self, userid): - """ - Retrieve all TimerSubscribers for the given userid. - """ - return [value for (key, value) in self.subscribers.items() if key[1] == userid] - - def get_channel_timers(self, channelid): - if channelid in self.channels: - return self.channels[channelid].timers - else: - return None - - def get_guild_timers(self, guildid): - if guildid in self.guild_channels: - return [timer for tchan in self.guild_channels[guildid] for timer in tchan.timers] - - async def wait_until_ready(self): - while not self.ready: - await asyncio.sleep(1) - - def bump_user(self, guildid, channelid, userid): - # Return if we are in a DM context - if guildid == 0: - return - - # Grab the subscriber if it exists - subber = self.subscribers.get((guildid, userid), None) - - # Bump the subscriber - if subber is not None and channelid == subber.timer.channel.id: - subber.bump() - - async def sub(self, ctx, member, timer): - log("Subscribing user {} (uid: {}) to timer {} (rid: {})".format(member.name, - member.id, - timer.name, - timer.role.id), - context="TIMER_INTERFACE") - - # Ensure that the user is not subscribed elsewhere - await self.unsub(member.guild.id, member.id) - - # Get the notify level - notify = ctx.client.config.users.get(member.id, "notify_level") - notify = NotifyLevel(notify) if notify is not None else NotifyLevel.WARNING - - # Create the subscriber - subber = TimerSubscriber(member, timer, self, notify=notify) - - # Attempt to add the sub role - try: - await member.add_roles(timer.role) - except discord.Forbidden: - await ctx.error_reply("Insufficient permissions to add the group role `{}`.".format(timer.role.name)) - except discord.NotFound: - await ctx.error_reply("Group role `{}` doesn't exist! This group is broken.".format(timer.role.id)) - - timer.subscribed[member.id] = subber - self.subscribers[(member.guild.id, member.id)] = subber - - async def unsub(self, guildid, userid): - """ - Unsubscribe a member from a timer, if they are subscribed. - Otherwise, do nothing. - Return the session data for ease of access. - """ - subber = self.subscribers.get((guildid, userid), None) - if subber is not None: - session = subber.session_data() - subber.active = False - - self.subscribers.pop((guildid, userid)) - subber.timer.subscribed.pop(userid) - - try: - await subber.member.remove_roles(subber.timer.role) - except Exception: - pass - - self.registry.new_session(*session) - return session - - @staticmethod - def parse_setupstr(setupstr): - stringy_stages = [stage.strip() for stage in setupstr.strip(';').split(';')] - - stages = [] - for stringy_stage in stringy_stages: - parts = [part.strip() for part in stringy_stage.split(",", maxsplit=2)] - - if len(parts) < 2 or not parts[1].isdigit(): - return None - stages.append(TimerStage(parts[0], int(parts[1]), message=parts[2] if len(parts) > 2 else "")) - - return stages diff --git a/bot/Timer/lib.py b/bot/Timer/lib.py new file mode 100644 index 0000000..d6376cc --- /dev/null +++ b/bot/Timer/lib.py @@ -0,0 +1,75 @@ +from enum import IntEnum + +from cmdClient.lib import SafeCancellation + +from meta import client +from data import tables +from utils.lib import timestamp_utcnow as now # noqa + + +join_emoji = '✅' +leave_emoji = '❌' + + +def parse_dur(diff, show_seconds=False): + """ + Parse a duration given in seconds to a time string. + """ + diff = max(diff, 0) + if show_seconds: + hours = diff // 3600 + minutes = (diff % 3600) // 60 + seconds = diff % 60 + return "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) + else: + diff = int(60 * round(diff / 60)) + hours = diff // 3600 + minutes = (diff % 3600) // 60 + return "{:02d}:{:02d}".format(hours, minutes) + + +def best_prefix(guildid): + if not guildid: + return client.prefix + else: + return tables.guilds.fetch_or_create(guildid).prefix or client.prefix + + +class TimerState(IntEnum): + """ + Enum representing the current running state of the timer. + + Values + ------ + UNSET: The timer isn't set up. + STOPPED: The timer is stopped. + RUNNING: The timer is running. + PAUSED: The timer has been paused. + """ + + UNSET = 0 + STOPPED = 1 + RUNNING = 2 + PAUSED = 3 + + +class NotifyLevel(IntEnum): + """ + Enum representing a subscriber's notification level. + NONE: Never send direct messages. + FINAL: Send a direct message when kicking for inactivity. + WARNING: Send direct messages for unsubscription warnings. + ALL: Send direct messages for all stage updates. + """ + NEVER = 1 + FINAL = 2 + WARNING = 3 + ALL = 4 + + +class InvalidPattern(SafeCancellation): + """ + Exception raised when an invalid pattern format is encountered. + Stores user-readable information about the pattern error. + """ + pass diff --git a/bot/Timer/registry.py b/bot/Timer/registry.py deleted file mode 100644 index 70d79ea..0000000 --- a/bot/Timer/registry.py +++ /dev/null @@ -1,57 +0,0 @@ -import sqlite3 as sq - - -class TimerRegistry(object): - session_keys = ( - 'userid', - 'guildid', - 'roleid', - 'starttime', - 'duration' - ) - - def __init__(self, db_file): - self.conn = sq.connect(db_file, timeout=20) - self.conn.row_factory = sq.Row - - self.ensure_table() - - def ensure_table(self): - """ - Ensure the session table exists, otherwise create it. - """ - cursor = self.conn.cursor() - columns = ("userid INTEGER NOT NULL, " - "guildid INTEGER NOT NULL, " - "roleid INTEGER NOT NULL, " - "starttime INTEGER NOT NULL, " - "duration INTEGER NOT NULL") - - cursor.execute("CREATE TABLE IF NOT EXISTS sessions ({})".format(columns)) - self.conn.commit() - - def close(self): - self.conn.commit() - self.conn.close() - - def get_sessions_where(self, **kwargs): - keys = [(key, kwargs[key]) for key in kwargs if key in self.session_keys] - - if keys: - keystr = "WHERE " + " AND ".join("{} = ?".format(key) for key, val in keys) - else: - keystr = "" - - cursor = self.conn.cursor() - cursor.execute('SELECT * FROM sessions {}'.format(keystr), tuple(value for key, value in keys)) - return cursor.fetchall() - - def new_session(self, *args): - if len(args) != len(self.session_keys): - raise ValueError("Improper number of session keys passed for storage.") - - cursor = self.conn.cursor() - value_str = ", ".join('?' for key in args) - - cursor.execute('INSERT INTO sessions VALUES ({})'.format(value_str), tuple(args)) - self.conn.commit() diff --git a/bot/Timer/timer_reactions.py b/bot/Timer/timer_reactions.py new file mode 100644 index 0000000..8b74341 --- /dev/null +++ b/bot/Timer/timer_reactions.py @@ -0,0 +1,28 @@ +from meta import log + +from .core import join_emoji, leave_emoji + + +async def joinleave_tracker(client, payload): + if payload.user_id == client.user.id: + return + + if payload.guild_id and (payload.emoji.name in (join_emoji, leave_emoji)): + timers = client.interface.get_timers_in(payload.guild_id, payload.channel_id) + timer = next((timer for timer in timers if payload.message_id in timer.message_ids), None) + if timer: + userid = payload.user_id + guild = timer.channel.guild + if payload.emoji.name == join_emoji and not client.interface.get_subscriber(userid, guild.id): + member = guild.get_member(userid) or await guild.fetch_member(userid) + + # Subscribe member + log("Reaction subscribing {}(uid:{}) in {}(gid:{}) to {}(rid:{})".format( + member, userid, guild, payload.guild_id, timer.data.name, timer.data.roleid + ), context="TIMER_REACTIONS") + await timer.subscribe(member, post=True) + elif payload.emoji.name == leave_emoji and payload.user_id in timer: + log("Reaction unsubscribing (uid:{}) in {}(gid:{}) from {}(rid:{})".format( + userid, guild, payload.guild_id, timer.data.name, timer.data.roleid + ), context="TIMER_REACTIONS") + await timer.unsubscribe(payload.user_id, post=True) diff --git a/bot/Timer/trackers.py b/bot/Timer/trackers.py deleted file mode 100644 index 0bee405..0000000 --- a/bot/Timer/trackers.py +++ /dev/null @@ -1,8 +0,0 @@ -async def message_tracker(client, message): - client.interface.bump_user(message.guild or 0, message.channel.id, message.author.id) - - -async def reaction_tracker(client, payload): - client.interface.bump_user(payload.guild_id or 0, - payload.channel_id, - payload.user_id) diff --git a/bot/Timer/voice.py b/bot/Timer/voice.py deleted file mode 100644 index 23a4283..0000000 --- a/bot/Timer/voice.py +++ /dev/null @@ -1,57 +0,0 @@ -from cmdClient import Context - -from logger import log - -from .Timer import TimerState - - -async def sub_on_vcjoin(client, member, before, after): - """ - When a member joins a study group voice channel, automatically subscribe them to the study group. - """ - if before.channel is None and after.channel is not None: - # Join voice channel event - - # Quit if the member is a bot - if member.bot: - return - - # Quit if the member is already subscribed - if (member.guild.id, member.id) in client.interface.subscribers: - return - - guild_timers = client.interface.get_guild_timers(member.guild.id) - - # Quit if there are no groups in this guild - if not guild_timers: - return - - # Get the collection of clocks in the guild - guild_clocks = {timer.clock_channel.id: timer for timer in guild_timers if timer.clock_channel is not None} - - # Quit if the voice channel is not a clock channel, otherwise get the related timer - timer = guild_clocks.get(after.channel.id, None) - if timer is None: - return - - # Finally, subscribe the member to the timer - ctx = Context(client, channel=timer.channel, guild=timer.channel.guild, author=member) - log("Reaction-subscribing user {} (uid: {}) to timer {} (rid: {})".format(member.name, - member.id, - timer.name, - timer.role.id), - context="CLOCK_AUTOSUB") - await client.interface.sub(ctx, member, timer) - - # Send a welcome message - welcome = "Welcome to **{}**, {}!\n".format(timer.name, member.mention) - if timer.stages and timer.state == TimerState.RUNNING: - welcome += "Currently on stage **{}** with **{}** remaining. {}".format( - timer.stages[timer.current_stage].name, - timer.pretty_remaining(), - timer.stages[timer.current_stage].message - ) - elif timer.stages: - welcome += "Group timer is set up but not running." - - await ctx.ch.send(welcome) diff --git a/bot/Timer/voice_events.py b/bot/Timer/voice_events.py new file mode 100644 index 0000000..49f4676 --- /dev/null +++ b/bot/Timer/voice_events.py @@ -0,0 +1,45 @@ +import asyncio + +from meta import log + + +async def vc_update_handler(client, member, before, after): + if before.channel != after.channel: + if member.bot: + return + + voice_channels = { + timer.voice_channel.id: timer + for timer in client.interface.get_timers_in(member.guild.id) + if timer.voice_channel + } + left = voice_channels.get(before.channel.id, None) if before.channel else None + joined = voice_channels.get(after.channel.id, None) if after.channel else None + + leave = (left and member.id in left and left.settings.track_voice_join.value) + join = ( + joined and + joined.settings.track_voice_leave.value and + (leave or not client.interface.get_subscriber(member.id, member.guild.id)) + ) + existing_sub = None + if leave: + # TODO: Improve hysterisis, with locks maybe? + # TODO: Maybe add a personal ignore_voice setting + # Briefly wait to handle connection issues + if not after.channel: + await asyncio.sleep(5) + if member.voice and member.voice.channel: + return + log("Voice unsubscribing {}(uid:{}) in {}(gid:{}) from {}(rid:{})".format( + member, member.id, member.guild, member.guild.id, left.data.name, left.data.roleid + ), context="TIMER_REACTIONS") + existing_sub = await left.unsubscribe(member.id, post=not join) + + if join: + log("Voice subscribing {}(uid:{}) in {}(gid:{}) to {}(rid:{})".format( + member, member.id, member.guild, member.guild.id, joined.data.name, joined.data.roleid + ), context="TIMER_REACTIONS") + new_sub = await joined.subscribe(member, post=True) + if existing_sub: + new_sub.clocked_time = existing_sub.clocked_time diff --git a/bot/Timer/voice_notify.py b/bot/Timer/voice_notify.py new file mode 100644 index 0000000..e513e54 --- /dev/null +++ b/bot/Timer/voice_notify.py @@ -0,0 +1,37 @@ +import asyncio +import discord + + +voice_alert_path = "assets/sounds/slow-spring-board.wav" + +guild_locks = {} + + +async def play_alert(channel: discord.VoiceChannel): + if not channel.members: + # Don't notify an empty channel + return + + lock = guild_locks.get(channel.guild.id, None) + if not lock: + lock = guild_locks[channel.guild.id] = asyncio.Lock() + + async with lock: + vc = channel.guild.voice_client + if not vc: + vc = await channel.connect() + elif vc.channel != channel: + await vc.move_to(channel) + + audio_stream = open(voice_alert_path, 'rb') + try: + vc.play(discord.PCMAudio(audio_stream), after=lambda e: audio_stream.close()) + except discord.HTTPException: + pass + + count = 0 + while vc.is_playing() and count < 10: + await asyncio.sleep(0.5) + count += 1 + + await vc.disconnect() diff --git a/bot/cmdClient b/bot/cmdClient index b47fcb1..7f9a9e8 160000 --- a/bot/cmdClient +++ b/bot/cmdClient @@ -1 +1 @@ -Subproject commit b47fcb1c2b2a4c4bd785cec525633c5a5184bcab +Subproject commit 7f9a9e816d159eba609c74dbf31efe43d38d13f5 diff --git a/bot/commands/config.py b/bot/commands/config.py deleted file mode 100644 index 18157a8..0000000 --- a/bot/commands/config.py +++ /dev/null @@ -1,254 +0,0 @@ -import discord - -from cmdClient import cmd -from cmdClient.lib import ResponseTimedOut, UserCancelled -from cmdClient.checks import in_guild - -from wards import timer_admin, timer_ready -from utils import seekers, ctx_addons, timer_utils # noqa -# from Timer import create_timer - - -@cmd("newgroup", - group="Configuration", - desc="Create a new timer group.") -@in_guild() -@timer_ready() -@timer_admin() -async def cmd_addgrp(ctx): - """ - Usage``: - newgroup - newgroup - newgroup , , , - Description: - Creates a new group with the specified properties. - With no arguments or just `name` given, prompts for the remaining information. - Parameters:: - name: The name of the group to create. - role: The role given to people who join the group. - channel: The text channel which can access this group. - clock channel: The voice channel displaying the status of the group timer. - Related: - group, groups, delgroup - Examples``: - newgroup Espresso - newgroup Espresso, Study Group 1, #study-channel, #espresso-vc - """ - args = ctx.arg_str.split(",") - args = [arg.strip() for arg in args] - - if len(args) == 4: - name, role_str, channel_str, clockchannel_str = args - - # Find the specified objects - try: - role = await ctx.find_role(role_str.strip(), interactive=True) - channel = await ctx.find_channel(channel_str.strip(), interactive=True) - clockchannel = await ctx.find_channel(clockchannel_str.strip(), interactive=True) - except UserCancelled: - raise UserCancelled("User cancelled selection, no group was created.") from None - except ResponseTimedOut: - raise ResponseTimedOut("Selection timed out, no group was created.") from None - - # Create the timer - timer = ctx.client.interface.create_timer(name, role, channel, clockchannel) - elif len(args) >= 1 and args[0]: - timer = await newgroup_interactive(ctx, name=args[0]) - else: - timer = await newgroup_interactive(ctx) - - await ctx.reply("Group **{}** has been created and bound to channel {}.".format(timer.name, timer.channel.mention)) - - -async def newgroup_interactive(ctx, name=None, role=None, channel=None, clock_channel=None): - """ - Interactively create a new study group. - Takes keyword arguments to use any pre-existing data. - """ - try: - if name is None: - name = await ctx.input("Please enter a friendly name for the new study group:") - while role is None: - role_str = await ctx.input( - "Please enter the study group role.\n" - "This role is given to people who join the group, " - "and is used for notifications.\n" - "I must have permission to mention this role and give it to members. " - "Note that it must be below my highest role in the role list.\n" - "(Accepted input: Role name or partial name, role id, role mention, or `c` to cancel.)" - ) - if role_str.lower() == 'c': - raise UserCancelled - - role = await ctx.find_role(role_str.strip(), interactive=True) - - while channel is None: - channel_str = await ctx.input( - "Please enter the text channel to bind the group to.\n" - "The group will only be accessible from commands in this channel, " - "and the channel will host the pinned status message for this group.\n" - "I must have the `MANAGE_MESSAGES` permission in this channel to pin the status message.\n" - "(Accepted input: Channel name or partial name, channel id, channel mention, or `c` to cancel.)" - ) - if channel_str.lower() == 'c': - raise UserCancelled - - channel = await ctx.find_channel( - channel_str.strip(), - interactive=True, - chan_type=discord.ChannelType.text - ) - - while clock_channel is None: - clock_channel_str = await ctx.input( - "Please enter the group voice channel, or `s` to continue without an associated voice channel.\n" - "The name of this channel will be updated with the current stage and time remaining, " - "and members who join the channel will automatically be subscribed to the study group.\n" - "I must have the `MANAGE_CHANNEL` permission in this channel to update the name.\n" - "(Accepted input: Channel name or partial name, channel id, channel mention, " - "or `s` to skip or `c` to cancel.)" - ) - if clock_channel_str.lower() == 's': - break - - if clock_channel_str.lower() == 'c': - raise UserCancelled - - clock_channel = await ctx.find_channel( - clock_channel_str.strip(), - interactive=True, - chan_type=discord.ChannelType.voice - ) - except UserCancelled: - raise UserCancelled( - "User cancelled during group creation! " - "No group was created." - ) from None - except ResponseTimedOut: - raise ResponseTimedOut( - "Timed out waiting for a response during group creation! " - "No group was created." - ) from None - - # We now have all the data we need - return ctx.client.interface.create_timer(name, role, channel, clock_channel) - - -@cmd("delgroup", - group="Configuration", - desc="Remove a timer group.") -@in_guild() -@timer_ready() -@timer_admin() -async def cmd_delgrp(ctx): - """ - Usage``: - delgroup - Description: - Deletes the given group from the collection of timer groups in the current guild. - If `name` is not given or matches multiple groups, will prompt for group selection. - Parameters:: - name: The name of the group to delete. - Related: - group, groups, newgroup - Examples``: - delgroup Espresso - """ - try: - timer = await ctx.get_timers_matching(ctx.arg_str, channel_only=False) - except ResponseTimedOut: - raise ResponseTimedOut("Group selection timed out. No groups were deleted.") from None - except UserCancelled: - raise UserCancelled("User cancelled group selection. No groups were deleted.") from None - - if timer is None: - return await ctx.error_reply("No matching timers found!") - - # Delete the timer - ctx.client.interface.destroy_timer(timer) - - # Notify the user - await ctx.reply("The group `{}` has been removed!".format(timer.name)) - - -@cmd("adminrole", - group="Configuration", - desc="View or configure the timer admin role") -@in_guild() -async def cmd_adminrole(ctx): - """ - Usage``: - adminrole - adminrole - Description: - View the timer admin role (in the first usage), or set it to the provided role (in the second usage). - The timer admin role allows creation and deletion of group timers, - as well as modification of the guild registry and forcing timer operations. - - *Setting the timer admin role requires the guild permission `manage_guild`.* - Parameters:: - role: The name, partial name, or id of the new timer admin role. - """ - if ctx.arg_str: - if not ctx.author.guild_permissions.manage_guild: - return await ctx.error_reply("You need the `manage_guild` permission to change the timer admin role.") - - try: - role = await ctx.find_role(ctx.arg_str, interactive=True) - except UserCancelled: - raise UserCancelled("User cancelled role selection. Timer admin role unchanged.") from None - except ResponseTimedOut: - raise ResponseTimedOut("Role selection timed out. Timer admin role unchanged.") from None - - ctx.client.config.guilds.set(ctx.guild.id, "timeradmin", role.id) - - await ctx.embedreply("Timer admin role set to {}.".format(role.mention), color=discord.Colour.green()) - else: - roleid = ctx.client.config.guilds.get(ctx.guild.id, "timeradmin") - if roleid is None: - return await ctx.embedreply("No timer admin role set for this guild.") - role = ctx.guild.get_role(roleid) - if role is None: - await ctx.embedreply("Timer admin role set to a nonexistent role `{}`.".format(roleid)) - else: - await ctx.embedreply("Timer admin role is {}.".format(role.mention)) - - -@cmd("globalgroups", - group="Configuration", - desc="Configure whether groups are accessible away from their channel.") -@in_guild() -async def cmd_globalgroups(ctx): - """ - Usage``: - globalgroups [off | on] - Description: - Configure whether groups may only be joined from their associated channel. - This can, for instance, allow members to join a group before getting access to - the group study channel. - **Setting this option required the timer admin role, see `adminrole`.** - Options:: - on: Groups may be joined from any channel. - off: Groups may only be joined from the channel they are bound to. (**Default**) - Related: - newgroup, join, adminrole - """ - if ctx.arg_str: - if not (await timer_admin.run(ctx) or ctx.author.guild_permissions.manage_guild): - return await ctx.error_reply("You need the timeradmin role to configure this setting!") - - if ctx.arg_str.lower() == 'off': - ctx.client.config.guilds.set(ctx.guild.id, 'globalgroups', False) - await ctx.reply("Groups may now only be joined from their associated channel.") - elif ctx.arg_str.lower() == 'on': - ctx.client.config.guilds.set(ctx.guild.id, 'globalgroups', True) - await ctx.reply("Groups may now be joined from any guild channel.") - else: - await ctx.error_reply("Unrecognised option `{}`. See `help globalgroups` for usage.".format(ctx.arg_tr)) - else: - setting = ctx.client.config.guilds.get(ctx.guild.id, 'globalgroups') - if setting: - await ctx.reply("Groups may be joined from any guild channel.") - else: - await ctx.reply("Groups may only be joined from their associated channel.") diff --git a/bot/commands/existential_cmds.py b/bot/commands/existential_cmds.py new file mode 100644 index 0000000..0933a00 --- /dev/null +++ b/bot/commands/existential_cmds.py @@ -0,0 +1,255 @@ +import asyncio +import logging +import discord + +from cmdClient import cmd +from cmdClient.lib import ResponseTimedOut +from cmdClient.Context import Context + +from Timer import Pattern +from utils import ctx_addons, seekers, interactive # noqa +from wards import timer_admin, timer_ready + + +@cmd('newgroup', + group="Group Admin", + aliases=('newtimer', 'create', 'creategroup'), + short_help="Create a new study group.", + flags=('role==', 'channel==', 'voice==', 'pattern==')) +@timer_admin() +@timer_ready() +async def cmd_newgroup(ctx: Context, flags): + """ + Usage``: + {prefix}newgroup [flags] + Description: + Create a new study group (also called a timer) in the guild. + Flags:: + role: Role to give to members in the study group. + channel: Text channel where timer messages are posted. + voice: Voice channel associated with the group. + pattern: Default timer pattern used when resetting the group timer. + Examples``: + {prefix}newgroup AwesomeStudyGroup + {prefix}newgroup ExtraAwesomeStudyGroup --channel #studychannel --pattern 50/10 + """ + timers = ctx.timers.get_timers_in(ctx.guild.id) + + # Parse the group name + name = ctx.args + if name: + if len(name) > 30: + return await ctx.error_reply("The group name must be under 30 characters!") + else: + while not name or len(name) > 30: + try: + if not name: + name = await ctx.input("Please enter a name for the new study group:") + else: + name = await ctx.input("The group name must be under 30 characters! Please try again:") + except ResponseTimedOut: + raise ResponseTimedOut("Session timed out! No group created.") + + if any(name.lower() == timer.data.name.lower() for timer in timers): + return await ctx.error_reply("There is already a group with this name!") + name_line = "**Creating new study group `{}`.**".format(name) + + # Parse flags + role = None + if flags['role']: + role = await ctx.find_role(flags['role'], interactive=True) + if not role: + return + + channel = None + if flags['channel']: + channel = await ctx.find_channel(flags['channel'], interactive=True, chan_type=discord.ChannelType.text) + if not channel: + return + + voice = None + if flags['voice']: + voice = await ctx.find_channel(flags['voice'], interactive=True, chan_type=discord.ChannelType.voice) + if not voice: + return + + pattern = None + if flags['pattern']: + pattern = Pattern.from_userstr(flags['pattern']) + + # Extract parameters and report lines + me = ctx.guild.me + + role_line = "" + role_error_line = "" + role_created = False + guild_perms = me.guild_permissions + if not guild_perms.manage_roles: + role_error_line = "Lacking `MANAGE ROLES` guild permission." + elif role is not None: + role_line = "Using provided group role {}.".format(role.mention) + if role >= me.top_role: + role_error_line = "Provided role {} is higher than my top role.".format(role.mention) + elif any(role.id == timer.data.roleid for timer in timers): + role_error_line = "Provided role {} is already associated to a group!".format(role.mention) + else: + # Attempt to find existing role + role = next((role for role in ctx.guild.roles if role.name.lower() == name.lower()), None) + if role: + if role >= me.top_role: + role_line = "Found existing role {}, but it is higher than my top role. ".format(role.mention) + role = None + else: + role_line = "Using existing group role {}.".format(role.mention) + + if not role: + # Create a new role + role = await ctx.guild.create_role(name=name) + role_created = True + await asyncio.sleep(0.1) # Ensure the caches are populated + role_line += "Created the study group role {}.".format(role.mention) + role_line += " This role will automatically be given to members when they join the group." + + channel = channel or ctx.ch + channel_error_line = '' + chan_perms = channel.permissions_for(me) + if not chan_perms.read_messages: + channel_error_line = "Cannot read messages in {}.".format(channel.mention) + elif not chan_perms.send_messages: + channel_error_line = "Cannot send messages in {}.".format(channel.mention) + elif not chan_perms.read_message_history: + channel_error_line = "Cannot read message history in {}.".format(channel.mention) + elif not chan_perms.embed_links: + channel_error_line = "Cannot send embeds in {}.".format(channel.mention) + elif not chan_perms.manage_messages: + channel_error_line = "Cannot manage messages in {}.".format(channel.mention) + + voice_line = "" + voice_error_line = "" + if voice is None: + voice_line = ( + "To associate a voice channel (for voice alerts or to auto-join members) " + "use `{}tconfig \"{}\" voice_channel `." + ).format(ctx.best_prefix, name) + else: + other = next( + (timer for timer in ctx.timers.get_timers_in(ctx.guild.id) if timer.voice_channelid == voice.id), + None + ) + + voice_line = ( + "Group bound to provided voice channel {}.".format(voice.mention) + ) + voice_perms = voice.permissions_for(me) + if other is not None: + voice_error_line = "{} is already bound to the group **{}**.".format( + voice.mention, + other.name + ) + elif not voice_perms.connect: + voice_error_line = "Cannot connect to voice channel." + elif not voice_perms.speak: + voice_error_line = "Cannot speak in voice channel." + elif not voice_perms.view_channel: + voice_error_line = "Cannot see voice channel." + + pattern_line = ( + "The default timer pattern (applied when the timer is reset, e.g. by `{0}reset`) is `{1}`. " + ).format( + ctx.best_prefix, + (pattern if pattern is not None else Pattern.get(0)).display(brief=True), + ) + + lines = [name_line, role_line, pattern_line, voice_line] + errors = (role_error_line, channel_error_line, voice_error_line) + if any(errors): + # A permission error occured, report and exit + error_lines = '\n'.join( + '`{}`: {} {}'.format(cat, '❌' if error else '✅', '*{}*'.format(error) if error else '') + for cat, error in zip(('Group role', 'Text channel', 'Voice channel'), errors) + ) + + embed = discord.Embed( + title="Status", + description=error_lines, + colour=discord.Colour.red() + ) + lines.append("**Couldn't create the new group due to a permission or parameter error.**") + await ctx.reply( + content='\n'.join(lines), + embed=embed + ) + if role_created: + await role.delete() + else: + # Create the new group and report + timer = ctx.timers.create_timer( + role, channel, name, + voice_channelid=voice.id if voice else None, + patternid=pattern.row.patternid if pattern else 0 + ) + if not timer: + # This shouldn't happen, due to the permission check + ctx.client.log( + "Failed to create timer!", + context='mid:{}'.format(ctx.msg.id), + level=logging.ERROR + ) + lines.append("**An unknown error occured, please try again later.**") + return await ctx.reply('\n'.join(lines)) + # TODO: Initial usage tips + # Info about cloning? + lines[0] = "**Created the study group {} in {}.**".format( + '`{}`'.format(name) if name != role.name else role.mention, + channel.mention + ) + lines[1] = "The role {} will be automatically given to members when they join the group.".format(role.mention) + lines.append("*For more advanced configuration options see `{}tconfig \"{}\"`.*".format(ctx.best_prefix, name)) + tips = ( + "• Join the new group using `{prefix}join` in {channel}{voice_msg}.\n" + "• Then start the group timer with `{prefix}start`.\n" + "• To change the pattern of work/break times instead use `{prefix}start `.\n" + " (E.g. `{prefix}start 50/10` for `50` minutes work and `10` minutes break.)\n\n" + "For more information, see `{prefix}help` for the command list and introductory guides, " + "and use `{prefix}help cmd` to get detailed help with a particular command." + ).format( + prefix=ctx.best_prefix, + channel=channel.mention, + voice_msg=", or by joining the {} voice channel.".format(voice.mention) if voice else '' + ) + await ctx.reply( + '\n'.join(lines), + embed=discord.Embed(title="Usage Tips", description=tips), + allowed_mentions=discord.AllowedMentions.none() + ) + + +@cmd('delgroup', + group="Group Admin", + aliases=('rmgroup',), + short_help="Delete a study group.") +@timer_admin() +@timer_ready() +async def cmd_delgroup(ctx): + """ + Usage``: + {prefix}delgroup + Description: + Delete a guild study group. + Examples``: + {prefix}delgroup {ctx.example_group_name} + """ + groups = ctx.timers.get_timers_in(ctx.guild.id) + group = next((group for group in groups if group.data.name.lower() == ctx.args.lower()), None) + + if group is None: + await ctx.error_reply("No group found with the name `{}`.".format(ctx.args)) + else: + if await ctx.ask("Are you sure you want to delete the group `{}`?".format(group.data.name)): + await ctx.timers.obliterate_timer(group) + await ctx.reply("Deleted the group `{}`.".format(group.data.name)) + if await ctx.ask("Do you also want to delete the group discord role **{}**?".format(group.role.name)): + try: + await group.role.delete() + except discord.HTTPException: + await ctx.reply("Failed to delete the associated role!") diff --git a/bot/commands/guides.py b/bot/commands/guides.py new file mode 100644 index 0000000..3311334 --- /dev/null +++ b/bot/commands/guides.py @@ -0,0 +1,76 @@ +import discord + +from cmdClient import cmd + + +def guide(name, **kwargs): + def wrapped(func): + # Create command + command = cmd(name, group="Guides", **kwargs)(func) + command.smart_help = func + return command + return wrapped + + +@guide("patterns", + short_help="How to change the timer work/break patterns.") +async def guide_patterns(ctx): + pattern_gif = discord.File('assets/guide-gifs/pattern-guide.gif') + embed = discord.Embed( + title="Guide to changing your timer pattern", + description=""" + A *timer pattern* is the sequence of stages the timer follows,\ + for example *50 minutes Work* followed by *10 minutes Break*. + Each timer's pattern is easily customisable, and patterns may be saved for simpler timer setup. + + The pattern is usually given as the stage durations separated by `/`.\ + For example, `50/10` represents 50 minutes work followed by 10 minutes break.\ + See the extended format below for finer control over the pattern stages. + + To modify a timer's pattern, use the `start` or `setup` commands.\ + For example, use `{prefix}start 50/10` to start your timer with a `50/10` pattern.\ + `setup` will stop the timer and change the pattern, while `start` will also restart the timer. + + Patterns always repeat forever, so in the above example, \ + after the break is finished the 50 minute work stage will start again. + + *See the gif below and `,phelp start` for more pattern usage examples.* + """.format(prefix=ctx.best_prefix) + ).add_field( + name="Extended Format", + value=""" + The *stage names* and *stage messages* of a pattern may be customised using the *extended pattern format*. + Stages are separated by `;` instead of `/`, and each stage has the form `name, duration, message`, \ + with the `message` being optional.\ + A `*` may be added after the duration to mark a stage as a "work" stage (visible in the study time summaries). + For example a custom `50/10` pattern could be given as \ + ```Study🔥, 50*, Good luck!; Break🌝, 10, Have a rest.``` + """, + inline=False + ).add_field( + name="Saving patterns", + value=""" + Patterns may also be *saved* and given names using the `savepattern` command. \ + Simply type `{prefix}savepattern pattern` (replacing `pattern` with the desired pattern), \ + and enter the name when prompted. \ + The saved pattern name may then be used wherever a pattern is required, \ + including in the `start` and `setup` commands. + """.format(prefix=ctx.best_prefix), + inline=False + ).set_image( + url='attachment://pattern-guide.gif' + ) + + await ctx.reply(embed=embed, file=pattern_gif) + + +@guide("settingup", + short_help="Setting up {ctx.client.user.name} in your server.") +async def guide_setting_up(ctx): + await ctx.reply("Setup guide coming soon!") + + +@guide("gettingstarted", + short_help="Getting started with using {ctx.client.user.name}.") +async def guide_getting_started(ctx): + await ctx.reply("Getting started guide coming soon!") diff --git a/bot/commands/guild_config.py b/bot/commands/guild_config.py new file mode 100644 index 0000000..5b1c6af --- /dev/null +++ b/bot/commands/guild_config.py @@ -0,0 +1,168 @@ +from cmdClient.checks import in_guild + +from settings import GuildSettings + +from Timer import module + + +@module.cmd( + "globalgroups", + group="Server Configuration", + short_help=("Whether groups may be joined outside their channel. " + "(`{ctx.guild_settings.globalgroups.formatted}`)") +) +@in_guild() +async def cmd_globalgroups(ctx): + """ + Usage``: + {prefix}globalgroups + {prefix}globalgroups on | off + Setting Description: + {ctx.guild_settings.settings.globalgroups.long_desc} + """ + await GuildSettings.settings.globalgroups.command(ctx, ctx.guild.id) + + +@module.cmd( + "prefix", + group="Server Configuration", + short_help=("The server command prefix. " + "(Currently `{ctx.guild_settings.prefix.formatted}`)") +) +@in_guild() +async def cmd_prefix(ctx): + """ + Usage``: + {prefix}prefix + {prefix}prefix + Setting Description: + {ctx.guild_settings.settings.prefix.long_desc} + """ + await GuildSettings.settings.prefix.command(ctx, ctx.guild.id) + + +@module.cmd( + "timeradmin", + group="Server Configuration", + short_help=("The role required for timer admin actions. " + "({ctx.guild_settings.timer_admin_role.formatted})") +) +@in_guild() +async def cmd_timeradmin(ctx): + """ + Usage``: + {prefix}timeradmin + {prefix}timeradmin + Setting Description: + {ctx.guild_settings.settings.timer_admin_role.long_desc} + Accepted Values: + Roles maybe given as their name, id, or partial name. + + *Modifying the `timeradmin` role requires the `administrator` server permission.* + """ + await GuildSettings.settings.timer_admin_role.command(ctx, ctx.guild.id) + + +@module.cmd( + "timezone", + group="Server Configuration", + short_help=("The server leaderboard timezone. " + "({ctx.guild_settings.timezone.formatted})") +) +@in_guild() +async def cmd_timezone(ctx): + """ + Usage``: + {prefix}timezone + {prefix}timezone + Setting Description: + {ctx.guild_settings.settings.timezone.long_desc} + Accepted Values: + Timezone names must be from the "TZ Database Name" column of \ + [this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). + For example, `Europe/London`, `Australia/Melbourne`, or `America/New_York`. + """ + await GuildSettings.settings.timezone.command(ctx, ctx.guild.id) + + +@module.cmd( + "studyrole", + group="Server Configuration", + short_help=("The global study role. " + "({ctx.guild_settings.studyrole.formatted})") +) +@in_guild() +async def cmd_studyrole(ctx): + """ + Usage``: + {prefix}studyrole + {prefix}studyrole + Setting Description: + {ctx.guild_settings.settings.studyrole.long_desc} + Accepted Values: + Roles maybe given as their name, id, or partial name. + """ + await GuildSettings.settings.studyrole.command(ctx, ctx.guild.id) + + +""" +@module.cmd("config", + group="Server Configuration", + short_help="View and modify server configuration.") +@in_guild() +async def cmd_config(ctx): + # Cache and map some info for faster access + setting_displaynames = {setting.display_name.lower(): setting for setting in GuildSettings.settings.values()} + + if not ctx.args or ctx.args.lower() == 'help': + # Display the current configuration, with either values or descriptions + props = { + setting.display_name: setting.get(ctx.guild.id).formatted if not ctx.args else setting.desc + for setting in GuildSettings.settings.values() + } + table = prop_tabulate(*zip(*props.items())) + embed = discord.Embed( + description="{table}\n\nUse `{prefix}config ` to view more information.".format( + prefix=ctx.best_prefix, + table=table + ), + title="Server settings" + ) + await ctx.reply(embed=embed) + else: + # Some args were given + parts = ctx.args.split(maxsplit=1) + + name = parts[0] + setting = setting_displaynames.get(name.lower(), None) + if setting is None: + return await ctx.error_reply( + "Server setting `{}` doesn't exist! Use `{}config` to see all server settings".format( + name, ctx.best_prefix + ) + ) + + if len(parts) == 1: + # config + # View config embed for provided setting + await ctx.reply(embed=setting.get(ctx.guild.id).embed) + else: + # config + # Check the write ward + if not await setting.write_ward.run(ctx): + await ctx.error_reply(setting.msg) + + # Attempt to set config setting + try: + (await setting.parse(ctx.guild.id, ctx, parts[1])).write() + except UserInputError as e: + await ctx.reply(embed=discord.Embed( + description="{} {}".format('❌', e.msg), + Colour=discord.Colour.red() + )) + else: + await ctx.reply(embed=discord.Embed( + description="{} Setting updated!".format('✅'), + Colour=discord.Colour.green() + )) +""" diff --git a/bot/commands/help.py b/bot/commands/help.py index ef80796..d09091b 100644 --- a/bot/commands/help.py +++ b/bot/commands/help.py @@ -4,9 +4,53 @@ from utils.lib import prop_tabulate from utils import interactive # noqa +from utils.timer_utils import is_timer_admin # Set the command groups to appear in the help +group_hints = { + 'Guides': "*Short general usage guides for different aspects of PomoBot.*", + 'Timer Usage': "*View and join the server groups.*", + 'Timer Control': "*Setup and control the group timers. May need timer admin permissions!*", + 'Personal Settings': "*Control how I interact with you.*", + 'Registry': "*Server leaderboard and personal study statistics.*", + 'Registry Admin': "*View and modify server session data.*", + 'Saved Patterns': "*Name custom timer patterns for faster setup.*", + 'Group Admin': "*Create, delete, and configure study groups.*", + 'Server Configuration': "*Control how I behave in your server.*", + 'Meta': "*Information about me!*" +} +standard_group_order = ( + ('Timer Usage', 'Timer Control', 'Personal Settings'), + ('Registry', 'Saved Patterns'), + ('Meta', 'Guides'), +) +admin_group_order = ( + ('Group Admin', 'Server Configuration', 'Meta', 'Guides'), + ('Timer Usage', 'Timer Control', 'Personal Settings'), + ('Registry', 'Registry Admin', 'Saved Patterns'), +) + +# Help embed format +title = "PomoBot Usage Manual and Command List" +header = """ +Flexible study group system with Pomodoro-style timers! +Supports multiple groups and custom timer patterns. +Join the [support server](https://discord.gg/MnMrQDe) \ +or make an issue on the [repository](https://github.com/Intery/PomoBot) if you have any \ +questions or issues. + +For more detailed information about each command use `{ctx.best_prefix}help `. +(For example, see `{ctx.best_prefix}help newgroup` and `{ctx.best_prefix}help start`.) +""" + +# Possible tips +tips = { + 'no_groups': "Get started by creating your first group with `{ctx.best_prefix}newgroup`!", + 'non_admin': "Use `{ctx.best_prefix}groups` to see the groups, and `{ctx.best_prefix}join` to join a group!", + 'admin': "Tweak timer behaviour with `{ctx.best_prefix}tconfig`." +} + help_groups = [ ("Timer", "*View and interact with the guild group timers.*"), ("Registry", "*Timer leaderboard and session history.*"), @@ -26,27 +70,33 @@ @cmd("help", - desc="Display information about commands.") + group="Meta", + short_help="Usage manual and command list.") async def cmd_help(ctx): """ - Usage: - help [cmdname] + Usage``: + {prefix}help [cmdname] Description: When used with no arguments, displays a list of commands with brief descriptions. Otherwise, shows documentation for the provided command. Examples: - help - help help + {prefix}help + {prefix}help join + {prefix}help newgroup """ if ctx.arg_str: # Attempt to fetch the command - command = ctx.client.cmd_cache.get(ctx.arg_str.strip(), None) + command = ctx.client.cmd_names.get(ctx.arg_str.strip(), None) if command is None: return await ctx.error_reply( ("Command `{}` not found!\n" - "Use the `help` command without arguments to see a list of commands.").format(ctx.arg_str) + "Write `{}help` to see a list of commands.").format(ctx.args, ctx.best_prefix) ) + smart_help = getattr(command, 'smart_help', None) + if smart_help is not None: + return await smart_help(ctx) + help_fields = command.long_help.copy() help_map = {field_name: i for i, (field_name, _) in enumerate(help_fields)} @@ -79,16 +129,15 @@ async def cmd_help(ctx): # Handle the related field names = [cmd_name.strip() for cmd_name in help_fields[pos][1].split(',')] names.sort(key=len) - values = [getattr(ctx.client.cmd_cache.get(cmd_name, None), 'desc', "") for cmd_name in names] + values = [ + (getattr(ctx.client.cmd_names.get(cmd_name, None), 'short_help', '') or '').format(ctx=ctx) + for cmd_name in names + ] help_fields[pos] = ( name, prop_tabulate(names, values) ) - usage_index = help_map.get("Usage", None) - if usage_index is not None: - help_fields[usage_index] = ("Usage", "`{}`".format('`\n`'.join(help_fields[usage_index][1].splitlines()))) - aliases = getattr(command, 'aliases', []) alias_str = "(Aliases `{}`.)".format("`, `".join(aliases)) if aliases else "" @@ -98,7 +147,16 @@ async def cmd_help(ctx): colour=discord.Colour(0x9b59b6) ) for fieldname, fieldvalue in help_fields: - embed.add_field(name=fieldname, value=fieldvalue, inline=False) + embed.add_field( + name=fieldname, + value=fieldvalue.format(ctx=ctx, prefix=ctx.best_prefix), + inline=False + ) + embed.add_field( + name="Still need help?", + value="Join our [support server](https://discord.gg/MnMrQDe)!" + ) + # TODO: Link to online docs embed.set_footer(text="[optional] and denote optional and required arguments, respectively.") @@ -115,7 +173,7 @@ async def cmd_help(ctx): cmd_groups[group] = cmd_group # Add the command name and description to the group - cmd_group.append((command.name, getattr(command, 'desc', ""))) + cmd_group.append((command.name, getattr(command, 'short_help', ''))) # Turn the command groups into strings stringy_cmd_groups = {} @@ -124,27 +182,33 @@ async def cmd_help(ctx): stringy_cmd_groups[group_name] = prop_tabulate(*zip(*cmd_group)) # Now put everything into a bunch of embeds - help_embeds = [] - active_fields = [] - for group_name, group_desc in help_groups: - group_str = stringy_cmd_groups.get(group_name, None) - if group_str is None: - continue - - active_fields.append((group_name, group_desc + '\n' + group_str)) - - if group_name == help_groups[-1][0] or sum([len(field.splitlines()) for _, field in active_fields]) > 10: - # Roll a new embed - embed = discord.Embed(description=help_str, colour=discord.Colour(0x9b59b6), title=help_title) - - # Add the active fields - for name, field in active_fields: - embed.add_field(name=name, value=field, inline=False) + if ctx.guild and await is_timer_admin(ctx.author): + group_order = admin_group_order + tip = tips['admin'] + else: + group_order = standard_group_order + tip = tips['non_admin'] - help_embeds.append(embed) + if ctx.guild and not ctx.timers.get_timers_in(ctx.guild.id): + tip = tips['no_groups'] - # Clear the active fields - active_fields = [] + help_embeds = [] + for page_groups in group_order: + embed = discord.Embed( + description=header.format(ctx=ctx), + colour=discord.Colour(0x9b59b6), + title=title + ) + for group in page_groups: + group_hint = group_hints.get(group, '').format(ctx=ctx) + group_str = stringy_cmd_groups.get(group, None) + if group_str: + embed.add_field( + name=group, + value="{}\n{}".format(group_hint, group_str).format(ctx=ctx), + inline=False + ) + help_embeds.append(embed) # Add the page numbers for i, embed in enumerate(help_embeds): @@ -152,6 +216,6 @@ async def cmd_help(ctx): # Send the embeds if help_embeds: - await ctx.pager(help_embeds) + await ctx.pager(help_embeds, content="**Tip:** {}".format(tip.format(ctx=ctx))) else: await ctx.reply(embed=discord.Embed(description=help_str, colour=discord.Colour(0x9b59b6))) diff --git a/bot/commands/meta.py b/bot/commands/meta.py new file mode 100644 index 0000000..afeb083 --- /dev/null +++ b/bot/commands/meta.py @@ -0,0 +1,87 @@ +import discord +from cmdClient import cmd + +from data import tables +from utils.lib import prop_tabulate + + +@cmd("about", + group="Meta", + short_help="Display some general information about me.") +async def cmd_about(ctx): + """ + Usage``: + {prefix}about + Description: + Replies with some general information about me. + """ + # Gather usage statistics + guild_row = tables.guilds.select_one_where(select_columns=('COUNT()',)) + guild_count = guild_row[0] if guild_row else 0 + + timer_row = tables.timers.select_one_where(select_columns=('COUNT()',)) + timer_count = timer_row[0] if timer_row else 0 + + session_row = tables.sessions.select_one_where(select_columns=('COUNT()', 'SUM(duration)')) + session_count = session_row[0] if session_row else 0 + session_time = session_row[1] // 3600 if session_row else 0 + + stats = { + 'Guilds': str(guild_count), + 'Timers': str(timer_count), + 'Recorded': "`{}` hours over `{}` sessions".format(session_time, session_count) + } + stats_str = prop_tabulate(*zip(*stats.items())) + + # Define links + links = { + 'Support server': "https://discord.gg/MnMrQDe", + 'Invite me!': ("https://discordapp.com/oauth2/authorize" + "?client_id=674238793431384067&scope=bot&permissions=271608912"), + 'Github page': "https://github.com/Intery/PomoBot" + } + link_str = ', '.join("[{}]({})".format(name, link) for name, link in links.items()) + + # Create embed + desc = ( + "Flexible study or work group timer using a customisable Pomodoro system.\n" + "Supports multiple groups and different timer setups.\n" + "{stats}\n\n" + "{links}" + ).format(stats=stats_str, links=link_str) + embed = discord.Embed( + description=desc, + colour=discord.Colour(0x9b59b6), + title='About Me' + ) + + # Finally send! + await ctx.reply(embed=embed) + + +@cmd("support", + group="Meta", + short_help="Sends my support server invite link.") +async def cmd_support(ctx): + """ + Usage``: + {prefix}support + Description: + Replies with the support server link. + """ + await ctx.reply("Chat with our friendly support team here: https://discord.gg/MnMrQDe") + + +@cmd("invite", + group="Meta", + short_help="Invite me.") +async def cmd_invite(ctx): + """ + Usage``: + {prefix}invite + Description: + Replies with the bot invite link. + """ + await ctx.reply("Invite PomoBot to your server with this link: {}".format( + "" + )) diff --git a/bot/commands/presets.py b/bot/commands/presets.py index 960515a..e3f7e2e 100644 --- a/bot/commands/presets.py +++ b/bot/commands/presets.py @@ -1,154 +1,294 @@ -from cmdClient import cmd +import datetime +import discord from cmdClient.checks import in_guild -from Timer import TimerInterface +from Timer import Pattern, module -from wards import timer_admin +from data import tables from utils import timer_utils, interactive, ctx_addons # noqa -from utils.lib import paginate_list +from utils.lib import prop_tabulate, paginate_list +from utils.timer_utils import is_timer_admin -def get_presets(ctx): +def _fetch_presets(ctx): """ - Get the valid setup string presets in the current context. + Fetch the current valid presets in this context. + Returns a list of the form (preset_type, preset_row). + Accounts for user pattern name overrides. """ - presets = {} - if ctx.guild: - presets.update(ctx.client.config.guilds.get(ctx.guild.id, "timer_presets") or {}) # Guild presets - presets.update(ctx.client.config.users.get(ctx.author.id, "timer_presets") or {}) # Personal presets + user_rows = tables.user_presets.select_where(userid=ctx.author.id) + guild_rows = tables.guild_presets.select_where(guildid=ctx.guild.id) - return presets + presets = {} + presets.update( + {row['preset_name'].lower(): (0, row) for row in guild_rows} + ) + presets.update( + {row['preset_name'].lower(): (1, row) for row in user_rows} + ) + return list(reversed(list(presets.values()))) -def preset_summary(setupstr): +def _format_preset(preset_type, preset_row): """ - Return a summary string of stage durations for the given setup string. + Format the available patterns into a pretty-viewable list. + Returns a list of tuples `(pattern_str, preset_type, pattern)`. """ - # First compile the preset - stages = TimerInterface.parse_setupstr(setupstr) - return "/".join(str(stage.duration) for stage in stages) + pattern = Pattern.get(preset_row['patternid']) + return "{} ({}, {})".format( + preset_row['preset_name'], + 'Personal' if preset_type == 1 else 'Server', + pattern.display(brief=True) + ) -@cmd("preset", - group="Timer", - desc="Create, view, and remove personal or guild setup string presets.", - aliases=["addpreset", "presets", "rmpreset"]) -async def cmd_preset(ctx): +@module.cmd("savepattern", + group="Saved Patterns", + short_help="Name a given pattern.") +@in_guild() +async def cmd_savepattern(ctx): """ Usage``: - presets - preset [presetname] - addpreset [presetname] - rmpreset + {prefix}savepattern Description: - Create, view, and remove personal or guild setup string presets. - See the `setup` command documentation for more information about setup string format. - - Note that the `Timer Admin` role is required to create or remove guild presets. - Forms:: - preset: Display information about the specified preset. - presets: List available personal and guild presets. - addpreset: Create a new preset. Prompts for name if not provided. - rmpreset: Remove the specified preset. - Related: - setup + See `{prefix}help patterns` for more information about timer patterns. + Examples``: + {prefix}savepattern 50/10 + {prefix}savepattern Work, 50, Good luck!; Break, 10, Have a rest. """ - presets = get_presets(ctx) - preset_list = list(presets.items()) - pretty_presets = [ - "{}\t ({})".format(name, preset_summary(preset)) - for name, preset in preset_list - ] - - if ctx.alias.lower() == "presets": - # Handle having no presets - if not pretty_presets: - return await ctx.embedreply("No presets available! Start creating presets with `addpreset`") - - # Format and return the list - pages = paginate_list(pretty_presets, title="Available Timer Presets") - return await ctx.pager(pages) - elif ctx.alias.lower() == "preset": - # Prompt for the preset if not given - if not ctx.arg_str: - preset = preset_list[await ctx.selector("Please select a preset.", pretty_presets)] - elif ctx.arg_str not in presets: - return await ctx.error_reply("Unrecognised preset `{}`.\n" - "Use `presets` to view the available presets.".format(ctx.arg_str)) + if ctx.args: + pattern = Pattern.from_userstr(ctx.args) + else: + # Get the current timer pattern, if applicable + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is not None: + pattern = sub.timer.current_pattern or sub.timer.default_pattern else: - preset = (ctx.arg_str, presets[ctx.arg_str]) - - # Build preset info - preset_info = "Preset `{}` with stages `{}`.\n```{}```".format(preset[0], preset_summary(preset[1]), preset[1]) - - # Output info - await ctx.reply(preset_info) - elif ctx.alias.lower() == "addpreset": - # Start by prompting for a name if none was given - name = ctx.arg_str or await ctx.input("Please enter a name for the new timer preset.") - - # Ragequit on names with bad characters - if "," in name or ";" in name: - return await ctx.error_reply("Preset names must not contain `,` or `;`.") - - # Prompt for the setup string - stages = None - while stages is None: - setupstr = await ctx.input( - "Please enter the timer preset setup string." + # Request pattern + pattern = Pattern.from_userstr(await ctx.input( + "Please enter the timer pattern you want to save.\n" + "**Tip**: See `{}help patterns` for more information " + "about creating or using timer patterns".format(ctx.best_prefix) + )) + + # Confirm and request name + name = await ctx.input( + "Please enter a name for this pattern." + "```{}```".format(pattern.display()) + ) + if not name: + return + + # Ask for preset type + pattern_type = 0 + if await is_timer_admin(ctx.author): + options = ( + "User Pattern (the saved pattern is available to you across all servers).", + "Server Pattern (the saved pattern is available to everyone in this server)." + ) + pattern_type = await ctx.selector("Would you like to create a User or Server pattern?", options) + + # Save preset + if pattern_type == 0: + # User preset + tables.user_presets.insert(userid=ctx.author.id, preset_name=name, patternid=pattern.row.patternid) + await ctx.reply( + "Saved the new user pattern `{name}`. " + "Apply it by joining any study group and writing `{prefix}start {name}`.".format( + prefix=ctx.best_prefix, + name=name ) - # Handle cancellation - if setupstr == "c": - return await ctx.embedreply("Preset creation cancelled by user.") - - # Parse setup string to ensure validity - stages = TimerInterface.parse_setupstr(setupstr) - if stages is None: - await ctx.error_reply("Setup string not understood.") - - # Prompt for whether to add a guild or personal preset - preset_type = 1 # 0 is Guild preset and 1 is personal preset - if await in_guild.run(ctx) and await timer_admin.run(ctx): - preset_type = await ctx.selector( - "What type of preset would you like to create?", - ["Guild preset (available to everyone in the guild)", - "Personal preset (only available to yourself)"] + ) + else: + # Guild preset + tables.guild_presets.insert( + guildid=ctx.guild.id, + preset_name=name, + created_by=ctx.author.id, + patternid=pattern.row.patternid + ) + await ctx.reply( + "Saved the new guild pattern `{name}`. " + "Any member may now apply it by joining a study group and writing `{prefix}start {name}`.".format( + prefix=ctx.best_prefix, + name=name ) + ) + + +@module.cmd("delpattern", + group="Saved Patterns", + short_help="Delete a saved pattern by name.", + aliases=('rmpattern',)) +@in_guild() +async def cmd_delpattern(ctx): + """ + Usage``: + {prefix}delpattern + Description: + Delete the given saved pattern. + """ + is_admin = await is_timer_admin(ctx.author) + is_user_preset = True + + if not ctx.args: + # Prompt for a saved pattern to remove + presets = _fetch_presets(ctx) + if not presets: + return await ctx.reply( + "No saved patterns exist yet! " + "See `{}help savepattern` for information about saving a pattern.".format( + ctx.best_prefix + ) + ) + if not is_admin: + presets = [preset for preset in presets if preset[0] == 1] + + ids = [row['patternid'] for _, row in presets] + tables.patterns.fetch_rows_where(patternid=ids) + + pretty_presets = [_format_preset(t, row) for t, row in presets] + result = await ctx.selector( + "Please select a saved pattern to remove.", + pretty_presets + ) + is_user_preset, row = presets[result] + else: + row = tables.user_presets.select_one_where(userid=ctx.author.id, preset_name=ctx.args) + if not row: + is_user_preset = False + row = tables.guild_presets.select_one_where(guildid=ctx.guild.id, preset_name=ctx.args) + if not row: + return await ctx.error_reply( + "No saved pattern found called `{}`.".format(ctx.args) + ) + + if not is_user_preset: + if is_admin: + tables.guild_presets.delete_where(guildid=ctx.guild.id, preset_name=ctx.args) + await ctx.reply("Removed saved server pattern `{}`.".format(row['preset_name'])) else: - # Non-admins don't get an option - preset_type = 1 - - # Create the preset - if preset_type == 0: - guild_presets = ctx.client.config.guilds.get(ctx.guild.id, "timer_presets") or {} - if name in guild_presets and not await ctx.ask("Preset `{}` already exists, overwrite?".format(name)): - return - guild_presets[name] = setupstr - ctx.client.config.guilds.set(ctx.guild.id, "timer_presets", guild_presets) - await ctx.embedreply("Guild preset `{}` created.".format(name)) - elif preset_type == 1: - personal_presets = ctx.client.config.users.get(ctx.author.id, "timer_presets") or {} - if name in personal_presets and not await ctx.ask("Preset `{}` already exists, overwrite?".format(name)): - return - personal_presets[name] = setupstr - ctx.client.config.users.set(ctx.author.id, "timer_presets", personal_presets) - await ctx.embedreply("Personal preset `{}` created.".format(name)) - elif ctx.alias.lower() == "rmpreset": - # Handle trying to remove nonexistent preset - if not ctx.arg_str: - return await ctx.error_reply("Please provide a preset to remove.") - if ctx.arg_str not in presets: - return await ctx.error_reply("Unrecognised preset `{}`.".format(ctx.arg_str)) - - personal_presets = ctx.client.config.users.get(ctx.author.id, "timer_presets") or {} - if ctx.arg_str in personal_presets: - personal_presets.pop(ctx.arg_str) - ctx.client.config.users.set(ctx.author.id, "timer_presets", personal_presets) - else: - if not await timer_admin.run(ctx): - return await ctx.error_reply("You need to be a timer admin to remove guild presets.") - guild_presets = ctx.client.config.guilds.get(ctx.guild.id, "timer_presets") or {} - guild_presets.pop(ctx.arg_str) - ctx.client.config.guilds.set(ctx.guild.id, "timer_presets", guild_presets) + await ctx.error_reply("You need timer admin permissions to remove a saved server pattern!") + else: + tables.user_presets.delete_where(userid=ctx.author.id, preset_name=ctx.args) + await ctx.reply("Removed saved personal pattern `{}`.".format(row['preset_name'])) + + +@module.cmd("savedpatterns", + group="Saved Patterns", + short_help="View the accessible saved patterns.", + aliases=('presets', 'patterns', 'showpatterns')) +@in_guild() +async def cmd_savedpatterns(ctx): + """ + Usage``: + {prefix}savedpatterns + Description: + List the personal and server-wide saved patterns accessible for custom timer setup. + """ + presets = _fetch_presets(ctx) + if not presets: + return await ctx.reply( + "No saved patterns exist yet! See `{}help savepattern` for information about saving a pattern.".format( + ctx.best_prefix + ) + ) + + ids = [row['patternid'] for _, row in presets] + tables.patterns.fetch_rows_where(patternid=ids) + + pretty_presets = [_format_preset(t, row) for t, row in presets] + + await ctx.pager( + paginate_list(pretty_presets, title="Saved Patterns") + ) + + +@module.cmd("showpattern", + group="Saved Patterns", + short_help="View details about a saved pattern.") +@in_guild() +async def cmd_showpattern(ctx): + """ + Usage``: + {prefix}showpattern + Description: + Show details about the provided saved pattern. + """ + is_user_preset = True + if not ctx.args: + # Prompt for a saved pattern to display + presets = _fetch_presets(ctx) + if not presets: + return await ctx.reply( + "No saved patterns exist yet! See `{}help savepattern` for information about saving a pattern.".format( + ctx.best_prefix + ) + ) + + ids = [row['patternid'] for _, row in presets] + tables.patterns.fetch_rows_where(patternid=ids) + + pretty_presets = [_format_preset(t, row) for t, row in presets] + result = await ctx.selector( + "Please select a saved pattern to view.", + pretty_presets + ) + is_user_preset, row = presets[result] + else: + row = tables.user_presets.select_one_where(userid=ctx.author.id, preset_name=ctx.args) + if not row: + is_user_preset = False + row = tables.guild_presets.select_one_where(guildid=ctx.guild.id, preset_name=ctx.args) + + if not row: + return await ctx.error_reply( + "No saved pattern found called `{}`.".format(ctx.args) + ) + + # Extract pattern information + pid = row['patternid'] + pattern = Pattern.get(pid) + + if is_user_preset: + session_data = tables.sessions.select_one_where( + select_columns=('SUM(duration)', ), + patternid=pid, + userid=ctx.author.id + ) + setup_data = tables.timer_pattern_history.select_one_where( + select_columns=('COUNT()', ), + patternid=pid, + modified_by=ctx.author.id + ) + else: + session_data = tables.sessions.select_one_where( + select_columns=('SUM(duration)', ), + patternid=pid, + guildid=ctx.guild.id + ) + setup_data = tables.timer_pattern_history.select_one_where( + select_columns=('COUNT()', ), + patternid=pid, + timerid=[timer.role.id for timer in ctx.timers.get_timers_in(ctx.guild.id)] + ) + total_dur = session_data[0] or 0 + times_used = setup_data[0] or 0 - await ctx.embedreply("Preset `{}` has been deleted.".format(ctx.arg_str)) + table = prop_tabulate( + ('Created by', 'Used', 'Used for'), + ("<@{}>".format(ctx.author.id if is_user_preset else row['created_by']) if row['created_by'] else "Unknown", + "{} times".format(times_used), + "{:.1f} hours (total session duration)".format(total_dur / 3600)) + ) + embed = discord.Embed( + title="{} Pattern `{}`".format('User' if is_user_preset else 'Guild', row['preset_name']), + description=table, + timestamp=datetime.datetime.utcfromtimestamp(row['created_at']) + ).set_footer( + text='Created At' + ).add_field( + name='Pattern', + value="```{}```".format(pattern.display()) + ) + await ctx.reply(embed=embed) diff --git a/bot/commands/registry.py b/bot/commands/registry.py index 1b1b731..49cbe3f 100644 --- a/bot/commands/registry.py +++ b/bot/commands/registry.py @@ -1,233 +1,561 @@ import datetime as dt -import discord +import json -from cmdClient import cmd -from cmdClient import checks +import pytz +import discord +from cmdClient.checks import in_guild -from utils import interactive # noqa +from Timer import module, Pattern +from Timer.lib import parse_dur, now -from Timer import Timer +from data import tables +from data.queries import get_session_user_totals +from utils.lib import paginate_list, timestamp_utcnow, prop_tabulate +from wards import timer_admin, has_timers +from settings import UserSettings -@cmd("history", - group="Registry", - desc="Display a list of past sessions in the current guild.", - aliases=['hist']) -@checks.in_guild() -async def cmd_hist(ctx): +@module.cmd("leaderboard", + group="Registry", + short_help="Server study leaderboards over a given time period.", + aliases=('lb',)) +@has_timers() +async def cmd_leaderboard(ctx): """ Usage``: - history + {prefix}lb [day | week | month | year] Description: - Display a list of your past timer sessions in the current guild. - All times are given in UTC. + Display the server study board in the given timeframe (or all time). + + The timeframe is determined using the *guild timezone* (see `{prefix}timezone`). + Examples``: + {prefix}lb + {prefix}lb week + {prefix}lb year """ - # Get the past sessions for this user - sessions = ctx.client.interface.registry.get_sessions_where(userid=ctx.author.id, guildid=ctx.guild.id) + # Extract the target timeframe + title = None + period_start = None + spec = ctx.args.lower() + timezone = ctx.guild_settings.timezone.value + day_start = dt.datetime.now(tz=timezone).replace(hour=0, minute=0, second=0, microsecond=0) + if not spec or spec == 'all': + period_start = None + title = "All-Time Leaderboard" + elif spec == 'day': + period_start = day_start + title = "Daily Leaderboard" + elif spec == 'week': + period_start = day_start - dt.timedelta(days=day_start.weekday()) + title = "Weekly Leaderboard" + elif spec == 'month': + period_start = day_start.replace(day=1) + title = "{} Leaderboard".format(period_start.strftime('%B')) + elif spec == 'year': + period_start = day_start.replace(month=1, day=1) + title = "{} Leaderboard".format(period_start.year) + else: + return await ctx.error_reply( + "Unrecognised timeframe `{}`.\n" + "**Usage:**`{}leaderboard day | month | week | year`".format(ctx.args, ctx.best_prefix) + ) + start_ts = int(period_start.astimezone(pytz.utc).timestamp() if period_start else 0) + + # lb data from saved sessions + lb_rows = get_session_user_totals(start_ts, guildid=ctx.guild.id) + + # Currently running sessions + subscribers = { + sub.userid: sub + for timer in ctx.timers.get_timers_in(ctx.guild.id) for sub in timer.subscribers.values() + if sub.session + } + + # Calculate names and totals + names = {} + user_totals = {} + for row in lb_rows: + names[row['userid']] = row['name'] or str(row['userid']) + user_totals[row['userid']] = row['total'] + + max_unsaved = int(timestamp_utcnow() - start_ts) + for uid, sub in subscribers.items(): + if sub.member: + names[uid] = sub.member.name + elif uid not in names: + names[uid] = sub.name + + user_totals[uid] = user_totals.get(uid, 0) + min((sub.unsaved_time, max_unsaved)) + + if not user_totals: + return await ctx.reply( + "No session data to show! " + "Join a running timer to start recording data." + ) + + # Sort based on total duration + sorted_totals = sorted( + [(uid, names[uid], user_totals[uid]) for uid in user_totals], + key=lambda tup: tup[2], + reverse=True + ) - # Get the current timer if it exists - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) + # Format and find index of author + lb_strings = [] + author_index = None + max_name_len = min((30, max(len(name) for name in names.values()))) + for i, (uid, name, total) in enumerate(sorted_totals): + if author_index is None and uid == ctx.author.id: + author_index = i + lb_strings.append( + "{:<{}}\t{:<9}".format( + name, + max_name_len, + parse_dur(total, show_seconds=True) + ) + ) - # Quit if we don't have anything - if not sessions and not timer: - return await ctx.reply("You have not completed any timer sessions!") + page_len = 20 + pages = paginate_list(lb_strings, block_length=page_len, title=title) + start_page = author_index // page_len if author_index is not None else 0 - # Get today's date and timestamp - today = dt.datetime.utcnow().date() - today = dt.datetime(today.year, today.month, today.day) - today_ts = dt.datetime.timestamp(today) + await ctx.pager( + pages, + start_at=start_page + ) - # Build a sorted list of the author's sessions - session_table = sorted( - [(sesh['starttime'], sesh['duration']) for sesh in sessions], - key=lambda tup: tup[0], - reverse=True + +_history_pattern = """\ +{tip}```md +{day} ({tz}) (Page {page}/{page_count}) + +Period | Duration | Focused | Pattern +--------------------------------------------- +{sessions} ++-----------------------------------+ +{total} +``` +""" +_history_session_pattern = ( + "{start} - {end} | {duration} | {focused} | {pattern}" +) +_history_total_pattern = ( + "{start} - {end} | {duration} | {focused}" +) + + +@module.cmd("history", + group="Registry", + short_help="Show your personal study session history.", + aliases=('hist',)) +@in_guild() +async def cmd_history(ctx): + """ + Usage``: + {prefix}history + Description: + Display your day by day study session history. + + The times are determined using your personal timezone (see `{prefix}mytimezone`). + """ + timezone = ctx.author_settings.timezone.value + has_default_tz = (tables.users.fetch_or_create(ctx.author.id).timezone is None) + tip = ( + "**Tip:** Use `{}mytimezone` to view your sessions in your own timezone!".format(ctx.best_prefix) + ) if has_default_tz else '' + + # Get the saved session rows, ordered by newest first + rows = tables.session_patterns.select_where( + _extra="ORDER BY start_time DESC", + guildid=ctx.guild.id, + userid=ctx.author.id ) - # Add the current session if it exists - if timer: - sesh_data = timer.subscribed[ctx.author.id].session_data() - session_table.insert(0, (sesh_data[3], sesh_data[4])) - - # Build the map (date_string, [session strings]) - day_sessions = [] - - current_offset = 0 - current_sessions = [] - current_total = 0 - for start, dur in session_table: - # Get current offset and corresponding session list - date_offset = (today_ts - start) // (60 * 60 * 24) + 1 - - # If we have a new offset, generate and store the old day's data - if date_offset > current_offset: - if current_sessions: - day_str = (today - dt.timedelta(current_offset)).strftime("%A, %d %b %Y") - dur_str = "{:<13} {}".format("Total:", _parse_duration(current_total)) - day_sessions.append((day_str, current_sessions, dur_str)) - - current_offset = date_offset - current_sessions = [] - current_total = 0 - - # Generate the session string - sesh_str = "{} - {} -- {}".format( - dt.datetime.fromtimestamp(start).strftime("%H:%M"), - dt.datetime.fromtimestamp(start + dur).strftime("%H:%M"), - _parse_duration(dur) - ) - current_sessions.append(sesh_str) - - current_total += dur - - # Add the last day - # TODO: Is there a nicer recipe for this? - if current_sessions: - day_str = (today - dt.timedelta(current_offset)).strftime("%A, %d %b %Y") - dur_str = "{:<13} {}".format("Total:", _parse_duration(current_total)) - day_sessions.append((day_str, current_sessions, dur_str)) - - # Make the pages + # Add the current session, if it exists + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub and sub.session: + start_time = sub.session_started + current_duration = sub.unsaved_time + focused_duration = sum( + t[0] * stage.duration * 60 + t[1] + for t, stage in zip(sub.session, sub.timer.current_pattern) + if stage.focus + ) + if sub.timer.current_stage.focus: + focused_duration += (now() - sub.timer.stage_start) + + pattern = tables.patterns.fetch(sub.timer.current_pattern.row.patternid).stage_str + current_row = { + 'start_time': start_time, + 'duration': current_duration, + 'focused_duration': focused_duration, + 'stage_str': pattern + } + rows = [current_row, *rows] + + if not rows: + return await ctx.reply( + "You have no recorded sessions! Join a running timer to start recording study time!" + ) + + # Bin these into days + day_rows = {} + for row in rows: + # Get the row day + start_day = ( + dt.datetime + .utcfromtimestamp(row['start_time']) + .replace(tzinfo=pytz.utc) + .astimezone(timezone) + .strftime("%A, %d/%b/%Y") + ) + if start_day not in day_rows: + day_rows[start_day] = [row] + else: + day_rows[start_day].append(row) + + # Create the pages + # TODO: If there are too many sessions in a day (~30), this may cause overflow issues pages = [] - num = len(day_sessions) - for i, (day_str, sessions, total_str) in enumerate(day_sessions): - page_str = " ({}/{})".format(i+1, num) if num > 1 else "" - header = day_str + page_str - - page = ( - "All times are in UTC! The current time in UTC is {now}.\n" - "```md\n" - "{header}\n" - "{header_rule}\n" - "{session_list}\n" - "{total_rule}\n" - "{total_str}" - "```" - ).format( - now=dt.datetime.utcnow().strftime("**%H:%M** on **%d %b %Y**"), - header=header, - header_rule='=' * len(header), - session_list='\n'.join(sessions), - total_rule='+' + (len(total_str) - 2) * '-' + '+', - total_str=total_str - ) - pages.append(page) - - # Finally, run the pager + page_count = len(day_rows) + for i, (day, rows) in enumerate(day_rows.items()): + # Sort day sessions in time ascending order + rows.reverse() + + # Build session lines + row_lines = [] + for row in rows: + start = ( + dt.datetime + .utcfromtimestamp(row['start_time']) + .replace(tzinfo=pytz.utc) + .astimezone(timezone) + ) + end = start + dt.timedelta(seconds=row['duration']) + if row['stage_str']: + stages = json.loads(row['stage_str']) + pattern = '/'.join(str(stage[1]) if i < 6 else '...' for i, stage in enumerate(stages[:7])) + else: + pattern = '' + row_lines.append( + _history_session_pattern.format( + start=start.strftime("%H:%M"), + end=end.strftime("%H:%M"), + duration=parse_dur(row['duration'] or 0, show_seconds=True), + focused=parse_dur(row['focused_duration'] or 0, show_seconds=True), + pattern=pattern + ) + ) + sessions = '\n'.join(row_lines) + + # Build total info + start = ( + dt.datetime + .utcfromtimestamp(rows[0]['start_time']) + .replace(tzinfo=pytz.utc) + .astimezone(timezone) + .strftime("%H:%M") + ) + end = ( + dt.datetime + .utcfromtimestamp(rows[-1]['start_time'] + rows[-1]['duration']) + .replace(tzinfo=pytz.utc) + .astimezone(timezone) + .strftime("%H:%M") + ) + duration = sum(row['duration'] for row in rows) + focused = sum(row['focused_duration'] or 0 for row in rows) + total_str = _history_total_pattern.format( + start=start, + end=end, + duration=parse_dur(duration, show_seconds=True), + focused=parse_dur(focused, show_seconds=True) + ) + + # Add to page list + pages.append( + _history_pattern.format( + tip=tip, + day=day, + tz=timezone, + page=i+1, + page_count=page_count, + sessions=sessions, + total=total_str + ) + ) await ctx.pager(pages) -def _parse_duration(dur): - dur = int(dur) - hours = dur // 3600 - minutes = (dur % 3600) // 60 - seconds = dur % 60 +def utctimestamp(aware_dt): + return int(aware_dt.astimezone(pytz.utc).timestamp()) + - return "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) +def _get_user_time_since(guildid, userid, period_start): + start_ts = int(utctimestamp(period_start) if period_start else 0) + rows = get_session_user_totals(start_ts, guildid=guildid, userid=userid) + return rows[0]['total'] if rows else 0 -@cmd("leaderboard", - group="Registry", - desc="Display total member group time in the last day/week/month or all-time.", - aliases=['lb']) -@checks.in_guild() -async def cmd_lb(ctx): +@module.cmd("stats", + group="Registry", + short_help="View a table of personal study statistics.", + aliases=('profile',)) +@has_timers() +async def cmd_stats(ctx): """ Usage``: - lb [day | week | month] + {prefix}stats + {prefix}stats Description: - Display the total timer time of each guild member, within the specified period. - The periods are rolling, i.e. `day` means the last 24h. - Without a period specified, the all-time totals will be shown. - Parameters:: - day: Show totals of sessions within the last 24 hours - week: Show totals of sessions within the last 7 days - month: Show totals of sessions within the last 31 days + View summary study statistics for yourself or the mentioned user. """ - out_msg = await ctx.reply("Generating leaderboard, please wait.") - - # Get the past sessions for this guild - sessions = ctx.client.interface.registry.get_sessions_where(guildid=ctx.guild.id) - - if not sessions: - return await ctx.reply("This guild has no past group sessions! Please check back soon.") - - # Current utc timestamp - now = Timer.now() - - # Determine maximum time separation allowed for sessions - region = ctx.arg_str.lower().strip() - if not region or region == 'all': - max_dist = now - head = "All-time leaderboard" - elif region == 'day': - max_dist = 60 * 60 * 24 - head = "Daily leaderboard" - elif region == 'week': - max_dist = 60 * 60 * 24 * 7 - head = "Weekly leaderboard" - elif region == 'month': - max_dist = 60 * 60 * 24 * 31 - head = "Monthly leaderboard" + target = None + if ctx.args: + maybe_id = ctx.args.strip('') + if not maybe_id.isdigit(): + return await ctx.error_reply( + "**Usage:** `{}stats [mention]`\n" + "Couldn't parse `{}` as a user mention or id!".format(ctx.best_prefix, ctx.args) + ) + targetid = int(maybe_id) + target = ctx.guild.get_member(targetid) else: - return await ctx.error_reply("Unknown region specification `{}`.".format(ctx.arg_str)) - - # Tally total session times - total_dict = {} - for session in sessions: - if now - session['starttime'] > max_dist: - continue - - if session['userid'] not in total_dict: - total_dict[session['userid']] = 0 - total_dict[session['userid']] += session['duration'] - - for guildid, userid in ctx.client.interface.subscribers: - if guildid == ctx.guild.id: - sub_data = ctx.client.interface.subscribers[(guildid, userid)].session_data() - if userid not in total_dict: - total_dict[userid] = 0 - total_dict[userid] += sub_data[4] - - # Reshape and sort the totals - totals = sorted(list(total_dict.items()), key=lambda tup: tup[1], reverse=True) - - # Build the string pairs - total_strs = [] - for userid, total in totals: - # Find the user - user = ctx.client.get_user(userid) - if user is None: - try: - user = await ctx.client.fetch_user(userid) - user_str = user.name - except discord.NotFound: - user_str = str(userid) + target = ctx.author + targetid = ctx.author.id + + timezone = UserSettings(targetid).timezone.value + day_start = dt.datetime.now(tz=timezone).replace(hour=0, minute=0, second=0, microsecond=0) + + sub = ctx.timers.get_subscriber(targetid, ctx.guild.id) + if not target and sub and sub.member: + target = sub.member + unsaved = sub.unsaved_time if sub else 0 + + # Total session count and duration + summary_row = tables.sessions.select_one_where( + select_columns=('COUNT() AS count', 'SUM(duration) AS total'), + userid=targetid, + guildid=ctx.guild.id + ) + if not summary_row['count']: + if target == ctx.author: + return await ctx.embed_reply( + "You have no recorded sessions! Join a running timer to start recording study time!" + ) else: - user_str = user.name + return await ctx.embed_reply( + "<@{}> has no recorded sessions!".format(targetid) + ) + session_count = summary_row['count'] + total_duration = summary_row['total'] + + # Favourites + pattern_rows = tables.session_patterns.select_where( + select_columns=('SUM(duration) AS total', 'patternid'), + _extra="GROUP BY patternid ORDER BY total DESC LIMIT 5", + userid=targetid, + guildid=ctx.guild.id + ) + print([dict(row) for row in pattern_rows]) + pattern_pairs = [ + (Pattern.get(row['patternid']).display(brief=True, truncate=6) if row['patternid'] is not None else "Unknown", + row['total']) + for row in pattern_rows + ] + max_len = max(len(p) for p, _ in pattern_pairs) + pattern_block = "```{}```".format( + '\n'.join( + "{:<{}} - {} ({}%)".format( + pattern, + max_len, + parse_dur(total), + (total * 100) // total_duration + ) + for pattern, total in pattern_pairs + ) + ) - total_strs.append((user_str, _parse_duration(total))) + timer_rows = tables.session_patterns.select_where( + select_columns=('SUM(duration) AS total', 'roleid'), + _extra="GROUP BY roleid ORDER BY total DESC LIMIT 5", + userid=targetid, + guildid=ctx.guild.id + ) + timer_pairs = [] + for row in timer_rows: + timer_row = tables.timers.fetch(row['roleid']) + if timer_row: + name = timer_row.name + else: + name = str(row['roleid']) + timer_pairs.append((name, row['total'])) + max_len = max(len(t) for t, _ in timer_pairs) + timer_block = "```{}```".format( + '\n'.join( + "{:^{}} - {} ({}%)".format( + timer_name, + max_len, + parse_dur(total), + (total * 100) // total_duration + ) + for timer_name, total in timer_pairs + ) + ) - # Build pages in groups of 20 - blocks = [total_strs[i:i+20] for i in range(0, len(total_strs), 20)] - max_block_lens = [len(max(list(zip(*block))[0], key=len)) for block in blocks] - page_blocks = [["{0[0]:^{max_len}} {0[1]:>10}".format(pair, max_len=max_block_lens[i]) for pair in block] - for i, block in enumerate(blocks)] + # Calculate streak and first session + streak = 0 + day_window = (day_start, day_start + dt.timedelta(days=1)) + ts_window = (utctimestamp(day_window[0]), utctimestamp(day_window[1])) - num = len(page_blocks) - pages = [] - for i, block in enumerate(page_blocks): - header = head + " (Page {}/{})".format(i+1, num) if num > 1 else head - header_rule = "=" * len(header) - page = "```md\n{}\n{}\n{}```".format( - header, - header_rule, - "\n".join(block) + session_rows = tables.sessions.select_where( + select_columns=('start_time', 'start_time + duration AS end_time'), + _extra="ORDER BY start_time DESC", + guildid=ctx.guild.id, + userid=targetid + ) + first_session_ts = session_rows[-1]['start_time'] + session_periods = ((row['start_time'], row['end_time']) for row in session_rows) + + # Account for the current day + start_time, end_time = (session_rows[0]['start_time'], session_rows[0]['end_time']) if session_rows else (0, 0) + if sub or end_time > ts_window[0]: + streak += 1 + day_window = (day_window[0] - dt.timedelta(days=1), day_window[0]) + ts_window = (utctimestamp(day_window[0]), ts_window[0]) + + for start, end in session_periods: + if end < ts_window[0]: + break + elif start < ts_window[1]: + streak += 1 + day_window = (day_window[0] - dt.timedelta(days=1), day_window[0]) + ts_window = (utctimestamp(day_window[0]), ts_window[0]) + + # Binned time totals + time_totals = {} + total_fields = { + 'Today': day_start, + 'This Week': day_start - dt.timedelta(days=day_start.weekday()), + 'This Month': day_start.replace(day=1), + 'This Year': day_start.replace(month=1, day=1), + 'All Time': None + } + for name, start in total_fields.items(): + time_totals[name] = parse_dur( + _get_user_time_since(ctx.guild.id, targetid, start) + unsaved, + show_seconds=False + ) + subtotal_table = prop_tabulate(*zip(*time_totals.items())) + + # Format stats into the final embed + desc = ( + "**{}** sessions completed, with a total of **{}** hours." + ).format(session_count, total_duration // 3600) + if sub: + desc += "\nCurrently studying in **{}** (in {}) for **{}**!".format( + sub.timer.name, + sub.timer.channel.mention, + parse_dur(sub.clocked_time + unsaved, show_seconds=True) ) - pages.append(page) - await out_msg.delete() - if not pages: - return await ctx.reply("No entries exist in the given range!") + embed = ( + discord.Embed( + title="Study Statistics", + description=desc, + timestamp=dt.datetime.utcfromtimestamp(first_session_ts) + ) + .set_footer(text="Studying Since") + .add_field( + name="Subtotals", + value=subtotal_table, + inline=True + ) + .add_field( + name="Streak", + value="**{}** days!".format(streak), + inline=True + ) + .add_field( + name="Favourite Patterns", + value=pattern_block, + inline=False + ) + .add_field( + name="Favourite Groups", + value=timer_block, + inline=True + ) + ) + if not target: + try: + target = await ctx.guild.fetch_member(targetid) + except discord.HTTPException: + pass + if target: + embed.set_author(name=target.name, icon_url=target.avatar_url) + else: + row = tables.users.fetch(targetid) + name = row.name if row else str(target.id) + embed.set_author(name=name) + await ctx.reply(embed=embed) + - await ctx.pager(pages, locked=False) +@module.cmd("clearregistry", + group="Registry Admin", + short_help="Remove all session history in this server.") +@timer_admin() +async def cmd_clearregistry(ctx): + """ + Usage``: + {prefix}clearregistry + Description: + Remove **all** session history in the server. + This will reset the server leaderboard, along with all personal statistics (including `stats` and `hist`). + ***This cannot be undone.*** + + *This command requires timer admin permissions.* + """ + prompt = ( + "Are you sure you want to delete **all** session history in this server? " + "This will reset the leaderboard and all member history. " + "**This cannot be undone**." + ) + if not await ctx.ask(prompt): + return + tables.sessions.delete_where(guildid=ctx.guild.id) + await ctx.reply("All session data has been deleted.") + + +""" +@module.cmd("forgetuser", + group="Registry Admin", + short_help="Remove all session history for a given member.") +@in_guild() +async def cmd_forgetuser(ctx): + ... + + +@module.cmd("delsession", + group="Registry Admin", + short_help="Remove a selected session from the registry.") +@in_guild() +async def cmd_delsession(ctx): + ... + + +@module.cmd("showsessions", + group="Registry Admin", + short_help="Show recent study sessions, with optional filtering.") +@in_guild() +async def cmd_showsessions(ctx): + ... + + +@module.cmd("showtimerhistory", + group="Registry Admin", + short_help="Show the pattern log for a given timer.") +@in_guild() +async def cmd_showtimerhistory(ctx): + ... +""" diff --git a/bot/commands/timer.py b/bot/commands/timer.py index 1584cc1..7517d23 100644 --- a/bot/commands/timer.py +++ b/bot/commands/timer.py @@ -1,89 +1,116 @@ -# import datetime -import discord -from cmdClient import cmd -from cmdClient.checks import in_guild +import datetime +import asyncio -from Timer import TimerState, NotifyLevel +import discord -from utils import timer_utils, interactive, ctx_addons # noqa +from meta import client +from data import tables -from wards import timer_ready +from Timer import TimerState, Pattern, module -from presets import get_presets +from utils import timer_utils, interactive, ctx_addons # noqa +from utils.live_messages import live_edit +from utils.timer_utils import is_timer_admin +from wards import has_timers -@cmd("join", - group="Timer", - desc="Join a group bound to the current channel.", - aliases=['sub']) -@in_guild() -@timer_ready() +@module.cmd("join", + group="Timer Usage", + short_help="Join a study group.", + aliases=['sub']) +@has_timers() async def cmd_join(ctx): """ Usage``: - join - join + {prefix}join + {prefix}join Description: - Join a group in the current channel or guild. - If there are multiple matching groups, or no group is provided, - will show the group selector. + Join a study group, and subscribe to the group timer notifications. + When used with no arguments, displays a selection prompt with the available groups. + + The `group` may be given as a group name or partial name. \ + See `{prefix}groups` for the list of groups in this server. Related: leave, status, groups, globalgroups Examples``: - join espresso + {prefix}join {ctx.example_group_name} """ # Get the timer they want to join - globalgroups = ctx.client.config.guilds.get(ctx.guild.id, 'globalgroups') - timer = await ctx.get_timers_matching(ctx.arg_str, channel_only=(not globalgroups), info=True) + globalgroups = ctx.guild_settings.globalgroups.value + timer = await ctx.get_timers_matching(ctx.args, channel_only=(not globalgroups), info=True) if timer is None: - return await ctx.error_reply( - ("No matching groups in this {}.\n" - "Use the `groups` command to see the groups in this guild!").format( - 'guild' if globalgroups else 'channel' - ) - ) + if not ctx.timers.get_timers_in(ctx.guild.id): + await ctx.error_reply( + "There are no study groups to join!\n" + "Create a new study group with `{prefix}newgroup ` " + "(e.g. `{prefix}newgroup Pomodoro`).".format(prefix=ctx.best_prefix) + ) + elif not globalgroups and not ctx.timers.get_timers_in(ctx.guild.id, ctx.ch.id): + await ctx.error_reply( + "No study groups in this channel!\n" + "Use `{prefix}groups` to see all server groups.".format(prefix=ctx.best_prefix) + ) + else: + await ctx.error_reply( + ("No matching groups in this {}.\n" + "Use `{}groups` to see the server study groups!").format( + 'server' if globalgroups else 'channel', + ctx.best_prefix + ) + ) + return # Query if the author is already in a group - current_timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if current_timer is not None: - if current_timer == timer: - return await ctx.error_reply("You are already in this group!\n" - "Use `status` to see the current timer status.") - - chan_info = " in {}".format(current_timer.channel.mention) if current_timer.channel != ctx.ch else "" - result = await ctx.ask("You are already in the group `{}`{}.\nAre you sure you want to switch?".format( - current_timer.name, - chan_info - )) - if not result: - return - - await current_timer.subscribed[ctx.author.id].unsub() + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is not None: + if sub.timer == timer: + return await ctx.error_reply( + "You are already in this study group!\n" + "Use `{prefix}status` to see the current timer status.".format(prefix=ctx.best_prefix) + ) + else: + result = await ctx.ask( + "You are already in the group **{}**{}. " + "Are you sure you want to switch groups?".format( + sub.timer.name, + " in {}".format(sub.timer.channel.mention) if ctx.ch != sub.timer.channel else "" + ) + ) + if not result: + return + # TODO: Vulnerable to interactive race-states + await sub.timer.unsubscribe(ctx.author.id) # Subscribe the member - await ctx.client.interface.sub(ctx, ctx.author, timer) + new_sub = await timer.subscribe(ctx.author) + if sub: + new_sub.clocked_time = sub.clocked_time - # Specify channel info if they are joining from a different channel + # Check if member is joining from a different channel this_channel = (timer.channel == ctx.ch) - chan_info = " in {}".format(timer.channel.mention) if not this_channel else "" - # Reply with the join message - message = "You have joined the group **{}**{}!".format(timer.name, chan_info) + # Build and send the join message + message = "You have {} **{}**{}!".format( + 'switched to' if sub is not None else 'joined', + timer.name, + " in {}".format(timer.channel.mention) if not this_channel else "" + ) if ctx.author.bot: message += " Good luck, colleague!" if timer.state == TimerState.RUNNING: - message += "\nCurrently on stage **{}** with **{}** remaining. {}".format( - timer.stages[timer.current_stage].name, - timer.pretty_remaining(), - timer.stages[timer.current_stage].message + message += " Currently on stage **{}** with **{}** remaining. {}".format( + timer.current_stage.name, + timer.pretty_remaining, + timer.current_stage.message ) - elif timer.stages: - message += "\nGroup timer is set up but not running. Use `start` to start the timer!" else: - message += "\nSet up the timer with `set`!" + message += ( + "\nTimer is not running! Start it with `{prefix}start` " + "(or `{prefix}start [pattern]` to use a custom timer pattern)." + ).format(prefix=ctx.best_prefix) - await ctx.reply(message) + await ctx.reply(message, reference=ctx.msg) # Poke a welcome message to the timer channel if we are somewhere else if not this_channel: @@ -93,439 +120,592 @@ async def cmd_join(ctx): )) -@cmd("leave", - group="Timer", - desc="Leave your current group.", - aliases=['unsub']) -@in_guild() -@timer_ready() +@module.cmd("leave", + group="Timer Usage", + short_help="Leave your current group.", + aliases=['unsub']) +@has_timers() async def cmd_unsub(ctx): """ Usage``: - leave + {prefix}leave Description: - Leave your current group, and unsubscribe from the group timer. + Leave your study group. Related: join, status, groups """ - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + await ctx.error_reply( + "You are not in a study group! Join one with `{prefix}join`.".format(prefix=ctx.best_prefix) + ) + else: + await sub.timer.unsubscribe(ctx.author.id) + await ctx.reply( + "You left **{}**! You were subscribed for **{}**.".format( + sub.timer.name, + sub.pretty_clocked + ), + reference=ctx.msg + ) + + +@module.cmd("setup", + group="Timer Control", + short_help="Stop and change your group timer pattern.", + aliases=['set']) +@has_timers() +async def cmd_set(ctx): + """ + Usage``: + {prefix}setup + {prefix}setup + Description: + Sets your group timer pattern (i.e. the pattern of work/break periods). + + See `{prefix}help patterns` and the examples below for more information about the pattern format. \ + A saved pattern name (see `{prefix}help presets`) may also be used in place of the associated pattern. + + *If the `admin_locked` timer option is set, this command requires timer admin permissions.* + Related: + join, start, reset, savepattern + Examples: + `{prefix}setup 50/10` (`50` minutes work followed by `10` minutes break.) + `{prefix}setup 25/5/25/5/25/10` (A standard Pomodoro pattern of work and breaks.) + `{prefix}setup Study, 50; Rest, 10` (Another `50/10` pattern, now with custom stage names.) + """ + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: return await ctx.error_reply( - "You need to join a group before you can leave one!" + "You are not in a study group! Join one with `{prefix}join`.".format(prefix=ctx.best_prefix) ) - session = await ctx.client.interface.unsub(ctx.guild.id, ctx.author.id) - clocked = session[-1] + if sub.timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be setup by timer admins.") - dur = int(clocked) - hours = dur // 3600 - minutes = (dur % 3600) // 60 - seconds = dur % 60 + if not ctx.args: + return await ctx.error_reply( + "Please provide a timer pattern! See `{}help setup` for usage".format(ctx.best_prefix) + ) - dur_str = "{:02d}:{:02d}:{:02d}".format(hours, minutes, seconds) + if sub.timer.state == TimerState.RUNNING: + if not await ctx.ask("Are you sure you want to **stop and reset** your study group timer?"): + return - await ctx.reply("You have been unsubscribed from **{}**! You were subscribed for **{}**.".format( - timer.name, - dur_str - )) + pattern = Pattern.from_userstr(ctx.args, timerid=sub.timer.roleid, userid=ctx.author.id, guildid=ctx.guild.id) + await sub.timer.setup(pattern, ctx.author.id) + + content = "**{}** set up! Use `{}start` to start when ready.".format(sub.timer.name, ctx.best_prefix) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=sub.timer, + ctx=ctx, + content=content, + reference=ctx.msg + ) + ) -@cmd("set", - group="Timer", - desc="Setup the stages of a group timer.", - aliases=['setup', 'reset']) -@in_guild() -@timer_ready() -async def cmd_set(ctx): +@module.cmd("reset", + group="Timer Control", + short_help="Reset the timer pattern to the default.") +@has_timers() +async def cmd_reset(ctx): """ Usage``: - set - set - set + {prefix}reset Description: - Setup the stages of the timer you are subscribed to. - When used with no parameters, uses the following default setup string: - ``` - Study, 25, Good luck!; Break, 5, Have a rest.; - Study, 25, Good luck!; Break, 5, Have a rest.; - Study, 25, Good luck!; Long Break, 10, Have a rest. - ``` - Stages are separated by semicolons, - and are of the format `stage name, stage duration, stage message`. - The `stage message` is optional. - - See the `presets` command for more information on using setup presets. + Stop your group timer, and reset the timer pattern to the timer default. + (To change the default pattern, see `{prefix}tconfig default_pattern`.) + + *If the `admin_locked` timer option is set, this command requires timer admin permissions.* Related: - join, start, presets + tconfig, setup, stop """ - # Get the timer we are acting on - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: - tchan = ctx.client.interface.channels.get(ctx.ch.id, None) - if tchan is None or not tchan.timers: - await ctx.error_reply("There are no timers in this channel!") - else: - await ctx.error_reply("Please join a group first!") - return + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + return await ctx.error_reply( + "You are not in a study group! Join one with `{prefix}join`.".format(prefix=ctx.best_prefix) + ) - # If the timer is running, prompt for confirmation - if timer.state == TimerState.RUNNING: - if ctx.arg_str: - if not await ctx.ask("The timer is running! Are you sure you want to reset it?"): - return - else: - if not await ctx.ask("The timer is running! Are you sure you want to reset it? " - "This will reset the stage sequence to the default!"): - return + if sub.timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be reset by timer admins.") - if not ctx.arg_str: - # Use the default setup string - # TODO: Customise defaults for different timers - setupstr = ( - "Study, 25, Good luck!; Break, 5, Have a rest.;" - "Study, 25, Good luck!; Break, 5, Have a rest.;" - "Study, 25, Good luck!; Long Break, 10, Have a rest." - ) - stages = ctx.client.interface.parse_setupstr(setupstr) - else: - # Parse the provided setup string - if "," in ctx.arg_str: - # Parse as a standard setup string - stages = ctx.client.interface.parse_setupstr(ctx.arg_str) - if stages is None: - return await ctx.error_reply("Didn't understand setup string!") - else: - # Parse as a preset - presets = get_presets(ctx) - if ctx.arg_str in presets: - stages = ctx.client.interface.parse_setupstr(presets[ctx.arg_str]) - else: - return await ctx.error_reply( - ("Didn't recognise the timer preset `{}`.\n" - "Use the `presets` command to view available presets.").format(ctx.arg_str) - ) + if sub.timer.state == TimerState.RUNNING: + if not await ctx.ask("Are you sure you want to **stop and reset** your study group timer?"): + return - timer.setup(stages) - await ctx.reply("Timer pattern set up! Start when ready.") + await sub.timer.setup(sub.timer.default_pattern, ctx.author.id) + + content = "**{}** has been reset! Use `{}start` to start when ready.".format(sub.timer.name, ctx.best_prefix) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=sub.timer, + ctx=ctx, + content=content, + reference=ctx.msg + ) + ) -@cmd("start", - group="Timer", - desc="Start your timer.", - aliases=["restart"]) -@in_guild() -@timer_ready() +@module.cmd("start", + group="Timer Control", + short_help="Start your group timer (and optionally change the pattern).", + aliases=["restart"]) +@has_timers() async def cmd_start(ctx): """ Usage``: - start - start + {prefix}start + {prefix}start + {prefix}start + {prefix}restart Description: - Start the timer you are subscribed to. - Can be used with a setup string to set up and start the timer in one go. + Start or restart your group timer. + + To modify the timer pattern (i.e. the pattern of work/break stages), \ + provide a timer pattern or a saved pattern name. \ + See `{prefix}help patterns` and the examples below for more information about the pattern format. + + *If the `admin_locked` timer option is set, this command requires timer admin permissions, unless\ + the timer is already stopped.* + Related: + stop, setup, tconfig, savepattern + Examples: + `{prefix}start` (Start/restart the timer with the current pattern.) + `{prefix}start 50/10` (`50` minutes work followed by `10` minutes break.) + `{prefix}start 25/5/25/5/25/10` (A standard Pomodoro pattern of work and breaks.) + `{prefix}start Study, 50; Rest, 10` (Another `50/10` pattern, now with custom stage names.) """ - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: - tchan = ctx.client.interface.channels.get(ctx.ch.id, None) - if tchan is None or not tchan.timers: - await ctx.error_reply("There are no timers in this channel!") - else: - await ctx.error_reply("Please join a group first!") - return - if timer.state == TimerState.RUNNING: - if await ctx.ask("Are you sure you want to restart your study group timer?"): - timer.stop() + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + return await ctx.error_reply( + "You are not in a study group! Join one with `{prefix}join`.".format(prefix=ctx.best_prefix) + ) + if sub.timer.state == TimerState.RUNNING: + if sub.timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be restarted by timer admins.") + + if await ctx.ask("Are you sure you want to **restart** your study group timer?"): + sub.timer.stop() else: return - if ctx.arg_str: - stages = ctx.client.interface.parse_setupstr(ctx.arg_str) + timer = sub.timer - if stages is None: - return await ctx.error_reply("Didn't understand setup string!") + if ctx.args: + pattern = Pattern.from_userstr(ctx.args, timerid=sub.timer.roleid, userid=ctx.author.id, guildid=ctx.guild.id) + await timer.setup(pattern, ctx.author.id) - timer.setup(stages) - - if not timer.stages: - return await ctx.error_reply("Please set up the timer first!") + this_channel = (ctx.ch == timer.channel) + content = "Started **{}** in {}!".format( + timer.name, + timer.channel.mention + ) if not this_channel else '' await timer.start() - - if timer.channel != ctx.ch: - await ctx.reply("Timer has been started in {}".format(timer.channel.mention)) + if ctx.args: + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=timer, + ctx=ctx, + content=content, + reference=ctx.msg + ) + ) + elif not this_channel: + await ctx.reply(content, reference=ctx.msg) -@cmd("stop", - group="Timer", - desc="Stop your timer.") -@in_guild() -@timer_ready() +@module.cmd("stop", + group="Timer Control", + short_help="Stop your group timer.") +@has_timers() async def cmd_stop(ctx): """ Usage``: - stop + {prefix}stop Description: - Stop the timer you are subscribed to. + Stop your study group timer. + + *If the `admin_locked` timer option is set, this command requires timer admin permissions.* + Related: + start, reset, tconfig """ - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: - tchan = ctx.client.interface.channels.get(ctx.ch.id, None) - if tchan is None or not tchan.timers: - await ctx.error_reply("There are no timers in this channel!") - else: - await ctx.error_reply("Please join a group first!") - return - if timer.state == TimerState.STOPPED: - return await ctx.error_reply("Can't stop something that's not moving!") + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + return await ctx.error_reply( + "You are not in a study group! Join one with `{prefix}join`.".format(prefix=ctx.best_prefix) + ) - if len(timer.subscribed) > 1: + if sub.timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be stopped by timer admins.") + + if sub.timer.state != TimerState.RUNNING: + # TODO: Might want an extra clause when we have Pause states + return await ctx.error_reply( + "Can't stop something that's not moving! (Your group timer is already stopped.)" + ) + + if len(sub.timer.subscribers) > 1: if not await ctx.ask("There are other people in your study group! " "Are you sure you want to stop the study group timer?"): return - timer.stop() - await ctx.reply("Your timer has been stopped.") + sub.timer.stop() + await ctx.reply("Your group timer has been stopped.") -@cmd("groups", - group="Timer", - desc="View the guild's groups.", - aliases=["timers"]) -@in_guild() -@timer_ready() +async def _group_msg(msg, ctx=None): + """ + Group message live-editor. + """ + sections = [] + for tchan in client.interface.guild_channels.get(ctx.guild.id, {}).values(): + if len(tchan.timers) > 0: + sections.append("{}\n\n{}".format( + tchan.channel.mention, + "\n\n".join(timer.pretty_summary for timer in tchan.timers) + )) + + embed = discord.Embed( + description="\n\n\n".join(sections) or "No timers in this guild!", + colour=discord.Colour(0x9b59b6), + title="Study groups", + timestamp=datetime.datetime.utcnow() + ).set_footer(text="Last Updated") + + if msg: + try: + await msg.edit(embed=embed) + return msg + except discord.HTTPException: + pass + else: + return await ctx.reply(embed=embed) + + +@module.cmd("groups", + group="Timer Usage", + short_help="List the server study groups.", + aliases=["timers"]) +@has_timers() async def cmd_groups(ctx): + """ + Usage``: + {prefix}groups + Description: + List all the study groups in this server. + Related: + join, newgroup, delgroup + """ # Handle there being no timers - if not ctx.client.interface.get_guild_timers(ctx.guild.id): - return await ctx.error_reply("There are no groups set up in this guild!") - - if "live_grouptokens" not in ctx.client.objects: - ctx.client.objects["live_grouptokens"] = {} - ctx.client.objects["live_grouptokens"][ctx.ch.id] = ctx.msg.id - - async def _groups(): - # Check if we have a new token - if ctx.client.objects["live_grouptokens"].get(ctx.ch.id, 0) != ctx.msg.id: - return None - - # Build the embed description - sections = [] - for tchan in ctx.client.interface.guild_channels[ctx.guild.id]: - if len(tchan.timers) > 0: - sections.append("{}\n\n{}".format( - tchan.channel.mention, - "\n\n".join(timer.pretty_summary() for timer in tchan.timers) - )) - - embed = discord.Embed( - description="\n\n\n".join(sections) or "No timers in this guild!", - colour=discord.Colour(0x9b59b6), - title="Group timers in this guild" + timers = ctx.timers.get_timers_in(ctx.guild.id) + if not timers: + return await ctx.error_reply( + "This server doesn't have any study groups yet!\n" + "Create one with `{prefix}newgroup ` " + "(e.g. `{prefix}newgroup Pomodoro`).".format(prefix=ctx.best_prefix) ) - return {'embed': embed} - await ctx.live_reply(_groups) + asyncio.create_task(live_edit( + None, + _group_msg, + 'groups', + ctx=ctx + )) + +async def _status_msg(msg, timer, ctx, content='', reference=None): + embed = discord.Embed( + description=timer.status_string(show_seconds=True), + colour=discord.Colour(0x9b59b6), + timestamp=datetime.datetime.utcnow() + ).set_footer(text="Last Updated") -@cmd("status", - group="Timer", - desc="View detailed information about a group.", - aliases=["group", "timer"]) -@in_guild() -@timer_ready() -async def cmd_group(ctx): + if msg: + try: + await msg.edit(content=content, embed=embed) + return msg + except discord.HTTPException: + pass + else: + return await ctx.reply(content=content, embed=embed, reference=reference) + + +@module.cmd("status", + group="Timer Usage", + short_help="Show the status of a group.") +@has_timers() +async def cmd_status(ctx): """ Usage``: - status [group] + {prefix}status + {prefix}status Description: - Display detailed information about the current group or the specified group. + Display the status of the provided group (or your current/selected group if none was given). + + The `group` may be given as a group name or partial name. \ + See `{prefix}groups` for the list of groups in this server. + Related: + groups, start, stop, setup """ - if ctx.arg_str: - timer = await ctx.get_timers_matching(ctx.arg_str, channel_only=False) + # Get target group + if ctx.args: + timer = await ctx.get_timers_matching(ctx.args, channel_only=False) if timer is None: - return await ctx.error_reply("No groups matching `{}`!".format(ctx.arg_str)) + return await ctx.error_reply("No groups found matching `{}`!".format(ctx.args)) else: - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: - timer = await ctx.get_timers_matching("", channel_only=False) + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub: + timer = sub.timer + else: + timer = await ctx.get_timers_matching('', channel_only=False) if timer is None: - return await ctx.error_reply("No groups are set up in this guild.") - - if "live_statustokens" not in ctx.client.objects: - ctx.client.objects["live_statustokens"] = {} - ctx.client.objects["live_statustokens"][ctx.ch.id] = ctx.msg.id - - async def _status(): - # Check if we have a new token - if ctx.client.objects["live_statustokens"].get(ctx.ch.id, 0) != ctx.msg.id: - return None + return await ctx.error_reply( + "This server doesn't have any study groups yet!\n" + "Create one with `{prefix}newgroup ` " + "(e.g. `{prefix}newgroup Pomodoro`).".format(prefix=ctx.best_prefix) + ) - embed = discord.Embed( - description=timer.pretty_pinstatus(), - colour=discord.Colour(0x9b59b6) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + ctx=ctx, + timer=timer ) - return {'embed': embed} + ) - await ctx.live_reply(_status) - -@cmd("notify", - group="Timer", - desc="Configure your personal notification level.", - aliases=["dm"]) -async def cmd_notify(ctx): +@module.cmd("shift", + group="Timer Control", + short_help="Add or remove time from the current stage.") +@has_timers() +async def cmd_shift(ctx): """ Usage``: - notify - notify + {prefix}shift + {prefix}shift Description: - View or set your notification level. - The possible levels are described below. - Notification levels:: - all: Receive all stage changes and status updates via DM. - warnings: Only receive a DM for inactivity warnings (default). - kick: Only receive a DM after being kicked for inactivity. - none: Never get sent any status updates via DM. + Adds or removes time from the current stage. + When `amount` is *positive*, adds time to the stage, and removes time when `amount` is *negative*. + `amount` must be given in minutes, with no units (see examples below). + If `amount` is not given, instead aligns the start of the stage to the nearest hour. + + *If the `admin_locked` timer option is set, this command requires timer admin permissions.* Examples``: - notify warnings + {prefix}shift +10 + {prefix}shift -10 """ - if not ctx.arg_str: - # Read the current level and report - level = ctx.client.config.users.get(ctx.author.id, "notify_level") or None - level = NotifyLevel(level) if level is not None else NotifyLevel.WARNING - - if level == NotifyLevel.ALL: - await ctx.reply("Your notification level is `ALL`.\n" - "You will be notified of all group status changes by direct message.") - elif level == NotifyLevel.WARNING: - await ctx.reply("Your notification level is `WARNING`.\n" - "You will receive a direct message when you are about to be kicked for inactivity.") - elif level == NotifyLevel.FINAL: - await ctx.reply("Your notification level is `KICK`.\n" - "You will only be messaged when you are kicked for inactivity.") - elif level == NotifyLevel.NONE: - await ctx.reply("Your notification level is `NONE`.\n" - "You will never be direct messaged about group status updates.") - else: - content = ctx.arg_str.lower() - - newlevel = None - message = None - if content in ["all", "everything"]: - newlevel = NotifyLevel.ALL - message = ("Your notification level has been set to `ALL`\n" - "You will be notified of all group status changes by direct message.") - elif content in ["warnings", "warning"]: - newlevel = NotifyLevel.WARNING - message = ("Your notification level has been set to `WARNING`.\n" - "You will receive a direct message when you are about to be kicked for inactivity.") - elif content in ["final", "kick"]: - newlevel = NotifyLevel.FINAL - message = ("Your notification level has been set to `KICK`.\n" - "You will only be messaged when you are kicked for inactivity.") - elif content in ["none", "dnd"]: - newlevel = NotifyLevel.NONE - message = ("Your notification level has been set to `NONE`.\n" - "You will never be direct messaged about group status updates.") - else: - await ctx.error_reply( - "I don't understand notification level `{}`! See `help notify` for valid levels.".format(ctx.arg_str) - ) - if newlevel is not None: - # Update the db entry - ctx.client.config.users.set(ctx.author.id, "notify_level", newlevel.value) + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + return await ctx.error_reply( + "You are not in a study group!" + ) - # Update any existing timers - for subber in ctx.client.interface.get_subs_for(ctx.author.id): - subber.notify = NotifyLevel(newlevel) + if sub.timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be shifted by timer admins.") - # Send the update message - await ctx.reply(message) + if sub.timer.state != TimerState.RUNNING: + return await ctx.error_reply( + "You can only shift a group timer while it is running!" + ) + if len(sub.timer.subscribers) > 1: + if not await ctx.ask("There are other people in your study group! " + "Are you sure you want to shift the study group timer?"): + return -@cmd("rename", - group="Timer", - desc="Rename your group.") -@in_guild() -@timer_ready() -async def cmd_rename(ctx): + if not ctx.args: + quantity = None + elif ctx.args.strip('+-').isdigit(): + quantity = (-1 if ctx.args.startswith('-') else 1) * int(ctx.args.strip('+-')) + else: + return await ctx.error_reply( + "Could not parse `{}` as a shift amount!".format(ctx.args) + ) + + sub.timer.shift(quantity * 60 if quantity is not None else None) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=sub.timer, + ctx=ctx, + content="Timer shifted!", + reference=ctx.msg + ) + ) + + +@module.cmd("skip", + group="Timer Control", + short_help="Skip the current stage.") +@has_timers() +async def cmd_skip(ctx): """ Usage``: - rename + {prefix}skip + {prefix}skip Description: - Set the name of your current group to `groupname`. - Arguments:: - groupname: The new name for your group, less than `20` characters long. - Related: - join, status, groups + Skip the current timer stage, or the number of stages given. + Examples``: + {prefix}skip 1 """ - timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if timer is None: + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: return await ctx.error_reply( - "You need to join a group first!" + "You are not in a study group!" ) - if not (0 < len(ctx.arg_str) < 20): + timer = sub.timer + + if timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be skipped by timer admins.") + + if timer.state != TimerState.RUNNING: return await ctx.error_reply( - "Please supply a new group name under `20` characters long!\n" - "**Usage:** `rename `" + "You can only skip stages of a group timer while it is running!" ) - timer.name = ctx.arg_str - await ctx.embedreply("Your group has been renamed to **{}**.".format(ctx.arg_str)) + if len(timer.subscribers) > 1: + if not await ctx.ask("There are other people in your study group! " + "Are you sure you want to skip forwards?"): + return + + # Collect the number of stages to skip + count = 1 + pattern_len = len(timer.current_pattern.stages) + if ctx.args: + if not ctx.args.isdigit(): + return await ctx.error_reply( + "**Usage:** `{prefix}skip [number].\n" + "Couldn't parse the number of stages to skip.".format(prefix=ctx.best_prefix) + ) + if len(ctx.args) > 10 or int(ctx.args) > pattern_len: + return await ctx.error_reply( + "Maximum number of skippable stages is `{}`.".format(pattern_len) + ) + count = int(ctx.args) + if count == 0: + return await ctx.error_reply( + "Skipping no stages.. done?" + ) -@cmd("syncwith", - group="Timer", - desc="Sync the start of your group timer with another group") -@in_guild() -@timer_ready() -async def cmd_syncwith(ctx): + # Calculate the shift time + shift_by = timer.remaining + sum( + timer.current_pattern.stages[(timer.stage_index + i + 1) % pattern_len].duration * 60 + for i in range(count - 1) + ) - 1 + + timer.shift(-1 * shift_by) + content = "**{}** stages skipped!".format(count) if count > 1 else "Stage skipped!" + await asyncio.sleep(1) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=sub.timer, + ctx=ctx, + content=content, + reference=ctx.msg + ) + ) + + +@module.cmd("syncwith", + group="Timer Control", + short_help="Sync the timer with another group.", + flags=('end',)) +@has_timers() +async def cmd_syncwith(ctx, flags): """ Usage``: - syncwith + {prefix}syncwith [--end] Description: - Align the start of your group timer with the other group. - This will possibly change your stage without notification. - Arguments:: - group: The name of the group to sync with. - Related: - join, status, groups, set + Synchronise your current timer with the timer of the provided group. + This is usually done by *moving* the start of your current stage to the start of the target group's stage. + If the `-end` flag is added, instead moves the *end* of your current stage to match the end of the target stage. + + *If the `admin_locked` timer option is set, this command requires timer admin permissions.* + Examples``: + {prefix}syncwith {ctx.example_group_name} + {prefix}syncwith {ctx.example_group_name} --end """ - # Check an argumnet was given - if not ctx.arg_str: - return await ctx.error_reply("No group name provided!\n**Usage:** `syncwith `.") - - # Check the author is in a group - current_timer = ctx.client.interface.get_timer_for(ctx.guild.id, ctx.author.id) - if current_timer is None: - return await ctx.error_reply("You can only sync a group you are a member of!") - - # Get the target timer to sync with - sync_timer = await ctx.get_timers_matching(ctx.arg_str, channel_only=False) - if sync_timer is None: - return await ctx.error_reply("No groups matching `{}`!".format(ctx.arg_str)) - - # Check both timers are set up - if not sync_timer.stages or not current_timer.stages: - return await ctx.error_reply("Both the current and target timer must be set up first!") - - # Calculate the total duration from the start of the timer - target_duration = sum(stage.duration for i, stage in enumerate(sync_timer.stages) if i < sync_timer.current_stage) - target_duration *= 60 - target_duration += sync_timer.now() - sync_timer.current_stage_start - - # Calculate the target stage in the current timer - i = -1 - elapsed = 0 - while elapsed < target_duration: - i = (i + 1) % len(current_timer.stages) - elapsed += current_timer.stages[i].duration * 60 - - # Calculate new stage start - new_stage_start = sync_timer.now() - (current_timer.stages[i].duration * 60 - (elapsed - target_duration)) - - # Change the stage and adjust the time - await current_timer.change_stage(i, notify=False, inactivity_check=False, report_old=False) - current_timer.current_stage_start = new_stage_start - current_timer.remaining = elapsed - target_duration - - # Notify the user - await ctx.embedreply(current_timer.pretty_pinstatus(), title="Timers synced!") + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is None: + return await ctx.error_reply( + "You are not in a study group!" + ) + timer = sub.timer + + if timer.settings.admin_locked.value and not await is_timer_admin(ctx.author): + return await ctx.error_reply("This timer may only be synced by a timer admin.") + + if timer.state != TimerState.RUNNING: + return await ctx.error_reply( + "Timers may only be synced while they are running!" + ) + + if ctx.args: + target = await ctx.get_timers_matching(ctx.args, channel_only=False) + if target is None and ctx.args.isdigit(): + # Last-ditch check, accept roleids from foreign guilds + roleid = int(ctx.args) + timer_row = tables.timers.fetch(roleid) + if timer_row is not None and timer_row.guildid in ctx.timers.guild_channels: + target = next( + (t for t in ctx.timers.guild_channels[timer_row.guildid][timer_row.channelid].timers + if t.roleid == roleid), + None + ) + + if target is None: + return await ctx.error_reply("No target groups found matching `{}`!".format(ctx.args)) + else: + return await ctx.error_reply( + "**Usage:** `{}syncwith [--end]`\n" + "No target group provided!".format(ctx.best_prefix) + ) + + if len(timer.subscribers) > 1: + if not await ctx.ask("There are other people in your study group! " + "Are you sure you want to sync it with **{}**?".format(target.name)): + return + + if target.state != TimerState.RUNNING: + return await ctx.error_reply( + "Target timer isn't running! Use `{}restart` if you want to restart your timer.".format(ctx.best_prefix) + ) + + # Perform the actual sync + diff = target.stage_start - timer.stage_start + if flags['end']: + diff += (target.current_stage.duration - timer.current_stage.duration) * 60 + + timer.shift(diff) + + content = "Timer synced with **{}**!".format(target.name) + asyncio.create_task( + live_edit( + None, + _status_msg, + 'status', + timer=sub.timer, + ctx=ctx, + content=content, + reference=ctx.msg + ) + ) diff --git a/bot/commands/timer_admin.py b/bot/commands/timer_admin.py new file mode 100644 index 0000000..6c5e564 --- /dev/null +++ b/bot/commands/timer_admin.py @@ -0,0 +1,55 @@ +from utils import seekers, ctx_addons # noqa +from wards import has_timers, timer_admin + +from Timer import module + + +@module.cmd("forcekick", + group="Timer Control", + short_help="Kick a member from a study group.", + aliases=('kick',)) +@has_timers() +@timer_admin() +async def cmd_forcekick(ctx): + """ + Usage``: + {prefix}forcekick + Description: + Forcefully unsubscribe a group member. + + *Requires timer admin permissions.* + Examples``: + {prefix}forcekick {ctx.author.name} + """ + if not ctx.args: + return await ctx.error_reply( + "**Usage:** `{}forcekick `\n" + "Please provided a user to kick.".format(ctx.best_prefix) + ) + subscribers = [ + sub for timer in ctx.timers.get_timers_in(ctx.guild.id) for sub in timer.subscribers.values() + ] + members = [ + sub.member for sub in subscribers if sub.member + ] + if len(members) != len(subscribers): + # There are some subscribers without a member! First get them + for sub in subscribers: + if not sub.member: + await sub._fetch_member() + members = [ + sub.member for sub in subscribers if sub.member + ] + + member = await ctx.find_member(ctx.args, collection=members, silent=True) + if not member: + return await ctx.error_reply("No subscriber found matching `{}`!".format(ctx.args)) + + sub = ctx.timers.get_subscriber(member.id, member.guild.id) + if not sub: + return await ctx.error_reply("This member is no longer subscribed!") + + await sub.timer.unsubscribe(sub.userid, post=True) + + if ctx.ch != sub.timer.channel: + await ctx.embed_reply("{} was unsubscribed.".format(member.mention)) diff --git a/bot/commands/timer_config.py b/bot/commands/timer_config.py new file mode 100644 index 0000000..a339d9c --- /dev/null +++ b/bot/commands/timer_config.py @@ -0,0 +1,154 @@ +import discord + +from utils.lib import prop_tabulate +from settings import TimerSettings, UserInputError +from wards import has_timers + +from Timer import module + + +@module.cmd("timerconfig", + group="Group Admin", + short_help="Advanced timer configuration.", + aliases=("tconfig", "groupconfig")) +@has_timers() +async def cmd_tconfig(ctx): + """ + Usage: + `{prefix}tconfig help` (*See short descriptions of all the timer settings.*) + `{prefix}tconfig [timer name]` (*See the current settings for the given timer.*) + `{prefix}tconfig [timer name] ` (*See details about the given setting.*) + `{prefix}tconfig [timer name] ` (*Modify a setting in the given timer.*) + Description: + View or set advanced timer settings. + + The `timer name` argument is optional and you will be prompted to select a timer if it is not provided. \ + However, **if the timer name contains a space it must be given in quotes**. + Partial timer names are also supported. + + *Modifying timer settings requires at least timer admin permissions.* + Examples``: + {prefix}tconfig help + {prefix}tconfig "{ctx.example_group_name}" + {prefix}tconfig "{ctx.example_group_name}" default_pattern + {prefix}tconfig "{ctx.example_group_name}" default_pattern 50/10 + """ + # Cache and map setting info + timers = ctx.timers.get_timers_in(ctx.guild.id) + timer_names = (timer.name.lower() for timer in timers) + setting_displaynames = {setting.display_name.lower(): setting for setting in TimerSettings.settings.values()} + args = ctx.args + + cats = {} # Timer setting categories + for setting in TimerSettings.settings.values(): + if setting.category not in cats: + cats[setting.category] = {} + cats[setting.category][setting.display_name] = setting + + # Parse + timer = None + setting = None + value = None + if args.lower() == 'help': + # No parsing to do + # Signified by empty timer value + pass + elif args: + splits = args[1:].split('"', maxsplit=1) if args.startswith('"') else args.split(maxsplit=1) + maybe_name = splits[0] + if maybe_name.lower() in setting_displaynames and maybe_name.lower() not in timer_names: + # Assume the provided name is a setting name + setting = setting_displaynames[maybe_name.lower()] + value = splits[1] if len(splits) > 1 else None + + # Retrieve the timer from context, or prompt + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is not None: + timer = sub.timer + elif len(timers) == 1: + timer = timers[0] + else: + timer = await ctx.get_timers_matching( + '', channel_only=False, info=True, + header="Please select a group to configure." + ) + else: + timer = await ctx.get_timers_matching(maybe_name, channel_only=False, info=True) + if not timer: + return await ctx.error_reply("No groups found matching `{}`.".format(maybe_name)) + if len(splits) > 1 and splits[1]: + remaining_splits = splits[1].split(maxsplit=1) + setting = setting_displaynames.get(remaining_splits[0].lower(), None) + if setting is None: + return await ctx.error_reply( + "`{}`is not a timer setting!\n" + "Use `{}tconfig \"{}\"` to see the available settings.".format( + remaining_splits[1], ctx.best_prefix, timer.name + ) + ) + if len(remaining_splits) > 1: + value = remaining_splits[1] + else: + # Retrieve the timer from context, or prompt + sub = ctx.timers.get_subscriber(ctx.author.id, ctx.guild.id) + if sub is not None: + timer = sub.timer + elif len(timers) == 1: + timer = timers[0] + else: + timer = await ctx.get_timers_matching( + '', channel_only=False, info=True, + header="Please select a group to view." + ) + + # Handle different modes + if timer is None or setting is None: + # Display timer configuration or descriptions + fields = ( + (cat, prop_tabulate(*zip(*( + (setting.display_name, setting.get(timer.roleid).formatted if timer is not None else setting.desc) + for name, setting in cat_settings.items() + )))) + for cat, cat_settings in cats.items() + ) + if timer: + embed = discord.Embed( + title="Timer configuration for `{}`".format(timer.name), + description=( + "**Tip:** See `{0}help tconfig` for command usage and examples, " + "and `{0}tconfig help` to see short descriptions of each setting.".format(ctx.best_prefix) + ) + ) + else: + embed = discord.Embed( + title="Timer configuration options", + description=( + "**Tip:** See `{}help tconfig` for command usage and examples.".format(ctx.best_prefix) + ) + ) + embed.set_footer( + text="Use \"{}tconfig timer setting [value]\" to see or modify a setting.".format( + ctx.best_prefix + ) + ) + for i, (name, value) in enumerate(fields): + embed.add_field(name=name, value=value, inline=(bool(timer) and bool((i + 1) % 3))) + await ctx.reply(embed=embed) + elif value is None: + # Display setting information for the given timer and value + await ctx.reply(embed=setting.get(timer.roleid).embed) + else: + # Check the write ward + if not await setting.write_ward.run(ctx): + return await ctx.error_reply(setting.write_ward.msg) + + # Write the setting value + try: + (await setting.parse(timer.roleid, ctx, value)).write() + except UserInputError as e: + await ctx.reply(embed=discord.Embed( + description="{} {}".format('❌', e.msg), + color=discord.Colour.red() + )) + else: + await ctx.reply(embed=discord.Embed(description="{} Setting updated!".format('✅'))) diff --git a/bot/commands/user_config.py b/bot/commands/user_config.py new file mode 100644 index 0000000..aabc145 --- /dev/null +++ b/bot/commands/user_config.py @@ -0,0 +1,44 @@ +from settings import UserSettings + +from Timer import module + + +@module.cmd( + "mytimezone", + group="Personal Settings", + short_help=("Timezone for displaying session data. " + "(Currently {ctx.author_settings.timezone.formatted})"), + aliases=('mytz',) +) +async def cmd_mytimezone(ctx): + """ + Usage``: + {prefix}mytimezone + {prefix}mytimezone + Setting Description: + {ctx.author_settings.settings.timezone.long_desc} + Accepted Values: + Timezone names must be from the "TZ Database Name" column of \ + [this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). + For example, `Europe/London`, `Australia/Melbourne`, or `America/New_York`. + """ + await UserSettings.settings.timezone.command(ctx, ctx.author.id) + + +@module.cmd( + "notify", + group="Personal Settings", + short_help=("DM notification level. " + "(Currently {ctx.author_settings.notify_level.formatted})") +) +async def cmd_notify(ctx): + """ + Usage``: + {prefix}notify + {prefix}notify + Setting Description: + {ctx.author_settings.settings.notify_level.long_desc} + Accepted Values: + {ctx.author_settings.settings.notify_level.accepted_table} + """ + await UserSettings.settings.notify_level.command(ctx, ctx.author.id) diff --git a/bot/data/__init__.py b/bot/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bot/data/data.py b/bot/data/data.py new file mode 100644 index 0000000..f29b9e7 --- /dev/null +++ b/bot/data/data.py @@ -0,0 +1,481 @@ +import os +import logging +import contextlib +from datetime import datetime +from itertools import chain +from enum import Enum + +import sqlite3 as sq +from cachetools import LRUCache +from meta import log + + +# Database constants +DB_PATH = 'data/data.db' +SCHEMA_PATH = 'data/schema.sql' +REQUIRED_VERSION = 1 + + +# Set up database connection +requires_init = not os.path.exists(DB_PATH) + +log("Establishing connection.", "DB_INIT", level=logging.DEBUG) +conn = sq.connect(DB_PATH, timeout=20, isolation_level=None) +conn.row_factory = sq.Row +conn.set_trace_callback(lambda message: log(message, context="DB_CONNECTOR", level=logging.DEBUG)) +sq.register_adapter(datetime, lambda dt: dt.timestamp()) + + +# Initialise the database if it was just created +if requires_init: + log("Running first-time setup.", "DB_INIT") + # Execute schema file + with conn: + with open(SCHEMA_PATH, 'r') as script: + conn.executescript(script.read()) + +# Check the version matches the required version +with conn: + log("Checking db version.", "DB_INIT") + cursor = conn.cursor() + + # Check if table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='VersionHistory'") + version_exists = cursor.fetchone() + if not version_exists: + # Create version table and insert version 0 + cursor.execute('CREATE TABLE VersionHistory (version INTEGER NOT NULL, time INTEGER NOT NULL)') + now = datetime.timestamp(datetime.utcnow()) + cursor.execute('INSERT INTO VersionHistory VALUES (0, {})'.format(datetime.timestamp(datetime.utcnow()))) + + # Get last entry in version table, compare against desired version + cursor.execute("SELECT * FROM VersionHistory ORDER BY rowid DESC LIMIT 1") + current_version, _, _ = cursor.fetchone() + + if current_version != REQUIRED_VERSION: + # Complain + raise Exception( + ("Database version is {}, required version is {}. " + "Please migrate database.").format(current_version, REQUIRED_VERSION) + ) + + cursor.close() + + +log("Established connection.", "DB_INIT") + + +# --------------- Data Interface Classes --------------- +class Table: + """ + Transparent interface to a single table structure in the database. + Contains standard methods to access the table. + Intended to be subclassed to provide more derivative access for specific tables. + """ + conn = conn + + def __init__(self, name): + self.name = name + + def select_where(self, *args, **kwargs): + with self.conn: + return select_where(self.name, *args, **kwargs) + + def select_one_where(self, *args, **kwargs): + with self.conn: + rows = self.select_where(*args, **kwargs) + return rows[0] if rows else None + + def update_where(self, *args, **kwargs): + with self.conn: + return update_where(self.name, *args, **kwargs) + + def delete_where(self, *args, **kwargs): + with self.conn: + return delete_where(self.name, *args, **kwargs) + + def insert(self, *args, **kwargs): + with self.conn: + return insert(self.name, *args, **kwargs) + + def insert_many(self, *args, **kwargs): + with self.conn: + return insert_many(self.name, *args, **kwargs) + + def upsert(self, *args, **kwargs): + with self.conn: + return upsert(self.name, *args, **kwargs) + + +class Row: + __slots__ = ('table', 'data', '_pending') + + conn = conn + + def __init__(self, table, data, *args, **kwargs): + super().__setattr__('table', table) + self.data = data + self._pending = None + + @property + def rowid(self): + return self.data[self.table.id_col] + + def __repr__(self): + return "Row[{}]({})".format( + self.table.name, + ', '.join("{}={!r}".format(field, getattr(self, field)) for field in self.table.columns) + ) + + def __getattr__(self, key): + if key in self.table.columns: + if self._pending and key in self._pending: + return self._pending[key] + else: + return self.data[key] + else: + raise AttributeError(key) + + def __setattr__(self, key, value): + if key in self.table.columns: + if self._pending is None: + self.update(**{key: value}) + else: + self._pending[key] = value + else: + super().__setattr__(key, value) + + @contextlib.contextmanager + def batch_update(self): + if self._pending: + raise ValueError("Nested batch updates for {}!".format(self.__class__.__name__)) + + self._pending = {} + try: + yield self._pending + finally: + self.update(**self._pending) + self._pending = None + + def _refresh(self): + row = self.table.select_one_where(**{self.table.id_col: self.rowid}) + if not row: + raise ValueError("Refreshing a {} which no longer exists!".format(type(self).__name__)) + self.data = row + + def update(self, **values): + rows = self.table.update_where(values, **{self.table.id_col: self.rowid}) + self.data = rows[0] + + @classmethod + def _select_where(cls, _extra=None, **conditions): + return select_where(cls._table, **conditions) + + @classmethod + def _insert(cls, **values): + return insert(cls._table, **values) + + @classmethod + def _update_where(cls, values, **conditions): + return update_where(cls._table, values, **conditions) + + +class RowTable(Table): + __slots__ = ( + 'name', + 'columns', + 'id_col', + 'row_cache' + ) + + conn = conn + + def __init__(self, name, columns, id_col, use_cache=True, cache=None, cache_size=1000): + self.name = name + self.columns = columns + self.id_col = id_col + self.row_cache = (cache or LRUCache(cache_size)) if use_cache else None + + # Extend original Table update methods to modify the cached rows + def update_where(self, *args, **kwargs): + data = super().update_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + cached_row = self.row_cache.get(data_row[self.id_col], None) + if cached_row is not None: + cached_row.data = data_row + return data + + def delete_where(self, *args, **kwargs): + data = super().delete_where(*args, **kwargs) + if self.row_cache is not None: + for data_row in data: + self.row_cache.pop(data_row[self.id_col], None) + return data + + def upsert(self, *args, **kwargs): + data = super().upsert(*args, **kwargs) + if self.row_cache is not None: + cached_row = self.row_cache.get(data[self.id_col], None) + if cached_row is not None: + cached_row.data = data + return data + + # New methods to fetch and create rows + def _make_rows(self, *data_rows): + """ + Create or retrieve Row objects for each provided data row. + If the rows already exist in cache, updates the cached row. + """ + if self.row_cache is not None: + rows = [] + for data_row in data_rows: + rowid = data_row[self.id_col] + + cached_row = self.row_cache.get(rowid, None) + if cached_row is not None: + cached_row.data = data_row + row = cached_row + else: + row = Row(self, data_row) + self.row_cache[rowid] = row + rows.append(row) + else: + rows = [Row(self, data_row) for data_row in data_rows] + return rows + + def create_row(self, *args, **kwargs): + data = self.insert(*args, **kwargs) + return self._make_rows(data)[0] + + def fetch_rows_where(self, *args, **kwargs): + # TODO: Handle list of rowids here? + data = self.select_where(*args, **kwargs) + return self._make_rows(*data) + + def fetch(self, rowid): + """ + Fetch the row with the given id, retrieving from cache where possible. + """ + row = self.row_cache.get(rowid, None) if self.row_cache is not None else None + if row is None: + rows = self.fetch_rows_where(**{self.id_col: rowid}) + row = rows[0] if rows else None + return row + + def fetch_or_create(self, rowid=None, **kwargs): + """ + Helper method to fetch a row with the given id or fields, or create it if it doesn't exist. + """ + if rowid is not None: + row = self.fetch(rowid) + else: + data = self.select_where(**kwargs) + row = self._make_rows(data[0])[0] if data else None + + if row is None: + creation_kwargs = kwargs + if rowid is not None: + creation_kwargs[self.id_col] = rowid + row = self.create_row(**creation_kwargs) + return row + + +# --------------- Query Builders --------------- +def select_where(table, select_columns=None, cursor=None, _extra='', **conditions): + """ + Select rows from the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + col_str = _format_selectkeys(select_columns) + + if conditions: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'SELECT {} FROM {} {} {}'.format(col_str, table, where_str, _extra), + criteria_values + ) + return cursor.fetchall() + + +def update_where(table, valuedict, cursor=None, **conditions): + """ + Update rows in the given table matching the conditions + """ + key_str, key_values = _format_updatestr(valuedict) + criteria, criteria_values = _format_conditions(conditions) + + if conditions: + where_str = "WHERE {}".format(criteria) + else: + where_str = "" + + cursor = cursor or conn.cursor() + cursor.execute( + 'UPDATE {} SET {} {} RETURNING *'.format(table, key_str, where_str), + tuple((*key_values, *criteria_values)) + ) + conn.commit() + return cursor.fetchall() + + +def delete_where(table, cursor=None, **conditions): + """ + Delete rows in the given table matching the conditions + """ + criteria, criteria_values = _format_conditions(conditions) + + cursor = cursor or conn.cursor() + cursor.execute( + 'DELETE FROM {} WHERE {}'.format(table, criteria), + criteria_values + ) + conn.commit() + return cursor.fetchall() + + +def insert(table, cursor=None, allow_replace=False, **values): + """ + Insert the given values into the table + """ + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + + action = 'REPLACE' if allow_replace else 'INSERT' + + # log(str(values)) + + cursor = cursor or conn.cursor() + cursor.execute( + '{} INTO {} {} VALUES {} RETURNING *'.format(action, table, key_str, value_str), + values + ) + conn.commit() + return cursor.fetchone() + + +def insert_many(table, *value_tuples, insert_keys=None, cursor=None): + """ + Insert all the given values into the table + """ + key_str = _format_insertkeys(insert_keys) + value_strs, value_tuples = zip(*(_format_insertvalues(value_tuple) for value_tuple in value_tuples)) + + value_str = ", ".join(value_strs) + values = tuple(chain(*value_tuples)) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} RETURNING *'.format(table, key_str, value_str), + values + ) + conn.commit() + return cursor.fetchall() + + +def upsert(table, constraint, cursor=None, **values): + """ + Insert or on conflict update. + """ + valuedict = values + keys, values = zip(*values.items()) + + key_str = _format_insertkeys(keys) + value_str, values = _format_insertvalues(values) + update_key_str, update_key_values = _format_updatestr(valuedict) + + if not isinstance(constraint, str): + constraint = ", ".join(constraint) + + cursor = cursor or conn.cursor() + cursor.execute( + 'INSERT INTO {} {} VALUES {} ON CONFLICT({}) DO UPDATE SET {} RETURNING *'.format( + table, key_str, value_str, constraint, update_key_str + ), + tuple((*values, *update_key_values)) + ) + conn.commit() + return cursor.fetchone() + + +# --------------- Query Formatting Tools --------------- +# Replace char used by the connection for query formatting +_replace_char: str = '?' + + +class fieldConstants(Enum): + """ + A collection of database field constants to use for selection conditions. + """ + NULL = "IS NULL" + NOTNULL = "IS NOT NULL" + + +def _format_conditions(conditions): + """ + Formats a dictionary of conditions into a string suitable for 'WHERE' clauses. + Supports `IN` type conditionals. + """ + if not conditions: + return ("", tuple()) + + values = [] + conditional_strings = [] + for key, item in conditions.items(): + if isinstance(item, (list, tuple)): + conditional_strings.append("{} IN ({})".format(key, ", ".join([_replace_char] * len(item)))) + values.extend(item) + elif isinstance(item, fieldConstants): + conditional_strings.append("{} {}".format(key, item.value)) + else: + conditional_strings.append("{}={}".format(key, _replace_char)) + values.append(item) + + return (' AND '.join(conditional_strings), values) + + +def _format_selectkeys(keys): + """ + Formats a list of keys into a string suitable for `SELECT`. + """ + if not keys: + return "*" + else: + return ", ".join(keys) + + +def _format_insertkeys(keys): + """ + Formats a list of keys into a string suitable for `INSERT` + """ + if not keys: + return "" + else: + return "({})".format(", ".join(keys)) + + +def _format_insertvalues(values): + """ + Formats a list of values into a string suitable for `INSERT` + """ + value_str = "({})".format(", ".join(_replace_char for value in values)) + return (value_str, values) + + +def _format_updatestr(valuedict): + """ + Formats a dictionary of keys and values into a string suitable for 'SET' clauses. + """ + if not valuedict: + return ("", tuple()) + keys, values = zip(*valuedict.items()) + + set_str = ", ".join("{} = {}".format(key, _replace_char) for key in keys) + + return (set_str, values) diff --git a/bot/data/queries.py b/bot/data/queries.py new file mode 100644 index 0000000..1c27678 --- /dev/null +++ b/bot/data/queries.py @@ -0,0 +1,18 @@ +""" +Collection of stored data queries and procedures. +""" + +from . import tables +# from .data import _format_conditions + + +def get_session_user_totals(start_ts, **kwargs): + sum_column = ( + "SUM(IIF(start_time < {start_ts}, duration - ({start_ts} - start_time), duration)) AS total" + ).format(start_ts=start_ts) if start_ts else "SUM(duration) AS total" + + return tables.session_patterns.select_where( + select_columns=('userid', 'name', sum_column), + _extra='AND start_time + duration > {} GROUP BY userid'.format(start_ts), + **kwargs + ) diff --git a/bot/data/tables.py b/bot/data/tables.py new file mode 100644 index 0000000..99c3585 --- /dev/null +++ b/bot/data/tables.py @@ -0,0 +1,46 @@ +from .data import RowTable, Table + +from cachetools import LFUCache +from weakref import WeakValueDictionary + + +guilds = RowTable( + 'guilds', + ('guildid', 'timer_admin_roleid', 'show_tips', + 'autoclean', 'timezone', 'prefix', 'globalgroups', 'studyrole_roleid'), + 'guildid', + cache_size=2500 +) + +users = RowTable( + 'users', + ('userid', 'notify_level', 'timezone', 'name'), + 'userid', + cache_size=2000 +) + +patterns = RowTable( + 'patterns', + ('patternid', 'short_repr', 'stage_str', 'created_at'), + 'patternid', + cache=LFUCache(1000) +) + +timers = RowTable( + 'timers', + ('roleid', 'guildid', 'name', 'channelid', 'patternid', + 'voice_channelid', 'voice_alert', 'track_voice_join', 'track_voice_leave', + 'auto_reset', 'admin_locked', 'track_role', 'compact', + 'voice_channel_name', + ), + # 'default_work_name', 'default_work_message', + # 'default_break_name', 'default_break_message'), + 'roleid', + cache=WeakValueDictionary() +) +sessions = Table('sessions') +session_patterns = Table('session_patterns') +timer_pattern_history = Table('timer_pattern_history') + +user_presets = Table('user_presets') +guild_presets = Table('guild_presets') diff --git a/bot/dev-main.py b/bot/dev-main.py new file mode 100644 index 0000000..29361dd --- /dev/null +++ b/bot/dev-main.py @@ -0,0 +1,7 @@ +import logging +import meta + +meta.logger.logger.setLevel(logging.DEBUG) +logging.getLogger("discord").setLevel(logging.INFO) + +import main # noqa diff --git a/bot/main.py b/bot/main.py index 66488d3..4444b02 100644 --- a/bot/main.py +++ b/bot/main.py @@ -1,31 +1,34 @@ import os -from config import conf -from logger import log -from cmdClient.cmdClient import cmdClient +from data import tables, data # noqa +from meta import client, conf + +import Timer # noqa -from BotData import BotData -from Timer import TimerInterface # Get the real location __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) -# Load required data from configs -masters = [int(master.strip()) for master in conf['masters'].split(",")] -config = BotData(app="pomo", data_file="data/config_data.db", version=0) - -# Initialise the client -client = cmdClient(prefix=conf['prefix'], owners=masters) -client.config = config -client.log = log - # Load the commands client.load_dir(os.path.join(__location__, 'commands')) +client.load_dir(os.path.join(__location__, 'plugins')) +# TODO: Recursive plugin loader # Initialise the timer -TimerInterface(client, conf['session_store']) +# TimerInterface(client, conf['session_store']) +client.initialise_modules() + + +@client.set_valid_prefixes +async def valid_prefixes(client, message): + return ( + (tables.guilds.fetch_or_create(message.guild.id).prefix if message.guild else None) or client.prefix, + '<@{}>'.format(client.user.id), + '<@!{}>'.format(client.user.id), + ) + # Log and execute! -log("Initial setup complete, logging in", context='SETUP') +client.log("Initial setup complete, logging in", context='SETUP') client.run(conf['TOKEN']) diff --git a/bot/meta/__init__.py b/bot/meta/__init__.py new file mode 100644 index 0000000..b38ada2 --- /dev/null +++ b/bot/meta/__init__.py @@ -0,0 +1,3 @@ +from .client import client +from .config import conf +from .logger import log diff --git a/bot/meta/client.py b/bot/meta/client.py new file mode 100644 index 0000000..3a5d334 --- /dev/null +++ b/bot/meta/client.py @@ -0,0 +1,10 @@ +from cmdClient.cmdClient import cmdClient + +from .config import conf +from .logger import log + + +# Initialise client +masters = [int(master.strip()) for master in conf['masters'].split(",")] +client = cmdClient(prefix=conf['prefix'], owners=masters) +client.log = log diff --git a/bot/config.py b/bot/meta/config.py similarity index 100% rename from bot/config.py rename to bot/meta/config.py diff --git a/bot/logger.py b/bot/meta/logger.py similarity index 65% rename from bot/logger.py rename to bot/meta/logger.py index 45937a2..129f850 100644 --- a/bot/logger.py +++ b/bot/meta/logger.py @@ -1,7 +1,8 @@ import sys +import traceback import logging -from config import conf +from .config import conf # Setup the logger @@ -17,6 +18,8 @@ logger.setLevel(logging.INFO) -def log(message, context="Global".center(18, '='), level=logging.INFO): +def log(message, context="Global".center(22, '='), level=logging.INFO, add_exc_info=False): + if add_exc_info: + message += '\n{}'.format(traceback.format_exc()) for line in message.split('\n'): - logger.log(level, '[{}] {}'.format(str(context).center(18, '='), line)) + logger.log(level, '[{}] {}'.format(str(context).center(22, '='), line)) diff --git a/bot/commands/exec.py b/bot/plugins/exec-cmds.py similarity index 97% rename from bot/commands/exec.py rename to bot/plugins/exec-cmds.py index 69e28c5..d3af7f2 100644 --- a/bot/commands/exec.py +++ b/bot/plugins/exec-cmds.py @@ -28,8 +28,9 @@ async def cmd_reboot(ctx): Update the timer status save file and reboot the client. """ ctx.client.interface.update_save("reboot") + ctx.client.interface.shutdown() await ctx.reply("Saved state. Rebooting now!") - await ctx.client.logout() + await ctx.client.close() @cmd("async") diff --git a/bot/settings/__init__.py b/bot/settings/__init__.py new file mode 100644 index 0000000..05956a5 --- /dev/null +++ b/bot/settings/__init__.py @@ -0,0 +1,5 @@ +from .base import Setting, ObjectSettings, UserInputError + +from .guild_settings import GuildSettings, GuildSetting +from .timer_settings import TimerSettings, TimerSetting +from .user_settings import UserSettings, UserSetting diff --git a/bot/settings/base.py b/bot/settings/base.py new file mode 100644 index 0000000..1e6c42f --- /dev/null +++ b/bot/settings/base.py @@ -0,0 +1,302 @@ +import discord +from cmdClient.cmdClient import cmdClient, Context +from cmdClient.lib import SafeCancellation +from cmdClient.Check import Check + +from utils.lib import prop_tabulate, DotDict + +from meta import client +from data.data import Table, RowTable + + +class Setting: + """ + Abstract base class describing a stored configuration setting. + A setting consists of logic to load the setting from storage, + present it in a readable form, understand user entered values, + and write it again in storage. + Additionally, the setting has attributes attached describing + the setting in a user-friendly manner for display purposes. + """ + attr_name: str = None # Internal attribute name for the setting + _default: ... = None # Default data value for the setting.. this may be None if the setting overrides 'default'. + + write_ward: Check = None # Check that must be passed to write the setting. Not implemented internally. + + # Configuration interface descriptions + display_name: str = None # User readable name of the setting + desc: str = None # User readable brief description of the setting + long_desc: str = None # User readable long description of the setting + accepts: str = None # User readable description of the acceptable values + + def __init__(self, id, data: ..., **kwargs): + self.client: cmdClient = client + self.id = id + self._data = data + + # Configuration embeds + @property + def embed(self): + """ + Discord Embed showing an information summary about the setting. + """ + embed = discord.Embed( + title="Configuration options for `{}`".format(self.display_name), + ) + fields = ("Current value", "Default value", "Accepted input") + values = (self.formatted or "Not Set", + self._format_data(self.id, self.default) or "None", + self.accepts) + table = prop_tabulate(fields, values) + embed.description = "{}\n{}".format(self.long_desc.format(self=self, client=self.client), table) + return embed + + @property + def success_response(self): + """ + Response message sent when the setting has successfully been updated. + """ + return "Setting Updated!" + + # Instance generation + @classmethod + def get(cls, id: int, **kwargs): + """ + Return a setting instance initialised from the stored value. + """ + data = cls._reader(id, **kwargs) + return cls(id, data, **kwargs) + + @classmethod + async def parse(cls, id: int, ctx: Context, userstr: str, **kwargs): + """ + Return a setting instance initialised from a parsed user string. + """ + data = await cls._parse_userstr(ctx, id, userstr, **kwargs) + return cls(id, data, **kwargs) + + # Main interface + @property + def data(self): + """ + Retrieves the current internal setting data if it is set, otherwise the default data + """ + return self._data if self._data is not None else self.default + + @data.setter + def data(self, new_data): + """ + Sets the internal setting data and writes the changes. + """ + self._data = new_data + self.write() + + @property + def default(self): + """ + Retrieves the default value for this setting. + Settings should override this if the default depends on the object id. + """ + return self._default + + @property + def value(self): + """ + Discord-aware object or objects associated with the setting. + """ + return self._data_to_value(self.id, self.data) + + @value.setter + def value(self, new_value): + """ + Setter which reads the discord-aware object, converts it to data, and writes it. + """ + self._data = self._data_from_value(self.id, new_value) + self.write() + + @property + def formatted(self): + """ + User-readable form of the setting. + """ + return self._format_data(self.id, self.data) + + def write(self, **kwargs): + """ + Write value to the database. + For settings which override this, + ensure you handle deletion of values when internal data is None. + """ + self._writer(self.id, self._data, **kwargs) + + # Raw converters + @classmethod + def _data_from_value(cls, id: int, value, **kwargs): + """ + Convert a high-level setting value to internal data. + Must be overriden by the setting. + Be aware of None values, these should always pass through as None + to provide an unsetting interface. + """ + raise NotImplementedError + + @classmethod + def _data_to_value(cls, id: int, data: ..., **kwargs): + """ + Convert internal data to high-level setting value. + Must be overriden by the setting. + """ + raise NotImplementedError + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Parse user provided input into internal data. + Must be overriden by the setting if the setting is user-configurable. + """ + raise NotImplementedError + + @classmethod + def _format_data(cls, id: int, data: ..., **kwargs): + """ + Convert internal data into a formatted user-readable string. + Must be overriden by the setting if the setting is user-viewable. + """ + raise NotImplementedError + + # Database access classmethods + @classmethod + def _reader(cls, id: int, **kwargs): + """ + Read a setting from storage and return setting data or None. + Must be overriden by the setting. + """ + raise NotImplementedError + + @classmethod + def _writer(cls, id: int, data: ..., **kwargs): + """ + Write provided setting data to storage. + Must be overriden by the setting unless the `write` method is overidden. + If the data is None, the setting is empty and should be unset. + """ + raise NotImplementedError + + @classmethod + async def command(cls, ctx, id): + """ + Standardised command viewing/setting interface for the setting. + """ + if not ctx.args: + # View config embed for provided cls + await ctx.reply(embed=cls.get(id).embed) + else: + # Check the write ward + if cls.write_ward and not await cls.write_ward.run(ctx): + await ctx.error_reply(cls.write_ward.msg) + else: + # Attempt to set config cls + try: + cls = await cls.parse(id, ctx, ctx.args) + except UserInputError as e: + await ctx.reply(embed=discord.Embed( + description="{} {}".format('❌', e.msg), + Colour=discord.Colour.red() + )) + else: + cls.write() + await ctx.reply(embed=discord.Embed( + description="{} {}".format('✅', cls.success_response), + Colour=discord.Colour.green() + )) + + +class ObjectSettings: + """ + Abstract class representing a linked collection of settings for a single object. + Initialised settings are provided as instance attributes in the form of properties. + """ + __slots__ = ('id', 'params') + + settings: DotDict = None + + def __init__(self, id, **kwargs): + self.id = id + self.params = tuple(kwargs.items()) + + @classmethod + def _setting_property(cls, setting): + def wrapped_setting(self): + return setting.get(self.id, **dict(self.params)) + return wrapped_setting + + @classmethod + def attach_setting(cls, setting: Setting): + name = setting.attr_name or setting.__name__ + setattr(cls, name, property(cls._setting_property(setting))) + cls.settings[name] = setting + + +class ColumnData: + """ + Mixin for settings stored in a single row and column of a Table. + Intended to be used with tables where the only primary key is the object id. + """ + # Table storing the desired data + _table_interface: Table = None + + # Name of the column storing the setting object id + _id_column: str = None + + # Name of the column with the desired data + _data_column: str = None + + # Whether to upsert or update for updates + _upsert: bool = True + + @classmethod + def _reader(cls, id: int, **kwargs): + """ + Read in the requested entry associated to the id. + Supports reading cached values from a `RowTable`. + """ + table = cls._table_interface + if isinstance(table, RowTable) and cls._id_column == table.id_col: + row = table.fetch(id) + return row.data[cls._data_column] if row else None + else: + params = { + "select_columns": (cls._data_column,), + cls._id_column: id + } + row = table.select_one_where(**params) + return row[cls._data_column] if row else None + + @classmethod + def _writer(cls, id: int, data: ..., **kwargs): + """ + Write the provided entry to the table, allowing replacements. + """ + table = cls._table_interface + params = { + cls._id_column: id + } + values = { + cls._data_column: data + } + + # Update data + if cls._upsert: + # Upsert data + table.upsert( + constraint=cls._id_column, + **params, + **values + ) + else: + # Update data + table.update_where(values, **params) + + +class UserInputError(SafeCancellation): + pass diff --git a/bot/settings/guild_settings.py b/bot/settings/guild_settings.py new file mode 100644 index 0000000..b097167 --- /dev/null +++ b/bot/settings/guild_settings.py @@ -0,0 +1,163 @@ +import datetime + +from data import tables +from wards import timer_admin, guild_admin +from meta import client +from utils.lib import DotDict + +from .base import Setting, ObjectSettings, ColumnData, UserInputError +from .setting_types import Role, Boolean, String, Timezone + + +class GuildSettings(ObjectSettings): + settings = DotDict() + + +class GuildSetting(ColumnData, Setting): + _table_interface = tables.guilds + _id_column = 'guildid' + + write_ward = timer_admin + + +@GuildSettings.attach_setting +class timer_admin_role(Role, GuildSetting): + attr_name = 'timer_admin_role' + _data_column = 'timer_admin_roleid' + + write_ward = guild_admin + + display_name = 'timer_admin_role' + desc = 'Role required to create and configure timers.' + long_desc = ( + "Role required to create and configure timers.\n" + "Having this role allows members to use most configuration commands " + "such as those under `Group Admin`, `Server Configuration`, " + "and `Registry Admin`.\n" + "Having the administrator server permission also allows use of these commands, " + "and some commands, such as `timeradmin` itself, require this permission instead.\n" + "(Required permissions for commands are listed in their `help` pages.)" + ) + + @property + def success_response(self): + return "The timer admin role is now {}.".format(self.formatted) + + +@GuildSettings.attach_setting +class show_tips(Boolean, GuildSetting): + attr_name = 'show_tips' + _data_column = 'show_tips' + + _default = True + + display_name = 'display_tips' + desc = 'Display usage tips for setting up and using the bot.' + long_desc = ( + "Display usage tips and hints on the output of various commands." + ) + + @property + def success_response(self): + return "Usage tips are now {}.".format("Enabled" if self.value else "Disabled") + + +@GuildSettings.attach_setting +class globalgroups(Boolean, GuildSetting): + attr_name = 'globalgroups' + _data_column = 'globalgroups' + + _default = False + + display_name = 'globalgroups' + desc = 'Whether timers may be joined from any channel.' + long_desc = ( + "By default, groups may only be joined from the text channel they are bound to. " + "This setting allows members to join study groups from any channel." + ) + + @property + def success_response(self): + if self.value: + return "Groups may now be joined from any channel." + else: + return "Groups may now only be joined from the text channel they are bound to." + + +@GuildSettings.attach_setting +class prefix(String, GuildSetting): + attr_name = 'prefix' + _data_column = 'prefix' + + _default = client.prefix + + write_ward = guild_admin + + display_name = 'prefix' + desc = 'The bot command prefix.' + long_desc = ( + "The command prefix required to run any command.\n" + "My mention will also always function as a prefix." + ) + + @property + def success_response(self): + return "The command prefix is now `{0}`. (E.g. `{0}help`.)".format(self.value) + + +@GuildSettings.attach_setting +class timezone(Timezone, GuildSetting): + attr_name = 'timezone' + _data_column = 'timezone' + + _default = 'UTC' + + write_ward = guild_admin + + display_name = 'timezone' + desc = 'The server leaderboard timezone.' + long_desc = ( + "The leaderboard timezone.\n" + "The current day/week/month/year displayed on the leaderboard will be calculated using this timezone." + ) + + @property + def success_response(self): + return ( + "The leaderboard timezone is now {}. " + "The current time is **{}**." + ).format(self.formatted, datetime.datetime.now(tz=self.value).strftime("%H:%M")) + + +@GuildSettings.attach_setting +class studyrole(Role, GuildSetting): + attr_name = 'studyrole' + _data_column = 'studyrole_roleid' + + write_ward = guild_admin + + display_name = 'studyrole' + desc = 'Common study role given to all timer members.' + long_desc = ( + "This role will be given to members when they join any group, " + "and removed when they leave the group, acting as a global study role.\n" + "The purpose is to facilitate easier simpler study permission management, " + "for example to control what channels studying members see." + ) + + @classmethod + async def _parse_userstr(cls, ctx, id: int, userstr: str, **kwargs): + roleid = await super()._parse_userstr(ctx, id, userstr, **kwargs) + if roleid: + role = ctx.guild.get_role(roleid) + # Check permissions + if role >= ctx.guild.me.top_role: + raise UserInputError("The study role must be lower than my top role!") + return roleid + + @property + def success_response(self): + if self.data: + return "The global study role is now {}.".format(self.value.mention) + else: + return "The global study role has been removed." diff --git a/bot/settings/setting_types.py b/bot/settings/setting_types.py new file mode 100644 index 0000000..43d65b2 --- /dev/null +++ b/bot/settings/setting_types.py @@ -0,0 +1,632 @@ +from enum import IntEnum +from typing import Any, Optional + +import pytz +import discord +from cmdClient.Context import Context +from cmdClient.lib import SafeCancellation + +from meta import client +import Timer + +from .base import UserInputError + + +class SettingType: + """ + Abstract class representing a setting type. + Intended to be used as a mixin for a Setting, + with the provided methods implementing converter methods for the setting. + """ + accepts: str = None # User readable description of the acceptable values + + # Raw converters + @classmethod + def _data_from_value(cls, id: int, value, **kwargs): + """ + Convert a high-level setting value to internal data. + """ + raise NotImplementedError + + @classmethod + def _data_to_value(cls, id: int, data: Any, **kwargs): + """ + Convert internal data to high-level setting value. + """ + raise NotImplementedError + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Parse user provided input into internal data. + """ + raise NotImplementedError + + @classmethod + def _format_data(cls, id: int, data: Any, **kwargs): + """ + Convert internal data into a formatted user-readable string. + """ + raise NotImplementedError + + +class Boolean(SettingType): + """ + Boolean type, supporting truthy and falsey user input. + Configurable to change truthy and falsey values, and the output map. + + Types: + data: Optional[bool] + The stored boolean value. + value: Optional[bool] + The stored boolean value. + """ + accepts = "Yes/No, On/Off, True/False, Enabled/Disabled" + + # Values that are accepted as truthy and falsey by the parser + _truthy = {"yes", "true", "on", "enable", "enabled"} + _falsey = {"no", "false", "off", "disable", "disabled"} + + # The user-friendly output strings to use for each value + _outputs = {True: "On", False: "Off", None: "Not Set"} + + @classmethod + def _data_from_value(cls, id: int, value: Optional[bool], **kwargs): + """ + Both data and value are of type Optional[bool]. + Directly return the provided value as data. + """ + return value + + @classmethod + def _data_to_value(cls, id: int, data: Optional[bool], **kwargs): + """ + Both data and value are of type Optional[bool]. + Directly return the internal data as the value. + """ + return data + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Looks up the provided string in the truthy and falsey tables. + """ + _userstr = userstr.lower() + if _userstr == "none": + return None + if _userstr in cls._truthy: + return True + elif _userstr in cls._falsey: + return False + else: + raise UserInputError("Unknown boolean type `{}`".format(userstr)) + + @classmethod + def _format_data(cls, id: int, data: bool, **kwargs): + """ + Pass the provided value through the outputs map. + """ + return cls._outputs[data] + + +class Integer(SettingType): + """ + Integer type. Storing any integer. + + Types: + data: Optional[int] + The stored integer value. + value: Optional[int] + The stored integer value. + """ + accepts = "An integer." + + # Set limits on the possible integers + _min = -4096 + _max = 4096 + + @classmethod + def _data_from_value(cls, id: int, value: Optional[bool], **kwargs): + """ + Both data and value are of type Optional[int]. + Directly return the provided value as data. + """ + return value + + @classmethod + def _data_to_value(cls, id: int, data: Optional[bool], **kwargs): + """ + Both data and value are of type Optional[int]. + Directly return the internal data as the value. + """ + return data + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Relies on integer casting to convert the user string + """ + if userstr.lower() == "none": + return None + + try: + num = int(userstr) + except Exception: + raise UserInputError("Couldn't parse provided integer.") from None + + if num > cls._max: + raise UserInputError("Provided integer was too large!") + elif num < cls._min: + raise UserInputError("Provided integer was too small!") + + return num + + @classmethod + def _format_data(cls, id: int, data: Optional[int], **kwargs): + """ + Return the string version of the data. + """ + if data is None: + return None + else: + return str(data) + + +class String(SettingType): + """ + String type, storing arbitrary text. + Configurable to limit text length and restrict input options. + + Types: + data: Optional[str] + The stored string. + value: Optional[str] + The stored string. + """ + accepts = "Any text" + + # Maximum length of string to accept + _maxlen: int = None + + # Set of input options to accept + _options: set = None + + # Whether to quote the string as code + _quote: bool = True + + @classmethod + def _data_from_value(cls, id: int, value: Optional[str], **kwargs): + """ + Return the provided value string as the data string. + """ + return value + + @classmethod + def _data_to_value(cls, id: int, data: Optional[str], **kwargs): + """ + Return the provided data string as the value string. + """ + return data + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Check that the user-entered string is of the correct length. + Accept "None" to unset. + """ + if userstr.lower() == "none": + # Unsetting case + return None + elif cls._maxlen is not None and len(userstr) > cls._maxlen: + raise UserInputError("Provided string was too long! Maximum length is `{}`".format(cls._maxlen)) + elif cls._options is not None and not userstr.lower() in cls._options: + raise UserInputError("Invalid option! Valid options are `{}`".format("`, `".join(cls._options))) + else: + return userstr + + @classmethod + def _format_data(cls, id: int, data: str, **kwargs): + """ + Wrap the string in backtics for formatting. + Handle the special case where the string is empty. + """ + if data: + return "`{}`".format(data) if cls._quote else str(data) + else: + return None + + +class Channel(SettingType): + """ + Channel type, storing a single `discord.Channel`. + + Types: + data: Optional[int] + The id of the stored Channel. + value: Optional[discord.abc.GuildChannel] + The stored Channel. + """ + accepts = "Channel mention/id/name, or 'None' to unset" + + # Type of channel, if any + _chan_type: discord.ChannelType = None + + @classmethod + def _data_from_value(cls, id: int, value: Optional[discord.abc.GuildChannel], **kwargs): + """ + Returns the channel id. + """ + return value.id if value is not None else None + + @classmethod + def _data_to_value(cls, id: int, data: Optional[int], **kwargs): + """ + Uses the client to look up the channel id. + Returns the Channel if found, otherwise None. + """ + # Always passthrough None + if data is None: + return None + + return client.get_channel(data) + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Pass to the channel seeker utility to find the requested channel. + Handle `0` and variants of `None` to unset. + """ + if userstr.lower() in ('0', 'none'): + return None + else: + channel = await ctx.find_channel(userstr, interactive=True, chan_type=cls._chan_type) + if channel is None: + raise SafeCancellation + else: + return channel.id + + @classmethod + def _format_data(cls, id: int, data: Optional[int], **kwargs): + """ + Retrieve an artificially created channel mention. + If the channel does not exist, this will show up as invalid-channel. + """ + if data is None: + return None + else: + return "<#{}>".format(data) + + +class Role(SettingType): + """ + Role type, storing a single `discord.Role`. + Configurably allows returning roles which don't exist or are not seen by the client + as `discord.Object`. + + Settings may override `get_guildid` if the setting object `id` is not the guildid. + + Types: + data: Optional[int] + The id of the stored Role. + value: Optional[Union[discord.Role, discord.Object]] + The stored Role, or, if the role wasn't found and `_strict` is not set, + a discord Object with the role id set. + """ + accepts = "Role mention/id/name, or 'None' to unset" + + # Whether to disallow returning roles which don't exist as `discord.Object`s + _strict = True + + @classmethod + def _data_from_value(cls, id: int, value: Optional[discord.Role], **kwargs): + """ + Returns the role id. + """ + return value.id if value is not None else None + + @classmethod + def _data_to_value(cls, id: int, data: Optional[int], **kwargs): + """ + Uses the client to look up the guild and role id. + Returns the role if found, otherwise returns a `discord.Object` with the id set, + depending on the `_strict` setting. + """ + # Always passthrough None + if data is None: + return None + + # Fetch guildid + guildid = cls._get_guildid(id, **kwargs) + + # Search for the role + role = None + guild = client.get_guild(guildid) + if guild is not None: + role = guild.get_role(data) + + if role is not None: + return role + elif not cls._strict: + return discord.Object(id=data) + else: + return None + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Pass to the role seeker utility to find the requested role. + Handle `0` and variants of `None` to unset. + """ + if userstr.lower() in ('0', 'none'): + return None + else: + role = await ctx.find_role(userstr, create=False, interactive=True) + if role is None: + raise SafeCancellation + else: + return role.id + + @classmethod + def _format_data(cls, id: int, data: Optional[int], **kwargs): + """ + Retrieve the role mention if found, otherwise the role id or None depending on `_strict`. + """ + role = cls._data_to_value(id, data, **kwargs) + if role is None: + return "Not Set" + elif isinstance(role, discord.Role): + return role.mention + else: + return "`{}`".format(role.id) + + @classmethod + def _get_guildid(cls, id: int, **kwargs): + """ + Fetch the current guildid. + Assumes that the guilid is either passed as a kwarg or is the object id. + Should be overriden in other cases. + """ + return kwargs.get('guildid', id) + + +class Emoji(SettingType): + """ + Emoji type. Stores both custom and unicode emojis. + """ + accepts = "Emoji, either built in or custom. Use 'None' to unset." + + @staticmethod + def _parse_emoji(emojistr): + """ + Converts a provided string into a PartialEmoji. + If the string is badly formatted, returns None. + """ + if ":" in emojistr: + emojistr = emojistr.strip('<>') + splits = emojistr.split(":") + if len(splits) == 3: + animated, name, id = splits + animated = bool(animated) + return discord.PartialEmoji(name, animated=animated, id=int(id)) + else: + # TODO: Check whether this is a valid emoji + return discord.PartialEmoji(emojistr) + + @classmethod + def _data_from_value(cls, id: int, value: Optional[discord.PartialEmoji], **kwargs): + """ + Both data and value are of type Optional[discord.PartialEmoji]. + Directly return the provided value as data. + """ + return value + + @classmethod + def _data_to_value(cls, id: int, data: Optional[discord.PartialEmoji], **kwargs): + """ + Both data and value are of type Optional[discord.PartialEmoji]. + Directly return the internal data as the value. + """ + return data + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Pass to the emoji string parser to get the emoji. + Handle `0` and variants of `None` to unset. + """ + if userstr.lower() in ('0', 'none'): + return None + else: + return cls._parse_emoji(userstr) + + @classmethod + def _format_data(cls, id: int, data: Optional[discord.PartialEmoji], **kwargs): + """ + Return a string form of the partial emoji, which generally displays the emoji. + """ + if data is None: + return None + else: + return str(data) + + +class PatternType(SettingType): + """ + Pattern type. Stores a valid Timer Pattern. + + Types: + data: Optional[int] + The stored patternid + value: Optional[Timer.Pattern] + The Timer.Pattern stored + """ + accepts = "A valid timer pattern." + + @classmethod + def _data_from_value(cls, id: int, value, **kwargs): + if value is not None: + return value.row.patternid + + @classmethod + def _data_to_value(cls, id: int, data: Optional[int], **kwargs): + return Timer.Pattern.get(data) if data is not None else None + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + if userstr.lower() == "none": + return None + + pattern = Timer.Pattern.from_userstr( + userstr, + userid=ctx.author.id, + guildid=ctx.guild.id if ctx.guild else None, + timerid=id + ) + + return pattern.row.patternid + + @classmethod + def _format_data(cls, id: int, data: Optional[int], brief=True, **kwargs): + """ + Return the string version of the data. + """ + if data is not None: + return cls._data_to_value(id, data, **kwargs).display(brief=brief) + + +class Timezone(SettingType): + """ + Timezone type, storing a valid timezone string. + + Types: + data: Optional[str] + The string representing the timezone in POSIX format. + value: Optional[timezone] + The pytz timezone. + """ + accepts = ( + "A timezone name from [this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) " + "(e.g. `Europe/London`)." + ) + + @classmethod + def _data_from_value(cls, id: int, value: Optional[str], **kwargs): + """ + Return the provided value string as the data string. + """ + if value is not None: + return str(value) + + @classmethod + def _data_to_value(cls, id: int, data: Optional[str], **kwargs): + """ + Return the provided data string as the value string. + """ + if data is not None: + return pytz.timezone(data) + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Check that the user-entered string is of the correct length. + Accept "None" to unset. + """ + if userstr.lower() == "none": + # Unsetting case + return None + try: + timezone = pytz.timezone(userstr) + except pytz.exceptions.UnknownTimeZoneError: + timezones = [tz for tz in pytz.all_timezones if userstr.lower() in tz.lower()] + if len(timezones) == 1: + timezone = timezones[0] + elif timezones: + result = await ctx.selector( + "Multiple matching timezones found, please select one.", + timezones + ) + timezone = timezones[result] + else: + raise UserInputError( + "Unknown timezone `{}`. " + "Please provide a TZ name from " + "[this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones)".format(userstr) + ) from None + + return str(timezone) + + @classmethod + def _format_data(cls, id: int, data: str, **kwargs): + """ + Wrap the string in backtics for formatting. + Handle the special case where the string is empty. + """ + if data: + return "`{}`".format(data) + else: + return 'Not Set' + + +class IntegerEnum(SettingType): + """ + Integer Enum type, accepting limited strings, storing an integer, and returning an IntEnum value + + Types: + data: Optional[int] + The stored integer. + value: Optional[Any] + The corresponding Enum member + """ + accepts = "A valid option." + + # Enum to use for mapping values + _enum: IntEnum = None + + # Custom map to format the value. If None, uses the enum names. + _output_map = None + + @classmethod + def _data_from_value(cls, id: int, value: ..., **kwargs): + """ + Return the value corresponding to the enum member + """ + if value is not None: + return value.value + + @classmethod + def _data_to_value(cls, id: int, data: ..., **kwargs): + """ + Return the enum member corresponding to the provided integer + """ + if data is not None: + return cls._enum(data) + + @classmethod + async def _parse_userstr(cls, ctx: Context, id: int, userstr: str, **kwargs): + """ + Find the corresponding enum member's value to the provided user input. + Accept "None" to unset. + """ + userstr = userstr.lower() + + options = {name.lower(): mem.value for name, mem in cls._enum.__members__.items()} + + if userstr == "none": + # Unsetting case + return None + elif userstr not in options: + raise UserInputError("Invalid option!") + else: + return options[userstr] + + @classmethod + def _format_data(cls, id: int, data: int, **kwargs): + """ + Format the data using either the `_enum` or the provided output map. + """ + if data is not None: + value = cls._enum(data) + if cls._output_map: + return cls._output_map[value] + else: + return "`{}`".format(value.name) diff --git a/bot/settings/timer_settings.py b/bot/settings/timer_settings.py new file mode 100644 index 0000000..8c57559 --- /dev/null +++ b/bot/settings/timer_settings.py @@ -0,0 +1,320 @@ +from data import tables +from utils.lib import prop_tabulate +from wards import timer_admin +from utils.lib import DotDict + +from .base import Setting, ObjectSettings, ColumnData +from .setting_types import Boolean, String, Channel, PatternType, UserInputError + + +class TimerSettings(ObjectSettings): + settings = DotDict() + + +class TimerSetting(ColumnData, Setting): + _table_interface = tables.timers + _id_column = 'roleid' + _upsert = False + + write_ward = timer_admin + + category: str = 'Misc' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._timer = kwargs.get('timer', None) + + @property + def timer(self): + if self._timer is None: + self._timer = self.client.interface.fetch_timer(self.id) + return self._timer + + @property + def embed(self): + embed = super().embed + embed.title = "Configuration options for `{}` in `{}`".format( + self.display_name, + self.timer.name + ) + return embed + + def write(self, **kwargs): + super().write(**kwargs) + self.timer.load() + + +@TimerSettings.attach_setting +class name(String, TimerSetting): + _data_column = 'name' + + category = 'Core' + display_name = 'name' + desc = "Name of the study group." + + long_desc = "Name of the group, shown in timer messages and used to join the timer." + + @classmethod + async def _parse_userstr(cls, ctx, id: int, userstr: str, **kwargs): + name = await super()._parse_userstr(ctx, id, userstr, **kwargs) + if name is None or not name: + raise UserInputError("Timer name cannot be none or empty!") + if len(name) > 30: + raise UserInputError("Timer name must be between `1` and `30` characters long!") + if name.lower() in (timer.name.lower() for timer in ctx.timers.get_timers_in(ctx.guild.id)): + raise UserInputError("Another timer already exists with this name!") + return name + + +@TimerSettings.attach_setting +class channel(Channel, TimerSetting): + _data_column = 'channelid' + + category = 'Core' + display_name = 'channel' + desc = "Text channel for timer subscription and messages." + long_desc = ( + "Text channel where the timer sends stage updates and other messages. " + "Unless the `globalgroups` server option is set, " + "the timer may also only be joined from this channel." + ) + + @classmethod + async def _parse_userstr(cls, ctx, id: int, userstr: str, **kwargs): + channelid = await super()._parse_userstr(ctx, id, userstr, **kwargs) + if channelid: + channel = ctx.guild.get_channel(channelid) + chan_perms = channel.permissions_for(ctx.guild.me) + if not chan_perms.read_messages: + raise UserInputError("Cannot read messages in {}.".format(channel.mention)) + elif not chan_perms.send_messages: + raise UserInputError("Cannot send messages in {}.".format(channel.mention)) + elif not chan_perms.read_message_history: + raise UserInputError("Cannot read message history in {}.".format(channel.mention)) + elif not chan_perms.embed_links: + raise UserInputError("Cannot send embeds in {}.".format(channel.mention)) + elif not chan_perms.manage_messages: + raise UserInputError("Cannot manage messages in {}.".format(channel.mention)) + else: + raise UserInputError("Timer channel cannot be empty!") + return channelid + + def write(self, **kwargs): + self.client.interface.move_timer(self.timer, self.data) + + +@TimerSettings.attach_setting +class default_pattern(PatternType, TimerSetting): + _data_column = 'patternid' + + category = 'Core' + display_name = 'default_pattern' + desc = "Default timer pattern to use when timer is reset." + + long_desc = ( + "The timer pattern applied when the timer is reset.\n" + "The timer may be reset either manually through the `reset` command, " + "or automatically if the `auto_reset` timer setting is on." + ) + + +@TimerSettings.attach_setting +class auto_reset(Boolean, TimerSetting): + _data_column = 'auto_reset' + + _default = False + + category = 'Core' + display_name = 'auto_reset' + desc = "Automatically reset when there are no members." + + long_desc = ( + "Automatically reset empty timers to their default pattern.\n" + "When set, the timer will automatically stop and reset itself to the default pattern " + "when it is empty. The reset occurs on the next stage change." + ) + + +@TimerSettings.attach_setting +class admin_locked(Boolean, TimerSetting): + _data_column = 'admin_locked' + + _default = False + + category = 'Core' + display_name = 'admin_locked' + desc = "Whether timer members are restricted from controlling the timer." + + long_desc = ( + "Whether timer admin permissions are required to control the timer.\n" + "When this is set, all **Timer Control** commands (such as `start`, `skip`, and `stop`) " + "require timer admin permissions. This essentially makes the timer 'static', " + "locked to a fixed pattern, and not modifiable by regular members.\n" + "There is one exception to the timer control rule, " + "in that members may start a timer that has been stopped (but not change its pattern). " + "This is to support use of the `auto_reset` setting." + ) + + +@TimerSettings.attach_setting +class voice_channel(Channel, TimerSetting): + _data_column = 'voice_channelid' + + category = 'Voice' + display_name = 'voice_channel' + desc = "Associated voice channel for alerts and auto-subscriptions." + + long_desc = ( + "Voice channel used for alerts and automatic subscriptions.\n" + "When set, this channel will be used for voice alerts when changing stage (see `voice_alerts`), " + "and automatic (un)subscriptions when members join or leave the channel " + "(see `track_vc_join` and `track_vc_leave`).\n" + "The name of the voice channel will also be updated to reflect the timer status (see `vc_name`).\n" + "To avoid ambiguitiy, each voice channel can be bound to at most one group." + ) + + @classmethod + async def _parse_userstr(cls, ctx, id: int, userstr: str, **kwargs): + channelid = await super()._parse_userstr(ctx, id, userstr, **kwargs) + if channelid: + channel = ctx.guild.get_channel(channelid) + # Check whether another timer exists with this voice channel + other = next( + (timer for timer in ctx.timers.get_timers_in(ctx.guild.id) if timer.voice_channelid == channel.id), + None + ) + if other is not None: + raise UserInputError("{} is already bound to the group **{}**".format(channel.mention, other.name)) + + # Check voice channel permissions + voice_perms = channel.permissions_for(ctx.guild.me) + if not voice_perms.connect: + raise UserInputError("Cannot connect to {}.".format(channel.mention)) + elif not voice_perms.speak: + raise UserInputError("Cannot speak in {}.".format(channel.mention)) + elif not voice_perms.view_channel: + raise UserInputError("Cannot see {}.".format(channel.mention)) + return channelid + + +@TimerSettings.attach_setting +class voice_alert(Boolean, TimerSetting): + _data_column = 'voice_alert' + + _default = True + + category = 'Voice' + display_name = 'voice_alerts' + desc = "Emit voice alerts on stage changes." + + long_desc = ( + "When set, the bot will join the voice channel and emit an audio alert upon each stage change." + ) + + +@TimerSettings.attach_setting +class track_voice_join(Boolean, TimerSetting): + _data_column = 'track_voice_join' + + _default = True + + category = 'Voice' + display_name = 'track_vc_join' + desc = "Automatically subscribe members joining the voice channel." + + long_desc = ( + "Whether to automatically subscribe members joining the voice channel." + ) + + +@TimerSettings.attach_setting +class track_voice_leave(Boolean, TimerSetting): + _data_column = 'track_voice_leave' + + _default = True + + category = 'Voice' + display_name = 'track_vc_leave' + desc = "Automatically unsubscribe members leaving the voice channel." + + long_desc = ( + "Whether to automatically unsubscribe members leaving the voice channel." + ) + + +@TimerSettings.attach_setting +class compact(Boolean, TimerSetting): + _data_column = 'compact' + + _default = False + + category = 'Format' + display_name = 'compact' + desc = "Use a more compact format for timer messages." + + long_desc = ( + "Whether to use a more compact format on timer messages, including " + "stage change messages and subscription/unsubscription messages. " + "Some information is lost, but this is generally safe to use on " + "servers where the members have experience with the timer." + ) + + +@TimerSettings.attach_setting +class vc_name(String, TimerSetting): + _data_column = 'voice_channel_name' + + _default = "{name} - {stage_name} ({remaining})" + + category = 'Format' + display_name = 'vc_name' + desc = "Updating name for the associated voice channel." + accepts = "A short text string, accepting the following substitutions." + + long_desc = ( + "When a voice channel is associated to the timer (see `voice_channel`), " + "the name of the voice channel will be updated to reflect the current status of the timer. " + "This setting controls the format of that name.\n" + "*Note that due to discord restrictions the name can update at most once per 10 minutes. " + "The remaining time property will thus generally be inaccurate.*" + ) + + @classmethod + async def _parse_userstr(cls, ctx, id: int, userstr: str, **kwargs): + name = await super()._parse_userstr(ctx, id, userstr, **kwargs) + + if not (2 < len(name) < 100): + raise UserInputError("Channel names must be between `2` and `100` characters long.") + + return name + + @property + def embed(self): + embed = super().embed + + fields = ("Current value", "Preview", "Default value", "Accepted input") + values = (self.formatted or "Not Set", + "`{}`".format(self.timer.voice_channel_name), + self._format_data(self.id, self.default) or "None", + self.accepts) + table = prop_tabulate(fields, values) + embed.description = "{}\n{}".format(self.long_desc, table) + + subs = { + '{name}': "Name of the study group.", + '{stage_name}': "Name of the current stage.", + '{stage_dur}': "Duration of the current stage.", + '{remaining}': "(Approximate) number of minutes left in the stage.", + '{sub_count}': "Number of members subscribed to the group.", + '{pattern}': "Short-form of the current timer pattern." + } + + embed.add_field( + name="Accepted substitutions.", + value=prop_tabulate(*zip(*subs.items())) + ) + return embed + + +# TODO: default_work_name, default_work_message, default_break_name, default_break_message +# Q: How do we update the default pattern with this info? Maybe we should use placeholders for the defaults instead? diff --git a/bot/settings/user_settings.py b/bot/settings/user_settings.py new file mode 100644 index 0000000..5aa7d81 --- /dev/null +++ b/bot/settings/user_settings.py @@ -0,0 +1,97 @@ +import datetime +import Timer + +from data import tables +from utils.lib import prop_tabulate, DotDict + +from .base import Setting, ObjectSettings, ColumnData, UserInputError +from .setting_types import Timezone, IntegerEnum + + +class UserSettings(ObjectSettings): + settings = DotDict() + + +class UserSetting(ColumnData, Setting): + _table_interface = tables.users + _id_column = 'userid' + + +@UserSettings.attach_setting +class timezone(Timezone, UserSetting): + attr_name = 'timezone' + _data_column = 'timezone' + + _default = 'UTC' + + display_name = 'timezone' + desc = "Timezone for displaying history and session data." + long_desc = ( + "Timezone used for displaying your historical sessions and study profile." + ) + + @property + def success_response(self): + return ( + "Your personal timezone is now {}. " + "This will apply to your session history and profile, but *not* to the server leaderboard.\n" + "Your current time is **{}**." + ).format(self.formatted, datetime.datetime.now(tz=self.value).strftime("%H:%M")) + + +@UserSettings.attach_setting +class notify_level(IntegerEnum, UserSetting): + attr_name = 'notify_level' + _data_column = 'notify_level' + + _enum = Timer.NotifyLevel + _default = _enum.WARNING + + display_name = 'notify_level' + desc = 'Control when you receive DM stage notifications.' + long_desc = ( + "Control when you receive notifications " + "via DM when a timer you are in changes stage (e.g. from `Work` to `Break`)." + ) + + accepts = "One of the following options." + accepted_dict = { + 'all': "Receive all stage changes and status updates via DM.", + 'warning': "Only receive a DM for inactivity warnings.", + 'final': "Only receive a DM after being kicked for inactivity.", + 'never': "Never receive status updates via DM." + } + accepted_table = prop_tabulate(*zip(*accepted_dict.items())) + + success_responses = { + _enum.ALL: "You will receive all stage changes and status updates via DM.", + _enum.WARNING: "You will only receive a DM for inactivity warnings.", + _enum.FINAL: "You will only receive a DM after being kicked for inactivity.", + _enum.NEVER: "You will never receive status updates via DM." + } + + @property + def embed(self): + embed = super().embed + embed.add_field( + name="Accepted Values", + value=self.accepted_table + ) + return embed + + @property + def success_response(self): + return ( + "Your notification level is now {}.\n{}" + ).format(self.formatted, self.success_responses[self.value]) + + @classmethod + async def _parse_userstr(cls, ctx, id, userstr, **kwargs): + try: + value = await super()._parse_userstr(ctx, id, userstr, **kwargs) + except UserInputError: + raise UserInputError( + "Unrecognised notification level `{}`. " + "Please use one of the options below.\n{}".format(userstr, cls.accepted_table) + ) from None + return value diff --git a/bot/utils/ctx_addons.py b/bot/utils/ctx_addons.py index 37a27cb..b6fdae2 100644 --- a/bot/utils/ctx_addons.py +++ b/bot/utils/ctx_addons.py @@ -1,10 +1,13 @@ -import asyncio import discord from cmdClient import Context +from data import tables + +from settings import GuildSettings, UserSettings + @Context.util -async def embedreply(ctx, desc, colour=discord.Colour(0x9b59b6), **kwargs): +async def embed_reply(ctx, desc, colour=discord.Colour(0x9b59b6), **kwargs): """ Simple helper to embed replies. All arguments are passed to the embed constructor. @@ -15,59 +18,51 @@ async def embedreply(ctx, desc, colour=discord.Colour(0x9b59b6), **kwargs): @Context.util -async def live_reply(ctx, reply_func, update_interval=5, max_messages=20): +async def error_reply(ctx, error_str, **kwargs): """ - Acts as `ctx.reply`, but asynchronously updates the reply every `update_interval` seconds - with the value of `reply_func`, until the value is `None`. - - Parameters - ---------- - reply_func: coroutine - An async coroutine with no arguments. - Expected to return a dictionary of arguments suitable for `ctx.reply()` and `Message.edit()`. - update_interval: int - An integer number of seconds. - max_messages: int - Maximum number of messages in channel to keep the reply live for. - - Returns - ------- - The output message after the first reply. + Notify the user of a user level error. + Typically, this will occur in a red embed, posted in the command channel. """ - # Send the initial message - message = await ctx.reply(**(await reply_func())) + embed = discord.Embed( + colour=discord.Colour.red(), + description=error_str + ) + try: + message = await ctx.ch.send(embed=embed, reference=ctx.msg, **kwargs) + ctx.sent_messages.append(message) + return message + except discord.Forbidden: + message = await ctx.reply(error_str) + ctx.sent_messages.append(message) + return message - # Start the counter - future = asyncio.ensure_future(_message_counter(ctx.client, ctx.ch, max_messages)) - # Build the loop function - async def _reply_loop(): - while not future.done(): - await asyncio.sleep(update_interval) - args = await reply_func() - if args is not None: - await message.edit(**args) - else: - break +def context_property(func): + setattr(Context, func.__name__, property(func)) + return func - # Start the loop - asyncio.ensure_future(_reply_loop()) - # Return the original message - return message +@context_property +def best_prefix(ctx): + guild_prefix = tables.guilds.fetch_or_create(ctx.guild.id).prefix if ctx.guild else '' + return guild_prefix or ctx.client.prefix -async def _message_counter(client, channel, max_count): - """ - Helper for live_reply - """ - # Build check function - def _check(message): - return message.channel == channel - - # Loop until the message counter reaches maximum - count = 0 - while count < max_count: - await client.wait_for('message', check=_check) - count += 1 - return +@context_property +def example_group_name(ctx): + name = "AwesomeStudyGroup" + if ctx.guild: + groups = ctx.timers.get_timers_in(ctx.guild.id) + if groups: + name = groups[0].name + return name + + +@context_property +def guild_settings(ctx): + return GuildSettings(ctx.guild.id if ctx.guild else 0) + + +@context_property +def author_settings(ctx): + return UserSettings(ctx.author.id) diff --git a/bot/utils/interactive.py b/bot/utils/interactive.py index 496368d..af0be6c 100644 --- a/bot/utils/interactive.py +++ b/bot/utils/interactive.py @@ -5,6 +5,62 @@ from .lib import paginate_list +# TODO: Interactive locks +cancel_emoji = '❌' +number_emojis = ( + '1️⃣', '2️⃣', '3️⃣', '4️⃣', '5️⃣', '6️⃣', '7️⃣', '8️⃣', '9️⃣' +) + + +async def discord_shield(coro): + try: + await coro + except discord.HTTPException: + pass + + +@Context.util +async def cancellable(ctx, msg, add_reaction=True, cancel_message=None, timeout=300): + """ + Add a cancellation reaction to the given message. + Pressing the reaction triggers cancellation of the original context, and a UserCancelled-style error response. + """ + # TODO: Not consistent with the exception driven flow, make a decision here? + # Add reaction + if add_reaction and cancel_emoji not in (str(r.emoji) for r in msg.reactions): + try: + await msg.add_reaction(cancel_emoji) + except discord.HTTPException: + return + + # Define cancellation function + async def _cancel(): + try: + await ctx.client.wait_for( + 'reaction_add', + timeout=timeout, + check=lambda r, u: (u == ctx.author + and r.message == msg + and str(r.emoji) == cancel_emoji) + ) + except asyncio.TimeoutError: + pass + else: + await ctx.client.active_command_response_cleaner(ctx) + if cancel_message: + await ctx.error_reply(cancel_message) + else: + try: + await ctx.msg.add_reaction(cancel_emoji) + except discord.HTTPException: + pass + [task.cancel() for task in ctx.tasks] + + # Launch cancellation task + task = asyncio.create_task(_cancel()) + ctx.tasks.append(task) + return task + @Context.util async def listen_for(ctx, allowed_input=None, timeout=120, lower=True, check=None): @@ -94,40 +150,72 @@ async def selector(ctx, header, select_from, timeout=120, max_len=20): raise ValueError("Selection list passed to `selector` cannot be empty.") # Generate the selector pages - footer = "Please type the number corresponding to your selection, or type `c` now to cancel." + footer = "Please reply with the number of your selection, or press {} to cancel.".format(cancel_emoji) list_pages = paginate_list(select_from, block_length=max_len) pages = ["\n".join([header, page, footer]) for page in list_pages] # Post the pages in a paged message - out_msg = await ctx.pager(pages) + out_msg = await ctx.pager(pages, add_cancel=True) + cancel_task = await ctx.cancellable(out_msg, add_reaction=False, timeout=None) - # Listen for valid input - valid_input = [str(i+1) for i in range(0, len(select_from))] + ['c', 'C'] - try: - result_msg = await ctx.listen_for(valid_input, timeout=timeout) - except ResponseTimedOut: - raise ResponseTimedOut("Selector timed out waiting for a response.") + if len(select_from) <= 5: + for i, _ in enumerate(select_from): + asyncio.create_task(discord_shield(out_msg.add_reaction(number_emojis[i]))) - # Try and delete the selector message and the user response. + # Build response tasks + valid_input = [str(i+1) for i in range(0, len(select_from))] + ['c', 'C'] + listen_task = asyncio.create_task(ctx.listen_for(valid_input, timeout=None)) + emoji_task = asyncio.create_task(ctx.client.wait_for( + 'reaction_add', + check=lambda r, u: (u == ctx.author + and r.message == out_msg + and str(r.emoji) in number_emojis) + )) + # Wait for the response tasks + done, pending = await asyncio.wait( + (listen_task, emoji_task), + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED + ) + + # Cleanup try: await out_msg.delete() - await result_msg.delete() - except discord.NotFound: - pass - except discord.Forbidden: + except discord.HTTPException: pass - # Handle user cancellation - if result_msg.content in ['c', 'C']: - raise UserCancelled("User cancelled selection.") + # Handle different return cases + if listen_task in done: + emoji_task.cancel() - # The content must now be a valid index. Collect and return it. - index = int(result_msg.content) - 1 - return index + result_msg = listen_task.result() + try: + await result_msg.delete() + except discord.HTTPException: + pass + if result_msg.content.lower() == 'c': + raise UserCancelled("Selection cancelled!") + result = int(result_msg.content) - 1 + elif emoji_task in done: + listen_task.cancel() + + reaction, _ = emoji_task.result() + result = number_emojis.index(str(reaction.emoji)) + elif cancel_task in done: + # Manually cancelled case.. the current task should have been cancelled + # Raise UserCancelled in case the task wasn't cancelled for some reason + raise UserCancelled("Selection cancelled!") + elif not done: + # Timeout case + raise ResponseTimedOut("Selector timed out waiting for a response.") + + # Finally cancel the canceller and return the provided index + cancel_task.cancel() + return result @Context.util -async def pager(ctx, pages, locked=True, **kwargs): +async def pager(ctx, pages, locked=True, start_at=0, add_cancel=False, **kwargs): """ Shows the user each page from the provided list `pages` one at a time, providing reactions to page back and forth between pages. @@ -150,25 +238,28 @@ async def pager(ctx, pages, locked=True, **kwargs): raise ValueError("Pager cannot page with no pages!") # Post first page. Method depends on whether the page is an embed or not. - if isinstance(pages[0], discord.Embed): - out_msg = await ctx.reply(embed=pages[0]) + if isinstance(pages[start_at], discord.Embed): + out_msg = await ctx.reply(embed=pages[start_at], **kwargs) else: - out_msg = await ctx.reply(pages[0]) + out_msg = await ctx.reply(pages[start_at], **kwargs) # Run the paging loop if required if len(pages) > 1: - asyncio.ensure_future(_pager(ctx, out_msg, pages, locked)) + task = asyncio.create_task(_pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs)) + ctx.tasks.append(task) + elif add_cancel: + await out_msg.add_reaction(cancel_emoji) # Return the output message return out_msg -async def _pager(ctx, out_msg, pages, locked): +async def _pager(ctx, out_msg, pages, locked, start_at, add_cancel, **kwargs): """ Asynchronous initialiser and loop for the `pager` utility above. """ # Page number - page = 0 + page = start_at # Add reactions to the output message next_emoji = "▶" @@ -176,11 +267,30 @@ async def _pager(ctx, out_msg, pages, locked): try: await out_msg.add_reaction(prev_emoji) - await out_msg.add_reaction( next_emoji) + if add_cancel: + await out_msg.add_reaction(cancel_emoji) + await out_msg.add_reaction(next_emoji) except discord.Forbidden: # We don't have permission to add paging emojis # Die as gracefully as we can - await ctx.error_reply("Cannot page results because I do not have permissions to react!") + if ctx.guild: + perms = ctx.ch.permissions_for(ctx.guild.me) + if not perms.add_reactions: + await ctx.error_reply( + "Cannot page results because I do not have the `add_reactions` permission!" + ) + elif not perms.read_message_history: + await ctx.error_reply( + "Cannot page results because I do not have the `read_message_history` permission!" + ) + else: + await ctx.error_reply( + "Cannot page results due to insufficient permissions!" + ) + else: + await ctx.error_reply( + "Cannot page results!" + ) return # Check function to determine whether a reaction is valid @@ -209,9 +319,9 @@ def check(reaction, user): # Edit the message with the new page active_page = pages[page] if isinstance(active_page, discord.Embed): - await out_msg.edit(embed=active_page) + await out_msg.edit(embed=active_page, **kwargs) else: - await out_msg.edit(content=active_page) + await out_msg.edit(content=active_page, **kwargs) # Clean up by removing the reactions try: @@ -225,6 +335,7 @@ def check(reaction, user): except discord.NotFound: pass + @Context.util async def input(ctx, msg="", timeout=120): """ diff --git a/bot/utils/lib.py b/bot/utils/lib.py index 99ec86c..dd0f94e 100644 --- a/bot/utils/lib.py +++ b/bot/utils/lib.py @@ -1,4 +1,5 @@ import datetime +import pytz def prop_tabulate(prop_list, value_list): @@ -46,7 +47,7 @@ def paginate_list(item_list, block_length=20, style="markdown", title=None): List of pages, each formatted into a codeblock, and containing at most `block_length` of the provided strings. """ - lines = ["{0:<5}{1:<5}".format("{}.".format(i + 1), str(line)) for i, line in enumerate(item_list)] + lines = ["{0:<5}{1:<5}".format("{}. ".format(i + 1), str(line)) for i, line in enumerate(item_list)] page_blocks = [lines[i:i + block_length] for i in range(0, len(lines), block_length)] pages = [] for i, block in enumerate(page_blocks): @@ -65,4 +66,13 @@ def timestamp_utcnow(): """ Return the current integer UTC timestamp. """ - return int(datetime.datetime.timestamp(datetime.datetime.utcnow())) + return int(datetime.datetime.now(tz=pytz.utc).timestamp()) + + +class DotDict(dict): + """ + Dict-type allowing dot access to keys. + """ + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ diff --git a/bot/utils/live_messages.py b/bot/utils/live_messages.py new file mode 100644 index 0000000..4c54590 --- /dev/null +++ b/bot/utils/live_messages.py @@ -0,0 +1,33 @@ +import asyncio + +from meta import client + +current_live = {} # token -> task + + +async def live_edit(msg, update_func, label='global', update_interval=5, max_distance=20, **kwargs): + if not msg: + msg = await update_func(None, **kwargs) + if not msg: + return + + token = (msg.channel.id, label) + task = current_live.pop(token, None) + if task is not None: + task.cancel() + + task = current_live[token] = asyncio.create_task(_message_counter(msg, max_distance)) + while not task.done(): + await asyncio.sleep(update_interval) + if await update_func(msg, **kwargs) is None: + task.cancel() + + +async def _message_counter(msg, max_count): + count = 0 + while count < max_count: + try: + await client.wait_for('message', check=lambda m: m.channel == msg.channel) + except asyncio.CancelledError: + break + count += 1 diff --git a/bot/utils/seekers.py b/bot/utils/seekers.py index da2b7d9..f648059 100644 --- a/bot/utils/seekers.py +++ b/bot/utils/seekers.py @@ -1,10 +1,12 @@ +import discord + from cmdClient import Context -from cmdClient.lib import InvalidContext, UserCancelled, ResponseTimedOut +from cmdClient.lib import InvalidContext, UserCancelled, ResponseTimedOut, SafeCancellation from . import interactive @Context.util -async def find_role(ctx, userstr, interactive=False, collection=None): +async def find_role(ctx, userstr, create=False, interactive=False, collection=None, allow_notfound=True): """ Find a guild role given a partial matching string, allowing custom role collections and several behavioural switches. @@ -14,12 +16,18 @@ async def find_role(ctx, userstr, interactive=False, collection=None): userstr: str String obtained from a user, expected to partially match a role in the collection. The string will be tested against both the id and the name of the role. + create: bool + Whether to offer to create the role if it does not exist. + The bot will only offer to create the role if it has the `manage_channels` permission. interactive: bool Whether to offer the user a list of roles to choose from, or pick the first matching role. - collection: List(discord.Role) + collection: List[Union[discord.Role, discord.Object]] Collection of roles to search amongst. If none, uses the guild role list. + allow_notfound: bool + Whether to return `None` when there are no matches, instead of raising `SafeCancellation`. + Overriden by `create`, if it is set. Returns ------- @@ -34,6 +42,8 @@ async def find_role(ctx, userstr, interactive=False, collection=None): If the user cancels interactive role selection. cmdClient.lib.ResponseTimedOut: If the user fails to respond to interactive role selection within `60` seconds` + cmdClient.lib.SafeCancellation: + If `allow_notfound` is `False`, and the search returned no matches. """ # Handle invalid situations and input if not ctx.guild: @@ -43,10 +53,11 @@ async def find_role(ctx, userstr, interactive=False, collection=None): raise ValueError("User string passed to find_role was empty.") # Create the collection to search from args or guild roles - collection = collection if collection else ctx.guild.roles + collection = collection if collection is not None else ctx.guild.roles # If the unser input was a number or possible role mention, get it out - roleid = userstr.strip('<#@&!>') + userstr = userstr.strip() + roleid = userstr.strip('<#@&!> ') roleid = int(roleid) if roleid.isdigit() else None searchstr = userstr.lower() @@ -69,8 +80,10 @@ def check(role): else: # We have multiple matching roles! if interactive: - # Interactive prompt with the list of roles - role_names = [role.name for role in roles] + # Interactive prompt with the list of roles, handle `Object`s + role_names = [ + role.name if isinstance(role, discord.Role) else str(role.id) for role in roles + ] try: selected = await ctx.selector( @@ -88,8 +101,35 @@ def check(role): # Just select the first one role = roles[0] + # Handle non-existence of the role if role is None: - await ctx.error_reply("Couldn't find a role matching `{}`!".format(userstr)) + msgstr = "Couldn't find a role matching `{}`!".format(userstr) + if create: + # Inform the user + msg = await ctx.error_reply(msgstr) + if ctx.guild.me.guild_permissions.manage_roles: + # Offer to create it + resp = await ctx.ask("Would you like to create this role?", timeout=30) + if resp: + # They accepted, create the role + # Before creation, check if the role name is too long + if len(userstr) > 100: + await ctx.error_reply("Could not create a role with a name over 100 characters long!") + else: + role = await ctx.guild.create_role( + name=userstr, + reason="Interactive role creation for {} (uid:{})".format(ctx.author, ctx.author.id) + ) + await msg.delete() + await ctx.reply("You have created the role `{}`!".format(userstr)) + + # If we still don't have a role, cancel unless allow_notfound is set + if role is None and not allow_notfound: + raise SafeCancellation + elif not allow_notfound: + raise SafeCancellation(msgstr) + else: + await ctx.error_reply(msgstr) return role @@ -189,7 +229,7 @@ def check(chan): return chan @Context.util -async def find_member(ctx, userstr, interactive=False, collection=None): +async def find_member(ctx, userstr, interactive=False, collection=None, silent=False): """ Find a guild member given a partial matching string, allowing custom member collections. @@ -205,6 +245,8 @@ async def find_member(ctx, userstr, interactive=False, collection=None): collection: List(discord.Member) Collection of members to search amongst. If none, uses the full guild member list. + silent: bool + Whether to reply with an error when there are no matches. Returns ------- @@ -283,7 +325,7 @@ def check(member): # Just select the first one member = members[0] - if member is None: + if member is None and not silent: await ctx.error_reply("Couldn't find a member matching `{}`!".format(userstr)) return member diff --git a/bot/utils/timer_utils.py b/bot/utils/timer_utils.py index 3c03d89..faad4e6 100644 --- a/bot/utils/timer_utils.py +++ b/bot/utils/timer_utils.py @@ -1,9 +1,11 @@ from cmdClient import Context from cmdClient.lib import UserCancelled, ResponseTimedOut +from data import tables + @Context.util -async def get_timers_matching(ctx, name_str, channel_only=True, info=False): +async def get_timers_matching(ctx, name_str, channel_only=True, info=False, header=None): """ Interactively get a guild timer matching the given string. @@ -26,18 +28,26 @@ async def get_timers_matching(ctx, name_str, channel_only=True, info=False): Raised if the user fails to respond to the selector within `120` seconds. """ # Get the full timer list - if channel_only: - timers = ctx.client.interface.get_channel_timers(ctx.ch.id) - else: - timers = ctx.client.interface.get_guild_timers(ctx.guild.id) + all_timers = ctx.timers.get_timers_in(ctx.guild.id, ctx.ch.id if channel_only else None) # If there are no timers, quit early - if not timers: + if not all_timers: return None # Build a list of matching timers name_str = name_str.strip() - timers = [timer for timer in timers if name_str.lower() in timer.name.lower()] + timers = [ + timer for timer in all_timers + if (name_str.lower() in timer.data.name.lower() + or name_str.lower() in timer.role.name.lower()) + ] + + if not timers: + # Try matching on subscribers instead + timers = [ + timer for timer in all_timers + if any(name_str.lower() in sub.name.lower() for sub in timer.subscribers.values()) + ] if len(timers) == 0: return None @@ -45,15 +55,23 @@ async def get_timers_matching(ctx, name_str, channel_only=True, info=False): return timers[0] else: if info: - select_from = [timer.oneline_summary() for timer in timers] + select_from = [timer.oneline_summary for timer in timers] else: - select_from = [timer.name for timer in timers] + select_from = [timer.data.name for timer in timers] try: - selected = await ctx.selector("Multiple matching groups found, please select one.", select_from) + selected = await ctx.selector(header or "Multiple matching groups found, please select one.", select_from) except ResponseTimedOut: raise ResponseTimedOut("Group selection timed out.") from None except UserCancelled: raise UserCancelled("User cancelled group selection.") from None return timers[selected] + + +async def is_timer_admin(member): + result = member.guild_permissions.administrator + if not result: + tarid = tables.guilds.fetch_or_create(member.guild.id).timer_admin_roleid + result = tarid in (role.id for role in member.roles) + return result diff --git a/bot/wards.py b/bot/wards.py index 35e72dc..1e6bf5e 100644 --- a/bot/wards.py +++ b/bot/wards.py @@ -1,26 +1,41 @@ from cmdClient import check +from cmdClient.checks import in_guild + +from utils.timer_utils import is_timer_admin + + +@check( + name="TIMER_READY", + msg="I am restarting! Please try again in a moment." +) +async def timer_ready(ctx, *args, **kwargs): + return ctx.client.interface.ready @check( name="TIMER_ADMIN", - msg=("You need to have one of the following to use this command.\n" - "- The `manage_guild` permission in this guild.\n" - "- The timer admin role (refer to the `adminrole` command).") + msg=("You need to have one of the following to do this!\n" + "- The `administrator` server permission.\n" + "- The timer admin role (see the `timeradmin` command)."), + requires=[in_guild] ) async def timer_admin(ctx, *args, **kwargs): - if ctx.author.guild_permissions.manage_guild: - return True + return await is_timer_admin(ctx.author) - roleid = ctx.client.config.guilds.get(ctx.guild.id, "timeradmin") - if roleid is None: - return False - return roleid in [r.id for r in ctx.author.roles] +@check( + name="HAS_TIMERS", + msg="No study groups have been created! Create a new group with the `newgroup` command.", + requires=[in_guild, timer_ready] +) +async def has_timers(ctx, *args, **kwargs): + return bool(ctx.timers.get_timers_in(ctx.guild.id)) @check( - name="TIMER_READY", - msg="I am restarting! Please try again in a moment." + name="ADMIN", + msg=("You need to be a server admin to do this!"), + requires=[in_guild] ) -async def timer_ready(ctx, *args, **kwargs): - return ctx.client.interface.ready +async def guild_admin(ctx, *args, **kwargs): + return ctx.author.guild_permissions.administrator diff --git a/data/migration/v0-v1/README.md b/data/migration/v0-v1/README.md new file mode 100644 index 0000000..692ff5a --- /dev/null +++ b/data/migration/v0-v1/README.md @@ -0,0 +1,41 @@ +# Data migration from version 0 to version 1 + +## Summary +Version 1 represents a paradigm shift in how PomoBot's data is stored. +In particular, all user and guild properties are moving into appropriate tables, and the sessions database is merging with the properties database. +Timers now also have configuration properties, and Timer patterns are stored in the database, with presets referencing the central pattern table. +Several views are being added to simplify analysis of the data, both internally and via external programs or services. + + +## Instructions +Copy `sample-migration.conf` to `migration.conf` and edit the data paths as required. Then run `migration.py`. + +## Object migration notes +### Guilds +Originally stored in the `guilds` and `guild_props` tables, guild properties have been split into appropriate tables. +- `timeradmin` property -> `guilds.timer_admin_roleid` + - Straightforward transfer, integer type +- `globalgroups` property -> `guilds.globalgroups` + - Straightforward transfer, boolean type +- `timers` property -> `timers` table + - Originally a json-encoded list of timer data, with each timer encoded as a tuple `[name, roleid, channelid, clock_channelid]`. + - Now each timer is encoded in its own row in `timers`, with each tuple-field given its own column. + - Each new `timer` also holds considerably more data due to the new configuration. +- `timer_presets` property -> `guild_presets` table + - Originally a json-encoded dictionary of the form `presetname: setupstring`. + - Each preset is now given by a single row of `guild_presets`. + - Setupstrings are stored in `patterns` as their associated patterns, and referred to by `patternid`. + +### Users +Originally stored in the `users` and `user_props` tables, user properties have been split into appropriate tables. +- `notify_level` property -> `users.notify_level` + - Straightforward transfer, integer (enum data) type. +- `timer_presets` property -> `user_presets` table + - Originally a json-encoded dictionary of the form `presetname: setupstring`. + - Each preset is now given by a single row of `user_presets`. + - Setupstrings are stored in `patterns` as their associated patterns, and referred to by `patternid`. + +### Sessions +The `sessions` table has been moved from the separate `registry` database into the central database. +Overall the format is the same, with the exception of the `starttime` column being renamed to `start_time`. +The `sessions` table now also tracks the `focused_duration`, `patternid`, and `stages` information for each session. diff --git a/data/migration/v0-v1/lib.py b/data/migration/v0-v1/lib.py new file mode 100644 index 0000000..39a3c92 --- /dev/null +++ b/data/migration/v0-v1/lib.py @@ -0,0 +1,104 @@ +from collections import namedtuple +import datetime +import json +import pytz + + +class Base: + __slots__ = () + + def __init__(self, **kwargs): + for prop in self.__slots__: + setattr(self, prop, kwargs.get(prop, None)) + + +class Guild(Base): + __slots__ = ( + 'guildid', + 'timer_admin_roleid', + 'globalgroups', + 'presets', + ) + + +class Timer(Base): + __slots__ = ( + 'roleid', + 'guildid', + 'name', + 'channelid', + 'voice_channelid' + ) + + +Stage = namedtuple('Stage', ('name', 'duration', 'message', 'focus')) + + +class Pattern(Base): + __slots__ = ( + 'stages', + 'patternid', + 'stage_str' + ) + + pattern_cache = {} # pattern_str -> id + lastid = 0 + + @classmethod + def parse(cls, string): + """ + Parse a setup string into a pattern + """ + # Accepts stages as 'name, length' or 'name, length, message' + stage_blocks = string.strip(';').split(';') + stages = [] + for block in stage_blocks: + # Extract stage components + parts = block.split(',', maxsplit=2) + if len(parts) == 2: + name, dur = parts + message = None + else: + name, dur, message = parts + + # Parse duration + dur = dur.strip() + focus = dur.startswith('*') or dur.endswith('*') + if focus: + dur = dur.strip('* ') + + # Build and add stage + stages.append(Stage(name.strip(), int(dur), (message or '').strip(), focus)) + + stage_str = json.dumps(stages) + if stage_str in cls.pattern_cache: + pattern = cls.pattern_cache[stage_str] + else: + cls.lastid += 1 + pattern = cls(stages=stages, patternid=cls.lastid, stage_str=stage_str) + cls.pattern_cache[stage_str] = pattern + + return pattern + + +class Preset(Base): + __slots__ = ( + 'name', + 'patternid', + ) + + +class User(Base): + __slots__ = ( + 'userid', + 'notify_level', + 'presets' + ) + + +time_diff = ( + int(datetime.datetime.now(tz=pytz.utc).timestamp()) + - int(datetime.datetime.timestamp(datetime.datetime.utcnow())) +) +def adjust_timestamp(ts): + return ts + time_diff diff --git a/data/migration/v0-v1/migration.py b/data/migration/v0-v1/migration.py new file mode 100644 index 0000000..850ff18 --- /dev/null +++ b/data/migration/v0-v1/migration.py @@ -0,0 +1,311 @@ +import os +import json +import configparser as cfgp +import pickle + +import sqlite3 as sq + +import lib + + +CONFFILE = "migration.conf" +DATA_DIR = "../../" + +# Read config file +print("Reading configuration file...", end='') +if not os.path.isfile(CONFFILE): + raise Exception( + "Couldn't find migration configuration file '{}'. " + "Please copy 'sample-migration.conf' to 'migration.conf' and edit as required." + ) + +config = cfgp.ConfigParser() +config.read(CONFFILE) + +orig_settings_path = DATA_DIR + config['Original']['settings_db'] +orig_session_path = DATA_DIR + config['Original']['session_db'] +if config['Original']['savefile'].lower() != 'none': + orig_savefile_path = DATA_DIR + config['Original']['savefile'] +else: + orig_savefile_path = None + +target_database_path = DATA_DIR + config['Target']['database'] +if config['Target']['savefile'] .lower() != 'none': + target_savefile_path = DATA_DIR + config['Target']['savefile'] +else: + target_savefile_path = None + +if not os.path.isfile(orig_session_path): + raise Exception("Provided original sessions database not found.") +if not os.path.isfile(orig_settings_path): + raise Exception("Provided original settings database not found.") +if orig_savefile_path is not None and not os.path.isfile(orig_savefile_path): + raise Exception("Provided original savefile not found.") +if os.path.isfile(target_database_path): + raise Exception("Target database file already exists! Refusing to overwrite.") + +print("Done") + +# Open databases +print("Opening databases...", end='') +orig_session_conn = sq.connect(orig_session_path) +orig_settings_conn = sq.connect(orig_settings_path) +target_database_conn = sq.connect(target_database_path) +print("Done") + +# Initialise the new database +print("Initialising target database...", end='') +with target_database_conn as conn: + with open("v1_schema.sql", 'r') as script: + conn.executescript(script.read()) +print("Done") + + +# ---------------------------------------------------- +# Initial setup done, start migration +# ---------------------------------------------------- +guilds = {} +timers = {} +users = {} + + +# Migrate guild properties +# First read properties +count = 0 +with orig_settings_conn as conn: + cursor = conn.cursor() + rows = cursor.execute( + "SELECT * FROM guilds" + ) + for guildid, prop, value in cursor.fetchall(): + count += 1 + if guildid not in guilds: + guild = guilds[guildid] = lib.Guild(guildid=guildid, presets={}) + else: + guild = guilds[guildid] + + if prop == 'timeradmin': + guild.timer_admin_roleid = json.loads(value) + elif prop == 'globalgroups': + guild.globalgroups = json.loads(value) + elif prop == 'timers': + timer_list = json.loads(value) + if timer_list: + for name, roleid, channelid, vc_channelid in timer_list: + timer = lib.Timer( + roleid=roleid, + guildid=guildid, + name=name, + channelid=channelid, + voice_channelid=vc_channelid + ) + timers[roleid] = timer + elif prop == 'timer_presets': + presets = json.loads(value) + if presets: + for name, setupstr in presets.items(): + pattern = lib.Pattern.parse(setupstr) + while name.lower() in guild.presets: + name += '_' + guild.presets[name.lower()] = lib.Preset(name=name, patternid=pattern.patternid) +print("Read {} guild properties.".format(count)) + +# Insert the guild rows +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO guilds (guildid, timer_admin_roleid, globalgroups) VALUES (?, ?, ?)", + ( + (guildid, guild.timer_admin_roleid, guild.globalgroups) + for guildid, guild in guilds.items() + ) + ) +print("Inserted {} guilds.".format(len(guilds))) + +# Insert the timer rows +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO timers (roleid, guildid, name, channelid, voice_channelid) VALUES (?, ?, ?, ?, ?)", + ( + (timer.roleid, timer.guildid, timer.name, timer.channelid, timer.voice_channelid) + for timer in timers.values() + ) + ) +print("Inserted {} timers.".format(len(timers))) + + +# Read user properties +count = 0 +with orig_settings_conn as conn: + cursor = conn.cursor() + rows = cursor.execute( + "SELECT * FROM users" + ) + for userid, prop, value in cursor.fetchall(): + count += 1 + if userid not in users: + user = users[userid] = lib.User(userid=userid, presets={}) + else: + user = users[userid] + + if prop == 'notify_level': + user.notify_level = json.loads(value) + elif prop == 'timer_presets': + presets = json.loads(value) + if presets: + for name, setupstr in presets.items(): + pattern = lib.Pattern.parse(setupstr) + while name.lower() in user.presets: + name += '_' + user.presets[name.lower()] = lib.Preset(name=name, patternid=pattern.patternid) +print("Read {} user properties.".format(count)) + + +# Insert the user rows +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO users (userid, notify_level) VALUES (?, ?)", + ( + (userid, user.notify_level) + for userid, user in users.items() + ) + ) +print("Inserted {} users.".format(len(users))) + + +# Migrate savedata +save_data = {} +if orig_savefile_path: + with open(orig_savefile_path) as f: + old_data = json.load(f) + flat_timers = old_data['timers'] + flat_subscribers = old_data['subscribers'] + flat_channels = old_data['timer_channels'] + + channels = {} + for channel in flat_channels: + channel_data = {} + channel_data['channelid'] = channel['id'] + channel_data['pinned_msg_id'] = channel['msgid'] + channel_data['timers'] = [] + channels[channel['id']] = channel_data + + for timer in flat_timers: + timer_data = {} + + if timer['stages']: + stages = [ + lib.Stage(stage['name'], stage['duration'], stage['message'], False) + for stage in timer['stages'] + ] + stage_str = json.dumps(stages) + if stage_str in lib.Pattern.pattern_cache: + pattern = lib.Pattern.pattern_cache[stage_str] + else: + lib.Pattern.lastid += 1 + pattern = lib.Pattern(stages=stages, patternid=lib.Pattern.lastid, stage_str=stage_str) + lib.Pattern.pattern_cache[stage_str] = pattern + patternid = pattern.patternid + else: + patternid = 0 + timer_data['roleid'] = timer['roleid'] + timer_data['state'] = timer['state'] + timer_data['patternid'] = patternid + timer_data['stage_index'] = timer['current_stage'] or 0 + timer_data['stage_start'] = lib.adjust_timestamp(timer['current_stage_start'] or 0) + timer_data['message_ids'] = timer['messages'] + timer_data['subscribers'] = [] + timer_data['last_voice_update'] = 0 + + _timer = timers[timer['roleid']] + if _timer: + timer_data['guildid'] = _timer.guildid + if _timer.channelid not in channels: + channels[_timer.channelid] = { + 'channelid': _timer.channelid, + 'pinned_msg_id': None, + 'timers': [], + } + channels[_timer.channelid]['timers'].append(timer_data) + + for channelid, channel in channels.items(): + if channel['timers']: + guildid = channel['timers'][0]['guildid'] + if guildid not in save_data: + guild_channels = save_data[guildid] = [] + else: + guild_channels = save_data[guildid] + guild_channels.append(channel) +print("Read and parsed old savefile.") + + +# Write patterns +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO patterns (patternid, short_repr, stage_str) VALUES (?, ?, ?)", + ( + (pattern.patternid, False, pattern.stage_str) + for pattern in lib.Pattern.pattern_cache.values() + ) + ) +print("Created {} patterns.".format(len(lib.Pattern.pattern_cache))) + + +# Write user presets +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO user_presets (userid, patternid, preset_name) VALUES (?, ?, ?)", + ( + (userid, preset.patternid, preset.name) + for userid, user in users.items() + for preset in user.presets.values() + ) + ) +print("Transferred user presets.") + + +# Write guild presets +with target_database_conn as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT INTO guild_presets (guildid, patternid, preset_name, created_by) VALUES (?, ?, ?, ?)", + ( + (guildid, preset.patternid, preset.name, 0) + for guildid, guild in guilds.items() + for preset in guild.presets.values() + ) + ) +print("Transferred guild presets.") + + +# Write new save file, if required +if target_savefile_path: + with open(target_savefile_path, 'wb') as savefile: + pickle.dump(save_data, savefile, pickle.HIGHEST_PROTOCOL) + print("Written new save file.") + + +# Transfer the session data +print("Migrating session data....", end='') +with orig_session_conn as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT * FROM sessions" + ) + with target_database_conn as tconn: + tcursor = tconn.cursor() + tcursor.executemany( + "INSERT INTO sessions (guildid, userid, roleid, start_time, duration) VALUES (?, ?, ?, ?, ?)", + ( + (row[1], row[0], row[2], lib.adjust_timestamp(row[3]), row[4]) + for row in cursor.fetchall() + if row[4] > 60 + ) + ) +print("Done") + +print("Data migration v0 -> v1 complete!") diff --git a/data/migration/v0-v1/sample-migration.conf b/data/migration/v0-v1/sample-migration.conf new file mode 100644 index 0000000..d4e80b4 --- /dev/null +++ b/data/migration/v0-v1/sample-migration.conf @@ -0,0 +1,13 @@ +# Copy this file to `migration.conf`. +# Don't change any values unless the database files have been moved. +# All paths are relative to `bot/data`. +[Original] +settings_db = config_data.db +session_db = sessions.db + +# This can also be None if the session data doesn't need to be transferred +savefile = timerstatus.json + +[Target] +database = data.db +savefile = timerstatus/timerstatus.pickle diff --git a/data/migration/v0-v1/v1_schema.sql b/data/migration/v0-v1/v1_schema.sql new file mode 100644 index 0000000..ad151c5 --- /dev/null +++ b/data/migration/v0-v1/v1_schema.sql @@ -0,0 +1,157 @@ +PRAGMA foreign_keys = 1; + +--Meta +CREATE TABLE VersionHistory ( + version INTEGER NOT NULL, + time INTEGER NOT NULL, + author TEXT +); + +INSERT INTO VersionHistory VALUES (1, strftime('%s', 'now'), 'Initial Creation'); + + +-- Guild configuration +CREATE TABLE guilds ( + guildid INTEGER NOT NULL PRIMARY KEY, + timer_admin_roleid INTEGER, + show_tips BOOLEAN, + globalgroups BOOLEAN, + autoclean INTEGER, + studyrole_roleid INTEGER, + timezone TEXT, + prefix TEXT +); + + +-- User configuration +CREATE TABLE users ( + userid INTEGER NOT NULL PRIMARY KEY, + notify_level INTEGER, + timezone TEXT, + name TEXT +); + + +-- Timer patterns +CREATE TABLE patterns ( + patternid INTEGER PRIMARY KEY AUTOINCREMENT, + short_repr BOOL NOT NULL, + stage_str TEXT NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) +); +CREATE UNIQUE INDEX pattern_strings ON patterns(stage_str); +INSERT INTO patterns(patternid, short_repr, stage_str) + VALUES (0, 1, '[["Study \ud83d\udd25", 25, "Good luck!", false], ["Break\ud83c\udf1b", 5, "Have a rest.", false], ["Study \ud83d\udd25", 25, "Good luck!", false], ["Break \ud83c\udf1c", 5, "Have a rest.", false], ["Study \ud83d\udd25", 25, "Good luck!", false], ["Long Break \ud83c\udf1d", 10, "Have a rest.", false]]'); + + + +-- Timer pattern presets +CREATE TABLE user_presets ( + userid INTEGER NOT NULL, + preset_name TEXT NOT NULL COLLATE NOCASE, + patternid INTEGER NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE +); +CREATE UNIQUE INDEX user_preset_names ON user_presets(userid, preset_name); + +CREATE TABLE guild_presets ( + guildid INTEGER NOT NULL, + preset_name TEXT NOT NULL COLLATE NOCASE, + created_by INTEGER NOT NULL, + patternid INTEGER NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE +); +CREATE UNIQUE INDEX guild_preset_names ON guild_presets(guildid, preset_name); + +CREATE VIEW user_preset_patterns AS + SELECT + userid, + preset_name, + patternid, + patterns.stage_str AS preset_string + FROM user_presets + INNER JOIN patterns USING (patternid); + +CREATE VIEW guild_preset_patterns AS + SELECT + guildid, + preset_name, + created_by, + patternid, + patterns.stage_str AS preset_string + FROM guild_presets + INNER JOIN patterns USING (patternid); + + +-- Timers +CREATE TABLE timers ( + roleid INTEGER NOT NULL PRIMARY KEY, + guildid INTEGER NOT NULL, + name TEXT NOT NULL, + channelid INTEGER NOT NULL, + patternid INTEGER NOT NULL DEFAULT 0, + brief BOOLEAN, + voice_channelid INTEGER, + voice_alert BOOLEAN, + track_voice_join BOOLEAN, + track_voice_leave BOOLEAN, + auto_reset BOOLEAN, + admin_locked BOOLEAN, + track_role BOOLEAN, + compact BOOLEAN, + voice_channel_name TEXT, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE SET NULL +); + + +CREATE TABLE timer_pattern_history ( + timerid INTEGER NOT NULL, + patternid INTEGER NOT NULL, + modified_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + modified_by INTEGER, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE, + FOREIGN KEY (timerid) REFERENCES timers (roleid) ON DELETE CASCADE +); + +CREATE INDEX idx_timerid_modified_at on timer_pattern_history (timerid, modified_at); + + +CREATE VIEW timer_patterns AS + SELECT * + FROM patterns + INNER JOIN timers USING (patternid); + + +CREATE VIEW current_timer_patterns AS + SELECT + timerid, + patternid, + max(modified_at) + FROM timer_pattern_history + GROUP BY timerid; + + +-- Session storage +CREATE TABLE sessions ( + guildid INTEGER NOT NULL, + userid INTEGER NOT NULL, + roleid INTEGER NOT NULL, + start_time INTEGER NOT NULL, + duration INTEGER NOT NULL, + focused_duration INTEGER, + patternid INTEGER, + stages TEXT, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE SET NULL +); +CREATE INDEX idx_sessions_guildid_userid on sessions (guildid, userid); + +CREATE VIEW session_patterns AS + SELECT + *, + patterns.stage_str AS stage_str, + users.name AS user_name + FROM sessions + LEFT JOIN patterns USING (patternid) + LEFT JOIN users USING (userid); diff --git a/data/schema.sql b/data/schema.sql new file mode 100644 index 0000000..ad151c5 --- /dev/null +++ b/data/schema.sql @@ -0,0 +1,157 @@ +PRAGMA foreign_keys = 1; + +--Meta +CREATE TABLE VersionHistory ( + version INTEGER NOT NULL, + time INTEGER NOT NULL, + author TEXT +); + +INSERT INTO VersionHistory VALUES (1, strftime('%s', 'now'), 'Initial Creation'); + + +-- Guild configuration +CREATE TABLE guilds ( + guildid INTEGER NOT NULL PRIMARY KEY, + timer_admin_roleid INTEGER, + show_tips BOOLEAN, + globalgroups BOOLEAN, + autoclean INTEGER, + studyrole_roleid INTEGER, + timezone TEXT, + prefix TEXT +); + + +-- User configuration +CREATE TABLE users ( + userid INTEGER NOT NULL PRIMARY KEY, + notify_level INTEGER, + timezone TEXT, + name TEXT +); + + +-- Timer patterns +CREATE TABLE patterns ( + patternid INTEGER PRIMARY KEY AUTOINCREMENT, + short_repr BOOL NOT NULL, + stage_str TEXT NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')) +); +CREATE UNIQUE INDEX pattern_strings ON patterns(stage_str); +INSERT INTO patterns(patternid, short_repr, stage_str) + VALUES (0, 1, '[["Study \ud83d\udd25", 25, "Good luck!", false], ["Break\ud83c\udf1b", 5, "Have a rest.", false], ["Study \ud83d\udd25", 25, "Good luck!", false], ["Break \ud83c\udf1c", 5, "Have a rest.", false], ["Study \ud83d\udd25", 25, "Good luck!", false], ["Long Break \ud83c\udf1d", 10, "Have a rest.", false]]'); + + + +-- Timer pattern presets +CREATE TABLE user_presets ( + userid INTEGER NOT NULL, + preset_name TEXT NOT NULL COLLATE NOCASE, + patternid INTEGER NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE +); +CREATE UNIQUE INDEX user_preset_names ON user_presets(userid, preset_name); + +CREATE TABLE guild_presets ( + guildid INTEGER NOT NULL, + preset_name TEXT NOT NULL COLLATE NOCASE, + created_by INTEGER NOT NULL, + patternid INTEGER NOT NULL, + created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE +); +CREATE UNIQUE INDEX guild_preset_names ON guild_presets(guildid, preset_name); + +CREATE VIEW user_preset_patterns AS + SELECT + userid, + preset_name, + patternid, + patterns.stage_str AS preset_string + FROM user_presets + INNER JOIN patterns USING (patternid); + +CREATE VIEW guild_preset_patterns AS + SELECT + guildid, + preset_name, + created_by, + patternid, + patterns.stage_str AS preset_string + FROM guild_presets + INNER JOIN patterns USING (patternid); + + +-- Timers +CREATE TABLE timers ( + roleid INTEGER NOT NULL PRIMARY KEY, + guildid INTEGER NOT NULL, + name TEXT NOT NULL, + channelid INTEGER NOT NULL, + patternid INTEGER NOT NULL DEFAULT 0, + brief BOOLEAN, + voice_channelid INTEGER, + voice_alert BOOLEAN, + track_voice_join BOOLEAN, + track_voice_leave BOOLEAN, + auto_reset BOOLEAN, + admin_locked BOOLEAN, + track_role BOOLEAN, + compact BOOLEAN, + voice_channel_name TEXT, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE SET NULL +); + + +CREATE TABLE timer_pattern_history ( + timerid INTEGER NOT NULL, + patternid INTEGER NOT NULL, + modified_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')), + modified_by INTEGER, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE CASCADE, + FOREIGN KEY (timerid) REFERENCES timers (roleid) ON DELETE CASCADE +); + +CREATE INDEX idx_timerid_modified_at on timer_pattern_history (timerid, modified_at); + + +CREATE VIEW timer_patterns AS + SELECT * + FROM patterns + INNER JOIN timers USING (patternid); + + +CREATE VIEW current_timer_patterns AS + SELECT + timerid, + patternid, + max(modified_at) + FROM timer_pattern_history + GROUP BY timerid; + + +-- Session storage +CREATE TABLE sessions ( + guildid INTEGER NOT NULL, + userid INTEGER NOT NULL, + roleid INTEGER NOT NULL, + start_time INTEGER NOT NULL, + duration INTEGER NOT NULL, + focused_duration INTEGER, + patternid INTEGER, + stages TEXT, + FOREIGN KEY (patternid) REFERENCES patterns (patternid) ON DELETE SET NULL +); +CREATE INDEX idx_sessions_guildid_userid on sessions (guildid, userid); + +CREATE VIEW session_patterns AS + SELECT + *, + patterns.stage_str AS stage_str, + users.name AS user_name + FROM sessions + LEFT JOIN patterns USING (patternid) + LEFT JOIN users USING (userid); diff --git a/data/timerstatus/.gitignore b/data/timerstatus/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/requirements.txt b/requirements.txt index 844f49a..dbfe9af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,4 @@ discord.py +pyNacl +cachetools +pytz