From 7c261340948d5b55762c656a3b7b4dbcde53313d Mon Sep 17 00:00:00 2001 From: evgeny Date: Thu, 15 Jan 2026 11:44:35 +0000 Subject: [PATCH] [AIT-258] feat: add Realtime mutable message support - Updated `ConnectionManager` and `MessageQueue` to process `PublishResult` during acknowledgments (ACK/NACK). - Extended `send_protocol_message` to return `PublishResult` for publish tracking. - Bumped default `protocol_version` to 5. - Added tests for message update, delete, append operations, and PublishResult handling. --- ably/realtime/connectionmanager.py | 45 ++- ably/realtime/realtime_channel.py | 211 ++++++++++++- ably/transport/defaults.py | 2 +- ably/transport/websockettransport.py | 6 +- .../realtimechannelmutablemessages_test.py | 289 ++++++++++++++++++ 5 files changed, 534 insertions(+), 19 deletions(-) create mode 100644 test/ably/realtime/realtimechannelmutablemessages_test.py diff --git a/ably/realtime/connectionmanager.py b/ably/realtime/connectionmanager.py index 01a0735b..9b09e126 100644 --- a/ably/realtime/connectionmanager.py +++ b/ably/realtime/connectionmanager.py @@ -4,6 +4,7 @@ import logging from collections import deque from datetime import datetime +from itertools import zip_longest from typing import TYPE_CHECKING import httpx @@ -13,6 +14,7 @@ from ably.types.connectiondetails import ConnectionDetails from ably.types.connectionerrors import ConnectionErrors from ably.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange +from ably.types.operations import PublishResult from ably.types.tokendetails import TokenDetails from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException, IncompatibleClientIdException @@ -29,7 +31,7 @@ class PendingMessage: def __init__(self, message: dict): self.message = message - self.future: asyncio.Future | None = None + self.future: asyncio.Future[PublishResult] | None = None action = message.get('action') # Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT @@ -58,15 +60,22 @@ def count(self) -> int: """Return the number of pending messages""" return len(self.messages) - def complete_messages(self, serial: int, count: int, err: AblyException | None = None) -> None: + def complete_messages( + self, + serial: int, + count: int, + res: list[PublishResult] | None, + err: AblyException | None = None + ) -> None: """Complete messages based on serial and count from ACK/NACK Args: serial: The msgSerial of the first message being acknowledged count: The number of messages being acknowledged + res: List of PublishResult objects for each message acknowledged, or None if not available err: Error from NACK, or None for successful ACK """ - log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, err={err}') + log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, res={res}, err={err}') if not self.messages: log.warning('MessageQueue.complete_messages(): called on empty queue') @@ -87,12 +96,17 @@ def complete_messages(self, serial: int, count: int, err: AblyException | None = completed_messages = self.messages[:num_to_complete] self.messages = self.messages[num_to_complete:] - for msg in completed_messages: + # Default res to empty list if None + res_list = res if res is not None else [] + for (msg, publish_result) in zip_longest(completed_messages, res_list): if msg.future and not msg.future.done(): if err: msg.future.set_exception(err) else: - msg.future.set_result(None) + # If publish_result is None, return empty PublishResult + if publish_result is None: + publish_result = PublishResult() + msg.future.set_result(publish_result) def complete_all_messages(self, err: AblyException) -> None: """Complete all pending messages with an error""" @@ -199,7 +213,7 @@ async def close_impl(self) -> None: self.notify_state(ConnectionState.CLOSED) - async def send_protocol_message(self, protocol_message: dict) -> None: + async def send_protocol_message(self, protocol_message: dict) -> PublishResult | None: """Send a protocol message and optionally track it for acknowledgment Args: @@ -233,12 +247,14 @@ async def send_protocol_message(self, protocol_message: dict) -> None: if state_should_queue: self.queued_messages.appendleft(pending_message) if pending_message.ack_required: - await pending_message.future + return await pending_message.future return None return await self._send_protocol_message_on_connected_state(pending_message) - async def _send_protocol_message_on_connected_state(self, pending_message: PendingMessage) -> None: + async def _send_protocol_message_on_connected_state( + self, pending_message: PendingMessage + ) -> PublishResult | None: if self.state == ConnectionState.CONNECTED and self.transport: # Add to pending queue before sending (for messages being resent from queue) if pending_message.ack_required and pending_message not in self.pending_message_queue.messages: @@ -253,7 +269,7 @@ async def _send_protocol_message_on_connected_state(self, pending_message: Pendi AblyException("No active transport", 500, 50000) ) if pending_message.ack_required: - await pending_message.future + return await pending_message.future return None def send_queued_messages(self) -> None: @@ -449,15 +465,18 @@ def on_heartbeat(self, id: str | None) -> None: self.__ping_future.set_result(None) self.__ping_future = None - def on_ack(self, serial: int, count: int) -> None: + def on_ack( + self, serial: int, count: int, res: list[PublishResult] | None + ) -> None: """Handle ACK protocol message from server Args: serial: The msgSerial of the first message being acknowledged count: The number of messages being acknowledged + res: List of PublishResult objects for each message acknowledged, or None if not available """ - log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}') - self.pending_message_queue.complete_messages(serial, count) + log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}, res={res}') + self.pending_message_queue.complete_messages(serial, count, res) def on_nack(self, serial: int, count: int, err: AblyException | None) -> None: """Handle NACK protocol message from server @@ -471,7 +490,7 @@ def on_nack(self, serial: int, count: int, err: AblyException | None) -> None: err = AblyException('Unable to send message; channel not responding', 50001, 500) log.error(f'ConnectionManager.on_nack(): serial={serial}, count={count}, err={err}') - self.pending_message_queue.complete_messages(serial, count, err) + self.pending_message_queue.complete_messages(serial, count, None, err) def deactivate_transport(self, reason: AblyException | None = None): # RTN19a: Before disconnecting, requeue any pending messages diff --git a/ably/realtime/realtime_channel.py b/ably/realtime/realtime_channel.py index fa6f396d..2f5a03d2 100644 --- a/ably/realtime/realtime_channel.py +++ b/ably/realtime/realtime_channel.py @@ -10,8 +10,9 @@ from ably.transport.websockettransport import ProtocolMessageAction from ably.types.channelstate import ChannelState, ChannelStateChange from ably.types.flags import Flag, has_flag -from ably.types.message import Message +from ably.types.message import Message, MessageAction, MessageVersion from ably.types.mixins import DecodingContext +from ably.types.operations import MessageOperation, PublishResult, UpdateDeleteResult from ably.types.presence import PresenceMessage from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException, IncompatibleClientIdException @@ -390,7 +391,7 @@ def unsubscribe(self, *args) -> None: self.__message_emitter.off(listener) # RTL6 - async def publish(self, *args, **kwargs) -> None: + async def publish(self, *args, **kwargs) -> PublishResult: """Publish a message or messages on this channel Publishes a single message or an array of messages to the channel. @@ -490,7 +491,7 @@ async def publish(self, *args, **kwargs) -> None: } # RTL6b: Await acknowledgment from server - await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message) + return await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message) def _throw_if_unpublishable_state(self) -> None: """Check if the channel and connection are in a state that allows publishing @@ -522,6 +523,200 @@ def _throw_if_unpublishable_state(self) -> None: 90001, ) + async def _send_update(self, message: Message, action: MessageAction, + operation: MessageOperation = None) -> UpdateDeleteResult: + """Internal method to send update/delete/append operations via websocket. + + Parameters + ---------- + message : Message + Message object with serial field required + action : MessageAction + The action type (MESSAGE_UPDATE, MESSAGE_DELETE, MESSAGE_APPEND) + operation : MessageOperation, optional + Operation metadata (description, metadata) + + Returns + ------- + UpdateDeleteResult + Result containing version serial of the operation + + Raises + ------ + AblyException + If message serial is missing or connection/channel state prevents operation + """ + # Check message has serial + if not message.serial: + raise AblyException( + "Message serial is required for update/delete/append operations", + 400, + 40000 + ) + + # Check connection and channel state + self._throw_if_unpublishable_state() + + # Create version from operation if provided + if not operation: + version = None + else: + version = MessageVersion( + client_id=operation.client_id, + description=operation.description, + metadata=operation.metadata + ) + + # Create a new message with the operation fields + update_message = Message( + name=message.name, + data=message.data, + client_id=message.client_id, + serial=message.serial, + action=action, + version=version, + ) + + # Encrypt if needed + if self.cipher: + update_message.encrypt(self.cipher) + + # Convert to dict representation + msg_dict = update_message.as_dict(binary=self.ably.options.use_binary_protocol) + + log.info( + f'RealtimeChannel._send_update(): sending {action.name} message; ' + f'channel = {self.name}, state = {self.state}, serial = {message.serial}' + ) + + # Send protocol message + protocol_message = { + "action": ProtocolMessageAction.MESSAGE, + "channel": self.name, + "messages": [msg_dict], + } + + # Send and await acknowledgment + result = await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message) + + # Return UpdateDeleteResult - we don't have version_serial from the result yet + # The server will send ACK with the result + if result and hasattr(result, 'serials') and result.serials: + return UpdateDeleteResult(version_serial=result.serials[0]) + return UpdateDeleteResult() + + async def update_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult: + """Updates an existing message on this channel. + + Parameters + ---------- + message : Message + Message object to update. Must have a serial field. + operation : MessageOperation, optional + Optional MessageOperation containing description and metadata for the update. + + Returns + ------- + UpdateDeleteResult + Result containing the version serial of the updated message. + + Raises + ------ + AblyException + If message serial is missing or connection/channel state prevents the update + """ + return await self._send_update(message, MessageAction.MESSAGE_UPDATE, operation) + + async def delete_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult: + """Deletes a message on this channel. + + Parameters + ---------- + message : Message + Message object to delete. Must have a serial field. + operation : MessageOperation, optional + Optional MessageOperation containing description and metadata for the delete. + + Returns + ------- + UpdateDeleteResult + Result containing the version serial of the deleted message. + + Raises + ------ + AblyException + If message serial is missing or connection/channel state prevents the delete + """ + return await self._send_update(message, MessageAction.MESSAGE_DELETE, operation) + + async def append_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult: + """Appends data to an existing message on this channel. + + Parameters + ---------- + message : Message + Message object with data to append. Must have a serial field. + operation : MessageOperation, optional + Optional MessageOperation containing description and metadata for the append. + + Returns + ------- + UpdateDeleteResult + Result containing the version serial of the appended message. + + Raises + ------ + AblyException + If message serial is missing or connection/channel state prevents the append + """ + return await self._send_update(message, MessageAction.MESSAGE_APPEND, operation) + + async def get_message(self, serial_or_message, timeout=None): + """Retrieves a single message by its serial using the REST API. + + Parameters + ---------- + serial_or_message : str or Message + Either a string serial or a Message object with a serial field. + timeout : float, optional + Timeout for the request. + + Returns + ------- + Message + Message object for the requested serial. + + Raises + ------ + AblyException + If the serial is missing or the message cannot be retrieved. + """ + # Delegate to parent Channel (REST) implementation + return await Channel.get_message(self, serial_or_message, timeout=timeout) + + async def get_message_versions(self, serial_or_message, params=None): + """Retrieves version history for a message using the REST API. + + Parameters + ---------- + serial_or_message : str or Message + Either a string serial or a Message object with a serial field. + params : dict, optional + Optional dict of query parameters for pagination. + + Returns + ------- + PaginatedResult + PaginatedResult containing Message objects representing each version. + + Raises + ------ + AblyException + If the serial is missing or versions cannot be retrieved. + """ + # Delegate to parent Channel (REST) implementation + return await Channel.get_message_versions(self, serial_or_message, params=params) + def _on_message(self, proto_msg: dict) -> None: action = proto_msg.get('action') # RTL4c1 @@ -766,7 +961,7 @@ class Channels(RestChannels): """ # RTS3 - def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChannel: + def get(self, name: str, options: ChannelOptions | None = None, **kwargs) -> RealtimeChannel: """Creates a new RealtimeChannel object, or returns the existing channel object. Parameters @@ -776,7 +971,15 @@ def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChann Channel name options: ChannelOptions or dict, optional Channel options for the channel + **kwargs: + Additional keyword arguments to create ChannelOptions (e.g., cipher, params) """ + # Convert kwargs to ChannelOptions if provided + if kwargs and not options: + options = ChannelOptions(**kwargs) + elif options and isinstance(options, dict): + options = ChannelOptions.from_dict(options) + if name not in self.__all: channel = self.__all[name] = RealtimeChannel(self.__ably, name, options) else: diff --git a/ably/transport/defaults.py b/ably/transport/defaults.py index 7a732d9a..b6b1098a 100644 --- a/ably/transport/defaults.py +++ b/ably/transport/defaults.py @@ -1,5 +1,5 @@ class Defaults: - protocol_version = "2" + protocol_version = "5" fallback_hosts = [ "a.ably-realtime.com", "b.ably-realtime.com", diff --git a/ably/transport/websockettransport.py b/ably/transport/websockettransport.py index 325685b7..bdd8780f 100644 --- a/ably/transport/websockettransport.py +++ b/ably/transport/websockettransport.py @@ -12,6 +12,7 @@ from ably.http.httputils import HttpUtils from ably.types.connectiondetails import ConnectionDetails +from ably.types.operations import PublishResult from ably.util.eventemitter import EventEmitter from ably.util.exceptions import AblyException from ably.util.helper import Timer, unix_time_ms @@ -172,7 +173,10 @@ async def on_protocol_message(self, msg): # Handle acknowledgment of sent messages msg_serial = msg.get('msgSerial', 0) count = msg.get('count', 1) - self.connection_manager.on_ack(msg_serial, count) + res = msg.get('res') + if res is not None: + res = [PublishResult.from_dict(result) for result in res] + self.connection_manager.on_ack(msg_serial, count, res) elif action == ProtocolMessageAction.NACK: # Handle negative acknowledgment (error sending messages) msg_serial = msg.get('msgSerial', 0) diff --git a/test/ably/realtime/realtimechannelmutablemessages_test.py b/test/ably/realtime/realtimechannelmutablemessages_test.py new file mode 100644 index 00000000..370ac5fe --- /dev/null +++ b/test/ably/realtime/realtimechannelmutablemessages_test.py @@ -0,0 +1,289 @@ +import logging +from typing import List + +import pytest + +from ably import AblyException, CipherParams, MessageAction +from ably.types.message import Message +from ably.types.operations import MessageOperation +from test.ably.testapp import TestApp +from test.ably.utils import BaseAsyncTestCase, WaitableEvent, assert_waiter + +log = logging.getLogger(__name__) + + +@pytest.mark.parametrize("transport", ["json", "msgpack"], ids=["JSON", "MsgPack"]) +class TestRealtimeChannelMutableMessages(BaseAsyncTestCase): + + @pytest.fixture(autouse=True) + async def setup(self, transport): + self.test_vars = await TestApp.get_test_vars() + self.ably = await TestApp.get_ably_realtime( + use_binary_protocol=True if transport == 'msgpack' else False, + ) + + async def test_update_message_success(self): + """Test successfully updating a message""" + channel = self.ably.channels[self.get_channel_name('mutable:update_test')] + + # First publish a message + result = await channel.publish('test-event', 'original data') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Create message with serial for update + message = Message( + data='updated data', + serial=serial, + ) + + # Update the message + update_result = await channel.update_message(message) + assert update_result is not None + updated_message = await self.wait_until_message_with_action_appears( + channel, serial, MessageAction.MESSAGE_UPDATE + ) + assert updated_message.data == 'updated data' + assert updated_message.version.serial == update_result.version_serial + assert updated_message.serial == serial + + async def test_update_message_without_serial_fails(self): + """Test that updating without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:update_test_no_serial')] + + message = Message(name='test-event', data='data') + + with pytest.raises(AblyException) as exc_info: + await channel.update_message(message) + + assert exc_info.value.status_code == 400 + assert 'serial is required' in str(exc_info.value).lower() + + async def test_delete_message_success(self): + """Test successfully deleting a message""" + channel = self.ably.channels[self.get_channel_name('mutable:delete_test')] + + # First publish a message + result = await channel.publish('test-event', 'data to delete') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Create message with serial for deletion + message = Message(serial=serial) + + operation = MessageOperation( + description='Inappropriate content', + metadata={'reason': 'moderation'} + ) + + # Delete the message + delete_result = await channel.delete_message(message, operation) + assert delete_result is not None + + # Verify the deletion propagated + deleted_message = await self.wait_until_message_with_action_appears( + channel, serial, MessageAction.MESSAGE_DELETE + ) + assert deleted_message.action == MessageAction.MESSAGE_DELETE + assert deleted_message.version.serial == delete_result.version_serial + assert deleted_message.version.description == 'Inappropriate content' + assert deleted_message.version.metadata == {'reason': 'moderation'} + assert deleted_message.serial == serial + + async def test_delete_message_without_serial_fails(self): + """Test that deleting without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:delete_test_no_serial')] + + message = Message(name='test-event', data='data') + + with pytest.raises(AblyException) as exc_info: + await channel.delete_message(message) + + assert exc_info.value.status_code == 400 + assert 'serial is required' in str(exc_info.value).lower() + + async def test_append_message_success(self): + """Test successfully appending to a message""" + channel = self.ably.channels[self.get_channel_name('mutable:append_test')] + + # First publish a message + result = await channel.publish('test-event', 'original content') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # Create message with serial and data to append + message = Message( + data=' appended content', + serial=serial + ) + + operation = MessageOperation( + description='Added more info', + metadata={'type': 'amendment'} + ) + + # Append to the message + append_result = await channel.append_message(message, operation) + assert append_result is not None + + # Verify the append propagated - action will be MESSAGE_UPDATE, data should be concatenated + appended_message = await self.wait_until_message_with_action_appears( + channel, serial, MessageAction.MESSAGE_UPDATE + ) + assert appended_message.data == 'original content appended content' + assert appended_message.version.serial == append_result.version_serial + assert appended_message.version.description == 'Added more info' + assert appended_message.version.metadata == {'type': 'amendment'} + assert appended_message.serial == serial + + async def test_append_message_without_serial_fails(self): + """Test that appending without a serial raises an exception""" + channel = self.ably.channels[self.get_channel_name('mutable:append_test_no_serial')] + + message = Message(name='test-event', data='data to append') + + with pytest.raises(AblyException) as exc_info: + await channel.append_message(message) + + assert exc_info.value.status_code == 400 + assert 'serial is required' in str(exc_info.value).lower() + + async def test_update_message_with_encryption(self): + """Test updating an encrypted message""" + # Create channel with encryption + channel_name = self.get_channel_name('mutable:update_encrypted') + cipher_params = CipherParams(secret_key='keyfordecrypt_16', algorithm='aes') + channel = self.ably.channels.get(channel_name, cipher=cipher_params) + + # Publish encrypted message + result = await channel.publish('encrypted-event', 'secret data') + assert result.serials is not None + assert len(result.serials) > 0 + + # Update the encrypted message + message = Message( + name='encrypted-event', + data='updated secret data', + serial=result.serials[0] + ) + + operation = MessageOperation(description='Updated encrypted message') + update_result = await channel.update_message(message, operation) + assert update_result is not None + + async def test_publish_returns_serials(self): + """Test that publish returns PublishResult with serials""" + channel = self.ably.channels[self.get_channel_name('mutable:publish_serials')] + + # Publish multiple messages + messages = [ + Message('event1', 'data1'), + Message('event2', 'data2'), + Message('event3', 'data3') + ] + + result = await channel.publish(messages=messages) + assert result is not None + assert hasattr(result, 'serials') + assert len(result.serials) == 3 + + async def test_complete_workflow_publish_update_delete(self): + """Test complete workflow: publish, update, delete""" + channel = self.ably.channels[self.get_channel_name('mutable:complete_workflow')] + + # 1. Publish a message + result = await channel.publish('workflow_event', 'Initial data') + assert result.serials is not None + assert len(result.serials) > 0 + serial = result.serials[0] + + # 2. Update the message + update_message = Message( + name='workflow_event_updated', + data='Updated data', + serial=serial + ) + update_operation = MessageOperation(description='Updated message') + update_result = await channel.update_message(update_message, update_operation) + assert update_result is not None + + # 3. Delete the message + delete_message = Message(serial=serial, data='Deleted') + delete_operation = MessageOperation(description='Deleted message') + delete_result = await channel.delete_message(delete_message, delete_operation) + assert delete_result is not None + + versions = await self.wait_until_get_all_message_version(channel, serial, 3) + + assert versions[0].version.serial == serial + assert versions[1].version.serial == update_result.version_serial + assert versions[2].version.serial == delete_result.version_serial + + async def test_append_message_with_string_data(self): + """Test appending string data to a message""" + channel = self.ably.channels[self.get_channel_name('mutable:append_string')] + + # Publish initial message + result = await channel.publish('append_event', 'Initial data') + assert len(result.serials) > 0 + serial = result.serials[0] + + messages_received = [] + append_received = WaitableEvent() + + def on_message(message): + messages_received.append(message) + append_received.finish() + + await channel.subscribe(on_message) + + # Append data + append_message = Message( + data=' appended data', + serial=serial + ) + append_operation = MessageOperation(description='Appended to message') + append_result = await channel.append_message(append_message, append_operation) + assert append_result is not None + + # Verify the append + appended_message = await self.wait_until_message_with_action_appears( + channel, serial, MessageAction.MESSAGE_UPDATE + ) + + await append_received.wait() + + assert messages_received[0].data == ' appended data' + assert messages_received[0].action == MessageAction.MESSAGE_APPEND + assert appended_message.data == 'Initial data appended data' + assert appended_message.version.serial == append_result.version_serial + assert appended_message.version.description == 'Appended to message' + assert appended_message.serial == serial + + async def wait_until_message_with_action_appears(self, channel, serial, action): + message: Message | None = None + async def check_message_action(): + nonlocal message + try: + message = await channel.get_message(serial) + return message.action == action + except Exception: + return False + + await assert_waiter(check_message_action) + + return message + + async def wait_until_get_all_message_version(self, channel, serial, count): + versions: List[Message] = [] + async def check_message_versions(): + nonlocal versions + versions = (await channel.get_message_versions(serial)).items + return len(versions) >= count + + await assert_waiter(check_message_versions) + + return versions