diff --git a/.env.example b/.env.example index e3a4b66..f9c1001 100644 --- a/.env.example +++ b/.env.example @@ -17,3 +17,7 @@ RAG_WARMUP_ON_LAUNCH=1 # YouTube API Configuration (Optional - for robust metadata) # Get your key from: https://console.cloud.google.com/apis/credentials YOUTUBE_API_KEY= + +# Telegram Bot Configuration +# Get your token from @BotFather on Telegram +TELEGRAM_BOT_TOKEN= diff --git a/deploy/entrypoint.sh b/deploy/entrypoint.sh index 5d3a047..eb7c75a 100755 --- a/deploy/entrypoint.sh +++ b/deploy/entrypoint.sh @@ -1,4 +1,13 @@ #!/usr/bin/env bash set -euo pipefail +# Start Telegram bot in background (if token is configured) +if [ -n "${TELEGRAM_BOT_TOKEN:-}" ]; then + echo "[entrypoint] Starting Telegram bot (polling mode)..." + python run_telegram_bot.py & + TELEGRAM_PID=$! + echo "[entrypoint] Telegram bot started (PID: $TELEGRAM_PID)" +fi + +# Start Streamlit frontend (foreground — container stays alive) exec python deploy/start_frontend.py diff --git a/deploy/main.tf b/deploy/main.tf index 0ffe244..7dae53a 100644 --- a/deploy/main.tf +++ b/deploy/main.tf @@ -20,7 +20,8 @@ locals { streamlit_container_secrets = concat( local.streamlit_openai_value != "" && local.streamlit_openai_is_arn ? [{ name = "OPENAI_API_KEY", valueFrom = local.streamlit_openai_value }] : [], - var.youtube_api_key_secret_arn != "" ? [{ name = "YOUTUBE_API_KEY", valueFrom = var.youtube_api_key_secret_arn }] : [] + var.youtube_api_key_secret_arn != "" ? [{ name = "YOUTUBE_API_KEY", valueFrom = var.youtube_api_key_secret_arn }] : [], + var.telegram_bot_token_secret_arn != "" ? [{ name = "TELEGRAM_BOT_TOKEN", valueFrom = var.telegram_bot_token_secret_arn }] : [] ) worker_container_secrets = concat( @@ -37,7 +38,8 @@ locals { exec_secret_arns = concat( local.openai_secret_arns, - var.youtube_api_key_secret_arn != "" ? [var.youtube_api_key_secret_arn] : [] + var.youtube_api_key_secret_arn != "" ? [var.youtube_api_key_secret_arn] : [], + var.telegram_bot_token_secret_arn != "" ? [var.telegram_bot_token_secret_arn] : [] ) } diff --git a/deploy/terraform.tfvars.example b/deploy/terraform.tfvars.example index 831088b..8c81650 100644 --- a/deploy/terraform.tfvars.example +++ b/deploy/terraform.tfvars.example @@ -49,6 +49,9 @@ streamlit_desired_count = 2 streamlit_min_capacity = 2 streamlit_max_capacity = 10 +# Telegram Bot (optional — leave empty to disable) +telegram_bot_token_secret_arn = "" + # Enable aws ecs execute-command support enable_ecs_exec = true diff --git a/deploy/variables.tf b/deploy/variables.tf index c9cbcba..0068ec5 100644 --- a/deploy/variables.tf +++ b/deploy/variables.tf @@ -222,3 +222,9 @@ variable "ask_assistant_enabled" { type = string default = "0" } + +variable "telegram_bot_token_secret_arn" { + description = "Secrets Manager ARN for TELEGRAM_BOT_TOKEN. Leave empty to disable the Telegram bot." + type = string + default = "" +} diff --git a/requirements.txt b/requirements.txt index 14aebce..7e7f283 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,9 @@ mypy>=1.10.0 # Testing pytest>=8.0.0 pytest-cov>=5.0.0 + +# Telegram Bot +python-telegram-bot>=21.0 + +# Testing (async) +pytest-asyncio>=0.23.0 diff --git a/run_telegram_bot.py b/run_telegram_bot.py new file mode 100644 index 0000000..b6af6c4 --- /dev/null +++ b/run_telegram_bot.py @@ -0,0 +1,40 @@ +"""Entry point for the NEET PYQ Telegram Bot.""" + +import logging +import os +from importlib import import_module + + +def _load_dotenv() -> None: + try: + dotenv = import_module("dotenv") + except ModuleNotFoundError: + return + + load_fn = getattr(dotenv, "load_dotenv", None) + if callable(load_fn): + load_fn() + + +_load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(name)s] %(levelname)s: %(message)s", +) +logger = logging.getLogger(__name__) + + +def main() -> None: + token = os.getenv("TELEGRAM_BOT_TOKEN") + if not token: + raise ValueError("TELEGRAM_BOT_TOKEN environment variable is required") + logger.info("Initializing NEET PYQ Telegram Bot...") + from src.telegram_bot.bot import create_application, run_polling + + app = create_application(token) + run_polling(app) + + +if __name__ == "__main__": + main() diff --git a/src/telegram_bot/__init__.py b/src/telegram_bot/__init__.py new file mode 100644 index 0000000..7ac60c0 --- /dev/null +++ b/src/telegram_bot/__init__.py @@ -0,0 +1,5 @@ +"""Telegram bot interface for the NEET PYQ Assistant.""" + +from src.telegram_bot.bot import create_application + +__all__ = ["create_application"] diff --git a/src/telegram_bot/bot.py b/src/telegram_bot/bot.py new file mode 100644 index 0000000..6beb958 --- /dev/null +++ b/src/telegram_bot/bot.py @@ -0,0 +1,225 @@ +import asyncio +import logging +from typing import Any, Mapping + +from telegram import Update +from telegram.constants import ChatAction, ParseMode +from telegram.ext import ( + Application, + CommandHandler, + ContextTypes, + MessageHandler, + filters, +) + +from src.telegram_bot.formatting import format_response +from src.telegram_bot.history import TelegramChatHistory + +logger = logging.getLogger(__name__) + +_ERR_GENERIC = "Sorry, I couldn't process your question right now. Please try again." +_ERR_IMAGE_DOWNLOAD = "I couldn't download your image. Please try sending it again." +_ERR_IMAGE_EXTRACT = "I couldn't read the question from your image. Try a clearer photo or type the question." + + +def _is_error_response(result: Mapping[str, object]) -> bool: + if "error" in result: + return True + answer = result.get("answer", "") + return isinstance(answer, str) and answer.startswith("Error generating answer:") + + +def create_application(token: str, rag: Any | None = None): + app = Application.builder().token(token).build() + if rag is None: + from src.utils.rag_singleton import get_rag_system + + rag = get_rag_system() + + app.bot_data["rag"] = rag + app.bot_data["history"] = TelegramChatHistory() + app.add_handler(CommandHandler("start", start_command)) + app.add_handler(CommandHandler("help", help_command)) + app.add_handler(MessageHandler(filters.PHOTO, handle_photo)) + app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message)) + app.add_error_handler(error_handler) + return app + + +def run_polling(app) -> None: + app.run_polling(drop_pending_updates=True) + + +async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + _ = context + message = update.message + if message is None: + return + welcome = ( + "Welcome to the NEET PYQ Assistant!\n\n" + "Send a text question or a photo question and I will help with answers and sources.\n\n" + "Type /help for usage instructions." + ) + _ = await message.reply_text(welcome, parse_mode=ParseMode.HTML) + + +async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + _ = context + message = update.message + if message is None: + return + help_text = ( + "How to use:\n" + "- Send a text question\n" + "- Or send a photo of a question with an optional caption" + ) + _ = await message.reply_text(help_text, parse_mode=ParseMode.HTML) + + +async def _send_typing_periodically( + chat_id: int, context: ContextTypes.DEFAULT_TYPE +) -> None: + try: + while True: + await context.bot.send_chat_action( + chat_id=chat_id, action=ChatAction.TYPING + ) + await asyncio.sleep(5) + except asyncio.CancelledError: + return + + +async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + message = update.message + user = update.effective_user + chat = update.effective_chat + if message is None or user is None or chat is None or message.text is None: + return + + question = message.text + user_id = user.id + chat_id = chat.id + rag = context.bot_data["rag"] + history = context.bot_data["history"] + + await context.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) + typing_task = asyncio.create_task(_send_typing_periodically(chat_id, context)) + try: + chat_history = history.load_history(user_id) + result = rag.query_with_history( + question, + chat_history=chat_history, + session_id=str(user_id), + user_id=str(user_id), + ) + + if _is_error_response(result): + _ = await message.reply_text(_ERR_GENERIC) + return + + raw_answer = str(result.get("answer", "")) + parts = format_response( + answer_text=raw_answer, + youtube_sources=result.get("sources", []), + question_sources=result.get("question_sources", []), + ) + for part in parts: + _ = await message.reply_text(part, parse_mode=ParseMode.HTML) + + history.save_turn( + user_id=user_id, + user_message=question, + assistant_message=raw_answer, + ) + except Exception: + logger.exception("Failed to handle text message") + _ = await message.reply_text(_ERR_GENERIC) + finally: + typing_task.cancel() + + +async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + message = update.message + user = update.effective_user + chat = update.effective_chat + if message is None or user is None or chat is None: + return + + user_id = user.id + chat_id = chat.id + caption = message.caption or "" + rag = context.bot_data["rag"] + history = context.bot_data["history"] + + await context.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING) + typing_task = asyncio.create_task(_send_typing_periodically(chat_id, context)) + try: + try: + photo = message.photo[-1] + file = await context.bot.get_file(photo.file_id) + image_bytes = bytes(await file.download_as_bytearray()) + except Exception: + logger.exception("Failed downloading photo") + _ = await message.reply_text(_ERR_IMAGE_DOWNLOAD) + return + + try: + extracted = rag.llm_manager.extract_image_context( + image_bytes=image_bytes, + filename="telegram_photo.jpg", + user_hint=caption, + session_id=str(user_id), + user_id=str(user_id), + ) + except Exception: + logger.exception("Failed extracting photo context") + _ = await message.reply_text(_ERR_IMAGE_EXTRACT) + return + + if caption: + question = f"{caption}\n\nImage context:\n{extracted}" + else: + question = str(extracted) + + chat_history = history.load_history(user_id) + result = rag.query_with_history( + question, + chat_history=chat_history, + session_id=str(user_id), + user_id=str(user_id), + ) + + if _is_error_response(result): + _ = await message.reply_text(_ERR_GENERIC) + return + + raw_answer = str(result.get("answer", "")) + parts = format_response( + answer_text=raw_answer, + youtube_sources=result.get("sources", []), + question_sources=result.get("question_sources", []), + ) + for part in parts: + _ = await message.reply_text(part, parse_mode=ParseMode.HTML) + + history.save_turn( + user_id=user_id, + user_message=question, + assistant_message=raw_answer, + ) + except Exception: + logger.exception("Failed to handle photo message") + _ = await message.reply_text(_ERR_GENERIC) + finally: + typing_task.cancel() + + +async def error_handler(update: object, context: ContextTypes.DEFAULT_TYPE) -> None: + logger.error("Unhandled exception", exc_info=context.error) + if isinstance(update, Update) and update.effective_chat: + try: + _ = await context.bot.send_message( + chat_id=update.effective_chat.id, text=_ERR_GENERIC + ) + except Exception: + logger.exception("Failed to send error message") diff --git a/src/telegram_bot/formatting.py b/src/telegram_bot/formatting.py new file mode 100644 index 0000000..fc5f0db --- /dev/null +++ b/src/telegram_bot/formatting.py @@ -0,0 +1,176 @@ +# pyright: reportPrivateUsage=false + +import re +from html import escape as html_escape +from typing import TypedDict + +from src.utils.answer_formatting import ( + _SUBSCRIPT_MAP, + _SUPERSCRIPT_MAP, +) + + +class YouTubeSource(TypedDict, total=False): + title: str + url: str + timestamp: str + + +class QuestionSource(TypedDict, total=False): + id: int | str + question: str + title: str + + +_INLINE_LATEX_PATTERN = re.compile(r"\$(.+?)\$", re.DOTALL) +_BLOCK_LATEX_PATTERN = re.compile(r"\$\$(.+?)\$\$", re.DOTALL) +_PAREN_LATEX_PATTERN = re.compile(r"\\\((.+?)\\\)", re.DOTALL) +_BRACKET_LATEX_PATTERN = re.compile(r"\\\[(.+?)\\\]", re.DOTALL) +_SUP_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) +_SUB_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + + +def _render_with_map(text: str, mapping: dict[str, str]) -> str: + cleaned = text.strip() + if not cleaned: + return "" + if all(char in mapping for char in cleaned): + return "".join(mapping[char] for char in cleaned) + return cleaned + + +def _strip_latex_delimiters(text: str) -> str: + stripped = _BLOCK_LATEX_PATTERN.sub(r"\1", text) + stripped = _PAREN_LATEX_PATTERN.sub(r"\1", stripped) + stripped = _BRACKET_LATEX_PATTERN.sub(r"\1", stripped) + stripped = _INLINE_LATEX_PATTERN.sub(r"\1", stripped) + return stripped + + +def _convert_sup_sub(text: str) -> str: + converted = _SUP_PATTERN.sub( + lambda match: _render_with_map(match.group(1), _SUPERSCRIPT_MAP), text + ) + converted = _SUB_PATTERN.sub( + lambda match: _render_with_map(match.group(1), _SUBSCRIPT_MAP), converted + ) + return converted + + +def format_answer_text(answer_text: str) -> str: + converted = _convert_sup_sub(answer_text) + latex_stripped = _strip_latex_delimiters(converted) + return html_escape(latex_stripped) + + +def _parse_timestamp_to_seconds(timestamp: str) -> int | None: + parts = [part.strip() for part in timestamp.split(":")] + if not parts or any(not part.isdigit() for part in parts): + return None + + values = [int(part) for part in parts] + if len(values) == 3: + hours, minutes, seconds = values + return (hours * 3600) + (minutes * 60) + seconds + if len(values) == 2: + minutes, seconds = values + return (minutes * 60) + seconds + if len(values) == 1: + return values[0] + return None + + +def format_youtube_sources(sources: list[YouTubeSource]) -> str: + if not sources: + return "" + + lines: list[str] = [] + for index, source in enumerate(sources, start=1): + raw_title = str(source.get("title") or f"YouTube Source {index}") + raw_url = str(source.get("url") or "") + timestamp = source.get("timestamp") + + if isinstance(timestamp, str) and timestamp.strip(): + seconds = _parse_timestamp_to_seconds(timestamp.strip()) + if seconds is not None: + separator = "&" if "?" in raw_url else "?" + raw_url = f"{raw_url}{separator}t={seconds}" + raw_title = f"{raw_title} ({timestamp.strip()})" + + escaped_title = html_escape(raw_title) + escaped_url = html_escape(raw_url, quote=True) + lines.append(f'{index}. {escaped_title}') + + return "\n".join(lines) + + +def format_question_sources(sources: list[QuestionSource]) -> str: + if not sources: + return "" + + lines: list[str] = [] + for index, source in enumerate(sources, start=1): + question_id = str(source.get("id") or "") + raw_title = str( + source.get("question") or source.get("title") or f"Question {index}" + ) + url = f"https://neetprep.com/epubQuestion/{question_id}" + lines.append( + f'{index}. {html_escape(raw_title)}' + ) + + return "\n".join(lines) + + +def format_response( + answer_text: str, + youtube_sources: list[YouTubeSource] | None = None, + question_sources: list[QuestionSource] | None = None, + max_length: int = 4096, +) -> list[str]: + full_text = format_answer_text(answer_text) + + youtube_block = format_youtube_sources(youtube_sources or []) + if youtube_block: + full_text += f"\n\nYouTube Sources\n{youtube_block}" + + question_block = format_question_sources(question_sources or []) + if question_block: + full_text += f"\n\nRelated Questions\n{question_block}" + + return split_message(full_text, max_length=max_length) + + +def split_message(text: str, max_length: int = 4096) -> list[str]: + if len(text) <= max_length: + return [text] + + chunks: list[str] = [] + remaining = text + + while remaining: + if len(remaining) <= max_length: + chunks.append(remaining) + break + + window = remaining[:max_length] + split_at = window.rfind("\n") + cut = split_at + 1 if split_at != -1 else max_length + + candidate = remaining[:cut] + last_lt = candidate.rfind("<") + last_gt = candidate.rfind(">") + if last_lt > last_gt: + safe_cut = last_lt + if safe_cut > 0: + cut = safe_cut + candidate = remaining[:cut] + + if not candidate: + candidate = remaining[:max_length] + cut = len(candidate) + + chunks.append(candidate) + remaining = remaining[cut:] + + return chunks diff --git a/src/telegram_bot/history.py b/src/telegram_bot/history.py new file mode 100644 index 0000000..5296475 --- /dev/null +++ b/src/telegram_bot/history.py @@ -0,0 +1,95 @@ +# pyright: reportMissingImports=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnannotatedClassAttribute=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportAny=false + +import json +import logging +import os +from urllib.parse import urlparse + +import redis + + +logger = logging.getLogger(__name__) + + +class TelegramChatHistory: + def __init__(self, redis_url=None, max_turns=4, ttl_seconds=604800): + self._redis = None + self._max_turns = max_turns + self._ttl_seconds = ttl_seconds + + effective_redis_url = redis_url or os.getenv( + "REDIS_URL", "redis://localhost:6379/0" + ) + + def _build_client(url: str): + return redis.from_url( + url, + socket_connect_timeout=2, + socket_timeout=2, + retry_on_timeout=False, + ) + + try: + client = _build_client(effective_redis_url) + client.ping() + self._redis = client + return + except Exception: + parsed = urlparse(effective_redis_url) + if parsed.scheme == "redis": + tls_url = effective_redis_url.replace("redis://", "rediss://", 1) + try: + client = _build_client(tls_url) + client.ping() + self._redis = client + return + except Exception: + logger.exception("Telegram history: Redis TLS retry failed") + + logger.exception("Telegram history: Redis connection failed") + + def _key(self, user_id: int) -> str: + return f"telegram_chat:{user_id}" + + def load_history(self, user_id: int) -> list[tuple[str, str]]: + if not self._redis: + return [] + + try: + raw_payload = self._redis.get(self._key(user_id)) + if not raw_payload: + return [] + + if isinstance(raw_payload, bytes): + raw_payload = raw_payload.decode("utf-8") + + decoded = json.loads(raw_payload) + if not isinstance(decoded, list): + return [] + + history: list[tuple[str, str]] = [] + for item in decoded: + if isinstance(item, (list, tuple)) and len(item) == 2: + history.append((str(item[0]), str(item[1]))) + + return history + except Exception: + logger.exception( + "Telegram history: failed to load history for user_id=%s", user_id + ) + return [] + + def save_turn(self, user_id, user_message, assistant_message): + if not self._redis: + return + + try: + current = self.load_history(user_id) + current.append((str(user_message), str(assistant_message))) + trimmed = current[-self._max_turns :] + payload = json.dumps([[u, a] for u, a in trimmed], ensure_ascii=False) + self._redis.setex(self._key(user_id), self._ttl_seconds, payload) + except Exception: + logger.exception( + "Telegram history: failed to save history for user_id=%s", user_id + ) diff --git a/tests/test_telegram_bot.py b/tests/test_telegram_bot.py new file mode 100644 index 0000000..0ecd0af --- /dev/null +++ b/tests/test_telegram_bot.py @@ -0,0 +1,316 @@ +import os +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +def _make_mock_rag( + answer="The answer is 42", + sources=None, + question_sources=None, + error=None, + extract_return="QUESTION: test", +): + rag = MagicMock() + result = { + "answer": answer, + "sources": sources or [], + "question_sources": question_sources or [], + } + if error is not None: + result["error"] = error + + rag.query_with_history.return_value = result + rag.llm_manager.extract_image_context.return_value = extract_return + return rag + + +def _make_mock_history(): + history = MagicMock() + history.load_history.return_value = [] + return history + + +def _make_text_update(text, user_id=123): + update = MagicMock() + update.effective_user.id = user_id + update.effective_chat.id = user_id + update.message.text = text + update.message.caption = None + update.message.photo = None + update.message.reply_text = AsyncMock() + return update + + +def _make_photo_update(caption=None, user_id=123): + update = MagicMock() + update.effective_user.id = user_id + update.effective_chat.id = user_id + update.message.text = None + update.message.caption = caption + + photo_small = MagicMock() + photo_small.file_id = "small_file_id" + photo_large = MagicMock() + photo_large.file_id = "large_file_id" + update.message.photo = [photo_small, photo_large] + + update.message.reply_text = AsyncMock() + return update + + +def _make_context(rag=None, history=None): + context = MagicMock() + context.bot_data = { + "rag": rag or _make_mock_rag(), + "history": history or _make_mock_history(), + } + context.bot.send_chat_action = AsyncMock() + + mock_file = MagicMock() + mock_file.download_as_bytearray = AsyncMock(return_value=bytearray(b"fake_image")) + context.bot.get_file = AsyncMock(return_value=mock_file) + return context + + +@pytest.mark.asyncio +async def test_start_command_sends_welcome_message(): + from src.telegram_bot.bot import start_command + + update = _make_text_update("/start") + context = _make_context() + + await start_command(update, context) + + update.message.reply_text.assert_called_once() + message = update.message.reply_text.call_args.args[0] + assert "NEET" in message or "PYQ" in message + + +@pytest.mark.asyncio +async def test_help_command_sends_usage_message(): + from src.telegram_bot.bot import help_command + + update = _make_text_update("/help") + context = _make_context() + + await help_command(update, context) + + update.message.reply_text.assert_called_once() + message = update.message.reply_text.call_args.args[0].lower() + assert "question" in message or "photo" in message + + +@pytest.mark.asyncio +async def test_handle_message_sends_typing_action(): + from src.telegram_bot.bot import handle_message + + update = _make_text_update("What is osmosis?") + context = _make_context() + + await handle_message(update, context) + + context.bot.send_chat_action.assert_called() + + +@pytest.mark.asyncio +async def test_handle_message_calls_rag_with_question_and_session_id(): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag() + update = _make_text_update("What is osmosis?", user_id=456) + context = _make_context(rag=rag) + + await handle_message(update, context) + + rag.query_with_history.assert_called_once() + call = rag.query_with_history.call_args + assert call.args[0] == "What is osmosis?" + assert call.kwargs["session_id"] == "456" + + +@pytest.mark.asyncio +async def test_handle_message_reply_uses_html_parse_mode(): + from src.telegram_bot.bot import handle_message + + update = _make_text_update("Explain diffusion") + context = _make_context() + + await handle_message(update, context) + + kwargs = update.message.reply_text.call_args.kwargs + assert kwargs.get("parse_mode") == "HTML" + + +@pytest.mark.asyncio +async def test_handle_message_saves_turn_after_reply(): + from src.telegram_bot.bot import handle_message + + history = _make_mock_history() + update = _make_text_update("What is osmosis?", user_id=789) + context = _make_context(history=history) + + await handle_message(update, context) + + history.save_turn.assert_called_once() + args = history.save_turn.call_args.kwargs + assert args["user_id"] == 789 + assert args["user_message"] == "What is osmosis?" + + +@pytest.mark.asyncio +async def test_handle_message_exception_returns_friendly_message_no_stack_trace(): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag() + rag.query_with_history.side_effect = Exception("DB connection failed") + update = _make_text_update("Test question") + context = _make_context(rag=rag) + + await handle_message(update, context) + + message = update.message.reply_text.call_args.args[0] + assert "sorry" in message.lower() or "couldn't" in message.lower() + assert "DB connection failed" not in message + + +@pytest.mark.asyncio +async def test_handle_message_error_as_answer_returns_friendly_message_and_not_saved(): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag(answer="Error generating answer: timeout") + history = _make_mock_history() + update = _make_text_update("Test question") + context = _make_context(rag=rag, history=history) + + await handle_message(update, context) + + message = update.message.reply_text.call_args.args[0] + assert "Error generating answer" not in message + history.save_turn.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_message_error_key_returns_friendly_message_and_not_saved(): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag( + answer="Knowledge base is empty or unavailable.", error="No vectorstore found" + ) + history = _make_mock_history() + update = _make_text_update("Test question") + context = _make_context(rag=rag, history=history) + + await handle_message(update, context) + + message = update.message.reply_text.call_args.args[0] + assert "No vectorstore found" not in message + history.save_turn.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_photo_downloads_highest_resolution_photo(): + from src.telegram_bot.bot import handle_photo + + update = _make_photo_update(caption="") + context = _make_context() + + await handle_photo(update, context) + + context.bot.get_file.assert_called_once_with("large_file_id") + + +@pytest.mark.asyncio +async def test_handle_photo_calls_extract_image_context(): + from src.telegram_bot.bot import handle_photo + + rag = _make_mock_rag() + update = _make_photo_update(caption="") + context = _make_context(rag=rag) + + await handle_photo(update, context) + + rag.llm_manager.extract_image_context.assert_called_once() + + +@pytest.mark.asyncio +async def test_handle_photo_caption_included_with_extracted_text(): + from src.telegram_bot.bot import handle_photo + + rag = _make_mock_rag(extract_return="QUESTION: velocity problem") + update = _make_photo_update(caption="Solve this physics problem") + context = _make_context(rag=rag) + + await handle_photo(update, context) + + query = rag.query_with_history.call_args.args[0] + assert "Solve this physics problem" in query + assert "QUESTION: velocity problem" in query + + +@pytest.mark.asyncio +async def test_handle_photo_without_caption_uses_extracted_text_only(): + from src.telegram_bot.bot import handle_photo + + rag = _make_mock_rag(extract_return="QUESTION: biology cell division") + update = _make_photo_update(caption=None) + context = _make_context(rag=rag) + + await handle_photo(update, context) + + query = rag.query_with_history.call_args.args[0] + assert query == "QUESTION: biology cell division" + + +@pytest.mark.asyncio +async def test_handle_photo_download_failure_returns_couldnt_download_message(): + from src.telegram_bot.bot import handle_photo + + update = _make_photo_update(caption="") + context = _make_context() + context.bot.get_file.side_effect = Exception("Network timeout") + + await handle_photo(update, context) + + message = update.message.reply_text.call_args.args[0] + assert "couldn't download" in message.lower() + assert "Network timeout" not in message + + +@pytest.mark.asyncio +async def test_handle_photo_extraction_failure_returns_couldnt_read_message(): + from src.telegram_bot.bot import handle_photo + + rag = _make_mock_rag() + rag.llm_manager.extract_image_context.side_effect = Exception("Vision API down") + update = _make_photo_update(caption="") + context = _make_context(rag=rag) + + await handle_photo(update, context) + + message = update.message.reply_text.call_args.args[0] + assert "couldn't read" in message.lower() + assert "Vision API" not in message + + +@pytest.mark.asyncio +async def test_history_boundary_saves_raw_answer_not_formatted_html(): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag(answer="The answer has H2O and 102") + history = _make_mock_history() + update = _make_text_update("What is water?", user_id=42) + context = _make_context(rag=rag, history=history) + + await handle_message(update, context) + + save_call = history.save_turn.call_args.kwargs + saved_answer = save_call["assistant_message"] + assert saved_answer == "The answer has H2O and 102" + + reply_text = update.message.reply_text.call_args.args[0] + assert "" not in reply_text + assert "₂" in reply_text diff --git a/tests/test_telegram_formatting.py b/tests/test_telegram_formatting.py new file mode 100644 index 0000000..388b956 --- /dev/null +++ b/tests/test_telegram_formatting.py @@ -0,0 +1,155 @@ +# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from src.telegram_bot.formatting import ( + format_answer_text, + format_question_sources, + format_response, + format_youtube_sources, + split_message, +) + + +def test_format_answer_text_plain_passthrough(): + assert format_answer_text("Simple NEET answer") == "Simple NEET answer" + + +def test_format_answer_text_escapes_html_special_chars(): + raw = "Use & keep > and < safely" + assert format_answer_text(raw) == "Use <tag> & keep > and < safely" + + +def test_format_answer_text_strips_inline_latex_delimiters(): + assert format_answer_text("Energy is $x^2$ law") == "Energy is x^2 law" + + +def test_format_answer_text_strips_block_latex_delimiters(): + assert ( + format_answer_text("Ratio $$\\frac{a}{b}$$ done") == "Ratio \\frac{a}{b} done" + ) + + +def test_format_answer_text_strips_parenthesis_latex_delimiters(): + assert format_answer_text(r"Compute \(x+1\) quickly") == "Compute x+1 quickly" + + +def test_format_answer_text_strips_bracket_latex_delimiters(): + assert format_answer_text(r"Matrix form: \[a+b\]") == "Matrix form: a+b" + + +def test_format_answer_text_converts_sup_tag_to_unicode(): + assert format_answer_text("x2") == "x²" + + +def test_format_answer_text_converts_sub_tag_to_unicode(): + assert format_answer_text("H2O") == "H₂O" + + +def test_format_answer_text_preserves_newlines_and_unicode_math_symbols(): + raw = "Line 1\nπ and √ stay" + assert format_answer_text(raw) == raw + + +def test_format_youtube_sources_single_source_with_timestamp_link(): + sources = [ + { + "title": "Kinematics intro", + "url": "https://youtube.com/watch?v=abc123", + "timestamp": "12:34", + } + ] + + rendered = format_youtube_sources(sources) + + assert rendered.startswith( + '1. ' + ) + assert "Kinematics intro" in rendered + assert "(12:34)" in rendered + + +def test_format_youtube_sources_empty_returns_blank_string(): + assert format_youtube_sources([]) == "" + + +def test_format_youtube_sources_multiple_sources_are_numbered(): + sources = [ + {"title": "First", "url": "https://youtu.be/1"}, + {"title": "Second", "url": "https://youtu.be/2", "timestamp": "00:10"}, + ] + + rendered = format_youtube_sources(sources) + + assert "1. " in rendered + assert "2. " in rendered + assert rendered.count("\n") == 1 + + +def test_format_question_sources_single_question_with_neetprep_link(): + rendered = format_question_sources([{"id": 12345, "question": "Find acceleration"}]) + + assert ( + rendered + == '1. Find acceleration' + ) + + +def test_format_question_sources_empty_returns_blank_string(): + assert format_question_sources([]) == "" + + +def test_format_response_with_sources_combines_answer_and_source_sections(): + chunks = format_response( + answer_text="Final answer", + youtube_sources=[ + { + "title": "Work-Energy", + "url": "https://youtube.com/watch?v=xyz", + "timestamp": "00:30", + } + ], + question_sources=[{"id": "777", "question": "PYQ on energy"}], + ) + + combined = "".join(chunks) + + assert "Final answer" in combined + assert "YouTube Sources" in combined + assert "Related Questions" in combined + assert "neetprep.com/epubQuestion/777" in combined + + +def test_format_response_without_sources_returns_answer_only(): + chunks = format_response(answer_text="Only answer") + assert chunks == ["Only answer"] + + +def test_split_message_short_text_not_split(): + assert split_message("short", max_length=4096) == ["short"] + + +def test_split_message_long_text_split_at_max_length(): + text = "a" * 5000 + chunks = split_message(text, max_length=4096) + + assert len(chunks) == 2 + assert all(len(chunk) <= 4096 for chunk in chunks) + assert "".join(chunks) == text + + +def test_split_message_prefers_newline_boundary_and_avoids_mid_html_tag(): + text = ("a" * 4000) + "\n" + ("b" * 80) + "bold" + ("c" * 80) + chunks = split_message(text, max_length=4096) + + assert chunks[0].endswith("\n") + assert "".join(chunks) == text + assert chunks[1].startswith("b" * 80) + + tag_split_text = ("x" * 4094) + "ok" + tag_chunks = split_message(tag_split_text, max_length=4096) + assert tag_chunks[0].endswith("x") + assert tag_chunks[1].startswith("") diff --git a/tests/test_telegram_history.py b/tests/test_telegram_history.py new file mode 100644 index 0000000..3ebef36 --- /dev/null +++ b/tests/test_telegram_history.py @@ -0,0 +1,116 @@ +# pyright: reportMissingImports=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportAny=false + +import json +import os +import sys +from unittest.mock import MagicMock + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +def _make_history(mock_redis=None): + from src.telegram_bot.history import TelegramChatHistory + + history = TelegramChatHistory.__new__(TelegramChatHistory) + history._redis = mock_redis + history._max_turns = 4 + history._ttl_seconds = 604800 + return history + + +def test_load_returns_empty_for_unknown_user(): + mock_redis = MagicMock() + mock_redis.get.return_value = None + + history = _make_history(mock_redis) + + assert history.load_history(1) == [] + + +def test_save_and_load_single_turn(): + storage = {} + mock_redis = MagicMock() + + def _get(key): + return storage.get(key) + + def _setex(key, _ttl, value): + storage[key] = value + + mock_redis.get.side_effect = _get + mock_redis.setex.side_effect = _setex + + history = _make_history(mock_redis) + history.save_turn(1, "hello", "hi there") + + assert history.load_history(1) == [("hello", "hi there")] + + +def test_history_trimmed_to_max_turns(): + storage = {} + mock_redis = MagicMock() + + def _get(key): + return storage.get(key) + + def _setex(key, _ttl, value): + storage[key] = value + + mock_redis.get.side_effect = _get + mock_redis.setex.side_effect = _setex + + history = _make_history(mock_redis) + history._max_turns = 2 + + for i in range(5): + history.save_turn(1, f"u{i}", f"a{i}") + + assert history.load_history(1) == [("u3", "a3"), ("u4", "a4")] + + +def test_different_users_isolated(): + storage = {} + mock_redis = MagicMock() + + def _get(key): + return storage.get(key) + + def _setex(key, _ttl, value): + storage[key] = value + + mock_redis.get.side_effect = _get + mock_redis.setex.side_effect = _setex + + history = _make_history(mock_redis) + history.save_turn(1, "u1", "a1") + history.save_turn(2, "u2", "a2") + + assert history.load_history(1) == [("u1", "a1")] + assert history.load_history(2) == [("u2", "a2")] + + +def test_ttl_set_on_save(): + mock_redis = MagicMock() + mock_redis.get.return_value = None + history = _make_history(mock_redis) + + history.save_turn(1, "hello", "raw answer") + + mock_redis.setex.assert_called_once() + key, ttl, payload = mock_redis.setex.call_args.args + assert key == "telegram_chat:1" + assert ttl == 604800 + assert json.loads(payload) == [["hello", "raw answer"]] + + +def test_load_graceful_when_redis_unavailable(): + history = _make_history(None) + + assert history.load_history(1) == [] + + +def test_save_graceful_when_redis_unavailable(): + history = _make_history(None) + + history.save_turn(1, "hello", "hi") diff --git a/tests/test_telegram_integration.py b/tests/test_telegram_integration.py new file mode 100644 index 0000000..e975c29 --- /dev/null +++ b/tests/test_telegram_integration.py @@ -0,0 +1,167 @@ +import os +import sys +from unittest.mock import MagicMock + +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from tests.test_telegram_bot import ( + _make_context, + _make_mock_history, + _make_mock_rag, + _make_photo_update, + _make_text_update, +) + + +class TestTextEndToEnd: + @pytest.mark.asyncio + async def test_text_question_returns_formatted_html_with_sources(self): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag( + answer="Mitosis creates two identical daughter cells.", + sources=[ + { + "title": "Cell Division Lecture", + "url": "https://www.youtube.com/watch?v=abc123", + "timestamp": "01:30", + } + ], + question_sources=[ + { + "id": 101, + "question": "Which phase includes chromosome alignment?", + } + ], + ) + update = _make_text_update("What is mitosis?") + context = _make_context(rag=rag) + + await handle_message(update, context) + + update.message.reply_text.assert_called_once() + reply = update.message.reply_text.call_args.args[0] + kwargs = update.message.reply_text.call_args.kwargs + + assert kwargs.get("parse_mode") == "HTML" + assert "YouTube Sources" in reply + assert "https://www.youtube.com/watch?v=abc123&t=90" in reply + assert "Related Questions" in reply + assert "https://neetprep.com/epubQuestion/101" in reply + + +class TestPhotoEndToEnd: + @pytest.mark.asyncio + async def test_photo_extracts_queries_and_replies_with_answer(self): + from src.telegram_bot.bot import handle_photo + + rag = _make_mock_rag( + answer="The block accelerates downward due to gravity.", + extract_return="QUESTION: A 2kg block is dropped from rest.", + ) + update = _make_photo_update(caption="Solve this physics problem") + context = _make_context(rag=rag) + + await handle_photo(update, context) + + rag.llm_manager.extract_image_context.assert_called_once() + query = rag.query_with_history.call_args.args[0] + assert "Solve this physics problem" in query + assert "A 2kg block is dropped from rest" in query + + reply = update.message.reply_text.call_args.args[0] + kwargs = update.message.reply_text.call_args.kwargs + assert "accelerates downward" in reply + assert kwargs.get("parse_mode") == "HTML" + + +class TestConversationContinuity: + @pytest.mark.asyncio + async def test_second_message_receives_history_from_first_message(self): + from src.telegram_bot.bot import handle_message + + history_store: list[tuple[str, str]] = [] + history = MagicMock() + + def load_history(_user_id): + return list(history_store) + + def save_turn(user_id, user_message, assistant_message): + _ = user_id + history_store.append((user_message, assistant_message)) + + history.load_history.side_effect = load_history + history.save_turn.side_effect = save_turn + + rag = _make_mock_rag( + answer="Osmosis is water movement across a semipermeable membrane." + ) + context = _make_context(rag=rag, history=history) + + update1 = _make_text_update("What is osmosis?", user_id=100) + await handle_message(update1, context) + + rag.query_with_history.return_value = { + "answer": "Diffusion needs no membrane boundary.", + "sources": [], + "question_sources": [], + } + update2 = _make_text_update("How is it different from diffusion?", user_id=100) + await handle_message(update2, context) + + second_call = rag.query_with_history.call_args_list[1] + chat_history = second_call.kwargs["chat_history"] + assert len(chat_history) == 1 + assert chat_history[0][0] == "What is osmosis?" + assert "water movement" in chat_history[0][1] + + +class TestErrorRecovery: + @pytest.mark.asyncio + async def test_rag_failure_on_first_message_recovers_on_second(self): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag() + rag.query_with_history.side_effect = [ + Exception("temporary backend failure"), + { + "answer": "Recovered answer after retry path.", + "sources": [], + "question_sources": [], + }, + ] + + history = _make_mock_history() + update1 = _make_text_update("msg1") + update2 = _make_text_update("msg2") + context = _make_context(rag=rag, history=history) + + await handle_message(update1, context) + first_reply = update1.message.reply_text.call_args.args[0] + assert "sorry" in first_reply.lower() or "couldn't" in first_reply.lower() + + await handle_message(update2, context) + second_reply = update2.message.reply_text.call_args.args[0] + assert "Recovered answer" in second_reply + + assert history.save_turn.call_count == 1 + assert history.save_turn.call_args.kwargs["user_message"] == "msg2" + + +class TestEmptyKnowledgeBase: + @pytest.mark.asyncio + async def test_no_relevant_information_reply_is_handled_gracefully(self): + from src.telegram_bot.bot import handle_message + + rag = _make_mock_rag(answer="No relevant information found.") + update = _make_text_update("What is dark matter?") + context = _make_context(rag=rag) + + await handle_message(update, context) + + reply = update.message.reply_text.call_args.args[0] + kwargs = update.message.reply_text.call_args.kwargs + assert "No relevant information found." in reply + assert kwargs.get("parse_mode") == "HTML"