Skip to content
This repository was archived by the owner on Jul 9, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ extend-ignore =
WPS428 # fails to understand overloading
WPS465 # fails to understand pipe-unions for types
WPS601 # fails to same-name class and instance attributes (pydantic & sqlalchemy)
# to many
WPS201 WPS202 WPS204 WPS210 WPS214 WPS217 WPS218 WPS220 WPS221 WPS234 WPS235
# too many
WPS201 WPS202 WPS204 WPS210 WPS211 WPS214 WPS217 WPS218 WPS220 WPS221 WPS230 WPS231 WPS234 WPS235

# don't block features
WPS100 # utils is a module name
Expand Down
8 changes: 8 additions & 0 deletions app/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from cryptography.fernet import Fernet
from dotenv import load_dotenv
from redis.asyncio import ConnectionPool
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
Expand All @@ -31,6 +32,9 @@
DB_URL: str = getenv("DB_LINK", "postgresql+psycopg://test:test@localhost:5432/test")
DB_SCHEMA: str | None = getenv("DB_SCHEMA", None)

REDIS_URL: str = getenv("REDIS_URL", "redis://localhost:6379")
REDIS_POCHTA_STREAM: str = getenv("REDIS_POCHTA_STREAM", "pochta:send")

MQ_URL: str = getenv("MQ_URL", "amqp://guest:guest@localhost/")
MQ_POCHTA_QUEUE: str = getenv("MQ_POCHTA_QUEUE", "pochta.send")

Expand Down Expand Up @@ -77,6 +81,10 @@ class Base(AsyncAttrs, DeclarativeBase, MappingBase):
metadata = db_meta


redis_pool = ConnectionPool.from_url(
REDIS_URL, decode_responses=True, max_connections=20
)

pochta_producer = RabbitDirectProducer(queue_name=MQ_POCHTA_QUEUE)

