Skip to content

Commit a7fe147

Browse files
committed
Create and use a Threadsafe pika connection
It wraps a pika SelectConnection with just the APIs that I need, implemented as thread-safe, converting from callbacks to something more ergonomic when dealing with threads.
1 parent 5767317 commit a7fe147

File tree

5 files changed

+246
-66
lines changed

5 files changed

+246
-66
lines changed

qio/broker_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,20 @@ def receive_messages():
7474
receiver = broker.receive(queuespec)
7575
for message in receiver:
7676
received_messages.append(message)
77-
# Don't call start() - this should block after prefetch_limit messages
7877

7978
thread = threading.Thread(target=receive_messages)
8079
thread.start()
8180

8281
# Wait for receiver to reach prefetch limit or timeout
83-
thread.join(timeout=1.0)
82+
thread.join(timeout=0.1)
8483

8584
# Should have received exactly prefetch_limit messages and be blocked
8685
assert len(received_messages) == prefetch_limit
8786
assert thread.is_alive() # Thread should still be alive (blocked)
8887

8988
broker.shutdown()
9089
thread.join(timeout=1.0) # Clean up thread
90+
assert not thread.is_alive()
9191

9292
@pytest.mark.timeout(2)
9393
def test_suspend_resume_affects_prefetch_capacity(self, broker):

qio/pika/broker.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from threading import Lock
22

3-
from pika import BlockingConnection
4-
from pika import ConnectionParameters
53
from pika import URLParameters
4+
from pika.connection import Parameters
65

76
from qio.broker import Broker
87
from qio.queuespec import QueueSpec
98

109
from .receiver import PikaReceiver
10+
from .threadsafe import ThreadsafeConnection
1111

1212

1313
class PikaBroker(Broker):
@@ -19,38 +19,23 @@ def from_uri(cls, uri: str, /):
1919
amqp_uri = "amqp:" + uri.removeprefix("pika:")
2020
return cls(URLParameters(amqp_uri))
2121

22-
def __init__(self, connection_params: ConnectionParameters | URLParameters):
23-
self.__connection_params = connection_params
24-
self.__producer_channel_lock = Lock()
25-
self.__producer_channel = BlockingConnection(self.__connection_params).channel()
22+
def __init__(self, connection_params: Parameters):
23+
self.__connection = ThreadsafeConnection(connection_params)
24+
self.__channel = self.__connection.channel()
2625
self.__shutdown_lock = Lock()
2726
self.__shutdown = False
2827
self.__receivers = set[PikaReceiver]()
2928

3029
def enqueue(self, body: bytes, /, *, queue: str):
31-
with self.__producer_channel_lock:
32-
self.__producer_channel.basic_publish(
33-
exchange="",
34-
routing_key=queue,
35-
body=body,
36-
)
30+
self.__channel.declare_queue(queue=queue, durable=True)
31+
self.__channel.publish(exchange="", routing_key=queue, body=body)
3732

3833
def purge(self, *, queue: str):
39-
with self.__producer_channel_lock:
40-
self.__producer_channel.queue_declare(queue=queue, durable=True)
41-
self.__producer_channel.queue_purge(queue=queue)
34+
self.__channel.declare_queue(queue=queue, durable=True)
35+
self.__channel.purge(queue=queue)
4236

4337
def receive(self, queuespec: QueueSpec, /) -> PikaReceiver:
44-
if not queuespec.queues:
45-
raise ValueError("Must specify at least one queue")
46-
if len(queuespec.queues) != 1:
47-
raise ValueError("Only one queue is supported")
48-
49-
receiver = PikaReceiver(
50-
connection_params=self.__connection_params,
51-
queue=queuespec.queues[0],
52-
prefetch=queuespec.concurrency,
53-
)
38+
receiver = PikaReceiver(self.__connection, queuespec)
5439
self.__receivers.add(receiver)
5540
return receiver
5641

@@ -62,3 +47,7 @@ def shutdown(self):
6247
self.__shutdown = True
6348
for receiver in self.__receivers:
6449
receiver.shutdown()
50+
self.__connection.close()
51+
52+
def __del__(self):
53+
self.shutdown()

qio/pika/receiver.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,42 @@
22
from threading import Lock
33
from typing import cast
44

5-
from pika import BlockingConnection
6-
from pika import ConnectionParameters
7-
from pika import URLParameters
8-
95
from qio.message import Message
6+
from qio.queuespec import QueueSpec
107
from qio.receiver import Receiver
118

9+
from .threadsafe import ThreadsafeConnection
10+
1211

