diff --git a/app.py b/app.py index 31f89e1..188430e 100644 --- a/app.py +++ b/app.py @@ -3,12 +3,13 @@ from datetime import datetime import pytz -from py_tools.connections.db.mysql import DBManager, BaseOrmTable, SQLAlchemyManager +from py_tools.connections.db.mysql import DBManager, BaseOrmTable, \ + SQLAlchemyManager from sqlalchemy import text from sqlalchemy.ext.asyncio import create_async_engine +from bot.command import CommandHandler from bot.bot_client import BotClient -from bot.commands import CommandHandler from config import config from core.emby_api import EmbyApi, EmbyRouterAPI from services import UserService @@ -20,7 +21,8 @@ async def create_database_if_not_exists() -> None: """创建数据库。""" engine_without_db = create_async_engine( - f"mysql+asyncmy://{config.db_user}:{config.db_pass}@{config.db_host}:{config.db_port}/", + f"mysql+asyncmy://{config.db_user}:{config.db_pass}@" + f"{config.db_host}:{config.db_port}/", echo=True, ) async with engine_without_db.begin() as conn: @@ -69,6 +71,20 @@ def _init_logger() -> None: file_handler.setFormatter(formatter) logger.addHandler(file_handler) + # 创建 logger 并设置级别 + logger_i = logging.getLogger() + logger_i.setLevel(config.log_level) + + # 文件处理器,记录到 default.log + file_handler = logging.FileHandler("default.log") + file_handler.setFormatter(formatter) + logger_i.addHandler(file_handler) + + # 控制台处理器,打印到终端 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger_i.addHandler(console_handler) + def _init_tz() -> None: """初始化时区设置。""" @@ -97,7 +113,8 @@ async def setup_bot() -> BotClient: async def fetch_group_members(bot_client: BotClient) -> None: """获取群组成员并更新配置。""" - members_in_group = await bot_client.get_group_members(config.telegram_group_ids) + members_in_group = await bot_client.get_group_members( + config.telegram_group_ids) for group_members in members_in_group.values(): for telegram_id in group_members: config.group_members[telegram_id] = group_members[telegram_id] @@ -121,9 +138,10 @@ async def main() -> None: # 初始化 Emby API 和命令处理器 emby_api = EmbyApi(config.emby_url, config.emby_api) emby_router_api = EmbyRouterAPI(config.api_url, config.api_key) - command_handler = CommandHandler( + CommandHandler( bot_client=bot_client, - user_service=UserService(emby_api=emby_api, emby_router_api=emby_router_api), + user_service=UserService(emby_api=emby_api, + emby_router_api=emby_router_api), ) logger.info("Emby API 和命令处理器初始化完成。") @@ -133,7 +151,6 @@ async def main() -> None: logger.info("群组成员信息已更新。") # 设置命令并进入空闲状态 - command_handler.setup_commands() logger.info("命令处理器设置完成,Bot 进入运行状态。") await bot_client.idle() diff --git a/bot/__init__.py b/bot/__init__.py index 7e6f38c..01db925 100644 --- a/bot/__init__.py +++ b/bot/__init__.py @@ -1,17 +1,17 @@ import logging +from bot.utils.filters import user_in_group_on_filter, admin_user_on_filter, \ + emby_user_on_filter from .bot_client import BotClient -from .commands import CommandHandler -from .filters import user_in_group_on_filter, admin_user_on_filter, emby_user_on_filter -from .message_helper import get_user_telegram_id +from .command import CommandHandler from .utils import parse_iso8601_to_normal_date, parse_timestamp_to_normal_date +from .utils.message_helper import get_user_telegram_id logger = logging.getLogger(__name__) logger.info("Bot module initialized") __all__ = [ "BotClient", - "CommandHandler", "user_in_group_on_filter", "admin_user_on_filter", "emby_user_on_filter", diff --git a/bot/bot_client.py b/bot/bot_client.py index 971601f..0c027d8 100644 --- a/bot/bot_client.py +++ b/bot/bot_client.py @@ -7,11 +7,11 @@ class BotClient: def __init__( - self, - api_id: str, - api_hash: str, - bot_token: str, - name="emby_bot", + self, + api_id: str, + api_hash: str, + bot_token: str, + name="emby_bot", ): self.client = Client( name=name, api_id=api_id, api_hash=api_hash, bot_token=bot_token diff --git a/bot/command/__init__.py b/bot/command/__init__.py new file mode 100644 index 0000000..5c0637a --- /dev/null +++ b/bot/command/__init__.py @@ -0,0 +1,25 @@ +import logging + +from bot import BotClient +from bot.command.admin_command import AdminCommandHandler +from bot.command.event_command import EventHandler +from bot.command.user_command import UserCommandHandler +from bot.command_router import setup_command_routes +from services import UserService + +logger = logging.getLogger(__name__) + + +class CommandHandler: + def __init__(self, bot_client: BotClient, user_service: UserService): + self.bot_client = bot_client + self.user_service = user_service + self.code_to_message_id = {} + self.user_command_handler = UserCommandHandler(bot_client, + user_service) + self.admin_command_handler = AdminCommandHandler(bot_client, + user_service) + self.event_handler = EventHandler(bot_client, user_service) + setup_command_routes(bot_client, self.user_command_handler, + self.admin_command_handler, self.event_handler) + logger.info("CommandHandler initialized") diff --git a/bot/command/admin_command.py b/bot/command/admin_command.py new file mode 100644 index 0000000..54028f1 --- /dev/null +++ b/bot/command/admin_command.py @@ -0,0 +1,180 @@ +import logging +from datetime import datetime + +from pyrogram.enums import ParseMode +from pyrogram.types import Message + +from bot import BotClient +from bot.utils import with_parsed_args, reply_html, send_error, \ + with_ensure_args +from bot.utils.message_helper import get_user_telegram_id +from services import UserService + +logger = logging.getLogger(__name__) + + +class AdminCommandHandler: + def __init__(self, bot_client: BotClient, user_service: UserService): + self.bot_client = bot_client + self.user_service = user_service + self.code_to_message_id = {} + logger.info("AdminCommandHandler initialized") + + @with_parsed_args + async def new_code(self, message: Message, args: list[str]): + """ + /new_code [数量] + """ + num = 1 + if args: + try: + num = int(args[0]) + except ValueError: + return await reply_html(message, + "❌ 请输入有效数量 /new_code [整数]") + + num = min(num, 20) + try: + code_list = await ( + self.user_service + .create_invite_code(message.from_user.id, num) + ) + await self.send_code(code_list, message) + except Exception as e: + await send_error(message, e, prefix="创建邀请码失败") + + @with_parsed_args + async def new_whitelist_code(self, message: Message, args: list[str]): + """ + /new_whitelist_code [数量] + """ + num = 1 + if args: + try: + num = int(args[0]) + except ValueError: + return await reply_html( + message, + "❌ 请输入有效数量 /new_whitelist_code [整数]") + + num = min(num, 20) + try: + code_list = await self.user_service.create_whitelist_code( + message.from_user.id, num) + await self.send_code(code_list, message, whitelist=True) + except Exception as e: + await send_error(message, e, prefix="创建白名单邀请码失败") + + async def send_code(self, code_list, message, whitelist: bool = False): + if whitelist: + base_text = "📌 白名单邀请码:\n点击复制👉" + else: + base_text = "📌 邀请码:\n点击复制👉" + for code_obj in code_list: + # 每次用 base_text 重置消息文本哦~ + message_text = f"{base_text}{code_obj.code}" + if message.reply_to_message is not None: + await self.bot_client.client.send_message( + chat_id=message.from_user.id, + text=message_text, + parse_mode=ParseMode.HTML, + ) + await self.bot_client.client.send_message( + chat_id=message.reply_to_message.from_user.id, + text=message_text, + parse_mode=ParseMode.HTML, + ) + await reply_html(message, "✅ 已发送邀请码") + else: + msg = await reply_html( + message, + message_text + ) + self.code_to_message_id[code_obj.code] = ( + message.chat.id, msg.id + ) + + @with_parsed_args + async def ban_emby(self, message: Message, args: list[str]): + """ + /ban_emby [原因] (群里需回复某人或手动指定) + """ + reason = args[0] if args else "管理员禁用" + + operator_id = message.from_user.id + telegram_id = await get_user_telegram_id(self.bot_client.client, + message) + try: + if await self.user_service.emby_ban(telegram_id, reason, + operator_id): + await reply_html( + message, + f"✅ 已禁用用户 {telegram_id} 的Emby账号" + ) + else: + await reply_html(message, "❌ 禁用失败,请稍后重试。") + except Exception as e: + await send_error(message, e, prefix="禁用失败") + + async def unban_emby(self, message: Message): + """ + /unban_emby (群里需回复某人或手动指定) + """ + operator_id = message.from_user.id + telegram_id = await get_user_telegram_id(self.bot_client.client, + message) + try: + if await self.user_service.emby_unban(telegram_id, operator_id): + await reply_html( + message, + f"✅ 已解禁用户 {telegram_id} 的Emby账号" + ) + else: + await reply_html(message, "❌ 解禁失败,请稍后重试。") + except Exception as e: + await send_error(message, e, prefix="解禁失败") + + @with_parsed_args + @with_ensure_args(2, "/register_until 2023-10-01 12:00:00") + async def register_until(self, message: Message, args: list[str]): + """ + /register_until <时间: YYYY-MM-DD HH:MM:SS> + 限时开放注册 + """ + time_str = " ".join(args) + try: + time = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") + now = datetime.now() + if time < now: + return await reply_html(message, "❌ 时间必须晚于当前时间") + + await self.user_service.set_emby_config( + message.from_user.id, + register_public_time=int(time.timestamp()) + ) + await reply_html( + message, + f"✅ 已开放注册,截止时间:{time_str}" + ) + except Exception as e: + await send_error(message, e, prefix="开放注册失败") + + @with_parsed_args + @with_ensure_args(1, "/register_amount <人数>") + async def register_amount(self, message: Message, args: list[str]): + """ + /register_amount <人数> + 开放指定数量的注册名额 + """ + try: + amount = int(args[0]) + await self.user_service.set_emby_config( + message.from_user.id, + register_public_user=amount + ) + await reply_html( + message, + f"✅ 已开放注册,名额:{amount}" + ) + except Exception as e: + await send_error(message, e, prefix="开放注册失败") diff --git a/bot/command/event_command.py b/bot/command/event_command.py new file mode 100644 index 0000000..04f8f1c --- /dev/null +++ b/bot/command/event_command.py @@ -0,0 +1,66 @@ +import logging + +from pyrogram.types import Message, CallbackQuery + +from bot import BotClient +from config import config +from services import UserService + +logger = logging.getLogger(__name__) + + +class EventHandler: + def __init__(self, bot_client: BotClient, user_service: UserService): + self.bot_client = bot_client + self.user_service = user_service + self.code_to_message_id = {} + logger.info("EventHandler initialized") + + async def handle_callback_query(self, _, + callback_query: CallbackQuery): + """ + 回调按钮事件统一处理,如切换线路。 + """ + data = callback_query.data.split('_') + if data[0] == 'SELECTROUTE': + index = data[1] + try: + if not config.router_list: + await callback_query.answer("尚未加载线路列表,请稍后重试") + return + + selected_router = next( + (r for r in config.router_list if r['index'] == index), + None) + if not selected_router: + await callback_query.answer("线路不存在") + return + + await self.user_service.update_user_router( + callback_query.from_user.id, index) + await callback_query.answer("线路已更新") + await callback_query.message.edit( + f"已选择 {selected_router['name']}\n" + "生效可能会有 30 秒延迟,请耐心等候。" + ) + except Exception as e: + await callback_query.answer(f"操作失败:{str(e)}", + show_alert=True) + logger.error(f"Callback query failed: {e}", exc_info=True) + + async def group_member_change_handler(self, _, message: Message): + """ + 群组成员变动处理器。 + """ + if message.left_chat_member: + left_member_id = message.left_chat_member.id + left_member = await self.user_service.must_get_user(left_member_id) + if (left_member.has_emby_account() + and not left_member.is_emby_baned() + and not left_member.is_whitelist): + await self.user_service.emby_ban(message.left_chat_member.id, + "用户已退出群组") + config.group_members.pop(message.left_chat_member.id, None) + if message.new_chat_members: + for new_member in message.new_chat_members: + config.group_members[new_member.id] = new_member diff --git a/bot/command/user_command.py b/bot/command/user_command.py new file mode 100644 index 0000000..847e9a1 --- /dev/null +++ b/bot/command/user_command.py @@ -0,0 +1,233 @@ +import logging +from datetime import datetime + +from pyrogram.types import Message, InlineKeyboardButton, InlineKeyboardMarkup + +from bot import BotClient +from bot.utils import reply_html, send_error, parse_iso8601_to_normal_date, \ + with_parsed_args, with_ensure_args +from bot.utils.message_helper import get_user_telegram_id +from config import config +from models.invite_code_model import InviteCodeType +from services import UserService + +logger = logging.getLogger(__name__) + + +class UserCommandHandler: + def __init__(self, bot_client: BotClient, user_service: UserService): + self.bot_client = bot_client + self.user_service = user_service + self.code_to_message_id = {} + logger.info("UserCommandHandler initialized") + + async def count(self, message: Message): + """ + /count + 查询服务器内片子数量 + """ + try: + count_data = self.user_service.emby_count() + if not count_data: + return await reply_html(message, "❌ 查询失败:无法获取数据") + + await reply_html( + message, + ( + f"🎬 电影数量:" + f"{count_data.get('MovieCount', 0)}" + f"\n" + f"📽️ 剧集数量:" + f"{count_data.get('SeriesCount', 0)}" + f"\n" + f"🎞️ 总集数:" + f"{count_data.get('EpisodeCount', 0)}" + f"\n" + ) + ) + except Exception as e: + await send_error(message, e, prefix="查询失败") + + async def info(self, message: Message): + """ + /info + 如果是私聊,查看自己信息;如果群里回复某人,则查看对方信息 + """ + telegram_id = await get_user_telegram_id(self.bot_client.client, + message) + try: + user, emby_info = await self.user_service.emby_info(telegram_id) + last_active = ( + parse_iso8601_to_normal_date(emby_info.get("LastActivityDate")) + if emby_info.get("LastActivityDate") else "无") + date_created = parse_iso8601_to_normal_date( + emby_info.get("DateCreated", "")) + ban_status = "正常" if ( + user.ban_time is None or user.ban_time == 0) else "已禁用" + + reply_text = ( + f"👤 用户信息:\n" + f"• Emby用户名:{user.emby_name}\n" + f"• 上次活动时间:{last_active}\n" + f"• 创建时间:{date_created}\n" + f"• 白名单:{'是' if user.is_whitelist else '否'}\n" + f"• 管理员:{'是' if user.is_admin else '否'}\n" + f"• 账号状态:{ban_status}\n" + ) + + if user.ban_time and user.ban_time > 0: + ban_time = datetime.fromtimestamp(user.ban_time).strftime( + '%Y-%m-%d %H:%M:%S') + reply_text += f"• 被ban时间:{ban_time}\n" + if user.reason: + reply_text += f"• 被ban原因:{user.reason}\n" + + await reply_html(message, reply_text) + except Exception as e: + await send_error(message, e, prefix="查询失败") + + @with_parsed_args + @with_ensure_args(1, "/use_code <邀请码>") + async def use_code(self, message: Message, args: list[str]): + """ + /use_code <邀请码> + """ + code = args[0] + telegram_id = message.from_user.id + try: + used_code = await self.user_service.redeem_code(telegram_id, code) + if not used_code: + return await reply_html(message, "❌ 邀请码使用失败") + # 根据类型给出不同的回复 + if used_code.code_type == InviteCodeType.REGISTER: + await reply_html(message, + "✅ 邀请码使用成功,您已获得创建账号资格") + else: + await reply_html(message, + "✅ 邀请码使用成功,您已获得白名单资格") + + # 如果该邀请码在bot中记录了消息,需要删除 + if self.code_to_message_id.get(code): + code_to_message_id = self.code_to_message_id[code] + await ( + self.bot_client + .client.delete_messages( + code_to_message_id[0], + code_to_message_id[1]) + ) + del self.code_to_message_id[code] + except Exception as e: + await send_error(message, e, prefix="邀请码使用失败") + + async def select_line(self, message: Message): + """ + /select_line + 用户选择线路(将返回可选线路按钮)。 + """ + try: + telegram_id = message.from_user.id + router_list = ( + config.router_list or + await self.user_service.get_router_list(telegram_id) + ) + # 缓存到 config 中,减少重复获取 + if router_list and not config.router_list: + config.router_list = router_list + + user_router = await self.user_service.get_user_router(telegram_id) + user_router_index = user_router.get('index', '') + message_text = f"当前线路:{user_router_index}\n请选择线路:" + message_buttons = [] + + for router in router_list: + index = router.get('index') + name = router.get('name') + # 已选线路高亮 + button_text = f"🔵 {name}" if index == user_router_index \ + else f"⚪ {name}" + ( + message_buttons + .append( + [InlineKeyboardButton( + button_text, + callback_data=f"SELECTROUTE_{index}")] + ) + ) + + keyboard = InlineKeyboardMarkup(message_buttons) + await reply_html(message, message_text, reply_markup=keyboard) + except Exception as e: + await send_error(message, e, prefix="查询失败") + + @with_parsed_args + @with_ensure_args(1, "/create <用户名>") + async def create_user(self, message: Message, args: list[str]): + """ + /create <用户名> + """ + emby_name = args[0] + try: + default_password = self.user_service.gen_default_passwd() + user = await ( + self.user_service.emby_create_user( + message.from_user.id, emby_name, default_password + ) + ) + if user and user.has_emby_account(): + await reply_html( + message, + f"✅ 创建用户成功。\n初始密码:{default_password}" + ) + else: + await reply_html(message, "❌ 创建用户失败,请稍后重试。") + except Exception as e: + await send_error(message, e, prefix="创建用户失败") + + async def reset_emby_password(self, message: Message): + """ + /reset_emby_password + """ + default_password = self.user_service.gen_default_passwd() + try: + if await ( + self.user_service + .reset_password( + message.from_user.id, default_password + ) + ): + await reply_html( + message, + f"✅ 密码重置成功。\n新密码:{default_password}" + ) + else: + await reply_html(message, "❌ 密码重置失败,请稍后重试。") + except Exception as e: + await send_error(message, e, prefix="密码重置失败") + + async def help_command(self, message: Message): + """ + /help 或 /start + 查看命令帮助。 + """ + help_message = ( + "用户命令:\n" + "/use_code [code] - 使用邀请码获取创建账号资格\n" + "/create [username] - 创建Emby用户 (英文/下划线, 至少5位)\n" + "/info - 查看用户信息(私聊查看自己的,群里可回复他人)\n" + "/select_line - 选择线路\n" + "/reset_emby_password - 重置Emby账号密码\n" + "/count - 查看服务器内影片数量\n" + "/help - 显示本帮助\n" + ) + if await self.user_service.is_admin(message.from_user.id): + help_message += ( + "\n管理命令:\n" + "/new_code [数量] - 创建新的普通邀请码\n" + "/new_whitelist_code [数量] - 创建新的白名单邀请码\n" + "/register_until [YYYY-MM-DD HH:MM:SS] - 限时开放注册\n" + "/register_amount [人数] - 开放指定注册名额\n" + "/info (群里回复某人) - 查看他人信息\n" + "/ban_emby [原因] - 禁用某用户的Emby账号\n" + "/unban_emby - 解禁某用户的Emby账号\n" + ) + await reply_html(message, help_message) diff --git a/bot/command_router.py b/bot/command_router.py new file mode 100644 index 0000000..06f3e6f --- /dev/null +++ b/bot/command_router.py @@ -0,0 +1,80 @@ +from pyrogram import filters + +from bot import BotClient +from bot.command.admin_command import AdminCommandHandler +from bot.command.event_command import EventHandler +from bot.command.user_command import UserCommandHandler +from bot.utils.filters import user_in_group_on_filter, emby_user_on_filter, \ + admin_user_on_filter + + +def setup_command_routes(bot_client: BotClient, + user_command_handler: UserCommandHandler, + admin_command_handler: AdminCommandHandler, + event_handler: EventHandler): + # 定义命令配置,每项为 (命令, 过滤器, 处理函数) + command_definitions = [ + ( + ["help", "start"], + filters.private, + user_command_handler.help_command + ), + ("count", user_in_group_on_filter, user_command_handler.count), + ("info", user_in_group_on_filter, user_command_handler.info), + ("use_code", filters.private & user_in_group_on_filter, + user_command_handler.use_code), + ("create", filters.private & user_in_group_on_filter, + user_command_handler.create_user), + ("reset_emby_password", + filters.private & user_in_group_on_filter & emby_user_on_filter, + user_command_handler.reset_emby_password), + ("select_line", + filters.private & user_in_group_on_filter & emby_user_on_filter, + user_command_handler.select_line), + ("new_code", admin_user_on_filter, admin_command_handler.new_code), + ("new_whitelist_code", admin_user_on_filter, + admin_command_handler.new_whitelist_code), + ("ban_emby", admin_user_on_filter, admin_command_handler.ban_emby), + ("unban_emby", admin_user_on_filter, admin_command_handler.unban_emby), + ("register_until", admin_user_on_filter, + admin_command_handler.register_until), + ("register_amount", admin_user_on_filter, + admin_command_handler.register_amount), + ] + + # 循环注册消息处理器 + for cmd, f, func in command_definitions: + if isinstance(cmd, list): + # 对于多个命令,一般只用于私聊 + def make_handler(func_=func, f_=f, cmd_=None): + if cmd_ is None: + cmd_ = cmd + + @bot_client.client.on_message( + filters.private & filters.command(cmd_) & f_) + async def handler(_, message): + await func_(message) + + return handler + + make_handler() + else: + def make_handler(func_=func, f_=f, cmd_=cmd): + @bot_client.client.on_message(filters.command(cmd_) & f_) + async def handler(_, message): + await func_(message) + + return handler + + make_handler() + + # 注册回调查询处理器 + @bot_client.client.on_callback_query() + async def c_select_line_cb(client, callback_query): + await event_handler.handle_callback_query(client, callback_query) + + # 注册群组成员变动处理器 + @bot_client.client.on_message( + filters.left_chat_member | filters.new_chat_members) + async def group_member_change_handler(client, message): + await event_handler.group_member_change_handler(client, message) diff --git a/bot/commands.py b/bot/commands.py deleted file mode 100644 index c861633..0000000 --- a/bot/commands.py +++ /dev/null @@ -1,589 +0,0 @@ -import logging -import functools -from datetime import datetime - -from pyrogram import filters -from pyrogram.enums import ParseMode -from pyrogram.types import ( - InlineKeyboardButton, - InlineKeyboardMarkup, - CallbackQuery, - Message, -) - -from bot.bot_client import BotClient -from bot.filters import ( - user_in_group_on_filter, - admin_user_on_filter, - emby_user_on_filter, -) -from bot.message_helper import get_user_telegram_id -from bot.utils import parse_iso8601_to_normal_date -from config import config -from models.invite_code_model import InviteCodeType -from services import UserService - -logger = logging.getLogger(__name__) - - -class CommandHandler: - def __init__(self, bot_client: BotClient, user_service: UserService): - self.bot_client = bot_client - self.user_service = user_service - self.code_to_message_id = {} - logger.info("CommandHandler initialized") - - # =============== 辅助方法 =============== - - @staticmethod - async def _reply_html(message: Message, text: str, **kwargs): - """ - 统一回复方法,使用 HTML parse_mode。 - """ - return await message.reply(text, parse_mode=ParseMode.HTML, **kwargs) - - @staticmethod - def _parse_args(message: Message) -> list[str]: - """ - 将用户输入拆分为命令 + 参数列表,如: - '/create testuser' -> ['testuser'] - """ - parts = message.text.strip().split(" ") - return parts[1:] if len(parts) > 1 else [] - - @staticmethod - def ensure_args(min_len: int, usage: str): - """ - 装饰器:确保命令行参数长度足够,不足则回复用法说明。 - """ - - def decorator(func): - @functools.wraps(func) - async def wrapper(self, message, *args, **kwargs): - # 从消息中解析参数 - parsed_args = self._parse_args(message) - if len(parsed_args) < min_len: - await self._reply_html( - message, f"参数不足,请参考用法:\n{usage}" - ) - return - # 将解析好的参数传递给目标函数,避免在函数内部再调用 _parse_args - return await func(self, message, parsed_args, *args, **kwargs) - - return wrapper - - return decorator - - async def _send_error( - self, message: Message, error: Exception, prefix: str = "操作失败" - ): - """ - 统一的异常捕获后回复方式。 - """ - logger.error(f"{prefix}:{error}", exc_info=True) - await self._reply_html(message, f"{prefix}:{error}") - - # =============== 各类命令逻辑 =============== - - @ensure_args(1, "/create <用户名>") - async def create_user(self, message: Message, args: list[str]): - """ - /create <用户名> - """ - - emby_name = args[0] - try: - default_password = self.user_service.gen_default_passwd() - user = await self.user_service.emby_create_user( - message.from_user.id, emby_name, default_password - ) - if user and user.has_emby_account(): - await self._reply_html( - message, - f"✅ 创建用户成功。\n初始密码:{default_password}", - ) - else: - await self._reply_html(message, "❌ 创建用户失败,请稍后重试。") - except Exception as e: - await self._send_error(message, e, prefix="创建用户失败") - - async def info(self, message: Message): - """ - /info - 如果是私聊,查看自己信息;如果群里回复某人,则查看对方信息 - """ - telegram_id = await get_user_telegram_id(self.bot_client.client, message) - try: - user, emby_info = await self.user_service.emby_info(telegram_id) - last_active = ( - parse_iso8601_to_normal_date(emby_info.get("LastActivityDate")) - if emby_info.get("LastActivityDate") - else "无" - ) - date_created = parse_iso8601_to_normal_date( - emby_info.get("DateCreated", "") - ) - ban_status = ( - "正常" if (user.ban_time is None or user.ban_time == 0) else "已禁用" - ) - - reply_text = ( - f"👤 用户信息:\n" - f"• Emby用户名:{user.emby_name}\n" - f"• 上次活动时间:{last_active}\n" - f"• 创建时间:{date_created}\n" - f"• 白名单:{'是' if user.is_whitelist else '否'}\n" - f"• 管理员:{'是' if user.is_admin else '否'}\n" - f"• 账号状态:{ban_status}\n" - ) - - if user.ban_time and user.ban_time > 0: - ban_time = datetime.fromtimestamp(user.ban_time).strftime( - "%Y-%m-%d %H:%M:%S" - ) - reply_text += f"• 被ban时间:{ban_time}\n" - if user.reason: - reply_text += f"• 被ban原因:{user.reason}\n" - - await self._reply_html(message, reply_text) - except Exception as e: - await self._send_error(message, e, prefix="查询失败") - - @ensure_args(1, "/use_code <邀请码>") - async def use_code(self, message: Message, args: list[str]): - """ - /use_code <邀请码> - """ - - code = args[0] - telegram_id = message.from_user.id - try: - used_code = await self.user_service.redeem_code(telegram_id, code) - if not used_code: - return await self._reply_html(message, "❌ 邀请码使用失败") - # 根据类型给出不同的回复 - if used_code.code_type == InviteCodeType.REGISTER: - await self._reply_html( - message, "✅ 邀请码使用成功,您已获得创建账号资格" - ) - else: - await self._reply_html(message, "✅ 邀请码使用成功,您已获得白名单资格") - - # 如果该邀请码在bot中记录了消息,需要删除 - if self.code_to_message_id.get(code): - code_to_message_id = self.code_to_message_id[code] - await self.bot_client.client.delete_messages( - code_to_message_id[0], code_to_message_id[1] - ) - del self.code_to_message_id[code] - except Exception as e: - await self._send_error(message, e, prefix="邀请码使用失败") - - async def reset_emby_password(self, message: Message): - """ - /reset_emby_password - """ - default_password = self.user_service.gen_default_passwd() - try: - if await self.user_service.reset_password( - message.from_user.id, default_password - ): - await self._reply_html( - message, - f"✅ 密码重置成功。\n新密码:{default_password}", - ) - else: - await self._reply_html(message, "❌ 密码重置失败,请稍后重试。") - except Exception as e: - await self._send_error(message, e, prefix="密码重置失败") - - async def new_code(self, message: Message): - """ - /new_code [数量] - """ - args = self._parse_args(message) - num = 1 - if args: - try: - num = int(args[0]) - except ValueError: - return await self._reply_html( - message, "❌ 请输入有效数量 /new_code [整数]" - ) - - num = min(num, 20) - try: - code_list = await self.user_service.create_invite_code( - message.from_user.id, num - ) - for code_obj in code_list: - message_text = f"📌 邀请码:\n点击复制👉{code_obj.code}" - if message.reply_to_message is not None: - await self.bot_client.client.send_message( - chat_id=message.from_user.id, - text=message_text, - parse_mode=ParseMode.HTML, - ) - await self.bot_client.client.send_message( - chat_id=message.reply_to_message.from_user.id, - text=message_text, - parse_mode=ParseMode.HTML, - ) - await self._reply_html(message, "✅ 已发送邀请码") - else: - msg = await self._reply_html(message, message_text) - self.code_to_message_id[code_obj.code] = (message.chat.id, msg.id) - except Exception as e: - await self._send_error(message, e, prefix="创建邀请码失败") - - async def new_whitelist_code(self, message: Message): - """ - /new_whitelist_code [数量] - """ - args = self._parse_args(message) - num = 1 - if args: - try: - num = int(args[0]) - except ValueError: - return await self._reply_html( - message, "❌ 请输入有效数量 /new_whitelist_code [整数]" - ) - - num = min(num, 20) - try: - code_list = await self.user_service.create_whitelist_code( - message.from_user.id, num - ) - for code_obj in code_list: - message_text = ( - f"📌 白名单邀请码:\n点击复制👉{code_obj.code}" - ) - if message.reply_to_message is not None: - await self.bot_client.client.send_message( - chat_id=message.from_user.id, - text=message_text, - parse_mode=ParseMode.HTML, - ) - await self.bot_client.client.send_message( - chat_id=message.reply_to_message.from_user.id, - text=message_text, - parse_mode=ParseMode.HTML, - ) - await self._reply_html(message, "✅ 已发送邀请码") - else: - msg = await self._reply_html(message, message_text) - self.code_to_message_id[code_obj.code] = (message.chat.id, msg.id) - except Exception as e: - await self._send_error(message, e, prefix="创建白名单邀请码失败") - - async def ban_emby(self, message: Message): - """ - /ban_emby [原因] (群里需回复某人或手动指定) - """ - args = self._parse_args(message) - reason = args[0] if args else "管理员禁用" - - operator_id = message.from_user.id - telegram_id = await get_user_telegram_id(self.bot_client.client, message) - try: - if await self.user_service.emby_ban(telegram_id, reason, operator_id): - await self._reply_html( - message, f"✅ 已禁用用户 {telegram_id} 的Emby账号" - ) - else: - await self._reply_html(message, "❌ 禁用失败,请稍后重试。") - except Exception as e: - await self._send_error(message, e, prefix="禁用失败") - - async def unban_emby(self, message: Message): - """ - /unban_emby (群里需回复某人或手动指定) - """ - operator_id = message.from_user.id - telegram_id = await get_user_telegram_id(self.bot_client.client, message) - try: - if await self.user_service.emby_unban(telegram_id, operator_id): - await self._reply_html( - message, f"✅ 已解禁用户 {telegram_id} 的Emby账号" - ) - else: - await self._reply_html(message, "❌ 解禁失败,请稍后重试。") - except Exception as e: - await self._send_error(message, e, prefix="解禁失败") - - async def select_line(self, message: Message): - """ - /select_line - 用户选择线路(将返回可选线路按钮)。 - """ - try: - telegram_id = message.from_user.id - router_list = config.router_list or await self.user_service.get_router_list( - telegram_id - ) - # 缓存到 config 中,减少重复获取 - if router_list and not config.router_list: - config.router_list = router_list - - user_router = await self.user_service.get_user_router(telegram_id) - user_router_index = user_router.get("index", "") - message_text = f"当前线路:{user_router_index}\n请选择线路:" - message_buttons = [] - - for router in router_list: - index = router.get("index") - name = router.get("name") - # 已选线路高亮 - button_text = ( - f"🔵 {name}" if index == user_router_index else f"⚪ {name}" - ) - message_buttons.append( - [ - InlineKeyboardButton( - button_text, callback_data=f"SELECTROUTE_{index}" - ) - ] - ) - - keyboard = InlineKeyboardMarkup(message_buttons) - await self._reply_html(message, message_text, reply_markup=keyboard) - except Exception as e: - await self._send_error(message, e, prefix="查询失败") - - async def group_member_change_handler(self, clent, message: Message): - """ - 群组成员变动处理器。 - """ - if message.left_chat_member: - left_member_id = message.left_chat_member.id - left_member = await self.user_service.must_get_user(left_member_id) - if ( - left_member.has_emby_account() - and not left_member.is_emby_baned() - and not left_member.is_whitelist - ): - await self.user_service.emby_ban( - message.left_chat_member.id, "用户已退出群组" - ) - config.group_members.pop(message.left_chat_member.id, None) - if message.new_chat_members: - for new_member in message.new_chat_members: - config.group_members[new_member.id] = new_member - - async def handle_callback_query(self, client, callback_query: CallbackQuery): - """ - 回调按钮事件统一处理,如切换线路。 - """ - data = callback_query.data.split("_") - if data[0] == "SELECTROUTE": - index = data[1] - try: - if not config.router_list: - await callback_query.answer("尚未加载线路列表,请稍后重试") - return - - selected_router = next( - (r for r in config.router_list if r["index"] == index), None - ) - if not selected_router: - await callback_query.answer("线路不存在") - return - - await self.user_service.update_user_router( - callback_query.from_user.id, index - ) - await callback_query.answer("线路已更新") - await callback_query.message.edit( - f"已选择 {selected_router['name']}\n" - "生效可能会有 30 秒延迟,请耐心等候。" - ) - except Exception as e: - await callback_query.answer(f"操作失败:{str(e)}", show_alert=True) - logger.error(f"Callback query failed: {e}", exc_info=True) - - async def count(self, message: Message): - """ - /count - 查询服务器内片子数量 - """ - try: - count_data = self.user_service.emby_count() - if not count_data: - return await self._reply_html(message, "❌ 查询失败:无法获取数据") - - await self._reply_html( - message, - ( - f"🎬 电影数量:{count_data.get('MovieCount', 0)}\n" - f"📽️ 剧集数量:{count_data.get('SeriesCount', 0)}\n" - f"🎞️ 总集数:{count_data.get('EpisodeCount', 0)}\n" - ), - ) - except Exception as e: - await self._send_error(message, e, prefix="查询失败") - - @ensure_args(2, "/register_until 2023-10-01 12:00:00") - async def register_until(self, message: Message, args: list[str]): - """ - /register_until <时间: YYYY-MM-DD HH:MM:SS> - 限时开放注册 - """ - - time_str = " ".join(args) - try: - time = datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") - now = datetime.now() - if time < now: - return await self._reply_html(message, "❌ 时间必须晚于当前时间") - - await self.user_service.set_emby_config( - message.from_user.id, register_public_time=int(time.timestamp()) - ) - await self._reply_html( - message, f"✅ 已开放注册,截止时间:{time_str}" - ) - except Exception as e: - await self._send_error(message, e, prefix="开放注册失败") - - @ensure_args(1, "/register_amount <人数>") - async def register_amount(self, message: Message, args: list[str]): - """ - /register_amount <人数> - 开放指定数量的注册名额 - """ - - try: - amount = int(args[0]) - await self.user_service.set_emby_config( - message.from_user.id, register_public_user=amount - ) - await self._reply_html( - message, f"✅ 已开放注册,名额:{amount}" - ) - except Exception as e: - await self._send_error(message, e, prefix="开放注册失败") - - async def help_command(self, message: Message): - """ - /help 或 /start - 查看命令帮助。 - """ - help_message = ( - "用户命令:\n" - "/use_code [code] - 使用邀请码获取创建账号资格\n" - "/create [username] - 创建Emby用户 (英文/下划线, 至少5位)\n" - "/info - 查看用户信息(私聊查看自己的,群里可回复他人)\n" - "/select_line - 选择线路\n" - "/reset_emby_password - 重置Emby账号密码\n" - "/count - 查看服务器内影片数量\n" - "/help - 显示本帮助\n" - ) - if await self.user_service.is_admin(message.from_user.id): - help_message += ( - "\n管理命令:\n" - "/new_code [数量] - 创建新的普通邀请码\n" - "/new_whitelist_code [数量] - 创建新的白名单邀请码\n" - "/register_until [YYYY-MM-DD HH:MM:SS] - 限时开放注册\n" - "/register_amount [人数] - 开放指定注册名额\n" - "/info (群里回复某人) - 查看他人信息\n" - "/ban_emby [原因] - 禁用某用户的Emby账号\n" - "/unban_emby - 解禁某用户的Emby账号\n" - ) - await self._reply_html(message, help_message) - - # =============== 命令挂载 =============== - def setup_commands(self): - @self.bot_client.client.on_message( - filters.private & filters.command(["help", "start"]) - ) - async def c_help(client, message): - await self.help_command(message) - - @self.bot_client.client.on_message( - filters.command("count") & user_in_group_on_filter - ) - async def c_count(client, message): - await self.count(message) - - @self.bot_client.client.on_message( - filters.command("info") & user_in_group_on_filter - ) - async def c_info(client, message): - await self.info(message) - - @self.bot_client.client.on_message( - filters.private & filters.command("use_code") & user_in_group_on_filter - ) - async def c_use_code(client, message): - await self.use_code(message) - - @self.bot_client.client.on_message( - filters.private & filters.command("create") & user_in_group_on_filter - ) - async def c_create_user(client, message): - await self.create_user(message) - - @self.bot_client.client.on_message( - filters.private - & filters.command("reset_emby_password") - & user_in_group_on_filter - & emby_user_on_filter - ) - async def c_reset_emby_password(client, message): - await self.reset_emby_password(message) - - @self.bot_client.client.on_message( - filters.private - & filters.command("select_line") - & user_in_group_on_filter - & emby_user_on_filter - ) - async def c_select_line(client, message): - await self.select_line(message) - - @self.bot_client.client.on_message( - filters.command("new_code") & admin_user_on_filter - ) - async def c_new_code(client, message): - await self.new_code(message) - - @self.bot_client.client.on_message( - filters.command("new_whitelist_code") & admin_user_on_filter - ) - async def c_new_whitelist_code(client, message): - await self.new_whitelist_code(message) - - @self.bot_client.client.on_message( - filters.command("ban_emby") & admin_user_on_filter - ) - async def c_ban_emby(client, message): - await self.ban_emby(message) - - @self.bot_client.client.on_message( - filters.command("unban_emby") & admin_user_on_filter - ) - async def c_unban_emby(client, message): - await self.unban_emby(message) - - @self.bot_client.client.on_message( - filters.command("register_until") & admin_user_on_filter - ) - async def c_register_until(client, message): - await self.register_until(message) - - @self.bot_client.client.on_message( - filters.command("register_amount") & admin_user_on_filter - ) - async def c_register_amount(client, message): - await self.register_amount(message) - - @self.bot_client.client.on_callback_query() - async def c_select_line_cb(client, callback_query): - await self.handle_callback_query(client, callback_query) - - @self.bot_client.client.on_message( - filters.left_chat_member | filters.new_chat_members - ) - async def group_member_change_handler(client, message): - await self.group_member_change_handler(client, message) diff --git a/bot/utils.py b/bot/utils.py deleted file mode 100644 index 98abc64..0000000 --- a/bot/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import logging -from datetime import datetime - -logger = logging.getLogger(__name__) - - -def parse_iso8601(datetime_str: str): - # 解析字符串为 datetime 对象 - try: - dt = datetime.strptime( - datetime_str[:26], "%Y-%m-%dT%H:%M:%S.%f" - ) # 截取到微秒部分 - logger.debug(f"Parsed ISO8601 datetime string: {datetime_str}") - return dt - except Exception as e: - logger.error( - f"Error parsing ISO8601 datetime string: {datetime_str}: {e}", exc_info=True - ) - return None - - -def parse_iso8601_to_timestamp(datetime_str: str): - dt = parse_iso8601(datetime_str) - if dt: - return dt.timestamp() - return None - - -def parse_iso8601_to_normal_date(datetime_str: str): - dt = parse_iso8601(datetime_str) - if dt: - return dt.strftime("%Y-%m-%d %H:%M:%S") - return None - - -def parse_timestamp_to_normal_date(timestamp: int): - try: - dt = datetime.fromtimestamp(timestamp) - logger.debug(f"Parsed timestamp: {timestamp}") - return dt.strftime("%Y-%m-%d %H:%M:%S") - except Exception as e: - logger.error(f"Error parsing timestamp {timestamp}: {e}", exc_info=True) - return None diff --git a/bot/utils/__init__.py b/bot/utils/__init__.py new file mode 100644 index 0000000..3b72aef --- /dev/null +++ b/bot/utils/__init__.py @@ -0,0 +1,112 @@ +import functools +import logging +from datetime import datetime + +from pyrogram.enums import ParseMode +from pyrogram.types import Message + +logger = logging.getLogger(__name__) + + +def parse_iso8601(datetime_str: str): + # 解析字符串为 datetime 对象 + try: + dt = datetime.strptime( + datetime_str[:26], "%Y-%m-%dT%H:%M:%S.%f" + ) # 截取到微秒部分 + logger.debug(f"Parsed ISO8601 datetime string: {datetime_str}") + return dt + except Exception as e: + logger.error( + f"Error parsing ISO8601 datetime string: {datetime_str}: {e}", + exc_info=True + ) + return None + + +def parse_iso8601_to_timestamp(datetime_str: str): + dt = parse_iso8601(datetime_str) + if dt: + return dt.timestamp() + return None + + +def parse_iso8601_to_normal_date(datetime_str: str): + dt = parse_iso8601(datetime_str) + if dt: + return dt.strftime("%Y-%m-%d %H:%M:%S") + return None + + +def parse_timestamp_to_normal_date(timestamp: int): + try: + dt = datetime.fromtimestamp(timestamp) + logger.debug(f"Parsed timestamp: {timestamp}") + return dt.strftime("%Y-%m-%d %H:%M:%S") + except Exception as e: + logger.error(f"Error parsing timestamp {timestamp}: {e}", + exc_info=True) + return None + + +async def reply_html(message: Message, text: str, **kwargs): + """ + 统一回复方法,使用 HTML parse_mode。 + """ + return await message.reply(text, parse_mode=ParseMode.HTML, **kwargs) + + +def with_parsed_args(func): + """ + 用于自动解析消息文本参数的装饰器 喵~。 + 这个装饰器会从消息的文本中提取以空格分割的参数(除第一个命令外), + 并将解析后的参数列表传递给被装饰的函数 喵~ + """ + + @functools.wraps(func) + async def wrapper(self, message: Message, *args, **kwargs): + parts = message.text.strip().split(" ") + parsed_args = parts[1:] if len(parts) > 1 else [] + return await func(self, message, parsed_args, *args, **kwargs) + + return wrapper + + +def with_ensure_args(min_len: int, usage: str): + """ + 用于确保命令参数数量足够的装饰器 喵~。 + 如果传入的参数数量少于要求的最小值,则自动回复提示信息,并终止函数的执行 喵~ + 参数: + min_len - 所需最小参数数量 + usage - 命令的正确用法示例 + """ + + def decorator(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # 判断是否为方法: + # 如果第一个参数是 Message,则视为普通函数,否则视为类方法(第一个参数为 self) + if args and isinstance(args[0], Message): + message_obj = args[0] + command_args = args[1] + else: + message_obj = args[1] + command_args = args[2] + if len(command_args) < min_len: + await reply_html(message_obj, + f"参数不足,请参考用法:\n{usage}") + return + return await func(*args, **kwargs) + + return wrapper + + return decorator + + +async def send_error(message: Message, error: Exception, + prefix: str = "操作失败"): + """ + 统一的异常捕获后回复方式。 + """ + logger.error(f"{prefix}:{error}", exc_info=True) + await reply_html(message, f"{prefix}:{error}") diff --git a/bot/filters.py b/bot/utils/filters.py similarity index 83% rename from bot/filters.py rename to bot/utils/filters.py index 5284478..a79f4b0 100644 --- a/bot/filters.py +++ b/bot/utils/filters.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -async def user_in_group_on_filter(filter, client, update) -> bool: +async def user_in_group_on_filter(_, __, update) -> bool: user = update.from_user or update.sender_chat telegram_id = user.id if config.group_members and telegram_id in config.group_members: @@ -22,7 +22,7 @@ async def user_in_group_on_filter(filter, client, update) -> bool: return False -async def admin_user_on_filter(filter, client, update) -> bool: +async def admin_user_on_filter(_, __, update) -> bool: user = update.from_user or update.sender_chat telegram_id = user.id try: @@ -32,7 +32,8 @@ async def admin_user_on_filter(filter, client, update) -> bool: return True except Exception as e: logger.error( - f"Error checking admin status for user {telegram_id}: {e}", exc_info=True + f"Error checking admin status for user {telegram_id}: {e}", + exc_info=True ) return False @@ -40,7 +41,7 @@ async def admin_user_on_filter(filter, client, update) -> bool: return False -async def emby_user_on_filter(filter, client, update) -> bool: +async def emby_user_on_filter(_, __, update) -> bool: user = update.from_user or update.sender_chat telegram_id = user.id try: @@ -50,7 +51,8 @@ async def emby_user_on_filter(filter, client, update) -> bool: return True except Exception as e: logger.error( - f"Error checking Emby status for user {telegram_id}: {e}", exc_info=True + f"Error checking Emby status for user {telegram_id}: {e}", + exc_info=True ) return False @@ -58,6 +60,7 @@ async def emby_user_on_filter(filter, client, update) -> bool: return False -user_in_group_on_filter = create(user_in_group_on_filter, "user_in_group_on_filter") +user_in_group_on_filter = create(user_in_group_on_filter, + "user_in_group_on_filter") admin_user_on_filter = create(admin_user_on_filter, "admin_user_on_filter") emby_user_on_filter = create(emby_user_on_filter, "emby_user_on_filter") diff --git a/bot/message_helper.py b/bot/utils/message_helper.py similarity index 75% rename from bot/message_helper.py rename to bot/utils/message_helper.py index 4efa283..bff23d6 100644 --- a/bot/message_helper.py +++ b/bot/utils/message_helper.py @@ -24,12 +24,14 @@ async def get_user_telegram_id(client, message): # 直接提供 Telegram ID(纯数字) if telegram_str.isdigit(): telegram_id = int(telegram_str) - logger.debug(f"Telegram ID from arguments (numeric): {telegram_id}") + logger.debug( + f"Telegram ID from arguments (numeric): {telegram_id}") # 使用 @username elif telegram_str.startswith("@"): telegram_username = telegram_str[1:] # 去掉 `@` - logger.debug(f"Telegram username from arguments: {telegram_username}") + logger.debug( + f"Telegram username from arguments: {telegram_username}") # 通过用户名查找 ID if telegram_username: @@ -37,7 +39,9 @@ async def get_user_telegram_id(client, message): user = await client.get_users(telegram_username) telegram_id = user.id logger.debug( - f"Telegram ID resolved from username {telegram_username}: {telegram_id}" + f"Telegram ID resolved from username " + f"{telegram_username}: " + f"{telegram_id}" ) except UsernameNotOccupied: error_message = f"❌ 用户名 @{telegram_username} 不存在" @@ -46,12 +50,15 @@ async def get_user_telegram_id(client, message): return None except PeerIdInvalid: error_message = f"❌ 无法获取用户 @{telegram_username} 的 ID" - logger.warning(f"Peer ID invalid for username: {telegram_username}") + logger.warning( + f"Peer ID invalid for username: {telegram_username}") await message.reply(error_message) return None except Exception as e: logger.error( - f"Error getting user ID from username {telegram_username}: {e}", + f"Error getting user ID from username " + f"{telegram_username}: " + f"{e}", exc_info=True, ) return None diff --git a/config.py b/config.py index 3290b57..ec26c86 100644 --- a/config.py +++ b/config.py @@ -1,5 +1,6 @@ import logging import os + from dotenv import load_dotenv # 加载 .env 文件 diff --git a/core/emby_api.py b/core/emby_api.py index b15ff0c..4f8c14a 100644 --- a/core/emby_api.py +++ b/core/emby_api.py @@ -1,4 +1,5 @@ import logging + import requests logger = logging.getLogger(__name__) @@ -19,7 +20,9 @@ def __init__(self, emby_url: str, emby_api: str, timeout: int = 10): self.api_key: str = emby_api self.timeout: int = timeout logger.info( - f"EmbyApi initialized with URL: {self.base_url}, timeout: {self.timeout}" + f"EmbyApi initialized with URL: " + f"{self.base_url}, timeout: " + f"{self.timeout}" ) def _request(self, method: str, path: str, data=None, params=None): @@ -44,7 +47,11 @@ def _request(self, method: str, path: str, data=None, params=None): url = f"{self.base_url}{path}" logger.debug( - f"Making {method} request to {url} with params: {params}, data: {data}" + f"Making " + f"{method} request to " + f"{url} with params: " + f"{params}, data: " + f"{data}" ) try: if method.upper() == "GET": @@ -53,7 +60,8 @@ def _request(self, method: str, path: str, data=None, params=None): ) elif method.upper() == "POST": response = requests.post( - url, params=params, json=data, timeout=self.timeout, headers=headers + url, params=params, json=data, timeout=self.timeout, + headers=headers ) else: raise Exception(f"暂不支持的 HTTP 方法: {method}") @@ -64,18 +72,21 @@ def _request(self, method: str, path: str, data=None, params=None): raise Exception("请求 Emby 服务器超时,请稍后重试或检查网络连接。") except requests.exceptions.ConnectionError as e: # 连接异常 - logger.error(f"Failed to connect to Emby server: {e}", exc_info=True) + logger.error(f"Failed to connect to Emby server: {e}", + exc_info=True) raise Exception(f"无法连接到 Emby 服务器: {str(e)}") except requests.exceptions.RequestException as e: # 其他 requests 异常 logger.error( - f"An unknown error occurred while requesting Emby: {e}", exc_info=True + f"An unknown error occurred while requesting Emby: {e}", + exc_info=True ) raise Exception(f"请求 Emby 时发生未知错误: {str(e)}") try: response.raise_for_status() - logger.debug(f"Request successful, status code: {response.status_code}") + logger.debug( + f"Request successful, status code: {response.status_code}") except Exception as e: logger.error(f"Emby API request failed: {e}", exc_info=True) raise Exception(f"Emby API 请求失败") @@ -93,7 +104,8 @@ def get_user(self, emby_id: str): return self._request("GET", path) except Exception as e: logger.error( - f"Failed to get user with Emby ID {emby_id}: {e}", exc_info=True + f"Failed to get user with Emby ID {emby_id}: {e}", + exc_info=True ) raise @@ -109,7 +121,8 @@ def create_user(self, name: str): try: return self._request("POST", path, data=data) except Exception as e: - logger.error(f"Failed to create user with name {name}: {e}", exc_info=True) + logger.error(f"Failed to create user with name {name}: {e}", + exc_info=True) raise def ban_user(self, emby_id: str): @@ -147,7 +160,8 @@ def ban_user(self, emby_id: str): return self.update_user_policy(emby_id, data) except Exception as e: logger.error( - f"Failed to ban user with Emby ID {emby_id}: {e}", exc_info=True + f"Failed to ban user with Emby ID {emby_id}: {e}", + exc_info=True ) raise @@ -186,7 +200,8 @@ def set_default_policy(self, emby_id: str): return self.update_user_policy(emby_id, data) except Exception as e: logger.error( - f"Failed to set default policy for user with Emby ID {emby_id}: {e}", + f"Failed to set default policy for user with Emby ID " + f"{emby_id}: {e}", exc_info=True, ) raise @@ -200,7 +215,8 @@ def update_user_policy(self, emby_id: str, policy_data: dict): """ path = f"/emby/Users/{emby_id}/Policy" logger.info( - f"Updating user policy for Emby ID: {emby_id} with data: {policy_data}" + f"Updating user policy for Emby ID: " + f"{emby_id} with data: {policy_data}" ) try: return self._request("POST", path, data=policy_data) @@ -224,7 +240,8 @@ def reset_user_password(self, emby_id: str): return self._request("POST", path, data=data) except Exception as e: logger.error( - f"Failed to reset password for user with Emby ID {emby_id}: {e}", + f"Failed to reset password for user with Emby ID " + f"{emby_id}: {e}", exc_info=True, ) raise @@ -293,7 +310,8 @@ def __init__(self, api_url: str, api_key: str = "", timeout: int = 10): self.api_key = api_key self.timeout = timeout logger.info( - f"EmbyRouterAPI initialized with URL: {self.api_url}, timeout: {self.timeout}" + f"EmbyRouterAPI initialized with URL: {self.api_url}" + f", timeout: {self.timeout}" ) def call_api(self, path: str): @@ -303,7 +321,8 @@ def call_api(self, path: str): :return: 成功时返回 JSON,失败抛出异常 """ url = f"{self.api_url}{path}" - headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + headers = { + "Authorization": f"Bearer {self.api_key}"} if self.api_key else {} logger.debug(f"Calling API at {url}") try: response = requests.get(url, headers=headers, timeout=self.timeout) @@ -313,11 +332,13 @@ def call_api(self, path: str): logger.error("Request to router service timed out", exc_info=True) raise Exception("请求路由服务超时,请稍后重试或检查网络连接。") except requests.exceptions.ConnectionError as e: - logger.error(f"Failed to connect to router service: {e}", exc_info=True) + logger.error(f"Failed to connect to router service: {e}", + exc_info=True) raise Exception(f"无法连接到路由服务: {str(e)}") except requests.exceptions.RequestException as e: logger.error( - f"An unknown error occurred while requesting router service: {e}", + f"An unknown error occurred while requesting router service: " + f"{e}", exc_info=True, ) raise Exception(f"请求路由服务时发生错误: {str(e)}") @@ -342,7 +363,8 @@ def query_user_route(self, user_id: str): return self.call_api(f"/api/route/{user_id}") except Exception as e: logger.error( - f"Failed to query user route for user ID {user_id}: {e}", exc_info=True + f"Failed to query user route for user ID {user_id}: {e}", + exc_info=True ) raise @@ -350,12 +372,15 @@ def update_user_route(self, user_id: str, new_index: str): """ 更新用户当前所使用的线路。 """ - logger.info(f"Updating user route for user ID: {user_id} to index: {new_index}") + logger.info( + f"Updating user route for user ID: " + f"{user_id} to index: {new_index}") try: return self.call_api(f"/api/route/{user_id}/{new_index}") except Exception as e: logger.error( - f"Failed to update user route for user ID {user_id} to index {new_index}: {e}", + f"Failed to update user route for user ID " + f"{user_id} to index {new_index}: {e}", exc_info=True, ) raise diff --git a/models/__init__.py b/models/__init__.py index d128db6..87ac476 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,3 @@ -from .user_model import User from .config_model import Config from .invite_code_model import InviteCode +from .user_model import User diff --git a/models/config_model.py b/models/config_model.py index 0b269a2..1ddab7e 100644 --- a/models/config_model.py +++ b/models/config_model.py @@ -1,4 +1,5 @@ import logging + from py_tools.connections.db.mysql import DBManager from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS from sqlalchemy import Integer, BigInteger @@ -10,11 +11,13 @@ class Config(BaseOrmTableWithTS): __tablename__ = "config" - total_register_user: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + total_register_user: Mapped[int] = mapped_column(Integer, nullable=False, + default=0) register_public_user: Mapped[int] = mapped_column( Integer, nullable=False, default=0 ) - register_public_time: Mapped[int] = mapped_column(BigInteger, nullable=True) + register_public_time: Mapped[int] = mapped_column(BigInteger, + nullable=True) class ConfigOrm(DBManager): diff --git a/models/invite_code_model.py b/models/invite_code_model.py index 3705697..025b081 100644 --- a/models/invite_code_model.py +++ b/models/invite_code_model.py @@ -1,5 +1,6 @@ -import logging import enum +import logging + from py_tools.connections.db.mysql import DBManager from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS from sqlalchemy import String, BigInteger, Boolean, Enum @@ -22,12 +23,15 @@ class InviteCode(BaseOrmTableWithTS): code: Mapped[str] = mapped_column( String(50), index=True, unique=True, nullable=False ) - telegram_id: Mapped[int] = mapped_column(BigInteger, index=True, nullable=False) + telegram_id: Mapped[int] = mapped_column(BigInteger, index=True, + nullable=False) code_type: Mapped[InviteCodeType] = mapped_column( Enum(InviteCodeType), nullable=False ) - is_used: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - used_time: Mapped[int] = mapped_column(BigInteger, default=None, nullable=True) + is_used: Mapped[bool] = mapped_column(Boolean, default=False, + nullable=False) + used_time: Mapped[int] = mapped_column(BigInteger, default=None, + nullable=True) used_user_id: Mapped[int] = mapped_column( BigInteger, default=None, nullable=True, index=True ) diff --git a/models/user_model.py b/models/user_model.py index 8d4069a..df3630f 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -1,4 +1,5 @@ import logging + from py_tools.connections.db.mysql import DBManager from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS from sqlalchemy import String, Boolean, BigInteger @@ -20,8 +21,10 @@ class User(BaseOrmTableWithTS): emby_id: Mapped[str] = mapped_column( String(50), index=True, unique=True, nullable=True ) - is_admin: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - is_whitelist: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + is_admin: Mapped[bool] = mapped_column(Boolean, default=False, + nullable=False) + is_whitelist: Mapped[bool] = mapped_column(Boolean, default=False, + nullable=False) enable_register: Mapped[bool] = mapped_column( Boolean, default=False, nullable=False ) @@ -63,7 +66,8 @@ def check_use_redeem_code(self) -> None: if self.emby_id is not None: raise Exception("该用户已拥有 Emby 账号,无法再次使用注册邀请码。") if self.enable_register: - raise Exception("该用户已经具备创建 Emby 账号的资格,无需再次使用邀请码。") + raise Exception( + "该用户已经具备创建 Emby 账号的资格,无需再次使用邀请码。") def check_use_whitelist_code(self) -> None: """检查是否可使用白名单邀请码。""" diff --git a/services/user_service.py b/services/user_service.py index c460761..dbb0dbc 100644 --- a/services/user_service.py +++ b/services/user_service.py @@ -18,6 +18,38 @@ logger = logging.getLogger(__name__) +async def first_or_create_emby_config() -> Config: + """获取或创建 Emby 配置。""" + emby_config = await ConfigOrm().query_one(conds=[Config.id == 1]) + if not emby_config: + emby_config = Config( + register_public_user=0, register_public_time=0, + total_register_user=0 + ) + await ConfigOrm().add(emby_config) + return emby_config + + +async def _check_register_permission(user: User, + emby_config: Config) -> bool: + """检查用户是否有权限注册 Emby 账号""" + enable_register = user.enable_register + if not enable_register and emby_config.register_public_user > 0: + enable_register = True + if ( + not enable_register + and emby_config.register_public_time > 0 + and (datetime.now().timestamp() + < emby_config.register_public_time) + ): + enable_register = True + if 0 < emby_config.register_public_time < datetime.now().timestamp(): + await ConfigOrm().update( + values={"register_public_time": 0}, conds=[Config.id == 1] + ) + return enable_register + + class UserService: """用户与 Emby 相关的业务逻辑层""" @@ -28,12 +60,14 @@ def __init__(self, emby_api: EmbyApi, emby_router_api: EmbyRouterAPI): @staticmethod async def get_or_create_user_by_telegram_id(telegram_id: int) -> User: """通过 telegram_id 从数据库获取用户,如果不存在则创建一个默认用户""" - user = await UserOrm().query_one(conds=[User.telegram_id == telegram_id]) + user = await UserOrm().query_one( + conds=[User.telegram_id == telegram_id]) if not user: default_user = User( telegram_id=telegram_id, is_admin=telegram_id in config.admin_list, - telegram_name=config.group_members.get(telegram_id, {}).username + telegram_name=config.group_members.get(telegram_id, + {}).username if config.group_members.get(telegram_id) else None, ) @@ -65,13 +99,14 @@ async def must_get_emby_user(self, telegram_id: int) -> User: return user async def _emby_create_user( - self, telegram_id: int, username: str, password: str + self, telegram_id: int, username: str, password: str ) -> User: """内部使用:真正调用 Emby API 创建用户,并设置初始密码""" user = await self.get_or_create_user_by_telegram_id(telegram_id) emby_user = self.emby_api.create_user(username) if not emby_user or not emby_user.get("Id"): - raise Exception("在 Emby 系统中创建账号失败,请检查 Emby 服务是否正常。") + raise Exception( + "在 Emby 系统中创建账号失败,请检查 Emby 服务是否正常。") emby_id = emby_user["Id"] user.emby_id = emby_id @@ -99,7 +134,7 @@ def gen_whitelist_code(num: int) -> List[str]: return [f"epw-{str(shortuuid.uuid())}" for _ in range(num)] async def create_invite_code( - self, telegram_id: int, count: int = 1 + self, telegram_id: int, count: int = 1 ) -> List[InviteCode]: """创建普通邀请码,需检测用户是否有权限""" user = await self.must_get_user(telegram_id) @@ -108,14 +143,15 @@ async def create_invite_code( code_objs = [ InviteCode( - code=code, telegram_id=telegram_id, code_type=InviteCodeType.REGISTER + code=code, telegram_id=telegram_id, + code_type=InviteCodeType.REGISTER ) for code in self.gen_register_code(count) ] return await InviteCodeOrm().bulk_add(code_objs) async def create_whitelist_code( - self, telegram_id: int, count: int = 1 + self, telegram_id: int, count: int = 1 ) -> List[InviteCode]: """创建白名单邀请码,需检测用户是否有权限""" user = await self.must_get_user(telegram_id) @@ -124,7 +160,8 @@ async def create_whitelist_code( code_objs = [ InviteCode( - code=code, telegram_id=telegram_id, code_type=InviteCodeType.WHITELIST + code=code, telegram_id=telegram_id, + code_type=InviteCodeType.WHITELIST ) for code in self.gen_whitelist_code(count) ] @@ -142,60 +179,36 @@ async def emby_info(self, telegram_id: int) -> Tuple[User, Dict]: ) return user, emby_user - async def first_or_create_emby_config(self) -> Config: - """获取或创建 Emby 配置。""" - emby_config = await ConfigOrm().query_one(conds=[Config.id == 1]) - if not emby_config: - emby_config = Config( - register_public_user=0, register_public_time=0, total_register_user=0 - ) - await ConfigOrm().add(emby_config) - return emby_config - async def emby_create_user( - self, telegram_id: int, username: str, password: str + self, telegram_id: int, username: str, password: str ) -> User: """创建 Emby 用户(外部调用入口),先判断各种配置是否允许注册,然后调用内部的 _emby_create_user""" user = await self.get_or_create_user_by_telegram_id(telegram_id) if user.has_emby_account(): - raise Exception("该 Telegram 用户已经绑定过 Emby 账号,无法重复创建。") + raise Exception( + "该 Telegram 用户已经绑定过 Emby 账号,无法重复创建。") - emby_config = await self.first_or_create_emby_config() + emby_config = await first_or_create_emby_config() if not emby_config: raise Exception("未找到 Emby 配置,无法创建账号。") - if not await self._check_register_permission(user, emby_config): + if not await _check_register_permission(user, emby_config): raise Exception("当前没有可用的注册权限或名额,创建账号被拒绝。") async with ConfigOrm().transaction() as session: - if not user.enable_register and emby_config.register_public_user > 0: + if (not user.enable_register + and emby_config.register_public_user > 0): emby_config.register_public_user -= 1 emby_config.total_register_user += 1 - new_user = await self._emby_create_user(telegram_id, username, password) + new_user = await self._emby_create_user(telegram_id, username, + password) session.add(new_user) session.add(emby_config) await session.commit() return new_user - async def _check_register_permission(self, user: User, emby_config: Config) -> bool: - """检查用户是否有权限注册 Emby 账号""" - enable_register = user.enable_register - if not enable_register and emby_config.register_public_user > 0: - enable_register = True - if ( - not enable_register - and emby_config.register_public_time > 0 - and datetime.now().timestamp() < emby_config.register_public_time - ): - enable_register = True - if 0 < emby_config.register_public_time < datetime.now().timestamp(): - await ConfigOrm().update( - values={"register_public_time": 0}, conds=[Config.id == 1] - ) - return enable_register - async def redeem_code(self, telegram_id: int, code: str): """使用邀请码,分为普通注册邀请码和白名单邀请码""" pattern = re.compile(r"^(epr|epw)-[A-Za-z0-9]+$") @@ -207,7 +220,8 @@ async def redeem_code(self, telegram_id: int, code: str): # 使用事务块,并通过行锁防止并发问题 async with InviteCodeOrm().transaction() as session: # 构造 SELECT 语句,并加上 FOR UPDATE 行锁 - stmt = select(InviteCode).where(InviteCode.code == code).with_for_update() + stmt = select(InviteCode).where( + InviteCode.code == code).with_for_update() result = await session.execute(stmt) valid_code = result.scalars().first() @@ -239,7 +253,8 @@ async def redeem_code(self, telegram_id: int, code: str): return valid_code - async def reset_password(self, telegram_id: int, password: str = "") -> bool: + async def reset_password(self, telegram_id: int, + password: str = "") -> bool: """重置用户的 Emby 密码。""" user = await self.must_get_emby_user(telegram_id) try: @@ -251,7 +266,8 @@ async def reset_password(self, telegram_id: int, password: str = "") -> bool: return False async def emby_ban( - self, telegram_id: int, reason: str, operator_telegram_id: Optional[int] = None + self, telegram_id: int, reason: str, + operator_telegram_id: Optional[int] = None ) -> bool: """禁用用户""" if operator_telegram_id is not None: @@ -276,7 +292,7 @@ async def emby_ban( return False async def emby_unban( - self, telegram_id: int, operator_telegram_id: Optional[int] = None + self, telegram_id: int, operator_telegram_id: Optional[int] = None ) -> bool: """解禁用户""" if operator_telegram_id is not None: @@ -300,16 +316,16 @@ async def emby_unban( return False async def set_emby_config( - self, - telegram_id: int, - register_public_user: Optional[int] = None, - register_public_time: Optional[int] = None, + self, + telegram_id: int, + register_public_user: Optional[int] = None, + register_public_time: Optional[int] = None, ) -> Config: """设置 Emby 注册相关配置,如公共注册名额和公共注册截止时间""" user = await self.must_get_user(telegram_id) user.check_set_emby_config() - emby_config = await self.first_or_create_emby_config() + emby_config = await first_or_create_emby_config() if not emby_config: raise Exception("未找到全局 Emby 配置,无法设置。") @@ -336,10 +352,12 @@ async def get_user_router(self, telegram_id: int) -> Dict: user = await self.must_get_emby_user(telegram_id) return self.emby_router_api.query_user_route(user.emby_id) - async def update_user_router(self, telegram_id: int, new_index: str) -> bool: + async def update_user_router(self, telegram_id: int, + new_index: str) -> bool: """更新用户线路信息""" user = await self.must_get_emby_user(telegram_id) - return self.emby_router_api.update_user_route(str(user.emby_id), str(new_index)) + return self.emby_router_api.update_user_route(str(user.emby_id), + str(new_index)) async def get_router_list(self, telegram_id: int) -> List[Dict]: """获取所有可用线路"""