diff --git a/src/bot/orchestrator.py b/src/bot/orchestrator.py
index ac1d5304..da997cac 100644
--- a/src/bot/orchestrator.py
+++ b/src/bot/orchestrator.py
@@ -306,6 +306,7 @@ def _register_agentic_handlers(self, app: Application) -> None:
("new", self.agentic_new),
("status", self.agentic_status),
("verbose", self.agentic_verbose),
+ ("model", self.agentic_model),
("repo", self.agentic_repo),
("restart", command.restart_command),
]
@@ -415,6 +416,7 @@ async def get_bot_commands(self) -> list: # type: ignore[type-arg]
BotCommand("new", "Start a fresh session"),
BotCommand("status", "Show session status"),
BotCommand("verbose", "Set output verbosity (0/1/2)"),
+ BotCommand("model", "Switch Claude model"),
BotCommand("repo", "List repos / switch workspace"),
BotCommand("restart", "Restart the bot"),
]
@@ -578,6 +580,81 @@ async def agentic_verbose(
parse_mode="HTML",
)
+ def _get_model_override(self, context: ContextTypes.DEFAULT_TYPE) -> Optional[str]:
+ """Return per-user model override, or None to use the default."""
+ return context.user_data.get("model_override")
+
+ @staticmethod
+ def _resolve_model_display(
+ user_override: Optional[str],
+ config_model: Optional[str],
+ last_model: Optional[str] = None,
+ ) -> str:
+ """Return a human-readable model string showing what will actually be used."""
+ if user_override:
+ return user_override
+ if config_model:
+ return config_model
+ if last_model:
+ return last_model
+ return "unknown (send a message first to detect)"
+
+ async def agentic_model(
+ self, update: Update, context: ContextTypes.DEFAULT_TYPE
+ ) -> None:
+ """Set Claude model: /model [model_name]."""
+ args = update.message.text.split()[1:] if update.message.text else []
+ user_override = self._get_model_override(context)
+ last_model = context.user_data.get("last_model")
+ current = self._resolve_model_display(
+ user_override, self.settings.claude_model, last_model
+ )
+
+ if not args:
+ source = "user override" if user_override else (
+ "server config" if self.settings.claude_model else "Claude Code default"
+ )
+ await update.message.reply_text(
+ f"Model: {escape_html(current)} ({source})\n\n"
+ "Usage: /model model_name\n"
+ "Reset: /model default",
+ parse_mode="HTML",
+ )
+ return
+
+ model_name = args[0].strip()
+ if not model_name or len(model_name) > 100:
+ await update.message.reply_text("Invalid model name.")
+ return
+ audit_logger = context.bot_data.get("audit_logger")
+ if model_name == "default":
+ context.user_data.pop("model_override", None)
+ default = self._resolve_model_display(None, self.settings.claude_model)
+ await update.message.reply_text(
+ f"Model reset to {escape_html(default)}",
+ parse_mode="HTML",
+ )
+ if audit_logger:
+ await audit_logger.log_command(
+ user_id=update.effective_user.id,
+ command="model_reset",
+ args=[],
+ success=True,
+ )
+ else:
+ context.user_data["model_override"] = model_name
+ await update.message.reply_text(
+ f"Model set to {escape_html(model_name)}",
+ parse_mode="HTML",
+ )
+ if audit_logger:
+ await audit_logger.log_command(
+ user_id=update.effective_user.id,
+ command="model",
+ args=[model_name],
+ success=True,
+ )
+
def _format_verbose_progress(
self,
activity_log: List[Dict[str, Any]],
@@ -941,6 +1018,7 @@ async def agentic_text(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
# New session created successfully — clear the one-shot flag
@@ -948,6 +1026,8 @@ async def agentic_text(
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ if claude_response.model:
+ context.user_data["last_model"] = claude_response.model
# Track directory changes
from .handlers.message import _update_working_directory_from_claude_response
@@ -1185,12 +1265,15 @@ async def agentic_document(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
if force_new:
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ if claude_response.model:
+ context.user_data["last_model"] = claude_response.model
from .handlers.message import _update_working_directory_from_claude_response
@@ -1384,6 +1467,7 @@ async def _handle_agentic_media_message(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
+ model_override=self._get_model_override(context),
)
finally:
heartbeat.cancel()
@@ -1392,6 +1476,7 @@ async def _handle_agentic_media_message(
context.user_data["force_new_session"] = False
context.user_data["claude_session_id"] = claude_response.session_id
+ context.user_data["last_model"] = claude_response.model
from .handlers.message import _update_working_directory_from_claude_response
diff --git a/src/claude/facade.py b/src/claude/facade.py
index fcb2ada6..09545ff6 100644
--- a/src/claude/facade.py
+++ b/src/claude/facade.py
@@ -37,6 +37,7 @@ async def run_command(
session_id: Optional[str] = None,
on_stream: Optional[Callable[[StreamUpdate], None]] = None,
force_new: bool = False,
+ model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Run Claude Code command with full integration."""
logger.info(
@@ -85,6 +86,7 @@ async def run_command(
session_id=claude_session_id,
continue_session=should_continue,
stream_callback=on_stream,
+ model_override=model_override,
)
except Exception as resume_error:
# If resume failed (e.g., session expired/missing on Claude's side),
@@ -109,6 +111,7 @@ async def run_command(
session_id=None,
continue_session=False,
stream_callback=on_stream,
+ model_override=model_override,
)
else:
raise
@@ -152,6 +155,7 @@ async def _execute(
session_id: Optional[str] = None,
continue_session: bool = False,
stream_callback: Optional[Callable] = None,
+ model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Execute command via SDK."""
return await self.sdk_manager.execute_command(
@@ -160,6 +164,7 @@ async def _execute(
session_id=session_id,
continue_session=continue_session,
stream_callback=stream_callback,
+ model_override=model_override,
)
async def _find_resumable_session(
diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py
index adf553f4..c6b5e591 100644
--- a/src/claude/sdk_integration.py
+++ b/src/claude/sdk_integration.py
@@ -53,6 +53,7 @@ class ClaudeResponse:
is_error: bool = False
error_type: Optional[str] = None
tools_used: List[Dict[str, Any]] = field(default_factory=list)
+ model: Optional[str] = None
@dataclass
@@ -153,6 +154,7 @@ async def execute_command(
session_id: Optional[str] = None,
continue_session: bool = False,
stream_callback: Optional[Callable[[StreamUpdate], None]] = None,
+ model_override: Optional[str] = None,
) -> ClaudeResponse:
"""Execute Claude Code command via SDK."""
start_time = asyncio.get_event_loop().time()
@@ -197,7 +199,7 @@ def _stderr_callback(line: str) -> None:
# Build Claude Agent options
options = ClaudeAgentOptions(
max_turns=self.config.claude_max_turns,
- model=self.config.claude_model or None,
+ model=model_override or self.config.claude_model or None,
max_budget_usd=self.config.claude_max_cost_per_request,
cwd=str(working_directory),
allowed_tools=sdk_allowed_tools,
@@ -294,11 +296,12 @@ async def _run_client() -> None:
timeout=self.config.claude_timeout_seconds,
)
- # Extract cost, tools, and session_id from result message
+ # Extract cost, tools, session_id, and model from result message
cost = 0.0
tools_used: List[Dict[str, Any]] = []
claude_session_id = None
result_content = None
+ response_model: Optional[str] = None
for message in messages:
if isinstance(message, ResultMessage):
cost = getattr(message, "total_cost_usd", 0.0) or 0.0
@@ -307,6 +310,8 @@ async def _run_client() -> None:
current_time = asyncio.get_event_loop().time()
for msg in messages:
if isinstance(msg, AssistantMessage):
+ if not response_model:
+ response_model = getattr(msg, "model", None)
msg_content = getattr(msg, "content", [])
if msg_content and isinstance(msg_content, list):
for block in msg_content:
@@ -377,6 +382,7 @@ async def _run_client() -> None:
]
),
tools_used=tools_used,
+ model=response_model,
)
except asyncio.TimeoutError:
diff --git a/tests/unit/test_claude/test_facade.py b/tests/unit/test_claude/test_facade.py
index 666a2246..814522e2 100644
--- a/tests/unit/test_claude/test_facade.py
+++ b/tests/unit/test_claude/test_facade.py
@@ -269,6 +269,85 @@ async def test_retry_after_failure_still_skips_auto_resume(
assert user_data["force_new_session"] is False
+class TestModelOverride:
+ """Verify model_override is passed through to _execute."""
+
+ async def test_model_override_forwarded_to_execute(self, facade, session_manager):
+ """run_command passes model_override through to _execute."""
+ project = Path("/test/project")
+ user_id = 123
+
+ with patch.object(
+ facade,
+ "_execute",
+ return_value=_make_mock_response(),
+ ) as mock_execute:
+ await facade.run_command(
+ prompt="hello",
+ working_directory=project,
+ user_id=user_id,
+ model_override="opus",
+ )
+
+ mock_execute.assert_called_once()
+ assert mock_execute.call_args.kwargs["model_override"] == "opus"
+
+ async def test_model_override_none_by_default(self, facade, session_manager):
+ """run_command passes model_override=None when not specified."""
+ project = Path("/test/project")
+ user_id = 123
+
+ with patch.object(
+ facade,
+ "_execute",
+ return_value=_make_mock_response(),
+ ) as mock_execute:
+ await facade.run_command(
+ prompt="hello",
+ working_directory=project,
+ user_id=user_id,
+ )
+
+ mock_execute.assert_called_once()
+ assert mock_execute.call_args.kwargs["model_override"] is None
+
+ async def test_model_override_survives_session_retry(self, facade, session_manager):
+ """model_override is preserved when session resume fails and retries."""
+ project = Path("/test/project")
+ user_id = 123
+
+ # Seed an existing session so resume is attempted
+ existing = ClaudeSession(
+ session_id="old-session",
+ user_id=user_id,
+ project_path=project,
+ created_at=datetime.utcnow(),
+ last_used=datetime.utcnow(),
+ )
+ await session_manager.storage.save_session(existing)
+ session_manager.active_sessions[existing.session_id] = existing
+
+ call_count = [0]
+
+ async def _execute_side_effect(**kwargs):
+ call_count[0] += 1
+ if call_count[0] == 1:
+ raise RuntimeError("session expired")
+ return _make_mock_response()
+
+ with patch.object(facade, "_execute", side_effect=_execute_side_effect):
+ await facade.run_command(
+ prompt="hello",
+ working_directory=project,
+ user_id=user_id,
+ session_id="old-session",
+ model_override="haiku",
+ )
+
+ # Both the initial call and retry should have model_override="haiku"
+ assert call_count[0] == 2
+
+
class TestEmptySessionIdWarning:
"""Verify facade warns when final session_id is empty."""
diff --git a/tests/unit/test_claude/test_sdk_integration.py b/tests/unit/test_claude/test_sdk_integration.py
index 17ba58ab..0af40e46 100644
--- a/tests/unit/test_claude/test_sdk_integration.py
+++ b/tests/unit/test_claude/test_sdk_integration.py
@@ -696,6 +696,66 @@ async def test_claude_model_none_when_unset(self, tmp_path):
assert len(captured_options) == 1
assert captured_options[0].model is None
+ async def test_model_override_takes_priority(self, tmp_path):
+ """Test that model_override overrides claude_model from config."""
+ config = Settings(
+ telegram_bot_token="test:token",
+ telegram_bot_username="testbot",
+ approved_directory=tmp_path,
+ claude_timeout_seconds=2,
+ claude_model="claude-sonnet-4-6",
+ )
+ manager = ClaudeSDKManager(config)
+
+ captured_options = []
+ mock_factory = _mock_client_factory(
+ _make_assistant_message("Test response"),
+ _make_result_message(total_cost_usd=0.01),
+ capture_options=captured_options,
+ )
+
+ with patch(
+ "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory
+ ):
+ await manager.execute_command(
+ prompt="Test prompt",
+ working_directory=tmp_path,
+ model_override="claude-opus-4-6",
+ )
+
+ assert len(captured_options) == 1
+ assert captured_options[0].model == "claude-opus-4-6"
+
+ async def test_model_override_none_uses_config(self, tmp_path):
+ """Test that model_override=None falls back to config model."""
+ config = Settings(
+ telegram_bot_token="test:token",
+ telegram_bot_username="testbot",
+ approved_directory=tmp_path,
+ claude_timeout_seconds=2,
+ claude_model="claude-haiku-4-5-20251001",
+ )
+ manager = ClaudeSDKManager(config)
+
+ captured_options = []
+ mock_factory = _mock_client_factory(
+ _make_assistant_message("Test response"),
+ _make_result_message(total_cost_usd=0.01),
+ capture_options=captured_options,
+ )
+
+ with patch(
+ "src.claude.sdk_integration.ClaudeSDKClient", side_effect=mock_factory
+ ):
+ await manager.execute_command(
+ prompt="Test prompt",
+ working_directory=tmp_path,
+ model_override=None,
+ )
+
+ assert len(captured_options) == 1
+ assert captured_options[0].model == "claude-haiku-4-5-20251001"
+
class TestClaudeMCPErrors:
"""Test MCP-specific error handling."""
diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py
index cc02b7c0..6029e0cc 100644
--- a/tests/unit/test_orchestrator.py
+++ b/tests/unit/test_orchestrator.py
@@ -82,8 +82,8 @@ def deps():
}
-def test_agentic_registers_6_commands(agentic_settings, deps):
- """Agentic mode registers start, new, status, verbose, repo, restart commands."""
+def test_agentic_registers_7_commands(agentic_settings, deps):
+ """Agentic mode registers start, new, status, verbose, model, repo, restart."""
orchestrator = MessageOrchestrator(agentic_settings, deps)
app = MagicMock()
app.add_handler = MagicMock()
@@ -100,11 +100,12 @@ def test_agentic_registers_6_commands(agentic_settings, deps):
]
commands = [h[0][0].commands for h in cmd_handlers]
- assert len(cmd_handlers) == 6
+ assert len(cmd_handlers) == 7
assert frozenset({"start"}) in commands
assert frozenset({"new"}) in commands
assert frozenset({"status"}) in commands
assert frozenset({"verbose"}) in commands
+ assert frozenset({"model"}) in commands
assert frozenset({"repo"}) in commands
assert frozenset({"restart"}) in commands
@@ -156,13 +157,13 @@ def test_agentic_registers_text_document_photo_handlers(agentic_settings, deps):
async def test_agentic_bot_commands(agentic_settings, deps):
- """Agentic mode returns 6 bot commands."""
+ """Agentic mode returns 7 bot commands."""
orchestrator = MessageOrchestrator(agentic_settings, deps)
commands = await orchestrator.get_bot_commands()
- assert len(commands) == 6
+ assert len(commands) == 7
cmd_names = [c.command for c in commands]
- assert cmd_names == ["start", "new", "status", "verbose", "repo", "restart"]
+ assert cmd_names == ["start", "new", "status", "verbose", "model", "repo", "restart"]
async def test_classic_bot_commands(classic_settings, deps):
@@ -926,3 +927,265 @@ async def help_command(update, context):
assert called["value"] is False
update.effective_message.reply_text.assert_called_once()
+
+
+# --- /model command tests ---
+
+
+async def test_agentic_model_shows_last_model_when_unset(agentic_settings, deps):
+ """/model with no override shows the model from the last response."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model"
+ update.message.reply_text = AsyncMock()
+
+ context = MagicMock()
+ context.user_data = {"last_model": "claude-opus-4-6"}
+
+ await orchestrator.agentic_model(update, context)
+
+ call_args = update.message.reply_text.call_args
+ text = call_args.args[0]
+ assert "claude-opus-4-6" in text
+ assert "Claude Code default" in text
+
+
+async def test_agentic_model_shows_unknown_before_first_message(agentic_settings, deps):
+ """/model before any message shows unknown."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model"
+ update.message.reply_text = AsyncMock()
+
+ context = MagicMock()
+ context.user_data = {}
+
+ await orchestrator.agentic_model(update, context)
+
+ call_args = update.message.reply_text.call_args
+ text = call_args.args[0]
+ assert "unknown" in text.lower()
+ assert call_args.kwargs.get("parse_mode") == "HTML"
+
+
+async def test_agentic_model_shows_config_model(tmp_dir, deps):
+ """/model shows the server-configured model when CLAUDE_MODEL is set."""
+ settings = create_test_config(
+ approved_directory=str(tmp_dir),
+ agentic_mode=True,
+ claude_model="claude-opus-4-6",
+ )
+ orchestrator = MessageOrchestrator(settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model"
+ update.message.reply_text = AsyncMock()
+
+ context = MagicMock()
+ context.user_data = {}
+
+ await orchestrator.agentic_model(update, context)
+
+ text = update.message.reply_text.call_args.args[0]
+ assert "claude-opus-4-6" in text
+ assert "server config" in text
+
+
+async def test_agentic_model_shows_user_override(agentic_settings, deps):
+ """/model shows the user's override when one is set."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model"
+ update.message.reply_text = AsyncMock()
+
+ context = MagicMock()
+ context.user_data = {"model_override": "haiku"}
+
+ await orchestrator.agentic_model(update, context)
+
+ text = update.message.reply_text.call_args.args[0]
+ assert "haiku" in text
+ assert "user override" in text
+
+
+async def test_agentic_model_sets_override(agentic_settings, deps):
+ """/model sonnet sets the user's model override."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model sonnet"
+ update.message.reply_text = AsyncMock()
+ update.effective_user.id = 123
+
+ context = MagicMock()
+ context.user_data = {}
+ context.bot_data = {"audit_logger": AsyncMock()}
+
+ await orchestrator.agentic_model(update, context)
+
+ assert context.user_data["model_override"] == "sonnet"
+ text = update.message.reply_text.call_args.args[0]
+ assert "sonnet" in text
+
+
+async def test_agentic_model_reset_to_default(agentic_settings, deps):
+ """/model default clears the user's model override."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model default"
+ update.message.reply_text = AsyncMock()
+ update.effective_user.id = 123
+
+ context = MagicMock()
+ context.user_data = {"model_override": "opus"}
+ context.bot_data = {"audit_logger": AsyncMock()}
+
+ await orchestrator.agentic_model(update, context)
+
+ assert "model_override" not in context.user_data
+ text = update.message.reply_text.call_args.args[0]
+ assert "reset" in text.lower()
+
+
+async def test_agentic_model_audit_logged(agentic_settings, deps):
+ """/model sonnet logs the action to audit logger."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model sonnet"
+ update.message.reply_text = AsyncMock()
+ update.effective_user.id = 42
+
+ audit_logger = AsyncMock()
+ context = MagicMock()
+ context.user_data = {}
+ context.bot_data = {"audit_logger": audit_logger}
+
+ await orchestrator.agentic_model(update, context)
+
+ audit_logger.log_command.assert_called_once_with(
+ user_id=42, command="model", args=["sonnet"], success=True,
+ )
+
+
+async def test_agentic_model_reset_audit_logged(agentic_settings, deps):
+ """/model default logs as model_reset with empty args."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model default"
+ update.message.reply_text = AsyncMock()
+ update.effective_user.id = 42
+
+ audit_logger = AsyncMock()
+ context = MagicMock()
+ context.user_data = {"model_override": "opus"}
+ context.bot_data = {"audit_logger": audit_logger}
+
+ await orchestrator.agentic_model(update, context)
+
+ audit_logger.log_command.assert_called_once_with(
+ user_id=42, command="model_reset", args=[], success=True,
+ )
+
+
+
+async def test_agentic_model_rejects_long_name(agentic_settings, deps):
+ """/model with overly long name is rejected."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ update = MagicMock()
+ update.message.text = "/model " + "a" * 101
+ update.message.reply_text = AsyncMock()
+
+ context = MagicMock()
+ context.user_data = {}
+
+ await orchestrator.agentic_model(update, context)
+
+ assert "model_override" not in context.user_data
+ text = update.message.reply_text.call_args.args[0]
+ assert "Invalid" in text
+
+
+async def test_model_override_passed_to_run_command(agentic_settings, deps):
+ """User model override is passed through to claude_integration.run_command."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ mock_response = MagicMock()
+ mock_response.session_id = "session-abc"
+ mock_response.content = "Hello!"
+ mock_response.tools_used = []
+
+ claude_integration = AsyncMock()
+ claude_integration.run_command = AsyncMock(return_value=mock_response)
+
+ update = MagicMock()
+ update.effective_user.id = 123
+ update.message.text = "Help me"
+ update.message.message_id = 1
+ update.message.chat.send_action = AsyncMock()
+ update.message.reply_text = AsyncMock()
+
+ progress_msg = AsyncMock()
+ progress_msg.delete = AsyncMock()
+ update.message.reply_text.return_value = progress_msg
+
+ context = MagicMock()
+ context.user_data = {"model_override": "haiku"}
+ context.bot_data = {
+ "settings": agentic_settings,
+ "claude_integration": claude_integration,
+ "storage": None,
+ "rate_limiter": None,
+ "audit_logger": None,
+ }
+
+ await orchestrator.agentic_text(update, context)
+
+ claude_integration.run_command.assert_called_once()
+ call_kwargs = claude_integration.run_command.call_args.kwargs
+ assert call_kwargs["model_override"] == "haiku"
+
+
+async def test_model_override_none_when_not_set(agentic_settings, deps):
+ """model_override is None when user hasn't set one."""
+ orchestrator = MessageOrchestrator(agentic_settings, deps)
+
+ mock_response = MagicMock()
+ mock_response.session_id = "session-abc"
+ mock_response.content = "Hello!"
+ mock_response.tools_used = []
+
+ claude_integration = AsyncMock()
+ claude_integration.run_command = AsyncMock(return_value=mock_response)
+
+ update = MagicMock()
+ update.effective_user.id = 123
+ update.message.text = "Help me"
+ update.message.message_id = 1
+ update.message.chat.send_action = AsyncMock()
+ update.message.reply_text = AsyncMock()
+
+ progress_msg = AsyncMock()
+ progress_msg.delete = AsyncMock()
+ update.message.reply_text.return_value = progress_msg
+
+ context = MagicMock()
+ context.user_data = {}
+ context.bot_data = {
+ "settings": agentic_settings,
+ "claude_integration": claude_integration,
+ "storage": None,
+ "rate_limiter": None,
+ "audit_logger": None,
+ }
+
+ await orchestrator.agentic_text(update, context)
+
+ call_kwargs = claude_integration.run_command.call_args.kwargs
+ assert call_kwargs["model_override"] is None