Skip to content

Commit e75cc8c

Browse files
committed
refactor(session): harden mtproto auth and optimize session handling
1 parent 8639dd0 commit e75cc8c

4 files changed

Lines changed: 153 additions & 134 deletions

File tree

pyrogram/session/auth.py

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ def __init__(
6060

6161
@staticmethod
6262
def pack(data: TLObject, server_time: float) -> bytes:
63+
payload = data.write()
6364
return (
6465
bytes(8)
6566
+ Long(int(server_time * (2**32)) & ~0b11)
66-
+ Int(len(data.write()))
67-
+ data.write()
67+
+ Int(len(payload))
68+
+ payload
6869
)
6970

7071
@staticmethod
@@ -150,7 +151,7 @@ async def create(self):
150151

151152
log.debug("Done encrypt data with RSA")
152153

153-
# Step 5. TODO: Handle "server_DH_params_fail". Code assumes response is ok
154+
# Step 5
154155
log.debug("Send req_DH_params")
155156
server_dh_params = await self.invoke(
156157
raw.functions.ReqDHParams(
@@ -163,6 +164,9 @@ async def create(self):
163164
)
164165
)
165166

167+
if isinstance(server_dh_params, raw.types.ServerDhParamsFail):
168+
raise Exception("Server DH params generation failed")
169+
166170
encrypted_answer = server_dh_params.encrypted_answer
167171

168172
server_nonce = server_nonce.to_bytes(16, "little", signed=True)
@@ -189,21 +193,49 @@ async def create(self):
189193

190194
dh_prime = int.from_bytes(server_dh_inner_data.dh_prime, "big")
191195
delta_time = server_dh_inner_data.server_time - time.time()
192-
193196
log.debug("Delta time: %s", round(delta_time, 3))
194197

195-
# Step 6
196198
g = server_dh_inner_data.g
199+
g_a = int.from_bytes(server_dh_inner_data.g_a, "big")
200+
201+
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
202+
answer = server_dh_inner_data.write()
203+
SecurityCheckMismatch.check(
204+
answer_with_hash[:20] == sha1(answer).digest(),
205+
"answer_with_hash[:20] == sha1(answer).digest()"
206+
)
207+
log.debug("SHA1 hash values check: OK")
208+
209+
# Validate DH parameters BEFORE expensive modular exponentiation
210+
SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME, "dh_prime == prime.CURRENT_DH_PRIME")
211+
SecurityCheckMismatch.check(1 < g < dh_prime - 1, "1 < g < dh_prime - 1")
212+
SecurityCheckMismatch.check(1 < g_a < dh_prime - 1, "1 < g_a < dh_prime - 1")
213+
SecurityCheckMismatch.check(
214+
2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64),
215+
"2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)"
216+
)
217+
log.debug("DH parameters and g_a validation: OK")
218+
219+
# Step 6 — now safe to compute
197220
b = int.from_bytes(urandom(256), "big")
198-
g_b = pow(g, b, dh_prime).to_bytes(256, "big")
221+
g_b = pow(g, b, dh_prime)
222+
223+
SecurityCheckMismatch.check(1 < g_b < dh_prime - 1, "1 < g_b < dh_prime - 1")
224+
SecurityCheckMismatch.check(
225+
2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64),
226+
"2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)"
227+
)
228+
log.debug("g_b validation: OK")
229+
230+
g_b_bytes = g_b.to_bytes(256, "big")
199231

200232
retry_id = 0
201233

202234
data = raw.types.ClientDHInnerData(
203235
nonce=nonce,
204236
server_nonce=server_nonce,
205237
retry_id=retry_id,
206-
g_b=g_b
238+
g_b=g_b_bytes
207239
).write()
208240

209241
sha = sha1(data).digest()
@@ -220,56 +252,25 @@ async def create(self):
220252
)
221253
)
222254

223-
# TODO: Handle "auth_key_aux_hash" if the previous step fails
255+
if isinstance(set_client_dh_params_answer, raw.types.DhGenFail):
256+
raise Exception("DH key generation failed (dh_gen_fail)")
257+
258+
if isinstance(set_client_dh_params_answer, raw.types.DhGenRetry):
259+
raise Exception("DH key generation requires retry (dh_gen_retry)")
224260

225261
# Step 7; Step 8
226-
g_a = int.from_bytes(server_dh_inner_data.g_a, "big")
227262
auth_key = pow(g_a, b, dh_prime).to_bytes(256, "big")
228263
server_nonce = server_nonce.to_bytes(16, "little", signed=True)
229264

