diff --git a/src/__version__.py b/src/__version__.py index 364e7ba..4c513f3 100644 --- a/src/__version__.py +++ b/src/__version__.py @@ -1 +1 @@ -__version__ = "0.6.4" +__version__ = "0.6.6" diff --git a/src/bot/middlewares/error.py b/src/bot/middlewares/error.py index dbc4603..56000ee 100644 --- a/src/bot/middlewares/error.py +++ b/src/bot/middlewares/error.py @@ -4,6 +4,12 @@ from aiogram.types import ErrorEvent, TelegramObject from aiogram.types import User as AiogramUser from aiogram.utils.formatting import Text +from aiogram_dialog.api.exceptions import ( + InvalidStackIdError, + OutdatedIntent, + UnknownIntent, + UnknownState, +) from dishka import AsyncContainer from src.bot.keyboards import get_user_keyboard @@ -29,8 +35,19 @@ async def middleware_logic( data: dict[str, Any], ) -> Any: aiogram_user: Optional[AiogramUser] = self._get_aiogram_user(event) - error_event = cast(ErrorEvent, event) + + if isinstance( + error_event.exception, + ( + InvalidStackIdError, + OutdatedIntent, + UnknownIntent, + UnknownState, + ), + ): + return await handler(event, data) + error = error_event.exception traceback_str = traceback.format_exc() error_type_name = type(error).__name__ @@ -59,7 +76,7 @@ async def middleware_logic( "user": True if user else False, "user_id": str(user.telegram_id) if user else False, "user_name": user.name if user else False, - "username": user.username if user else False, + "username": user.username if user and user.username else False, "error": f"{error_type_name}: {error_message.as_html()}", }, reply_markup=reply_markup, diff --git a/src/core/logger.py b/src/core/logger.py index 4f8ed2d..7e757fb 100644 --- a/src/core/logger.py +++ b/src/core/logger.py @@ -78,9 +78,9 @@ def setup_logger() -> None: sink=LOG_DIR / LOG_FILENAME, level=LOG_LEVEL, format=LOG_FORMAT, - rotation=LOG_ROTATION, - retention=LOG_RETENTION, - compression=compress_log_file, + rotation="1GB", + retention="3 days", + compression="zip", encoding=LOG_ENCODING, ) diff --git a/src/infrastructure/database/migrations/versions/0017_fix_external_squad.py b/src/infrastructure/database/migrations/versions/0017_fix_external_squad.py new file mode 100644 index 0000000..693cf42 --- /dev/null +++ b/src/infrastructure/database/migrations/versions/0017_fix_external_squad.py @@ -0,0 +1,37 @@ +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "0017" +down_revision: Union[str, None] = "0016" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("plans", sa.Column("external_squad_new", sa.UUID(), nullable=True)) + + op.execute(""" + UPDATE plans + SET external_squad_new = external_squad[1] + WHERE external_squad IS NOT NULL; + """) + + op.drop_column("plans", "external_squad") + + op.alter_column("plans", "external_squad_new", new_column_name="external_squad") + + +def downgrade() -> None: + op.add_column("plans", sa.Column("external_squad_new", sa.ARRAY(sa.UUID()), nullable=True)) + + op.execute(""" + UPDATE plans + SET external_squad_new = ARRAY[external_squad] + WHERE external_squad IS NOT NULL; + """) + + op.drop_column("plans", "external_squad") + + op.alter_column("plans", "external_squad_new", new_column_name="external_squad") diff --git a/src/infrastructure/database/models/sql/plan.py b/src/infrastructure/database/models/sql/plan.py index ffe61fe..2fe9a60 100644 --- a/src/infrastructure/database/models/sql/plan.py +++ b/src/infrastructure/database/models/sql/plan.py @@ -56,7 +56,7 @@ class Plan(BaseSql, TimestampMixin): ) allowed_user_ids: Mapped[list[int]] = mapped_column(ARRAY(BigInteger), nullable=True) internal_squads: Mapped[list[UUID]] = mapped_column(ARRAY(PG_UUID), nullable=False) - external_squad: Mapped[Optional[UUID]] = mapped_column(ARRAY(PG_UUID), nullable=True) + external_squad: Mapped[Optional[UUID]] = mapped_column(PG_UUID, nullable=True) durations: Mapped[list["PlanDuration"]] = relationship( "PlanDuration", diff --git a/src/infrastructure/database/repositories/base.py b/src/infrastructure/database/repositories/base.py index bc56b70..18bc8ec 100644 --- a/src/infrastructure/database/repositories/base.py +++ b/src/infrastructure/database/repositories/base.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, TypeVar, Union, cast +from typing import Any, Optional, Type, TypeVar, Union from sqlalchemy import ColumnExpressionArgument, delete, func, select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -42,7 +42,8 @@ async def delete_instance(self, instance: T) -> None: await self.session.delete(instance) async def _get_one(self, model: ModelType[T], *conditions: ConditionType) -> Optional[T]: - result = await self.session.execute(select(model).where(*conditions)) + stmt = select(model).where(*conditions) + result = await self.session.execute(stmt) return result.unique().scalar_one_or_none() async def _get_many( @@ -78,23 +79,16 @@ async def _update( **kwargs: Any, ) -> Optional[T]: if not kwargs: - if not load_result: - return None - return cast(Optional[T], await self._get_one(model, *conditions)) + return await self._get_one(model, *conditions) if load_result else None - query = update(model).where(*conditions).values(**kwargs) + stmt = update(model).where(*conditions).values(**kwargs) if load_result: - query = query.returning(model.id) # type: ignore [attr-defined] - - result = await self.session.execute(query) - obj_id: Optional[int] = result.scalar_one_or_none() - - if obj_id is not None and load_result: - db_obj = await self.session.get(model, obj_id) - await self.session.refresh(db_obj) - return db_obj + stmt = stmt.returning(model) + result = await self.session.execute(stmt) + return result.scalar_one_or_none() + await self.session.execute(stmt) return None async def _delete(self, model: ModelType[T], *conditions: ConditionType) -> int: diff --git a/src/infrastructure/database/uow.py b/src/infrastructure/database/uow.py index d0df29a..26967f6 100644 --- a/src/infrastructure/database/uow.py +++ b/src/infrastructure/database/uow.py @@ -8,18 +8,22 @@ class UnitOfWork: - session_pool: async_sessionmaker[AsyncSession] - session: Optional[AsyncSession] = None + def __init__(self, session_maker: async_sessionmaker[AsyncSession]) -> None: + self.session_maker = session_maker + self.session: Optional[AsyncSession] = None + self._repository: Optional[RepositoriesFacade] = None - repository: RepositoriesFacade - - def __init__(self, session_pool: async_sessionmaker[AsyncSession]) -> None: - self.session_pool = session_pool + @property + def repository(self) -> RepositoriesFacade: + if self._repository is None: + raise RuntimeError("SQL session not started. Use 'async with uow:'") + return self._repository async def __aenter__(self) -> Self: - self.session = self.session_pool() - self.repository = RepositoriesFacade(session=self.session) - logger.debug(f"Opened session '{id(self.session)}'") + self.session = self.session_maker() + self._repository = RepositoriesFacade(session=self.session) + + logger.debug(f"SQL session started. Session ID: '{id(self.session)}'") return self async def __aexit__( @@ -31,19 +35,18 @@ async def __aexit__( if self.session is None: return - session_id = id(self.session) try: - if exc_type is None: - await self.commit() + if exc_type: + await self.session.rollback() + logger.warning(f"SQL transaction rolled back due to error: '{exc_val}'") else: - logger.warning( - f"Exception detected '{exc_val}', rolling back session '{session_id}'" - ) - await self.rollback() + await self.session.commit() + logger.debug("SQL transaction committed successfully") finally: await self.session.close() - logger.debug(f"Closed session '{session_id}'") self.session = None + self._repository = None + logger.debug("SQL session closed") async def commit(self) -> None: if self.session: diff --git a/src/infrastructure/di/providers/database.py b/src/infrastructure/di/providers/database.py index ece7c5e..093977c 100644 --- a/src/infrastructure/di/providers/database.py +++ b/src/infrastructure/di/providers/database.py @@ -43,5 +43,5 @@ async def get_uow( self, session_maker: async_sessionmaker[AsyncSession], ) -> AsyncIterable[UnitOfWork]: - async with UnitOfWork(session_maker) as uow: - yield uow + uow = UnitOfWork(session_maker) + yield uow diff --git a/src/infrastructure/taskiq/broker.py b/src/infrastructure/taskiq/broker.py index 8b9dba9..3fa1a75 100644 --- a/src/infrastructure/taskiq/broker.py +++ b/src/infrastructure/taskiq/broker.py @@ -8,8 +8,17 @@ def create_broker(config: AppConfig) -> RedisStreamBroker: - result_backend: AsyncResultBackend[Any] = RedisAsyncResultBackend(redis_url=config.redis.dsn) - broker = RedisStreamBroker(url=config.redis.dsn).with_result_backend(result_backend) + result_backend: AsyncResultBackend[Any] = RedisAsyncResultBackend( + redis_url=config.redis.dsn, + keep_results=False, + result_ex_time=3600, + ) + + broker = RedisStreamBroker( + url=config.redis.dsn, + maxlen=1000, + ).with_result_backend(result_backend) + return broker diff --git a/src/infrastructure/taskiq/tasks/notifications.py b/src/infrastructure/taskiq/tasks/notifications.py index 16f832d..1d7afd2 100644 --- a/src/infrastructure/taskiq/tasks/notifications.py +++ b/src/infrastructure/taskiq/tasks/notifications.py @@ -1,12 +1,11 @@ import asyncio from typing import Any, Union, cast -from aiogram.types import BufferedInputFile from dishka.integrations.taskiq import FromDishka, inject from src.bot.keyboards import get_buy_keyboard, get_renew_keyboard from src.core.constants import BATCH_DELAY, BATCH_SIZE -from src.core.enums import MediaType, UserNotificationType +from src.core.enums import UserNotificationType from src.core.utils.iterables import chunked from src.core.utils.message_payload import MessagePayload from src.core.utils.types import RemnaUserDto @@ -23,13 +22,11 @@ async def send_error_notification_task( payload: MessagePayload, notification_service: FromDishka[NotificationService], ) -> None: - file_data = BufferedInputFile( - file=traceback_str.encode(), - filename=f"error_{error_id}.txt", + await notification_service.error_notify( + traceback_str=traceback_str, + payload=payload, + error_id=error_id, ) - payload.media = file_data - payload.media_type = MediaType.DOCUMENT - await notification_service.notify_super_dev(payload=payload) @broker.task diff --git a/src/services/broadcast.py b/src/services/broadcast.py index 3f2847a..1000659 100644 --- a/src/services/broadcast.py +++ b/src/services/broadcast.py @@ -41,8 +41,10 @@ def __init__( async def create(self, broadcast: BroadcastDto) -> BroadcastDto: db_broadcast = Broadcast(**broadcast.model_dump()) - db_created_broadcast = await self.uow.repository.broadcasts.create(db_broadcast) - await self.uow.commit() + + async with self.uow: + db_created_broadcast = await self.uow.repository.broadcasts.create(db_broadcast) + logger.info(f"Created broadcast '{broadcast.task_id}'") return BroadcastDto.from_model(db_created_broadcast) # type: ignore[return-value] @@ -59,11 +61,15 @@ async def create_messages( ) for m in messages ] - db_created_messages = await self.uow.repository.broadcasts.create_messages(db_messages) + + async with self.uow: + db_created_messages = await self.uow.repository.broadcasts.create_messages(db_messages) + return BroadcastMessageDto.from_model_list(db_created_messages) async def get(self, task_id: UUID) -> Optional[BroadcastDto]: - db_broadcast = await self.uow.repository.broadcasts.get(task_id) + async with self.uow: + db_broadcast = await self.uow.repository.broadcasts.get(task_id) if db_broadcast: logger.debug(f"Retrieved broadcast '{task_id}'") @@ -73,14 +79,17 @@ async def get(self, task_id: UUID) -> Optional[BroadcastDto]: return BroadcastDto.from_model(db_broadcast) async def get_all(self) -> list[BroadcastDto]: - db_broadcasts = await self.uow.repository.broadcasts.get_all() + async with self.uow: + db_broadcasts = await self.uow.repository.broadcasts.get_all() + return BroadcastDto.from_model_list(list(reversed(db_broadcasts))) async def update(self, broadcast: BroadcastDto) -> Optional[BroadcastDto]: - db_updated_broadcast = await self.uow.repository.broadcasts.update( - task_id=broadcast.task_id, - **broadcast.changed_data, - ) + async with self.uow: + db_updated_broadcast = await self.uow.repository.broadcasts.update( + task_id=broadcast.task_id, + **broadcast.changed_data, + ) if db_updated_broadcast: logger.info(f"Updated broadcast '{broadcast.task_id}' successfully") @@ -93,22 +102,27 @@ async def update(self, broadcast: BroadcastDto) -> Optional[BroadcastDto]: return BroadcastDto.from_model(db_updated_broadcast) async def update_message(self, broadcast_id: int, message: BroadcastMessageDto) -> None: - await self.uow.repository.broadcasts.update_message( - broadcast_id=broadcast_id, - user_id=message.user_id, - **message.changed_data, - ) + async with self.uow: + await self.uow.repository.broadcasts.update_message( + broadcast_id=broadcast_id, + user_id=message.user_id, + **message.changed_data, + ) async def bulk_update_messages(self, messages: list[BroadcastMessageDto]) -> None: - await self.uow.repository.broadcasts.bulk_update_messages( - data=[m.model_dump() for m in messages], - ) + async with self.uow: + await self.uow.repository.broadcasts.bulk_update_messages( + data=[m.model_dump() for m in messages], + ) async def delete_broadcast(self, broadcast_id: int) -> None: - await self.uow.repository.broadcasts._delete(Broadcast, Broadcast.id == broadcast_id) + async with self.uow: + await self.uow.repository.broadcasts._delete(Broadcast, Broadcast.id == broadcast_id) async def get_status(self, task_id: UUID) -> Optional[BroadcastStatus]: - db_broadcast = await self.uow.repository.broadcasts.get(task_id) + async with self.uow: + db_broadcast = await self.uow.repository.broadcasts.get(task_id) + return db_broadcast.status if db_broadcast else None # @@ -127,7 +141,9 @@ async def get_audience_count( if audience == BroadcastAudience.PLAN: if plan_id: - db_subs = await self.uow.repository.subscriptions.filter_by_plan_id(plan_id) + async with self.uow: + db_subs = await self.uow.repository.subscriptions.filter_by_plan_id(plan_id) + active_subs = [ s for s in db_subs @@ -137,43 +153,62 @@ async def get_audience_count( ] return len(active_subs) - count = await self.uow.repository.plans._count( - Plan, - Plan.availability != PlanAvailability.TRIAL, - ) + async with self.uow: + count = await self.uow.repository.plans._count( + Plan, + Plan.availability != PlanAvailability.TRIAL, + ) + logger.debug(f"Audience count for '{audience}' (plan={plan_id}) is '{count}'") return count if audience == BroadcastAudience.ALL: - return await self.uow.repository.users._count(User, is_not_block) + async with self.uow: + result = await self.uow.repository.users._count(User, is_not_block) + return result if audience == BroadcastAudience.SUBSCRIBED: conditions = and_( is_not_block, User.current_subscription.has(Subscription.status == SubscriptionStatus.ACTIVE), ) - return await self.uow.repository.users._count(User, conditions) + + async with self.uow: + result = await self.uow.repository.users._count(User, conditions) + + return result if audience == BroadcastAudience.UNSUBSCRIBED: conditions = and_( is_not_block, User.current_subscription_id.is_(None), ) - return await self.uow.repository.users._count(User, conditions) + + async with self.uow: + result = await self.uow.repository.users._count(User, conditions) + return result if audience == BroadcastAudience.EXPIRED: conditions = and_( is_not_block, User.current_subscription.has(Subscription.status == SubscriptionStatus.EXPIRED), ) - return await self.uow.repository.users._count(User, conditions) + + async with self.uow: + result = await self.uow.repository.users._count(User, conditions) + + return result if audience == BroadcastAudience.TRIAL: conditions = and_( is_not_block, User.current_subscription.has(Subscription.is_trial.is_(True)), ) - return await self.uow.repository.users._count(User, conditions) + + async with self.uow: + result = await self.uow.repository.users._count(User, conditions) + + return result raise Exception(f"Unknown broadcast audience: {audience}") @@ -190,7 +225,11 @@ async def get_audience_users( ) if audience == BroadcastAudience.PLAN and plan_id: - db_subscriptions = await self.uow.repository.subscriptions.filter_by_plan_id(plan_id) + async with self.uow: + db_subscriptions = await self.uow.repository.subscriptions.filter_by_plan_id( + plan_id + ) + active_subs = [ s for s in db_subscriptions @@ -199,14 +238,19 @@ async def get_audience_users( and not s.user.is_bot_blocked ] user_ids = [sub.user_telegram_id for sub in active_subs] - db_users = await self.uow.repository.users.get_by_ids(telegram_ids=user_ids) + + async with self.uow: + db_users = await self.uow.repository.users.get_by_ids(telegram_ids=user_ids) + logger.debug( f"Retrieved '{len(db_users)}' users for audience '{audience}' (plan={plan_id})" ) return UserDto.from_model_list(db_users) if audience == BroadcastAudience.ALL: - db_users = await self.uow.repository.users._get_many(User, is_not_block) + async with self.uow: + db_users = await self.uow.repository.users._get_many(User, is_not_block) + return UserDto.from_model_list(db_users) if audience == BroadcastAudience.SUBSCRIBED: @@ -214,12 +258,18 @@ async def get_audience_users( is_not_block, User.current_subscription.has(Subscription.status == SubscriptionStatus.ACTIVE), ) - db_users = await self.uow.repository.users._get_many(User, conditions) + + async with self.uow: + db_users = await self.uow.repository.users._get_many(User, conditions) + return UserDto.from_model_list(db_users) if audience == BroadcastAudience.UNSUBSCRIBED: conditions = and_(is_not_block, User.current_subscription_id.is_(None)) - db_users = await self.uow.repository.users._get_many(User, conditions) + + async with self.uow: + db_users = await self.uow.repository.users._get_many(User, conditions) + return UserDto.from_model_list(db_users) if audience == BroadcastAudience.EXPIRED: @@ -227,7 +277,10 @@ async def get_audience_users( is_not_block, User.current_subscription.has(Subscription.status == SubscriptionStatus.EXPIRED), ) - db_users = await self.uow.repository.users._get_many(User, conditions) + + async with self.uow: + db_users = await self.uow.repository.users._get_many(User, conditions) + return UserDto.from_model_list(db_users) if audience == BroadcastAudience.TRIAL: @@ -235,7 +288,10 @@ async def get_audience_users( is_not_block, User.current_subscription.has(Subscription.is_trial.is_(True)), ) - db_users = await self.uow.repository.users._get_many(User, conditions) + + async with self.uow: + db_users = await self.uow.repository.users._get_many(User, conditions) + return UserDto.from_model_list(db_users) raise Exception(f"Unknown broadcast audience: {audience}") diff --git a/src/services/payment_gateway.py b/src/services/payment_gateway.py index c71b16a..539414c 100644 --- a/src/services/payment_gateway.py +++ b/src/services/payment_gateway.py @@ -113,24 +113,27 @@ async def create_default(self) -> None: logger.warning(f"Unhandled payment gateway type '{gateway_type}' - skipping") continue - order_index = await self.uow.repository.gateways.get_max_index() - order_index = (order_index or 0) + 1 - - payment_gateway = PaymentGatewayDto( - order_index=order_index, - type=gateway_type, - currency=Currency.from_gateway_type(gateway_type), - is_active=is_active, - settings=settings, - ) + async with self.uow: + order_index = await self.uow.repository.gateways.get_max_index() + + order_index = (order_index or 0) + 1 + + payment_gateway = PaymentGatewayDto( + order_index=order_index, + type=gateway_type, + currency=Currency.from_gateway_type(gateway_type), + is_active=is_active, + settings=settings, + ) - db_payment_gateway = PaymentGateway(**payment_gateway.model_dump()) - db_payment_gateway = await self.uow.repository.gateways.create(db_payment_gateway) + db_payment_gateway = PaymentGateway(**payment_gateway.model_dump()) + db_payment_gateway = await self.uow.repository.gateways.create(db_payment_gateway) logger.info(f"Payment gateway '{gateway_type}' created") async def get(self, gateway_id: int) -> Optional[PaymentGatewayDto]: - db_gateway = await self.uow.repository.gateways.get(gateway_id) + async with self.uow: + db_gateway = await self.uow.repository.gateways.get(gateway_id) if not db_gateway: logger.warning(f"Payment gateway '{gateway_id}' not found") @@ -140,7 +143,8 @@ async def get(self, gateway_id: int) -> Optional[PaymentGatewayDto]: return PaymentGatewayDto.from_model(db_gateway, decrypt=True) async def get_by_type(self, gateway_type: PaymentGatewayType) -> Optional[PaymentGatewayDto]: - db_gateway = await self.uow.repository.gateways.get_by_type(gateway_type) + async with self.uow: + db_gateway = await self.uow.repository.gateways.get_by_type(gateway_type) if not db_gateway: logger.warning(f"Payment gateway of type '{gateway_type}' not found") @@ -150,7 +154,9 @@ async def get_by_type(self, gateway_type: PaymentGatewayType) -> Optional[Paymen return PaymentGatewayDto.from_model(db_gateway, decrypt=True) async def get_all(self, sorted: bool = False) -> list[PaymentGatewayDto]: - db_gateways = await self.uow.repository.gateways.get_all(sorted) + async with self.uow: + db_gateways = await self.uow.repository.gateways.get_all(sorted) + logger.debug(f"Retrieved '{len(db_gateways)}' payment gateways") return PaymentGatewayDto.from_model_list(db_gateways, decrypt=False) @@ -160,10 +166,11 @@ async def update(self, gateway: PaymentGatewayDto) -> Optional[PaymentGatewayDto if gateway.settings and gateway.settings.changed_data: updated_data["settings"] = gateway.settings.prepare_init_data(encrypt=True) - db_updated_gateway = await self.uow.repository.gateways.update( - gateway_id=gateway.id, # type: ignore[arg-type] - **updated_data, - ) + async with self.uow: + db_updated_gateway = await self.uow.repository.gateways.update( + gateway_id=gateway.id, # type: ignore[arg-type] + **updated_data, + ) if db_updated_gateway: logger.info(f"Payment gateway '{gateway.type}' updated successfully") @@ -176,29 +183,37 @@ async def update(self, gateway: PaymentGatewayDto) -> Optional[PaymentGatewayDto return PaymentGatewayDto.from_model(db_updated_gateway, decrypt=True) async def filter_active(self, is_active: bool = True) -> list[PaymentGatewayDto]: - db_gateways = await self.uow.repository.gateways.filter_active(is_active) + async with self.uow: + db_gateways = await self.uow.repository.gateways.filter_active(is_active) + logger.debug(f"Filtered active gateways: '{is_active}', found '{len(db_gateways)}'") return PaymentGatewayDto.from_model_list(db_gateways, decrypt=False) async def move_gateway_up(self, gateway_id: int) -> bool: - db_gateways = await self.uow.repository.gateways.get_all() - db_gateways.sort(key=lambda p: p.order_index) - - index = next((i for i, p in enumerate(db_gateways) if p.id == gateway_id), None) - if index is None: - logger.warning(f"Payment gateway with ID '{gateway_id}' not found for move operation") - return False - - if index == 0: - gateway = db_gateways.pop(0) - db_gateways.append(gateway) - logger.debug(f"Payment gateway '{gateway_id}' moved from top to bottom") - else: - db_gateways[index - 1], db_gateways[index] = db_gateways[index], db_gateways[index - 1] - logger.debug(f"Payment gateway '{gateway_id}' moved up one position") - - for i, gateway in enumerate(db_gateways, start=1): - gateway.order_index = i + async with self.uow: + db_gateways = await self.uow.repository.gateways.get_all() + db_gateways.sort(key=lambda p: p.order_index) + + index = next((i for i, p in enumerate(db_gateways) if p.id == gateway_id), None) + if index is None: + logger.warning( + f"Payment gateway with ID '{gateway_id}' not found for move operation" + ) + return False + + if index == 0: + gateway = db_gateways.pop(0) + db_gateways.append(gateway) + logger.debug(f"Payment gateway '{gateway_id}' moved from top to bottom") + else: + db_gateways[index - 1], db_gateways[index] = ( + db_gateways[index], + db_gateways[index - 1], + ) + logger.debug(f"Payment gateway '{gateway_id}' moved up one position") + + for i, gateway in enumerate(db_gateways, start=1): + gateway.order_index = i logger.info(f"Payment gateway '{gateway_id}' reorder successfully") return True diff --git a/src/services/plan.py b/src/services/plan.py index e0ffc4d..d56d8c8 100644 --- a/src/services/plan.py +++ b/src/services/plan.py @@ -32,17 +32,19 @@ def __init__( self.uow = uow async def create(self, plan: PlanDto) -> PlanDto: - order_index = await self.uow.repository.plans.get_max_index() - order_index = (order_index or 0) + 1 - plan.order_index = order_index + async with self.uow: + order_index = await self.uow.repository.plans.get_max_index() + order_index = (order_index or 0) + 1 + plan.order_index = order_index + db_plan = self._dto_to_model(plan) + db_created_plan = await self.uow.repository.plans.create(db_plan) - db_plan = self._dto_to_model(plan) - db_created_plan = await self.uow.repository.plans.create(db_plan) logger.info(f"Created plan '{plan.name}' with ID '{db_created_plan.id}'") return PlanDto.from_model(db_created_plan) # type: ignore[return-value] async def get(self, plan_id: int) -> Optional[PlanDto]: - db_plan = await self.uow.repository.plans.get(plan_id) + async with self.uow: + db_plan = await self.uow.repository.plans.get(plan_id) if db_plan: logger.debug(f"Retrieved plan '{plan_id}'") @@ -52,7 +54,8 @@ async def get(self, plan_id: int) -> Optional[PlanDto]: return PlanDto.from_model(db_plan) async def get_by_name(self, plan_name: str) -> Optional[PlanDto]: - db_plan = await self.uow.repository.plans.get_by_name(plan_name) + async with self.uow: + db_plan = await self.uow.repository.plans.get_by_name(plan_name) if db_plan: logger.debug(f"Retrieved plan by name '{plan_name}'") @@ -62,13 +65,17 @@ async def get_by_name(self, plan_name: str) -> Optional[PlanDto]: return PlanDto.from_model(db_plan) async def get_all(self) -> list[PlanDto]: - db_plans = await self.uow.repository.plans.get_all() + async with self.uow: + db_plans = await self.uow.repository.plans.get_all() + logger.debug(f"Retrieved '{len(db_plans)}' plans") return PlanDto.from_model_list(db_plans) async def update(self, plan: PlanDto) -> Optional[PlanDto]: db_plan = self._dto_to_model(plan) - db_updated_plan = await self.uow.repository.plans.update(db_plan) + + async with self.uow: + db_updated_plan = await self.uow.repository.plans.update(db_plan) if db_updated_plan: logger.info(f"Updated plan '{plan.name}' (ID: '{plan.id}') successfully") @@ -81,7 +88,8 @@ async def update(self, plan: PlanDto) -> Optional[PlanDto]: return PlanDto.from_model(db_updated_plan) async def delete(self, plan_id: int) -> bool: - result = await self.uow.repository.plans.delete(plan_id) + async with self.uow: + result = await self.uow.repository.plans.delete(plan_id) if result: logger.info(f"Plan '{plan_id}' deleted successfully") @@ -91,16 +99,18 @@ async def delete(self, plan_id: int) -> bool: return result async def count(self) -> int: - count = await self.uow.repository.plans.count() + async with self.uow: + count = await self.uow.repository.plans.count() logger.debug(f"Total plans count: '{count}'") return count # async def get_trial_plan(self) -> Optional[PlanDto]: - db_plans: list[Plan] = await self.uow.repository.plans.filter_by_availability( - availability=PlanAvailability.TRIAL - ) + async with self.uow: + db_plans: list[Plan] = await self.uow.repository.plans.filter_by_availability( + availability=PlanAvailability.TRIAL + ) if db_plans: if len(db_plans) > 1: @@ -123,7 +133,9 @@ async def get_trial_plan(self) -> Optional[PlanDto]: async def get_available_plans(self, user: UserDto) -> list[PlanDto]: logger.debug(f"Fetching available plans for user '{user.telegram_id}'") - db_plans: list[Plan] = await self.uow.repository.plans.filter_active(is_active=True) + async with self.uow: + db_plans: list[Plan] = await self.uow.repository.plans.filter_active(is_active=True) + logger.debug(f"Total active plans retrieved: '{len(db_plans)}'") db_filtered_plans = [] @@ -163,9 +175,10 @@ async def get_available_plans(self, user: UserDto) -> list[PlanDto]: return PlanDto.from_model_list(db_filtered_plans) async def get_allowed_plans(self) -> list[PlanDto]: - db_plans: list[Plan] = await self.uow.repository.plans.filter_by_availability( - availability=PlanAvailability.ALLOWED, - ) + async with self.uow: + db_plans: list[Plan] = await self.uow.repository.plans.filter_by_availability( + availability=PlanAvailability.ALLOWED, + ) if db_plans: logger.debug( @@ -177,24 +190,25 @@ async def get_allowed_plans(self) -> list[PlanDto]: return PlanDto.from_model_list(db_plans) async def move_plan_up(self, plan_id: int) -> bool: - db_plans = await self.uow.repository.plans.get_all() - db_plans.sort(key=lambda p: p.order_index) - - index = next((i for i, p in enumerate(db_plans) if p.id == plan_id), None) - if index is None: - logger.warning(f"Plan with ID '{plan_id}' not found for move operation") - return False - - if index == 0: - plan = db_plans.pop(0) - db_plans.append(plan) - logger.debug(f"Plan '{plan_id}' moved from top to bottom") - else: - db_plans[index - 1], db_plans[index] = db_plans[index], db_plans[index - 1] - logger.debug(f"Plan '{plan_id}' moved up one position") + async with self.uow: + db_plans = await self.uow.repository.plans.get_all() + db_plans.sort(key=lambda p: p.order_index) + + index = next((i for i, p in enumerate(db_plans) if p.id == plan_id), None) + if index is None: + logger.warning(f"Plan with ID '{plan_id}' not found for move operation") + return False + + if index == 0: + plan = db_plans.pop(0) + db_plans.append(plan) + logger.debug(f"Plan '{plan_id}' moved from top to bottom") + else: + db_plans[index - 1], db_plans[index] = db_plans[index], db_plans[index - 1] + logger.debug(f"Plan '{plan_id}' moved up one position") - for i, plan in enumerate(db_plans, start=1): - plan.order_index = i + for i, plan in enumerate(db_plans, start=1): + plan.order_index = i logger.info(f"Plan '{plan_id}' reorder successfully") return True diff --git a/src/services/promocode.py b/src/services/promocode.py index c81c901..4ee057d 100644 --- a/src/services/promocode.py +++ b/src/services/promocode.py @@ -35,7 +35,8 @@ async def create(self, promocode: PromocodeDto) -> PromocodeDto: # type: ignore pass async def get(self, promocode_id: int) -> Optional[PromocodeDto]: - db_promocode = await self.uow.repository.promocodes.get(promocode_id) + async with self.uow: + db_promocode = await self.uow.repository.promocodes.get(promocode_id) if db_promocode: logger.debug(f"Retrieved promocode '{promocode_id}'") @@ -45,7 +46,8 @@ async def get(self, promocode_id: int) -> Optional[PromocodeDto]: return PromocodeDto.from_model(db_promocode) async def get_by_code(self, promocode_code: str) -> Optional[PromocodeDto]: - db_promocode = await self.uow.repository.promocodes.get_by_code(promocode_code) + async with self.uow: + db_promocode = await self.uow.repository.promocodes.get_by_code(promocode_code) if db_promocode: logger.debug(f"Retrieved promocode by code '{promocode_code}'") @@ -55,15 +57,18 @@ async def get_by_code(self, promocode_code: str) -> Optional[PromocodeDto]: return PromocodeDto.from_model(db_promocode) async def get_all(self) -> list[PromocodeDto]: - db_promocodes = await self.uow.repository.promocodes.get_all() + async with self.uow: + db_promocodes = await self.uow.repository.promocodes.get_all() + logger.debug(f"Retrieved '{len(db_promocodes)}' promocodes") return PromocodeDto.from_model_list(db_promocodes) async def update(self, promocode: PromocodeDto) -> Optional[PromocodeDto]: - db_updated_promocode = await self.uow.repository.promocodes.update( - promocode_id=promocode.id, # type: ignore[arg-type] - **promocode.changed_data, - ) + async with self.uow: + db_updated_promocode = await self.uow.repository.promocodes.update( + promocode_id=promocode.id, # type: ignore[arg-type] + **promocode.changed_data, + ) if db_updated_promocode: logger.info(f"Updated promocode '{promocode.code}' successfully") @@ -76,7 +81,8 @@ async def update(self, promocode: PromocodeDto) -> Optional[PromocodeDto]: return PromocodeDto.from_model(db_updated_promocode) async def delete(self, promocode_id: int) -> bool: - result = await self.uow.repository.promocodes.delete(promocode_id) + async with self.uow: + result = await self.uow.repository.promocodes.delete(promocode_id) if result: logger.info(f"Promocode '{promocode_id}' deleted successfully") @@ -89,13 +95,17 @@ async def delete(self, promocode_id: int) -> bool: return result async def filter_by_type(self, promocode_type: PromocodeRewardType) -> list[PromocodeDto]: - db_promocodes = await self.uow.repository.promocodes.filter_by_type(promocode_type) + async with self.uow: + db_promocodes = await self.uow.repository.promocodes.filter_by_type(promocode_type) + logger.debug( f"Filtered promocodes by type '{promocode_type}', found '{len(db_promocodes)}'" ) return PromocodeDto.from_model_list(db_promocodes) async def filter_active(self, is_active: bool = True) -> list[PromocodeDto]: - db_promocodes = await self.uow.repository.promocodes.filter_active(is_active) + async with self.uow: + db_promocodes = await self.uow.repository.promocodes.filter_active(is_active) + logger.debug(f"Filtered active promocodes: '{is_active}', found '{len(db_promocodes)}'") return PromocodeDto.from_model_list(db_promocodes) diff --git a/src/services/referral.py b/src/services/referral.py index 35ecbf1..aea64d8 100644 --- a/src/services/referral.py +++ b/src/services/referral.py @@ -71,13 +71,14 @@ async def create_referral( referred: UserDto, level: ReferralLevel, ) -> ReferralDto: - referral = await self.uow.repository.referrals.create_referral( - Referral( - referrer_telegram_id=referrer.telegram_id, - referred_telegram_id=referred.telegram_id, - level=level, + async with self.uow: + referral = await self.uow.repository.referrals.create_referral( + Referral( + referrer_telegram_id=referrer.telegram_id, + referred_telegram_id=referred.telegram_id, + level=level, + ) ) - ) await self.user_service.clear_user_cache(referrer.telegram_id) await self.user_service.clear_user_cache(referred.telegram_id) @@ -85,20 +86,28 @@ async def create_referral( return ReferralDto.from_model(referral) # type: ignore[return-value] async def get_referral_by_referred(self, telegram_id: int) -> Optional[ReferralDto]: - referral = await self.uow.repository.referrals.get_referral_by_referred(telegram_id) + async with self.uow: + referral = await self.uow.repository.referrals.get_referral_by_referred(telegram_id) + return ReferralDto.from_model(referral) if referral else None async def get_referrals_by_referrer(self, telegram_id: int) -> List[ReferralDto]: - referrals = await self.uow.repository.referrals.get_referrals_by_referrer(telegram_id) + async with self.uow: + referrals = await self.uow.repository.referrals.get_referrals_by_referrer(telegram_id) + return ReferralDto.from_model_list(referrals) async def get_referral_count(self, telegram_id: int) -> int: - count = await self.uow.repository.referrals.count_referrals_by_referrer(telegram_id) + async with self.uow: + count = await self.uow.repository.referrals.count_referrals_by_referrer(telegram_id) + logger.debug(f"Retrieved counted '{count}' referrals for user '{telegram_id}'") return count async def get_reward_count(self, telegram_id: int) -> int: - count = await self.uow.repository.referrals.count_rewards_by_referrer(telegram_id) + async with self.uow: + count = await self.uow.repository.referrals.count_rewards_by_referrer(telegram_id) + logger.debug(f"Retrieved counted '{count}' rewards for user '{telegram_id}'") return count @@ -107,10 +116,12 @@ async def get_total_rewards_amount( telegram_id: int, reward_type: ReferralRewardType, ) -> int: - total_amount = await self.uow.repository.referrals.sum_rewards_by_user( - telegram_id, - reward_type, - ) + async with self.uow: + total_amount = await self.uow.repository.referrals.sum_rewards_by_user( + telegram_id, + reward_type, + ) + logger.debug( f"Retrieved calculated total rewards amount as '{total_amount}' " f"for user 'user_telegram_id' for type '{reward_type.name}'" @@ -124,30 +135,38 @@ async def create_reward( type: ReferralRewardType, amount: int, ) -> ReferralRewardDto: - reward = await self.uow.repository.referrals.create_reward( - ReferralReward( - referral_id=referral_id, - user_telegram_id=user_telegram_id, - type=type, - amount=amount, - is_issued=False, + async with self.uow: + reward = await self.uow.repository.referrals.create_reward( + ReferralReward( + referral_id=referral_id, + user_telegram_id=user_telegram_id, + type=type, + amount=amount, + is_issued=False, + ) ) - ) + logger.info(f"ReferralReward '{referral_id} created, user '{user_telegram_id}'") return ReferralRewardDto.from_model(reward) # type: ignore[return-value] async def get_rewards_by_user(self, telegram_id: int) -> List[ReferralRewardDto]: - rewards = await self.uow.repository.referrals.get_rewards_by_user(telegram_id) + async with self.uow: + rewards = await self.uow.repository.referrals.get_rewards_by_user(telegram_id) + return ReferralRewardDto.from_model_list(rewards) async def get_rewards_by_referral(self, referral_id: int) -> List[ReferralRewardDto]: - rewards = await self.uow.repository.referrals.get_rewards_by_referral(referral_id) + async with self.uow: + rewards = await self.uow.repository.referrals.get_rewards_by_referral(referral_id) + return ReferralRewardDto.from_model_list(rewards) # async def mark_reward_as_issued(self, reward_id: int) -> None: - await self.uow.repository.referrals.update_reward(reward_id, is_issued=True) + async with self.uow: + await self.uow.repository.referrals.update_reward(reward_id, is_issued=True) + logger.info(f"Marked reward '{reward_id}' as issued") async def handle_referral(self, user: UserDto, code: Optional[str]) -> None: diff --git a/src/services/settings.py b/src/services/settings.py index 37f48b9..43c4f2d 100644 --- a/src/services/settings.py +++ b/src/services/settings.py @@ -38,7 +38,9 @@ def __init__( async def create(self) -> SettingsDto: settings = SettingsDto() db_settings = Settings(**settings.prepare_init_data()) - db_settings = await self.uow.repository.settings.create(db_settings) + + async with self.uow: + db_settings = await self.uow.repository.settings.create(db_settings) await self._clear_cache() logger.info("Default settings created in DB") @@ -46,7 +48,9 @@ async def create(self) -> SettingsDto: @redis_cache(prefix="get_settings", ttl=TIME_10M) async def get(self) -> SettingsDto: - db_settings = await self.uow.repository.settings.get() + async with self.uow: + db_settings = await self.uow.repository.settings.get() + if not db_settings: return await self.create() else: @@ -65,7 +69,10 @@ async def update(self, settings: SettingsDto) -> SettingsDto: settings.referral = settings.referral changed_data = settings.prepare_changed_data() - db_updated_settings = await self.uow.repository.settings.update(**changed_data) + + async with self.uow: + db_updated_settings = await self.uow.repository.settings.update(**changed_data) + await self._clear_cache() if changed_data: diff --git a/src/services/subscription.py b/src/services/subscription.py index a4f4941..9d054d6 100644 --- a/src/services/subscription.py +++ b/src/services/subscription.py @@ -55,13 +55,16 @@ async def create(self, user: UserDto, subscription: SubscriptionDto) -> Subscrip data["plan"] = subscription.plan.model_dump(mode="json") db_subscription = Subscription(**data, user_telegram_id=user.telegram_id) - db_created_subscription = await self.uow.repository.subscriptions.create(db_subscription) + + async with self.uow: + db_created_subscription = await self.uow.repository.subscriptions.create( + db_subscription + ) await self.user_service.set_current_subscription( telegram_id=user.telegram_id, subscription_id=db_created_subscription.id, ) - await self.uow.commit() await self.clear_subscription_cache(db_subscription.id, db_subscription.user_telegram_id) logger.info(f"Created subscription '{db_subscription.id}' for user '{user.telegram_id}'") @@ -69,7 +72,8 @@ async def create(self, user: UserDto, subscription: SubscriptionDto) -> Subscrip @redis_cache(prefix="get_subscription", ttl=TIME_5M) async def get(self, subscription_id: int) -> Optional[SubscriptionDto]: - db_subscription = await self.uow.repository.subscriptions.get(subscription_id) + async with self.uow: + db_subscription = await self.uow.repository.subscriptions.get(subscription_id) if db_subscription: logger.debug(f"Retrieved subscription '{subscription_id}'") @@ -80,16 +84,17 @@ async def get(self, subscription_id: int) -> Optional[SubscriptionDto]: @redis_cache(prefix="get_current_subscription", ttl=TIME_1M) async def get_current(self, telegram_id: int) -> Optional[SubscriptionDto]: - db_user = await self.uow.repository.users.get(telegram_id) + async with self.uow: + db_user = await self.uow.repository.users.get(telegram_id) - if not db_user or not db_user.current_subscription_id: - logger.debug( - f"Current subscription check: User '{telegram_id}' has no active subscription" - ) - return None + if not db_user or not db_user.current_subscription_id: + logger.debug( + f"Current subscription check: User '{telegram_id}' has no active subscription" + ) + return None - subscription_id = db_user.current_subscription_id - db_active_subscription = await self.uow.repository.subscriptions.get(subscription_id) + subscription_id = db_user.current_subscription_id + db_active_subscription = await self.uow.repository.subscriptions.get(subscription_id) if db_active_subscription: logger.debug( @@ -105,12 +110,16 @@ async def get_current(self, telegram_id: int) -> Optional[SubscriptionDto]: return SubscriptionDto.from_model(db_active_subscription) async def get_all_by_user(self, telegram_id: int) -> list[SubscriptionDto]: - db_subscriptions = await self.uow.repository.subscriptions.get_all_by_user(telegram_id) + async with self.uow: + db_subscriptions = await self.uow.repository.subscriptions.get_all_by_user(telegram_id) + logger.debug(f"Retrieved '{len(db_subscriptions)}' subscriptions for user '{telegram_id}'") return SubscriptionDto.from_model_list(db_subscriptions) async def get_all(self) -> list[SubscriptionDto]: - db_subscriptions = await self.uow.repository.subscriptions.get_all() + async with self.uow: + db_subscriptions = await self.uow.repository.subscriptions.get_all() + logger.debug(f"Retrieved '{len(db_subscriptions)}' total subscriptions") return SubscriptionDto.from_model_list(db_subscriptions) @@ -120,12 +129,11 @@ async def update(self, subscription: SubscriptionDto) -> Optional[SubscriptionDt if subscription.plan.changed_data or "plan" in data: data["plan"] = subscription.plan.model_dump(mode="json") - db_updated_subscription = await self.uow.repository.subscriptions.update( - subscription_id=subscription.id, # type: ignore[arg-type] - **data, - ) - - await self.uow.commit() + async with self.uow: + db_updated_subscription = await self.uow.repository.subscriptions.update( + subscription_id=subscription.id, # type: ignore[arg-type] + **data, + ) if db_updated_subscription: await self.clear_subscription_cache( @@ -150,7 +158,9 @@ async def has_used_trial(self, user_telegram_id: int) -> bool: Subscription.status != SubscriptionStatus.DELETED, ) - count = await self.uow.repository.subscriptions._count(Subscription, conditions) + async with self.uow: + count = await self.uow.repository.subscriptions._count(Subscription, conditions) + return count > 0 async def clear_subscription_cache(self, subscription_id: int, user_telegram_id: int) -> None: diff --git a/src/services/transaction.py b/src/services/transaction.py index fbc7790..2f230aa 100644 --- a/src/services/transaction.py +++ b/src/services/transaction.py @@ -38,12 +38,16 @@ async def create(self, user: UserDto, transaction: TransactionDto) -> Transactio data["pricing"] = transaction.pricing.model_dump(mode="json") db_transaction = Transaction(**data, user_telegram_id=user.telegram_id) - db_created_transaction = await self.uow.repository.transactions.create(db_transaction) + + async with self.uow: + db_created_transaction = await self.uow.repository.transactions.create(db_transaction) + logger.info(f"Created transaction '{transaction.payment_id}' for user '{user.telegram_id}'") return TransactionDto.from_model(db_created_transaction) # type: ignore[return-value] async def get(self, payment_id: UUID) -> Optional[TransactionDto]: - db_transaction = await self.uow.repository.transactions.get(payment_id) + async with self.uow: + db_transaction = await self.uow.repository.transactions.get(payment_id) if db_transaction: logger.debug(f"Retrieved transaction '{payment_id}'") @@ -53,25 +57,32 @@ async def get(self, payment_id: UUID) -> Optional[TransactionDto]: return TransactionDto.from_model(db_transaction) async def get_by_user(self, telegram_id: int) -> list[TransactionDto]: - db_transactions = await self.uow.repository.transactions.get_by_user(telegram_id) + async with self.uow: + db_transactions = await self.uow.repository.transactions.get_by_user(telegram_id) + logger.debug(f"Retrieved '{len(db_transactions)}' transactions for user '{telegram_id}'") return TransactionDto.from_model_list(db_transactions) async def get_all(self) -> list[TransactionDto]: - db_transactions = await self.uow.repository.transactions.get_all() + async with self.uow: + db_transactions = await self.uow.repository.transactions.get_all() + logger.debug(f"Retrieved '{len(db_transactions)}' total transactions") return TransactionDto.from_model_list(db_transactions) async def get_by_status(self, status: TransactionStatus) -> list[TransactionDto]: - db_transactions = await self.uow.repository.transactions.get_by_status(status) + async with self.uow: + db_transactions = await self.uow.repository.transactions.get_by_status(status) + logger.debug(f"Retrieved '{len(db_transactions)}' transactions with status '{status}'") return TransactionDto.from_model_list(db_transactions) async def update(self, transaction: TransactionDto) -> Optional[TransactionDto]: - db_updated_transaction = await self.uow.repository.transactions.update( - payment_id=transaction.payment_id, - **transaction.changed_data, - ) + async with self.uow: + db_updated_transaction = await self.uow.repository.transactions.update( + payment_id=transaction.payment_id, + **transaction.changed_data, + ) if db_updated_transaction: logger.info(f"Updated transaction '{transaction.payment_id}' successfully") @@ -84,11 +95,15 @@ async def update(self, transaction: TransactionDto) -> Optional[TransactionDto]: return TransactionDto.from_model(db_updated_transaction) async def count(self) -> int: - count = await self.uow.repository.transactions.count() + async with self.uow: + count = await self.uow.repository.transactions.count() + logger.debug(f"Total transactions count: '{count}'") return count async def count_by_status(self, status: TransactionStatus) -> int: - count = await self.uow.repository.transactions.count_by_status(status) + async with self.uow: + count = await self.uow.repository.transactions.count_by_status(status) + logger.debug(f"Transactions count with status '{status}': '{count}'") return count diff --git a/src/services/user.py b/src/services/user.py index bf72941..7f4a959 100644 --- a/src/services/user.py +++ b/src/services/user.py @@ -63,8 +63,9 @@ async def create(self, aiogram_user: AiogramUser) -> UserDto: ), ) db_user = User(**user.model_dump()) - db_created_user = await self.uow.repository.users.create(db_user) - await self.uow.commit() + + async with self.uow: + db_created_user = await self.uow.repository.users.create(db_user) await self.clear_user_cache(user.telegram_id) logger.info(f"Created new user '{user.telegram_id}'") @@ -82,8 +83,9 @@ async def create_from_panel(self, remna_user: RemnaUserDto) -> UserDto: language=self.config.default_locale, ) db_user = User(**user.model_dump()) - db_created_user = await self.uow.repository.users.create(db_user) - await self.uow.commit() + + async with self.uow: + db_created_user = await self.uow.repository.users.create(db_user) await self.clear_user_cache(user.telegram_id) logger.info(f"Created new user '{user.telegram_id}' from panel") @@ -91,7 +93,8 @@ async def create_from_panel(self, remna_user: RemnaUserDto) -> UserDto: @redis_cache(prefix="get_user", ttl=TIME_5M) async def get(self, telegram_id: int) -> Optional[UserDto]: - db_user = await self.uow.repository.users.get(telegram_id) + async with self.uow: + db_user = await self.uow.repository.users.get(telegram_id) if db_user: logger.debug(f"Retrieved user '{telegram_id}'") @@ -101,10 +104,11 @@ async def get(self, telegram_id: int) -> Optional[UserDto]: return UserDto.from_model(db_user) async def update(self, user: UserDto) -> Optional[UserDto]: - db_updated_user = await self.uow.repository.users.update( - telegram_id=user.telegram_id, - **user.prepare_changed_data(), - ) + async with self.uow: + db_updated_user = await self.uow.repository.users.update( + telegram_id=user.telegram_id, + **user.prepare_changed_data(), + ) if db_updated_user: await self.clear_user_cache(db_updated_user.telegram_id) @@ -156,7 +160,8 @@ async def compare_and_update( return await self.update(user) async def delete(self, user: UserDto) -> bool: - result = await self.uow.repository.users.delete(user.telegram_id) + async with self.uow: + result = await self.uow.repository.users.delete(user.telegram_id) if result: await self.clear_user_cache(user.telegram_id) @@ -166,62 +171,83 @@ async def delete(self, user: UserDto) -> bool: return result async def get_by_partial_name(self, query: str) -> list[UserDto]: - db_users = await self.uow.repository.users.get_by_partial_name(query) + async with self.uow: + db_users = await self.uow.repository.users.get_by_partial_name(query) + logger.debug(f"Retrieved '{len(db_users)}' users for query '{query}'") return UserDto.from_model_list(db_users) async def get_by_referral_code(self, referral_code: str) -> Optional[UserDto]: - user = await self.uow.repository.users.get_by_referral_code(referral_code) + async with self.uow: + user = await self.uow.repository.users.get_by_referral_code(referral_code) + return UserDto.from_model(user) @redis_cache(prefix="users_count", ttl=TIME_10M) async def count(self) -> int: - count = await self.uow.repository.users.count() + async with self.uow: + count = await self.uow.repository.users.count() + logger.debug(f"Total users count: '{count}'") return count @redis_cache(prefix="get_by_role", ttl=TIME_10M) async def get_by_role(self, role: UserRole) -> list[UserDto]: - db_users = await self.uow.repository.users.filter_by_role(role) + async with self.uow: + db_users = await self.uow.repository.users.filter_by_role(role) + logger.debug(f"Retrieved '{len(db_users)}' users with role '{role}'") return UserDto.from_model_list(db_users) @redis_cache(prefix="get_blocked_users", ttl=TIME_10M) async def get_blocked_users(self) -> list[UserDto]: - db_users = await self.uow.repository.users.filter_by_blocked(blocked=True) + async with self.uow: + db_users = await self.uow.repository.users.filter_by_blocked(blocked=True) + logger.debug(f"Retrieved '{len(db_users)}' blocked users") return UserDto.from_model_list(list(reversed(db_users))) @redis_cache(prefix="get_all", ttl=TIME_10M) async def get_all(self) -> list[UserDto]: - db_users = await self.uow.repository.users.get_all() + async with self.uow: + db_users = await self.uow.repository.users.get_all() + logger.debug(f"Retrieved '{len(db_users)}' users") return UserDto.from_model_list(db_users) async def set_block(self, user: UserDto, blocked: bool) -> None: user.is_blocked = blocked - await self.uow.repository.users.update( - user.telegram_id, - **user.prepare_changed_data(), - ) + + async with self.uow: + await self.uow.repository.users.update( + user.telegram_id, + **user.prepare_changed_data(), + ) + await self.clear_user_cache(user.telegram_id) logger.info(f"Set block={blocked} for user '{user.telegram_id}'") async def set_bot_blocked(self, user: UserDto, blocked: bool) -> None: user.is_bot_blocked = blocked - await self.uow.repository.users.update( - user.telegram_id, - **user.prepare_changed_data(), - ) + + async with self.uow: + await self.uow.repository.users.update( + user.telegram_id, + **user.prepare_changed_data(), + ) + await self.clear_user_cache(user.telegram_id) logger.info(f"Set bot_blocked={blocked} for user '{user.telegram_id}'") async def set_role(self, user: UserDto, role: UserRole) -> None: user.role = role - await self.uow.repository.users.update( - user.telegram_id, - **user.prepare_changed_data(), - ) + + async with self.uow: + await self.uow.repository.users.update( + user.telegram_id, + **user.prepare_changed_data(), + ) + await self.clear_user_cache(user.telegram_id) logger.info(f"Set role='{role.name}' for user '{user.telegram_id}'") @@ -231,11 +257,12 @@ async def update_recent_activity(self, telegram_id: int) -> None: await self._add_to_recent_activity(RecentActivityUsersKey(), telegram_id) async def get_recent_registered_users(self) -> list[UserDto]: - db_users = await self.uow.repository.users._get_many( - User, - order_by=User.id.asc(), - limit=RECENT_REGISTERED_MAX_COUNT, - ) + async with self.uow: + db_users = await self.uow.repository.users._get_many( + User, + order_by=User.id.asc(), + limit=RECENT_REGISTERED_MAX_COUNT, + ) logger.debug(f"Retrieved '{len(db_users)}' recent registered users") return UserDto.from_model_list(list(reversed(db_users))) @@ -314,26 +341,32 @@ async def search_users(self, message: Message) -> list[UserDto]: return found_users async def set_current_subscription(self, telegram_id: int, subscription_id: int) -> None: - await self.uow.repository.users.update( - telegram_id=telegram_id, - current_subscription_id=subscription_id, - ) + async with self.uow: + await self.uow.repository.users.update( + telegram_id=telegram_id, + current_subscription_id=subscription_id, + ) + await self.clear_user_cache(telegram_id) logger.info(f"Set current_subscription='{subscription_id}' for user '{telegram_id}'") async def delete_current_subscription(self, telegram_id: int) -> None: - await self.uow.repository.users.update( - telegram_id=telegram_id, - current_subscription_id=None, - ) + async with self.uow: + await self.uow.repository.users.update( + telegram_id=telegram_id, + current_subscription_id=None, + ) + await self.clear_user_cache(telegram_id) logger.info(f"Delete current subscription for user '{telegram_id}'") async def add_points(self, user: Union[BaseUserDto, UserDto], points: int) -> None: - await self.uow.repository.users.update( - telegram_id=user.telegram_id, - points=user.points + points, - ) + async with self.uow: + await self.uow.repository.users.update( + telegram_id=user.telegram_id, + points=user.points + points, + ) + await self.clear_user_cache(user.telegram_id) logger.info(f"Add '{points}' points for user '{user.telegram_id}'")