|
17 | 17 | # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. |
18 | 18 |
|
19 | 19 | import asyncio |
| 20 | +from contextlib import suppress |
20 | 21 | import functools |
21 | 22 | import inspect |
22 | 23 | import logging |
@@ -395,6 +396,7 @@ def __init__( |
395 | 396 | self.sessions = {} |
396 | 397 | self.media_sessions = {} |
397 | 398 | self.sessions_lock = asyncio.Lock() |
| 399 | + self._session_futures = {} |
398 | 400 |
|
399 | 401 | self.save_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions) |
400 | 402 | self.get_file_semaphore = asyncio.Semaphore(self.max_concurrent_transmissions) |
@@ -1371,67 +1373,121 @@ async def get_session( |
1371 | 1373 | return self.session |
1372 | 1374 |
|
1373 | 1375 | sessions = self.media_sessions if is_media else self.sessions |
| 1376 | + session_key = (dc_id, is_media) |
1374 | 1377 |
|
1375 | | - if not temporary and sessions.get(dc_id): |
1376 | | - return sessions[dc_id] |
| 1378 | + creator = False |
1377 | 1379 |
|
1378 | | - if not server_address or not port: |
1379 | | - dc_option = await self.get_dc_option(dc_id, is_media=is_media, ipv6=self.ipv6, is_cdn=is_cdn) |
| 1380 | + if not temporary: |
| 1381 | + async with self.sessions_lock: |
| 1382 | + if sessions.get(dc_id): |
| 1383 | + return sessions[dc_id] |
1380 | 1384 |
|
1381 | | - server_address = server_address or dc_option.ip_address |
1382 | | - port = port or dc_option.port |
| 1385 | + pending_session = self._session_futures.get(session_key) |
1383 | 1386 |
|
1384 | | - if is_media: |
1385 | | - auth_key = (await self.get_session(dc_id)).auth_key |
1386 | | - else: |
1387 | | - if not is_current_dc: |
1388 | | - auth_key = await Auth( |
1389 | | - self, |
1390 | | - dc_id, |
1391 | | - server_address, |
1392 | | - port, |
1393 | | - await self.storage.test_mode() |
1394 | | - ).create() |
1395 | | - else: |
1396 | | - auth_key = await self.storage.auth_key() |
1397 | | - |
1398 | | - session = Session( |
1399 | | - self, |
1400 | | - dc_id, |
1401 | | - server_address, |
1402 | | - port, |
1403 | | - auth_key, |
1404 | | - await self.storage.test_mode(), |
1405 | | - is_media=is_media |
1406 | | - ) |
| 1387 | + if pending_session is None: |
| 1388 | + pending_session = self.loop.create_future() |
| 1389 | + pending_session.add_done_callback( |
| 1390 | + lambda future: None if future.cancelled() else future.exception() |
| 1391 | + ) |
| 1392 | + self._session_futures[session_key] = pending_session |
| 1393 | + creator = True |
1407 | 1394 |
|
1408 | | - if not temporary: |
1409 | | - sessions[dc_id] = session |
| 1395 | + if not creator: |
| 1396 | + return await pending_session |
1410 | 1397 |
|
1411 | | - await session.start() |
| 1398 | + session = None |
1412 | 1399 |
|
1413 | | - if not is_current_dc and export_authorization: |
1414 | | - for _ in range(3): |
1415 | | - exported_auth = await self.invoke( |
1416 | | - raw.functions.auth.ExportAuthorization( |
1417 | | - dc_id=dc_id |
1418 | | - ) |
1419 | | - ) |
| 1400 | + try: |
| 1401 | + if not server_address or not port: |
| 1402 | + dc_option = await self.get_dc_option(dc_id, is_media=is_media, ipv6=self.ipv6, is_cdn=is_cdn) |
1420 | 1403 |
|
1421 | | - try: |
1422 | | - await session.invoke( |
1423 | | - raw.functions.auth.ImportAuthorization( |
1424 | | - id=exported_auth.id, |
1425 | | - bytes=exported_auth.bytes |
| 1404 | + server_address = server_address or dc_option.ip_address |
| 1405 | + port = port or dc_option.port |
| 1406 | + |
| 1407 | + if is_media: |
| 1408 | + auth_key = (await self.get_session(dc_id)).auth_key |
| 1409 | + else: |
| 1410 | + if not is_current_dc: |
| 1411 | + auth_key = await Auth( |
| 1412 | + self, |
| 1413 | + dc_id, |
| 1414 | + server_address, |
| 1415 | + port, |
| 1416 | + await self.storage.test_mode() |
| 1417 | + ).create() |
| 1418 | + else: |
| 1419 | + auth_key = await self.storage.auth_key() |
| 1420 | + |
| 1421 | + session = Session( |
| 1422 | + self, |
| 1423 | + dc_id, |
| 1424 | + server_address, |
| 1425 | + port, |
| 1426 | + auth_key, |
| 1427 | + await self.storage.test_mode(), |
| 1428 | + is_media=is_media |
| 1429 | + ) |
| 1430 | + |
| 1431 | + await session.start() |
| 1432 | + |
| 1433 | + try: |
| 1434 | + await asyncio.wait_for(session.is_started.wait(), Session.WAIT_TIMEOUT) |
| 1435 | + except asyncio.TimeoutError as e: |
| 1436 | + with suppress(Exception): |
| 1437 | + await session.stop() |
| 1438 | + session = None |
| 1439 | + raise ConnectionError(f"Failed to start session for DC{dc_id}") from e |
| 1440 | + |
| 1441 | + if not is_current_dc and export_authorization: |
| 1442 | + for _ in range(3): |
| 1443 | + exported_auth = await self.invoke( |
| 1444 | + raw.functions.auth.ExportAuthorization( |
| 1445 | + dc_id=dc_id |
1426 | 1446 | ) |
1427 | 1447 | ) |
1428 | | - except AuthBytesInvalid: |
1429 | | - continue |
| 1448 | + |
| 1449 | + try: |
| 1450 | + await session.invoke( |
| 1451 | + raw.functions.auth.ImportAuthorization( |
| 1452 | + id=exported_auth.id, |
| 1453 | + bytes=exported_auth.bytes |
| 1454 | + ) |
| 1455 | + ) |
| 1456 | + except AuthBytesInvalid: |
| 1457 | + continue |
| 1458 | + else: |
| 1459 | + break |
1430 | 1460 | else: |
1431 | | - break |
1432 | | - else: |
1433 | | - await session.stop() |
1434 | | - raise AuthBytesInvalid |
| 1461 | + await session.stop() |
| 1462 | + session = None |
| 1463 | + raise AuthBytesInvalid |
| 1464 | + |
| 1465 | + if not temporary: |
| 1466 | + async with self.sessions_lock: |
| 1467 | + cached_session = sessions.get(dc_id) |
| 1468 | + |
| 1469 | + if cached_session is None: |
| 1470 | + sessions[dc_id] = session |
| 1471 | + else: |
| 1472 | + session = cached_session |
| 1473 | + |
| 1474 | + pending_session = self._session_futures.pop(session_key, None) |
| 1475 | + |
| 1476 | + if pending_session is not None and not pending_session.done(): |
| 1477 | + pending_session.set_result(session) |
| 1478 | + except Exception as e: |
| 1479 | + if session is not None: |
| 1480 | + with suppress(Exception): |
| 1481 | + await session.stop() |
| 1482 | + |
| 1483 | + if not temporary: |
| 1484 | + async with self.sessions_lock: |
| 1485 | + pending_session = self._session_futures.pop(session_key, None) |
| 1486 | + |
| 1487 | + if pending_session is not None and not pending_session.done(): |
| 1488 | + pending_session.set_exception(e) |
| 1489 | + |
| 1490 | + raise |
1435 | 1491 |
|
1436 | 1492 | return session |
1437 | 1493 |
|
|
0 commit comments