Skip to content

Commit faa7568

Browse files
committed
fix: improve session management and resource cleanup
1 parent 2c84cd3 commit faa7568

3 files changed

Lines changed: 54 additions & 29 deletions

File tree

pyrogram/client.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ def __init__(
366366
self.connection_factory = connection_factory
367367
self.protocol_factory = protocol_factory
368368

369+
if self.workers < 1:
370+
raise ValueError(f"workers must be >= 1, got {self.workers}")
371+
if self.max_concurrent_transmissions < 1:
372+
raise ValueError(f"max_concurrent_transmissions must be >= 1, got {self.max_concurrent_transmissions}")
373+
369374
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
370375

371376
self.storage: Storage
@@ -442,14 +447,15 @@ def loop(self) -> asyncio.AbstractEventLoop:
442447
def loop(self, value: asyncio.AbstractEventLoop):
443448
self._loop = value
444449

445-
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
450+
if not hasattr(self, "listeners"):
451+
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
446452

447453
def __enter__(self):
448-
return self.start()
454+
return utils.get_event_loop().run_until_complete(self.start())
449455

450456
def __exit__(self, *args):
451457
try:
452-
self.stop()
458+
utils.get_event_loop().run_until_complete(self.stop())
453459
except ConnectionError:
454460
pass
455461

@@ -646,10 +652,10 @@ async def authorize(self) -> User:
646652

647653
return signed_up
648654

649-
async def authorize_qr(self, except_ids: List[int] = []) -> "User":
655+
async def authorize_qr(self, except_ids: List[int] = None) -> "User":
650656
from qrcode import QRCode
651657

652-
qr_login = QRLogin(self, except_ids)
658+
qr_login = QRLogin(self, except_ids or [])
653659
await qr_login.recreate()
654660

655661
qr = QRCode(version=1)
@@ -896,7 +902,9 @@ async def handle_updates(self, updates):
896902
)
897903
)
898904

899-
if diff.new_messages:
905+
if isinstance(diff, (raw.types.updates.DifferenceEmpty, raw.types.updates.DifferenceTooLong)):
906+
pass
907+
elif getattr(diff, "new_messages", None):
900908
self.dispatcher.updates_queue.put_nowait((
901909
raw.types.UpdateNewMessage(
902910
message=diff.new_messages[0],
@@ -906,9 +914,8 @@ async def handle_updates(self, updates):
906914
{u.id: u for u in diff.users},
907915
{c.id: c for c in diff.chats}
908916
))
909-
else:
910-
if diff.other_updates: # The other_updates list can be empty
911-
self.dispatcher.updates_queue.put_nowait((diff.other_updates[0], {}, {}))
917+
elif getattr(diff, "other_updates", None):
918+
self.dispatcher.updates_queue.put_nowait((diff.other_updates[0], {}, {}))
912919
elif isinstance(updates, raw.types.UpdateShort):
913920
self.dispatcher.updates_queue.put_nowait((updates.update, {}, {}))
914921
elif isinstance(updates, raw.types.UpdatesTooLong):
@@ -1106,17 +1113,21 @@ async def handle_download(self, packet):
11061113
try:
11071114
async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args):
11081115
file.write(chunk)
1109-
except BaseException as e:
1116+
except (SystemExit, KeyboardInterrupt, GeneratorExit):
11101117
if not in_memory:
11111118
file.close()
11121119
os.remove(temp_file_path)
1113-
1114-
if isinstance(e, asyncio.CancelledError):
1115-
raise e
1116-
1117-
if isinstance(e, (FloodWaitX, FloodPremiumWaitX)):
1118-
raise e
1119-
1120+
raise
1121+
except (asyncio.CancelledError, FloodWaitX, FloodPremiumWaitX):
1122+
if not in_memory:
1123+
file.close()
1124+
os.remove(temp_file_path)
1125+
raise
1126+
except Exception as e:
1127+
if not in_memory:
1128+
file.close()
1129+
os.remove(temp_file_path)
1130+
log.exception("Download failed: %s", e)
11201131
return None
11211132
else:
11221133
if in_memory:
@@ -1279,8 +1290,8 @@ async def get_file(
12791290

12801291
# https://core.telegram.org/cdn#verifying-files
12811292
def _check_all_hashes():
1282-
for i, h in enumerate(hashes):
1283-
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
1293+
for h in hashes:
1294+
cdn_chunk = decrypted_chunk[h.offset - offset_bytes: h.offset - offset_bytes + h.limit]
12841295
CDNFileHashMismatch.check(
12851296
h.hash == sha256(cdn_chunk).digest(),
12861297
"h.hash == sha256(cdn_chunk).digest()"
@@ -1308,16 +1319,15 @@ def _check_all_hashes():
13081319

13091320
if len(chunk) < chunk_size or current >= total:
13101321
break
1311-
except Exception as e:
1312-
raise e
13131322
finally:
13141323
await cdn_session.stop()
13151324
except pyrogram.StopTransmission:
13161325
raise
13171326
except (FloodWaitX, FloodPremiumWaitX):
13181327
raise
13191328
except Exception as e:
1320-
log.exception(e)
1329+
log.exception("get_file error: %s", e)
1330+
raise
13211331

13221332
async def get_session(
13231333
self,
@@ -1507,7 +1517,8 @@ async def get_dc_option(
15071517
is_cdn: bool = False,
15081518
ipv6: bool = False
15091519
) -> "raw.types.DcOption":
1510-
self.__config = await self.invoke(raw.functions.help.GetConfig())
1520+
if self.__config is None:
1521+
self.__config = await self.invoke(raw.functions.help.GetConfig())
15111522

15121523
if dc_id is None:
15131524
dc_id = self.__config.this_dc

pyrogram/methods/auth/connect.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,17 @@ async def connect(
4141

4242
await self.load_session()
4343

44-
self.session = await self.get_session(
45-
server_address=await self.storage.server_address(),
46-
port=await self.storage.port(),
47-
export_authorization=False,
48-
temporary=True
49-
)
44+
try:
45+
self.session = await self.get_session(
46+
server_address=await self.storage.server_address(),
47+
port=await self.storage.port(),
48+
export_authorization=False,
49+
temporary=True
50+
)
51+
except Exception:
52+
await self.storage.close()
53+
raise
54+
5055
self.is_connected = True
5156

5257
is_ipv6_session = ":" in await self.storage.server_address()

pyrogram/methods/auth/terminate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,20 @@ async def terminate(
5757

5858
self.media_sessions.clear()
5959

60+
for aux_session in self.sessions.values():
61+
await aux_session.stop()
62+
63+
self.sessions.clear()
64+
self._session_futures.clear()
65+
6066
self.updates_watchdog_event.set()
6167

6268
if self.updates_watchdog_task is not None:
6369
await self.updates_watchdog_task
6470

6571
self.updates_watchdog_event.clear()
6672

73+
if hasattr(self, "executor") and self.executor:
74+
self.executor.shutdown(wait=False)
75+
6776
self.is_initialized = False

0 commit comments

Comments
 (0)