Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pages/1_Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.utils.rag_singleton import get_rag_system
from src.utils.answer_formatting import (
format_assistant_answer_for_streamlit,
format_chat_message_for_streamlit,
)
from src.utils.ui_helpers import setup_public_page_chrome

load_dotenv()
logger = logging.getLogger(__name__)


def _latex_to_streamlit(text: str) -> str:
text = re.sub(r"\\\[(.+?)\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\((.+?)\\\)", r"$\1$", text, flags=re.DOTALL)
return text


def debug_log(message: str):
logger.info(message)
print(message)
Expand Down Expand Up @@ -548,7 +546,9 @@ def save_history():
last_user_query = message["content"]

with st.chat_message(message["role"]):
st.markdown(_latex_to_streamlit(message["content"]))
st.markdown(
format_chat_message_for_streamlit(message["role"], message["content"])
)

if message["role"] == "assistant" and (
message.get("sources") or message.get("question_sources")
Expand Down Expand Up @@ -692,7 +692,7 @@ def save_history():
answer = response.get("answer", "No answer generated.")
sources = response.get("sources", [])

st.markdown(_latex_to_streamlit(answer))
st.markdown(format_assistant_answer_for_streamlit(answer))

question_sources = response.get("question_sources", [])

Expand Down
2 changes: 2 additions & 0 deletions src/rag/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def _get_default_system_prompt(self) -> str:
- Use bullet points when listing multiple items
- For NEET-specific topics, provide accurate scientific information
- Do NOT say "Document 1", "Document 2", etc. in your answer
- Do NOT use HTML tags like <sup>, <sub>, <b>, <i>, or <br> in your answer
- For exponents, subscripts, and formulas, use plain text or LaTeX instead of HTML
- Use LaTeX with $...$ for inline math and $$...$$ for display math"""

def build_prompt(
Expand Down
121 changes: 121 additions & 0 deletions src/utils/answer_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import re


_SCRIPT_STYLE_BLOCK_PATTERN = re.compile(
r"<(script|style)\b[^>]*>.*?</\1>", re.IGNORECASE | re.DOTALL
)
_LINE_BREAK_PATTERN = re.compile(r"<br\s*/?>", re.IGNORECASE)
_SUPPORTED_TAG_PATTERNS = {
"bold": re.compile(r"<(b|strong)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL),
"italic": re.compile(r"<(i|em)\b[^>]*>(.*?)</\1>", re.IGNORECASE | re.DOTALL),
"sup": re.compile(r"<sup\b[^>]*>(.*?)</sup>", re.IGNORECASE | re.DOTALL),
"sub": re.compile(r"<sub\b[^>]*>(.*?)</sub>", re.IGNORECASE | re.DOTALL),
}


_SUPERSCRIPT_MAP = {
"0": "⁰",
"1": "¹",
"2": "²",
"3": "³",
"4": "⁴",
"5": "⁵",
"6": "⁶",
"7": "⁷",
"8": "⁸",
"9": "⁹",
"+": "⁺",
"-": "⁻",
"=": "⁼",
"(": "⁽",
")": "⁾",
"i": "ⁱ",
"n": "ⁿ",
}

_SUBSCRIPT_MAP = {
"0": "₀",
"1": "₁",
"2": "₂",
"3": "₃",
"4": "₄",
"5": "₅",
"6": "₆",
"7": "₇",
"8": "₈",
"9": "₉",
"+": "₊",
"-": "₋",
"=": "₌",
"(": "₍",
")": "₎",
"a": "ₐ",
"e": "ₑ",
"h": "ₕ",
"i": "ᵢ",
"j": "ⱼ",
"k": "ₖ",
"l": "ₗ",
"m": "ₘ",
"n": "ₙ",
"o": "ₒ",
"p": "ₚ",
"r": "ᵣ",
"s": "ₛ",
"t": "ₜ",
"u": "ᵤ",
"v": "ᵥ",
"x": "ₓ",
}


def _normalize_latex_delimiters(text: str) -> str:
text = re.sub(r"\\\[(.+?)\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\((.+?)\\\)", r"$\1$", text, flags=re.DOTALL)
return text


def _render_with_map(text: str, mapping: dict[str, str], fallback_prefix: str) -> str:
cleaned = text.strip()
if cleaned and all(char in mapping for char in cleaned):
return "".join(mapping[char] for char in cleaned)
return f"{fallback_prefix}({cleaned})" if cleaned else ""


def _normalize_html_like_tags(text: str) -> str:
if "<" not in text or ">" not in text:
return text

normalized = _SCRIPT_STYLE_BLOCK_PATTERN.sub("", text)
normalized = _LINE_BREAK_PATTERN.sub("\n", normalized)

normalized = _SUPPORTED_TAG_PATTERNS["bold"].sub(
lambda match: f"**{match.group(2).strip()}**", normalized
)
normalized = _SUPPORTED_TAG_PATTERNS["italic"].sub(
lambda match: f"*{match.group(2).strip()}*", normalized
)
normalized = _SUPPORTED_TAG_PATTERNS["sup"].sub(
lambda match: _render_with_map(match.group(1), _SUPERSCRIPT_MAP, "^"),
normalized,
)
normalized = _SUPPORTED_TAG_PATTERNS["sub"].sub(
lambda match: _render_with_map(match.group(1), _SUBSCRIPT_MAP, "_"),
normalized,
)

normalized = re.sub(r"[ \t]+\n", "\n", normalized)
normalized = re.sub(r"\n[ \t]+", "\n", normalized)
normalized = re.sub(r"[ \t]{2,}", " ", normalized)
return normalized.strip()


def format_assistant_answer_for_streamlit(text: str) -> str:
normalized = _normalize_html_like_tags(text)
return _normalize_latex_delimiters(normalized)


def format_chat_message_for_streamlit(role: str, text: str) -> str:
if role != "assistant":
return text
return format_assistant_answer_for_streamlit(text)
70 changes: 70 additions & 0 deletions tests/test_answer_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import unittest


class TestAssistantAnswerFormatting(unittest.TestCase):
def test_normalize_assistant_answer_converts_known_html_tags(self):
from src.utils.answer_formatting import format_assistant_answer_for_streamlit

raw = (
"Kinetic energy is proportional to mv<sup>2</sup>/2. "
"Water is H<sub>2</sub>O.<br>"
"<b>Important</b> and <i>revised</i>."
)

formatted = format_assistant_answer_for_streamlit(raw)

self.assertEqual(
formatted,
"Kinetic energy is proportional to mv²/2. Water is H₂O.\n"
+ "**Important** and *revised*.",
)

def test_normalize_assistant_answer_falls_back_for_non_unicode_subscripts(self):
from src.utils.answer_formatting import format_assistant_answer_for_streamlit

raw = "For ideal solutions, ΔH<sub>xyz</sub> = 0 and x<sup>ab</sup> is generic."

formatted = format_assistant_answer_for_streamlit(raw)

self.assertIn("ΔH_(xyz) = 0", formatted)
self.assertIn("x^(ab)", formatted)

def test_normalize_assistant_answer_strips_scripts_without_rewriting_literals(self):
from src.utils.answer_formatting import format_assistant_answer_for_streamlit

raw = "Use work-energy theorem<script>alert(1)</script> safely."

formatted = format_assistant_answer_for_streamlit(raw)

self.assertEqual(formatted, "Use work-energy theorem safely.")

def test_normalize_assistant_answer_preserves_literal_supported_tag_text(self):
from src.utils.answer_formatting import format_assistant_answer_for_streamlit

raw = "The literal token is <sup> in HTML and should stay visible."

formatted = format_assistant_answer_for_streamlit(raw)

self.assertEqual(formatted, raw)

def test_format_chat_message_leaves_user_text_unchanged(self):
from src.utils.answer_formatting import format_chat_message_for_streamlit

raw = "What does <b> mean in HTML?"

formatted = format_chat_message_for_streamlit("user", raw)

self.assertEqual(formatted, raw)

def test_prompt_builder_explicitly_forbids_html_tags(self):
from src.rag.llm_manager import RAGPromptBuilder

prompt = RAGPromptBuilder().default_system_prompt

self.assertIn("Do NOT use HTML tags", prompt)
self.assertIn("<sup>", prompt)
self.assertIn("LaTeX", prompt)


if __name__ == "__main__":
_ = unittest.main(verbosity=2)
3 changes: 2 additions & 1 deletion tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def test_build_with_sources(self):
query="What is mitochondria?", context_docs=docs, include_sources=True
)

self.assertIn("bio.txt", prompt)
self.assertIn("Previous Year Question", prompt)
self.assertIn("Mitochondria is the powerhouse of the cell.", prompt)


def run_tests():
Expand Down
Loading