Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions ably/realtime/connectionmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import deque
from datetime import datetime
from itertools import zip_longest
from typing import TYPE_CHECKING

import httpx
Expand All @@ -13,6 +14,7 @@
from ably.types.connectiondetails import ConnectionDetails
from ably.types.connectionerrors import ConnectionErrors
from ably.types.connectionstate import ConnectionEvent, ConnectionState, ConnectionStateChange
from ably.types.operations import PublishResult
from ably.types.tokendetails import TokenDetails
from ably.util.eventemitter import EventEmitter
from ably.util.exceptions import AblyException, IncompatibleClientIdException
Expand All @@ -29,7 +31,7 @@ class PendingMessage:

def __init__(self, message: dict):
self.message = message
self.future: asyncio.Future | None = None
self.future: asyncio.Future[PublishResult] | None = None
action = message.get('action')

# Messages that require acknowledgment: MESSAGE, PRESENCE, ANNOTATION, OBJECT
Expand Down Expand Up @@ -58,15 +60,22 @@ def count(self) -> int:
"""Return the number of pending messages"""
return len(self.messages)

def complete_messages(self, serial: int, count: int, err: AblyException | None = None) -> None:
def complete_messages(
self,
serial: int,
count: int,
res: list[PublishResult] | None,
err: AblyException | None = None
) -> None:
"""Complete messages based on serial and count from ACK/NACK

Args:
serial: The msgSerial of the first message being acknowledged
count: The number of messages being acknowledged
res: List of PublishResult objects for each message acknowledged, or None if not available
err: Error from NACK, or None for successful ACK
"""
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, err={err}')
log.debug(f'MessageQueue.complete_messages(): serial={serial}, count={count}, res={res}, err={err}')

