From ee494a6543c8084cd2d64a934e8c162a634a14d0 Mon Sep 17 00:00:00 2001 From: uael Date: Tue, 7 Nov 2023 00:42:59 -0800 Subject: [PATCH 1/2] l2cap: refactor server side to allow deferred accept In order to avoid any breaking changes this re-impl current APIs with the exact same behavior. The previous impl was preventing one to defer the response to an l2cap channel connection request, both for BR/EDR basic channels and LE credit based ones. This commit change this to spawn a task on every channel incoming connection request, then all registered listeners are given a chance to accept it through a `asyncio.Future`. After a bit of delay, if none had accepted it, the connection is automatically rejected. --- bumble/l2cap.py | 570 +++++++++++++++++++++++++++++++----------------- 1 file changed, 375 insertions(+), 195 deletions(-) diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 7a2f0ede..f245b3ea 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -35,8 +35,10 @@ Union, Deque, Iterable, + Set, SupportsBytes, TYPE_CHECKING, + overload, ) from .utils import deprecated @@ -237,6 +239,8 @@ class L2CAP_Control_Frame: classes: Dict[int, Type[L2CAP_Control_Frame]] = {} code = 0 name: str + identifier: int + pdu: bytes @staticmethod def from_bytes(pdu: bytes) -> L2CAP_Control_Frame: @@ -391,6 +395,9 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST ''' + psm: int + source_cid: int + @staticmethod def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: psm_length = 2 @@ -637,7 +644,11 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame): (CODE 0x14) ''' + le_psm: int source_cid: int + mtu: int + mps: int + initial_credits: int # ----------------------------------------------------------------------------- @@ -1375,19 +1386,14 @@ def __str__(self) -> str: # ----------------------------------------------------------------------------- +@dataclasses.dataclass class ClassicChannelServer(EventEmitter): - def __init__( - self, - manager: ChannelManager, - psm: int, - handler: Optional[Callable[[ClassicChannel], Any]], - mtu: int, - ) -> None: + _close_closure: Callable[[], None] + psm: int + handler: Optional[Callable[[ClassicChannel], Any]] + + def __post_init__(self) -> None: super().__init__() - self.manager = manager - self.handler = handler - self.psm = psm - self.mtu = mtu def on_connection(self, channel: ClassicChannel) -> None: self.emit('connection', channel) @@ -1395,28 +1401,18 @@ def on_connection(self, channel: ClassicChannel) -> None: self.handler(channel) def close(self) -> None: - if self.psm in self.manager.servers: - del self.manager.servers[self.psm] + self._close_closure() # ----------------------------------------------------------------------------- +@dataclasses.dataclass class LeCreditBasedChannelServer(EventEmitter): - def __init__( - self, - manager: ChannelManager, - psm: int, - handler: Optional[Callable[[LeCreditBasedChannel], Any]], - max_credits: int, - mtu: int, - mps: int, - ) -> None: + _close_closure: Callable[[], None] + psm: int + handler: Optional[Callable[[LeCreditBasedChannel], Any]] + + def __post_init__(self) -> None: super().__init__() - self.manager = manager - self.handler = handler - self.psm = psm - self.max_credits = max_credits - self.mtu = mtu - self.mps = mps def on_connection(self, channel: LeCreditBasedChannel) -> None: self.emit('connection', channel) @@ -1424,21 +1420,107 @@ def on_connection(self, channel: LeCreditBasedChannel) -> None: self.handler(channel) def close(self) -> None: - if self.psm in self.manager.le_coc_servers: - del self.manager.le_coc_servers[self.psm] + self._close_closure() + + +# ----------------------------------------------------------------------------- +class PendingConnection: + """ + All pending connection types. + A `PendingConnection` is a temporary object used to accept an incoming connection + request, it contains the acceptor channel configuration preferences and transition + to the connected state through the `on_connection` callback. + This object is not supposed to live anymore once the channel is connected. + """ + + class Any: + """L2CAP any channel pending connection.""" + + on_connection: Callable[[Any], None] + mtu: int + + @dataclasses.dataclass + class Basic(Any): + """L2CAP basic channel pending connection.""" + + on_connection: Callable[[ClassicChannel], None] = lambda _: None + mtu: int = L2CAP_MIN_BR_EDR_MTU + + @dataclasses.dataclass + class LeCreditBased(Any): + """L2CAP LE credit based channel pending connection.""" + + on_connection: Callable[[LeCreditBasedChannel], None] = lambda _: None + mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU + mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS + + +# ----------------------------------------------------------------------------- +class IncomingConnection: + """ + All incoming connection types. + A `IncomingConnection` is a temporary object used to notify listeners of an + incoming channel connection request. It can accepted through the `future` field. + Multiple listeners can observe the same incoming connection request, but no more + than one can actually accept, first come first served. Thus it's recommended for + delayed accept to before check the state of the future field. + This object is not supposed to live anymore once accepted. + + Example: + ```python + fut = asyncio.Future() + + def listener(incoming: IncomingConnection.Any) -> None: + if isinstance(incoming, IncomingConnection.Basic) and incoming.psm == 0xcafe: + incoming.future.set_result(PendingConnection.Basic(fut.set_result, mtu=123)) + + device.l2cap_manager.listen(listener) + channel = await fut + ``` + """ + + @dataclasses.dataclass + class Any: + """L2CAP any incoming channel connection request.""" + + connection: Connection + psm: int + source_cid: int + + def __post_init__(self) -> None: + self.future: asyncio.Future[Any] = asyncio.Future() + + @dataclasses.dataclass + class Basic(Any): + """L2CAP incoming basic channel connection request.""" + + future: asyncio.Future[PendingConnection.Basic] = dataclasses.field(init=False) + + @dataclasses.dataclass + class LeCreditBased(Any): + """L2CAP incoming LE credit based channel connection request.""" + + mtu: int + mps: int + initial_credits: int + + future: asyncio.Future[PendingConnection.LeCreditBased] = dataclasses.field( + init=False + ) # ----------------------------------------------------------------------------- class ChannelManager: identifiers: Dict[int, int] channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]] - servers: Dict[int, ClassicChannelServer] le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]] - le_coc_servers: Dict[int, LeCreditBasedChannelServer] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] _host: Optional[Host] connection_parameters_update_response: Optional[asyncio.Future[int]] + listeners: List[Callable[[IncomingConnection.Any], None]] + used_psm: Set[int] def __init__( self, @@ -1452,15 +1534,15 @@ def __init__( L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None, } - self.servers = {} # Servers accepting connections, by PSM self.le_coc_channels = ( {} ) # LE CoC channels, mapped by connection and destination cid - self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM self.le_coc_requests = {} # LE CoC connection requests, by identifier self.extended_features = extended_features self.connectionless_mtu = connectionless_mtu self.connection_parameters_update_response = None + self.listeners = [] + self.used_psm = set() @property def host(self) -> Host: @@ -1513,6 +1595,31 @@ def find_free_le_cid(channels: Iterable[int]) -> int: raise RuntimeError('no free CID') + def allocate_psm(self) -> int: + # Find a free PSM + for candidate in range( + L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 + ): + if (candidate >> 8) % 2 == 1: + continue + if candidate in self.used_psm: + continue + return candidate + raise InvalidStateError('no free PSM') + + def allocate_spsm(self) -> int: + # Find a free sPSM + for candidate in range( + L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 + ): + if candidate in self.used_psm: + continue + return candidate + raise InvalidStateError('no free PSM') + + def free_psm(self, psm: int) -> None: + self.used_psm.remove(psm) + def next_identifier(self, connection: Connection) -> int: identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier @@ -1527,6 +1634,35 @@ def deregister_fixed_channel(self, cid: int) -> None: if cid in self.fixed_channels: del self.fixed_channels[cid] + @overload + def listen( + self, cb: Callable[[IncomingConnection.Basic], None] + ) -> Callable[[IncomingConnection.Basic], None]: + ... + + @overload + def listen( + self, cb: Callable[[IncomingConnection.LeCreditBased], None] + ) -> Callable[[IncomingConnection.LeCreditBased], None]: + ... + + def listen(self, cb: Any) -> Any: + if cb in self.listeners: + raise ValueError('listener already registered') + self.listeners.append(cb) + return cb + + @overload + def unlisten(self, cb: Callable[[IncomingConnection.Basic], None]) -> None: + ... + + @overload + def unlisten(self, cb: Callable[[IncomingConnection.LeCreditBased], None]) -> None: + ... + + def unlisten(self, cb: Any) -> None: + self.listeners.remove(cb) + @deprecated("Please use create_classic_server") def register_server( self, @@ -1534,7 +1670,7 @@ def register_server( server: Callable[[ClassicChannel], Any], ) -> int: return self.create_classic_server( - handler=server, spec=ClassicChannelSpec(psm=psm) + handler=server, spec=ClassicChannelSpec(psm=None if psm == 0 else psm) ).psm def create_classic_server( @@ -1542,24 +1678,12 @@ def create_classic_server( spec: ClassicChannelSpec, handler: Optional[Callable[[ClassicChannel], Any]] = None, ) -> ClassicChannelServer: - if not spec.psm: - # Find a free PSM - for candidate in range( - L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2 - ): - if (candidate >> 8) % 2 == 1: - continue - if candidate in self.servers: - continue - spec.psm = candidate - break - else: - raise InvalidStateError('no free PSM') + server: ClassicChannelServer + if spec.psm is None: + spec.psm = self.allocate_psm() else: - # Check that the PSM isn't already in use - if spec.psm in self.servers: - raise ValueError('PSM already in use') - + if spec.psm is self.used_psm: + raise ValueError(f'{spec.psm}: PSM already in use') # Check that the PSM is valid if spec.psm % 2 == 0: raise ValueError('invalid PSM (not odd)') @@ -1568,10 +1692,22 @@ def create_classic_server( if check % 2 != 0: raise ValueError('invalid PSM') check >>= 8 + self.used_psm.add(spec.psm) + + def listener(incoming: IncomingConnection.Basic) -> None: + if incoming.psm == spec.psm: + incoming.future.set_result( + PendingConnection.Basic(server.on_connection, spec.mtu) + ) - self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu) + def close() -> None: + self.unlisten(listener) + assert spec.psm is not None + self.free_psm(spec.psm) - return self.servers[spec.psm] + self.listen(listener) + server = ClassicChannelServer(close, spec.psm, handler) + return server @deprecated("Please use create_le_credit_based_server()") def register_le_coc_server( @@ -1594,32 +1730,30 @@ def create_le_credit_based_server( spec: LeCreditBasedChannelSpec, handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None, ) -> LeCreditBasedChannelServer: - if not spec.psm: - # Find a free PSM - for candidate in range( - L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1 - ): - if candidate in self.le_coc_servers: - continue - spec.psm = candidate - break - else: - raise InvalidStateError('no free PSM') + server: LeCreditBasedChannelServer + if spec.psm is None: + spec.psm = self.allocate_psm() else: - # Check that the PSM isn't already in use - if spec.psm in self.le_coc_servers: - raise ValueError('PSM already in use') + if spec.psm is self.used_psm: + raise ValueError(f'{spec.psm}: SPSM already in use') + self.used_psm.add(spec.psm) + + def listener(incoming: IncomingConnection.LeCreditBased) -> None: + if incoming.psm == spec.psm: + incoming.future.set_result( + PendingConnection.LeCreditBased( + server.on_connection, spec.mtu, spec.mps, spec.max_credits + ) + ) - self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer( - self, - spec.psm, - handler, - max_credits=spec.max_credits, - mtu=spec.mtu, - mps=spec.mps, - ) + def close() -> None: + self.unlisten(listener) + assert spec.psm is not None + self.free_psm(spec.psm) - return self.le_coc_servers[spec.psm] + self.listen(listener) + server = LeCreditBasedChannelServer(close, spec.psm, handler) + return server def on_disconnection(self, connection_handle: int, _reason: int) -> None: logger.debug(f'disconnection from {connection_handle}, cleaning up channels') @@ -1719,15 +1853,62 @@ def on_l2cap_command_reject( logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') def on_l2cap_connection_request( - self, connection: Connection, cid: int, request + self, connection: Connection, cid: int, request: L2CAP_Connection_Request ) -> None: - # Check if there's a server for this PSM - server = self.servers.get(request.psm) - if server: - # Find a free CID for this new channel - connection_channels = self.channels.setdefault(connection.handle, {}) - source_cid = self.find_free_br_edr_cid(connection_channels) - if source_cid is None: # Should never happen! + + # Asynchronous connection request handling. + async def handle_connection_request() -> None: + incoming = IncomingConnection.Basic( + connection, request.psm, request.source_cid + ) + + # Dispatch incoming connection. + for listener in self.listeners: + if not incoming.future.done(): + listener(incoming) + + try: + pending = await asyncio.wait_for(incoming.future, timeout=3.0) + except asyncio.TimeoutError as e: + incoming.future.cancel(e) + pending = None + + if pending: + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_br_edr_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_Connection_Response( + identifier=request.identifier, + destination_cid=request.source_cid, + source_cid=0, + # pylint: disable=line-too-long + result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + status=0x0000, + ), + ) + return + + # Create a new channel + logger.debug( + f'creating server channel with cid={source_cid} for psm {request.psm}' + ) + channel = ClassicChannel( + self, connection, cid, request.psm, source_cid, pending.mtu + ) + connection_channels[source_cid] = channel + + # Notify + pending.on_connection(channel) + channel.on_connection_request(request) + else: + logger.warning( + f'No server for connection 0x{connection.handle:04X} ' + f'on PSM {request.psm}' + ) self.send_control_frame( connection, cid, @@ -1736,41 +1917,13 @@ def on_l2cap_connection_request( destination_cid=request.source_cid, source_cid=0, # pylint: disable=line-too-long - result=L2CAP_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, status=0x0000, ), ) - return - # Create a new channel - logger.debug( - f'creating server channel with cid={source_cid} for psm {request.psm}' - ) - channel = ClassicChannel( - self, connection, cid, request.psm, source_cid, server.mtu - ) - connection_channels[source_cid] = channel - - # Notify - server.on_connection(channel) - channel.on_connection_request(request) - else: - logger.warning( - f'No server for connection 0x{connection.handle:04X} ' - f'on PSM {request.psm}' - ) - self.send_control_frame( - connection, - cid, - L2CAP_Connection_Response( - identifier=request.identifier, - destination_cid=request.source_cid, - source_cid=0, - # pylint: disable=line-too-long - result=L2CAP_Connection_Response.CONNECTION_REFUSED_PSM_NOT_SUPPORTED, - status=0x0000, - ), - ) + # Spawn connection request handling. + connection.abort_on('disconnection', handle_connection_request()) def on_l2cap_connection_response( self, connection: Connection, cid: int, response @@ -1971,108 +2124,135 @@ def on_l2cap_connection_parameter_update_response( ) def on_l2cap_le_credit_based_connection_request( - self, connection: Connection, cid: int, request + self, + connection: Connection, + cid: int, + request: L2CAP_LE_Credit_Based_Connection_Request, ) -> None: - if request.le_psm in self.le_coc_servers: - server = self.le_coc_servers[request.le_psm] - # Check that the CID isn't already used - le_connection_channels = self.le_coc_channels.setdefault( - connection.handle, {} + # Asynchronous connection request handling. + async def handle_connection_request() -> None: + incoming = IncomingConnection.LeCreditBased( + connection, + request.le_psm, + request.source_cid, + request.mtu, + request.mps, + request.initial_credits, ) - if request.source_cid in le_connection_channels: - logger.warning(f'source CID {request.source_cid} already in use') + + # Dispatch incoming connection. + for listener in self.listeners: + if not incoming.future.done(): + listener(incoming) + + try: + pending = await asyncio.wait_for(incoming.future, timeout=3.0) + except asyncio.TimeoutError as e: + incoming.future.cancel(e) + pending = None + + if pending: + # Check that the CID isn't already used + le_connection_channels = self.le_coc_channels.setdefault( + connection.handle, {} + ) + if request.source_cid in le_connection_channels: + logger.warning(f'source CID {request.source_cid} already in use') + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=0, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + ), + ) + return + + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_le_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier=request.identifier, + destination_cid=0, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=0, + # pylint: disable=line-too-long + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + ), + ) + return + + # Create a new channel + logger.debug( + f'creating LE CoC server channel with cid={source_cid} for psm ' + f'{request.le_psm}' + ) + channel = LeCreditBasedChannel( + self, + connection, + request.le_psm, + source_cid, + request.source_cid, + pending.mtu, + pending.mps, + request.initial_credits, + request.mtu, + request.mps, + pending.max_credits, + True, + ) + connection_channels[source_cid] = channel + le_connection_channels[request.source_cid] = channel + + # Respond self.send_control_frame( connection, cid, L2CAP_LE_Credit_Based_Connection_Response( identifier=request.identifier, - destination_cid=0, - mtu=server.mtu, - mps=server.mps, - initial_credits=0, + destination_cid=source_cid, + mtu=pending.mtu, + mps=pending.mps, + initial_credits=pending.max_credits, # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, ), ) - return - # Find a free CID for this new channel - connection_channels = self.channels.setdefault(connection.handle, {}) - source_cid = self.find_free_le_cid(connection_channels) - if source_cid is None: # Should never happen! + # Notify + pending.on_connection(channel) + else: + logger.info( + f'No LE server for connection 0x{connection.handle:04X} ' + f'on PSM {request.le_psm}' + ) self.send_control_frame( connection, cid, L2CAP_LE_Credit_Based_Connection_Response( identifier=request.identifier, destination_cid=0, - mtu=server.mtu, - mps=server.mps, + mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, initial_credits=0, # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, ), ) - return - - # Create a new channel - logger.debug( - f'creating LE CoC server channel with cid={source_cid} for psm ' - f'{request.le_psm}' - ) - channel = LeCreditBasedChannel( - self, - connection, - request.le_psm, - source_cid, - request.source_cid, - server.mtu, - server.mps, - request.initial_credits, - request.mtu, - request.mps, - server.max_credits, - True, - ) - connection_channels[source_cid] = channel - le_connection_channels[request.source_cid] = channel - - # Respond - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=source_cid, - mtu=server.mtu, - mps=server.mps, - initial_credits=server.max_credits, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL, - ), - ) - # Notify - server.on_connection(channel) - else: - logger.info( - f'No LE server for connection 0x{connection.handle:04X} ' - f'on PSM {request.le_psm}' - ) - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier=request.identifier, - destination_cid=0, - mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - initial_credits=0, - # pylint: disable=line-too-long - result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, - ), - ) + # Spawn connection request handling. + connection.abort_on('disconnection', handle_connection_request()) def on_l2cap_le_credit_based_connection_response( self, connection: Connection, _cid: int, response From 412fd0f78a663a39143757995e3137d5a68edbdc Mon Sep 17 00:00:00 2001 From: uael Date: Tue, 7 Nov 2023 00:53:56 -0800 Subject: [PATCH 2/2] pandora: implement L2CAP pandora service Co-authored-by: Josh Wu --- bumble/pandora/__init__.py | 3 + bumble/pandora/l2cap.py | 289 +++++++++++++++++++++++++++++++++++++ setup.cfg | 2 +- 3 files changed, 293 insertions(+), 1 deletion(-) create mode 100644 bumble/pandora/l2cap.py diff --git a/bumble/pandora/__init__.py b/bumble/pandora/__init__.py index e02f54a9..cdf9da94 100644 --- a/bumble/pandora/__init__.py +++ b/bumble/pandora/__init__.py @@ -26,11 +26,13 @@ from .device import PandoraDevice from .host import HostService from .security import SecurityService, SecurityStorageService +from .l2cap import L2CAPService from pandora.host_grpc_aio import add_HostServicer_to_server from pandora.security_grpc_aio import ( add_SecurityServicer_to_server, add_SecurityStorageServicer_to_server, ) +from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server from typing import Callable, List, Optional # public symbols @@ -77,6 +79,7 @@ async def serve( add_SecurityStorageServicer_to_server( SecurityStorageService(bumble.device, config), server ) + add_L2CAPServicer_to_server(L2CAPService(bumble.device, config), server) # call hooks if any. for hook in _SERVICERS_HOOKS: diff --git a/bumble/pandora/l2cap.py b/bumble/pandora/l2cap.py new file mode 100644 index 00000000..47cb7a83 --- /dev/null +++ b/bumble/pandora/l2cap.py @@ -0,0 +1,289 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import dataclasses +import grpc +import struct + +from bumble import device +from bumble import l2cap +from bumble.pandora import config +from bumble.pandora import utils +from bumble.utils import EventWatcher +from google.protobuf import any_pb2 # pytype: disable=pyi-error +from google.protobuf import empty_pb2 # pytype: disable=pyi-error +from pandora import l2cap_pb2 +from pandora import l2cap_grpc_aio +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Union + + +@dataclasses.dataclass +class ChannelProxy: + channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel, None] + + def __post_init__(self) -> None: + assert self.channel + self.rx: asyncio.Queue[bytes] = asyncio.Queue() + self._disconnection_result: asyncio.Future[None] = asyncio.Future() + self.channel.sink = self.rx.put_nowait + + def on_close() -> None: + assert not self._disconnection_result.done() + self.channel = None + self._disconnection_result.set_result(None) + + self.channel.on('close', on_close) + + def send(self, data: bytes) -> None: + assert self.channel + if isinstance(self.channel, l2cap.ClassicChannel): + self.channel.send_pdu(data) + else: + self.channel.write(data) + + async def disconnect(self) -> None: + assert self.channel + await self.channel.disconnect() + + async def wait_disconnect(self) -> None: + await self._disconnection_result + assert not self.channel + + +@dataclasses.dataclass +class ChannelIndex: + connection_handle: int + cid: int + + @classmethod + def from_token(cls, token: l2cap_pb2.Channel) -> 'ChannelIndex': + connection_handle, cid = struct.unpack('>HH', token.cookie.value) + return cls(connection_handle, cid) + + def into_token(self) -> l2cap_pb2.Channel: + return l2cap_pb2.Channel( + cookie=any_pb2.Any( + value=struct.pack('>HH', self.connection_handle, self.cid) + ) + ) + + def __hash__(self): + return hash(self.connection_handle | (self.cid << 12)) + + +class L2CAPService(l2cap_grpc_aio.L2CAPServicer): + channels: Dict[ChannelIndex, ChannelProxy] = {} + pending: List[l2cap.IncomingConnection.Any] = [] + accepts: List[asyncio.Queue[l2cap.IncomingConnection.Any]] = [] + + def __init__(self, dev: device.Device, config: config.Config) -> None: + self.device = dev + self.config = config + + def on_connection(incoming: l2cap.IncomingConnection.Any) -> None: + self.pending.append(incoming) + for acceptor in self.accepts: + acceptor.put_nowait(incoming) + + # Make sure our listener is called before the builtins ones. + self.device.l2cap_channel_manager.listeners.insert(0, on_connection) + + def register(self, index: ChannelIndex, proxy: ChannelProxy) -> None: + self.channels[index] = proxy + + def on_close(*_: Any) -> None: + # TODO: Fix Bumble L2CAP which emit `close` event twice. + if index in self.channels: + del self.channels[index] + + # Listen for disconnection. + assert proxy.channel + proxy.channel.on('close', on_close) + + async def listen(self) -> AsyncIterator[l2cap.IncomingConnection.Any]: + for incoming in self.pending: + if incoming.future.done(): + self.pending.remove(incoming) + continue + yield incoming + queue: asyncio.Queue[l2cap.IncomingConnection.Any] = asyncio.Queue() + self.accepts.append(queue) + try: + while incoming := await queue.get(): + yield incoming + finally: + self.accepts.remove(queue) + + @utils.rpc + async def Connect( + self, request: l2cap_pb2.ConnectRequest, context: grpc.ServicerContext + ) -> l2cap_pb2.ConnectResponse: + # Retrieve Bumble `Connection` from request. + connection_handle = int.from_bytes(request.connection.cookie.value, 'big') + connection = self.device.lookup_connection(connection_handle) + if connection is None: + raise RuntimeError(f'{connection_handle}: not connection for handle') + + channel: Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel] + if request.type_variant() == 'basic': + assert request.basic + channel = await connection.create_l2cap_channel( + spec=l2cap.ClassicChannelSpec( + psm=request.basic.psm, mtu=request.basic.mtu + ) + ) + elif request.type_variant() == 'le_credit_based': + assert request.le_credit_based + channel = await connection.create_l2cap_channel( + spec=l2cap.LeCreditBasedChannelSpec( + psm=request.le_credit_based.spsm, + max_credits=request.le_credit_based.initial_credit, + mtu=request.le_credit_based.mtu, + mps=request.le_credit_based.mps, + ) + ) + else: + raise NotImplementedError(f"{request.type_variant()}: unsupported type") + + index = ChannelIndex(channel.connection.handle, channel.source_cid) + self.register(index, ChannelProxy(channel)) + return l2cap_pb2.ConnectResponse(channel=index.into_token()) + + @utils.rpc + async def WaitConnection( + self, request: l2cap_pb2.WaitConnectionRequest, context: grpc.ServicerContext + ) -> l2cap_pb2.WaitConnectionResponse: + iter = self.listen() + fut: asyncio.Future[ + Union[l2cap.ClassicChannel, l2cap.LeCreditBasedChannel] + ] = asyncio.Future() + + # Filter by connection. + if request.connection: + handle = int.from_bytes(request.connection.cookie.value, 'big') + iter = (it async for it in iter if it.connection.handle == handle) + + if request.type_variant() == 'basic': + assert request.basic + basic = l2cap.PendingConnection.Basic( + fut.set_result, + request.basic.mtu or l2cap.L2CAP_MIN_BR_EDR_MTU, + ) + async for i in ( + it + async for it in iter + if isinstance(it, l2cap.IncomingConnection.Basic) + ): + if not i.future.done() and i.psm == request.basic.psm: + i.future.set_result(basic) + break + elif request.type_variant() == 'le_credit_based': + assert request.le_credit_based + le_credit_based = l2cap.PendingConnection.LeCreditBased( + fut.set_result, + request.le_credit_based.mtu + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + request.le_credit_based.mps + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + request.le_credit_based.initial_credit + or l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, + ) + async for j in ( + it + async for it in iter + if isinstance(it, l2cap.IncomingConnection.LeCreditBased) + ): + if not j.future.done() and j.psm == request.le_credit_based.spsm: + j.future.set_result(le_credit_based) + break + else: + raise NotImplementedError(f"{request.type_variant()}: unsupported type") + + channel = await fut + index = ChannelIndex(channel.connection.handle, channel.source_cid) + self.register(index, ChannelProxy(channel)) + return l2cap_pb2.WaitConnectionResponse(channel=index.into_token()) + + @utils.rpc + async def Disconnect( + self, request: l2cap_pb2.DisconnectRequest, context: grpc.ServicerContext + ) -> l2cap_pb2.DisconnectResponse: + channel = self.channels[ChannelIndex.from_token(request.channel)] + await channel.disconnect() + return l2cap_pb2.DisconnectResponse(success=empty_pb2.Empty()) + + @utils.rpc + async def WaitDisconnection( + self, request: l2cap_pb2.WaitDisconnectionRequest, context: grpc.ServicerContext + ) -> l2cap_pb2.WaitDisconnectionResponse: + channel = self.channels[ChannelIndex.from_token(request.channel)] + await channel.wait_disconnect() + return l2cap_pb2.WaitDisconnectionResponse(success=empty_pb2.Empty()) + + @utils.rpc + async def Receive( + self, request: l2cap_pb2.ReceiveRequest, context: grpc.ServicerContext + ) -> AsyncGenerator[l2cap_pb2.ReceiveResponse, None]: + watcher = EventWatcher() + if request.source_variant() == 'channel': + assert request.channel + channel = self.channels[ChannelIndex.from_token(request.channel)] + rx = channel.rx + elif request.source_variant() == 'fixed_channel': + assert request.fixed_channel + rx = asyncio.Queue() + handle = request.fixed_channel.connection is not None and int.from_bytes( + request.fixed_channel.connection.cookie.value, 'big' + ) + + @watcher.on(self.device.host, 'l2cap_pdu') + def _(connection: device.Connection, cid: int, pdu: bytes) -> None: + assert request.fixed_channel + if cid == request.fixed_channel.cid and ( + handle is None or handle == connection.handle + ): + rx.put_nowait(pdu) + + else: + raise NotImplementedError(f"{request.source_variant()}: unsupported type") + try: + while data := await rx.get(): + yield l2cap_pb2.ReceiveResponse(data=data) + finally: + watcher.close() + + @utils.rpc + async def Send( + self, request: l2cap_pb2.SendRequest, context: grpc.ServicerContext + ) -> l2cap_pb2.SendResponse: + if request.sink_variant() == 'channel': + assert request.channel + channel = self.channels[ChannelIndex.from_token(request.channel)] + channel.send(request.data) + elif request.sink_variant() == 'fixed_channel': + assert request.fixed_channel + # Retrieve Bumble `Connection` from request. + connection_handle = int.from_bytes( + request.fixed_channel.connection.cookie.value, 'big' + ) + connection = self.device.lookup_connection(connection_handle) + if connection is None: + raise RuntimeError(f'{connection_handle}: not connection for handle') + self.device.l2cap_channel_manager.send_pdu( + connection, request.fixed_channel.cid, request.data + ) + else: + raise NotImplementedError(f"{request.sink_variant()}: unsupported type") + return l2cap_pb2.SendResponse(success=empty_pb2.Empty()) diff --git a/setup.cfg b/setup.cfg index 5cdf35ab..d34df086 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ include_package_data = True install_requires = aiohttp ~= 3.8; platform_system!='Emscripten' appdirs >= 1.4; platform_system!='Emscripten' - bt-test-interfaces >= 0.0.2; platform_system!='Emscripten' + bt-test-interfaces >= 0.0.5; platform_system!='Emscripten' click == 8.1.3; platform_system!='Emscripten' cryptography == 39; platform_system!='Emscripten' # Pyodide bundles a version of cryptography that is built for wasm, which may not match the