230-
# TODO: Handle errors
231-
232-
#######################
233-
# Security checks
234-
#######################
235-
236-
SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME, "dh_prime == prime.CURRENT_DH_PRIME")
237-
log.debug("DH parameters check: OK")
238-
239-
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
240-
g_b = int.from_bytes(g_b, "big")
241-
SecurityCheckMismatch.check(1 < g < dh_prime - 1, "1 < g < dh_prime - 1")
242-
SecurityCheckMismatch.check(1 < g_a < dh_prime - 1, "1 < g_a < dh_prime - 1")
243-
SecurityCheckMismatch.check(1 < g_b < dh_prime - 1, "1 < g_b < dh_prime - 1")
244-
SecurityCheckMismatch.check(
245-
2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64),
246-
"2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)"
247-
)
248-
SecurityCheckMismatch.check(
249-
2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64),
250-
"2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)"
251-
)
252-
log.debug("g_a and g_b validation: OK")
253-
254-
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
255-
answer = server_dh_inner_data.write() # Call .write() to remove padding
256-
SecurityCheckMismatch.check(
257-
answer_with_hash[:20] == sha1(answer).digest(),
258-
"answer_with_hash[:20] == sha1(answer).digest()"
259-
)
260-
log.debug("SHA1 hash values check: OK")
261-
262265
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
263-
# 1st message
264266
SecurityCheckMismatch.check(nonce == res_pq.nonce, "nonce == res_pq.nonce")
265-
# 2nd message
267+
266268
server_nonce = int.from_bytes(server_nonce, "little", signed=True)
267269
SecurityCheckMismatch.check(nonce == server_dh_params.nonce, "nonce == server_dh_params.nonce")
268270
SecurityCheckMismatch.check(
269271
server_nonce == server_dh_params.server_nonce,
270272
"server_nonce == server_dh_params.server_nonce"
271273
)
272-
# 3rd message
273274
SecurityCheckMismatch.check(
274275
nonce == set_client_dh_params_answer.nonce,
275276
"nonce == set_client_dh_params_answer.nonce"
@@ -287,13 +288,12 @@ async def create(self):
287288
log.debug("Server salt: %s", int.from_bytes(server_salt, "little"))
288289

289290
log.info("Done auth key exchange: %s", set_client_dh_params_answer.__class__.__name__)
290-
except ConnectionError as e:
291+
except ConnectionError:
291292
log.info("Unable to connect due to network issues. Retrying...")
292-
# Treat like transient network error: retry according to MAX_RETRIES
293293
if retries_left:
294294
retries_left -= 1
295295
else:
296-
raise e
296+
raise
297297
await asyncio.sleep(1)
298298
continue
299299
except Exception as e:
@@ -302,7 +302,7 @@ async def create(self):
302302
if retries_left > 0:
303303
retries_left -= 1
304304
elif retries_left == 0:
305-
raise e
305+
raise
306306

307307
await asyncio.sleep(1)
308308
continue

pyrogram/session/internals/msg_factory.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,11 @@ def __init__(self, client: "pyrogram.Client"):
3030
self.client = client
3131

3232
self._last_msg_id = 0
33-
34-
self._msg_id_lock = asyncio.Lock()
35-
self._seq_no_lock = asyncio.Lock()
36-
33+
self._lock = asyncio.Lock()
3734
self._content_related_messages_sent = 0
3835

3936
async def allocate_message_identity(self) -> int:
40-
async with self._msg_id_lock:
37+
async with self._lock:
4138
base_msg_id = int(self.client.server_time * (2**32)) & ~0b11
4239

4340
if base_msg_id <= self._last_msg_id:
@@ -47,19 +44,19 @@ async def allocate_message_identity(self) -> int:
4744

4845
return base_msg_id
4946

50-
async def allocate_message_sequence(self, is_content_related: bool) -> int:
51-
async with self._seq_no_lock:
52-
seq_no = (self._content_related_messages_sent * 2) + (1 if is_content_related else 0)
47+
async def create(self, body: TLObject) -> Message:
48+
async with self._lock:
49+
base_msg_id = int(self.client.server_time * (2**32)) & ~0b11
5350

54-
if is_content_related:
55-
self._content_related_messages_sent += 1
51+
if base_msg_id <= self._last_msg_id:
52+
base_msg_id = self._last_msg_id + 4
5653

57-
return seq_no
54+
self._last_msg_id = base_msg_id
5855

59-
async def create(self, body: TLObject) -> Message:
60-
msg_id = await self.allocate_message_identity()
56+
is_content_related = not isinstance(body, (Ping, HttpWait, MsgsAck, MsgContainer))
57+
seq_no = (self._content_related_messages_sent * 2) + (1 if is_content_related else 0)
6158

62-
is_content_related = not isinstance(body, (Ping, HttpWait, MsgsAck, MsgContainer))
63-
seq_no = await self.allocate_message_sequence(is_content_related)
59+
if is_content_related:
60+
self._content_related_messages_sent += 1
6461

65-
return Message(body, msg_id, seq_no, len(body))
62+
return Message(body, base_msg_id, seq_no, len(body))

0 commit comments

Comments
 (0)