Skip to content

Commit 631b337

Browse files
committed
refactor: optimize update dispatching and filter performance
1 parent e75cc8c commit 631b337

10 files changed

Lines changed: 198 additions & 132 deletions

File tree

pyrogram/dispatcher.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,25 @@ async def deleted_business_messages_parser(update, users, chats):
262262
async def start(self):
263263
if callable(self.client.start_handler):
264264
try:
265-
await self.client.start_handler(self.client)
265+
if inspect.iscoroutinefunction(self.client.start_handler):
266+
await self.client.start_handler(self.client)
267+
else:
268+
result = self.client.start_handler(self.client)
269+
if inspect.isawaitable(result):
270+
await result
266271
except Exception as e:
267-
log.exception(e)
272+
log.exception("start_handler raised: %s", e)
268273

269274
if not self.client.no_updates:
275+
self.locks_list.clear()
276+
self.handler_worker_tasks.clear()
277+
278+
if 0 not in self.groups or self.conversation_handler not in self.groups.get(0, []):
279+
if 0 not in self.groups:
280+
self.groups[0] = []
281+
self.groups = OrderedDict(sorted(self.groups.items()))
282+
self.groups[0].insert(0, self.conversation_handler)
283+
270284
for i in range(self.client.workers):
271285
self.locks_list.append(asyncio.Lock())
272286

@@ -301,7 +315,7 @@ async def stop(self, clear_handlers: bool = True):
301315
self.client.remove_listener(listener)
302316
if getattr(listener, "future", None) and not listener.future.done():
303317
try:
304-
listener.future.set_exception(asyncio.CancelledError())
318+
listener.future.cancel()
305319
except Exception:
306320
pass
307321
except Exception:
@@ -311,9 +325,14 @@ async def stop(self, clear_handlers: bool = True):
311325

312326
if callable(self.client.stop_handler):
313327
try:
314-
await self.client.stop_handler(self.client)
328+
if inspect.iscoroutinefunction(self.client.stop_handler):
329+
await self.client.stop_handler(self.client)
330+
else:
331+
result = self.client.stop_handler(self.client)
332+
if inspect.isawaitable(result):
333+
await result
315334
except Exception as e:
316-
log.exception(e)
335+
log.exception("stop_handler raised: %s", e)
317336

318337
if not self.client.no_updates:
319338
for i in range(self.client.workers):
@@ -322,8 +341,10 @@ async def stop(self, clear_handlers: bool = True):
322341
for i in self.handler_worker_tasks:
323342
await i
324343

344+
self.handler_worker_tasks.clear()
345+
self.locks_list.clear()
346+
325347
if clear_handlers:
326-
self.handler_worker_tasks.clear()
327348
self.groups.clear()
328349

329350
log.info("Stopped %s HandlerTasks", self.client.workers)
@@ -339,6 +360,8 @@ async def fn():
339360
self.groups = OrderedDict(sorted(self.groups.items()))
340361

341362
self.groups[group].append(handler)
363+
except Exception as e:
364+
log.exception("Failed to add handler: %s", e)
342365
finally:
343366
for lock in self.locks_list:
344367
lock.release()
@@ -352,14 +375,17 @@ async def fn():
352375

353376
try:
354377
if group not in self.groups:
355-
raise ValueError(
356-
f"Group {group} does not exist. Handler was not removed."
357-
)
378+
log.warning("Group %s does not exist. Handler was not removed.", group)
379+
return
358380

359381
self.groups[group].remove(handler)
360382

361383
if not self.groups[group]:
362384
del self.groups[group]
385+
except ValueError:
386+
log.warning("Handler not found in group %s.", group)
387+
except Exception as e:
388+
log.exception("Failed to remove handler: %s", e)
363389
finally:
364390
for lock in self.locks_list:
365391
lock.release()
@@ -395,16 +421,16 @@ async def handler_worker(self, lock):
395421
try:
396422
if await handler.check(self.client, parsed_update):
397423
args = (parsed_update,)
398-
except Exception as e:
399-
log.exception(e)
424+
except Exception:
425+
log.exception("Handler check failed")
400426
continue
401427

402428
elif isinstance(handler, RawUpdateHandler):
403429
try:
404430
if await handler.check(self.client, update):
405431
args = (update, users, chats)
406-
except Exception as e:
407-
log.exception(e)
432+
except Exception:
433+
log.exception("Raw handler check failed")
408434
continue
409435

410436
if args is None:
@@ -421,8 +447,7 @@ async def handler_worker(self, lock):
421447
*args
422448
)
423449
except asyncio.CancelledError:
424-
# Swallow task cancellations during shutdown/interrupt
425-
pass
450+
raise
426451
except pyrogram.StopPropagation:
427452
raise
428453
except pyrogram.ContinuePropagation:
@@ -435,8 +460,8 @@ async def handler_worker(self, lock):
435460
break
436461
except pyrogram.StopPropagation:
437462
pass
438-
except Exception as e:
439-
log.exception(e)
463+
except Exception:
464+
log.exception("Unhandled exception in handler worker")
440465

