Skip to content

Commit a84626a

Browse files
committed
feat: access_hash gathering
1 parent 894e2d2 commit a84626a

3 files changed

Lines changed: 51 additions & 48 deletions

File tree

userbot/src/db/models.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,17 @@
1111
class Account(Base):
1212
__tablename__ = 'accounts'
1313
account_id = Column(Integer, primary_key=True)
14-
user_telegram_id = Column(BIGINT, unique=True, nullable=True) # Can be null until first login
14+
user_telegram_id = Column(BIGINT, unique=True, nullable=True)
1515
api_id = Column(BYTEA, nullable=False)
1616
api_hash = Column(BYTEA, nullable=False)
1717
account_name = Column(String(255), unique=True, nullable=False)
1818
is_enabled = Column(Boolean, default=True, nullable=False)
1919
lang_code = Column(String(10), default='ru', nullable=False)
2020

21-
# Device info
2221
device_model = Column(TEXT)
2322
system_version = Column(TEXT)
2423
app_version = Column(TEXT)
2524

26-
# Proxy info
2725
proxy_type = Column(TEXT)
2826
proxy_ip = Column(TEXT)
2927
proxy_port = Column(Integer)
@@ -46,11 +44,18 @@ class Session(Base):
4644
server_address = Column(TEXT)
4745
port = Column(Integer)
4846
auth_key_data = Column(BYTEA)
47+
48+
# New fields for self-entity caching
49+
self_user_id = Column(BIGINT)
50+
self_access_hash = Column(BIGINT)
51+
52+
# Update state fields
4953
pts = Column(Integer)
5054
qts = Column(Integer)
5155
date = Column(BIGINT)
5256
seq = Column(Integer)
5357
takeout_id = Column(BIGINT)
58+
5459
last_used_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False)
5560
created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False)
5661

userbot/src/db_manager.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,37 +28,26 @@ async def add_account(db: AsyncSession, account_name: str, api_id: str, api_hash
2828
)
2929
db.add(new_account)
3030
await db.flush()
31-
logger.info(f"Added account '{account_name}' with ID: {new_account.account_id}")
3231
return new_account
3332
except IntegrityError:
34-
logger.warning(f"Account with name '{account_name}' or user_id '{user_telegram_id}' already exists.")
35-
await db.rollback()
36-
return None
33+
await db.rollback(); return None
3734
except Exception as e:
38-
logger.error(f"Error adding account '{account_name}': {e}")
39-
await db.rollback()
40-
raise
35+
await db.rollback(); raise e
4136

4237
async def get_account(db: AsyncSession, account_name: str) -> Optional[Account]:
43-
result = await db.execute(select(Account).where(Account.account_name == account_name))
44-
return result.scalars().first()
38+
result = await db.execute(select(Account).where(Account.account_name == account_name)); return result.scalars().first()
4539

4640
async def get_account_by_id(db: AsyncSession, account_id: int) -> Optional[Account]:
47-
"""Retrieves an account by its primary key ID."""
48-
result = await db.execute(select(Account).where(Account.account_id == account_id))
49-
return result.scalars().first()
41+
result = await db.execute(select(Account).where(Account.account_id == account_id)); return result.scalars().first()
5042

5143
async def get_account_by_user_id(db: AsyncSession, user_id: int) -> Optional[Account]:
52-
result = await db.execute(select(Account).where(Account.user_telegram_id == user_id))
53-
return result.scalars().first()
44+
result = await db.execute(select(Account).where(Account.user_telegram_id == user_id)); return result.scalars().first()
5445

5546
async def get_all_accounts(db: AsyncSession) -> List[Account]:
56-
result = await db.execute(select(Account).options(selectinload(Account.session)).order_by(Account.account_id))
57-
return result.scalars().all()
47+
result = await db.execute(select(Account).options(selectinload(Account.session)).order_by(Account.account_id)); return result.scalars().all()
5848

5949
async def get_all_active_accounts(db: AsyncSession) -> List[Account]:
60-
result = await db.execute(select(Account).where(Account.is_enabled == True))
61-
return result.scalars().all()
50+
result = await db.execute(select(Account).where(Account.is_enabled == True)); return result.scalars().all()
6251

6352
async def delete_account(db: AsyncSession, account_name: str) -> bool:
6453
account = await get_account(db, account_name)

userbot/src/db_session.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from telethon.sessions.abstract import Session
66
from telethon.crypto import AuthKey
7-
from telethon.tl.types import InputPeerSelf
7+
from telethon.tl.types import User, InputPeerUser
88

99
import userbot.src.db_manager as db_manager
1010
from userbot.src.db.session import get_db
@@ -14,17 +14,12 @@
1414
class DbSession(Session):
1515
"""
1616
A Telethon session class that stores session data in a PostgreSQL database.
17-
This implementation correctly provides all the abstract methods and properties
18-
required by Telethon's base Session class.
17+
It implements entity processing to cache the current user's ID and access_hash,
18+
which is crucial for the client's startup and authorization checks.
1919
"""
2020

