Skip to content

Commit b6c4bff

Browse files
committed
fix(session): handle fatal errors and improve session concurrency
1 parent 0d049c2 commit b6c4bff

3 files changed

Lines changed: 138 additions & 54 deletions

File tree

pyrogram/client.py

Lines changed: 106 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
1818

1919
import asyncio
20+
from contextlib import suppress
2021
import functools
2122
import inspect
2223
import logging
@@ -395,6 +396,7 @@ def __init__(
395396
self.sessions = {}
396397
self.media_sessions = {}
397398
self.sessions_lock = asyncio.Lock()
399+
self._session_futures = {}
398400

399401
self.save_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions)
400402
self.get_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions)
@@ -1371,67 +1373,121 @@ async def get_session(
13711373
return self.session
13721374

13731375
sessions = self.media_sessions if is_media else self.sessions
1376+
session_key = (dc_id, is_media)
13741377

1375-
if not temporary and sessions.get(dc_id):
1376-
return sessions[dc_id]
1378+
creator = False
13771379

1378-
if not server_address or not port:
1379-
dc_option = await self.get_dc_option(dc_id, is_media=is_media, ipv6=self.ipv6, is_cdn=is_cdn)
1380+
if not temporary:
1381+
async with self.sessions_lock:
1382+
if sessions.get(dc_id):
1383+
return sessions[dc_id]
13801384

1381-
server_address = server_address or dc_option.ip_address
1382-
port = port or dc_option.port
1385+
pending_session = self._session_futures.get(session_key)
13831386

1384-
if is_media:
1385-
auth_key = (await self.get_session(dc_id)).auth_key
1386-
else:
1387-
if not is_current_dc:
1388-
auth_key = await Auth(
1389-
self,
1390-
dc_id,
1391-
server_address,
1392-
port,
1393-
await self.storage.test_mode()
1394-
).create()
1395-
else:
1396-
auth_key = await self.storage.auth_key()
1397-
1398-
session = Session(
1399-
self,
1400-
dc_id,
1401-
server_address,
1402-
port,
1403-
auth_key,
1404-
await self.storage.test_mode(),
1405-
is_media=is_media
1406-
)
1387+
if pending_session is None:
1388+
pending_session = self.loop.create_future()
1389+
pending_session.add_done_callback(
1390+
lambda future: None if future.cancelled() else future.exception()
1391+
)
1392+
self._session_futures[session_key] = pending_session
1393+
creator = True
14071394

1408-
if not temporary:
1409-
sessions[dc_id] = session
1395+
if not creator:
1396+
return await pending_session
14101397

1411-
await session.start()
1398+
session = None
14121399

1413-
if not is_current_dc and export_authorization:
1414-
for _ in range(3):
1415-
exported_auth = await self.invoke(
1416-
raw.functions.auth.ExportAuthorization(
1417-
dc_id=dc_id
1418-
)
1419-
)
1400+
try:
1401+
if not server_address or not port:
1402+
dc_option = await self.get_dc_option(dc_id, is_media=is_media, ipv6=self.ipv6, is_cdn=is_cdn)
14201403

1421-
try:
1422-
await session.invoke(
1423-
raw.functions.auth.ImportAuthorization(
1424-
id=exported_auth.id,
1425-
bytes=exported_auth.bytes
1404+
server_address = server_address or dc_option.ip_address
1405+
port = port or dc_option.port
1406+
1407+
if is_media:
1408+
auth_key = (await self.get_session(dc_id)).auth_key
1409+
else:
1410+
if not is_current_dc:
1411+
auth_key = await Auth(
1412+
self,
1413+
dc_id,
1414+
server_address,
1415+
port,
1416+
await self.storage.test_mode()
1417+
).create()
1418+
else:
1419+
auth_key = await self.storage.auth_key()
1420+
1421+
session = Session(
1422+
self,
1423+
dc_id,
1424+
server_address,
1425+
port,
1426+
auth_key,
1427+
await self.storage.test_mode(),
1428+
is_media=is_media
1429+
)
1430+
1431+
await session.start()
1432+
1433+
try:
1434+
await asyncio.wait_for(session.is_started.wait(), Session.WAIT_TIMEOUT)
1435+
except asyncio.TimeoutError as e:
1436+
with suppress(Exception):
1437+
await session.stop()
1438+
session = None
1439+
raise ConnectionError(f"Failed to start session for DC{dc_id}") from e
1440+
1441+
if not is_current_dc and export_authorization:
1442+
for _ in range(3):
1443+
exported_auth = await self.invoke(
1444+
raw.functions.auth.ExportAuthorization(
1445+
dc_id=dc_id
14261446
)
14271447
)
1428-
except AuthBytesInvalid:
1429-
continue
1448+
1449+
try:
1450+
await session.invoke(
1451+
raw.functions.auth.ImportAuthorization(
1452+
id=exported_auth.id,
1453+
bytes=exported_auth.bytes
1454+
)
1455+
)
1456+
except AuthBytesInvalid:
1457+
continue
1458+
else:
1459+
break
14301460
else:
1431-
break
1432-
else:
1433-
await session.stop()
1434-
raise AuthBytesInvalid
1461+
await session.stop()
1462+
session = None
1463+
raise AuthBytesInvalid
1464+
1465+
if not temporary:
1466+
async with self.sessions_lock:
1467+
cached_session = sessions.get(dc_id)
1468+
1469+
if cached_session is None:
1470+
sessions[dc_id] = session
1471+
else:
1472+
session = cached_session
1473+
1474+
pending_session = self._session_futures.pop(session_key, None)
1475+
1476+
if pending_session is not None and not pending_session.done():
1477+
pending_session.set_result(session)
1478+
except Exception as e:
1479+
if session is not None:
1480+
with suppress(Exception):
1481+
await session.stop()
1482+
1483+
if not temporary:
1484+
async with self.sessions_lock:
1485+
pending_session = self._session_futures.pop(session_key, None)
1486+
1487+
if pending_session is not None and not pending_session.done():
1488+
pending_session.set_exception(e)
1489+
1490+
raise
14351491

14361492
return session
14371493

pyrogram/methods/auth/send_phone_number_code.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import pyrogram
2424
from pyrogram import enums, raw, types
25-
from pyrogram.errors import NetworkMigrate, PhoneMigrate
25+
from pyrogram.errors import NetworkMigrateX, PhoneMigrateX
2626

2727
log = logging.getLogger(__name__)
2828

@@ -165,7 +165,7 @@ async def send_phone_number_code(
165165
)
166166

167167
r = await self.invoke(rpc, recaptcha_token=recaptcha_token)
168-
except (PhoneMigrate, NetworkMigrate) as e:
168+
except (PhoneMigrateX, NetworkMigrateX) as e:
169169
dc_option = await self.get_dc_option(e.value, ipv6=self.ipv6)
170170
await self.session.stop()
171171

pyrogram/session/session.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(
140140

141141
self.is_started = asyncio.Event()
142142
self.restart_lock = asyncio.Lock()
143+
self.fatal_error: Optional[BaseException] = None
143144

144145
@property
145146
def state(self) -> SessionState:
@@ -160,6 +161,7 @@ async def start(self):
160161
return
161162

162163
await self._set_state(SessionState.STARTING)
164+
self.fatal_error = None
163165

164166
self.connection = self.client.connection_factory(
165167
dc_id=self.dc_id,
@@ -289,6 +291,16 @@ async def restart(self):
289291
await self.stop()
290292
await self.start()
291293

294+
def _fail_pending_results(self, error: BaseException) -> None:
295+
for result in self.results.values():
296+
if result.value is None:
297+
result.value = error
298+
result.event.set()
299+
300+
def _set_fatal_error(self, error: BaseException) -> None:
301+
self.fatal_error = error
302+
self._fail_pending_results(error)
303+
292304
async def handle_packet(self, packet):
293305
try:
294306
data = await self.client.loop.run_in_executor(
@@ -299,7 +311,7 @@ async def handle_packet(self, packet):
299311
self.auth_key,
300312
self.auth_key_id
301313
)
302-
except ValueError as e:
314+
except (ConnectionError, SecurityCheckMismatch, ValueError) as e:
303315
log.debug(e)
304316
log.info("Restarting session due to - %s - %s", e.__class__.__name__, e)
305317
self.client.loop.create_task(self.restart())
@@ -443,9 +455,10 @@ async def recv_worker(self):
443455
if packet:
444456
error_code = -Int.read(BytesIO(packet))
445457
error_msg = "unknown error"
458+
transport_error = None
446459

447460
if error_code == 404:
448-
raise AuthKeyNotFound(
461+
transport_error = AuthKeyNotFound(
449462
"Auth key not found in the system. Try again or delete your session file "
450463
"and log in again with your phone number or bot token."
451464
)
@@ -460,10 +473,19 @@ async def recv_worker(self):
460473
"Invalid data center. Please check your configuration."
461474
)
462475
except TransportError as e:
476+
transport_error = e
463477
error_msg = str(e)
478+
else:
479+
if transport_error is not None:
480+
error_msg = str(transport_error)
464481

465482
log.warning("Server sent transport error: %s (%s)", error_code, error_msg)
466483

484+
if isinstance(transport_error, AuthKeyNotFound):
485+
self._set_fatal_error(transport_error)
486+
self.client.loop.create_task(self.stop())
487+
break
488+
467489

468490
if self.is_started.is_set():
469491
if packet:
@@ -483,6 +505,9 @@ async def recv_worker(self):
483505
async def send(
484506
self, data: TLObject, wait_response: bool = True, timeout: float = WAIT_TIMEOUT
485507
):
508+
if self.fatal_error is not None:
509+
raise self.fatal_error
510+
486511
message = await self.msg_factory.create(data)
487512
msg_id = message.msg_id
488513

@@ -518,6 +543,9 @@ async def send(
518543
if result is None:
519544
raise TimeoutError("Request timed out")
520545

546+
if isinstance(result, BaseException):
547+
raise result
548+
521549
if isinstance(result, raw.types.RpcError):
522550
if isinstance(
523551
data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)

0 commit comments

Comments
 (0)