Skip to content

Commit 47c5615

Browse files
committed
add draft messages
1 parent c6dc110 commit 47c5615

File tree

9 files changed

+449
-4
lines changed

9 files changed

+449
-4
lines changed

stream_chat/async_chat/channel.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Iterable, List, Union
2+
from typing import Any, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.channel import ChannelInterface, add_user_id
55
from stream_chat.types.stream_response import StreamResponse
@@ -205,3 +205,23 @@ async def unmute(self, user_id: str) -> StreamResponse:
205205
"channel_cid": self.cid,
206206
}
207207
return await self.client.post("moderation/unmute/channel", data=params)
208+
209+
async def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
210+
payload = {"message": add_user_id(message, user_id)}
211+
return await self.client.post(f"{self.url}/draft", data=payload)
212+
213+
async def delete_draft(
214+
self, user_id: str, parent_id: Optional[str] = None
215+
) -> StreamResponse:
216+
params = {"user_id": user_id}
217+
if parent_id:
218+
params["parent_id"] = parent_id
219+
return await self.client.delete(f"{self.url}/draft", params=params)
220+
221+
async def get_draft(
222+
self, user_id: str, parent_id: Optional[str] = None
223+
) -> StreamResponse:
224+
params = {"user_id": user_id}
225+
if parent_id:
226+
params["parent_id"] = parent_id
227+
return await self.client.get(f"{self.url}/draft", params=params)

stream_chat/async_chat/client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from stream_chat.async_chat.segment import Segment
2222
from stream_chat.types.base import SortParam
2323
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
24+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
2425
from stream_chat.types.segment import (
2526
QuerySegmentsOptions,
2627
QuerySegmentTargetsOptions,
@@ -797,6 +798,28 @@ async def unread_counts(self, user_id: str) -> StreamResponse:
797798
async def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
798799
return await self.post("unread_batch", data={"user_ids": user_ids})
799800

801+
async def query_drafts(
802+
self,
803+
user_id: str,
804+
filter: Optional[QueryDraftsFilter] = None,
805+
sort: Optional[List[SortParam]] = None,
806+
options: Optional[QueryDraftsOptions] = None,
807+
) -> StreamResponse:
808+
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
809+
"user_id": user_id
810+
}
811+
812+
if filter is not None:
813+
data["filter"] = cast(dict, filter)
814+
815+
if sort is not None:
816+
data["sort"] = cast(dict, sort)
817+
818+
if options is not None:
819+
data.update(cast(dict, options))
820+
821+
return await self.post("drafts/query", data=data)
822+
800823
async def close(self) -> None:
801824
await self.session.close()
802825