2121
def __init__(self, account_id: int):
22-
"""
23-
Initializes the database-backed session.
24-
25-
Args:
26-
account_id (int): The unique identifier for the account this session belongs to.
27-
"""
22+
"""Initializes the database-backed session."""
2823
super().__init__()
2924
if not isinstance(account_id, int):
3025
raise ValueError("DbSession requires a valid integer account_id.")
@@ -36,16 +31,18 @@ def __init__(self, account_id: int):
3631
self._server_address: Optional[str] = None
3732
self._port: int = 443
3833
self._takeout_id: Optional[int] = None
34+
35+
# In-memory cache for the self user entity
36+
self._self_user_id: Optional[int] = None
37+
self._self_access_hash: Optional[int] = None
3938

4039
self._pts: Optional[int] = None
4140
self._qts: Optional[int] = None
4241
self._date: Optional[int] = None
4342
self._seq: Optional[int] = None
4443

4544
async def load(self) -> None:
46-
"""
47-
Loads the session data for the current account_id from the database.
48-
"""
45+
"""Loads session data, including the self-user cache, from the database."""
4946
logger.debug(f"Attempting to load session for account_id: {self.account_id}")
5047
async with get_db() as db:
5148
session_data = await db_manager.get_session(db, self.account_id)
@@ -61,15 +58,20 @@ async def load(self) -> None:
6158
self._qts = session_data.qts
6259
self._date = session_data.date
6360
self._seq = session_data.seq
61+
62+
# Load self entity from cache
63+
self._self_user_id = session_data.self_user_id
64+
self._self_access_hash = session_data.self_access_hash
6465
else:
65-
logger.info(f"No session data in DB for account_id: {self.account_id}. New login required.")
66+
logger.info(f"No session data in DB for account_id: {self.account_id}.")
6667

6768
async def save(self) -> None:
68-
"""Saves the current session data to the database."""
69+
"""Saves session data, including the self-user cache, to the database."""
6970
session_data = {
7071
"account_id": self.account_id, "dc_id": self._dc_id,
7172
"server_address": self._server_address, "port": self._port,
7273
"auth_key_data": self._auth_key.key if self._auth_key else None,
74+
"self_user_id": self._self_user_id, "self_access_hash": self._self_access_hash,
7375
"pts": self._pts, "qts": self._qts, "date": self._date,
7476
"seq": self._seq, "takeout_id": self._takeout_id,
7577
}
@@ -78,7 +80,6 @@ async def save(self) -> None:
7880
logger.info(f"Session saved for account_id: {self.account_id}")
7981

8082
async def delete(self) -> None:
81-
"""Deletes the session for the current account from the database."""
8283
async with get_db() as db:
8384
await db_manager.delete_session(db, self.account_id)
8485
self._auth_key = None; self._dc_id = 0
@@ -108,8 +109,7 @@ def get_update_state(self, entity_id: int) -> Optional[Tuple[int, int, int, int,
108109
def set_update_state(self, entity_id: int, state: Any):
109110
if isinstance(state.date, datetime):
110111
date_ts = int(state.date.replace(tzinfo=timezone.utc).timestamp())
111-
else:
112-
date_ts = int(state.date)
112+
else: date_ts = int(state.date)
113113
self._pts, self._qts, self._date, self._seq = state.pts, state.qts, date_ts, state.seq
114114

115115
async def close(self) -> None: pass
@@ -118,18 +118,27 @@ def get_update_states(self) -> List[Tuple[int, int, int, int, int, int]]:
118118
if self._pts is None: return []
119119
return [(0, self._pts, self._qts, self._date, self._seq, 0)]
120120

121-
def process_entities(self, tlo: object) -> None: pass
122-
123121
def get_input_entity(self, key: Any) -> Any:
122+
if key == 0:
123+
if self._self_user_id and self._self_access_hash:
124+
return InputPeerUser(self._self_user_id, self._self_access_hash)
125+
raise KeyError("Entity not found in DbSession cache. It should be populated by process_entities.")
126+
127+
def process_entities(self, tlo: object) -> None:
124128
"""
125-
Returns InputPeerSelf() if key is 0, which tells Telethon to use the
126-
currently authorized user. This is the correct way to handle "self" lookups.
129+
Processes a TLObject to find and cache the 'self' user entity.
130+
This is the standard mechanism Telethon uses to update the session
131+
with the current user's ID and access_hash after login.
127132
"""
128-
if key == 0:
129-
return InputPeerSelf()
130-
raise KeyError("Entity not found in DbSession cache (caching is not implemented).")
133+
if not hasattr(tlo, '__iter__'):
134+
tlo = (tlo,)
135+
136+
for entity in tlo:
137+
if isinstance(entity, User) and entity.is_self:
138+
self._self_user_id = entity.id
139+
self._self_access_hash = entity.access_hash
140+
# No need to save immediately, Telethon will call .save() when it's appropriate.
141+
break
131142

132143
def cache_file(self, md5_digest: bytes, file_size: int, instance: Any) -> None: pass
133-
134-
def get_file(self, md5_digest: bytes, file_size: int, exact: bool = True) -> Optional[Any]:
135-
return None
144+
def get_file(self, md5_digest: bytes, file_size: int, exact: bool = True) -> Optional[Any]: return None

0 commit comments

Comments
 (0)