diff --git a/pages/1_Chat.py b/pages/1_Chat.py index 7c1be5c..ea3e361 100644 --- a/pages/1_Chat.py +++ b/pages/1_Chat.py @@ -14,18 +14,16 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.utils.rag_singleton import get_rag_system +from src.utils.answer_formatting import ( + format_assistant_answer_for_streamlit, + format_chat_message_for_streamlit, +) from src.utils.ui_helpers import setup_public_page_chrome load_dotenv() logger = logging.getLogger(__name__) -def _latex_to_streamlit(text: str) -> str: - text = re.sub(r"\\\[(.+?)\\\]", r"$$\1$$", text, flags=re.DOTALL) - text = re.sub(r"\\\((.+?)\\\)", r"$\1$", text, flags=re.DOTALL) - return text - - def debug_log(message: str): logger.info(message) print(message) @@ -548,7 +546,9 @@ def save_history(): last_user_query = message["content"] with st.chat_message(message["role"]): - st.markdown(_latex_to_streamlit(message["content"])) + st.markdown( + format_chat_message_for_streamlit(message["role"], message["content"]) + ) if message["role"] == "assistant" and ( message.get("sources") or message.get("question_sources") @@ -692,7 +692,7 @@ def save_history(): answer = response.get("answer", "No answer generated.") sources = response.get("sources", []) - st.markdown(_latex_to_streamlit(answer)) + st.markdown(format_assistant_answer_for_streamlit(answer)) question_sources = response.get("question_sources", []) diff --git a/src/rag/llm_manager.py b/src/rag/llm_manager.py index cbbcb56..33c5cac 100644 --- a/src/rag/llm_manager.py +++ b/src/rag/llm_manager.py @@ -207,6 +207,8 @@ def _get_default_system_prompt(self) -> str: - Use bullet points when listing multiple items - For NEET-specific topics, provide accurate scientific information - Do NOT say "Document 1", "Document 2", etc. in your answer +- Do NOT use HTML tags like , , , , or
in your answer +- For exponents, subscripts, and formulas, use plain text or LaTeX instead of HTML - Use LaTeX with $...$ for inline math and $$...$$ for display math""" def build_prompt( diff --git a/src/utils/answer_formatting.py b/src/utils/answer_formatting.py new file mode 100644 index 0000000..350babb --- /dev/null +++ b/src/utils/answer_formatting.py @@ -0,0 +1,121 @@ +import re + + +_SCRIPT_STYLE_BLOCK_PATTERN = re.compile( + r"<(script|style)\b[^>]*>.*?", re.IGNORECASE | re.DOTALL +) +_LINE_BREAK_PATTERN = re.compile(r"", re.IGNORECASE) +_SUPPORTED_TAG_PATTERNS = { + "bold": re.compile(r"<(b|strong)\b[^>]*>(.*?)", re.IGNORECASE | re.DOTALL), + "italic": re.compile(r"<(i|em)\b[^>]*>(.*?)", re.IGNORECASE | re.DOTALL), + "sup": re.compile(r"]*>(.*?)
", re.IGNORECASE | re.DOTALL), + "sub": re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL), +} + + +_SUPERSCRIPT_MAP = { + "0": "⁰", + "1": "¹", + "2": "²", + "3": "³", + "4": "⁴", + "5": "⁵", + "6": "⁶", + "7": "⁷", + "8": "⁸", + "9": "⁹", + "+": "⁺", + "-": "⁻", + "=": "⁼", + "(": "⁽", + ")": "⁾", + "i": "ⁱ", + "n": "ⁿ", +} + +_SUBSCRIPT_MAP = { + "0": "₀", + "1": "₁", + "2": "₂", + "3": "₃", + "4": "₄", + "5": "₅", + "6": "₆", + "7": "₇", + "8": "₈", + "9": "₉", + "+": "₊", + "-": "₋", + "=": "₌", + "(": "₍", + ")": "₎", + "a": "ₐ", + "e": "ₑ", + "h": "ₕ", + "i": "ᵢ", + "j": "ⱼ", + "k": "ₖ", + "l": "ₗ", + "m": "ₘ", + "n": "ₙ", + "o": "ₒ", + "p": "ₚ", + "r": "ᵣ", + "s": "ₛ", + "t": "ₜ", + "u": "ᵤ", + "v": "ᵥ", + "x": "ₓ", +} + + +def _normalize_latex_delimiters(text: str) -> str: + text = re.sub(r"\\\[(.+?)\\\]", r"$$\1$$", text, flags=re.DOTALL) + text = re.sub(r"\\\((.+?)\\\)", r"$\1$", text, flags=re.DOTALL) + return text + + +def _render_with_map(text: str, mapping: dict[str, str], fallback_prefix: str) -> str: + cleaned = text.strip() + if cleaned and all(char in mapping for char in cleaned): + return "".join(mapping[char] for char in cleaned) + return f"{fallback_prefix}({cleaned})" if cleaned else "" + + +def _normalize_html_like_tags(text: str) -> str: + if "<" not in text or ">" not in text: + return text + + normalized = _SCRIPT_STYLE_BLOCK_PATTERN.sub("", text) + normalized = _LINE_BREAK_PATTERN.sub("\n", normalized) + + normalized = _SUPPORTED_TAG_PATTERNS["bold"].sub( + lambda match: f"**{match.group(2).strip()}**", normalized + ) + normalized = _SUPPORTED_TAG_PATTERNS["italic"].sub( + lambda match: f"*{match.group(2).strip()}*", normalized + ) + normalized = _SUPPORTED_TAG_PATTERNS["sup"].sub( + lambda match: _render_with_map(match.group(1), _SUPERSCRIPT_MAP, "^"), + normalized, + ) + normalized = _SUPPORTED_TAG_PATTERNS["sub"].sub( + lambda match: _render_with_map(match.group(1), _SUBSCRIPT_MAP, "_"), + normalized, + ) + + normalized = re.sub(r"[ \t]+\n", "\n", normalized) + normalized = re.sub(r"\n[ \t]+", "\n", normalized) + normalized = re.sub(r"[ \t]{2,}", " ", normalized) + return normalized.strip() + + +def format_assistant_answer_for_streamlit(text: str) -> str: + normalized = _normalize_html_like_tags(text) + return _normalize_latex_delimiters(normalized) + + +def format_chat_message_for_streamlit(role: str, text: str) -> str: + if role != "assistant": + return text + return format_assistant_answer_for_streamlit(text) diff --git a/tests/test_answer_formatting.py b/tests/test_answer_formatting.py new file mode 100644 index 0000000..7f01af1 --- /dev/null +++ b/tests/test_answer_formatting.py @@ -0,0 +1,70 @@ +import unittest + + +class TestAssistantAnswerFormatting(unittest.TestCase): + def test_normalize_assistant_answer_converts_known_html_tags(self): + from src.utils.answer_formatting import format_assistant_answer_for_streamlit + + raw = ( + "Kinetic energy is proportional to mv2/2. " + "Water is H2O.
" + "Important and revised." + ) + + formatted = format_assistant_answer_for_streamlit(raw) + + self.assertEqual( + formatted, + "Kinetic energy is proportional to mv²/2. Water is H₂O.\n" + + "**Important** and *revised*.", + ) + + def test_normalize_assistant_answer_falls_back_for_non_unicode_subscripts(self): + from src.utils.answer_formatting import format_assistant_answer_for_streamlit + + raw = "For ideal solutions, ΔHxyz = 0 and xab is generic." + + formatted = format_assistant_answer_for_streamlit(raw) + + self.assertIn("ΔH_(xyz) = 0", formatted) + self.assertIn("x^(ab)", formatted) + + def test_normalize_assistant_answer_strips_scripts_without_rewriting_literals(self): + from src.utils.answer_formatting import format_assistant_answer_for_streamlit + + raw = "Use work-energy theorem safely." + + formatted = format_assistant_answer_for_streamlit(raw) + + self.assertEqual(formatted, "Use work-energy theorem safely.") + + def test_normalize_assistant_answer_preserves_literal_supported_tag_text(self): + from src.utils.answer_formatting import format_assistant_answer_for_streamlit + + raw = "The literal token is in HTML and should stay visible." + + formatted = format_assistant_answer_for_streamlit(raw) + + self.assertEqual(formatted, raw) + + def test_format_chat_message_leaves_user_text_unchanged(self): + from src.utils.answer_formatting import format_chat_message_for_streamlit + + raw = "What does mean in HTML?" + + formatted = format_chat_message_for_streamlit("user", raw) + + self.assertEqual(formatted, raw) + + def test_prompt_builder_explicitly_forbids_html_tags(self): + from src.rag.llm_manager import RAGPromptBuilder + + prompt = RAGPromptBuilder().default_system_prompt + + self.assertIn("Do NOT use HTML tags", prompt) + self.assertIn("", prompt) + self.assertIn("LaTeX", prompt) + + +if __name__ == "__main__": + _ = unittest.main(verbosity=2) diff --git a/tests/test_rag.py b/tests/test_rag.py index 723bbae..c59b900 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -218,7 +218,8 @@ def test_build_with_sources(self): query="What is mitochondria?", context_docs=docs, include_sources=True ) - self.assertIn("bio.txt", prompt) + self.assertIn("Previous Year Question", prompt) + self.assertIn("Mitochondria is the powerhouse of the cell.", prompt) def run_tests():