@@ -366,6 +366,11 @@ def __init__(
366366 self .connection_factory = connection_factory
367367 self .protocol_factory = protocol_factory
368368
369+ if self .workers < 1 :
370+ raise ValueError (f"workers must be >= 1, got { self .workers } " )
371+ if self .max_concurrent_transmissions < 1 :
372+ raise ValueError (f"max_concurrent_transmissions must be >= 1, got { self .max_concurrent_transmissions } " )
373+
369374 self .executor = ThreadPoolExecutor (self .workers , thread_name_prefix = "Handler" )
370375
371376 self .storage : Storage
@@ -442,14 +447,15 @@ def loop(self) -> asyncio.AbstractEventLoop:
442447 def loop (self , value : asyncio .AbstractEventLoop ):
443448 self ._loop = value
444449
445- self .listeners = {listener_type : [] for listener_type in pyrogram .enums .ListenerTypes }
450+ if not hasattr (self , "listeners" ):
451+ self .listeners = {listener_type : [] for listener_type in pyrogram .enums .ListenerTypes }
446452
447453 def __enter__ (self ):
448- return self .start ()
454+ return utils . get_event_loop (). run_until_complete ( self .start () )
449455
450456 def __exit__ (self , * args ):
451457 try :
452- self .stop ()
458+ utils . get_event_loop (). run_until_complete ( self .stop () )
453459 except ConnectionError :
454460 pass
455461
@@ -646,10 +652,10 @@ async def authorize(self) -> User:
646652
647653 return signed_up
648654
649- async def authorize_qr (self , except_ids : List [int ] = [] ) -> "User" :
655+ async def authorize_qr (self , except_ids : List [int ] = None ) -> "User" :
650656 from qrcode import QRCode
651657
652- qr_login = QRLogin (self , except_ids )
658+ qr_login = QRLogin (self , except_ids or [] )
653659 await qr_login .recreate ()
654660
655661 qr = QRCode (version = 1 )
@@ -896,7 +902,9 @@ async def handle_updates(self, updates):
896902 )
897903 )
898904
899- if diff .new_messages :
905+ if isinstance (diff , (raw .types .updates .DifferenceEmpty , raw .types .updates .DifferenceTooLong )):
906+ pass
907+ elif getattr (diff , "new_messages" , None ):
900908 self .dispatcher .updates_queue .put_nowait ((
901909 raw .types .UpdateNewMessage (
902910 message = diff .new_messages [0 ],
@@ -906,9 +914,8 @@ async def handle_updates(self, updates):
906914 {u .id : u for u in diff .users },
907915 {c .id : c for c in diff .chats }
908916 ))
909- else :
910- if diff .other_updates : # The other_updates list can be empty
911- self .dispatcher .updates_queue .put_nowait ((diff .other_updates [0 ], {}, {}))
917+ elif getattr (diff , "other_updates" , None ):
918+ self .dispatcher .updates_queue .put_nowait ((diff .other_updates [0 ], {}, {}))
912919 elif isinstance (updates , raw .types .UpdateShort ):
913920 self .dispatcher .updates_queue .put_nowait ((updates .update , {}, {}))
914921 elif isinstance (updates , raw .types .UpdatesTooLong ):
@@ -1106,17 +1113,21 @@ async def handle_download(self, packet):
11061113 try :
11071114 async for chunk in self .get_file (file_id , file_size , 0 , 0 , progress , progress_args ):
11081115 file .write (chunk )
1109- except BaseException as e :
1116+ except ( SystemExit , KeyboardInterrupt , GeneratorExit ) :
11101117 if not in_memory :
11111118 file .close ()
11121119 os .remove (temp_file_path )
1113-
1114- if isinstance (e , asyncio .CancelledError ):
1115- raise e
1116-
1117- if isinstance (e , (FloodWaitX , FloodPremiumWaitX )):
1118- raise e
1119-
1120+ raise
1121+ except (asyncio .CancelledError , FloodWaitX , FloodPremiumWaitX ):
1122+ if not in_memory :
1123+ file .close ()
1124+ os .remove (temp_file_path )
1125+ raise
1126+ except Exception as e :
1127+ if not in_memory :
1128+ file .close ()
1129+ os .remove (temp_file_path )
1130+ log .exception ("Download failed: %s" , e )
11201131 return None
11211132 else :
11221133 if in_memory :
@@ -1279,8 +1290,8 @@ async def get_file(
12791290
12801291 # https://core.telegram.org/cdn#verifying-files
12811292 def _check_all_hashes ():
1282- for i , h in enumerate ( hashes ) :
1283- cdn_chunk = decrypted_chunk [h .limit * i : h .limit * ( i + 1 ) ]
1293+ for h in hashes :
1294+ cdn_chunk = decrypted_chunk [h .offset - offset_bytes : h .offset - offset_bytes + h . limit ]
12841295 CDNFileHashMismatch .check (
12851296 h .hash == sha256 (cdn_chunk ).digest (),
12861297 "h.hash == sha256(cdn_chunk).digest()"
@@ -1308,16 +1319,15 @@ def _check_all_hashes():
13081319
13091320 if len (chunk ) < chunk_size or current >= total :
13101321 break
1311- except Exception as e :
1312- raise e
13131322 finally :
13141323 await cdn_session .stop ()
13151324 except pyrogram .StopTransmission :
13161325 raise
13171326 except (FloodWaitX , FloodPremiumWaitX ):
13181327 raise
13191328 except Exception as e :
1320- log .exception (e )
1329+ log .exception ("get_file error: %s" , e )
1330+ raise
13211331
13221332 async def get_session (
13231333 self ,
@@ -1507,7 +1517,8 @@ async def get_dc_option(
15071517 is_cdn : bool = False ,
15081518 ipv6 : bool = False
15091519 ) -> "raw.types.DcOption" :
1510- self .__config = await self .invoke (raw .functions .help .GetConfig ())
1520+ if self .__config is None :
1521+ self .__config = await self .invoke (raw .functions .help .GetConfig ())
15111522
15121523 if dc_id is None :
15131524 dc_id = self .__config .this_dc
0 commit comments