Skip to content
This repository was archived by the owner on Jul 9, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion app/common/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sys
from pathlib import Path
from typing import Any

from aiosmtplib import SMTP
from cryptography.fernet import Fernet
from pydantic import AmqpDsn, BaseModel, Field, PostgresDsn, computed_field
from pydantic import AmqpDsn, BaseModel, Field, PostgresDsn, computed_field, RedisDsn
from pydantic_settings import BaseSettings, SettingsConfigDict
from redis.asyncio import ConnectionPool
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
Expand Down Expand Up @@ -105,6 +107,21 @@ def postgres_dsn(self) -> str:
path=self.postgres_database,
).unicode_string()

redis_host: str = "localhost"
redis_port: int = 5800
redis_password: str = "test"
redis_pochta_stream: str = "pochta:send"

@computed_field
@property
def redis_dsn(self) -> str:
return RedisDsn.build(
scheme="redis",
password=self.redis_password,
host=self.redis_host,
port=self.redis_port,
).unicode_string()

mq_host: str = "localhost"
mq_port: int = 5672
mq_username: str = "guest"
Expand Down Expand Up @@ -160,6 +177,10 @@ class Base(AsyncAttrs, DeclarativeBase, MappingBase):
metadata = db_meta


redis_pool: ConnectionPool[Any] = ConnectionPool.from_url(
settings.redis_dsn, decode_responses=True, max_connections=20
)

pochta_producer = RabbitDirectProducer(queue_name=settings.mq_pochta_queue)

