diff --git a/config.yaml b/config.yaml index 1b62f96..deb2041 100644 --- a/config.yaml +++ b/config.yaml @@ -94,3 +94,12 @@ security: pii_protection: false output_blocklist: true # Input: block injections/PII. Output: mask PII, block secrets. + +context_management: + strategy: "heuristic" # allowed values: "heuristic" or "karl" + + karl: + model: "same_as_chat" # or explicit Ollama model name + summary_max_tokens: 512 + keep_last_messages: 2 + log_dir: "logs" diff --git a/src/core/context_summarizer.py b/src/core/context_summarizer.py new file mode 100644 index 0000000..c088d41 --- /dev/null +++ b/src/core/context_summarizer.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from typing import Any + + +class KarlSummarizationError(RuntimeError): + """Raised when Karl cannot summarize the current conversation.""" + + +class KarlSummarizer: + """Pure service for context compression via an LLM summary.""" + + def __init__( + self, + llm_core, + config: dict[str, Any], + chat_model_name: str, + ) -> None: + self._llm_core = llm_core + self._chat_model_name = chat_model_name + self._model_name = self._resolve_model_name(config) + self._summary_max_tokens = self._require_positive_int( + config, "summary_max_tokens" + ) + self._keep_last_messages = self._require_non_negative_int( + config, "keep_last_messages" + ) + self._log_dir = self._require_non_empty_str(config, "log_dir") + + def summarize(self, messages: list[dict]) -> list[dict]: + items = [dict(message) for message in messages] + if len(items) <= self._keep_last_messages: + return items + + split_index = len(items) - self._keep_last_messages + history = items[:split_index] + tail = items[split_index:] + + prompt_messages = [ + { + "role": "system", + "content": ( + "Compress the conversation history precisely. " + "Preserve facts, open tasks, constraints, decisions, and unresolved questions. " + "Do not invent content. Keep it concise and actionable." + ), + }, + { + "role": "user", + "content": self._format_history(history), + }, + ] + + try: + stream = self._llm_core.stream_chat( + model_name=self._model_name, + messages=prompt_messages, + options={"num_predict": self._summary_max_tokens}, + keep_alive=600, + ) + summary = "".join( + chunk.get("message", {}).get("content", "") for chunk in stream + ).strip() + except Exception as exc: + raise KarlSummarizationError( + f"Karl summarization failed with model '{self._model_name}'." + ) from exc + + if not summary: + raise KarlSummarizationError( + f"Karl summarization returned an empty summary with model '{self._model_name}'." + ) + + self._append_log_entry(len(history), len(summary), self._model_name) + + return [{"role": "system", "content": summary}, *tail] + + def _append_log_entry( + self, summarized_count: int, summary_length: int, model_name: str + ) -> None: + Path(self._log_dir).mkdir(parents=True, exist_ok=True) + log_file = Path(self._log_dir) / f"karl_{datetime.now().strftime('%Y-%m-%d')}.log" + timestamp = datetime.now().astimezone().isoformat(timespec="seconds") + with log_file.open("a", encoding="utf-8") as handle: + handle.write( + f"{timestamp} summarized={summarized_count} summary_chars={summary_length} model={model_name}\n" + ) + + def _resolve_model_name(self, config: dict[str, Any]) -> str: + model_name = self._require_non_empty_str(config, "model") + if model_name == "same_as_chat": + return self._chat_model_name + return model_name + + @staticmethod + def _format_history(messages: list[dict]) -> str: + lines: list[str] = [] + for message in messages: + role = str(message.get("role", "unknown")).upper() + content = str(message.get("content", "")).strip() + lines.append(f"{role}: {content}") + return "\n".join(lines) + + @staticmethod + def _require_non_empty_str(config: dict[str, Any], key: str) -> str: + value = config.get(key) + if not isinstance(value, str) or not value.strip(): + raise ValueError( + f"context_management.karl.{key} must be a non-empty string." + ) + return value.strip() + + @staticmethod + def _require_positive_int(config: dict[str, Any], key: str) -> int: + value = config.get(key) + if not isinstance(value, int) or value <= 0: + raise ValueError( + f"context_management.karl.{key} must be a positive integer." + ) + return value + + @staticmethod + def _require_non_negative_int(config: dict[str, Any], key: str) -> int: + value = config.get(key) + if not isinstance(value, int) or value < 0: + raise ValueError( + f"context_management.karl.{key} must be a non-negative integer." + ) + return value diff --git a/src/ui/terminal_ui.py b/src/ui/terminal_ui.py index 19705e6..6aa79fa 100644 --- a/src/ui/terminal_ui.py +++ b/src/ui/terminal_ui.py @@ -7,6 +7,7 @@ from colorama import Fore, Style, init from config.personas import get_all_persona_names, get_drink, _load_system_prompts +from core.context_summarizer import KarlSummarizationError, KarlSummarizer from core.context_utils import context_near_limit, karl_prepare_quick_and_dirty from core.orchestrator import broadcast_to_ensemble @@ -389,6 +390,23 @@ def _ensure_context_headroom(self) -> None: wait_msg = self._t("context_wait_message", persona_name=self.bot, drink=drink) print(wait_msg) + context_management = self._require_context_management_config() + strategy = context_management["strategy"] + + if strategy == "heuristic": + self._apply_heuristic_context_trim(persona_options) + return + + if strategy == "karl": + self._apply_karl_context_summary(context_management["karl"], persona_options) + return + + raise ValueError( + "context_management.strategy must be either 'heuristic' or 'karl'." + ) + + def _apply_heuristic_context_trim(self, persona_options) -> None: + num_ctx = persona_options.get("num_ctx") if not num_ctx: logging.info("TerminalUI: Context limit reached, but 'num_ctx' is not set.") @@ -416,6 +434,58 @@ def _ensure_context_headroom(self) -> None: notice = self.texts["terminal_context_trim_notice"] print(f"{Fore.YELLOW}{notice}{Style.RESET_ALL}") + def _apply_karl_context_summary(self, karl_cfg, persona_options) -> None: + summarizer = KarlSummarizer( + llm_core=self.streamer._llm_core, + config=karl_cfg, + chat_model_name=self.streamer.model_name, + ) + original_length = len(self.history) + try: + self.history = summarizer.summarize(self.history) + except KarlSummarizationError: + fallback = karl_cfg.get("fallback_strategy") + if fallback == "heuristic": + self._apply_heuristic_context_trim(persona_options) + return + logging.exception("Karl summarization failed and no fallback is configured.") + raise + + removed = original_length - len(self.history) + if removed > 0: + notice = self.texts["terminal_context_trim_notice"] + print(f"{Fore.YELLOW}{notice}{Style.RESET_ALL}") + + def _require_context_management_config(self): + context_management = getattr(self.config, "context_management", None) + if not isinstance(context_management, dict): + raise ValueError("Missing required 'context_management' configuration section.") + + strategy = context_management.get("strategy") + if strategy not in {"heuristic", "karl"}: + raise ValueError( + "context_management.strategy must be either 'heuristic' or 'karl'." + ) + + if strategy == "karl": + karl_cfg = context_management.get("karl") + if not isinstance(karl_cfg, dict): + raise ValueError("Missing required 'context_management.karl' section.") + required_keys = { + "model", + "summary_max_tokens", + "keep_last_messages", + "log_dir", + } + missing = [key for key in required_keys if key not in karl_cfg] + if missing: + raise ValueError( + "Missing required context_management.karl keys: " + + ", ".join(sorted(missing)) + ) + + return context_management + def _print_loaded_history(self) -> None: if not self.history: return diff --git a/src/ui/web_ui.py b/src/ui/web_ui.py index e1f88af..94c81e0 100644 --- a/src/ui/web_ui.py +++ b/src/ui/web_ui.py @@ -6,6 +6,7 @@ import gradio as gr from config.personas import get_all_persona_names, get_drink, _load_system_prompts +from core.context_summarizer import KarlSummarizationError, KarlSummarizer from core.context_utils import context_near_limit, karl_prepare_quick_and_dirty from core.streaming_provider import inject_wiki_context, lookup_wiki_snippet from ui.conversation_io_terminal import load_conversation @@ -127,6 +128,26 @@ def _handle_context_warning(self, llm_history, chat_history): chat_history.append((None, warn)) persona_options = getattr(self.streamer, "persona_options", {}) or {} + context_management = self._require_context_management_config() + strategy = context_management["strategy"] + + if strategy == "heuristic": + self._apply_heuristic_context_trim(llm_history, persona_options) + return True + + if strategy == "karl": + self._apply_karl_context_summary( + llm_history, + context_management["karl"], + persona_options, + ) + return True + + raise ValueError( + "context_management.strategy must be either 'heuristic' or 'karl'." + ) + + def _apply_heuristic_context_trim(self, llm_history, persona_options): num_ctx_value = persona_options.get("num_ctx") @@ -150,7 +171,51 @@ def _handle_context_warning(self, llm_history, chat_history): num_ctx_value, ) - return True + def _apply_karl_context_summary(self, llm_history, karl_cfg, persona_options): + summarizer = KarlSummarizer( + llm_core=self.streamer._llm_core, + config=karl_cfg, + chat_model_name=self.streamer.model_name, + ) + try: + llm_history[:] = summarizer.summarize(llm_history) + except KarlSummarizationError: + fallback = karl_cfg.get("fallback_strategy") + if fallback == "heuristic": + self._apply_heuristic_context_trim(llm_history, persona_options) + return + logging.exception("Karl summarization failed and no fallback is configured.") + raise + + def _require_context_management_config(self): + context_management = getattr(self.cfg, "context_management", None) + if not isinstance(context_management, dict): + raise ValueError("Missing required 'context_management' configuration section.") + + strategy = context_management.get("strategy") + if strategy not in {"heuristic", "karl"}: + raise ValueError( + "context_management.strategy must be either 'heuristic' or 'karl'." + ) + + if strategy == "karl": + karl_cfg = context_management.get("karl") + if not isinstance(karl_cfg, dict): + raise ValueError("Missing required 'context_management.karl' section.") + required_keys = { + "model", + "summary_max_tokens", + "keep_last_messages", + "log_dir", + } + missing = [key for key in required_keys if key not in karl_cfg] + if missing: + raise ValueError( + "Missing required context_management.karl keys: " + + ", ".join(sorted(missing)) + ) + + return context_management # Stream the response (UI updates continuously) def _stream_reply(self, message_history, chat_history): diff --git a/tests/test_context_summarizer.py b/tests/test_context_summarizer.py new file mode 100644 index 0000000..3faccfe --- /dev/null +++ b/tests/test_context_summarizer.py @@ -0,0 +1,77 @@ +from pathlib import Path + +from core.context_summarizer import KarlSummarizer + + +class _FakeLLMCore: + def __init__(self, chunks): + self.chunks = chunks + self.calls = [] + + def stream_chat(self, model_name, messages, options=None, keep_alive=600): + self.calls.append( + { + "model_name": model_name, + "messages": messages, + "options": options, + "keep_alive": keep_alive, + } + ) + return iter(self.chunks) + + +def test_karl_summarize_reduces_history_and_keeps_tail(tmp_path): + fake = _FakeLLMCore([{"message": {"content": "Kurzfassung"}}]) + cfg = { + "model": "same_as_chat", + "summary_max_tokens": 512, + "keep_last_messages": 2, + "log_dir": str(tmp_path), + } + summarizer = KarlSummarizer(fake, cfg, chat_model_name="chat-model") + + original = [ + {"role": "system", "content": "Regeln"}, + {"role": "user", "content": "Frage alt"}, + {"role": "assistant", "content": "Antwort alt"}, + {"role": "user", "content": "Frage neu"}, + {"role": "assistant", "content": "Antwort neu"}, + ] + frozen = [m.copy() for m in original] + + result = summarizer.summarize(original) + + assert result == [ + {"role": "system", "content": "Kurzfassung"}, + {"role": "user", "content": "Frage neu"}, + {"role": "assistant", "content": "Antwort neu"}, + ] + assert original == frozen + assert fake.calls[0]["model_name"] == "chat-model" + assert fake.calls[0]["options"] == {"num_predict": 512} + + +def test_karl_summarize_creates_daily_log_file(tmp_path): + fake = _FakeLLMCore([{"message": {"content": "Zusammenfassung"}}]) + cfg = { + "model": "karl-model", + "summary_max_tokens": 256, + "keep_last_messages": 1, + "log_dir": str(tmp_path), + } + summarizer = KarlSummarizer(fake, cfg, chat_model_name="chat-model") + + result = summarizer.summarize( + [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + ] + ) + + assert result[0]["role"] == "system" + log_files = list(Path(tmp_path).glob("karl_*.log")) + assert len(log_files) == 1 + log_text = log_files[0].read_text(encoding="utf-8") + assert "summarized=" in log_text + assert "summary_chars=" in log_text + assert "model=karl-model" in log_text diff --git a/tests/test_terminal_ui.py b/tests/test_terminal_ui.py index ea3a3d6..ac84963 100644 --- a/tests/test_terminal_ui.py +++ b/tests/test_terminal_ui.py @@ -15,6 +15,15 @@ def _create_terminal_ui() -> TerminalUI: t=catalog.format, core={"model_name": "dummy"}, ui={"experimental": {"broadcast_mode": True}}, + context_management={ + "strategy": "heuristic", + "karl": { + "model": "same_as_chat", + "summary_max_tokens": 512, + "keep_last_messages": 2, + "log_dir": "logs", + }, + }, ) return TerminalUI( @@ -86,6 +95,15 @@ def test_terminal_ui_broadcast_flag_hides_askall(monkeypatch, capsys) -> None: t=catalog.format, core={"model_name": "dummy"}, ui={"experimental": {"broadcast_mode": False}}, + context_management={ + "strategy": "heuristic", + "karl": { + "model": "same_as_chat", + "summary_max_tokens": 512, + "keep_last_messages": 2, + "log_dir": "logs", + }, + }, ) ui = TerminalUI( diff --git a/tests/test_web_ui.py b/tests/test_web_ui.py index 61f2f0f..7df02f4 100644 --- a/tests/test_web_ui.py +++ b/tests/test_web_ui.py @@ -38,7 +38,20 @@ def _create_web_ui(ui_config=None): if ui_config is None: ui_config = {"experimental": {"broadcast_mode": True}} - dummy_config = SimpleNamespace(texts={}, t=lambda key, **kwargs: key, ui=ui_config) + dummy_config = SimpleNamespace( + texts={}, + t=lambda key, **kwargs: key, + ui=ui_config, + context_management={ + "strategy": "heuristic", + "karl": { + "model": "same_as_chat", + "summary_max_tokens": 512, + "keep_last_messages": 2, + "log_dir": "logs", + }, + }, + ) return WebUI( factory=Mock(), config=dummy_config, @@ -81,6 +94,54 @@ def test_respond_streaming_prepares_history_with_valid_num_ctx(): streamer.stream.assert_called_once() +def test_webui_heuristic_strategy_never_instantiates_karl(): + web_ui = _create_web_ui() + web_ui.bot = "Karl" + streamer = Mock() + streamer.persona_options = {"num_ctx": "4096"} + streamer.stream.return_value = iter(["Hallo"]) + web_ui.streamer = streamer + + with ( + patch("ui.web_ui.lookup_wiki_snippet", return_value=([], [])), + patch("ui.web_ui.context_near_limit", return_value=True), + patch("ui.web_ui.get_drink", return_value="☕"), + patch("ui.web_ui.karl_prepare_quick_and_dirty", side_effect=lambda h, c: h), + patch("ui.web_ui.KarlSummarizer") as mock_karl, + ): + list(web_ui.respond_streaming("Hallo", [], [])) + + mock_karl.assert_not_called() + + +def test_webui_karl_strategy_uses_karl_instead_of_heuristic(): + web_ui = _create_web_ui() + web_ui.cfg.context_management["strategy"] = "karl" + web_ui.bot = "Karl" + streamer = Mock() + streamer.persona_options = {"num_ctx": "4096"} + streamer.model_name = "chat-model" + streamer._llm_core = Mock() + streamer.stream.return_value = iter(["Antwort"]) + web_ui.streamer = streamer + + with ( + patch("ui.web_ui.lookup_wiki_snippet", return_value=([], [])), + patch("ui.web_ui.context_near_limit", return_value=True), + patch("ui.web_ui.get_drink", return_value="☕"), + patch("ui.web_ui.karl_prepare_quick_and_dirty") as mock_prepare, + patch("ui.web_ui.KarlSummarizer") as mock_karl, + ): + instance = mock_karl.return_value + instance.summarize.return_value = [{"role": "system", "content": "S"}] + outputs = list(web_ui.respond_streaming("Hallo", [], [])) + + mock_prepare.assert_not_called() + mock_karl.assert_called_once() + final_state = outputs[-1][2] + assert final_state[0] == {"role": "system", "content": "S"} + + def test_respond_streaming_skips_history_preparation_without_num_ctx(caplog): caplog.set_level(logging.DEBUG)