Skip to content

Commit 79f1fac

Browse files
committed
fix: sqlite
1 parent ac247bf commit 79f1fac

File tree

7 files changed

+129
-234
lines changed

7 files changed

+129
-234
lines changed

scripts/manage_account.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def add_account_logic(args: argparse.Namespace) -> None:
3434
"""
3535
logger.info(f"--- Adding new account: {args.name} ---")
3636

37-
session_file: str = f"temp_cli_{args.name}.session"
37+
session_file_path: str = f"temp_cli_{args.name}.session"
3838
temp_client: Optional[TelegramClient] = None
3939
try:
4040
async with get_db() as db:
@@ -50,7 +50,7 @@ async def add_account_logic(args: argparse.Namespace) -> None:
5050
return
5151

5252
logger.info("Initializing temporary session to verify credentials...")
53-
temp_client = TelegramClient(SQLiteSession(session_file), int(api_id), api_hash)
53+
temp_client = TelegramClient(SQLiteSession(session_file_path), int(api_id), api_hash)
5454
await temp_client.connect()
5555

5656
if not await temp_client.is_user_authorized():
@@ -137,24 +137,14 @@ async def add_account_logic(args: argparse.Namespace) -> None:
137137
return
138138

139139
# Now extract session from file and save to DB
140-
logger.info("Extracting session data and saving to the database...")
141-
reader_session = SQLiteSession(session_file)
142-
reader_session.load()
143-
144-
update_state = reader_session.get_update_state(0)
145-
pts, qts, date_ts, seq, _ = (None, None, None, None, None)
146-
if update_state:
147-
pts, qts, date_ts, seq, _ = update_state
140+
logger.info("Reading session file and saving to the database...")
141+
with open(session_file_path, 'rb') as f:
142+
session_bytes: bytes = f.read()
148143

149144
await db_manager.add_or_update_session(
150145
db,
151146
account_id=new_account.account_id,
152-
dc_id=reader_session.dc_id,
153-
server_address=reader_session.server_address,
154-
port=reader_session.port,
155-
auth_key_data=reader_session.auth_key.key,
156-
takeout_id=reader_session.takeout_id,
157-
pts=pts, qts=qts, date=date_ts, seq=seq
147+
session_file=session_bytes
158148
)
159149

160150
logger.info(f"\nSuccess! Account '{args.name}' was added. Restart the bot to activate.")
@@ -164,9 +154,9 @@ async def add_account_logic(args: argparse.Namespace) -> None:
164154
finally:
165155
if temp_client and temp_client.is_connected():
166156
await temp_client.disconnect()
167-
if os.path.exists(session_file):
168-
os.remove(session_file)
169-
logger.info(f"Cleaned up temporary session file: {session_file}")
157+
if os.path.exists(session_file_path):
158+
os.remove(session_file_path)
159+
logger.info(f"Cleaned up temporary session file: {session_file_path}")
170160

171161

172162
async def edit_account_logic(args: argparse.Namespace) -> None:

userbot/__init__.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
from faker import Faker
77
from telethon import TelegramClient as TelethonTelegramClient, events
8+
from telethon.tl.functions.channels import JoinChannelRequest
89
from telethon.errors.rpcerrorlist import UserAlreadyParticipantError
910
from python_socks import ProxyType
1011

1112
from userbot.src.config import API_ID, API_HASH, LOG_LEVEL
1213
from userbot.src.db.session import initialize_database, get_db
13-
from userbot.src.db.models import Account
14-
from userbot.src.db_session import DbSession
14+
from userbot.src.db.models import Account, Session
15+
from userbot.src.memory_session import MemorySession
1516
from userbot.src.encrypt import encryption_manager
1617
import userbot.src.db_manager as db_manager
1718
from userbot.src.log_handler import DBLogHandler
@@ -35,12 +36,11 @@ class TelegramClient(TelethonTelegramClient):
3536
def __init__(self, *args, **kwargs):
3637
super().__init__(*args, **kwargs)
3738
self.lang_code: str = 'ru'
39+
self.account_id_override: Optional[int] = kwargs.get('account_id_override')
3840

3941
@property
4042
def current_account_id(self) -> Optional[int]:
41-
if hasattr(self, 'session') and isinstance(self.session, DbSession):
42-
return self.session.account_id
43-
return None
43+
return self.account_id_override
4444

4545
async def get_string(self, key: str, module_name: Optional[str] = None, **kwargs) -> str:
4646
return translator.get_string(self.lang_code, key, module_name, **kwargs)
@@ -139,21 +139,29 @@ async def manage_clients() -> None:
139139
encryption_manager.decrypt(account.proxy_username).decode() if account.proxy_username else None,
140140
encryption_manager.decrypt(account.proxy_password).decode() if account.proxy_password else None
141141
)
142+
143+
session_instance: MemorySession
144+
if account.session and account.session.session_file:
145+
try:
146+
decrypted_session_bytes: bytes = encryption_manager.decrypt(account.session.session_file)
147+
session_instance = MemorySession(decrypted_session_bytes)
148+
except Exception as e:
149+
logger.error(f"Could not decrypt session for '{account.account_name}'. Skipping. Error: {e}")
150+
continue
151+
else:
152+
logger.warning(f"No session file found in DB for account '{account.account_name}'. A new login will be required if client is used directly.")
153+
session_instance = MemorySession(None)
142154

143-
session = DbSession(
144-
account_id=account.account_id,
145-
user_id=account.user_telegram_id,
146-
access_hash=account.access_hash
147-
)
148155

149156
new_client: TelegramClient = TelegramClient(
150-
session=session,
157+
session=session_instance,
151158
api_id=int(acc_api_id),
152159
api_hash=acc_api_hash,
153160
device_model=account.device_model,
154161
system_version=account.system_version,
155162
app_version=account.app_version,
156-
proxy=proxy_details
163+
proxy=proxy_details,
164+
account_id_override=account.account_id
157165
)
158166
ACTIVE_CLIENTS[account.account_id] = new_client
159167
tasks.append(start_individual_client(new_client, account))

userbot/src/core_handlers.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
# --- Helper ---
2323
async def get_account_id_from_client(client) -> int | None:
24-
return next((acc_id for acc_id, c in ACTIVE_CLIENTS.items() if c == client), None)
24+
# Access the override property directly
25+
return client.current_account_id if hasattr(client, 'current_account_id') else None
2526

2627
# --- Module Management ---
2728
async def load_account_modules(account_id: int, client_instance: TelegramClient, current_help_info: Dict[str, str]):
@@ -75,7 +76,7 @@ async def list_accounts_handler(event: events.NewMessage.Event):
7576

7677
async def add_account_handler(event: events.NewMessage.Event):
7778
account_name: str = event.pattern_match.group(1)
78-
session_file: str = f"temp_add_{account_name}.session"
79+
session_file_path: str = f"temp_add_{account_name}.session"
7980
temp_client: Optional[TelegramClient] = None
8081

8182
try:
@@ -86,7 +87,7 @@ async def add_account_handler(event: events.NewMessage.Event):
8687
api_hash_resp = await conv.get_response(); api_hash = api_hash_resp.text.strip()
8788

8889
await conv.send_message(await event.client.get_string("verifying_creds"))
89-
temp_client = TelegramClient(SQLiteSession(session_file), int(api_id), api_hash)
90+
temp_client = TelegramClient(SQLiteSession(session_file_path), int(api_id), api_hash)
9091

9192
await temp_client.connect()
9293
if not await temp_client.is_user_authorized():
@@ -141,23 +142,13 @@ async def add_account_handler(event: events.NewMessage.Event):
141142
return
142143

143144
# Now extract session data and save it
144-
reader_session = SQLiteSession(session_file)
145-
reader_session.load()
146-
147-
update_state = reader_session.get_update_state(0)
148-
pts, qts, date_ts, seq, _ = (None, None, None, None, None)
149-
if update_state:
150-
pts, qts, date_ts, seq, _ = update_state
151-
145+
with open(session_file_path, 'rb') as f:
146+
session_bytes: bytes = f.read()
147+
152148
await db_manager.add_or_update_session(
153149
db,
154150
account_id=new_acc.account_id,
155-
dc_id=reader_session.dc_id,
156-
server_address=reader_session.server_address,
157-
port=reader_session.port,
158-
auth_key_data=reader_session.auth_key.key,
159-
takeout_id=reader_session.takeout_id,
160-
pts=pts, qts=qts, date=date_ts, seq=seq
151+
session_file=session_bytes
161152
)
162153

163154
await conv.send_message(await event.client.get_string("add_acc_success", account_name=account_name))
@@ -170,8 +161,8 @@ async def add_account_handler(event: events.NewMessage.Event):
170161
finally:
171162
if temp_client and temp_client.is_connected():
172163
await temp_client.disconnect()
173-
if os.path.exists(session_file):
174-
os.remove(session_file)
164+
if os.path.exists(session_file_path):
165+
os.remove(session_file_path)
175166

176167

177168
async def delete_account_handler(event: events.NewMessage.Event):

userbot/src/db/models.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Column, Integer, String, Boolean, ForeignKey, BIGINT, TEXT,
33
TIMESTAMP, UniqueConstraint
44
)
5-
from sqlalchemy.dialects.postgresql import BYTEA, JSONB
5+
from sqlalchemy.dialects.postgresql import BYTEA
66
from sqlalchemy.orm import declarative_base, relationship
77
from sqlalchemy.sql import func
88

@@ -43,18 +43,7 @@ class Session(Base):
4343
__tablename__ = 'sessions'
4444
session_id = Column(Integer, primary_key=True)
4545
account_id = Column(Integer, ForeignKey('accounts.account_id', ondelete='CASCADE'), nullable=False, unique=True)
46-
dc_id = Column(Integer, nullable=False)
47-
server_address = Column(TEXT)
48-
port = Column(Integer)
49-
auth_key_data = Column(BYTEA)
50-
51-
# Update state fields
52-
pts = Column(Integer)
53-
qts = Column(Integer)
54-
date = Column(BIGINT)
55-
seq = Column(Integer)
56-
takeout_id = Column(BIGINT)
57-
46+
session_file = Column(BYTEA, nullable=False)
5847
last_used_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False)
5948
created_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False)
6049

@@ -80,7 +69,7 @@ class AccountModule(Base):
8069
module_id = Column(Integer, ForeignKey('modules.module_id', ondelete='CASCADE'), nullable=False)
8170
is_active = Column(Boolean, default=True)
8271
is_trusted = Column(Boolean, default=False)
83-
configuration = Column(JSONB)
72+
configuration = Column(BYTEA) # Changed to BYTEA for encrypted JSON
8473
activated_at = Column(TIMESTAMP(timezone=True), server_default=func.now(), nullable=False)
8574
updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now())
8675

userbot/src/db_manager.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def get_all_accounts(db: AsyncSession) -> List[Account]:
116116

117117
async def get_all_active_accounts(db: AsyncSession) -> List[Account]:
118118
"""Retrieves all enabled accounts from the database."""
119-
result = await db.execute(select(Account).where(Account.is_enabled == True))
119+
result = await db.execute(select(Account).where(Account.is_enabled == True).options(selectinload(Account.session)))
120120
return result.scalars().all()
121121

122122
async def delete_account(db: AsyncSession, account_name: str) -> bool:
@@ -161,34 +161,35 @@ async def update_account_self_info(account_id: int, user_id: int, access_hash: i
161161

162162
# --- Session CRUD ---
163163
async def get_session(db: AsyncSession, account_id: int) -> Optional[Session]:
164-
"""Retrieves a session for a given account and decrypts its auth key."""
164+
"""Retrieves a session for a given account ID."""
165165
result = await db.execute(select(Session).where(Session.account_id == account_id))
166-
session = result.scalars().first()
167-
if session and session.auth_key_data:
168-
try:
169-
session.auth_key_data = encryption_manager.decrypt(session.auth_key_data)
170-
except Exception as e:
171-
logger.error(f"Failed to decrypt session auth_key for account {account_id}: {e}")
172-
return None
173-
return session
166+
return result.scalars().first()
174167

175-
async def add_or_update_session(db: AsyncSession, **kwargs) -> Optional[Session]:
176-
"""Adds or updates a session in the database, encrypting the auth key."""
177-
account_id = kwargs.get("account_id")
168+
async def add_or_update_session(db: AsyncSession, account_id: int, session_file: bytes) -> Optional[Session]:
169+
"""
170+
Adds or updates a session in the database.
171+
172+
Args:
173+
db (AsyncSession): The database session.
174+
account_id (int): The ID of the account this session belongs to.
175+
session_file (bytes): The raw, unencrypted bytes of the session file.
176+
177+
Returns:
178+
Optional[Session]: The created or updated Session object.
179+
"""
178180
if not account_id: return None
179181

180182
result = await db.execute(select(Session).where(Session.account_id == account_id))
181183
session = result.scalars().first()
182184

185+
encrypted_session_file = encryption_manager.encrypt(session_file)
186+
183187
if not session:
184-
session = Session(account_id=account_id)
188+
session = Session(account_id=account_id, session_file=encrypted_session_file)
185189
db.add(session)
186-
187-
for key, value in kwargs.items():
188-
if key == "auth_key_data" and value is not None:
189-
value = encryption_manager.encrypt(value)
190-
setattr(session, key, value)
191-
190+
else:
191+
session.session_file = encrypted_session_file
192+
192193
session.last_used_at = datetime.now(timezone.utc)
193194
await db.flush()
194195
return session

0 commit comments

Comments
 (0)