Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bot/cogs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ def __str__(self) -> str:

EXTENSIONS = [module.name for module in iter_modules(__path__, f"{__package__}.")]
VERSION: VersionInfo = VersionInfo(major=0, minor=3, micro=1, releaselevel="final")

del Literal, NamedTuple
119 changes: 1 addition & 118 deletions bot/cogs/admin.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
from __future__ import annotations

import asyncio
import importlib
import os
import re
import subprocess # nosec # We already know this is dangerous, but it's needed
import sys
from typing import TYPE_CHECKING, Literal, Optional

import discord
from discord.ext import commands
from discord.ext.commands import Greedy

if TYPE_CHECKING:
from libs.utils import RoboContext
from utils.context import RoboContext

from bot.rodhaj import Rodhaj

GIT_PULL_REGEX = re.compile(r"\s+(?P<filename>.*)\b\s+\|\s+[\d]")


class Admin(commands.Cog, command_attrs=dict(hidden=True)):
"""Administrative commands for Rodhaj"""
Expand All @@ -33,78 +25,6 @@ def display_emoji(self) -> discord.PartialEmoji:
async def cog_check(self, ctx: RoboContext) -> bool:
return await self.bot.is_owner(ctx.author)

async def reload_or_load_extension(self, module: str) -> None:
try:
await self.bot.reload_extension(module)
except commands.ExtensionNotLoaded:
await self.bot.load_extension(module)

def find_modules_from_git(self, output: str) -> list[tuple[int, str]]:
files = GIT_PULL_REGEX.findall(output)
ret: list[tuple[int, str]] = []
for file in files:
root, ext = os.path.splitext(file)
if ext != ".py" or root.endswith("__init__"):
continue

true_root = ".".join(root.split("/")[1:])

if true_root.startswith("cogs") or true_root.startswith("libs"):
# A subdirectory within these are a part of the codebase

ret.append((true_root.count(".") + 1, true_root))

# For reload order, the submodules should be reloaded first
ret.sort(reverse=True)
return ret

async def run_process(self, command: str) -> list[str]:
process = await asyncio.create_subprocess_shell(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
result = await process.communicate()

return [output.decode() for output in result]

def tick(self, opt: Optional[bool], label: Optional[str] = None) -> str:
lookup = {
True: "\U00002705",
False: "\U0000274c",
None: "\U000023e9",
}
emoji = lookup.get(opt, "\U0000274c")
if label is not None:
return f"{emoji}: {label}"
return emoji

def format_results(self, statuses: list) -> str:
desc = "\U00002705 - Successful reload | \U0000274c - Failed reload | \U000023e9 - Skipped\n\n"
status = "\n".join(f"- {status}: `{module}`" for status, module in statuses)
desc += status
return desc

async def reload_exts(self, module: str) -> list[tuple[str, str]]:
statuses = []
try:
await self.reload_or_load_extension(module)
statuses.append((self.tick(True), module))
except commands.ExtensionError:
statuses.append((self.tick(False), module))

return statuses

def reload_lib_modules(self, module: str) -> list[tuple[str, str]]:
statuses = []
try:
actual_module = sys.modules[module]
importlib.reload(actual_module)
statuses.append((self.tick(True), module))
except KeyError:
statuses.append((self.tick(None), module))
except Exception:
statuses.append((self.tick(False), module))
return statuses

# Umbra's sync command
# To learn more about it, see the link below (and ?tag ass on the dpy server):
# https://about.abstractumbra.dev/discord.py/2023/01/29/sync-command-example.html
Expand Down Expand Up @@ -147,43 +67,6 @@ async def sync(

await ctx.send(f"Synced the tree to {ret}/{len(guilds)}.")

@commands.command(name="reload-all", hidden=True)
async def reload(self, ctx: RoboContext) -> None:
"""Reloads all cogs and utils"""
async with ctx.typing():
stdout, _ = await self.run_process("git pull")

# progress and stuff is redirected to stderr in git pull
# however, things like "fast forward" and files
# along with the text "already up-to-date" are in stdout

if stdout.startswith("Already up-to-date."):
await ctx.send(stdout)
return

modules = self.find_modules_from_git(stdout)

mods_text = "\n".join(
f"{index}. `{module}`" for index, (_, module) in enumerate(modules, start=1)
)
prompt_text = (
f"This will update the following modules, are you sure?\n{mods_text}"
)

confirm = await ctx.prompt(prompt_text)
if not confirm:
await ctx.send("Aborting....")
return

statuses = []
for is_submodule, module in modules:
if is_submodule:
statuses = self.reload_lib_modules(module)
else:
statuses = await self.reload_exts(module)

await ctx.send(self.format_results(statuses))


async def setup(bot: Rodhaj) -> None:
await bot.add_cog(Admin(bot))
Loading