Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ async def initialize(self) -> None:
builder.defaults(Defaults(do_quote=self.settings.reply_quote))
builder.rate_limiter(AIORateLimiter(max_retries=1))

from .update_processor import StopAwareUpdateProcessor

builder.concurrent_updates(StopAwareUpdateProcessor())

# Configure connection settings
builder.connect_timeout(30)
builder.read_timeout(30)
Expand Down
98 changes: 92 additions & 6 deletions src/bot/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -109,12 +110,23 @@ def _tool_icon(name: str) -> str:
return _TOOL_ICONS.get(name, "\U0001f527")


@dataclass
class ActiveRequest:
"""Tracks an in-flight Claude request so it can be interrupted."""

user_id: int
interrupt_event: asyncio.Event = field(default_factory=asyncio.Event)
interrupted: bool = False
progress_msg: Any = None # telegram Message object


class MessageOrchestrator:
"""Routes messages based on mode. Single entry point for all Telegram updates."""

def __init__(self, settings: Settings, deps: Dict[str, Any]):
self.settings = settings
self.deps = deps
self._active_requests: Dict[int, ActiveRequest] = {}

def _inject_deps(self, handler: Callable) -> Callable: # type: ignore[type-arg]
"""Wrap handler to inject dependencies into context.bot_data."""
Expand Down Expand Up @@ -344,6 +356,14 @@ def _register_agentic_handlers(self, app: Application) -> None:
group=10,
)

# Stop button callback (must be before cd: handler)
app.add_handler(
CallbackQueryHandler(
self._inject_deps(self._handle_stop_callback),
pattern=r"^stop:",
)
)