1312
class PikaReceiver(Receiver):
1413
def __init__(
1514
self,
16-
*,
17-
connection_params: ConnectionParameters | URLParameters,
18-
queue: str,
19-
prefetch: int,
15+
connection: ThreadsafeConnection,
16+
queuespec: QueueSpec,
17+
/,
2018
):
21-
self.__connection = BlockingConnection(connection_params)
22-
self.__channel = self.__connection.channel()
23-
self.__channel.queue_declare(queue=queue, durable=True)
24-
self.__prefetch_lock = Lock()
25-
self.__prefetch = prefetch
26-
self.__channel.basic_qos(prefetch_count=prefetch, global_qos=True)
27-
self.__iterator = self.__channel.consume(queue=queue)
19+
if len(queuespec.queues) == 0:
20+
raise ValueError("Must specify at least one queue")
21+
if len(queuespec.queues) != 1:
22+
raise ValueError("Only one queue is supported")
23+
24+
self.__channel = connection.channel()
25+
self.__consumer_tag = dict[str, str]()
2826
self.__tag = dict[Message, int]()
29-
self.__suspended = set[Message]()
27+
28+
self.__channel.declare_queue(queue=queuespec.queues[0], durable=True)
29+
self.__prefetch_lock = Lock()
30+
self.__prefetch = 0
31+
self.__adjust_prefetch(+queuespec.concurrency)
32+
self.__channel.consume(queuespec.queues[0])
33+
34+
def __adjust_prefetch(self, change: int) -> None:
35+
with self.__prefetch_lock:
36+
self.__prefetch += change
37+
self.__channel.qos(prefetch_count=self.__prefetch, global_qos=True)
3038

