|
1 | 1 | import logging |
| 2 | +import asyncio |
| 3 | +import sys |
2 | 4 | from contextlib import asynccontextmanager |
3 | 5 | from typing import AsyncGenerator |
4 | 6 |
|
5 | 7 | from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession |
6 | 8 | from sqlalchemy.exc import SQLAlchemyError |
7 | 9 |
|
8 | | -from userbot.src.config import DB_TYPE, DB_HOST, DB_PORT, DB_USER, DB_PASS, DB_NAME |
| 10 | +from userbot.src.config import ( |
| 11 | + DB_TYPE, DB_HOST, DB_PORT, DB_USER, DB_PASS, DB_NAME, |
| 12 | + DB_CONN_RETRIES, DB_CONN_RETRY_DELAY |
| 13 | +) |
9 | 14 | from userbot.src.db.models import Base |
10 | 15 |
|
11 | 16 | logger: logging.Logger = logging.getLogger(__name__) |
|
31 | 36 |
|
32 | 37 | async def initialize_database() -> None: |
33 | 38 | """ |
34 | | - Creates all tables in the database based on the SQLAlchemy models. |
| 39 | + Connects to the database and creates all tables based on the SQLAlchemy models. |
| 40 | + Includes a retry mechanism to handle database startup delays. |
35 | 41 | """ |
36 | | - async with async_engine.begin() as conn: |
37 | | - await conn.run_sync(Base.metadata.create_all) |
38 | | - logger.info("Database schema initialization check complete.") |
| 42 | + for attempt in range(DB_CONN_RETRIES): |
| 43 | + try: |
| 44 | + async with async_engine.begin() as conn: |
| 45 | + await conn.run_sync(Base.metadata.create_all) |
| 46 | + logger.info("Database schema initialization check complete.") |
| 47 | + return # Success, exit the function |
| 48 | + except (ConnectionRefusedError, OSError) as e: |
| 49 | + if attempt < DB_CONN_RETRIES - 1: |
| 50 | + logger.warning( |
| 51 | + f"Database connection failed (attempt {attempt + 1}/{DB_CONN_RETRIES}): {e}. " |
| 52 | + f"Retrying in {DB_CONN_RETRY_DELAY} seconds..." |
| 53 | + ) |
| 54 | + await asyncio.sleep(DB_CONN_RETRY_DELAY) |
| 55 | + else: |
| 56 | + logger.critical( |
| 57 | + f"Could not connect to the database after {DB_CONN_RETRIES} attempts. " |
| 58 | + "Please ensure the database is running and the .env file is configured correctly." |
| 59 | + ) |
| 60 | + raise # Re-raise the final exception to stop the application |
| 61 | + |
| 62 | + # This part should not be reachable if all retries fail, but as a safeguard: |
| 63 | + logger.critical("Exhausted all retries to connect to the database. Exiting.") |
| 64 | + sys.exit(1) |
| 65 | + |
39 | 66 |
|
40 | 67 | @asynccontextmanager |
41 | 68 | async def get_db() -> AsyncGenerator[AsyncSession, None]: |
|
0 commit comments