password_reset_cryptography = CryptographyProvider(
Expand Down
159 changes: 159 additions & 0 deletions app/common/redis_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
import logging
from collections.abc import Callable
from typing import Any, Final, Protocol, TypeVar, get_type_hints

from pydantic import BaseModel, ValidationError
from redis.asyncio import Redis
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import (
ConnectionError,
MaxConnectionsError,
ResponseError,
TimeoutError,
)

from app.common.config import REDIS_URL

BLOCK_TIME_MS: Final[int] = 2000

T = TypeVar("T", bound=BaseModel, contravariant=True)


class MessageHandlerProtocol(Protocol[T]):
async def __call__(self, message: T) -> None:
pass
# TODO remove protocol


class RedisStreamConsumer:
def __init__(
self,
stream_name: str,
group_name: str,
consumer_name: str,
model: type[T],
message_handler: MessageHandlerProtocol[T],
) -> None:
self.stream_name: str = stream_name
self.group_name: str = group_name
self.consumer_name: str = consumer_name
self.model: type[T] = model
self.message_handler: MessageHandlerProtocol[T] = message_handler

self.redis_client = Redis.from_url(
REDIS_URL,
decode_responses=True,
retry=Retry(ExponentialBackoff(cap=10, base=1), 10),
retry_on_error=(ConnectionError, TimeoutError, MaxConnectionsError),
)

async def create_group_if_not_exist(self) -> None:
try:
await self.redis_client.xgroup_create(
name=self.stream_name,
groupname=self.group_name,
id="$",
mkstream=True,
)
except ResponseError as response_exc:
if "BUSYGROUP" not in str(response_exc):
raise

async def process_message(self, message_id: str, data: dict[str, str]) -> None:
try:
validated_data = self.model.model_validate(data)
except ValidationError:
logging.error(
"Invalid message payload",
extra={"original_message": data},
)
await self.redis_client.xack(self.stream_name, self.group_name, message_id)
return

try:
await self.message_handler(validated_data)
except Exception as handling_exc: # noqa PIE786
logging.error(
f"Error in {self.consumer_name} while processing message {data}",
exc_info=handling_exc,
)
return
await self.redis_client.xack(self.stream_name, self.group_name, message_id)

async def handle_messages(self) -> None:
await self.create_group_if_not_exist()

last_message_id: str = "0"

while True: # noqa WPS457 # required for continuous message handling
messages = await self.redis_client.xreadgroup(
groupname=self.group_name,
consumername=self.consumer_name,
streams={self.stream_name: last_message_id},
count=1,
block=BLOCK_TIME_MS,
)
if len(messages) == 0:
continue
elif len(messages[0][1]) == 0:
last_message_id = ">"
for message_id, data in messages[0][1]:
await self.process_message(message_id, data)
if last_message_id != ">":
last_message_id = message_id

async def run(self) -> None:
while True: # noqa WPS457 # required for continuous running
try:
await self.handle_messages()
except asyncio.CancelledError:
await self.redis_client.close()
break
except (ConnectionError, TimeoutError, MaxConnectionsError):
await asyncio.sleep(10)
continue
except Exception as handling_exc: # noqa PIE786
logging.error(
f"An error occurred in worker {self.consumer_name}: {handling_exc}",
exc_info=handling_exc,
)
await asyncio.sleep(2)
continue


class RedisRouter:
def __init__(self) -> None:
self.consumers: list[RedisStreamConsumer] = []
self.tasks: list[asyncio.Task[Any]] = []

def add_consumer(
self, stream_name: str, group_name: str, consumer_name: str
) -> Callable[[MessageHandlerProtocol[T]], None]:
def redis_consumer_wrapper(func: MessageHandlerProtocol[T]) -> None:
model = next(iter(get_type_hints(func).values()))
if not issubclass(model, BaseModel):
raise TypeError(f"Expected a subclass of BaseModel, got {model}.")
worker_instance = RedisStreamConsumer(
stream_name=stream_name,
group_name=group_name,
consumer_name=consumer_name,
model=model,
message_handler=func,
)
self.consumers.append(worker_instance)

return redis_consumer_wrapper

def include_router(self, router: "RedisRouter") -> None:
self.consumers.extend(router.consumers)

async def run_consumers(self) -> None:
self.tasks.extend(
[asyncio.create_task(consumer.run()) for consumer in self.consumers]
)

async def terminate_consumers(self) -> None:
for task in self.tasks:
task.cancel()
7 changes: 5 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from starlette.responses import Response
from starlette.staticfiles import StaticFiles

from app import supbot, users
from app import pochta, supbot, users
from app.common.config import (
DATABASE_MIGRATED,
MQ_URL,
PRODUCTION_MODE,
Base,
engine,
pochta_producer,
redis_pool,
sessionmaker,
)
from app.common.sqlalchemy_ext import session_context
Expand Down Expand Up @@ -47,10 +48,11 @@ async def lifespan(_: FastAPI) -> AsyncIterator[None]:
async with AsyncExitStack() as stack:
await stack.enter_async_context(users.lifespan())
await stack.enter_async_context(supbot.lifespan())

await stack.enter_async_context(pochta.lifespan())
yield

await rabbit_connection.close()
await redis_pool.disconnect()


app = FastAPI(
Expand Down Expand Up @@ -86,6 +88,7 @@ async def custom_swagger_ui_html() -> Response:
allow_headers=["*"],
)

app.include_router(pochta.api_router)
app.include_router(users.api_router)
app.include_router(supbot.api_router)

Expand Down
3 changes: 3 additions & 0 deletions app/pochta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.pochta.main import api_router, lifespan

__all__ = ["api_router", "lifespan"]
Empty file.
13 changes: 13 additions & 0 deletions app/pochta/dependencies/redis_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Annotated

from fastapi import Depends
from redis.asyncio import Redis

from app.common.config import redis_pool


async def get_redis_connection() -> Redis:
return Redis(connection_pool=redis_pool)


RedisConnection = Annotated[Redis, Depends(get_redis_connection)]
24 changes: 24 additions & 0 deletions app/pochta/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from app.common.fastapi_ext import APIRouterExt
from app.common.redis_ext import RedisRouter
from app.pochta.routes import pochta_mub
from app.pochta.workers import pochta_rds
from app.users.utils.mub import MUBProtection

mub_router = APIRouterExt(prefix="/mub", dependencies=[MUBProtection])
mub_router.include_router(pochta_mub.router, prefix="/pochta-service")

api_router = APIRouterExt()
api_router.include_router(mub_router)

redis_router = RedisRouter()
redis_router.include_router(pochta_rds.router)


@asynccontextmanager
async def lifespan() -> AsyncIterator[None]:
await redis_router.run_consumers()
yield
await redis_router.terminate_consumers()
Empty file added app/pochta/routes/__init__.py
Empty file.
14 changes: 14 additions & 0 deletions app/pochta/routes/pochta_mub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from app.common.config import REDIS_POCHTA_STREAM
from app.common.fastapi_ext import APIRouterExt
from app.pochta.dependencies.redis_dep import RedisConnection

router = APIRouterExt(tags=["pochta mub"])


@router.post("/")
async def home(r: RedisConnection) -> dict[str, str]:
await r.xadd(
REDIS_POCHTA_STREAM,
{"key": "value"},
)
return {"msg": f"Message was added to stream {REDIS_POCHTA_STREAM}"}
Empty file added app/pochta/workers/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions app/pochta/workers/pochta_rds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pydantic import BaseModel

from app.common.config import REDIS_POCHTA_STREAM
from app.common.redis_ext import RedisRouter


class PochtaSchema(BaseModel):
key: str


router = RedisRouter()


@router.add_consumer(
stream_name=REDIS_POCHTA_STREAM,
group_name="pochta:group",
consumer_name="pochta_consumer",
)
async def process_email_message(message: PochtaSchema) -> None:
print(f"Message: {message}") # noqa T201 Temporary print for debugging
14 changes: 14 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
services:
redis:
image: redis:7.4.0-alpine
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
start_period: 60s
interval: 10s
timeout: 60s
retries: 5

mq:
image: rabbitmq:3.12.10-management-alpine
volumes:
Expand Down Expand Up @@ -62,6 +73,8 @@ services:
profiles:
- app
depends_on:
redis:
condition: service_healthy
mq:
condition: service_healthy
db:
Expand All @@ -80,6 +93,7 @@ services:
environment:
WATCHFILES_FORCE_POLLING: true
DB_LINK: postgresql+psycopg://test:test@db:5432/test
REDIS_URL: redis://redis:6379
MQ_URL: amqp://guest:guest@mq
DB_SCHEMA: xi_auth
# DATABASE_MIGRATED: "1"
Loading