Skip to content

Commit d653980

Browse files
committed
fix: some more fixes with cache
1 parent 0b4bb80 commit d653980

File tree

2 files changed

+58
-126
lines changed

2 files changed

+58
-126
lines changed

userbot/__init__.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from faker import Faker
77
from telethon import TelegramClient as TelethonTelegramClient, events
8-
from telethon.sessions import StringSession
98
from telethon.errors.rpcerrorlist import UserAlreadyParticipantError
109
from python_socks import ProxyType
1110

@@ -18,31 +17,22 @@
1817
from userbot.src.log_handler import DBLogHandler
1918
from userbot.src.locales import translator
2019

21-
# --- Basic Setup ---
2220
logger: logging.Logger = logging.getLogger("userbot")
2321
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
2422

25-
# --- Globals ---
2623
ACTIVE_CLIENTS: Dict[int, "TelegramClient"] = {}
2724
FAKE: Faker = Faker()
2825
GLOBAL_HELP_INFO: Dict[int, Dict[str, str]] = {}
2926

30-
# --- Helper Functions ---
3127
def _generate_random_device() -> Dict[str, str]:
32-
"""Generates a dictionary with random device information."""
3328
return {
3429
"device_model": FAKE.user_agent(),
3530
"system_version": f"SDK {FAKE.random_int(min=28, max=33)}",
3631
"app_version": f"{FAKE.random_int(min=9, max=10)}.{FAKE.random_int(min=0, max=9)}.{FAKE.random_int(min=0, max=9)}"
3732
}
3833

39-
# --- Core TelegramClient Class ---
4034
class TelegramClient(TelethonTelegramClient):
41-
"""
42-
Custom TelegramClient class with database interaction and localization methods.
43-
"""
4435
def __init__(self, *args, **kwargs):
45-
"""Initializes the custom Telegram client."""
4636
super().__init__(*args, **kwargs)
4737
self.lang_code: str = 'ru'
4838

@@ -55,7 +45,6 @@ def current_account_id(self) -> Optional[int]:
5545
async def get_string(self, key: str, module_name: Optional[str] = None, **kwargs) -> str:
5646
return translator.get_string(self.lang_code, key, module_name, **kwargs)
5747

58-
# --- Startup Logic ---
5948
async def db_setup() -> None:
6049
if not API_ID or not API_HASH:
6150
logger.critical("API_ID or API_HASH is not set. Please run 'python3 -m scripts.setup'. Exiting.")
@@ -146,16 +135,16 @@ async def manage_clients() -> None:
146135
proxy_type_enum = proxy_map.get(account.proxy_type.lower())
147136
if proxy_type_enum:
148137
proxy_details = (
149-
proxy_type_enum,
150-
account.proxy_ip,
151-
account.proxy_port,
152-
True,
138+
proxy_type_enum, account.proxy_ip, account.proxy_port, True,
153139
encryption_manager.decrypt(account.proxy_username).decode() if account.proxy_username else None,
154140
encryption_manager.decrypt(account.proxy_password).decode() if account.proxy_password else None
155141
)
142+
143+
# Use the asynchronous factory method to create the session
144+
session = await DbSession.create(account_id=account.account_id)
156145

