From 814f1c8d760a1fb16ceb0369f070fa053b8c959e Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 3 Feb 2026 15:32:26 -0600 Subject: [PATCH] feat(py): Allow deferred chat client initialization (#205) When using Posit Connect managed OAuth credentials, chat client connections need access to HTTP headers in the Shiny session object, requiring creation inside the server function. Changes: - Defer client initialization when data_source=None and client=None - Add chat_client property getter/setter for setting client after init - Add client parameter to server() method for deferred pattern - Add _require_client() method for runtime checks - Update methods (client, console, generate_greeting) to require client Closes #205 Co-Authored-By: Claude Opus 4.5 --- pkg-py/src/querychat/_querychat_base.py | 52 ++++++- pkg-py/src/querychat/_shiny.py | 10 ++ pkg-py/tests/test_deferred_client.py | 190 +++++++++++++++++++++++ pkg-py/tests/test_deferred_datasource.py | 9 +- 4 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 pkg-py/tests/test_deferred_client.py diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index e8a7c7f1..87d96ff7 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -81,12 +81,18 @@ def __init__( self._extra_instructions = extra_instructions self._categorical_threshold = categorical_threshold - # Normalize and initialize client (doesn't need data_source) - client = normalize_client(client) - self._client = copy.deepcopy(client) - self._client.set_turns([]) - + # Initialize client + # When data_source is None (deferred pattern), also defer client initialization + # unless an explicit client is provided self._client_console = None + if data_source is None and client is None: + # Deferred pattern: don't try to create a default client + self._client: chatlas.Chat | None = None + else: + # Normalize and initialize client + normalized_client = normalize_client(client) + self._client = copy.deepcopy(normalized_client) + self._client.set_turns([]) # Initialize data source (may be None for deferred pattern) if data_source is not None: @@ -114,7 +120,9 @@ def _build_system_prompt(self) -> None: extra_instructions=self._extra_instructions, categorical_threshold=self._categorical_threshold, ) - self._client.system_prompt = self._system_prompt.render(self.tools) + # Only set system_prompt on client if client is available + if self._client is not None: + self._client.system_prompt = self._system_prompt.render(self.tools) def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]: """Raise if data_source is not set, otherwise return it for type narrowing.""" @@ -126,6 +134,16 @@ def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]: ) return self._data_source + def _require_client(self, method_name: str) -> chatlas.Chat: + """Raise if client is not set, otherwise return it for type narrowing.""" + if self._client is None: + raise RuntimeError( + f"client must be set before calling {method_name}(). " + "Either pass client to __init__(), set the chat_client property, " + "or pass client to server()." + ) + return self._client + def client( self, *, @@ -152,11 +170,12 @@ def client( """ data_source = self._require_data_source("client") + base_client = self._require_client("client") if self._system_prompt is None: raise RuntimeError("System prompt not initialized") tools = normalize_tools(tools, default=self.tools) - chat = copy.deepcopy(self._client) + chat = copy.deepcopy(base_client) chat.set_turns([]) chat.system_prompt = self._system_prompt.render(tools) @@ -177,7 +196,8 @@ def client( def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str: """Generate a welcome greeting for the chat.""" self._require_data_source("generate_greeting") - client = copy.deepcopy(self._client) + base_client = self._require_client("generate_greeting") + client = copy.deepcopy(base_client) client.set_turns([]) return str(client.chat(GREETING_PROMPT, echo=echo)) @@ -190,6 +210,7 @@ def console( ) -> None: """Launch an interactive console chat with the data.""" self._require_data_source("console") + self._require_client("console") tools = normalize_tools(tools, default=("query",)) if new or self._client_console is None: @@ -216,6 +237,21 @@ def data_source(self, value: IntoFrame | sqlalchemy.Engine) -> None: self._data_source = normalize_data_source(value, self._table_name) self._build_system_prompt() + @property + def chat_client(self) -> chatlas.Chat | None: + """Get the current chat client.""" + return self._client + + @chat_client.setter + def chat_client(self, value: str | chatlas.Chat) -> None: + """Set the chat client, normalizing and updating system prompt if needed.""" + normalized_client = normalize_client(value) + self._client = copy.deepcopy(normalized_client) + self._client.set_turns([]) + # Update system prompt on client if data_source is already set + if self._data_source is not None and self._system_prompt is not None: + self._client.system_prompt = self._system_prompt.render(self.tools) + def cleanup(self) -> None: """Clean up resources associated with the data source.""" if self._data_source is not None: diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index c1dcc9a1..844edd78 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -405,6 +405,7 @@ def server( self, *, data_source: Optional[IntoFrame | sqlalchemy.Engine | ibis.Table] = None, + client: Optional[str | chatlas.Chat] = None, enable_bookmarking: bool = False, id: Optional[str] = None, ) -> ServerValues[IntoFrameT]: @@ -422,6 +423,11 @@ def server( Optional data source to use. If provided, sets the data_source property before initializing server logic. This is useful for the deferred pattern where data_source is not known at initialization time. + client + Optional chat client to use. If provided, sets the chat_client property + before initializing server logic. This is useful for the deferred pattern + where the client cannot be created at initialization time (e.g., when + using Posit Connect managed OAuth credentials that require session access). enable_bookmarking Whether to enable bookmarking for the querychat module. id @@ -485,7 +491,11 @@ def title(): if data_source is not None: self.data_source = data_source + if client is not None: + self.chat_client = client + resolved_data_source = self._require_data_source("server") + self._require_client("server") return mod_server( id or self.id, diff --git a/pkg-py/tests/test_deferred_client.py b/pkg-py/tests/test_deferred_client.py new file mode 100644 index 00000000..9940cf4f --- /dev/null +++ b/pkg-py/tests/test_deferred_client.py @@ -0,0 +1,190 @@ +"""Tests for deferred chat client initialization.""" + +import pandas as pd +import pytest +from chatlas import ChatOpenAI +from querychat._querychat_base import QueryChatBase + + +@pytest.fixture +def sample_df(): + """Create a sample pandas DataFrame for testing.""" + return pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + }, + ) + + +class TestDeferredClientInit: + """Tests for initializing QueryChatBase with deferred client.""" + + def test_init_with_none_data_source_defers_client(self): + """When data_source is None and client is not provided, client should be None.""" + qc = QueryChatBase(None, "users") + assert qc._client is None + assert qc.chat_client is None + + def test_init_with_explicit_client_and_none_data_source(self, monkeypatch): + """When data_source is None but client is provided, client should be initialized.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(None, "users", client="openai") + assert qc._client is not None + assert qc.chat_client is not None + + def test_init_with_data_source_initializes_client(self, sample_df, monkeypatch): + """When data_source is provided, client should be initialized with default.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(sample_df, "users") + assert qc._client is not None + assert qc.chat_client is not None + + +class TestChatClientProperty: + """Tests for the chat_client property setter.""" + + def test_chat_client_setter(self, monkeypatch): + """Setting chat_client should normalize and store the client.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(None, "users") + assert qc.chat_client is None + + qc.chat_client = "openai" + assert qc.chat_client is not None + + def test_chat_client_setter_with_chat_object(self, monkeypatch): + """Setting chat_client with a Chat object should work.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(None, "users") + assert qc.chat_client is None + + chat = ChatOpenAI() + qc.chat_client = chat + assert qc.chat_client is not None + + def test_chat_client_setter_updates_system_prompt(self, sample_df, monkeypatch): + """Setting chat_client should update system_prompt if data_source is set.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + # Start with data_source but deferred client + qc = QueryChatBase(None, "users") + qc.data_source = sample_df + + # Now set the client - it should get the system prompt + qc.chat_client = "openai" + assert qc._client is not None + # The system prompt should have been set on the client + assert qc._client.system_prompt is not None + + def test_chat_client_getter_returns_none_when_not_set(self): + """chat_client property returns None when not set.""" + qc = QueryChatBase(None, "users") + assert qc.chat_client is None + + +class TestClientMethodRequirements: + """Tests that methods properly require client to be set.""" + + def test_client_method_requires_client(self, sample_df, monkeypatch): + """client() should raise if client not set.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + # Initialize with data_source but no client + qc = QueryChatBase(None, "users") + qc.data_source = sample_df + + with pytest.raises(RuntimeError, match="client must be set"): + qc.client() + + def test_console_requires_client(self, sample_df, monkeypatch): + """console() should raise if client not set.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(None, "users") + qc.data_source = sample_df + + with pytest.raises(RuntimeError, match="client must be set"): + qc.console() + + def test_generate_greeting_requires_client(self, sample_df, monkeypatch): + """generate_greeting() should raise if client not set.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(None, "users") + qc.data_source = sample_df + + with pytest.raises(RuntimeError, match="client must be set"): + qc.generate_greeting() + + +class TestDeferredClientIntegration: + """Integration tests for the full deferred client workflow.""" + + def test_deferred_data_source_and_client(self, sample_df, monkeypatch): + """Test setting both data_source and client after init.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + + # Create with both deferred + qc = QueryChatBase(None, "users") + assert qc.data_source is None + assert qc.chat_client is None + + # Set data_source first + qc.data_source = sample_df + assert qc.data_source is not None + + # Set client second + qc.chat_client = "openai" + assert qc.chat_client is not None + + # Now methods should work + client = qc.client() + assert client is not None + assert "users" in qc.system_prompt + + def test_deferred_client_then_data_source(self, sample_df, monkeypatch): + """Test setting client before data_source.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + + # Create with both deferred + qc = QueryChatBase(None, "users") + + # Set client first + qc.chat_client = "openai" + assert qc.chat_client is not None + + # Set data_source second + qc.data_source = sample_df + assert qc.data_source is not None + + # Now methods should work + client = qc.client() + assert client is not None + + def test_no_openai_key_error_when_deferred(self, monkeypatch): + """When data_source is None, no OpenAI API key error should occur.""" + # Remove OpenAI API key if set + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("QUERYCHAT_CLIENT", raising=False) + + # This should NOT raise an error about missing API key + qc = QueryChatBase(None, "users") + assert qc._client is None + assert qc.chat_client is None + + +class TestBackwardCompatibility: + """Tests that existing patterns continue to work.""" + + def test_immediate_pattern_unchanged(self, sample_df, monkeypatch): + """Existing code with data_source continues to work.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing") + qc = QueryChatBase(sample_df, "test_table") + + assert qc.data_source is not None + assert qc.chat_client is not None + + # All methods should work immediately + client = qc.client() + assert client is not None + + prompt = qc.system_prompt + assert "test_table" in prompt diff --git a/pkg-py/tests/test_deferred_datasource.py b/pkg-py/tests/test_deferred_datasource.py index af46b0d4..ffd97add 100644 --- a/pkg-py/tests/test_deferred_datasource.py +++ b/pkg-py/tests/test_deferred_datasource.py @@ -139,14 +139,19 @@ class TestDeferredPatternIntegration: def test_deferred_then_set_property(self, sample_df): """Test setting data_source via property after init.""" - # Create with None + # Create with None - both data_source and client are deferred qc = QueryChatBase(None, "users") assert qc.data_source is None + assert qc.chat_client is None - # Set via property + # Set data_source via property qc.data_source = sample_df assert qc.data_source is not None + # Set client via property (required now that we defer both) + qc.chat_client = "openai" + assert qc.chat_client is not None + # Now methods should work client = qc.client() assert client is not None