if not self.messages:
log.warning('MessageQueue.complete_messages(): called on empty queue')
Expand All @@ -87,12 +96,17 @@ def complete_messages(self, serial: int, count: int, err: AblyException | None =
completed_messages = self.messages[:num_to_complete]
self.messages = self.messages[num_to_complete:]

for msg in completed_messages:
# Default res to empty list if None
res_list = res if res is not None else []
for (msg, publish_result) in zip_longest(completed_messages, res_list):
if msg.future and not msg.future.done():
if err:
msg.future.set_exception(err)
else:
msg.future.set_result(None)
# If publish_result is None, return empty PublishResult
if publish_result is None:
publish_result = PublishResult()
msg.future.set_result(publish_result)

def complete_all_messages(self, err: AblyException) -> None:
"""Complete all pending messages with an error"""
Expand Down Expand Up @@ -199,7 +213,7 @@ async def close_impl(self) -> None:

self.notify_state(ConnectionState.CLOSED)

async def send_protocol_message(self, protocol_message: dict) -> None:
async def send_protocol_message(self, protocol_message: dict) -> PublishResult | None:
"""Send a protocol message and optionally track it for acknowledgment

Args:
Expand Down Expand Up @@ -233,12 +247,14 @@ async def send_protocol_message(self, protocol_message: dict) -> None:
if state_should_queue:
self.queued_messages.appendleft(pending_message)
if pending_message.ack_required:
await pending_message.future
return await pending_message.future
return None

return await self._send_protocol_message_on_connected_state(pending_message)

async def _send_protocol_message_on_connected_state(self, pending_message: PendingMessage) -> None:
async def _send_protocol_message_on_connected_state(
self, pending_message: PendingMessage
) -> PublishResult | None:
if self.state == ConnectionState.CONNECTED and self.transport:
# Add to pending queue before sending (for messages being resent from queue)
if pending_message.ack_required and pending_message not in self.pending_message_queue.messages:
Expand All @@ -253,7 +269,7 @@ async def _send_protocol_message_on_connected_state(self, pending_message: Pendi
AblyException("No active transport", 500, 50000)
)
if pending_message.ack_required:
await pending_message.future
return await pending_message.future
return None

def send_queued_messages(self) -> None:
Expand Down Expand Up @@ -449,15 +465,18 @@ def on_heartbeat(self, id: str | None) -> None:
self.__ping_future.set_result(None)
self.__ping_future = None

def on_ack(self, serial: int, count: int) -> None:
def on_ack(
self, serial: int, count: int, res: list[PublishResult] | None
) -> None:
"""Handle ACK protocol message from server

Args:
serial: The msgSerial of the first message being acknowledged
count: The number of messages being acknowledged
res: List of PublishResult objects for each message acknowledged, or None if not available
"""
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}')
self.pending_message_queue.complete_messages(serial, count)
log.debug(f'ConnectionManager.on_ack(): serial={serial}, count={count}, res={res}')
self.pending_message_queue.complete_messages(serial, count, res)

def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
"""Handle NACK protocol message from server
Expand All @@ -471,7 +490,7 @@ def on_nack(self, serial: int, count: int, err: AblyException | None) -> None:
err = AblyException('Unable to send message; channel not responding', 50001, 500)

log.error(f'ConnectionManager.on_nack(): serial={serial}, count={count}, err={err}')
self.pending_message_queue.complete_messages(serial, count, err)
self.pending_message_queue.complete_messages(serial, count, None, err)

def deactivate_transport(self, reason: AblyException | None = None):
# RTN19a: Before disconnecting, requeue any pending messages
Expand Down
211 changes: 207 additions & 4 deletions ably/realtime/realtime_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from ably.transport.websockettransport import ProtocolMessageAction
from ably.types.channelstate import ChannelState, ChannelStateChange
from ably.types.flags import Flag, has_flag
from ably.types.message import Message
from ably.types.message import Message, MessageAction, MessageVersion
from ably.types.mixins import DecodingContext
from ably.types.operations import MessageOperation, PublishResult, UpdateDeleteResult
from ably.types.presence import PresenceMessage
from ably.util.eventemitter import EventEmitter
from ably.util.exceptions import AblyException, IncompatibleClientIdException
Expand Down Expand Up @@ -390,7 +391,7 @@ def unsubscribe(self, *args) -> None:
self.__message_emitter.off(listener)

# RTL6
async def publish(self, *args, **kwargs) -> None:
async def publish(self, *args, **kwargs) -> PublishResult:
"""Publish a message or messages on this channel

Publishes a single message or an array of messages to the channel.
Expand Down Expand Up @@ -490,7 +491,7 @@ async def publish(self, *args, **kwargs) -> None:
}

# RTL6b: Await acknowledgment from server
await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)
return await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)

def _throw_if_unpublishable_state(self) -> None:
"""Check if the channel and connection are in a state that allows publishing
Expand Down Expand Up @@ -522,6 +523,200 @@ def _throw_if_unpublishable_state(self) -> None:
90001,
)

async def _send_update(self, message: Message, action: MessageAction,
operation: MessageOperation = None) -> UpdateDeleteResult:
"""Internal method to send update/delete/append operations via websocket.

Parameters
----------
message : Message
Message object with serial field required
action : MessageAction
The action type (MESSAGE_UPDATE, MESSAGE_DELETE, MESSAGE_APPEND)
operation : MessageOperation, optional
Operation metadata (description, metadata)

Returns
-------
UpdateDeleteResult
Result containing version serial of the operation

Raises
------
AblyException
If message serial is missing or connection/channel state prevents operation
"""
# Check message has serial
if not message.serial:
raise AblyException(
"Message serial is required for update/delete/append operations",
400,
40000
)

# Check connection and channel state
self._throw_if_unpublishable_state()

# Create version from operation if provided
if not operation:
version = None
else:
version = MessageVersion(
client_id=operation.client_id,
description=operation.description,
metadata=operation.metadata
)

# Create a new message with the operation fields
update_message = Message(
name=message.name,
data=message.data,
client_id=message.client_id,
serial=message.serial,
action=action,
version=version,
)

# Encrypt if needed
if self.cipher:
update_message.encrypt(self.cipher)

# Convert to dict representation
msg_dict = update_message.as_dict(binary=self.ably.options.use_binary_protocol)

log.info(
f'RealtimeChannel._send_update(): sending {action.name} message; '
f'channel = {self.name}, state = {self.state}, serial = {message.serial}'
)

# Send protocol message
protocol_message = {
"action": ProtocolMessageAction.MESSAGE,
"channel": self.name,
"messages": [msg_dict],
}