157146
new_client: TelegramClient = TelegramClient(
158-
session=DbSession(account_id=account.account_id),
147+
session=session,
159148
api_id=int(acc_api_id),
160149
api_hash=acc_api_hash,
161150
device_model=account.device_model,

userbot/src/db_session.py

Lines changed: 53 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from telethon.sessions.abstract import Session
66
from telethon.crypto import AuthKey
7-
from telethon.tl.types import InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel
7+
from telethon.tl.types import PeerUser
88

99
import userbot.src.db_manager as db_manager
1010
from userbot.src.db.session import get_db
@@ -14,22 +14,23 @@
1414
class DbSession(Session):
1515
"""
1616
A Telethon session class that stores session data in a PostgreSQL database.
17-
This implementation correctly provides all the abstract methods and properties
18-
required by Telethon's base Session class.
1917
"""
2018

21-
def __init__(self, account_id: int):
19+
def __init__(self, account_id: int, self_user_id: Optional[int]):
2220
"""
2321
Initializes the database-backed session.
2422
2523
Args:
2624
account_id (int): The unique identifier for the account this session belongs to.
25+
self_user_id (Optional[int]): The Telegram user ID of the account holder.
2726
"""
2827
super().__init__()
2928
if not isinstance(account_id, int):
3029
raise ValueError("DbSession requires a valid integer account_id.")
3130

3231
self.account_id: int = account_id
32+
self._self_user_id: Optional[int] = self_user_id
33+
3334
self._auth_key: Optional[AuthKey] = None
3435
self._dc_id: int = 0
3536
self._server_address: Optional[str] = None
@@ -40,11 +41,29 @@ def __init__(self, account_id: int):
4041
self._qts: Optional[int] = None
4142
self._date: Optional[int] = None
4243
self._seq: Optional[int] = None
43-
44-
async def load(self) -> None:
44+
45+
@classmethod
46+
async def create(cls, account_id: int) -> "DbSession":
4547
"""
46-
Loads the session data for the current account_id from the database.
48+
Asynchronously creates and pre-loads a DbSession instance.
49+
This factory method is used to fetch necessary data like the user_telegram_id
50+
before the synchronous parts of the session are accessed by Telethon.
51+
52+
Args:
53+
account_id (int): The unique identifier for the account.
54+
55+
Returns:
56+
DbSession: A new instance of DbSession.
4757
"""
58+
self_user_id: Optional[int] = None
59+
async with get_db() as db:
60+
account = await db_manager.get_account_by_id(db, account_id)
61+
if account:
62+
self_user_id = account.user_telegram_id
63+
return cls(account_id, self_user_id)
64+
65+
async def load(self) -> None:
66+
"""Loads the session data for the current account_id from the database."""
4867
logger.debug(f"Attempting to load session for account_id: {self.account_id}")
4968
async with get_db() as db:
5069
session_data = await db_manager.get_session(db, self.account_id)
@@ -54,157 +73,81 @@ async def load(self) -> None:
5473
self._dc_id = session_data.dc_id
5574
self._server_address = session_data.server_address
5675
self._port = session_data.port
57-
58-
auth_key_bytes = session_data.auth_key_data
59-
if auth_key_bytes:
60-
self._auth_key = AuthKey(data=bytes(auth_key_bytes))
61-
else:
62-
self._auth_key = None
63-
76+
self._auth_key = AuthKey(data=bytes(session_data.auth_key_data)) if session_data.auth_key_data else None
6477
self._takeout_id = session_data.takeout_id
6578
self._pts = session_data.pts
6679
self._qts = session_data.qts
6780
self._date = session_data.date
6881
self._seq = session_data.seq
6982
else:
70-
logger.info(f"No session data found in DB for account_id: {self.account_id}. New login required.")
83+
logger.info(f"No session data in DB for account_id: {self.account_id}. New login required.")
7184

7285
async def save(self) -> None:
73-
"""
74-
Saves the current session data to the database.
75-
This is called automatically by Telethon.
76-
"""
77-
logger.debug(f"Attempting to save session for account_id: {self.account_id}")
86+
"""Saves the current session data to the database."""
7887
session_data = {
79-
"account_id": self.account_id,
80-
"dc_id": self._dc_id,
81-
"server_address": self._server_address,
82-
"port": self._port,
88+
"account_id": self.account_id, "dc_id": self._dc_id,
89+
"server_address": self._server_address, "port": self._port,
8390
"auth_key_data": self._auth_key.key if self._auth_key else None,
84-
"pts": self._pts,
85-
"qts": self._qts,
86-
"date": self._date,
87-
"seq": self._seq,
88-
"takeout_id": self._takeout_id,
91+
"pts": self._pts, "qts": self._qts, "date": self._date,
92+
"seq": self._seq, "takeout_id": self._takeout_id,
8993
}
9094
async with get_db() as db:
9195
await db_manager.add_or_update_session(db, **session_data)
9296
logger.info(f"Session saved for account_id: {self.account_id}")
9397

9498
async def delete(self) -> None:
9599
"""Deletes the session for the current account from the database."""
96-
logger.info(f"Deleting session from DB for account_id: {self.account_id}")
97100
async with get_db() as db:
98101
await db_manager.delete_session(db, self.account_id)
99-
# Clear in-memory data
100-
self._auth_key = None
101-
self._dc_id = 0
102-
103-
# --- Abstract Properties Implementation ---
102+
self._auth_key = None; self._dc_id = 0
104103

105104
@property
106-
def dc_id(self) -> int:
107-
return self._dc_id
108-
105+
def dc_id(self) -> int: return self._dc_id
109106
@property
110-
def server_address(self) -> Optional[str]:
111-
return self._server_address
112-
107+
def server_address(self) -> Optional[str]: return self._server_address
113108
@property
114-
def port(self) -> int:
115-
return self._port
116-
109+
def port(self) -> int: return self._port
117110
@property
118-
def auth_key(self) -> Optional[AuthKey]:
119-
return self._auth_key
120-
111+
def auth_key(self) -> Optional[AuthKey]: return self._auth_key
121112
@auth_key.setter
122-
def auth_key(self, value: Optional[AuthKey]):
123-
self._auth_key = value
124-
113+
def auth_key(self, value: Optional[AuthKey]): self._auth_key = value
125114
@property
126-
def takeout_id(self) -> Optional[int]:
127-
return self._takeout_id
128-
115+
def takeout_id(self) -> Optional[int]: return self._takeout_id
129116
@takeout_id.setter
130-
def takeout_id(self, value: Optional[int]):
131-
self._takeout_id = value
132-
133-
# --- Abstract Methods Implementation ---
117+
def takeout_id(self, value: Optional[int]): self._takeout_id = value
134118

135119
def set_dc(self, dc_id: int, server_address: str, port: int):
136-
self._dc_id = dc_id
137-
self._server_address = server_address
138-
self._port = port
120+
self._dc_id, self._server_address, self._port = dc_id, server_address, port
139121

140122
def get_update_state(self, entity_id: int) -> Optional[Tuple[int, int, int, int, int]]:
141-
if self._pts is None:
142-
return None
123+
if self._pts is None: return None
143124
return self._pts, self._qts, self._date, self._seq, 0
144125

145126
def set_update_state(self, entity_id: int, state: Any):
146127
if isinstance(state.date, datetime):
147128
date_ts = int(state.date.replace(tzinfo=timezone.utc).timestamp())
148129
else:
149130
date_ts = int(state.date)
150-
151-
self._pts = state.pts
152-
self._qts = state.qts
153-
self._date = date_ts
154-
self._seq = state.seq
131+
self._pts, self._qts, self._date, self._seq = state.pts, state.qts, date_ts, state.seq
155132

156-
async def close(self) -> None:
157-
pass
133+
async def close(self) -> None: pass
158134

159135
def get_update_states(self) -> List[Tuple[int, int, int, int, int, int]]:
160-
if self._pts is None:
161-
return []
136+
if self._pts is None: return []
162137
return [(0, self._pts, self._qts, self._date, self._seq, 0)]
163138

164-
# --- New Stubs for Entity and File Caching ---
165-
166-
def process_entities(self, tlo: object) -> None:
167-
"""
168-
This session does not cache entities, so this method does nothing.
169-
170-
Args:
171-
tlo (object): A TLObject containing entities.
172-
"""
173-
pass
139+
def process_entities(self, tlo: object) -> None: pass
174140

175141
def get_input_entity(self, key: Any) -> Any:
176142
"""
177-
This session does not cache entities, so this method always fails.
178-
179-
Args:
180-
key (Any): The key to look up an entity.
181-
182-
Raises:
183-
KeyError: Always, as no entities are cached.
143+
Returns the input entity for the current user if key is 0.
144+
This is crucial for the client to know "who it is" upon startup.
184145
"""
185-
raise KeyError("Entity not found in DbSession cache (caching is not implemented).")
146+
if key == 0 and self._self_user_id is not None:
147+
return PeerUser(self._self_user_id)
148+
raise KeyError("Entity not found in DbSession cache (caching is not implemented for other entities).")
186149

187-
def cache_file(self, md5_digest: bytes, file_size: int, instance: Any) -> None:
188-
"""
189-
This session does not cache files, so this method does nothing.
190-
191-
Args:
192-
md5_digest (bytes): The MD5 digest of the file.
193-
file_size (int): The size of the file.
194-
instance (Any): The InputFile or InputPhoto instance.
195-
"""
196-
pass
150+
def cache_file(self, md5_digest: bytes, file_size: int, instance: Any) -> None: pass
197151

198152
def get_file(self, md5_digest: bytes, file_size: int, exact: bool = True) -> Optional[Any]:
199-
"""
200-
This session does not cache files, so this method always returns None.
201-
202-
Args:
203-
md5_digest (bytes): The MD5 digest of the file.
204-
file_size (int): The size of the file.
205-
exact (bool): Whether the file size must be exact.
206-
207-
Returns:
208-
None: Always, as no files are cached.
209-
"""
210153
return None

0 commit comments

Comments
 (0)