diff --git a/chatkit/server.py b/chatkit/server.py index b740a3f..267769e 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -977,4 +977,5 @@ def is_hidden(item: ThreadItem) -> bool: created_at=thread.created_at, items=items, status=thread.status, + allowed_image_domains=thread.allowed_image_domains, ) diff --git a/chatkit/types.py b/chatkit/types.py index da099b9..f9da6a0 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -562,7 +562,7 @@ class ThreadMetadata(BaseModel): id: str created_at: datetime status: ThreadStatus = Field(default_factory=lambda: ActiveStatus()) - # TODO - make not client rendered + allowed_image_domains: list[str] | None = None metadata: dict[str, Any] = Field(default_factory=dict) diff --git a/pyproject.toml b/pyproject.toml index fd44002..04ad045 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-chatkit" -version = "1.6.2" +version = "1.6.3" description = "A ChatKit backend SDK." readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 054eba4..ec30b51 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -1,5 +1,6 @@ import asyncio import base64 +import json import sqlite3 from contextlib import contextmanager from datetime import datetime @@ -639,6 +640,78 @@ async def responder( assert events[-1].thread.status == LockedStatus(reason="Because") +async def test_saves_allowed_image_domains_and_streams_thread_updated(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + thread.allowed_image_domains = ["example.com", "images.example.com"] + return + yield + + with make_server(responder) as server: + events = await server.process_streaming( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello, world!")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ) + ) + thread = next( + event.thread for event in events if event.type == "thread.created" + ) + loaded = await server.store.load_thread(thread.id, DEFAULT_CONTEXT) + assert loaded.allowed_image_domains == ["example.com", "images.example.com"] + assert events[-1].type == "thread.updated" + assert events[-1].thread.allowed_image_domains == [ + "example.com", + "images.example.com", + ] + + +async def test_omits_unset_allowed_image_domains_in_created_and_updated_json_events(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + thread.title = "Updated title" + return + yield + + with make_server(responder) as server: + stream = await server.process( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello, world!")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ).model_dump_json(), + DEFAULT_CONTEXT, + ) + assert isinstance(stream, StreamingResult) + + thread_created_event: dict[str, Any] | None = None + thread_updated_event: dict[str, Any] | None = None + async for raw in stream.json_events: + event = json.loads(raw.split(b"data: ")[1]) + if event["type"] == "thread.created": + thread_created_event = event + if event["type"] == "thread.updated": + thread_updated_event = event + + assert thread_created_event is not None + assert "allowed_image_domains" not in thread_created_event["thread"] + + assert thread_updated_event is not None + assert thread_updated_event["thread"]["title"] == "Updated title" + assert "allowed_image_domains" not in thread_updated_event["thread"] + + async def test_emits_thread_updated_mid_stream_and_persists(): async def responder( thread: ThreadMetadata, input: UserMessageItem | None, context: Any diff --git a/uv.lock b/uv.lock index e0cb9c1..7e68417 100644 --- a/uv.lock +++ b/uv.lock @@ -819,7 +819,7 @@ wheels = [ [[package]] name = "openai-chatkit" -version = "1.6.2" +version = "1.6.3" source = { virtual = "." } dependencies = [ { name = "jinja2" },