# Send and await acknowledgment
result = await self.__realtime.connection.connection_manager.send_protocol_message(protocol_message)

# Return UpdateDeleteResult - we don't have version_serial from the result yet
# The server will send ACK with the result
if result and hasattr(result, 'serials') and result.serials:
return UpdateDeleteResult(version_serial=result.serials[0])
return UpdateDeleteResult()

async def update_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
"""Updates an existing message on this channel.

Parameters
----------
message : Message
Message object to update. Must have a serial field.
operation : MessageOperation, optional
Optional MessageOperation containing description and metadata for the update.

Returns
-------
UpdateDeleteResult
Result containing the version serial of the updated message.

Raises
------
AblyException
If message serial is missing or connection/channel state prevents the update
"""
return await self._send_update(message, MessageAction.MESSAGE_UPDATE, operation)

async def delete_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
"""Deletes a message on this channel.

Parameters
----------
message : Message
Message object to delete. Must have a serial field.
operation : MessageOperation, optional
Optional MessageOperation containing description and metadata for the delete.

Returns
-------
UpdateDeleteResult
Result containing the version serial of the deleted message.

Raises
------
AblyException
If message serial is missing or connection/channel state prevents the delete
"""
return await self._send_update(message, MessageAction.MESSAGE_DELETE, operation)

async def append_message(self, message: Message, operation: MessageOperation = None) -> UpdateDeleteResult:
"""Appends data to an existing message on this channel.

Parameters
----------
message : Message
Message object with data to append. Must have a serial field.
operation : MessageOperation, optional
Optional MessageOperation containing description and metadata for the append.

Returns
-------
UpdateDeleteResult
Result containing the version serial of the appended message.

Raises
------
AblyException
If message serial is missing or connection/channel state prevents the append
"""
return await self._send_update(message, MessageAction.MESSAGE_APPEND, operation)

async def get_message(self, serial_or_message, timeout=None):
"""Retrieves a single message by its serial using the REST API.

Parameters
----------
serial_or_message : str or Message
Either a string serial or a Message object with a serial field.
timeout : float, optional
Timeout for the request.

Returns
-------
Message
Message object for the requested serial.

Raises
------
AblyException
If the serial is missing or the message cannot be retrieved.
"""
# Delegate to parent Channel (REST) implementation
return await Channel.get_message(self, serial_or_message, timeout=timeout)

async def get_message_versions(self, serial_or_message, params=None):
"""Retrieves version history for a message using the REST API.

Parameters
----------
serial_or_message : str or Message
Either a string serial or a Message object with a serial field.
params : dict, optional
Optional dict of query parameters for pagination.

Returns
-------
PaginatedResult
PaginatedResult containing Message objects representing each version.

Raises
------
AblyException
If the serial is missing or versions cannot be retrieved.
"""
# Delegate to parent Channel (REST) implementation
return await Channel.get_message_versions(self, serial_or_message, params=params)

def _on_message(self, proto_msg: dict) -> None:
action = proto_msg.get('action')
# RTL4c1
Expand Down Expand Up @@ -766,7 +961,7 @@ class Channels(RestChannels):
"""

# RTS3
def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChannel:
def get(self, name: str, options: ChannelOptions | None = None, **kwargs) -> RealtimeChannel:
"""Creates a new RealtimeChannel object, or returns the existing channel object.

Parameters
Expand All @@ -776,7 +971,15 @@ def get(self, name: str, options: ChannelOptions | None = None) -> RealtimeChann
Channel name
options: ChannelOptions or dict, optional
Channel options for the channel
**kwargs:
Additional keyword arguments to create ChannelOptions (e.g., cipher, params)
"""
# Convert kwargs to ChannelOptions if provided
if kwargs and not options:
options = ChannelOptions(**kwargs)
elif options and isinstance(options, dict):
options = ChannelOptions.from_dict(options)

if name not in self.__all:
channel = self.__all[name] = RealtimeChannel(self.__ably, name, options)
else:
Expand Down
2 changes: 1 addition & 1 deletion ably/transport/defaults.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class Defaults:
protocol_version = "2"
protocol_version = "5"
fallback_hosts = [
"a.ably-realtime.com",
"b.ably-realtime.com",
Expand Down
Loading
Loading