Skip to content
This repository was archived by the owner on Jul 24, 2025. It is now read-only.

Commit 8697c59

Browse files
authored
Merge pull request #4 from danfimov/add-broker
feat: add AsyncpgBroker
2 parents bb07f89 + 0ccf471 commit 8697c59

6 files changed

Lines changed: 458 additions & 10 deletions

File tree

README.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# TaskIQ - Asyncpg
22

3-
TaskIQ-Asyncpg is a plugin for taskiq that adds a new result backend based on PostgreSQL and [Asyncpg](https://github.com/MagicStack/asyncpg).
3+
TaskIQ-Asyncpg is a plugin for taskiq that adds a new result backend and broker based on PostgreSQL and [Asyncpg](https://github.com/MagicStack/asyncpg).
44

55
## Installation
66
To use this project you must have installed core taskiq library:
@@ -34,18 +34,15 @@ Let's see the example with the redis broker and PostgreSQL Asyncpg result backen
3434
# broker.py
3535
import asyncio
3636

37-
from taskiq_redis import ListQueueBroker
38-
from taskiq_asyncpg import AsyncpgResultBackend
37+
from taskiq_asyncpg import AsyncpgResultBackend, AsyncpgBroker
3938

40-
asyncpg_result_backend = AsyncpgResultBackend(
39+
result_backend = AsyncpgResultBackend(
4140
dsn="postgres://postgres:postgres@localhost:5432/postgres",
4241
)
4342

44-
# Or you can use PubSubBroker if you need broadcasting
45-
broker = ListQueueBroker(
46-
url="redis://localhost:6379",
47-
result_backend=asyncpg_result_backend,
48-
)
43+
broker = AsyncpgBroker(
44+
dsn="postgres://postgres:postgres@localhost:5432/postgres",
45+
).with_result_backend(result_backend)
4946

5047

5148
@broker.task

taskiq_asyncpg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from taskiq_asyncpg.broker import AsyncpgBroker
12
from taskiq_asyncpg.result_backend import AsyncpgResultBackend
23

34
__all__ = [
5+
"AsyncpgBroker",
46
"AsyncpgResultBackend",
57
]

taskiq_asyncpg/broker.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import asyncio
2+
import json
3+
import logging
4+
from collections.abc import AsyncGenerator
5+
from typing import Any, Final, Optional, Union, cast
6+
7+
import asyncpg
8+
from taskiq import AckableMessage, AsyncBroker, BrokerMessage
9+
from typing_extensions import override
10+
11+
from taskiq_asyncpg.exceptions import DatabaseConnectionError
12+
from taskiq_asyncpg.queries import (
13+
CREATE_TABLE_MESSAGES_QUERY,
14+
DELETE_MESSAGE_QUERY,
15+
INSERT_MESSAGE_QUERY,
16+
SELECT_MESSAGE_QUERY,
17+
)
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class AsyncpgBroker(AsyncBroker):
23+
"""Broker for TaskIQ based on Asyncpg."""
24+
25+
def __init__(
26+
self,
27+
dsn: Optional[str] = "postgres://postgres:postgres@localhost:5432/postgres",
28+
channel_name: str = "taskiq",
29+
table_name: str = "taskiq_messages",
30+
max_retry_attempts: int = 5,
31+
**connect_kwargs: Any,
32+
) -> None:
33+
"""
34+
Construct a new broker.
35+
36+
:param dsn: connection string to PostgreSQL.
37+
:param channel_name: Name of the channel to listen on.
38+
:param table_name: Name of the table to store messages.
39+
:param max_retry_attempts: Maximum number of message processing attempts.
40+
:param connect_kwargs: additional arguments for nats `ConnectionPool` class.
41+
"""
42+
super().__init__()
43+
self._dsn: Final = dsn
44+
self.channel_name: Final = channel_name
45+
self.table_name: Final = table_name
46+
self.connect_kwargs: Final = connect_kwargs
47+
self.max_retry_attempts: Final = max_retry_attempts
48+
49+
self.read_conn: Optional["asyncpg.Connection[asyncpg.Record]"] = None
50+
self.write_pool: Optional["asyncpg.pool.Pool[asyncpg.Record]"] = None
51+
self._queue: Optional[asyncio.Queue[str]] = None
52+
53+
@override
54+
async def startup(self) -> None:
55+
"""Initialize the broker."""
56+
await super().startup()
57+
58+
try:
59+
self.read_conn = await asyncpg.connect(self._dsn, **self.connect_kwargs)
60+
self.write_pool = await asyncpg.create_pool(self._dsn)
61+
62+
if self.read_conn is None:
63+
msg = "read_conn not initialized"
64+
raise RuntimeError(msg)
65+
if self.write_pool is None:
66+
msg = "write_pool not initialized"
67+
raise RuntimeError(msg)
68+
69+
async with self.write_pool.acquire() as conn:
70+
_ = await conn.execute(
71+
CREATE_TABLE_MESSAGES_QUERY.format(self.table_name),
72+
)
73+
74+
await self.read_conn.add_listener(
75+
self.channel_name,
76+
self._notification_handler,
77+
)
78+
self._queue = asyncio.Queue()
79+
except Exception as error:
80+
raise DatabaseConnectionError(str(error)) from error
81+
82+
@override
83+
async def shutdown(self) -> None:
84+
"""Close all connections on shutdown."""
85+
await super().shutdown()
86+
if self.read_conn is not None:
87+
await self.read_conn.close()
88+
if self.write_pool is not None:
89+
await self.write_pool.close()
90+
91+
def _notification_handler(
92+
self,
93+
con_ref: Union[
94+
"asyncpg.Connection[asyncpg.Record]",
95+
"asyncpg.pool.PoolConnectionProxy[asyncpg.Record]",
96+
],
97+
pid: int,
98+
channel: str,
99+
payload: object,
100+
/,
101+
) -> None:
102+
"""Handle NOTIFY messages.
103+
104+
From asyncpg.connection.add_listener docstring:
105+
A callable or a coroutine function receiving the following arguments:
106+
**con_ref**: a Connection the callback is registered with;
107+
**pid**: PID of the Postgres server that sent the notification;
108+
**channel**: name of the channel the notification was sent to;
109+
**payload**: the payload.
110+
"""
111+
logger.debug("Received notification on channel %s: %s", channel, payload)
112+
if self._queue is not None:
113+
self._queue.put_nowait(str(payload))
114+
115+
@override
116+
async def kick(self, message: BrokerMessage) -> None:
117+
"""
118+
Send message to the channel.
119+
120+
Inserts the message into the database and sends a NOTIFY.
121+
122+
:param message: Message to send.
123+
"""
124+
if self.write_pool is None:
125+
raise ValueError("Please run startup before kicking.")
126+
127+
async with self.write_pool.acquire() as conn:
128+
# Insert the message into the database
129+
message_inserted_id = cast(
130+
int,
131+
await conn.fetchval(
132+
INSERT_MESSAGE_QUERY.format(self.table_name),
133+
message.task_id,
134+
message.task_name,
135+
message.message.decode(),
136+
json.dumps(message.labels),
137+
),
138+
)
139+
140+
delay_value = message.labels.get("delay")
141+
if delay_value is not None:
142+
delay_seconds = int(delay_value)
143+
_ = asyncio.create_task( # noqa: RUF006
144+
self._schedule_notification(message_inserted_id, delay_seconds),
145+
)
146+
else:
147+
# Send a NOTIFY with the message ID as payload
148+
_ = await conn.execute(
149+
f"NOTIFY {self.channel_name}, '{message_inserted_id}'",
150+
)
151+
152+
async def _schedule_notification(self, message_id: int, delay_seconds: int) -> None:
153+
"""Schedule a notification to be sent after a delay."""
154+
await asyncio.sleep(delay_seconds)
155+
if self.write_pool is None:
156+
return
157+
async with self.write_pool.acquire() as conn:
158+
# Send NOTIFY
159+
_ = await conn.execute(f"NOTIFY {self.channel_name}, '{message_id}'")
160+
161+
@override
162+
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
163+
"""
164+
Listen to the channel.
165+
166+
Yields messages as they are received.
167+
168+
:yields: AckableMessage instances.
169+
"""
170+
if self.read_conn is None:
171+
raise ValueError("Call startup before starting listening.")
172+
if self._queue is None:
173+
raise ValueError("Startup did not initialize the queue.")
174+
175+
while True:
176+
try:
177+
payload = await self._queue.get()
178+
message_id = int(payload)
179+
message_row = await self.read_conn.fetchrow(
180+
SELECT_MESSAGE_QUERY.format(self.table_name), message_id,
181+
)
182+
if message_row is None:
183+
logger.warning(
184+
f"Message with id {message_id} not found in database.",
185+
)
186+
continue
187+
if message_row.get("message") is None:
188+
msg = "Message row does not have 'message' column"
189+
raise ValueError(msg)
190+
message_str = message_row["message"]
191+
if not isinstance(message_str, str):
192+
msg = "message is not a string"
193+
raise ValueError(msg)
194+
message_data = message_str.encode()
195+
196+
async def ack(*, _message_id: int = message_id) -> None:
197+
if self.write_pool is None:
198+
raise ValueError("Call startup before starting listening.")
199+
200+
async with self.write_pool.acquire() as conn:
201+
_ = await conn.execute(
202+
DELETE_MESSAGE_QUERY.format(self.table_name),
203+
_message_id,
204+
)
205+
206+
yield AckableMessage(data=message_data, ack=ack)
207+
except Exception as e:
208+
logger.exception(f"Error processing message: {e}")
209+
continue

taskiq_asyncpg/queries.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,24 @@
2626
DELETE_RESULT_QUERY = """
2727
DELETE FROM {} WHERE task_id = $1
2828
"""
29+
30+
CREATE_TABLE_MESSAGES_QUERY = """
31+
CREATE TABLE IF NOT EXISTS {} (
32+
id SERIAL PRIMARY KEY,
33+
task_id VARCHAR NOT NULL,
34+
task_name VARCHAR NOT NULL,
35+
message TEXT NOT NULL,
36+
labels JSONB NOT NULL,
37+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
38+
);
39+
"""
40+
41+
INSERT_MESSAGE_QUERY = """
42+
INSERT INTO {} (task_id, task_name, message, labels)
43+
VALUES ($1, $2, $3, $4)
44+
RETURNING id
45+
"""
46+
47+
SELECT_MESSAGE_QUERY = "SELECT * FROM {} WHERE id = $1"
48+
49+
DELETE_MESSAGE_QUERY = "DELETE FROM {} WHERE id = $1"

tests/conftest.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import string
44
from typing import AsyncGenerator, TypeVar
55

6+
import asyncpg
67
import pytest
78

9+
from taskiq_asyncpg.broker import AsyncpgBroker
810
from taskiq_asyncpg.result_backend import AsyncpgResultBackend
911

1012
_ReturnType = TypeVar("_ReturnType")
@@ -30,7 +32,7 @@ def postgres_table() -> str:
3032
"""
3133
return "".join(
3234
random.choice(
33-
string.ascii_uppercase,
35+
string.ascii_lowercase,
3436
)
3537
for _ in range(10)
3638
)
@@ -48,6 +50,18 @@ def postgresql_dsn() -> str:
4850
or "postgresql://postgres:postgres@localhost:5432/taskiqasyncpg"
4951
)
5052

53+
@pytest.fixture
54+
async def connection(postgresql_dsn: str) -> AsyncGenerator[asyncpg.Connection, None]:
55+
"""
56+
Fixture to create a connection to PostgreSQL.
57+
58+
:param postgresql_dsn: DSN to PostgreSQL.
59+
:return: connection to PostgreSQL.
60+
"""
61+
conn = await asyncpg.connect(postgresql_dsn)
62+
yield conn
63+
await conn.close()
64+
5165

5266
@pytest.fixture()
5367
async def asyncpg_result_backend(
@@ -62,3 +76,27 @@ async def asyncpg_result_backend(
6276
yield backend
6377
await backend._database_pool.execute(f"DROP TABLE {postgres_table}")
6478
await backend.shutdown()
79+
80+
81+
@pytest.fixture()
82+
async def asyncpg_broker(
83+
postgresql_dsn: str,
84+
postgres_table: str,
85+
) -> AsyncGenerator[AsyncpgBroker, None]:
86+
"""
87+
Fixture to set up and tear down the broker.
88+
89+
Initializes the broker with test parameters.
90+
"""
91+
broker = AsyncpgBroker(
92+
dsn=postgresql_dsn,
93+
channel_name=f"{postgres_table}_channel",
94+
table_name=postgres_table,
95+
)
96+
await broker.startup()
97+
yield broker
98+
assert broker.write_pool
99+
await broker.write_pool.execute(
100+
f"DROP TABLE {postgres_table}",
101+
)
102+
await broker.shutdown()

0 commit comments

Comments
 (0)