Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,677 changes: 951 additions & 726 deletions Cargo.lock

Large diffs are not rendered by default.

15 changes: 4 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,10 @@ serde_json = "1"
utoipa = { version = "5", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "9", features = ["axum"] }

# Database (multi-backend: PostgreSQL, MySQL, SQLite)
sqlx = { version = "0.8", features = [
"runtime-tokio",
"tls-rustls",
"any",
"postgres",
"mysql",
"sqlite",
"chrono",
"uuid",
] }
# Database (multi-backend: PostgreSQL, MySQL)
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4", "with-uuid-1"] }
deadpool-postgres = "0.14"
mysql_async = { version = "0.35", default-features = false, features = ["default-rustls"] }

# Environment
dotenvy = "0.15"
Expand Down
85 changes: 35 additions & 50 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,53 @@ pub mod session;

pub use session::SessionManager;

use sqlx::AnyPool;

/// Create a database pool from DATABASE_URL env var.
/// Supports: postgres://, mysql://, sqlite://
///
/// Falls back to legacy POSTGRES_* env vars if DATABASE_URL is not set.
pub async fn create_pool() -> anyhow::Result<AnyPool> {
// Install all drivers
sqlx::any::install_default_drivers();
/// Detect backend from DATABASE_URL or legacy POSTGRES_* vars.
/// Returns the connection URL and backend type.
pub fn resolve_database_url() -> (String, DbBackend) {
if let Ok(url) = std::env::var("DATABASE_URL") {
let backend = if url.starts_with("mysql") {
DbBackend::MySQL
} else {
DbBackend::Postgres
};
return (url, backend);
}

let database_url = if let Ok(url) = std::env::var("DATABASE_URL") {
url
} else if std::env::var("POSTGRES_HOST").is_ok() || std::env::var("POSTGRES_USER").is_ok() {
// Legacy fallback: only if POSTGRES_* vars are explicitly set
if std::env::var("POSTGRES_HOST").is_ok() || std::env::var("POSTGRES_USER").is_ok() {
let host = std::env::var("POSTGRES_HOST").unwrap_or_else(|_| "localhost".to_string());
let port = std::env::var("POSTGRES_PORT").unwrap_or_else(|_| "5432".to_string());
let user = std::env::var("POSTGRES_USER").unwrap_or_else(|_| "postgres".to_string());
let password =
std::env::var("POSTGRES_PASSWORD").unwrap_or_else(|_| "postgres".to_string());
let db = std::env::var("POSTGRES_DB").unwrap_or_else(|_| "wagateway".to_string());
let url = format!("postgres://{}:{}@{}:{}/{}", user, password, host, port, db);
return (url, DbBackend::Postgres);
}

format!("postgres://{}:{}@{}:{}/{}", user, password, host, port, db)
} else {
// Default: SQLite, fully local, no external services needed
tracing::info!("No DATABASE_URL set, using local SQLite database");
"sqlite://wa-rs.db".to_string()
};

let backend_name = if database_url.starts_with("postgres") {
"PostgreSQL"
} else if database_url.starts_with("mysql") {
"MySQL"
} else if database_url.starts_with("sqlite") {
"SQLite"
} else {
"Unknown"
};

// Mask password in log
let masked_url = mask_url(&database_url);
tracing::info!("Connecting to {} ({})", backend_name, masked_url);

// SQLite: ensure create mode is enabled
let connect_url = if database_url.starts_with("sqlite") && !database_url.contains("mode=") {
if database_url.contains('?') {
format!("{}&mode=rwc", database_url)
} else {
format!("{}?mode=rwc", database_url)
}
} else {
database_url.clone()
};

let pool = AnyPool::connect(&connect_url).await?;
if std::env::var("MYSQL_HOST").is_ok() || std::env::var("MYSQL_USER").is_ok() {
let host = std::env::var("MYSQL_HOST").unwrap_or_else(|_| "localhost".to_string());
let port = std::env::var("MYSQL_PORT").unwrap_or_else(|_| "3306".to_string());
let user = std::env::var("MYSQL_USER").unwrap_or_else(|_| "root".to_string());
let password = std::env::var("MYSQL_PASSWORD").unwrap_or_else(|_| "".to_string());
let db = std::env::var("MYSQL_DB").unwrap_or_else(|_| "wagateway".to_string());
let url = format!("mysql://{}:{}@{}:{}/{}", user, password, host, port, db);
return (url, DbBackend::MySQL);
}

tracing::info!("Connected to {}", backend_name);
// Default: PostgreSQL localhost
(
"postgres://postgres:postgres@localhost:5432/wagateway".to_string(),
DbBackend::Postgres,
)
}

Ok(pool)
#[derive(Clone, Copy, Debug)]
pub enum DbBackend {
Postgres,
MySQL,
}

fn mask_url(url: &str) -> String {
// Mask password in connection URL for safe logging
pub fn mask_url(url: &str) -> String {
if let Some(at_pos) = url.find('@') {
if let Some(colon_pos) = url[..at_pos].rfind(':') {
if let Some(slash_pos) = url[..colon_pos].rfind('/') {
Expand Down
209 changes: 69 additions & 140 deletions src/db/schema.rs
Original file line number Diff line number Diff line change
@@ -1,109 +1,68 @@
use sqlx::AnyPool;
use crate::db::session::DbPool;

pub async fn init_schema(pool: &AnyPool) -> anyhow::Result<()> {
let backend = detect_backend(pool);

match backend {
DbBackend::Postgres => init_postgres(pool).await,
DbBackend::MySQL => {
init_mysql(pool).await?;
migrate_mysql(pool).await
}
DbBackend::SQLite => init_sqlite(pool).await,
pub async fn init_schema(pool: &DbPool) -> anyhow::Result<()> {
match pool {
DbPool::Postgres(pg) => init_postgres(pg).await,
DbPool::MySQL(my) => init_mysql(my).await,
}
}

#[derive(Debug, Clone, Copy)]
enum DbBackend {
Postgres,
MySQL,
SQLite,
}

fn detect_backend(pool: &AnyPool) -> DbBackend {
let url = std::env::var("DATABASE_URL").unwrap_or_default();
if url.starts_with("postgres") {
DbBackend::Postgres
} else if url.starts_with("mysql") {
DbBackend::MySQL
} else if url.starts_with("sqlite") {
DbBackend::SQLite
} else {
// Fallback: check pool kind name
let name = format!("{:?}", pool);
if name.contains("Postgres") {
DbBackend::Postgres
} else if name.contains("MySql") {
DbBackend::MySQL
} else {
DbBackend::SQLite
}
}
}

async fn init_postgres(pool: &AnyPool) -> anyhow::Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id VARCHAR(255) PRIMARY KEY,
name VARCHAR(255),
storage_path TEXT NOT NULL,
phone_number VARCHAR(50),
push_name VARCHAR(255),
status VARCHAR(50) NOT NULL DEFAULT 'disconnected',
is_logged_in BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_connected_at TIMESTAMPTZ
async fn init_postgres(pool: &deadpool_postgres::Pool) -> anyhow::Result<()> {
let client = pool.get().await?;

client
.execute(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id VARCHAR(255) PRIMARY KEY,
name VARCHAR(255),
storage_path TEXT NOT NULL,
phone_number VARCHAR(50),
push_name VARCHAR(255),
status VARCHAR(50) NOT NULL DEFAULT 'disconnected',
is_logged_in BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_connected_at TIMESTAMPTZ
)
"#,
&[],
)
"#,
)
.execute(pool)
.await?;
.await?;

sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS webhooks (
id VARCHAR(255) PRIMARY KEY,
session_id VARCHAR(255) NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
url TEXT NOT NULL,
events TEXT NOT NULL DEFAULT '',
secret VARCHAR(255),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
client
.execute(
r#"
CREATE TABLE IF NOT EXISTS webhooks (
id VARCHAR(255) PRIMARY KEY,
session_id VARCHAR(255) NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
url TEXT NOT NULL,
events TEXT NOT NULL DEFAULT '',
secret VARCHAR(255),
enabled BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
"#,
&[],
)
"#,
)
.execute(pool)
.await?;
.await?;

sqlx::query("CREATE INDEX IF NOT EXISTS idx_webhooks_session_id ON webhooks(session_id)")
.execute(pool)
client
.execute(
"CREATE INDEX IF NOT EXISTS idx_webhooks_session_id ON webhooks(session_id)",
&[],
)
.await?;

Ok(())
}

async fn migrate_mysql(pool: &AnyPool) -> anyhow::Result<()> {
// SQLx Any driver doesn't support TINYINT, TIMESTAMP, or DATETIME — use INT/VARCHAR
let migrations = [
"ALTER TABLE sessions MODIFY COLUMN is_logged_in INT NOT NULL DEFAULT 0",
"ALTER TABLE sessions MODIFY COLUMN storage_path VARCHAR(500) NOT NULL",
"ALTER TABLE sessions MODIFY COLUMN created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
"ALTER TABLE sessions MODIFY COLUMN updated_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
"ALTER TABLE sessions MODIFY COLUMN last_connected_at VARCHAR(30) NULL",
"ALTER TABLE webhooks MODIFY COLUMN enabled INT NOT NULL DEFAULT 1",
"ALTER TABLE webhooks MODIFY COLUMN created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
"ALTER TABLE webhooks MODIFY COLUMN url VARCHAR(2000) NOT NULL",
];
for sql in &migrations {
let _ = sqlx::query(sql).execute(pool).await;
}
Ok(())
}
async fn init_mysql(pool: &mysql_async::Pool) -> anyhow::Result<()> {
use mysql_async::prelude::*;

let mut conn = pool.get_conn().await?;

async fn init_mysql(pool: &AnyPool) -> anyhow::Result<()> {
sqlx::query(
conn.query_drop(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id VARCHAR(255) PRIMARY KEY,
Expand All @@ -113,16 +72,15 @@ async fn init_mysql(pool: &AnyPool) -> anyhow::Result<()> {
push_name VARCHAR(255),
status VARCHAR(50) NOT NULL DEFAULT 'disconnected',
is_logged_in INT NOT NULL DEFAULT 0,
created_at VARCHAR(30) NOT NULL,
updated_at VARCHAR(30) NOT NULL,
created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00',
updated_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00',
last_connected_at VARCHAR(30) NULL
)
"#,
)
.execute(pool)
.await?;

sqlx::query(
conn.query_drop(
r#"
CREATE TABLE IF NOT EXISTS webhooks (
id VARCHAR(255) PRIMARY KEY,
Expand All @@ -131,57 +89,28 @@ async fn init_mysql(pool: &AnyPool) -> anyhow::Result<()> {
events VARCHAR(2000) NOT NULL DEFAULT '',
secret VARCHAR(255),
enabled INT NOT NULL DEFAULT 1,
created_at VARCHAR(30) NOT NULL,
created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00',
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE,
INDEX idx_webhooks_session_id (session_id)
)
"#,
)
.execute(pool)
.await?;

Ok(())
}

async fn init_sqlite(pool: &AnyPool) -> anyhow::Result<()> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
name TEXT,
storage_path TEXT NOT NULL,
phone_number TEXT,
push_name TEXT,
status TEXT NOT NULL DEFAULT 'disconnected',
is_logged_in INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
last_connected_at TEXT
)
"#,
)
.execute(pool)
.await?;

sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS webhooks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
url TEXT NOT NULL,
events TEXT NOT NULL DEFAULT '',
secret TEXT,
enabled INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"#,
)
.execute(pool)
.await?;

sqlx::query("CREATE INDEX IF NOT EXISTS idx_webhooks_session_id ON webhooks(session_id)")
.execute(pool)
.await?;
// Auto-migrate existing tables with incompatible types
let migrations = [
"ALTER TABLE sessions MODIFY COLUMN is_logged_in INT NOT NULL DEFAULT 0",
"ALTER TABLE sessions MODIFY COLUMN storage_path VARCHAR(500) NOT NULL",
"ALTER TABLE sessions MODIFY COLUMN created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
"ALTER TABLE sessions MODIFY COLUMN updated_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
"ALTER TABLE sessions MODIFY COLUMN last_connected_at VARCHAR(30) NULL",
"ALTER TABLE webhooks MODIFY COLUMN enabled INT NOT NULL DEFAULT 1",
"ALTER TABLE webhooks MODIFY COLUMN url VARCHAR(2000) NOT NULL",
"ALTER TABLE webhooks MODIFY COLUMN created_at VARCHAR(30) NOT NULL DEFAULT '1970-01-01 00:00:00'",
];
for sql in &migrations {
let _ = conn.query_drop(*sql).await;
}

Ok(())
}
Loading
Loading