diff --git a/src/bot/handlers/callback.py b/src/bot/handlers/callback.py index 66dd660c..1a985e81 100644 --- a/src/bot/handlers/callback.py +++ b/src/bot/handlers/callback.py @@ -57,6 +57,11 @@ async def handle_callback_query( action, param = data, None # Route to appropriate handler + from .command import _handle_model_selection + + async def _model_effort_handler(query, param, context): + await _handle_model_selection(query, f"{action}:{param}", context) + handlers = { "cd": handle_cd_callback, "action": handle_action_callback, @@ -66,6 +71,8 @@ async def handle_callback_query( "conversation": handle_conversation_callback, "git": handle_git_callback, "export": handle_export_callback, + "model": _model_effort_handler, + "effort": _model_effort_handler, } handler = handlers.get(action) diff --git a/src/bot/handlers/command.py b/src/bot/handlers/command.py index 651a08f8..18f16428 100644 --- a/src/bot/handlers/command.py +++ b/src/bot/handlers/command.py @@ -174,7 +174,8 @@ async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No "• /status - Show session and usage status\n" "• /export - Export session history\n" "• /actions - Show context-aware quick actions\n" - "• /git - Git repository information\n\n" + "• /git - Git repository information\n" + "• /model [name] - View or switch Claude model\n\n" "Session Behavior:\n" "• Sessions are automatically maintained per project directory\n" "• Switching directories with /cd resumes the session for that project\n" @@ -1232,6 +1233,167 @@ async def git_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non logger.error("Error in git_command", error=str(e), user_id=user_id) +# Model IDs mapped to user-friendly labels +_MODELS = { + "opus": "claude-opus-4-6", + "sonnet": "claude-sonnet-4-6", + "haiku": "claude-haiku-4-5-20251001", +} + +# Effort levels per model (Haiku doesn't support effort; "max" is Opus-only) +_EFFORT_BY_MODEL = { + "opus": ["low", "medium", "high", "max"], + "sonnet": ["low", "medium", "high"], + "haiku": [], +} + + +def _current_model_label(context: ContextTypes.DEFAULT_TYPE) -> str: + """Return a human-friendly label for the active model + effort.""" + override = context.user_data.get("model_override") + effort = context.user_data.get("effort_override") + # Reverse-map model ID to short name + model_id = override or "" + label = model_id + for short, full in _MODELS.items(): + if full == model_id: + label = short.capitalize() + break + if not override: + settings = context.bot_data.get("settings") + server_model = getattr(settings, "claude_model", None) if settings else None + label = f"Default ({server_model or 'CLI default'})" + parts = [label] + if effort: + parts.append(f"effort={effort}") + return " | ".join(parts) + + +async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle /model command - show model selection keyboard.""" + current = _current_model_label(context) + + keyboard = [ + [ + InlineKeyboardButton("Opus", callback_data="model:opus"), + InlineKeyboardButton("Sonnet", callback_data="model:sonnet"), + InlineKeyboardButton("Haiku", callback_data="model:haiku"), + ], + [InlineKeyboardButton("Reset to default", callback_data="model:default")], + ] + + await update.message.reply_text( + f"🤖 Current: {escape_html(current)}\n\n" + "Choose a model:\n" + "⚠️ Switching will start a new session.", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(keyboard), + ) + + +async def _handle_model_selection(query, data: str, context) -> None: + """Shared logic for model/effort selection (used by both callback routes).""" + if data.startswith("model:"): + choice = data.split(":", 1)[1] + + if choice == "default": + context.user_data.pop("model_override", None) + context.user_data.pop("effort_override", None) + context.user_data["force_new_session"] = True + await query.edit_message_text( + "🤖 Model and effort reset to server defaults.\n" + "Next message starts a fresh session.", + parse_mode="HTML", + ) + logger.info("Model override cleared", user_id=query.from_user.id) + return + + model_id = _MODELS.get(choice) + if not model_id: + await query.edit_message_text("Unknown model.") + return + + context.user_data["model_override"] = model_id + # Clear stale effort when switching models + context.user_data.pop("effort_override", None) + # Force new session so the model change takes effect immediately + context.user_data["force_new_session"] = True + + logger.info( + "Model override set", + user_id=query.from_user.id, + model=model_id, + ) + + # Show effort level selection (if supported by this model) + effort_levels = _EFFORT_BY_MODEL.get(choice, []) + if not effort_levels: + # Model doesn't support effort (e.g. Haiku) + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + rows = [] + row = [] + for level in effort_levels: + row.append( + InlineKeyboardButton(level.capitalize(), callback_data=f"effort:{level}") + ) + if len(row) == 2: + rows.append(row) + row = [] + if row: + rows.append(row) + rows.append( + [InlineKeyboardButton("Skip (keep current)", callback_data="effort:skip")] + ) + + await query.edit_message_text( + f"🤖 Model set to {escape_html(choice.capitalize())}.\n\n" + "Choose effort level:", + parse_mode="HTML", + reply_markup=InlineKeyboardMarkup(rows), + ) + + elif data.startswith("effort:"): + level = data.split(":", 1)[1] + + if level == "skip": + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + return + + all_effort_levels = {"low", "medium", "high", "max"} + if level in all_effort_levels: + context.user_data["effort_override"] = level + current = _current_model_label(context) + await query.edit_message_text( + f"🤖 {escape_html(current)} — ready.\n" + "New session will start with your next message.", + parse_mode="HTML", + ) + logger.info( + "Effort override set", + user_id=query.from_user.id, + effort=level, + ) + + +async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + """Handle model and effort selection callbacks (agentic mode route).""" + query = update.callback_query + await query.answer() + await _handle_model_selection(query, query.data, context) + + async def restart_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: """Handle /restart command - gracefully restart the bot process. diff --git a/src/bot/handlers/message.py b/src/bot/handlers/message.py index e5fa9f78..798e855b 100644 --- a/src/bot/handlers/message.py +++ b/src/bot/handlers/message.py @@ -393,6 +393,8 @@ async def stream_handler(update_obj): session_id=session_id, on_stream=stream_handler, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag @@ -824,6 +826,8 @@ async def handle_document(update: Update, context: ContextTypes.DEFAULT_TYPE) -> working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -951,6 +955,8 @@ async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # Update session ID @@ -1068,6 +1074,8 @@ async def handle_voice(update: Update, context: ContextTypes.DEFAULT_TYPE) -> No working_directory=current_dir, user_id=user_id, session_id=session_id, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) context.user_data["claude_session_id"] = claude_response.session_id diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py index a18248b0..73fcab18 100644 --- a/src/bot/orchestrator.py +++ b/src/bot/orchestrator.py @@ -319,6 +319,7 @@ def _register_agentic_handlers(self, app: Application) -> None: ("status", self.agentic_status), ("verbose", self.agentic_verbose), ("repo", self.agentic_repo), + ("model", command.model_command), ("restart", command.restart_command), ] if self.settings.enable_project_threads: @@ -364,6 +365,14 @@ def _register_agentic_handlers(self, app: Application) -> None: ) ) + # Model/effort selection callbacks + app.add_handler( + CallbackQueryHandler( + self._inject_deps(command.model_callback), + pattern=r"^(model|effort):", + ) + ) + # Only cd: callbacks (for project selection), scoped by pattern app.add_handler( CallbackQueryHandler( @@ -392,6 +401,7 @@ def _register_classic_handlers(self, app: Application) -> None: ("export", command.export_session), ("actions", command.quick_actions), ("git", command.git_command), + ("model", command.model_command), ("restart", command.restart_command), ] if self.settings.enable_project_threads: @@ -436,6 +446,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("status", "Show session status"), BotCommand("verbose", "Set output verbosity (0/1/2)"), BotCommand("repo", "List repos / switch workspace"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), ] if self.settings.enable_project_threads: @@ -456,6 +467,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg] BotCommand("export", "Export current session"), BotCommand("actions", "Show quick actions"), BotCommand("git", "Git repository commands"), + BotCommand("model", "Switch Claude model and effort"), BotCommand("restart", "Restart the bot"), ] if self.settings.enable_project_threads: @@ -990,6 +1002,8 @@ async def agentic_text( on_stream=on_stream, force_new=force_new, interrupt_event=interrupt_event, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) # New session created successfully — clear the one-shot flag @@ -1240,6 +1254,8 @@ async def agentic_document( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) if force_new: @@ -1439,6 +1455,8 @@ async def _handle_agentic_media_message( session_id=session_id, on_stream=on_stream, force_new=force_new, + model_override=context.user_data.get("model_override"), + effort_override=context.user_data.get("effort_override"), ) finally: heartbeat.cancel() diff --git a/src/claude/facade.py b/src/claude/facade.py index 5c7276eb..18225f98 100644 --- a/src/claude/facade.py +++ b/src/claude/facade.py @@ -39,6 +39,8 @@ async def run_command( on_stream: Optional[Callable[[StreamUpdate], None]] = None, force_new: bool = False, interrupt_event: Optional["asyncio.Event"] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Run Claude Code command with full integration.""" logger.info( @@ -88,6 +90,8 @@ async def run_command( continue_session=should_continue, stream_callback=on_stream, interrupt_event=interrupt_event, + model_override=model_override, + effort_override=effort_override, ) except Exception as resume_error: # If resume failed (e.g., session expired/missing on Claude's side), @@ -113,6 +117,8 @@ async def run_command( continue_session=False, stream_callback=on_stream, interrupt_event=interrupt_event, + model_override=model_override, + effort_override=effort_override, ) else: raise @@ -157,6 +163,8 @@ async def _execute( continue_session: bool = False, stream_callback: Optional[Callable] = None, interrupt_event: Optional[asyncio.Event] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute command via SDK.""" return await self.sdk_manager.execute_command( @@ -166,6 +174,8 @@ async def _execute( continue_session=continue_session, stream_callback=stream_callback, interrupt_event=interrupt_event, + model_override=model_override, + effort_override=effort_override, ) async def _find_resumable_session( diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index ab9c4046..8978fb87 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -155,6 +155,8 @@ async def execute_command( continue_session: bool = False, stream_callback: Optional[Callable[[StreamUpdate], None]] = None, interrupt_event: Optional[asyncio.Event] = None, + model_override: Optional[str] = None, + effort_override: Optional[str] = None, ) -> ClaudeResponse: """Execute Claude Code command via SDK.""" start_time = asyncio.get_event_loop().time() @@ -199,7 +201,7 @@ def _stderr_callback(line: str) -> None: # Build Claude Agent options options = ClaudeAgentOptions( max_turns=self.config.claude_max_turns, - model=self.config.claude_model or None, + model=model_override or self.config.claude_model or None, max_budget_usd=self.config.claude_max_cost_per_request, cwd=str(working_directory), allowed_tools=sdk_allowed_tools, @@ -212,6 +214,7 @@ def _stderr_callback(line: str) -> None: "excludedCommands": self.config.sandbox_excluded_commands or [], }, system_prompt=base_prompt, + effort=effort_override, setting_sources=["project"], stderr=_stderr_callback, ) diff --git a/tests/unit/test_bot/test_model_command.py b/tests/unit/test_bot/test_model_command.py new file mode 100644 index 00000000..538f5994 --- /dev/null +++ b/tests/unit/test_bot/test_model_command.py @@ -0,0 +1,250 @@ +"""Tests for the /model command — runtime model and effort switching. + +Covers: +- /model shows inline keyboard with model choices +- Model selection sets model_override and force_new_session +- Effort selection sets effort_override +- "default" clears all overrides +- Haiku skips effort keyboard (not supported) +- Opus shows "max" effort, Sonnet does not +- _current_model_label returns correct labels +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from telegram import InlineKeyboardMarkup + +from src.bot.handlers.command import ( + _EFFORT_BY_MODEL, + _MODELS, + _current_model_label, + _handle_model_selection, + model_command, +) +from src.config.settings import Settings + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def settings(tmp_path): + return Settings( + telegram_bot_token="test:token", + telegram_bot_username="testbot", + approved_directory=tmp_path, + ) + + +@pytest.fixture +def context(settings): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": settings} + ctx.args = None + return ctx + + +@pytest.fixture +def update(context): + upd = MagicMock() + upd.message = AsyncMock() + upd.effective_user.id = 12345 + return upd + + +@pytest.fixture +def callback_query(): + query = MagicMock() + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + query.from_user.id = 12345 + return query + + +# --------------------------------------------------------------------------- +# /model command (keyboard display) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_model_command_shows_keyboard(update, context): + """Verify /model sends an inline keyboard with model choices.""" + await model_command(update, context) + + update.message.reply_text.assert_called_once() + call_kwargs = update.message.reply_text.call_args + assert isinstance(call_kwargs.kwargs["reply_markup"], InlineKeyboardMarkup) + # Should contain the session warning + assert "new session" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_model_command_shows_current_override(update, context): + """When an override is active, /model should show it.""" + context.user_data["model_override"] = _MODELS["sonnet"] + await model_command(update, context) + + text = update.message.reply_text.call_args.args[0] + assert "Sonnet" in text + + +# --------------------------------------------------------------------------- +# Model selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_select_opus_sets_override(callback_query, context): + """Selecting Opus sets model_override and force_new_session.""" + await _handle_model_selection(callback_query, "model:opus", context) + + assert context.user_data["model_override"] == _MODELS["opus"] + assert context.user_data["force_new_session"] is True + + +@pytest.mark.asyncio +async def test_select_sonnet_sets_override(callback_query, context): + """Selecting Sonnet sets the correct model ID.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + assert context.user_data["model_override"] == _MODELS["sonnet"] + + +@pytest.mark.asyncio +async def test_select_haiku_skips_effort(callback_query, context): + """Selecting Haiku should not show effort keyboard (not supported).""" + await _handle_model_selection(callback_query, "model:haiku", context) + + assert context.user_data["model_override"] == _MODELS["haiku"] + # Final message, no reply_markup (no effort keyboard) + call_kwargs = callback_query.edit_message_text.call_args + assert "reply_markup" not in call_kwargs.kwargs or call_kwargs.kwargs.get("reply_markup") is None + assert "ready" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_opus_shows_effort_with_max(callback_query, context): + """Opus should show effort keyboard including 'max'.""" + await _handle_model_selection(callback_query, "model:opus", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + # Flatten button labels + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" in labels + assert "effort" in call_kwargs.args[0].lower() + + +@pytest.mark.asyncio +async def test_select_sonnet_shows_effort_without_max(callback_query, context): + """Sonnet should show effort keyboard without 'max'.""" + await _handle_model_selection(callback_query, "model:sonnet", context) + + call_kwargs = callback_query.edit_message_text.call_args + markup = call_kwargs.kwargs.get("reply_markup") + assert markup is not None + labels = [btn.text for row in markup.inline_keyboard for btn in row] + assert "Max" not in labels + assert "High" in labels + + +@pytest.mark.asyncio +async def test_default_clears_overrides(callback_query, context): + """Selecting 'default' clears model, effort, and forces new session.""" + context.user_data["model_override"] = _MODELS["opus"] + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:default", context) + + assert "model_override" not in context.user_data + assert "effort_override" not in context.user_data + assert context.user_data["force_new_session"] is True + + +# --------------------------------------------------------------------------- +# Effort selection callback +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_effort_sets_override(callback_query, context): + """Selecting an effort level stores it in user_data.""" + context.user_data["model_override"] = _MODELS["opus"] + + await _handle_model_selection(callback_query, "effort:high", context) + + assert context.user_data["effort_override"] == "high" + + +@pytest.mark.asyncio +async def test_effort_skip_keeps_existing(callback_query, context): + """Selecting 'skip' should not set effort_override.""" + context.user_data["model_override"] = _MODELS["sonnet"] + + await _handle_model_selection(callback_query, "effort:skip", context) + + assert "effort_override" not in context.user_data + + +@pytest.mark.asyncio +async def test_model_switch_clears_stale_effort(callback_query, context): + """Switching models should clear any previous effort override.""" + context.user_data["effort_override"] = "high" + + await _handle_model_selection(callback_query, "model:haiku", context) + + assert "effort_override" not in context.user_data + + +# --------------------------------------------------------------------------- +# Label helper +# --------------------------------------------------------------------------- + + +def test_label_default_no_settings(): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {} + assert _current_model_label(ctx) == "Default (CLI default)" + + +def test_label_default_with_server_model(): + ctx = MagicMock() + ctx.user_data = {} + ctx.bot_data = {"settings": MagicMock(claude_model="claude-sonnet-4-6")} + assert _current_model_label(ctx) == "Default (claude-sonnet-4-6)" + + +def test_label_with_model_and_effort(): + ctx = MagicMock() + ctx.user_data = {"model_override": _MODELS["sonnet"], "effort_override": "medium"} + assert _current_model_label(ctx) == "Sonnet | effort=medium" + + +def test_label_model_only(): + ctx = MagicMock() + ctx.user_data = {"model_override": _MODELS["opus"]} + assert _current_model_label(ctx) == "Opus" + + +# --------------------------------------------------------------------------- +# Effort level configuration +# --------------------------------------------------------------------------- + + +def test_haiku_has_no_effort_levels(): + assert _EFFORT_BY_MODEL["haiku"] == [] + + +def test_sonnet_has_no_max(): + assert "max" not in _EFFORT_BY_MODEL["sonnet"] + assert "high" in _EFFORT_BY_MODEL["sonnet"] + + +def test_opus_has_max(): + assert "max" in _EFFORT_BY_MODEL["opus"] diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index 320f54ae..556f7071 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -82,8 +82,8 @@ def deps(): } -def test_agentic_registers_6_commands(agentic_settings, deps): - """Agentic mode registers start, new, status, verbose, repo, restart commands.""" +def test_agentic_registers_7_commands(agentic_settings, deps): + """Agentic mode registers start, new, status, verbose, repo, model, restart commands.""" orchestrator = MessageOrchestrator(agentic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -100,17 +100,18 @@ def test_agentic_registers_6_commands(agentic_settings, deps): ] commands = [h[0][0].commands for h in cmd_handlers] - assert len(cmd_handlers) == 6 + assert len(cmd_handlers) == 7 assert frozenset({"start"}) in commands assert frozenset({"new"}) in commands assert frozenset({"status"}) in commands assert frozenset({"verbose"}) in commands assert frozenset({"repo"}) in commands + assert frozenset({"model"}) in commands assert frozenset({"restart"}) in commands -def test_classic_registers_14_commands(classic_settings, deps): - """Classic mode registers all 14 commands.""" +def test_classic_registers_15_commands(classic_settings, deps): + """Classic mode registers all 15 commands.""" orchestrator = MessageOrchestrator(classic_settings, deps) app = MagicMock() app.add_handler = MagicMock() @@ -125,7 +126,7 @@ def test_classic_registers_14_commands(classic_settings, deps): if isinstance(call[0][0], CommandHandler) ] - assert len(cmd_handlers) == 14 + assert len(cmd_handlers) == 15 def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): @@ -151,8 +152,8 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps): # 4 message handlers (text, document, photo, voice) assert len(msg_handlers) == 4 - # 2 callback handlers (stop: + cd:) - assert len(cb_handlers) == 2 + # 3 callback handlers (stop: + model/effort: + cd:) + assert len(cb_handlers) == 3 async def test_agentic_bot_commands(agentic_settings, deps): @@ -160,17 +161,17 @@ async def test_agentic_bot_commands(agentic_settings, deps): orchestrator = MessageOrchestrator(agentic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 6 + assert len(commands) == 7 cmd_names = [c.command for c in commands] - assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"] + assert cmd_names == ["start", "new", "status", "verbose", "repo", "model", "restart"] async def test_classic_bot_commands(classic_settings, deps): - """Classic mode returns 14 bot commands.""" + """Classic mode returns 15 bot commands.""" orchestrator = MessageOrchestrator(classic_settings, deps) commands = await orchestrator.get_bot_commands() - assert len(commands) == 14 + assert len(commands) == 15 cmd_names = [c.command for c in commands] assert "start" in cmd_names assert "help" in cmd_names @@ -338,7 +339,7 @@ async def test_agentic_callback_scoped_to_cd_pattern(agentic_settings, deps): if isinstance(call[0][0], CallbackQueryHandler) ] - assert len(cb_handlers) == 2 + assert len(cb_handlers) == 3 # Find the cd: handler by pattern cd_handler = [h for h in cb_handlers if h.pattern and h.pattern.match("cd:x")] assert len(cd_handler) == 1