441466
async def handle_update_handler_exception(
442467
self,

pyrogram/filters.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def __or__(self, other):
4343
class InvertFilter(Filter):
4444
def __init__(self, base):
4545
self.base = base
46+
self._base_is_async = inspect.iscoroutinefunction(base.__call__)
4647

4748
async def __call__(self, client: "pyrogram.Client", update: Update):
48-
if inspect.iscoroutinefunction(self.base.__call__):
49+
if self._base_is_async:
4950
x = await self.base(client, update)
5051
else:
5152
x = await client.loop.run_in_executor(
@@ -61,9 +62,11 @@ class AndFilter(Filter):
6162
def __init__(self, base, other):
6263
self.base = base
6364
self.other = other
65+
self._base_is_async = inspect.iscoroutinefunction(base.__call__)
66+
self._other_is_async = inspect.iscoroutinefunction(other.__call__)
6467

6568
async def __call__(self, client: "pyrogram.Client", update: Update):
66-
if inspect.iscoroutinefunction(self.base.__call__):
69+
if self._base_is_async:
6770
x = await self.base(client, update)
6871
else:
6972
x = await client.loop.run_in_executor(
@@ -72,11 +75,10 @@ async def __call__(self, client: "pyrogram.Client", update: Update):
7275
client, update
7376
)
7477

75-
# short circuit
7678
if not x:
7779
return False
7880

79-
if inspect.iscoroutinefunction(self.other.__call__):
81+
if self._other_is_async:
8082
y = await self.other(client, update)
8183
else:
8284
y = await client.loop.run_in_executor(
@@ -85,16 +87,18 @@ async def __call__(self, client: "pyrogram.Client", update: Update):
8587
client, update
8688
)
8789

88-
return x and y
90+
return y
8991

9092

9193
class OrFilter(Filter):
9294
def __init__(self, base, other):
9395
self.base = base
9496
self.other = other
97+
self._base_is_async = inspect.iscoroutinefunction(base.__call__)
98+
self._other_is_async = inspect.iscoroutinefunction(other.__call__)
9599

96100
async def __call__(self, client: "pyrogram.Client", update: Update):
97-
if inspect.iscoroutinefunction(self.base.__call__):
101+
if self._base_is_async:
98102
x = await self.base(client, update)
99103
else:
100104
x = await client.loop.run_in_executor(
@@ -103,11 +107,10 @@ async def __call__(self, client: "pyrogram.Client", update: Update):
103107
client, update
104108
)
105109

106-
# short circuit
107110
if x:
108111
return True
109112

110-
if inspect.iscoroutinefunction(self.other.__call__):
113+
if self._other_is_async:
111114
y = await self.other(client, update)
112115
else:
113116
y = await client.loop.run_in_executor(
@@ -116,7 +119,7 @@ async def __call__(self, client: "pyrogram.Client", update: Update):
116119
client, update
117120
)
118121

119-
return x or y
122+
return y
120123

121124

122125
CUSTOM_FILTER_NAME = "CustomFilter"
@@ -146,7 +149,7 @@ def create(func: Callable, name: str = None, **kwargs) -> Filter:
146149
:meth:`~pyrogram.filters.command` or :meth:`~pyrogram.filters.regex`.
147150
"""
148151
return type(
149-
name or func.__name__ or CUSTOM_FILTER_NAME,
152+
name or getattr(func, "__name__", None) or CUSTOM_FILTER_NAME,
150153
(Filter,),
151154
{"__call__": func, **kwargs}
152155
)()
@@ -923,11 +926,11 @@ async def paid_message_filter(_, __, m: Message):
923926

924927
# region linked_channel_filter
925928
async def linked_channel_filter(_, __, m: Message):
926-
return bool(
927-
m.forward_origin and
928-
m.forward_origin.type == enums.MessageOriginType.CHANNEL and
929-
m.forward_origin.chat == m.sender_chat
930-
)
929+
origin = m.forward_origin
930+
if not (origin and origin.type == enums.MessageOriginType.CHANNEL and m.sender_chat):
931+
return False
932+
origin_chat = getattr(origin, "chat", None)
933+
return origin_chat is not None and origin_chat.id == m.sender_chat.id
931934

932935

933936
linked_channel = create(linked_channel_filter)
@@ -999,7 +1002,7 @@ def command(commands: Union[str, List[str]], prefixes: Optional[Union[str, List[
9991002
command_re = re.compile(r"([\"'])(.*?)(?<!\\)\1|(\S+)")
10001003

10011004
async def func(flt, client: pyrogram.Client, message: Message):
1002-
username = client.me.username or ""
1005+
username = client.me.username if client.me else ""
10031006
text = message.text or message.caption
10041007
message.command = None
10051008

@@ -1013,11 +1016,12 @@ async def func(flt, client: pyrogram.Client, message: Message):
10131016
without_prefix = text[len(prefix):]
10141017

10151018
for cmd in flt.commands:
1016-
if not re.match(rf"^(?:{cmd}(?:@?{username})?)(?:\s|$)", without_prefix,
1019+
escaped_cmd = re.escape(cmd)
1020+
if not re.match(rf"^(?:{escaped_cmd}(?:@?{username})?)(?:\s|$)", without_prefix,
10171021
flags=re.IGNORECASE if not flt.case_sensitive else 0):
10181022
continue
10191023

1020-
without_command = re.sub(rf"{cmd}(?:@?{username})?\s?", "", without_prefix, count=1,
1024+
without_command = re.sub(rf"{escaped_cmd}(?:@?{username})?\s?", "", without_prefix, count=1,
10211025
flags=re.IGNORECASE if not flt.case_sensitive else 0)
10221026

10231027
# match.groups are 1-indexed, group(1) is the quote, group(2) is the text
@@ -1087,6 +1091,8 @@ async def func(flt, _, update: Update):
10871091

10881092
if value:
10891093
update.matches = list(flt.p.finditer(value)) or None
1094+
else:
1095+
update.matches = None
10901096

10911097
return bool(update.matches)
10921098

pyrogram/handlers/callback_query_handler.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def check_if_has_matching_listener(
112112
listener_does_match = await filters(client, query)
113113
else:
114114
listener_does_match = await client.loop.run_in_executor(
115-
None, filters, client, query
115+
client.executor, filters, client, query
116116
)
117117
else:
118118
listener_does_match = True
@@ -132,32 +132,30 @@ async def check(self, client: "pyrogram.Client", query: CallbackQuery):
132132
client, query
133133
)
134134

135+
query._matched_listener = listener if listener_does_match else None
136+
135137
if callable(self.filters):
136138
if iscoroutinefunction(self.filters.__call__):
137139
handler_does_match = await self.filters(client, query)
138140
else:
139141
handler_does_match = await client.loop.run_in_executor(
140-
None, self.filters, client, query
142+
client.executor, self.filters, client, query
141143
)
142144
else:
143145
handler_does_match = True
144146

145-
data = self.compose_data_identifier(query)
146-
147-
if PyromodConfig.unallowed_click_alert:
148-
# matches with the current query but from any user
147+
if PyromodConfig.unallowed_click_alert and listener:
148+
data = self.compose_data_identifier(query)
149149
permissive_identifier = Identifier(
150150
chat_id=data.chat_id,
151151
message_id=data.message_id,
152152
inline_message_id=data.inline_message_id,
153153
from_user_id=None,
154154
)
155155

156-
matches = permissive_identifier.matches(data)
157-
158156
if (
159-
listener
160-
and (matches and not listener_does_match)
157+
permissive_identifier.matches(data)
158+
and not listener_does_match
161159
and listener.unallowed_click_alert
162160
):
163161
alert = (
@@ -168,8 +166,6 @@ async def check(self, client: "pyrogram.Client", query: CallbackQuery):
168166
await query.answer(alert)
169167
return False
170168

171-
# let handler get the chance to handle if listener
172-
# exists but its filters doesn't match
173169
return listener_does_match or handler_does_match
174170

175171
async def resolve_future_or_callback(
@@ -183,9 +179,11 @@ async def resolve_future_or_callback(
183179
:param args: The arguments to call the callback with.
184180
:return: None
185181
"""
186-
listener_does_match, listener = await self.check_if_has_matching_listener(
187-
client, query
188-
)
182+
listener = getattr(query, '_matched_listener', None)
183+
listener_does_match = listener is not None
184+
185+
if not listener_does_match:
186+
listener_does_match, listener = await self.check_if_has_matching_listener(client, query)
189187

190188
if listener and listener_does_match:
191189
client.remove_listener(listener)

pyrogram/handlers/conversation_handler.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,20 @@
2929
class ConversationHandler(MessageHandler, CallbackQueryHandler):
3030
"""The Conversation handler class."""
3131
def __init__(self):
32-
# Initialize base handler with a no-op async callback to satisfy handler interface
3332
Handler.__init__(self, self.callback)
33+
self.original_callback = self.callback
3434
self.waiters = {}
3535

36+
def register_waiter(self, chat_id, waiter):
37+
old = self.waiters.get(chat_id)
38+
if old and not old['future'].done():
39+
old['future'].cancel()
40+
self.waiters[chat_id] = waiter
41+
3642
async def check(self, client: "pyrogram.Client", update: Union[Message, CallbackQuery]):
43+
if not self.waiters:
44+
return False
45+
3746
if isinstance(update, Message) and update.outgoing:
3847
return False
3948

@@ -71,8 +80,9 @@ async def check(self, client: "pyrogram.Client", update: Union[Message, Callback
7180

7281
@staticmethod
7382
async def callback(_, __):
74-
pass
83+
raise pyrogram.StopPropagation
7584

7685
def delete_waiter(self, chat_id, future):
77-
if future == self.waiters[chat_id]['future']:
86+
waiter = self.waiters.get(chat_id)
87+
if waiter and waiter.get('future') == future:
7888
del self.waiters[chat_id]

0 commit comments

Comments
 (0)