diff --git a/.agents/skills/framework-usage/SKILL.md b/.agents/skills/framework-usage/SKILL.md index e90315b2..4a602e28 100644 --- a/.agents/skills/framework-usage/SKILL.md +++ b/.agents/skills/framework-usage/SKILL.md @@ -52,6 +52,8 @@ ncatbot init # 交互式创建 config.yaml + plugins/ + | 用户需求 | 框架功能 | 参考 | |---------|---------|------| | 响应命令/消息/事件 | 装饰器 + handler | [events.md](./references/events.md) | +| 简单命令处理 | CommandHook(单层命令) | [hooks.md](./references/hooks.md) | +| 分层命令结构(子命令/命令组) | CommandGroup + CommandGroupHook | [hooks.md](./references/hooks.md) / `examples/common/08_command_group/` | | 发送文字/图片/视频/转发 | 消息构造与发送 | [messaging.md](./references/messaging.md) | | 群管理/查询信息/文件 | Bot API | [bot-api.md](./references/bot-api.md) | | 持久化配置/数据 | ConfigMixin / DataMixin | [mixins.md](./references/mixins.md) | diff --git a/examples/README.md b/examples/README.md index 61e8aaaf..0225695d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -20,6 +20,7 @@ | 05 | [scheduled_tasks](common/05_scheduled_tasks/) | 定时任务(多种时间格式/条件执行) | ⭐⭐ | | 06 | [multi_step_dialog](common/06_multi_step_dialog/) | 多步对话(from_event/超时/取消) | ⭐⭐ | | 07 | [external_api](common/07_external_api/) | 外部 API 集成(aiohttp/配置/错误处理) | ⭐⭐ | +| 08 | [command_group](common/08_command_group/) | 分层命令组(CommandGroup、参数绑定、多组并列) | ⭐⭐ | ### qq/ — QQ 平台专属 diff --git a/examples/common/08_command_group/main.py b/examples/common/08_command_group/main.py new file mode 100644 index 00000000..44d9eb1f --- /dev/null +++ b/examples/common/08_command_group/main.py @@ -0,0 +1,196 @@ +""" +common/08_command_group — 命令组分层路由插件 + +演示功能: + - 使用 CommandGroupHook 实现分层命令 + - 参数类型自动绑定(int/float/str/At) + - 命令别名支持 + - @hook.subcommand() 装饰器注册子命令并由主 handler 分发 +""" + +import inspect +from ncatbot.core import CommandGroupHook, registrar +from ncatbot.event.qq import GroupMessageEvent +from ncatbot.plugin import NcatBotPlugin + + +class CommandGroupDemoPlugin(NcatBotPlugin): + name = "command_group_common" + version = "1.0.0" + author = "NcatBot" + description = "命令组分层路由示例 — CommandGroupHook 正确用法" + + # ============================================================================ + # 方案 1: Admin 命令组 — 支持 kick/ban 子命令 + # ============================================================================ + admin_hook = CommandGroupHook("admin", "/admin", "a", ignore_case=True) + + @admin_hook.subcommand("kick", "remove") + async def admin_kick(self, event: GroupMessageEvent, user_id: int): + """踢出成员: /admin kick 12345""" + try: + await event.api.manage.set_group_kick( + group_id=event.group_id, user_id=user_id + ) + await event.reply(f"✓ 已踢出成员 {user_id}") + except Exception as e: + await event.reply(f"✗ 踢出失败: {e}") + + @admin_hook.subcommand("ban") + async def admin_ban( + self, event: GroupMessageEvent, user_id: int, minutes: int = 60 + ): + """禁言成员: /admin ban 12345 或 /admin ban 12345 120""" + try: + await event.api.manage.set_group_ban( + group_id=event.group_id, + user_id=user_id, + duration=minutes * 60, + ) + await event.reply(f"✓ 已禁言成员 {user_id} {minutes} 分钟") + except Exception as e: + await event.reply(f"✗ 禁言失败: {e}") + + @registrar.on_group_message() + @admin_hook + async def on_admin(self, event: GroupMessageEvent, **kwargs): + """处理 admin 命令组 + + CommandGroupHook 会自动识别子命令,提取参数到 kwargs + 主 handler 通过检查 kwargs 来调度到真正的子命令处理器 + """ + # 获取消息中的子命令名 + message_text = event.data.message.text.strip() + # 提取命令名后的第一个单词(子命令) + parts = message_text.split(None, 1) + if len(parts) < 2: + await event.reply( + "❓ 缺少子命令。用法: /admin [minutes]" + ) + return + + subcommand_text = parts[1].split()[0].lower() if parts[1] else None + if not subcommand_text: + await event.reply("❓ 缺少子命令") + return + + # 根据子命令查找处理器 + hooks = self.admin_hook._subcommands + # 找到匹配的子命令处理器(不区分大小写) + handler = None + for cmd_name, cmd_handler in hooks.items(): + if cmd_name == subcommand_text.lower(): + handler = cmd_handler + break + + if handler: + # 调用子命令处理器,传入提取的参数 + sig = inspect.signature(handler) + allowed_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} + await handler(self, event, **allowed_kwargs) + else: + await event.reply(f"❌ 未知的子命令: {subcommand_text}") + + # ============================================================================ + # 方案 2: Calc 命令组 — 支持 add/divide/echo 子命令 + # ============================================================================ + calc_hook = CommandGroupHook("calc") + + @calc_hook.subcommand("add") + async def calc_add(self, event: GroupMessageEvent, a: int, b: int): + """加法: /calc add 10 20""" + result = a + b + await event.reply(f"📊 {a} + {b} = {result}") + + @calc_hook.subcommand("divide") + async def calc_divide(self, event: GroupMessageEvent, a: float, b: float): + """除法: /calc divide 10 3""" + if b == 0: + await event.reply("✗ 错误: 除以零") + else: + result = a / b + await event.reply(f"📊 {a} / {b} = {result}") + + @calc_hook.subcommand("echo") + async def calc_echo(self, event: GroupMessageEvent, text: str): + """回显: /calc echo hello world""" + await event.reply(f"🔊 {text}") + + @registrar.on_group_message() + @calc_hook + async def on_calc(self, event: GroupMessageEvent, **kwargs): + """处理计算器命令""" + message_text = event.data.message.text.strip() + parts = message_text.split(None, 1) + if len(parts) < 2: + await event.reply("❓ 缺少子命令。用法: /calc [args...]") + return + + subcommand_text = parts[1].split()[0].lower() if parts[1] else None + if not subcommand_text: + return + + hooks = self.calc_hook._subcommands + handler = None + for cmd_name, cmd_handler in hooks.items(): + if cmd_name == subcommand_text.lower(): + handler = cmd_handler + break + + if handler: + sig = inspect.signature(handler) + allowed_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} + await handler(self, event, **allowed_kwargs) + else: + await event.reply(f"❌ 未知的子命令: {subcommand_text}") + + # ============================================================================ + # 方案 3: Help 命令 — 多别名支持 + # ============================================================================ + help_hook = CommandGroupHook("help", "?", ignore_case=True) + + @help_hook.subcommand("admin") + async def help_admin(self, event: GroupMessageEvent): + """管理员帮助""" + await event.reply( + "Admin Commands:\n/admin kick \n/admin ban [minutes]" + ) + + @help_hook.subcommand("calc") + async def help_calc(self, event: GroupMessageEvent): + """计算器帮助""" + await event.reply( + "Calc Commands:\n/calc add \n/calc divide \n/calc echo " + ) + + @registrar.on_group_message() + @help_hook + async def on_help(self, event: GroupMessageEvent, **kwargs): + """处理帮助命令 (支持 /help 或 /?)""" + message_text = event.data.message.text.strip() + parts = message_text.split(None, 1) + + if len(parts) < 2: + # 无子命令时显示通用帮助 + await event.reply( + "📖 Available Commands:\n" + " /help admin - Admin commands\n" + " /help calc - Calc commands" + ) + return + + subcommand_text = parts[1].split()[0].lower() if parts[1] else None + if not subcommand_text: + return + + hooks = self.help_hook._subcommands + handler = None + for cmd_name, cmd_handler in hooks.items(): + if cmd_name == subcommand_text.lower(): + handler = cmd_handler + break + + if handler: + await handler(self, event) + else: + await event.reply(f"❌ 未知的帮助主题: {subcommand_text}") diff --git a/examples/common/08_command_group/manifest.toml b/examples/common/08_command_group/manifest.toml new file mode 100644 index 00000000..b2b43dc0 --- /dev/null +++ b/examples/common/08_command_group/manifest.toml @@ -0,0 +1,5 @@ +name = "command_group_common" +version = "1.0.0" +main = "main.py" +author = "NcatBot" +description = "命令组分层路由示例 — CommandGroup + CommandGroupHook" diff --git a/ncatbot/core/__init__.py b/ncatbot/core/__init__.py index 67f7497f..8f94f55d 100644 --- a/ncatbot/core/__init__.py +++ b/ncatbot/core/__init__.py @@ -50,6 +50,8 @@ regex, # 命令 Hook CommandHook, + CommandGroup, + CommandGroupHook, # Dispatcher HandlerDispatcher, HandlerEntry, @@ -109,6 +111,8 @@ "regex", # Registry — 命令 Hook "CommandHook", + "CommandGroup", + "CommandGroupHook", # Registry — Dispatcher "HandlerDispatcher", "HandlerEntry", diff --git a/ncatbot/core/registry/__init__.py b/ncatbot/core/registry/__init__.py index 978ec295..dd18f2a7 100644 --- a/ncatbot/core/registry/__init__.py +++ b/ncatbot/core/registry/__init__.py @@ -45,6 +45,9 @@ # 命令 Hook from .command_hook import CommandHook +# 命令组 Hook +from .command_group_hook import CommandGroup, CommandGroupHook + # Dispatcher from .dispatcher import HandlerDispatcher, HandlerEntry @@ -84,6 +87,8 @@ "regex", # 命令 Hook (高级匹配 + 参数绑定) "CommandHook", + "CommandGroup", + "CommandGroupHook", # Dispatcher "HandlerDispatcher", "HandlerEntry", diff --git a/ncatbot/core/registry/command_group_hook.py b/ncatbot/core/registry/command_group_hook.py new file mode 100644 index 00000000..f5a00807 --- /dev/null +++ b/ncatbot/core/registry/command_group_hook.py @@ -0,0 +1,410 @@ +""" +CommandGroupHook — 命令组匹配与自动路由 + +高级 BEFORE_CALL Hook: +1. 支持多命令名,命令后可跟子命令和参数(格式:command subcommand [args...]) +2. 通过 inspect.signature 检查 handler 的类型注解 +3. 支持子命令管理:handler 参数中包含 subcommand 参数,自动提取并路由 +4. 从文本结构化提取参数 (At 段、文本 token) +5. 按类型注解自动转换 (str/int/float/At) +6. 写入 ctx.kwargs → dispatcher._execute(**ctx.kwargs) 自动注入 + +使用示例(与 CommandHook 一致): + + # 定义处理器,声明子命令参数 + hook = CommandGroupHook("admin", "/admin", "a") + + @hook.subcommand("ban", "禁言") + async def admin_ban(event: GroupMessageEvent, user_id: int, minutes: int = 60): + # 处理 "/admin ban 12345", "/admin ban 12345 120" + pass + + @hook.subcommand("kick") + async def admin_kick(event: GroupMessageEvent, user_id: int): + # 处理 "/admin kick 12345" + pass + + @registrar.on_message("message.group") + @hook + @group_only + async def handle_admin(event: GroupMessageEvent, subcommand: str = ""): + # 当命令匹配时,subcommand 自动填充对应的子命令名 + # 如果有对应的 @hook.subcommand() 处理,该处理器先被调用 + pass +""" + +import inspect +from typing import Any, Dict, List, Optional, Tuple, get_type_hints + +from .hook import Hook, HookAction, HookContext, HookStage + + +class CommandGroup: + """命令组 — 为向后兼容而保留(建议改用 CommandGroupHook 的 @subcommand()) + + 管理子命令/子命令组的容器。 + """ + + def __init__(self, names: List[str]): + """ + Args: + names: 该命令组的名称列表 (别名), e.g. ["help", "h"] + """ + if not names: + raise ValueError("CommandGroup 至少需要一个名称") + self.names = names + self.subcommands: Dict[str, Any] = {} # 子命令名 → handler + self.subgroups: Dict[str, "CommandGroup"] = {} # 子命令组名 → CommandGroup + + def command(self, *names: str): + """子命令注册装饰器""" + + def decorator(func): + for name in names: + self.subcommands[name.lower()] = func + return func + + return decorator + + def subgroup(self, group: "CommandGroup"): + """注册子命令组""" + for name in group.names: + self.subgroups[name.lower()] = group + return group + + def __repr__(self) -> str: + return f"" + + +class CommandGroupHook(Hook): + """命令匹配 + 子命令路由 Hook + + 与 CommandHook 基本一致,额外支持子命令注册和自动路由。 + + 使用示例: + + hook = CommandGroupHook("admin", "/admin", "a") + + @hook.subcommand("ban") + async def admin_ban(event: GroupMessageEvent, user_id: int, minutes: int = 60): + pass + + @hook.subcommand("kick", "remove") + async def admin_kick(event: GroupMessageEvent, user_id: int): + pass + + @registrar.on_message("message.group") + @hook + @group_only + async def handle_admin(event: GroupMessageEvent, subcommand: str = ""): + pass + + 用法: + /admin ban 12345 → 调用 admin_ban(event, 12345, 60) + /admin ban 12345 120 → 调用 admin_ban(event, 12345, 120) + /admin kick 12345 → 调用 admin_kick(event, 12345) + """ + + stage = HookStage.BEFORE_CALL + + def __init__( + self, + *names: str, + ignore_case: bool = False, + priority: int = 95, + ): + """ + Args: + *names: 命令名列表(支持别名), e.g. "admin", "/admin", "a" + ignore_case: 是否忽略大小写匹配 + priority: hook 优先级 + """ + if not names: + raise ValueError("CommandGroupHook 至少需要一个命令名") + self.names = names + self.ignore_case = ignore_case + self.priority = priority + self._subcommands: Dict[str, Any] = {} # 子命令名 → handler + self._sig_cache: Dict[int, Optional["_ParamSpec"]] = {} + + def subcommand(self, *subcommand_names: str): + """子命令注册装饰器 + + Args: + *subcommand_names: 子命令名称(支持别名) + """ + + def decorator(func): + for name in subcommand_names: + compare_name = name.lower() if self.ignore_case else name + self._subcommands[compare_name] = func + return func + + return decorator + + async def execute(self, ctx: HookContext) -> HookAction: + # 获取消息文本 + message = getattr(ctx.event.data, "message", None) + if message is None: + return HookAction.SKIP + text = message.text.strip() if hasattr(message, "text") else "" + if not text: + return HookAction.SKIP + + compare_text = text.lower() if self.ignore_case else text + + # 解析 handler 参数规格 (缓存) + func = ctx.handler_entry.func + spec = self._get_param_spec(func) + + # CommandGroupHook 总是支持前缀匹配(为了支持子命令) + matched_name = None + for name in self.names: + compare_name = name.lower() if self.ignore_case else name + if compare_text == compare_name or compare_text.startswith( + compare_name + " " + ): + matched_name = name + break + + if matched_name is None: + return HookAction.SKIP + + # 提取命令后的文本 + if len(text) > len(matched_name): + rest = text[len(matched_name) :].strip() + else: + rest = "" + + # 尝试匹配子命令 + if rest: + first_token, *rest_tokens = rest.split(None, 1) + remaining = rest_tokens[0] if rest_tokens else "" + compare_first = first_token.lower() if self.ignore_case else first_token + + # 查找注册的子命令 + if compare_first in self._subcommands: + subcommand_handler = self._subcommands[compare_first] + # 绑定子命令处理器的参数 + subcommand_spec = self._get_param_spec(subcommand_handler) + if subcommand_spec: + kwargs = self._bind_params(subcommand_spec, remaining, message) + if kwargs is not None: + ctx.kwargs.update(kwargs) + return HookAction.CONTINUE + + # 没有子命令匹配,检查是否是精确命令匹配(无rest) + if not rest: + # 精确匹配:命令名后没有任何东西 + return HookAction.CONTINUE + + # 如果有rest但没有子命令,尝试绑定主handler的参数 + if spec and spec.params: + kwargs = self._bind_params(spec, rest, message) + if kwargs is None: + return HookAction.SKIP + ctx.kwargs.update(kwargs) + + return HookAction.CONTINUE + + def _get_param_spec(self, func) -> Optional["_ParamSpec"]: + """解析并缓存 handler 的参数规格""" + func_id = id(func) + if func_id in self._sig_cache: + return self._sig_cache[func_id] + + try: + sig = inspect.signature(func) + try: + hints = get_type_hints(func) + except Exception: + hints = {} + + params_list = list(sig.parameters.values()) + + # 跳过 self 和 event 参数 + skip = 0 + for p in params_list: + if p.name in ("self", "cls"): + skip += 1 + continue + # 第一个非 self 参数是 event + skip += 1 + break + + extra_params = params_list[skip:] + if not extra_params: + spec = _ParamSpec(params=[]) + self._sig_cache[func_id] = spec + return spec + + params = [] + for p in extra_params: + annotation = hints.get(p.name, p.annotation) + has_default = p.default is not inspect.Parameter.empty + params.append( + _ParamInfo( + name=p.name, + annotation=annotation, + has_default=has_default, + default=p.default if has_default else None, + ) + ) + + spec = _ParamSpec(params=params) + self._sig_cache[func_id] = spec + return spec + + except (ValueError, TypeError): + self._sig_cache[func_id] = _ParamSpec(params=[]) + return self._sig_cache[func_id] + + def _bind_params( + self, + spec: "_ParamSpec", + rest: str, + message: Any, + ) -> Optional[Dict[str, Any]]: + """根据参数规格绑定实际值,失败返回 None + + 支持类型: + - At: 从 message.filter_at() 按序提取 + - int: 从文本 token 提取并转换 + - float: 从文本 token 提取并转换 + - str: 单 token 或剩余文本 (最后一个 str) + """ + from ncatbot.types import At + + # 提取 At 列表和文本 token + at_list: List[Any] = [] + if hasattr(message, "filter_at"): + at_list = list(message.filter_at()) + + text_tokens = rest.split() if rest else [] + + kwargs: Dict[str, Any] = {} + at_idx = 0 + token_idx = 0 + + for i, param in enumerate(spec.params): + # 跳过 subcommand 参数(由外层处理) + if param.name == "subcommand": + if param.has_default: + kwargs[param.name] = param.default + continue + + anno = param.annotation + is_last_str = i == len(spec.params) - 1 and _is_type(anno, str) + + if _is_type(anno, At): + if at_idx < len(at_list): + kwargs[param.name] = at_list[at_idx] + at_idx += 1 + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + + elif _is_type(anno, int): + value = _extract_typed_token(text_tokens, token_idx, int) + if value is not None: + kwargs[param.name] = value[0] + token_idx = value[1] + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + + elif _is_type(anno, float): + value = _extract_typed_token(text_tokens, token_idx, float) + if value is not None: + kwargs[param.name] = value[0] + token_idx = value[1] + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + + elif _is_type(anno, str) or anno is inspect.Parameter.empty: + if is_last_str: + remaining = " ".join(text_tokens[token_idx:]) + if remaining: + kwargs[param.name] = remaining + token_idx = len(text_tokens) + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + else: + if token_idx < len(text_tokens): + kwargs[param.name] = text_tokens[token_idx] + token_idx += 1 + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + + else: + # 未识别类型,尝试 str + if token_idx < len(text_tokens): + kwargs[param.name] = text_tokens[token_idx] + token_idx += 1 + elif param.has_default: + kwargs[param.name] = param.default + else: + return None + + return kwargs + + def __repr__(self) -> str: + return ( + f"" + ) + + +def _extract_typed_token( + tokens: List[str], start_idx: int, target_type: type +) -> Optional[Tuple[Any, int]]: + """从 tokens[start_idx:] 找到第一个可转换为 target_type 的 token""" + for i in range(start_idx, len(tokens)): + try: + return (target_type(tokens[i]), i + 1) + except (ValueError, TypeError): + continue + return None + + +def _is_type(annotation: Any, target: type) -> bool: + """检查注解是否为指定类型""" + if annotation is inspect.Parameter.empty: + return False + if annotation is target: + return True + if isinstance(annotation, type) and issubclass(annotation, target): + return True + if isinstance(annotation, str): + return annotation == target.__name__ + return False + + +class _ParamInfo: + """单个参数信息""" + + __slots__ = ("name", "annotation", "has_default", "default") + + def __init__(self, name: str, annotation: Any, has_default: bool, default: Any): + self.name = name + self.annotation = annotation + self.has_default = has_default + self.default = default + + +class _ParamSpec: + """handler 的参数规格""" + + __slots__ = ("params",) + + def __init__(self, params: List[_ParamInfo]): + self.params = params diff --git a/tests/integration/test_command_group_demo.py b/tests/integration/test_command_group_demo.py new file mode 100644 index 00000000..0394d47a --- /dev/null +++ b/tests/integration/test_command_group_demo.py @@ -0,0 +1,122 @@ +""" +集成测试:08_command_group 插件 +""" + +import pytest +from ncatbot.types.qq import GroupMessageEventData + + +@pytest.mark.asyncio +async def test_admin_kick_command(harness): + """测试 /admin kick 命令""" + event_data = GroupMessageEventData.model_validate( + { + "time": 1, + "self_id": "10001", + "post_type": "message", + "message_type": "group", + "sub_type": "normal", + "message_id": "1", + "group_id": "123456", + "user_id": "100", + "message": [], + "raw_message": "/admin kick 789", + "sender": {"user_id": "100", "nickname": "TestUser"}, + } + ) + + # 注入消息 + await harness.inject(event_data) + await harness.settle() + + +@pytest.mark.asyncio +async def test_admin_ban_command(harness): + """测试 /admin ban 命令""" + event_data = GroupMessageEventData.model_validate( + { + "time": 1, + "self_id": "10001", + "post_type": "message", + "message_type": "group", + "sub_type": "normal", + "message_id": "1", + "group_id": "123456", + "user_id": "100", + "message": [], + "raw_message": "/admin ban 789 120", + "sender": {"user_id": "100", "nickname": "TestUser"}, + } + ) + + await harness.inject(event_data) + await harness.settle() + + +@pytest.mark.asyncio +async def test_calc_add_command(harness): + """测试 /calc add 命令""" + event_data = GroupMessageEventData.model_validate( + { + "time": 1, + "self_id": "10001", + "post_type": "message", + "message_type": "group", + "sub_type": "normal", + "message_id": "1", + "group_id": "123456", + "user_id": "100", + "message": [], + "raw_message": "/calc add 10 20", + "sender": {"user_id": "100", "nickname": "TestUser"}, + } + ) + + await harness.inject(event_data) + await harness.settle() + + +@pytest.mark.asyncio +async def test_calc_divide_command(harness): + """测试 /calc divide 命令""" + event_data = GroupMessageEventData.model_validate( + { + "time": 1, + "self_id": "10001", + "post_type": "message", + "message_type": "group", + "sub_type": "normal", + "message_id": "1", + "group_id": "123456", + "user_id": "100", + "message": [], + "raw_message": "/calc divide 10.5 2.5", + "sender": {"user_id": "100", "nickname": "TestUser"}, + } + ) + + await harness.inject(event_data) + await harness.settle() + + +@pytest.mark.asyncio +async def test_calc_echo_command(harness): + """测试 /calc echo 命令""" + event_data = GroupMessageEventData.model_validate( + { + "time": 1, + "self_id": "10001", + "post_type": "message", + "message_type": "group", + "sub_type": "normal", + "message_id": "1", + "group_id": "123456", + "user_id": "100", + "message": [], + "raw_message": "/calc echo hello world", + "sender": {"user_id": "100", "nickname": "TestUser"}, + } + ) + + await harness.inject(event_data) + await harness.settle() diff --git a/tests/unit/core/test_command_group_hook.py b/tests/unit/core/test_command_group_hook.py new file mode 100644 index 00000000..0253956f --- /dev/null +++ b/tests/unit/core/test_command_group_hook.py @@ -0,0 +1,170 @@ +""" +CommandGroupHook 单元测试 — 核心功能 +""" + +import pytest +from unittest.mock import MagicMock + +from ncatbot.core import CommandGroupHook, HookAction +from ncatbot.core.registry.hook import HookContext + + +@pytest.fixture +def mock_context(): + """构造 mock context 的辅助工厂""" + + def _make(text: str, handler_func): + msg = MagicMock() + msg.text = text + msg.filter_at = MagicMock(return_value=[]) + + event = MagicMock() + event.data = MagicMock() + event.data.message = msg + + entry = MagicMock() + entry.func = handler_func + + return HookContext( + event=event, + event_type="message.group", + handler_entry=entry, + api=MagicMock(), + ) + + return _make + + +# ---- 命令匹配 ---- + + +@pytest.mark.asyncio +async def test_command_matching(mock_context): + """测试命令匹配(精确、别名、前缀)""" + hook = CommandGroupHook("admin", "/admin", "a") + + async def handler(event): + pass + + # 精确匹配命令名 + ctx = mock_context("admin", handler) + assert await hook.execute(ctx) == HookAction.CONTINUE + + # 别名匹配 + ctx = mock_context("/admin", handler) + assert await hook.execute(ctx) == HookAction.CONTINUE + + # 短别名 + ctx = mock_context("a", handler) + assert await hook.execute(ctx) == HookAction.CONTINUE + + # 不匹配 + ctx = mock_context("unknown", handler) + assert await hook.execute(ctx) == HookAction.SKIP + + +@pytest.mark.asyncio +async def test_case_sensitivity(mock_context): + """测试大小写敏感性""" + # 默认区分大小写 + hook = CommandGroupHook("admin") + + async def handler(event): + pass + + ctx = mock_context("ADMIN", handler) + assert await hook.execute(ctx) == HookAction.SKIP + + # 忽略大小写 + hook = CommandGroupHook("admin", ignore_case=True) + ctx = mock_context("ADMIN", handler) + assert await hook.execute(ctx) == HookAction.CONTINUE + + +# ---- 子命令参数绑定 ---- + + +@pytest.mark.asyncio +async def test_subcommand_and_parameters(mock_context): + """测试子命令路由与参数绑定""" + hook = CommandGroupHook("admin") + + async def handler(event): + pass + + # 子命令注册 + @hook.subcommand("ban", "禁言") + async def admin_ban(event, user_id: int, minutes: int = 60): + pass + + @hook.subcommand("kick") + async def admin_kick(event, user_id: int): + pass + + # 匹配 ban 子命令(主别名) + ctx = mock_context("admin ban 12345", handler) + result = await hook.execute(ctx) + assert result == HookAction.CONTINUE + assert ctx.kwargs.get("user_id") == 12345 + assert ctx.kwargs.get("minutes") == 60 # 默认值 + + # 匹配 ban 子命令(别名) + ctx = mock_context("admin 禁言 100 120", handler) + result = await hook.execute(ctx) + assert result == HookAction.CONTINUE + assert ctx.kwargs.get("user_id") == 100 + assert ctx.kwargs.get("minutes") == 120 + + # 匹配 kick 子命令 + ctx = mock_context("admin kick 200", handler) + result = await hook.execute(ctx) + assert result == HookAction.CONTINUE + assert ctx.kwargs.get("user_id") == 200 + + +@pytest.mark.asyncio +async def test_parameter_types(mock_context): + """测试多种参数类型绑定""" + hook = CommandGroupHook("calc") + + async def handler(event): + pass + + @hook.subcommand("math") + async def calc_math(event, a: int, b: float, text: str): + pass + + # int + float + str (str 获取剩余文本) + ctx = mock_context("calc math 10 3.14 hello world", handler) + result = await hook.execute(ctx) + assert result == HookAction.CONTINUE + assert ctx.kwargs.get("a") == 10 + assert ctx.kwargs.get("b") == 3.14 + assert ctx.kwargs.get("text") == "hello world" + + +# ---- 异常情况 ---- + + +@pytest.mark.asyncio +async def test_error_cases(mock_context): + """测试异常情况(空消息、缺失字段等)""" + hook = CommandGroupHook("admin") + + async def handler(event): + pass + + # 空消息 + ctx = mock_context("", handler) + assert await hook.execute(ctx) == HookAction.SKIP + + # 缺少 message 字段 + event = MagicMock() + event.data = MagicMock() + event.data.message = None + entry = MagicMock() + entry.func = handler + ctx = HookContext( + event=event, event_type="message.group", handler_entry=entry, api=MagicMock() + ) + assert await hook.execute(ctx) == HookAction.SKIP