3139
def __iter__(self) -> Iterator[Message]:
32-
for method, _, body in self.__iterator:
40+
for method, _, body in self.__channel.messages():
3341
message = Message(body)
3442
tag = cast(int, method.delivery_tag)
3543
self.__tag[message] = tag
@@ -41,39 +49,24 @@ def pause(self, message: Message, /):
4149
The message processing is not completed, and is expected to unpause,
4250
but its assigned capacity may be allocated elsewhere temporarily.
4351
"""
44-
with self.__prefetch_lock:
45-
self.__prefetch += 1
46-
prefetch = self.__prefetch # Memo for the lambda
47-
self.__connection.add_callback_threadsafe(
48-
lambda: self.__channel.basic_qos(prefetch_count=prefetch)
49-
)
52+
self.__adjust_prefetch(+1)
5053

5154
def unpause(self, message: Message, /):
5255
"""Unpause processing of a message.
5356
5457
The previously paused message processing is resuming, so its assigned
5558
capacity is no longer available for allocation elsewhere.
5659
"""
57-
with self.__prefetch_lock:
58-
self.__prefetch -= 1
59-
prefetch = self.__prefetch # Memo for the lambda
60-
self.__connection.add_callback_threadsafe(
61-
lambda: self.__channel.basic_qos(prefetch_count=prefetch)
62-
)
60+
self.__adjust_prefetch(-1)
6361

6462
def finish(self, message: Message, /):
6563
"""Finish processing a message.
6664
6765
The message is done processing, and its assigned capacity may be
6866
allocated elsewhere permanently.
6967
"""
70-
self.__connection.add_callback_threadsafe(
71-
lambda: self.__channel.basic_ack(delivery_tag=self.__tag.pop(message))
72-
)
68+
self.__channel.ack(delivery_tag=self.__tag.pop(message))
7369

7470
def shutdown(self):
75-
self.__connection.add_callback_threadsafe(self.__shutdown)
76-
77-
def __shutdown(self):
78-
self.__channel.cancel()
79-
self.__connection.close()
71+
for consumer_tag in list(self.__consumer_tag.values()):
72+
self.__channel.cancel(consumer_tag=consumer_tag)

qio/pika/threadsafe.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from collections.abc import Callable
2+
from collections.abc import Iterator
3+
from collections.abc import Mapping
4+
from concurrent.futures import Future
5+
from contextlib import suppress
6+
from queue import Queue
7+
from queue import ShutDown
8+
from threading import Event
9+
from threading import Thread
10+
from typing import Any
11+
from typing import AnyStr
12+
from typing import cast
13+
14+
from pika import SelectConnection
15+
from pika import frame
16+
from pika import spec
17+
from pika.channel import Channel
18+
from pika.connection import Parameters
19+
from pika.exceptions import ConnectionClosedByClient
20+
21+
22+
class ThreadsafeConnection:
23+
def __init__(self, connection_params: Parameters):
24+
opened = Future()
25+
self.__closed = Future()
26+
self.__connection = SelectConnection(
27+
connection_params,
28+
lambda _: opened.set_result(None),
29+
lambda _, exc: opened.set_exception(cast(BaseException, exc)),
30+
lambda _, exc: self.__closed.set_exception(exc),
31+
)
32+
self.__thread = Thread(target=self.__connection.ioloop.start)
33+
self.__thread.start()
34+
opened.result()
35+
36+
def __wait[T](self, fn: Callable[[], T]):
37+
# TODO: Throw connection exceptions on all waiters
38+
event = Event()
39+
self.__connection.ioloop.add_callback_threadsafe(
40+
lambda: event.set() if fn() else event.set()
41+
)
42+
event.wait()
43+
44+
def channel(self, channel_number: int | None = None) -> ThreadsafeChannel:
45+
future = Future[Channel]()
46+
self.__wait(
47+
lambda: self.__connection.channel(
48+
channel_number=channel_number,
49+
on_open_callback=future.set_result,
50+
)
51+
)
52+
return ThreadsafeChannel(self.__wait, future.result())
53+
54+
def close(self, reply_code: int = 200, reply_text: str = "Normal shutdown"):
55+
self.__wait(
56+
lambda: self.__connection.close(
57+
reply_code=reply_code,
58+
reply_text=reply_text,
59+
)
60+
)
61+
with suppress(ConnectionClosedByClient):
62+
self.__closed.result()
63+
self.__connection.ioloop.stop()
64+
self.__thread.join()
65+
66+
67+
class ThreadsafeChannel:
68+
def __init__(
69+
self,
70+
wait: Callable[[Callable[[], Any]], None],
71+
channel: Channel,
72+
):
73+
self.__wait = wait
74+
self.__channel = channel
75+
self.__messages = Queue[
76+
tuple[spec.Basic.Deliver, spec.BasicProperties, bytes]
77+
]()
78+
self.__channel.add_on_close_callback(
79+
lambda c, e: self.__messages.shutdown(immediate=True)
80+
)
81+
82+
def declare_queue(
83+
self,
84+
queue: str,
85+
passive: bool = False,
86+
durable: bool = False,
87+
exclusive: bool = False,
88+
auto_delete: bool = False,
89+
arguments: Mapping[str, Any] | None = None,
90+
) -> frame.Method[spec.Queue.DeclareOk]:
91+
future: Future[frame.Method[spec.Queue.DeclareOk]] = Future()
92+
self.__wait(
93+
lambda: self.__channel.queue_declare(
94+
queue=queue,
95+
passive=passive,
96+
durable=durable,
97+
exclusive=exclusive,
98+
auto_delete=auto_delete,
99+
arguments=arguments,
100+
callback=future.set_result,
101+
)
102+
)
103+
return future.result()
104+
105+
def publish(
106+
self,
107+
exchange: str,
108+
routing_key: str,
109+
body: AnyStr,
110+
properties: spec.BasicProperties | None = None,
111+
mandatory: bool = False,
112+
):
113+
self.__wait(
114+
lambda: self.__channel.basic_publish(
115+
exchange=exchange,
116+
routing_key=routing_key,
117+
body=body,
118+
properties=properties,
119+
mandatory=mandatory,
120+
)
121+
)
122+
123+
def purge(self, queue: str) -> frame.Method[spec.Queue.PurgeOk]:
124+
future: Future[frame.Method[spec.Queue.PurgeOk]] = Future()
125+
self.__wait(
126+
lambda: self.__channel.queue_purge(
127+
queue=queue,
128+
callback=future.set_result,
129+
)
130+
)
131+
132+
return future.result()
133+
134+
def consume(
135+
self,
136+
queue: str,
137+
auto_ack: bool = False,
138+
exclusive: bool = False,
139+
consumer_tag: str | None = None,
140+
arguments: Mapping[str, Any] | None = None,
141+
) -> frame.Method[spec.Basic.ConsumeOk]:
142+
future: Future[frame.Method[spec.Basic.ConsumeOk]] = Future()
143+
self.__wait(
144+
lambda: self.__channel.basic_consume(
145+
queue=queue,
146+
on_message_callback=lambda _, m, p, b: self.__messages.put((m, p, b)),
147+
auto_ack=auto_ack,
148+
exclusive=exclusive,
149+
consumer_tag=consumer_tag,
150+
arguments=arguments,
151+
callback=future.set_result,
152+
)
153+
)
154+
return future.result()
155+
156+
def cancel(self, consumer_tag: str) -> frame.Method[spec.Basic.CancelOk]:
157+
future: Future[frame.Method[spec.Basic.CancelOk]] = Future()
158+
self.__wait(
159+
lambda: self.__channel.basic_cancel(
160+
consumer_tag=consumer_tag,
161+
callback=lambda r: future.set_result(r),
162+
)
163+
)
164+
return future.result()
165+
166+
def qos(
167+
self,
168+
prefetch_size: int = 0,
169+
prefetch_count: int = 0,
170+
global_qos: bool = False,
171+
) -> frame.Method[spec.Basic.QosOk]:
172+
future: Future[frame.Method[spec.Basic.QosOk]] = Future()
173+
self.__wait(
174+
lambda: self.__channel.basic_qos(
175+
prefetch_size=prefetch_size,
176+
prefetch_count=prefetch_count,
177+
global_qos=global_qos,
178+
callback=future.set_result,
179+
)
180+
)
181+
return future.result()
182+
183+
def ack(self, delivery_tag: int = 0, multiple: bool = False):
184+
self.__wait(
185+
lambda: self.__channel.basic_ack(
186+
delivery_tag=delivery_tag,
187+
multiple=multiple,
188+
)
189+
)
190+
191+
def messages(
192+
self,
193+
) -> Iterator[tuple[spec.Basic.Deliver, spec.BasicProperties, bytes]]:
194+
while True:
195+
try:
196+
yield self.__messages.get()
197+
except ShutDown:
198+
return

0 commit comments

Comments
 (0)