diff --git a/Dockerfile b/Dockerfile index f169ffa..94fa53c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,36 +1,52 @@ -# 用于构建和设置变量 -FROM python:3.12-alpine AS builder +# Single-stage build for Python application +FROM python:3.12-slim -# 设置时区为Asia/Shanghai, DOCKER_MODE为1 +# Set environment variables ENV TZ=Asia/Shanghai \ - DOCKER_MODE=1 \ - PUID=0 \ - PGID=0 \ - UMASK=000 \ - PYTHONWARNINGS="ignore:semaphore_tracker:UserWarning" \ - WORKDIR="/app" - -# 设置默认工作目录 + DOCKER_MODE=1 \ + PUID=0 \ + PGID=0 \ + UMASK=000 \ + PYTHONWARNINGS="ignore:semaphore_tracker:UserWarning" \ + WORKDIR="/app" \ + PATH="/root/.local/bin:${PATH}" + +# Set working directory WORKDIR ${WORKDIR} -#复制uv lockfile到工作目录中 -COPY uv.lock ${WORKDIR} - -# 安装必要的环境 -RUN apk add --no-cache --virtual .build-deps gcc git musl-dev \ - && wget -qO- https://astral.sh/uv/install.sh | sh \ - && source /root/.local/bin/env \ - && uv sync \ - && uv cache clean \ - && apk del --purge .build-deps \ - && rm -rf /tmp/* /root/.cache /var/cache/apk/* - -# 将从构建上下文目录中的文件和目录复制到新的一层的镜像内的工作目录中 +# Copy requirements files first for better caching +COPY pyproject.toml uv.lock .python-version ./ + +# Install uv and application dependencies +# Use bash explicitly to support 'source' command +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + gcc \ + git \ + wget \ + ca-certificates \ + bash \ + libc6-dev \ + python3-dev && \ + # Log Python version from .python-version file + echo "Target Python version from .python-version: $(cat .python-version)" && \ + # Install uv + wget -qO- https://astral.sh/uv/install.sh | bash && \ + # Make uv available without source + bash -c 'export PATH="/root/.local/bin:$PATH" && \ + # Install project dependencies from pyproject.toml + /root/.local/bin/uv sync' && \ + # Clean up build dependencies + apt-get purge -y --auto-remove gcc git wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /root/.cache + +# Copy the rest of the application COPY . . -# 将应用日志输出到stdout +# Redirect logs to stdout RUN ln -sf /dev/stdout /app/default.log -# 定义容器启动时执行的默认命令 -ENTRYPOINT ["/root/.local/bin/uv","run","app.py"] +# Define entrypoint using uv +ENTRYPOINT ["/root/.local/bin/uv", "run", "app.py"] diff --git a/app.py b/app.py index 1c35cba..81ae541 100644 --- a/app.py +++ b/app.py @@ -3,54 +3,54 @@ from datetime import datetime import pytz -from py_tools.connections.db.mysql import DBManager, BaseOrmTable, SQLAlchemyManager -from sqlalchemy import text -from sqlalchemy.ext.asyncio import create_async_engine from bot.bot_client import BotClient from bot.commands import CommandHandler from config import config from core.emby_api import EmbyApi, EmbyRouterAPI from services import UserService +from models.database import ( + init_db, + create_database_if_not_exists, + create_tables, +) # Initialize logger logger = logging.getLogger(__name__) - - -async def create_database_if_not_exists() -> None: - """创建数据库。""" - engine_without_db = create_async_engine( - f"mysql+asyncmy://{config.db_user}:{config.db_pass}@{config.db_host}:{config.db_port}/", - echo=True, - ) - async with engine_without_db.begin() as conn: - query = f"CREATE DATABASE IF NOT EXISTS {config.db_name}" - logger.info(f"SQL Query: {query}, Context: Creating database") - await conn.execute(text(query)) - await engine_without_db.dispose() +# async def _init_db() -> None: """初始化数据库连接并创建表。""" - await create_database_if_not_exists() + # Create database if it doesn't exist + await create_database_if_not_exists( + host=config.db_host, + port=config.db_port, + user=config.db_user, + password=config.db_pass, + db_name=config.db_name, + ) - db_client = SQLAlchemyManager( + # Initialize the engine and session factory + await init_db( host=config.db_host, port=config.db_port, user=config.db_user, password=config.db_pass, db_name=config.db_name, + echo=True, ) - db_client.init_mysql_engine() - DBManager.init_db_client(db_client) - async with DBManager.connection() as conn: - logger.info("Context: Creating tables") - await conn.run_sync(BaseOrmTable.metadata.create_all) + # Create all tables + await create_tables() def _init_logger() -> None: """初始化日志记录器。""" + # Clear any existing handlers to prevent duplicates + logger.handlers = [] + + # Create handlers handler = logging.StreamHandler() # 输出到终端 fmt = "%(levelname)s [%(asctime)s] %(name)s - %(message)s" datefmt = "%Y-%m-%d %H:%M:%S" @@ -60,16 +60,13 @@ def _init_logger() -> None: ) handler.setFormatter(formatter) - # 设置日志默认值 - logging.basicConfig(format=fmt, datefmt=datefmt, level=config.log_level) + # Set log level + logger.setLevel(config.log_level) - # 添加流处理器 + # Add handler (don't use basicConfig if you're manually configuring) logger.addHandler(handler) - # 设置日志级别 - logger.setLevel(config.log_level) - - # 如果你需要将日志写入文件,可以继续保持原有的文件配置: + # Add file handler if needed file_handler = logging.FileHandler("default.log") file_handler.setFormatter(formatter) logger.addHandler(file_handler) diff --git a/models/__init__.py b/models/__init__.py index d128db6..e69de29 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +0,0 @@ -from .user_model import User -from .config_model import Config -from .invite_code_model import InviteCode diff --git a/models/config_model.py b/models/config_model.py index 0b269a2..4ef26bc 100644 --- a/models/config_model.py +++ b/models/config_model.py @@ -1,13 +1,14 @@ import logging -from py_tools.connections.db.mysql import DBManager -from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS -from sqlalchemy import Integer, BigInteger +from sqlalchemy import Integer, BigInteger, select from sqlalchemy.orm import mapped_column, Mapped +from .database import Base, BaseModelWithTS, DbOperations, get_session +from .invite_code_model import InviteCode + logger = logging.getLogger(__name__) -class Config(BaseOrmTableWithTS): +class Config(Base, BaseModelWithTS): __tablename__ = "config" total_register_user: Mapped[int] = mapped_column(Integer, nullable=False, default=0) @@ -17,8 +18,62 @@ class Config(BaseOrmTableWithTS): register_public_time: Mapped[int] = mapped_column(BigInteger, nullable=True) -class ConfigOrm(DBManager): - orm_table = Config +class ConfigRepository: + """Replaces ConfigOrm to handle Config database operations""" + + @staticmethod + async def create_config(**kwargs): + return await DbOperations.create(Config, **kwargs) + + @staticmethod + async def get_by_id(config_id: int): + return await DbOperations.get_by_id(Config, config_id) + + @staticmethod + async def get_first_config(): + """Get the first (and typically only) config record""" + async for session in get_session(): + result = await session.execute(select(Config).limit(1)) + return result.scalars().first() + + @staticmethod + async def update_config(config_id: int, **kwargs): + return await DbOperations.update(Config, config_id, **kwargs) + + @staticmethod + async def create_invite_code(**kwargs): + return await DbOperations.create(InviteCode, **kwargs) + + # @staticmethod + # async def get_by_id(code_id: int): + # return await DbOperations.get_by_id(InviteCode, code_id) + + @staticmethod + async def get_by_code(code: str): + async for session in get_session(): + result = await session.execute( + select(InviteCode).where(InviteCode.code == code) + ) + return result.scalars().first() + + @staticmethod + async def get_by_telegram_id(telegram_id: int): + async for session in get_session(): + result = await session.execute( + select(InviteCode).where(InviteCode.telegram_id == telegram_id) + ) + return result.scalars().all() + @staticmethod + async def update_invite_code(code_id: int, **kwargs): + return await DbOperations.update(InviteCode, code_id, **kwargs) -logger.info("Config model initialized") + @staticmethod + async def mark_as_used(code_id: int, used_time: int, used_user_id: int): + return await DbOperations.update( + InviteCode, + code_id, + is_used=True, + used_time=used_time, + used_user_id=used_user_id, + ) diff --git a/models/database.py b/models/database.py new file mode 100644 index 0000000..549d11a --- /dev/null +++ b/models/database.py @@ -0,0 +1,139 @@ +import logging +from datetime import datetime +from typing import AsyncGenerator, Optional, Type, TypeVar, Any + +from sqlalchemy import Column, DateTime, Integer, text +from sqlalchemy.ext.asyncio import ( + AsyncSession, + create_async_engine, + async_sessionmaker, + AsyncEngine, +) +from sqlalchemy.orm import declarative_base, DeclarativeMeta + +logger = logging.getLogger(__name__) + +# Global engine reference +engine: Optional[AsyncEngine] = None +async_session_factory: Optional[async_sessionmaker] = None + +# Create base model class +Base = declarative_base() + +# Type variable for ORM models +T = TypeVar("T", bound=DeclarativeMeta) + + +class BaseModel: + """Base model class to replace BaseOrmTable.""" + + id = Column(Integer, primary_key=True, autoincrement=True) + + +class BaseModelWithTS(BaseModel): + """Base model with timestamp columns to replace BaseOrmTableWithTS.""" + + created_at = Column(DateTime, default=datetime.utcnow, nullable=False) + updated_at = Column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False + ) + + +async def init_db( + host: str, port: int, user: str, password: str, db_name: str, echo: bool = False +) -> None: + """Initialize database connection.""" + global engine, async_session_factory + + connection_string = f"mysql+asyncmy://{user}:{password}@{host}:{port}/{db_name}" + engine = create_async_engine(connection_string, echo=echo) + async_session_factory = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + + +async def create_database_if_not_exists( + host: str, port: int, user: str, password: str, db_name: str +) -> None: + """Create database if it doesn't exist.""" + engine_without_db = create_async_engine( + f"mysql+asyncmy://{user}:{password}@{host}:{port}/", + echo=True, + ) + async with engine_without_db.begin() as conn: + query = f"CREATE DATABASE IF NOT EXISTS {db_name}" + logger.info(f"SQL Query: {query}, Context: Creating database") + await conn.execute(text(query)) + await engine_without_db.dispose() + + +async def create_tables() -> None: + """Create all tables defined in the models.""" + if engine is None: + raise RuntimeError("Database engine not initialized") + + async with engine.begin() as conn: + logger.info("Context: Creating tables") + await conn.run_sync(Base.metadata.create_all) + + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """Provide a session for database operations.""" + if async_session_factory is None: + raise RuntimeError("Session factory not initialized") + + async with async_session_factory() as session: + try: + yield session + finally: + await session.close() + + +class DbOperations: + """Class to replace DBManager for common database operations.""" + + @staticmethod + async def create(model: Type[T], **kwargs) -> T: + """Create a new record.""" + async for session in get_session(): + instance = model(**kwargs) + session.add(instance) + await session.commit() + await session.refresh(instance) + return instance + + @staticmethod + async def get_by_id(model: Type[T], id: int) -> Optional[T]: + """Get record by ID.""" + async for session in get_session(): + return await session.get(model, id) + + @staticmethod + async def update(model: Type[T], id: int, **kwargs) -> Optional[T]: + """Update a record by ID.""" + async for session in get_session(): + instance = await session.get(model, id) + if instance: + for key, value in kwargs.items(): + setattr(instance, key, value) + await session.commit() + await session.refresh(instance) + return instance + + @staticmethod + async def delete(model: Type[T], id: int) -> bool: + """Delete a record by ID.""" + async for session in get_session(): + instance = await session.get(model, id) + if instance: + await session.delete(instance) + await session.commit() + return True + return False + + @staticmethod + async def execute(query: Any) -> Any: + """Execute a custom query.""" + async for session in get_session(): + result = await session.execute(query) + return result diff --git a/models/invite_code_model.py b/models/invite_code_model.py index 3705697..5b018af 100644 --- a/models/invite_code_model.py +++ b/models/invite_code_model.py @@ -1,10 +1,10 @@ import logging import enum -from py_tools.connections.db.mysql import DBManager -from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS -from sqlalchemy import String, BigInteger, Boolean, Enum +from sqlalchemy import String, BigInteger, Boolean, Enum, select from sqlalchemy.orm import Mapped, mapped_column +from .database import Base, BaseModelWithTS, DbOperations, get_session + logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ def __str__(self): return self.value -class InviteCode(BaseOrmTableWithTS): +class InviteCode(Base, BaseModelWithTS): __tablename__ = "invite_code" code: Mapped[str] = mapped_column( @@ -40,8 +40,43 @@ def __repr__(self): ) -class InviteCodeOrm(DBManager): - orm_table = InviteCode +class InviteCodeRepository: + """Replaces InviteCodeOrm to handle InviteCode database operations""" + + @staticmethod + async def create_invite_code(**kwargs): + return await DbOperations.create(InviteCode, **kwargs) + + @staticmethod + async def get_by_id(code_id: int): + return await DbOperations.get_by_id(InviteCode, code_id) + @staticmethod + async def get_by_code(code: str): + async for session in get_session(): + result = await session.execute( + select(InviteCode).where(InviteCode.code == code) + ) + return result.scalars().first() -logger.info("InviteCode model initialized") + @staticmethod + async def get_by_telegram_id(telegram_id: int): + async for session in get_session(): + result = await session.execute( + select(InviteCode).where(InviteCode.telegram_id == telegram_id) + ) + return result.scalars().all() + + @staticmethod + async def update_invite_code(code_id: int, **kwargs): + return await DbOperations.update(InviteCode, code_id, **kwargs) + + @staticmethod + async def mark_as_used(code_id: int, used_time: int, used_user_id: int): + return await DbOperations.update( + InviteCode, + code_id, + is_used=True, + used_time=used_time, + used_user_id=used_user_id, + ) diff --git a/models/user_model.py b/models/user_model.py index 8d4069a..05c879a 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -1,15 +1,14 @@ import logging -from py_tools.connections.db.mysql import DBManager -from py_tools.connections.db.mysql.orm_model import BaseOrmTableWithTS -from sqlalchemy import String, Boolean, BigInteger +from sqlalchemy import String, Boolean, BigInteger, select from sqlalchemy.orm import Mapped, mapped_column +from .database import Base, BaseModelWithTS, DbOperations, get_session from config import config logger = logging.getLogger(__name__) -class User(BaseOrmTableWithTS): +class User(Base, BaseModelWithTS): __tablename__ = "user" telegram_id: Mapped[int] = mapped_column( @@ -108,8 +107,35 @@ def emby_ban_info(self) -> tuple[int, str]: return self.ban_time, self.reason -class UserOrm(DBManager): - orm_table = User +class UserRepository: + """Replaces UserOrm to handle User database operations""" + @staticmethod + async def create_user(**kwargs): + return await DbOperations.create(User, **kwargs) -logger.info("User model initialized") + @staticmethod + async def get_by_id(user_id: int): + return await DbOperations.get_by_id(User, user_id) + + @staticmethod + async def get_by_telegram_id(telegram_id: int): + async for session in get_session(): + result = await session.execute( + select(User).where(User.telegram_id == telegram_id) + ) + return result.scalars().first() + + @staticmethod + async def get_by_emby_id(emby_id: str): + async for session in get_session(): + result = await session.execute(select(User).where(User.emby_id == emby_id)) + return result.scalars().first() + + @staticmethod + async def update_user(user_id: int, **kwargs): + return await DbOperations.update(User, user_id, **kwargs) + + @staticmethod + async def delete_user(user_id: int): + return await DbOperations.delete(User, user_id) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0ea63cc..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -git+https://github.com/rebeeh/pyrogram.git@master -python-dotenv==1.0.1 -urllib3==2.3.0 -PyMySQL==1.1.1 -TgCrypto==1.2.5 -pytz~=2025.1 -requests==2.31.0 -SQLAlchemy~=2.0.20 -huidevkit[db-orm]~=0.6.0 -shortuuid~=1.0.13 -cryptography -asyncmy \ No newline at end of file diff --git a/services/user_service.py b/services/user_service.py index ed8cf5c..01a0379 100644 --- a/services/user_service.py +++ b/services/user_service.py @@ -6,21 +6,23 @@ from typing import Optional, List, Dict, Tuple import shortuuid -from sqlalchemy import select from config import config from core.emby_api import EmbyApi, EmbyRouterAPI -from models import User, Config, InviteCode -from models.config_model import ConfigOrm -from models.invite_code_model import InviteCodeOrm, InviteCodeType -from models.user_model import UserOrm +from models.config_model import Config, ConfigRepository +from models.invite_code_model import InviteCode, InviteCodeRepository, InviteCodeType +from models.user_model import User, UserRepository +from models.database import get_session logger = logging.getLogger(__name__) + class NotBoundError(Exception): """用户未绑定 Emby 账号的异常""" + pass + class UserService: """用户与 Emby 相关的业务逻辑层""" @@ -31,18 +33,16 @@ def __init__(self, emby_api: EmbyApi, emby_router_api: EmbyRouterAPI): @staticmethod async def get_or_create_user_by_telegram_id(telegram_id: int) -> User: """通过 telegram_id 从数据库获取用户,如果不存在则创建一个默认用户""" - user = await UserOrm().query_one(conds=[User.telegram_id == telegram_id]) + user = await UserRepository.get_by_telegram_id(telegram_id) if not user: - default_user = User( + # Create new user with parameters instead of User object + user = await UserRepository.create_user( telegram_id=telegram_id, is_admin=telegram_id in config.admin_list, telegram_name=config.group_members.get(telegram_id, {}).username if config.group_members.get(telegram_id) else None, ) - user_id = await UserOrm().add(default_user) - user = default_user - user.id = user_id return user @staticmethod @@ -77,9 +77,13 @@ async def _emby_create_user( raise Exception("在 Emby 系统中创建账号失败,请检查 Emby 服务是否正常。") emby_id = emby_user["Id"] - user.emby_id = emby_id - user.emby_name = username - user.enable_register = False + # Update user directly with UserRepository + await UserRepository.update_user( + user.id, emby_id=emby_id, emby_name=username, enable_register=False + ) + + # Reload user after update + user = await UserRepository.get_by_id(user.id) # 设置初始密码 & 默认Policy self.emby_api.set_user_password(emby_id, password) @@ -109,13 +113,15 @@ async def create_invite_code( if not user.check_create_invite_code(): raise Exception("您没有权限生成普通邀请码。") - code_objs = [ - InviteCode( + # Create and store invite codes one by one + created_codes = [] + for code in self.gen_register_code(count): + invite_code = await InviteCodeRepository.create_invite_code( code=code, telegram_id=telegram_id, code_type=InviteCodeType.REGISTER ) - for code in self.gen_register_code(count) - ] - return await InviteCodeOrm().bulk_add(code_objs) + created_codes.append(invite_code) + + return created_codes async def create_whitelist_code( self, telegram_id: int, count: int = 1 @@ -125,13 +131,15 @@ async def create_whitelist_code( if not user.check_create_whitelist_code(): raise Exception("您没有权限生成白名单邀请码。") - code_objs = [ - InviteCode( + # Create and store whitelist codes one by one + created_codes = [] + for code in self.gen_whitelist_code(count): + invite_code = await InviteCodeRepository.create_invite_code( code=code, telegram_id=telegram_id, code_type=InviteCodeType.WHITELIST ) - for code in self.gen_whitelist_code(count) - ] - return await InviteCodeOrm().bulk_add(code_objs) + created_codes.append(invite_code) + + return created_codes async def emby_info(self, telegram_id: int) -> Tuple[User, Dict]: """获取当前用户在 Emby 的信息""" @@ -140,17 +148,18 @@ async def emby_info(self, telegram_id: int) -> Tuple[User, Dict]: raise NotBoundError("该用户尚未绑定 Emby 账号。") emby_user = self.emby_api.get_user(str(user.emby_id)) if not emby_user: - raise Exception("从 Emby 服务器获取用户信息失败,请检查 Emby 服务是否正常。") + raise Exception( + "从 Emby 服务器获取用户信息失败,请检查 Emby 服务是否正常。" + ) return user, emby_user async def first_or_create_emby_config(self) -> Config: """获取或创建 Emby 配置。""" - emby_config = await ConfigOrm().query_one(conds=[Config.id == 1]) + emby_config = await ConfigRepository.get_by_id(1) if not emby_config: - emby_config = Config( + emby_config = await ConfigRepository.create_config( register_public_user=0, register_public_time=0, total_register_user=0 ) - await ConfigOrm().add(emby_config) return emby_config async def emby_create_user( @@ -168,17 +177,28 @@ async def emby_create_user( if not await self._check_register_permission(user, emby_config): raise Exception("当前没有可用的注册权限或名额,创建账号被拒绝。") - async with ConfigOrm().transaction() as session: - if not user.enable_register and emby_config.register_public_user > 0: - emby_config.register_public_user -= 1 - - emby_config.total_register_user += 1 - new_user = await self._emby_create_user(telegram_id, username, password) - - session.add(new_user) - session.add(emby_config) - await session.commit() - return new_user + # Use manual session management instead of transaction context manager + async for session in get_session(): + try: + if not user.enable_register and emby_config.register_public_user > 0: + emby_config.register_public_user -= 1 + + emby_config.total_register_user += 1 + await ConfigRepository.update_config( + emby_config.id, + register_public_user=emby_config.register_public_user, + total_register_user=emby_config.total_register_user, + ) + + # Create user in Emby system + new_user = await self._emby_create_user(telegram_id, username, password) + + await session.commit() + return new_user + except Exception as e: + await session.rollback() + logger.error(f"创建用户失败: {e}") + raise async def _check_register_permission(self, user: User, emby_config: Config) -> bool: """检查用户是否有权限注册 Emby 账号""" @@ -192,9 +212,7 @@ async def _check_register_permission(self, user: User, emby_config: Config) -> b ): enable_register = True if 0 < emby_config.register_public_time < datetime.now().timestamp(): - await ConfigOrm().update( - values={"register_public_time": 0}, conds=[Config.id == 1] - ) + await ConfigRepository.update_config(1, register_public_time=0) return enable_register async def redeem_code(self, telegram_id: int, code: str): @@ -205,40 +223,42 @@ async def redeem_code(self, telegram_id: int, code: str): user = await self.must_get_user(telegram_id) - # 使用事务块,并通过行锁防止并发问题 - async with InviteCodeOrm().transaction() as session: - # 构造 SELECT 语句,并加上 FOR UPDATE 行锁 - stmt = select(InviteCode).where(InviteCode.code == code).with_for_update() - result = await session.execute(stmt) - valid_code = result.scalars().first() - - if not valid_code or valid_code.is_used: - raise Exception("该邀请码无效或已被使用。") - - # 根据邀请码类型执行不同的业务逻辑校验 - if valid_code.code_type == InviteCodeType.REGISTER: - user.check_use_redeem_code() - elif valid_code.code_type == InviteCodeType.WHITELIST: - user.check_use_whitelist_code() - if user.is_emby_baned(): - await self.emby_unban(telegram_id) - - # 标记邀请码已使用,并记录使用时间和使用者 - valid_code.is_used = True - valid_code.used_time = datetime.now().timestamp() - valid_code.used_user_id = telegram_id - - # 根据邀请码类型更新用户状态 - if valid_code.code_type == InviteCodeType.REGISTER: - user.enable_register = True - elif valid_code.code_type == InviteCodeType.WHITELIST: - user.is_whitelist = True - - session.add(valid_code) - session.add(user) - await session.commit() - - return valid_code + # Use direct session rather than transaction context manager + async for session in get_session(): + try: + # Get invite code + valid_code = await InviteCodeRepository.get_by_code(code) + + if not valid_code or valid_code.is_used: + raise Exception("该邀请码无效或已被使用。") + + # 根据邀请码类型执行不同的业务逻辑校验 + if valid_code.code_type == InviteCodeType.REGISTER: + user.check_use_redeem_code() + elif valid_code.code_type == InviteCodeType.WHITELIST: + user.check_use_whitelist_code() + if user.is_emby_baned(): + await self.emby_unban(telegram_id) + + # Mark code as used + now = int(datetime.now().timestamp()) + await InviteCodeRepository.mark_as_used(valid_code.id, now, telegram_id) + + # Update user based on code type + if valid_code.code_type == InviteCodeType.REGISTER: + await UserRepository.update_user(user.id, enable_register=True) + elif valid_code.code_type == InviteCodeType.WHITELIST: + await UserRepository.update_user(user.id, is_whitelist=True) + + await session.commit() + + # Refresh user object after update + user = await UserRepository.get_by_id(user.id) + return valid_code + except Exception as e: + await session.rollback() + logger.error(f"使用邀请码失败: {e}") + raise async def reset_password(self, telegram_id: int, password: str = "") -> bool: """重置用户的 Emby 密码。""" @@ -265,12 +285,8 @@ async def emby_ban( try: self.emby_api.ban_user(str(user.emby_id)) - user.ban_time = int(datetime.now().timestamp()) - user.reason = reason - await UserOrm().update( - {"ban_time": user.ban_time, "reason": reason}, - conds=[User.id == user.id], - ) + ban_time = int(datetime.now().timestamp()) + await UserRepository.update_user(user.id, ban_time=ban_time, reason=reason) return True except Exception as e: logger.error(f"禁用用户失败: {e}") @@ -290,11 +306,7 @@ async def emby_unban( try: self.emby_api.set_default_policy(str(user.emby_id)) - user.ban_time = 0 - user.reason = "" - await UserOrm().update( - {"ban_time": 0, "reason": None}, conds=[User.id == user.id] - ) + await UserRepository.update_user(user.id, ban_time=0, reason=None) return True except Exception as e: logger.error(f"解禁用户失败: {e}") @@ -314,18 +326,17 @@ async def set_emby_config( if not emby_config: raise Exception("未找到全局 Emby 配置,无法设置。") + update_data = {} if register_public_user is not None: - emby_config.register_public_user = register_public_user + update_data["register_public_user"] = register_public_user if register_public_time is not None: - emby_config.register_public_time = register_public_time - - await ConfigOrm().update( - values={ - "register_public_user": emby_config.register_public_user, - "register_public_time": emby_config.register_public_time, - }, - conds=[Config.id == 1], - ) + update_data["register_public_time"] = register_public_time + + if update_data: + await ConfigRepository.update_config(emby_config.id, **update_data) + # Refresh config after update + emby_config = await ConfigRepository.get_by_id(emby_config.id) + return emby_config def emby_count(self) -> Dict: