diff --git a/.env.example b/.env.example
index e3a4b66..f9c1001 100644
--- a/.env.example
+++ b/.env.example
@@ -17,3 +17,7 @@ RAG_WARMUP_ON_LAUNCH=1
# YouTube API Configuration (Optional - for robust metadata)
# Get your key from: https://console.cloud.google.com/apis/credentials
YOUTUBE_API_KEY=
+
+# Telegram Bot Configuration
+# Get your token from @BotFather on Telegram
+TELEGRAM_BOT_TOKEN=
diff --git a/deploy/entrypoint.sh b/deploy/entrypoint.sh
index 5d3a047..eb7c75a 100755
--- a/deploy/entrypoint.sh
+++ b/deploy/entrypoint.sh
@@ -1,4 +1,13 @@
#!/usr/bin/env bash
set -euo pipefail
+# Start Telegram bot in background (if token is configured)
+if [ -n "${TELEGRAM_BOT_TOKEN:-}" ]; then
+ echo "[entrypoint] Starting Telegram bot (polling mode)..."
+ python run_telegram_bot.py &
+ TELEGRAM_PID=$!
+ echo "[entrypoint] Telegram bot started (PID: $TELEGRAM_PID)"
+fi
+
+# Start Streamlit frontend (foreground — container stays alive)
exec python deploy/start_frontend.py
diff --git a/deploy/main.tf b/deploy/main.tf
index 0ffe244..7dae53a 100644
--- a/deploy/main.tf
+++ b/deploy/main.tf
@@ -20,7 +20,8 @@ locals {
streamlit_container_secrets = concat(
local.streamlit_openai_value != "" && local.streamlit_openai_is_arn ? [{ name = "OPENAI_API_KEY", valueFrom = local.streamlit_openai_value }] : [],
- var.youtube_api_key_secret_arn != "" ? [{ name = "YOUTUBE_API_KEY", valueFrom = var.youtube_api_key_secret_arn }] : []
+ var.youtube_api_key_secret_arn != "" ? [{ name = "YOUTUBE_API_KEY", valueFrom = var.youtube_api_key_secret_arn }] : [],
+ var.telegram_bot_token_secret_arn != "" ? [{ name = "TELEGRAM_BOT_TOKEN", valueFrom = var.telegram_bot_token_secret_arn }] : []
)
worker_container_secrets = concat(
@@ -37,7 +38,8 @@ locals {
exec_secret_arns = concat(
local.openai_secret_arns,
- var.youtube_api_key_secret_arn != "" ? [var.youtube_api_key_secret_arn] : []
+ var.youtube_api_key_secret_arn != "" ? [var.youtube_api_key_secret_arn] : [],
+ var.telegram_bot_token_secret_arn != "" ? [var.telegram_bot_token_secret_arn] : []
)
}
diff --git a/deploy/terraform.tfvars.example b/deploy/terraform.tfvars.example
index 831088b..8c81650 100644
--- a/deploy/terraform.tfvars.example
+++ b/deploy/terraform.tfvars.example
@@ -49,6 +49,9 @@ streamlit_desired_count = 2
streamlit_min_capacity = 2
streamlit_max_capacity = 10
+# Telegram Bot (optional — leave empty to disable)
+telegram_bot_token_secret_arn = ""
+
# Enable aws ecs execute-command support
enable_ecs_exec = true
diff --git a/deploy/variables.tf b/deploy/variables.tf
index c9cbcba..0068ec5 100644
--- a/deploy/variables.tf
+++ b/deploy/variables.tf
@@ -222,3 +222,9 @@ variable "ask_assistant_enabled" {
type = string
default = "0"
}
+
+variable "telegram_bot_token_secret_arn" {
+ description = "Secrets Manager ARN for TELEGRAM_BOT_TOKEN. Leave empty to disable the Telegram bot."
+ type = string
+ default = ""
+}
diff --git a/requirements.txt b/requirements.txt
index 14aebce..7e7f283 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -52,3 +52,9 @@ mypy>=1.10.0
# Testing
pytest>=8.0.0
pytest-cov>=5.0.0
+
+# Telegram Bot
+python-telegram-bot>=21.0
+
+# Testing (async)
+pytest-asyncio>=0.23.0
diff --git a/run_telegram_bot.py b/run_telegram_bot.py
new file mode 100644
index 0000000..b6af6c4
--- /dev/null
+++ b/run_telegram_bot.py
@@ -0,0 +1,40 @@
+"""Entry point for the NEET PYQ Telegram Bot."""
+
+import logging
+import os
+from importlib import import_module
+
+
+def _load_dotenv() -> None:
+ try:
+ dotenv = import_module("dotenv")
+ except ModuleNotFoundError:
+ return
+
+ load_fn = getattr(dotenv, "load_dotenv", None)
+ if callable(load_fn):
+ load_fn()
+
+
+_load_dotenv()
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ token = os.getenv("TELEGRAM_BOT_TOKEN")
+ if not token:
+ raise ValueError("TELEGRAM_BOT_TOKEN environment variable is required")
+ logger.info("Initializing NEET PYQ Telegram Bot...")
+ from src.telegram_bot.bot import create_application, run_polling
+
+ app = create_application(token)
+ run_polling(app)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/telegram_bot/__init__.py b/src/telegram_bot/__init__.py
new file mode 100644
index 0000000..7ac60c0
--- /dev/null
+++ b/src/telegram_bot/__init__.py
@@ -0,0 +1,5 @@
+"""Telegram bot interface for the NEET PYQ Assistant."""
+
+from src.telegram_bot.bot import create_application
+
+__all__ = ["create_application"]
diff --git a/src/telegram_bot/bot.py b/src/telegram_bot/bot.py
new file mode 100644
index 0000000..6beb958
--- /dev/null
+++ b/src/telegram_bot/bot.py
@@ -0,0 +1,225 @@
+import asyncio
+import logging
+from typing import Any, Mapping
+
+from telegram import Update
+from telegram.constants import ChatAction, ParseMode
+from telegram.ext import (
+ Application,
+ CommandHandler,
+ ContextTypes,
+ MessageHandler,
+ filters,
+)
+
+from src.telegram_bot.formatting import format_response
+from src.telegram_bot.history import TelegramChatHistory
+
+logger = logging.getLogger(__name__)
+
+_ERR_GENERIC = "Sorry, I couldn't process your question right now. Please try again."
+_ERR_IMAGE_DOWNLOAD = "I couldn't download your image. Please try sending it again."
+_ERR_IMAGE_EXTRACT = "I couldn't read the question from your image. Try a clearer photo or type the question."
+
+
+def _is_error_response(result: Mapping[str, object]) -> bool:
+ if "error" in result:
+ return True
+ answer = result.get("answer", "")
+ return isinstance(answer, str) and answer.startswith("Error generating answer:")
+
+
+def create_application(token: str, rag: Any | None = None):
+ app = Application.builder().token(token).build()
+ if rag is None:
+ from src.utils.rag_singleton import get_rag_system
+
+ rag = get_rag_system()
+
+ app.bot_data["rag"] = rag
+ app.bot_data["history"] = TelegramChatHistory()
+ app.add_handler(CommandHandler("start", start_command))
+ app.add_handler(CommandHandler("help", help_command))
+ app.add_handler(MessageHandler(filters.PHOTO, handle_photo))
+ app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, handle_message))
+ app.add_error_handler(error_handler)
+ return app
+
+
+def run_polling(app) -> None:
+ app.run_polling(drop_pending_updates=True)
+
+
+async def start_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ _ = context
+ message = update.message
+ if message is None:
+ return
+ welcome = (
+ "Welcome to the NEET PYQ Assistant!\n\n"
+ "Send a text question or a photo question and I will help with answers and sources.\n\n"
+ "Type /help for usage instructions."
+ )
+ _ = await message.reply_text(welcome, parse_mode=ParseMode.HTML)
+
+
+async def help_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ _ = context
+ message = update.message
+ if message is None:
+ return
+ help_text = (
+ "How to use:\n"
+ "- Send a text question\n"
+ "- Or send a photo of a question with an optional caption"
+ )
+ _ = await message.reply_text(help_text, parse_mode=ParseMode.HTML)
+
+
+async def _send_typing_periodically(
+ chat_id: int, context: ContextTypes.DEFAULT_TYPE
+) -> None:
+ try:
+ while True:
+ await context.bot.send_chat_action(
+ chat_id=chat_id, action=ChatAction.TYPING
+ )
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ return
+
+
+async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ message = update.message
+ user = update.effective_user
+ chat = update.effective_chat
+ if message is None or user is None or chat is None or message.text is None:
+ return
+
+ question = message.text
+ user_id = user.id
+ chat_id = chat.id
+ rag = context.bot_data["rag"]
+ history = context.bot_data["history"]
+
+ await context.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING)
+ typing_task = asyncio.create_task(_send_typing_periodically(chat_id, context))
+ try:
+ chat_history = history.load_history(user_id)
+ result = rag.query_with_history(
+ question,
+ chat_history=chat_history,
+ session_id=str(user_id),
+ user_id=str(user_id),
+ )
+
+ if _is_error_response(result):
+ _ = await message.reply_text(_ERR_GENERIC)
+ return
+
+ raw_answer = str(result.get("answer", ""))
+ parts = format_response(
+ answer_text=raw_answer,
+ youtube_sources=result.get("sources", []),
+ question_sources=result.get("question_sources", []),
+ )
+ for part in parts:
+ _ = await message.reply_text(part, parse_mode=ParseMode.HTML)
+
+ history.save_turn(
+ user_id=user_id,
+ user_message=question,
+ assistant_message=raw_answer,
+ )
+ except Exception:
+ logger.exception("Failed to handle text message")
+ _ = await message.reply_text(_ERR_GENERIC)
+ finally:
+ typing_task.cancel()
+
+
+async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ message = update.message
+ user = update.effective_user
+ chat = update.effective_chat
+ if message is None or user is None or chat is None:
+ return
+
+ user_id = user.id
+ chat_id = chat.id
+ caption = message.caption or ""
+ rag = context.bot_data["rag"]
+ history = context.bot_data["history"]
+
+ await context.bot.send_chat_action(chat_id=chat_id, action=ChatAction.TYPING)
+ typing_task = asyncio.create_task(_send_typing_periodically(chat_id, context))
+ try:
+ try:
+ photo = message.photo[-1]
+ file = await context.bot.get_file(photo.file_id)
+ image_bytes = bytes(await file.download_as_bytearray())
+ except Exception:
+ logger.exception("Failed downloading photo")
+ _ = await message.reply_text(_ERR_IMAGE_DOWNLOAD)
+ return
+
+ try:
+ extracted = rag.llm_manager.extract_image_context(
+ image_bytes=image_bytes,
+ filename="telegram_photo.jpg",
+ user_hint=caption,
+ session_id=str(user_id),
+ user_id=str(user_id),
+ )
+ except Exception:
+ logger.exception("Failed extracting photo context")
+ _ = await message.reply_text(_ERR_IMAGE_EXTRACT)
+ return
+
+ if caption:
+ question = f"{caption}\n\nImage context:\n{extracted}"
+ else:
+ question = str(extracted)
+
+ chat_history = history.load_history(user_id)
+ result = rag.query_with_history(
+ question,
+ chat_history=chat_history,
+ session_id=str(user_id),
+ user_id=str(user_id),
+ )
+
+ if _is_error_response(result):
+ _ = await message.reply_text(_ERR_GENERIC)
+ return
+
+ raw_answer = str(result.get("answer", ""))
+ parts = format_response(
+ answer_text=raw_answer,
+ youtube_sources=result.get("sources", []),
+ question_sources=result.get("question_sources", []),
+ )
+ for part in parts:
+ _ = await message.reply_text(part, parse_mode=ParseMode.HTML)
+
+ history.save_turn(
+ user_id=user_id,
+ user_message=question,
+ assistant_message=raw_answer,
+ )
+ except Exception:
+ logger.exception("Failed to handle photo message")
+ _ = await message.reply_text(_ERR_GENERIC)
+ finally:
+ typing_task.cancel()
+
+
+async def error_handler(update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
+ logger.error("Unhandled exception", exc_info=context.error)
+ if isinstance(update, Update) and update.effective_chat:
+ try:
+ _ = await context.bot.send_message(
+ chat_id=update.effective_chat.id, text=_ERR_GENERIC
+ )
+ except Exception:
+ logger.exception("Failed to send error message")
diff --git a/src/telegram_bot/formatting.py b/src/telegram_bot/formatting.py
new file mode 100644
index 0000000..fc5f0db
--- /dev/null
+++ b/src/telegram_bot/formatting.py
@@ -0,0 +1,176 @@
+# pyright: reportPrivateUsage=false
+
+import re
+from html import escape as html_escape
+from typing import TypedDict
+
+from src.utils.answer_formatting import (
+ _SUBSCRIPT_MAP,
+ _SUPERSCRIPT_MAP,
+)
+
+
+class YouTubeSource(TypedDict, total=False):
+ title: str
+ url: str
+ timestamp: str
+
+
+class QuestionSource(TypedDict, total=False):
+ id: int | str
+ question: str
+ title: str
+
+
+_INLINE_LATEX_PATTERN = re.compile(r"\$(.+?)\$", re.DOTALL)
+_BLOCK_LATEX_PATTERN = re.compile(r"\$\$(.+?)\$\$", re.DOTALL)
+_PAREN_LATEX_PATTERN = re.compile(r"\\\((.+?)\\\)", re.DOTALL)
+_BRACKET_LATEX_PATTERN = re.compile(r"\\\[(.+?)\\\]", re.DOTALL)
+_SUP_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL)
+_SUB_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL)
+
+
+def _render_with_map(text: str, mapping: dict[str, str]) -> str:
+ cleaned = text.strip()
+ if not cleaned:
+ return ""
+ if all(char in mapping for char in cleaned):
+ return "".join(mapping[char] for char in cleaned)
+ return cleaned
+
+
+def _strip_latex_delimiters(text: str) -> str:
+ stripped = _BLOCK_LATEX_PATTERN.sub(r"\1", text)
+ stripped = _PAREN_LATEX_PATTERN.sub(r"\1", stripped)
+ stripped = _BRACKET_LATEX_PATTERN.sub(r"\1", stripped)
+ stripped = _INLINE_LATEX_PATTERN.sub(r"\1", stripped)
+ return stripped
+
+
+def _convert_sup_sub(text: str) -> str:
+ converted = _SUP_PATTERN.sub(
+ lambda match: _render_with_map(match.group(1), _SUPERSCRIPT_MAP), text
+ )
+ converted = _SUB_PATTERN.sub(
+ lambda match: _render_with_map(match.group(1), _SUBSCRIPT_MAP), converted
+ )
+ return converted
+
+
+def format_answer_text(answer_text: str) -> str:
+ converted = _convert_sup_sub(answer_text)
+ latex_stripped = _strip_latex_delimiters(converted)
+ return html_escape(latex_stripped)
+
+
+def _parse_timestamp_to_seconds(timestamp: str) -> int | None:
+ parts = [part.strip() for part in timestamp.split(":")]
+ if not parts or any(not part.isdigit() for part in parts):
+ return None
+
+ values = [int(part) for part in parts]
+ if len(values) == 3:
+ hours, minutes, seconds = values
+ return (hours * 3600) + (minutes * 60) + seconds
+ if len(values) == 2:
+ minutes, seconds = values
+ return (minutes * 60) + seconds
+ if len(values) == 1:
+ return values[0]
+ return None
+
+
+def format_youtube_sources(sources: list[YouTubeSource]) -> str:
+ if not sources:
+ return ""
+
+ lines: list[str] = []
+ for index, source in enumerate(sources, start=1):
+ raw_title = str(source.get("title") or f"YouTube Source {index}")
+ raw_url = str(source.get("url") or "")
+ timestamp = source.get("timestamp")
+
+ if isinstance(timestamp, str) and timestamp.strip():
+ seconds = _parse_timestamp_to_seconds(timestamp.strip())
+ if seconds is not None:
+ separator = "&" if "?" in raw_url else "?"
+ raw_url = f"{raw_url}{separator}t={seconds}"
+ raw_title = f"{raw_title} ({timestamp.strip()})"
+
+ escaped_title = html_escape(raw_title)
+ escaped_url = html_escape(raw_url, quote=True)
+ lines.append(f'{index}. {escaped_title}')
+
+ return "\n".join(lines)
+
+
+def format_question_sources(sources: list[QuestionSource]) -> str:
+ if not sources:
+ return ""
+
+ lines: list[str] = []
+ for index, source in enumerate(sources, start=1):
+ question_id = str(source.get("id") or "")
+ raw_title = str(
+ source.get("question") or source.get("title") or f"Question {index}"
+ )
+ url = f"https://neetprep.com/epubQuestion/{question_id}"
+ lines.append(
+ f'{index}. {html_escape(raw_title)}'
+ )
+
+ return "\n".join(lines)
+
+
+def format_response(
+ answer_text: str,
+ youtube_sources: list[YouTubeSource] | None = None,
+ question_sources: list[QuestionSource] | None = None,
+ max_length: int = 4096,
+) -> list[str]:
+ full_text = format_answer_text(answer_text)
+
+ youtube_block = format_youtube_sources(youtube_sources or [])
+ if youtube_block:
+ full_text += f"\n\nYouTube Sources\n{youtube_block}"
+
+ question_block = format_question_sources(question_sources or [])
+ if question_block:
+ full_text += f"\n\nRelated Questions\n{question_block}"
+
+ return split_message(full_text, max_length=max_length)
+
+
+def split_message(text: str, max_length: int = 4096) -> list[str]:
+ if len(text) <= max_length:
+ return [text]
+
+ chunks: list[str] = []
+ remaining = text
+
+ while remaining:
+ if len(remaining) <= max_length:
+ chunks.append(remaining)
+ break
+
+ window = remaining[:max_length]
+ split_at = window.rfind("\n")
+ cut = split_at + 1 if split_at != -1 else max_length
+
+ candidate = remaining[:cut]
+ last_lt = candidate.rfind("<")
+ last_gt = candidate.rfind(">")
+ if last_lt > last_gt:
+ safe_cut = last_lt
+ if safe_cut > 0:
+ cut = safe_cut
+ candidate = remaining[:cut]
+
+ if not candidate:
+ candidate = remaining[:max_length]
+ cut = len(candidate)
+
+ chunks.append(candidate)
+ remaining = remaining[cut:]
+
+ return chunks
diff --git a/src/telegram_bot/history.py b/src/telegram_bot/history.py
new file mode 100644
index 0000000..5296475
--- /dev/null
+++ b/src/telegram_bot/history.py
@@ -0,0 +1,95 @@
+# pyright: reportMissingImports=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnannotatedClassAttribute=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportAny=false
+
+import json
+import logging
+import os
+from urllib.parse import urlparse
+
+import redis
+
+
+logger = logging.getLogger(__name__)
+
+
+class TelegramChatHistory:
+ def __init__(self, redis_url=None, max_turns=4, ttl_seconds=604800):
+ self._redis = None
+ self._max_turns = max_turns
+ self._ttl_seconds = ttl_seconds
+
+ effective_redis_url = redis_url or os.getenv(
+ "REDIS_URL", "redis://localhost:6379/0"
+ )
+
+ def _build_client(url: str):
+ return redis.from_url(
+ url,
+ socket_connect_timeout=2,
+ socket_timeout=2,
+ retry_on_timeout=False,
+ )
+
+ try:
+ client = _build_client(effective_redis_url)
+ client.ping()
+ self._redis = client
+ return
+ except Exception:
+ parsed = urlparse(effective_redis_url)
+ if parsed.scheme == "redis":
+ tls_url = effective_redis_url.replace("redis://", "rediss://", 1)
+ try:
+ client = _build_client(tls_url)
+ client.ping()
+ self._redis = client
+ return
+ except Exception:
+ logger.exception("Telegram history: Redis TLS retry failed")
+
+ logger.exception("Telegram history: Redis connection failed")
+
+ def _key(self, user_id: int) -> str:
+ return f"telegram_chat:{user_id}"
+
+ def load_history(self, user_id: int) -> list[tuple[str, str]]:
+ if not self._redis:
+ return []
+
+ try:
+ raw_payload = self._redis.get(self._key(user_id))
+ if not raw_payload:
+ return []
+
+ if isinstance(raw_payload, bytes):
+ raw_payload = raw_payload.decode("utf-8")
+
+ decoded = json.loads(raw_payload)
+ if not isinstance(decoded, list):
+ return []
+
+ history: list[tuple[str, str]] = []
+ for item in decoded:
+ if isinstance(item, (list, tuple)) and len(item) == 2:
+ history.append((str(item[0]), str(item[1])))
+
+ return history
+ except Exception:
+ logger.exception(
+ "Telegram history: failed to load history for user_id=%s", user_id
+ )
+ return []
+
+ def save_turn(self, user_id, user_message, assistant_message):
+ if not self._redis:
+ return
+
+ try:
+ current = self.load_history(user_id)
+ current.append((str(user_message), str(assistant_message)))
+ trimmed = current[-self._max_turns :]
+ payload = json.dumps([[u, a] for u, a in trimmed], ensure_ascii=False)
+ self._redis.setex(self._key(user_id), self._ttl_seconds, payload)
+ except Exception:
+ logger.exception(
+ "Telegram history: failed to save history for user_id=%s", user_id
+ )
diff --git a/tests/test_telegram_bot.py b/tests/test_telegram_bot.py
new file mode 100644
index 0000000..0ecd0af
--- /dev/null
+++ b/tests/test_telegram_bot.py
@@ -0,0 +1,316 @@
+import os
+import sys
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+
+def _make_mock_rag(
+ answer="The answer is 42",
+ sources=None,
+ question_sources=None,
+ error=None,
+ extract_return="QUESTION: test",
+):
+ rag = MagicMock()
+ result = {
+ "answer": answer,
+ "sources": sources or [],
+ "question_sources": question_sources or [],
+ }
+ if error is not None:
+ result["error"] = error
+
+ rag.query_with_history.return_value = result
+ rag.llm_manager.extract_image_context.return_value = extract_return
+ return rag
+
+
+def _make_mock_history():
+ history = MagicMock()
+ history.load_history.return_value = []
+ return history
+
+
+def _make_text_update(text, user_id=123):
+ update = MagicMock()
+ update.effective_user.id = user_id
+ update.effective_chat.id = user_id
+ update.message.text = text
+ update.message.caption = None
+ update.message.photo = None
+ update.message.reply_text = AsyncMock()
+ return update
+
+
+def _make_photo_update(caption=None, user_id=123):
+ update = MagicMock()
+ update.effective_user.id = user_id
+ update.effective_chat.id = user_id
+ update.message.text = None
+ update.message.caption = caption
+
+ photo_small = MagicMock()
+ photo_small.file_id = "small_file_id"
+ photo_large = MagicMock()
+ photo_large.file_id = "large_file_id"
+ update.message.photo = [photo_small, photo_large]
+
+ update.message.reply_text = AsyncMock()
+ return update
+
+
+def _make_context(rag=None, history=None):
+ context = MagicMock()
+ context.bot_data = {
+ "rag": rag or _make_mock_rag(),
+ "history": history or _make_mock_history(),
+ }
+ context.bot.send_chat_action = AsyncMock()
+
+ mock_file = MagicMock()
+ mock_file.download_as_bytearray = AsyncMock(return_value=bytearray(b"fake_image"))
+ context.bot.get_file = AsyncMock(return_value=mock_file)
+ return context
+
+
+@pytest.mark.asyncio
+async def test_start_command_sends_welcome_message():
+ from src.telegram_bot.bot import start_command
+
+ update = _make_text_update("/start")
+ context = _make_context()
+
+ await start_command(update, context)
+
+ update.message.reply_text.assert_called_once()
+ message = update.message.reply_text.call_args.args[0]
+ assert "NEET" in message or "PYQ" in message
+
+
+@pytest.mark.asyncio
+async def test_help_command_sends_usage_message():
+ from src.telegram_bot.bot import help_command
+
+ update = _make_text_update("/help")
+ context = _make_context()
+
+ await help_command(update, context)
+
+ update.message.reply_text.assert_called_once()
+ message = update.message.reply_text.call_args.args[0].lower()
+ assert "question" in message or "photo" in message
+
+
+@pytest.mark.asyncio
+async def test_handle_message_sends_typing_action():
+ from src.telegram_bot.bot import handle_message
+
+ update = _make_text_update("What is osmosis?")
+ context = _make_context()
+
+ await handle_message(update, context)
+
+ context.bot.send_chat_action.assert_called()
+
+
+@pytest.mark.asyncio
+async def test_handle_message_calls_rag_with_question_and_session_id():
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag()
+ update = _make_text_update("What is osmosis?", user_id=456)
+ context = _make_context(rag=rag)
+
+ await handle_message(update, context)
+
+ rag.query_with_history.assert_called_once()
+ call = rag.query_with_history.call_args
+ assert call.args[0] == "What is osmosis?"
+ assert call.kwargs["session_id"] == "456"
+
+
+@pytest.mark.asyncio
+async def test_handle_message_reply_uses_html_parse_mode():
+ from src.telegram_bot.bot import handle_message
+
+ update = _make_text_update("Explain diffusion")
+ context = _make_context()
+
+ await handle_message(update, context)
+
+ kwargs = update.message.reply_text.call_args.kwargs
+ assert kwargs.get("parse_mode") == "HTML"
+
+
+@pytest.mark.asyncio
+async def test_handle_message_saves_turn_after_reply():
+ from src.telegram_bot.bot import handle_message
+
+ history = _make_mock_history()
+ update = _make_text_update("What is osmosis?", user_id=789)
+ context = _make_context(history=history)
+
+ await handle_message(update, context)
+
+ history.save_turn.assert_called_once()
+ args = history.save_turn.call_args.kwargs
+ assert args["user_id"] == 789
+ assert args["user_message"] == "What is osmosis?"
+
+
+@pytest.mark.asyncio
+async def test_handle_message_exception_returns_friendly_message_no_stack_trace():
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag()
+ rag.query_with_history.side_effect = Exception("DB connection failed")
+ update = _make_text_update("Test question")
+ context = _make_context(rag=rag)
+
+ await handle_message(update, context)
+
+ message = update.message.reply_text.call_args.args[0]
+ assert "sorry" in message.lower() or "couldn't" in message.lower()
+ assert "DB connection failed" not in message
+
+
+@pytest.mark.asyncio
+async def test_handle_message_error_as_answer_returns_friendly_message_and_not_saved():
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag(answer="Error generating answer: timeout")
+ history = _make_mock_history()
+ update = _make_text_update("Test question")
+ context = _make_context(rag=rag, history=history)
+
+ await handle_message(update, context)
+
+ message = update.message.reply_text.call_args.args[0]
+ assert "Error generating answer" not in message
+ history.save_turn.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_handle_message_error_key_returns_friendly_message_and_not_saved():
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag(
+ answer="Knowledge base is empty or unavailable.", error="No vectorstore found"
+ )
+ history = _make_mock_history()
+ update = _make_text_update("Test question")
+ context = _make_context(rag=rag, history=history)
+
+ await handle_message(update, context)
+
+ message = update.message.reply_text.call_args.args[0]
+ assert "No vectorstore found" not in message
+ history.save_turn.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_downloads_highest_resolution_photo():
+ from src.telegram_bot.bot import handle_photo
+
+ update = _make_photo_update(caption="")
+ context = _make_context()
+
+ await handle_photo(update, context)
+
+ context.bot.get_file.assert_called_once_with("large_file_id")
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_calls_extract_image_context():
+ from src.telegram_bot.bot import handle_photo
+
+ rag = _make_mock_rag()
+ update = _make_photo_update(caption="")
+ context = _make_context(rag=rag)
+
+ await handle_photo(update, context)
+
+ rag.llm_manager.extract_image_context.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_caption_included_with_extracted_text():
+ from src.telegram_bot.bot import handle_photo
+
+ rag = _make_mock_rag(extract_return="QUESTION: velocity problem")
+ update = _make_photo_update(caption="Solve this physics problem")
+ context = _make_context(rag=rag)
+
+ await handle_photo(update, context)
+
+ query = rag.query_with_history.call_args.args[0]
+ assert "Solve this physics problem" in query
+ assert "QUESTION: velocity problem" in query
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_without_caption_uses_extracted_text_only():
+ from src.telegram_bot.bot import handle_photo
+
+ rag = _make_mock_rag(extract_return="QUESTION: biology cell division")
+ update = _make_photo_update(caption=None)
+ context = _make_context(rag=rag)
+
+ await handle_photo(update, context)
+
+ query = rag.query_with_history.call_args.args[0]
+ assert query == "QUESTION: biology cell division"
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_download_failure_returns_couldnt_download_message():
+ from src.telegram_bot.bot import handle_photo
+
+ update = _make_photo_update(caption="")
+ context = _make_context()
+ context.bot.get_file.side_effect = Exception("Network timeout")
+
+ await handle_photo(update, context)
+
+ message = update.message.reply_text.call_args.args[0]
+ assert "couldn't download" in message.lower()
+ assert "Network timeout" not in message
+
+
+@pytest.mark.asyncio
+async def test_handle_photo_extraction_failure_returns_couldnt_read_message():
+ from src.telegram_bot.bot import handle_photo
+
+ rag = _make_mock_rag()
+ rag.llm_manager.extract_image_context.side_effect = Exception("Vision API down")
+ update = _make_photo_update(caption="")
+ context = _make_context(rag=rag)
+
+ await handle_photo(update, context)
+
+ message = update.message.reply_text.call_args.args[0]
+ assert "couldn't read" in message.lower()
+ assert "Vision API" not in message
+
+
+@pytest.mark.asyncio
+async def test_history_boundary_saves_raw_answer_not_formatted_html():
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag(answer="The answer has H2O and 102")
+ history = _make_mock_history()
+ update = _make_text_update("What is water?", user_id=42)
+ context = _make_context(rag=rag, history=history)
+
+ await handle_message(update, context)
+
+ save_call = history.save_turn.call_args.kwargs
+ saved_answer = save_call["assistant_message"]
+ assert saved_answer == "The answer has H2O and 102"
+
+ reply_text = update.message.reply_text.call_args.args[0]
+ assert "" not in reply_text
+ assert "₂" in reply_text
diff --git a/tests/test_telegram_formatting.py b/tests/test_telegram_formatting.py
new file mode 100644
index 0000000..388b956
--- /dev/null
+++ b/tests/test_telegram_formatting.py
@@ -0,0 +1,155 @@
+# pyright: reportMissingImports=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false
+
+import os
+import sys
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+from src.telegram_bot.formatting import (
+ format_answer_text,
+ format_question_sources,
+ format_response,
+ format_youtube_sources,
+ split_message,
+)
+
+
+def test_format_answer_text_plain_passthrough():
+ assert format_answer_text("Simple NEET answer") == "Simple NEET answer"
+
+
+def test_format_answer_text_escapes_html_special_chars():
+ raw = "Use & keep > and < safely"
+ assert format_answer_text(raw) == "Use <tag> & keep > and < safely"
+
+
+def test_format_answer_text_strips_inline_latex_delimiters():
+ assert format_answer_text("Energy is $x^2$ law") == "Energy is x^2 law"
+
+
+def test_format_answer_text_strips_block_latex_delimiters():
+ assert (
+ format_answer_text("Ratio $$\\frac{a}{b}$$ done") == "Ratio \\frac{a}{b} done"
+ )
+
+
+def test_format_answer_text_strips_parenthesis_latex_delimiters():
+ assert format_answer_text(r"Compute \(x+1\) quickly") == "Compute x+1 quickly"
+
+
+def test_format_answer_text_strips_bracket_latex_delimiters():
+ assert format_answer_text(r"Matrix form: \[a+b\]") == "Matrix form: a+b"
+
+
+def test_format_answer_text_converts_sup_tag_to_unicode():
+ assert format_answer_text("x2") == "x²"
+
+
+def test_format_answer_text_converts_sub_tag_to_unicode():
+ assert format_answer_text("H2O") == "H₂O"
+
+
+def test_format_answer_text_preserves_newlines_and_unicode_math_symbols():
+ raw = "Line 1\nπ and √ stay"
+ assert format_answer_text(raw) == raw
+
+
+def test_format_youtube_sources_single_source_with_timestamp_link():
+ sources = [
+ {
+ "title": "Kinematics intro",
+ "url": "https://youtube.com/watch?v=abc123",
+ "timestamp": "12:34",
+ }
+ ]
+
+ rendered = format_youtube_sources(sources)
+
+ assert rendered.startswith(
+ '1. '
+ )
+ assert "Kinematics intro" in rendered
+ assert "(12:34)" in rendered
+
+
+def test_format_youtube_sources_empty_returns_blank_string():
+ assert format_youtube_sources([]) == ""
+
+
+def test_format_youtube_sources_multiple_sources_are_numbered():
+ sources = [
+ {"title": "First", "url": "https://youtu.be/1"},
+ {"title": "Second", "url": "https://youtu.be/2", "timestamp": "00:10"},
+ ]
+
+ rendered = format_youtube_sources(sources)
+
+ assert "1. " in rendered
+ assert "2. " in rendered
+ assert rendered.count("\n") == 1
+
+
+def test_format_question_sources_single_question_with_neetprep_link():
+ rendered = format_question_sources([{"id": 12345, "question": "Find acceleration"}])
+
+ assert (
+ rendered
+ == '1. Find acceleration'
+ )
+
+
+def test_format_question_sources_empty_returns_blank_string():
+ assert format_question_sources([]) == ""
+
+
+def test_format_response_with_sources_combines_answer_and_source_sections():
+ chunks = format_response(
+ answer_text="Final answer",
+ youtube_sources=[
+ {
+ "title": "Work-Energy",
+ "url": "https://youtube.com/watch?v=xyz",
+ "timestamp": "00:30",
+ }
+ ],
+ question_sources=[{"id": "777", "question": "PYQ on energy"}],
+ )
+
+ combined = "".join(chunks)
+
+ assert "Final answer" in combined
+ assert "YouTube Sources" in combined
+ assert "Related Questions" in combined
+ assert "neetprep.com/epubQuestion/777" in combined
+
+
+def test_format_response_without_sources_returns_answer_only():
+ chunks = format_response(answer_text="Only answer")
+ assert chunks == ["Only answer"]
+
+
+def test_split_message_short_text_not_split():
+ assert split_message("short", max_length=4096) == ["short"]
+
+
+def test_split_message_long_text_split_at_max_length():
+ text = "a" * 5000
+ chunks = split_message(text, max_length=4096)
+
+ assert len(chunks) == 2
+ assert all(len(chunk) <= 4096 for chunk in chunks)
+ assert "".join(chunks) == text
+
+
+def test_split_message_prefers_newline_boundary_and_avoids_mid_html_tag():
+ text = ("a" * 4000) + "\n" + ("b" * 80) + "bold" + ("c" * 80)
+ chunks = split_message(text, max_length=4096)
+
+ assert chunks[0].endswith("\n")
+ assert "".join(chunks) == text
+ assert chunks[1].startswith("b" * 80)
+
+ tag_split_text = ("x" * 4094) + "ok"
+ tag_chunks = split_message(tag_split_text, max_length=4096)
+ assert tag_chunks[0].endswith("x")
+ assert tag_chunks[1].startswith("")
diff --git a/tests/test_telegram_history.py b/tests/test_telegram_history.py
new file mode 100644
index 0000000..3ebef36
--- /dev/null
+++ b/tests/test_telegram_history.py
@@ -0,0 +1,116 @@
+# pyright: reportMissingImports=false, reportUnknownParameterType=false, reportMissingParameterType=false, reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportAny=false
+
+import json
+import os
+import sys
+from unittest.mock import MagicMock
+
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+
+def _make_history(mock_redis=None):
+ from src.telegram_bot.history import TelegramChatHistory
+
+ history = TelegramChatHistory.__new__(TelegramChatHistory)
+ history._redis = mock_redis
+ history._max_turns = 4
+ history._ttl_seconds = 604800
+ return history
+
+
+def test_load_returns_empty_for_unknown_user():
+ mock_redis = MagicMock()
+ mock_redis.get.return_value = None
+
+ history = _make_history(mock_redis)
+
+ assert history.load_history(1) == []
+
+
+def test_save_and_load_single_turn():
+ storage = {}
+ mock_redis = MagicMock()
+
+ def _get(key):
+ return storage.get(key)
+
+ def _setex(key, _ttl, value):
+ storage[key] = value
+
+ mock_redis.get.side_effect = _get
+ mock_redis.setex.side_effect = _setex
+
+ history = _make_history(mock_redis)
+ history.save_turn(1, "hello", "hi there")
+
+ assert history.load_history(1) == [("hello", "hi there")]
+
+
+def test_history_trimmed_to_max_turns():
+ storage = {}
+ mock_redis = MagicMock()
+
+ def _get(key):
+ return storage.get(key)
+
+ def _setex(key, _ttl, value):
+ storage[key] = value
+
+ mock_redis.get.side_effect = _get
+ mock_redis.setex.side_effect = _setex
+
+ history = _make_history(mock_redis)
+ history._max_turns = 2
+
+ for i in range(5):
+ history.save_turn(1, f"u{i}", f"a{i}")
+
+ assert history.load_history(1) == [("u3", "a3"), ("u4", "a4")]
+
+
+def test_different_users_isolated():
+ storage = {}
+ mock_redis = MagicMock()
+
+ def _get(key):
+ return storage.get(key)
+
+ def _setex(key, _ttl, value):
+ storage[key] = value
+
+ mock_redis.get.side_effect = _get
+ mock_redis.setex.side_effect = _setex
+
+ history = _make_history(mock_redis)
+ history.save_turn(1, "u1", "a1")
+ history.save_turn(2, "u2", "a2")
+
+ assert history.load_history(1) == [("u1", "a1")]
+ assert history.load_history(2) == [("u2", "a2")]
+
+
+def test_ttl_set_on_save():
+ mock_redis = MagicMock()
+ mock_redis.get.return_value = None
+ history = _make_history(mock_redis)
+
+ history.save_turn(1, "hello", "raw answer")
+
+ mock_redis.setex.assert_called_once()
+ key, ttl, payload = mock_redis.setex.call_args.args
+ assert key == "telegram_chat:1"
+ assert ttl == 604800
+ assert json.loads(payload) == [["hello", "raw answer"]]
+
+
+def test_load_graceful_when_redis_unavailable():
+ history = _make_history(None)
+
+ assert history.load_history(1) == []
+
+
+def test_save_graceful_when_redis_unavailable():
+ history = _make_history(None)
+
+ history.save_turn(1, "hello", "hi")
diff --git a/tests/test_telegram_integration.py b/tests/test_telegram_integration.py
new file mode 100644
index 0000000..e975c29
--- /dev/null
+++ b/tests/test_telegram_integration.py
@@ -0,0 +1,167 @@
+import os
+import sys
+from unittest.mock import MagicMock
+
+import pytest
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
+
+from tests.test_telegram_bot import (
+ _make_context,
+ _make_mock_history,
+ _make_mock_rag,
+ _make_photo_update,
+ _make_text_update,
+)
+
+
+class TestTextEndToEnd:
+ @pytest.mark.asyncio
+ async def test_text_question_returns_formatted_html_with_sources(self):
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag(
+ answer="Mitosis creates two identical daughter cells.",
+ sources=[
+ {
+ "title": "Cell Division Lecture",
+ "url": "https://www.youtube.com/watch?v=abc123",
+ "timestamp": "01:30",
+ }
+ ],
+ question_sources=[
+ {
+ "id": 101,
+ "question": "Which phase includes chromosome alignment?",
+ }
+ ],
+ )
+ update = _make_text_update("What is mitosis?")
+ context = _make_context(rag=rag)
+
+ await handle_message(update, context)
+
+ update.message.reply_text.assert_called_once()
+ reply = update.message.reply_text.call_args.args[0]
+ kwargs = update.message.reply_text.call_args.kwargs
+
+ assert kwargs.get("parse_mode") == "HTML"
+ assert "YouTube Sources" in reply
+ assert "https://www.youtube.com/watch?v=abc123&t=90" in reply
+ assert "Related Questions" in reply
+ assert "https://neetprep.com/epubQuestion/101" in reply
+
+
+class TestPhotoEndToEnd:
+ @pytest.mark.asyncio
+ async def test_photo_extracts_queries_and_replies_with_answer(self):
+ from src.telegram_bot.bot import handle_photo
+
+ rag = _make_mock_rag(
+ answer="The block accelerates downward due to gravity.",
+ extract_return="QUESTION: A 2kg block is dropped from rest.",
+ )
+ update = _make_photo_update(caption="Solve this physics problem")
+ context = _make_context(rag=rag)
+
+ await handle_photo(update, context)
+
+ rag.llm_manager.extract_image_context.assert_called_once()
+ query = rag.query_with_history.call_args.args[0]
+ assert "Solve this physics problem" in query
+ assert "A 2kg block is dropped from rest" in query
+
+ reply = update.message.reply_text.call_args.args[0]
+ kwargs = update.message.reply_text.call_args.kwargs
+ assert "accelerates downward" in reply
+ assert kwargs.get("parse_mode") == "HTML"
+
+
+class TestConversationContinuity:
+ @pytest.mark.asyncio
+ async def test_second_message_receives_history_from_first_message(self):
+ from src.telegram_bot.bot import handle_message
+
+ history_store: list[tuple[str, str]] = []
+ history = MagicMock()
+
+ def load_history(_user_id):
+ return list(history_store)
+
+ def save_turn(user_id, user_message, assistant_message):
+ _ = user_id
+ history_store.append((user_message, assistant_message))
+
+ history.load_history.side_effect = load_history
+ history.save_turn.side_effect = save_turn
+
+ rag = _make_mock_rag(
+ answer="Osmosis is water movement across a semipermeable membrane."
+ )
+ context = _make_context(rag=rag, history=history)
+
+ update1 = _make_text_update("What is osmosis?", user_id=100)
+ await handle_message(update1, context)
+
+ rag.query_with_history.return_value = {
+ "answer": "Diffusion needs no membrane boundary.",
+ "sources": [],
+ "question_sources": [],
+ }
+ update2 = _make_text_update("How is it different from diffusion?", user_id=100)
+ await handle_message(update2, context)
+
+ second_call = rag.query_with_history.call_args_list[1]
+ chat_history = second_call.kwargs["chat_history"]
+ assert len(chat_history) == 1
+ assert chat_history[0][0] == "What is osmosis?"
+ assert "water movement" in chat_history[0][1]
+
+
+class TestErrorRecovery:
+ @pytest.mark.asyncio
+ async def test_rag_failure_on_first_message_recovers_on_second(self):
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag()
+ rag.query_with_history.side_effect = [
+ Exception("temporary backend failure"),
+ {
+ "answer": "Recovered answer after retry path.",
+ "sources": [],
+ "question_sources": [],
+ },
+ ]
+
+ history = _make_mock_history()
+ update1 = _make_text_update("msg1")
+ update2 = _make_text_update("msg2")
+ context = _make_context(rag=rag, history=history)
+
+ await handle_message(update1, context)
+ first_reply = update1.message.reply_text.call_args.args[0]
+ assert "sorry" in first_reply.lower() or "couldn't" in first_reply.lower()
+
+ await handle_message(update2, context)
+ second_reply = update2.message.reply_text.call_args.args[0]
+ assert "Recovered answer" in second_reply
+
+ assert history.save_turn.call_count == 1
+ assert history.save_turn.call_args.kwargs["user_message"] == "msg2"
+
+
+class TestEmptyKnowledgeBase:
+ @pytest.mark.asyncio
+ async def test_no_relevant_information_reply_is_handled_gracefully(self):
+ from src.telegram_bot.bot import handle_message
+
+ rag = _make_mock_rag(answer="No relevant information found.")
+ update = _make_text_update("What is dark matter?")
+ context = _make_context(rag=rag)
+
+ await handle_message(update, context)
+
+ reply = update.message.reply_text.call_args.args[0]
+ kwargs = update.message.reply_text.call_args.kwargs
+ assert "No relevant information found." in reply
+ assert kwargs.get("parse_mode") == "HTML"