password_reset_cryptography = CryptographyProvider(
Expand Down
180 changes: 180 additions & 0 deletions app/common/redis_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import asyncio
import logging
from collections.abc import Callable
from typing import Any, Final, Protocol, TypeVar, get_type_hints, ClassVar

from pydantic import BaseModel, ValidationError
from redis.asyncio import Redis
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import ConnectionError, ResponseError, TimeoutError

from app.common.config import settings

BLOCK_TIME_MS: Final[int] = 2000

T = TypeVar("T", bound=BaseModel, contravariant=True)


class MessageHandlerProtocol(Protocol[T]):
async def __call__(self, message: T) -> None:
pass
# TODO nq remove protocol


class ConsumerException(Exception):
message: ClassVar[str]
requeue: ClassVar[bool]

def __init__(self, message_override: str | None = None) -> None:
super().__init__(message_override or self.message)


class SMTPTimeoutException(ConsumerException):
message = "SMTP sad(("
requeue = True


class RedisStreamConsumer:
def __init__(
self,
stream_name: str,
group_name: str,
consumer_name: str,
model: type[T],
message_handler: MessageHandlerProtocol[T],
) -> None:
self.stream_name = stream_name
self.group_name = group_name
self.consumer_name = consumer_name
self.model = model
self.message_handler = message_handler

self.redis_client = Redis.from_url(
url=settings.redis_dsn,
decode_responses=True,
retry=Retry(backoff=ExponentialBackoff(cap=10, base=1), retries=10),
retry_on_error=[ConnectionError, TimeoutError],
)

async def create_group_if_not_exist(self) -> None:
try:
await self.redis_client.xgroup_create(
name=self.stream_name,
groupname=self.group_name,
id="$",
mkstream=True,
)
except ResponseError as response_exc:
# Не поднимаем ошибку о том, что группа уже создана (`BUSYGROUP`)
if "BUSYGROUP" not in str(response_exc):
raise

async def process_message(self, message_id: str, data: dict[str, str]) -> None:
# TODO nq fixup error messages & expand extras
try:
validated_data = self.model.model_validate(data)
except ValidationError as e: # TODO nq mb move to handle_messages
logging.error(
"Invalid message payload",
extra={"original_message": data},
exc_info=e,
)
await self.redis_client.xack( # type: ignore[no-untyped-call]
self.stream_name,
self.group_name,
message_id,
)
return

try:
await self.message_handler(validated_data)
# TODO nq ConsumerException?
except Exception as e: # noqa PIE786 # TODO nq mb move to handle_messages
logging.error(
f"Error in {self.consumer_name} while processing message {data}",
exc_info=e,
)

await self.redis_client.xack( # type: ignore[no-untyped-call]
self.stream_name,
self.group_name,
message_id,
)

async def handle_messages(self) -> None:
await self.create_group_if_not_exist()

last_message_id: str = "0"

while True: # noqa WPS457 # required for continuous message handling
messages = await self.redis_client.xreadgroup(
groupname=self.group_name,
consumername=self.consumer_name,
streams={self.stream_name: last_message_id},
count=1,
block=BLOCK_TIME_MS,
)
if len(messages) == 0:
continue
elif len(messages[0][1]) == 0:
last_message_id = ">"
continue

message_id, data = messages[0][1][0]
await self.process_message(message_id, data)
if last_message_id != ">":
last_message_id = message_id

async def run(self) -> None:
while True: # noqa WPS457 # required for continuous running
try:
await self.handle_messages()
except asyncio.CancelledError:
await self.redis_client.close() # TODO nq move to destruct?
break
except Exception as e: # noqa PIE786
logging.error(
f"An error occurred in worker {self.consumer_name}: {e}",
exc_info=e,
)
await asyncio.sleep(2)
# TODO nq backoff & give up after 10 tries
# or remove `while True` completely
continue


class RedisRouter:
def __init__(self) -> None:
self.consumers: list[RedisStreamConsumer] = []
self.tasks: list[asyncio.Task[Any]] = []

def add_consumer(
self, stream_name: str, group_name: str, consumer_name: str
) -> Callable[[MessageHandlerProtocol[T]], None]:
def redis_consumer_wrapper(func: MessageHandlerProtocol[T]) -> None:
model = next(iter(get_type_hints(func).values())) # TODO nq signature
if not issubclass(model, BaseModel):
raise TypeError(f"Expected a subclass of BaseModel, got {model}")
worker_instance = RedisStreamConsumer(
stream_name=stream_name,
group_name=group_name,
consumer_name=consumer_name,
model=model,
message_handler=func,
)
self.consumers.append(worker_instance)

return redis_consumer_wrapper

def include_router(self, router: "RedisRouter") -> None:
self.consumers.extend(router.consumers)

async def run_consumers(self) -> None:
self.tasks.extend(
[asyncio.create_task(consumer.run()) for consumer in self.consumers]
)

async def terminate_consumers(self) -> None:
for task in self.tasks:
task.cancel()
10 changes: 9 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

from app import pochta, supbot, users
from app.common.bridges.config_bdg import public_users_bridge
from app.common.config import Base, engine, pochta_producer, sessionmaker, settings
from app.common.config import (
Base,
engine,
pochta_producer,
redis_pool,
sessionmaker,
settings,
)
from app.common.sqlalchemy_ext import session_context
from app.common.starlette_cors_ext import CorrectCORSMiddleware

Expand Down Expand Up @@ -47,6 +54,7 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
yield

await rabbit_connection.close()
await redis_pool.disconnect()


app = FastAPI(
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions app/pochta/dependencies/redis_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Annotated

from fastapi import Depends
from redis.asyncio import Redis

from app.common.config import redis_pool


async def get_redis_connection() -> Redis[str]:
# TODO nq add backoff?
return Redis(connection_pool=redis_pool, decode_responses=True)


RedisConnection = Annotated[Redis[str], Depends(get_redis_connection)]
8 changes: 7 additions & 1 deletion app/pochta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from app.common.config import settings
from app.common.fastapi_ext import APIRouterExt
from app.pochta.routes import pochta_mub
from app.common.redis_ext import RedisRouter
from app.pochta.routes import pochta_mub, pochta_rds
from app.users.utils.mub import MUBProtection

mub_router = APIRouterExt(prefix="/mub", dependencies=[MUBProtection])
Expand All @@ -13,9 +14,14 @@
api_router = APIRouterExt()
api_router.include_router(mub_router)

redis_router = RedisRouter()
redis_router.include_router(pochta_rds.router)


@asynccontextmanager
async def lifespan() -> AsyncIterator[None]:
if settings.production_mode and settings.email is None:
logging.warning("Configuration for email service is missing")
await redis_router.run_consumers()
yield
await redis_router.terminate_consumers()
Empty file added app/pochta/routes/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions app/pochta/routes/pochta_mub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from app.common.config import settings, smtp_client
from app.common.fastapi_ext import APIRouterExt
from app.pochta.dependencies.redis_dep import RedisConnection

router = APIRouterExt(tags=["pochta mub"])

Expand All @@ -30,3 +31,12 @@ async def send_email_from_file(

async with smtp_client as smtp:
await smtp.send_message(message)


@router.post("/")
async def home(r: RedisConnection) -> str:
await r.xadd(
settings.redis_pochta_stream,
{"key": "value"},
)
return f"Message was added to stream {settings.redis_pochta_stream}"
49 changes: 49 additions & 0 deletions app/pochta/routes/pochta_rds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Literal

from pydantic import BaseModel, Field

from app.common.config import settings
from app.common.redis_ext import RedisRouter


class RegistrationEmailV1Data(BaseModel):
template: Literal["registration-v1"] = "registration-v1"
email_confirmation_token: str


class RegistrationEmailV2Data(BaseModel):
template: Literal["registration-v2"] = "registration-v2"
email_confirmation_token: str
username: str


class PasswordResetEmailData(BaseModel):
template: Literal["password-reset-v1"] = "password-reset-v1"
reset_confirmation_token: str


class EmailSendRequest(BaseModel):
email: str
data: RegistrationEmailV1Data | RegistrationEmailV2Data | PasswordResetEmailData = Field(discriminator="template")


# {{ data.email_confirmation_token }}

# {"email": "test@test.test", "data": {"template": "registration-v1", "email_confirmation_token": ""}}
# {"email": "test@test.test", "data": {"template": "password-reset-v1", "reset_confirmation_token": ""}}


class PochtaSchema(BaseModel):
key: str


router = RedisRouter()


@router.add_consumer(
stream_name=settings.redis_pochta_stream,
group_name="pochta:group",
consumer_name="pochta_consumer",
)
async def process_email_message(message: PochtaSchema) -> None:
print(f"Message: {message}") # noqa T201 # temporary print for debugging
Empty file added app/pochta/workers/__init__.py
Empty file.
Loading
Loading