stream_chat/base/channel.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Any, Awaitable, Dict, Iterable, List, Union
2+
from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.client import StreamChatInterface
55
from stream_chat.base.exceptions import StreamChannelException
@@ -426,6 +426,45 @@ def unmute(self, user_id: str) -> Union[StreamResponse, Awaitable[StreamResponse
426426
"""
427427
pass
428428

429+
@abc.abstractmethod
430+
def create_draft(
431+
self, message: Dict, user_id: str
432+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
433+
"""
434+
Creates or updates a draft message in a channel.
435+
436+
:param message: The message object
437+
:param user_id: The ID of the user creating the draft
438+
:return: The Server Response
439+
"""
440+
pass
441+
442+
@abc.abstractmethod
443+
def delete_draft(
444+
self, user_id: str, parent_id: Optional[str] = None
445+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
446+
"""
447+
Deletes a draft message from a channel.
448+
449+
:param user_id: The ID of the user who owns the draft
450+
:param parent_id: Optional ID of the parent message if this is a thread draft
451+
:return: The Server Response
452+
"""
453+
pass
454+
455+
@abc.abstractmethod
456+
def get_draft(
457+
self, user_id: str, parent_id: Optional[str] = None
458+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
459+
"""
460+
Retrieves a draft message from a channel.
461+
462+
:param user_id: The ID of the user who owns the draft
463+
:param parent_id: Optional ID of the parent message if this is a thread draft
464+
:return: The Server Response
465+
"""
466+
pass
467+
429468

430469
def add_user_id(payload: Dict, user_id: str) -> Dict:
431470
return {**payload, "user": {"id": user_id}}

stream_chat/base/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from stream_chat.types.base import SortParam
1111
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
12+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
1213
from stream_chat.types.segment import (
1314
QuerySegmentsOptions,
1415
QuerySegmentTargetsOptions,
@@ -1337,6 +1338,16 @@ def unread_counts_batch(
13371338
"""
13381339
pass
13391340

1341+
@abc.abstractmethod
1342+
def query_drafts(
1343+
self,
1344+
user_id: str,
1345+
filter: Optional[QueryDraftsFilter] = None,
1346+
sort: Optional[List[SortParam]] = None,
1347+
options: Optional[QueryDraftsOptions] = None,
1348+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
1349+
pass
1350+
13401351
#####################
13411352
# Private methods #
13421353
#####################

stream_chat/channel.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Iterable, List, Union
2+
from typing import Any, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.channel import ChannelInterface, add_user_id
55
from stream_chat.types.stream_response import StreamResponse
@@ -206,3 +206,26 @@ def unmute(self, user_id: str) -> StreamResponse:
206206
"channel_cid": self.cid,
207207
}
208208
return self.client.post("moderation/unmute/channel", data=params)
209+
210+
def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
211+
message["user_id"] = user_id
212+
payload = {"message": message}
213+
return self.client.post(f"{self.url}/draft", data=payload)
214+
215+
def delete_draft(
216+
self, user_id: str, parent_id: Optional[str] = None
217+
) -> StreamResponse:
218+
params = {"user_id": user_id}
219+
if parent_id:
220+
params["parent_id"] = parent_id
221+
222+
return self.client.delete(f"{self.url}/draft", params=params)
223+
224+
def get_draft(
225+
self, user_id: str, parent_id: Optional[str] = None
226+
) -> StreamResponse:
227+
params = {"user_id": user_id}
228+
if parent_id:
229+
params["parent_id"] = parent_id
230+
231+
return self.client.get(f"{self.url}/draft", params=params)

stream_chat/client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from stream_chat.segment import Segment
1111
from stream_chat.types.base import SortParam
1212
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
13+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
1314
from stream_chat.types.segment import (
1415
QuerySegmentsOptions,
1516
QuerySegmentTargetsOptions,
@@ -758,4 +759,22 @@ def unread_counts(self, user_id: str) -> StreamResponse:
758759
return self.get("unread", params={"user_id": user_id})
759760

760761
def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
761-
return self.post("unread_batch", data={"user_ids": user_ids})
762+
return self.post("unread_counts/batch", data={"user_ids": user_ids})
763+
764+
def query_drafts(
765+
self,
766+
user_id: str,
767+
filter: Optional[QueryDraftsFilter] = None,
768+
sort: Optional[List[SortParam]] = None,
769+
options: Optional[QueryDraftsOptions] = None,
770+
) -> StreamResponse:
771+
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
772+
"user_id": user_id
773+
}
774+
if filter is not None:
775+
data["filter"] = cast(dict, filter)
776+
if sort is not None:
777+
data["sort"] = cast(dict, sort)
778+
if options is not None:
779+
data.update(cast(dict, options))
780+
return self.post("drafts/query", data=data)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import uuid
2+
from typing import Dict
3+
4+
import pytest
5+
6+
from stream_chat.async_chat.channel import Channel
7+
from stream_chat.async_chat.client import StreamChatAsync
8+
from stream_chat.types.base import SortOrder
9+
10+
11+
@pytest.mark.incremental
12+
class TestDraft:
13+
async def test_create_draft(self, channel: Channel, random_user: Dict):
14+
draft_message = {"text": "This is a draft message"}
15+
response = await channel.create_draft(draft_message, random_user["id"])
16+
17+
assert "draft" in response
18+
assert response["draft"]["message"]["text"] == "This is a draft message"
19+
assert response["draft"]["channel_cid"] == channel.cid
20+
21+
async def test_get_draft(self, channel: Channel, random_user: Dict):
22+
# First create a draft
23+
draft_message = {"text": "This is a draft to retrieve"}
24+
await channel.create_draft(draft_message, random_user["id"])
25+
26+
# Then get the draft
27+
response = await channel.get_draft(random_user["id"])
28+
29+
assert "draft" in response
30+
assert response["draft"]["message"]["text"] == "This is a draft to retrieve"
31+
assert response["draft"]["channel_cid"] == channel.cid
32+
33+
async def test_delete_draft(self, channel: Channel, random_user: Dict):
34+
# First create a draft
35+
draft_message = {"text": "This is a draft to delete"}
36+
await channel.create_draft(draft_message, random_user["id"])
37+
38+
# Then delete the draft
39+
await channel.delete_draft(random_user["id"])
40+
41+
# Verify it's deleted by trying to get it
42+
try:
43+
await channel.get_draft(random_user["id"])
44+
raise AssertionError("Draft should be deleted")
45+
except Exception:
46+
# Expected behavior, draft should not be found
47+
pass
48+
49+
async def test_thread_draft(self, channel: Channel, random_user: Dict):
50+
# First create a parent message
51+
msg = await channel.send_message({"text": "Parent message"}, random_user["id"])
52+
parent_id = msg["message"]["id"]
53+
54+
# Create a draft reply
55+
draft_reply = {"text": "This is a draft reply", "parent_id": parent_id}
56+
response = await channel.create_draft(draft_reply, random_user["id"])
57+
58+
assert "draft" in response
59+
assert response["draft"]["message"]["text"] == "This is a draft reply"
60+
assert response["draft"]["parent_id"] == parent_id
61+
62+
# Get the draft reply
63+
response = await channel.get_draft(random_user["id"], parent_id=parent_id)
64+
65+
assert "draft" in response
66+
assert response["draft"]["message"]["text"] == "This is a draft reply"
67+
assert response["draft"]["parent_id"] == parent_id
68+
69+
# Delete the draft reply
70+
await channel.delete_draft(random_user["id"], parent_id=parent_id)
71+
72+
# Verify it's deleted
73+
try:
74+
await channel.get_draft(random_user["id"], parent_id=parent_id)
75+
raise AssertionError("Thread draft should be deleted")
76+
except Exception:
77+
# Expected behavior
78+
pass
79+
80+
async def test_query_drafts(
81+
self, client: StreamChatAsync, channel: Channel, random_user: Dict
82+
):
83+
# Create multiple drafts in different channels
84+
draft1 = {"text": "Draft in channel 1"}
85+
await channel.create_draft(draft1, random_user["id"])
86+
87+
# Create another channel with a draft
88+
channel2 = client.channel("messaging", str(uuid.uuid4()))
89+
await channel2.create(random_user["id"])
90+
91+
draft2 = {"text": "Draft in channel 2"}
92+
await channel2.create_draft(draft2, random_user["id"])
93+
94+
# Query all drafts for the user
95+
response = await client.query_drafts(random_user["id"])
96+
97+
assert "drafts" in response
98+
assert len(response["drafts"]) == 2
99+
100+
# Query drafts for a specific channel
101+
response = await client.query_drafts(
102+
random_user["id"], filter={"channel_cid": channel2.cid}
103+
)
104+
105+
assert "drafts" in response
106+
assert len(response["drafts"]) == 1
107+
draft = response["drafts"][0]
108+
assert draft["channel_cid"] == channel2.cid
109+
assert draft["message"]["text"] == "Draft in channel 2"
110+
111+
# Query drafts with sort
112+
response = await client.query_drafts(
113+
random_user["id"],
114+
sort=[{"field": "created_at", "direction": SortOrder.ASC}],
115+
)
116+
117+
assert "drafts" in response
118+
assert len(response["drafts"]) == 2
119+
assert response["drafts"][0]["channel_cid"] == channel.cid
120+
assert response["drafts"][1]["channel_cid"] == channel2.cid
121+
122+
# Query drafts with pagination
123+
response = await client.query_drafts(
124+
random_user["id"],
125+
options={"limit": 1},
126+
)
127+
128+
assert "drafts" in response
129+
assert len(response["drafts"]) == 1
130+
assert response["drafts"][0]["channel_cid"] == channel2.cid
131+
132+
assert response["next"] is not None
133+
134+
# Query drafts with pagination
135+
response = await client.query_drafts(
136+
random_user["id"],
137+
options={"limit": 1, "next": response["next"]},
138+
)
139+
140+
assert "drafts" in response
141+
assert len(response["drafts"]) == 1
142+
assert response["drafts"][0]["channel_cid"] == channel.cid
143+
144+
# Cleanup
145+
try:
146+
await channel2.delete()
147+
except Exception:
148+
pass

0 commit comments

Comments
 (0)