# Only cd: callbacks (for project selection), scoped by pattern
app.add_handler(
CallbackQueryHandler(
Expand Down Expand Up @@ -675,9 +695,11 @@ def _make_stream_callback(
progress_msg: Any,
tool_log: List[Dict[str, Any]],
start_time: float,
reply_markup: Optional[InlineKeyboardMarkup] = None,
mcp_images: Optional[List[ImageAttachment]] = None,
approved_directory: Optional[Path] = None,
draft_streamer: Optional[DraftStreamer] = None,
interrupt_event: Optional[asyncio.Event] = None,
) -> Optional[Callable[[StreamUpdate], Any]]:
"""Create a stream callback for verbose progress updates.

Expand All @@ -701,6 +723,10 @@ def _make_stream_callback(
last_edit_time = [0.0] # mutable container for closure

async def _on_stream(update_obj: StreamUpdate) -> None:
# Stop all streaming activity after interrupt
if interrupt_event is not None and interrupt_event.is_set():
return

# Intercept send_image_to_user MCP tool calls.
# The SDK namespaces MCP tools as "mcp__<server>__<tool>",
# so match both the bare name and the namespaced variant.
Expand Down Expand Up @@ -765,7 +791,9 @@ async def _on_stream(update_obj: StreamUpdate) -> None:
tool_log, verbose_level, start_time
)
try:
await progress_msg.edit_text(new_text)
await progress_msg.edit_text(
new_text, reply_markup=reply_markup
)
except Exception:
pass

Expand Down Expand Up @@ -885,12 +913,30 @@ async def agentic_text(
await chat.send_action("typing")

verbose_level = self._get_verbose_level(context)
progress_msg = await update.message.reply_text("Working...")

# Create Stop button and interrupt event
interrupt_event = asyncio.Event()
stop_kb = InlineKeyboardMarkup(
[[InlineKeyboardButton("Stop", callback_data=f"stop:{user_id}")]]
)
progress_msg = await update.message.reply_text(
"Working...", reply_markup=stop_kb
)

# Register active request for stop callback
active_request = ActiveRequest(
user_id=user_id,
interrupt_event=interrupt_event,
progress_msg=progress_msg,
)
self._active_requests[user_id] = active_request

claude_integration = context.bot_data.get("claude_integration")
if not claude_integration:
self._active_requests.pop(user_id, None)
await progress_msg.edit_text(
"Claude integration not available. Check configuration."
"Claude integration not available. Check configuration.",
reply_markup=None,
)
return

Expand Down Expand Up @@ -924,9 +970,11 @@ async def agentic_text(
progress_msg,
tool_log,
start_time,
reply_markup=stop_kb,
mcp_images=mcp_images,
approved_directory=self.settings.approved_directory,
draft_streamer=draft_streamer,
interrupt_event=interrupt_event,
)

# Independent typing heartbeat — stays alive even with no stream events
Expand All @@ -941,6 +989,7 @@ async def agentic_text(
session_id=session_id,
on_stream=on_stream,
force_new=force_new,
interrupt_event=interrupt_event,
)

# New session created successfully — clear the one-shot flag
Expand Down Expand Up @@ -974,9 +1023,14 @@ async def agentic_text(
from .utils.formatting import ResponseFormatter

formatter = ResponseFormatter(self.settings)
formatted_messages = formatter.format_claude_response(
claude_response.content
)

response_content = claude_response.content
if claude_response.interrupted:
response_content = (
response_content or ""
) + "\n\n_(Interrupted by user)_"

formatted_messages = formatter.format_claude_response(response_content)

except Exception as e:
success = False
Expand All @@ -989,6 +1043,7 @@ async def agentic_text(
]
finally:
heartbeat.cancel()
self._active_requests.pop(user_id, None)
if draft_streamer:
try:
await draft_streamer.flush()
Expand Down Expand Up @@ -1555,6 +1610,37 @@ async def agentic_repo(
reply_markup=reply_markup,
)

async def _handle_stop_callback(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
"""Handle stop: callbacks — interrupt a running Claude request."""
query = update.callback_query
target_user_id = int(query.data.split(":", 1)[1])

# Only the requesting user can stop their own request
if query.from_user.id != target_user_id:
await query.answer(
"Only the requesting user can stop this.", show_alert=True
)
return

active = self._active_requests.get(target_user_id)
if not active:
await query.answer("Already completed.", show_alert=False)
return
if active.interrupted:
await query.answer("Already stopping...", show_alert=False)
return

active.interrupt_event.set()
active.interrupted = True
await query.answer("Stopping...", show_alert=False)

try:
await active.progress_msg.edit_text("Stopping...", reply_markup=None)
except Exception:
pass

async def _agentic_callback(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
) -> None:
Expand Down
70 changes: 70 additions & 0 deletions src/bot/update_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Selective-concurrency update processor for PTB.

Regular updates (messages, commands) process sequentially -- one at a time.
Priority callbacks (stop:*) bypass the queue and run immediately so they can
interrupt the currently-running handler.
"""

import asyncio
from typing import Any, Awaitable

from telegram import Update
from telegram.ext._baseupdateprocessor import BaseUpdateProcessor


class StopAwareUpdateProcessor(BaseUpdateProcessor):
"""Update processor that lets priority callbacks bypass sequential processing.

PTB calls ``process_update(update, coroutine)`` for every incoming update.
The base class holds a semaphore (max 256) then calls our
``do_process_update()``.

For priority callbacks (``stop:*``): we just ``await coroutine`` -- runs
immediately.
For everything else: we acquire ``_sequential_lock`` first -- only one
runs at a time.

A stop callback arrives while a text handler holds the lock -> stop
callback runs concurrently -> fires the ``asyncio.Event`` -> the watcher
task inside ``execute_command()`` calls ``client.interrupt()`` -> Claude
stops -> ``run_command()`` returns -> handler finishes -> lock released.
"""

_PRIORITY_PREFIXES = ("stop:",)

def __init__(self) -> None:
# High limit so priority callbacks are never blocked by semaphore
super().__init__(max_concurrent_updates=256)
self._sequential_lock = asyncio.Lock()

@classmethod
def _is_priority_callback(cls, update: object) -> bool:
"""Return True if the update is a priority callback query."""
if not isinstance(update, Update):
return False
cb = update.callback_query
return (
cb is not None
and cb.data is not None
and cb.data.startswith(cls._PRIORITY_PREFIXES)
)

async def do_process_update(
self,
update: object,
coroutine: Awaitable[Any],
) -> None:
"""Process an update, applying sequential lock for non-priority updates."""
if self._is_priority_callback(update):
# Run immediately -- no sequential lock
await coroutine
else:
# One at a time for everything else
async with self._sequential_lock:
await coroutine

async def initialize(self) -> None:
"""Initialize the processor (no-op)."""

async def shutdown(self) -> None:
"""Shutdown the processor (no-op)."""
6 changes: 6 additions & 0 deletions src/claude/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Provides simple interface for bot handlers.
"""

import asyncio
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -37,6 +38,7 @@ async def run_command(
session_id: Optional[str] = None,
on_stream: Optional[Callable[[StreamUpdate], None]] = None,
force_new: bool = False,
interrupt_event: Optional["asyncio.Event"] = None,
) -> ClaudeResponse:
"""Run Claude Code command with full integration."""
logger.info(
Expand Down Expand Up @@ -85,6 +87,7 @@ async def run_command(
session_id=claude_session_id,
continue_session=should_continue,
stream_callback=on_stream,
interrupt_event=interrupt_event,
)
except Exception as resume_error:
# If resume failed (e.g., session expired/missing on Claude's side),
Expand All @@ -109,6 +112,7 @@ async def run_command(
session_id=None,
continue_session=False,
stream_callback=on_stream,
interrupt_event=interrupt_event,
)
else:
raise
Expand Down Expand Up @@ -152,6 +156,7 @@ async def _execute(
session_id: Optional[str] = None,
continue_session: bool = False,
stream_callback: Optional[Callable] = None,
interrupt_event: Optional[asyncio.Event] = None,
) -> ClaudeResponse:
"""Execute command via SDK."""
return await self.sdk_manager.execute_command(
Expand All @@ -160,6 +165,7 @@ async def _execute(
session_id=session_id,
continue_session=continue_session,
stream_callback=stream_callback,
interrupt_event=interrupt_event,
)

async def _find_resumable_session(
Expand Down
Loading