diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..8ecbb4e6 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +backend/target +backend/data +backend/*.sqlite* +backend/remote-dbs +backend/uploads +frontend/node_modules +frontend/.svelte-kit +frontend/build +node_modules +.git +.claude diff --git a/.gitignore b/.gitignore index d657935d..01bcb4c1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ node_modules backend/uploads .dev-ports .playwright-cli +.claude/worktrees/ +.vercel .claude diff --git a/.vercelignore b/.vercelignore new file mode 100644 index 00000000..a51dfaf6 --- /dev/null +++ b/.vercelignore @@ -0,0 +1,7 @@ +backend/target +backend/data +backend/*.sqlite* +backend/remote-dbs +.git +.claude +node_modules diff --git a/AGENTS.md b/AGENTS.md index eefc50d4..24d05e73 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -256,7 +256,7 @@ Copy the appropriate template to `frontend/.env` for your use case: #### `backend/src/main.rs` - **Entry point** for the backend server -- Initializes Axum router with WebSocket handler at `/api`, image upload/serving routes, and Airtable sync endpoint +- Initializes Axum router with per-cohort WebSocket handler at `/api/ws/:cohort_name`, admin/user REST endpoints, and image upload/serving routes - Implements port binding with fallback logic (tries sequential ports if in use) - Manages uploads directory and request body size limits - Depends on `lib.rs` for `AppState`, `handle_socket.rs` for WebSocket handling @@ -266,7 +266,7 @@ Copy the appropriate template to `frontend/.env` for your use case: - Defines `AppState` struct containing DB connection pool, pub/sub subscriptions, and rate limiters - Configures separate admin/user rate limit quotas for expensive queries and mutations - Includes protobuf module generation via `build.rs` -- Declares modules: `websocket_api`, `auth`, `db`, `handle_socket`, `subscriptions`, `airtable_users`, `convert`, `seed`, `test_utils` +- Declares modules: `websocket_api`, `auth`, `db`, `global_db`, `handle_socket`, `subscriptions`, `convert`, `seed`, `test_utils` #### `backend/src/handle_socket.rs` - Core WebSocket request/response handler (~1150 lines) @@ -301,11 +301,6 @@ Copy the appropriate template to `frontend/.env` for your use case: - Implements `From` trait for all domain types (Portfolio, Market, Order, Trade, Transfer, Account, Auction) - Converts Rust Decimal to protobuf floats, timestamps to protobuf Timestamp format -#### `backend/src/airtable_users.rs` -- Syncs Airtable student records to Kinde and database -- Creates Kinde accounts and DB entries, assigns initial balances based on product ID -- Caches Kinde API tokens, logs errors back to Airtable - #### `backend/src/seed.rs` - Development seed data (feature-gated behind `dev-mode`) - Seeds fresh databases with test accounts (Alice, Bob, Charlie, Admin), markets, orders, and trades diff --git a/DEPLOY.md b/DEPLOY.md new file mode 100644 index 00000000..c15e0e90 --- /dev/null +++ b/DEPLOY.md @@ -0,0 +1,122 @@ +# Deployment + +## Architecture + +- **Frontend**: Static SPA on Vercel (SvelteKit with `adapter-static`) +- **Backend**: Rust binary on Fly.io (Docker, SQLite with persistent volume) +- Frontend and backend are on different domains, so HTTP API calls use cross-origin requests with CORS. + +## Staging + +- **Frontend**: https://platform-staging-five-gamma.vercel.app (Vercel project `platform-staging` under `trading-bootcamp` team) +- **Backend**: https://trading-bootcamp-staging.fly.dev (Fly app `trading-bootcamp-staging`) + +### Deploy staging backend + +```bash +fly deploy --config backend/fly.staging.toml +``` + +### Deploy staging frontend + +```bash +vercel --prod --scope trading-bootcamp +``` + +### Vercel env vars (Production environment) + +| Variable | Value | +|----------|-------| +| `PUBLIC_KINDE_CLIENT_ID` | `a9869bb1225848b9ad5bad2a04b72b5f` | +| `PUBLIC_KINDE_DOMAIN` | `https://account.trading.camp` | +| `PUBLIC_KINDE_REDIRECT_URI` | `https://platform-staging-five-gamma.vercel.app` | +| `PUBLIC_SERVER_URL` | `wss://trading-bootcamp-staging.fly.dev/api` | +| `PUBLIC_TEST_AUTH` | `false` | + +**Important**: When setting env vars via `vercel env add`, pipe with `printf` (not `echo`) to avoid embedding trailing newlines: +```bash +printf 'value' | vercel env add VAR_NAME production --scope trading-bootcamp +``` + +### Kinde setup + +Add the frontend URL to "Allowed callback URLs" in the Kinde application settings for client ID `a9869bb1225848b9ad5bad2a04b72b5f`. + +## Code changes required for deployment + +### 1. `.dockerignore` + +Excludes `backend/target`, `backend/data`, SQLite files, `node_modules`, `.git`, `.claude` etc. Without this, Docker context transfer is ~733MB instead of ~700KB. + +### 2. `backend/Dockerfile` modifications + +- **`ENV SQLX_OFFLINE=true`**: SQLx does compile-time query checking against a live database by default. In Docker there's no database, so offline mode uses the pre-generated `.sqlx/` cache instead. +- **`COPY ./backend/global_migrations /app/global_migrations`**: The multi-cohort feature added a `global_migrations/` directory that needs to be in the runtime image. + +### 3. `.vercelignore` + +Excludes the same heavy directories from Vercel uploads. + +### 4. CORS: `backend/Cargo.toml` + `backend/src/main.rs` + +Since frontend (Vercel) and backend (Fly.io) are on different domains, the browser blocks cross-origin API requests. The fix: + +- Added `"cors"` feature to `tower-http` in `Cargo.toml` +- Replaced the manual `SetResponseHeaderLayer` for `Access-Control-Allow-Origin: *` with `CorsLayer::permissive()`, which properly handles preflight OPTIONS requests and allows the `Authorization` header + +### 5. Cross-origin HTTP API calls: `frontend/src/lib/apiBase.ts` + +In development, the Vite dev server proxies `/api/*` to the backend (see `frontend/vite.config.ts`). In production, there's no proxy — the frontend and backend are on different domains. + +`apiBase.ts` derives the HTTP base URL from `PUBLIC_SERVER_URL`: +- `wss://host.fly.dev/api` → `https://host.fly.dev` +- `ws://localhost:8080` → `http://localhost:8080` + +`cohortApi.ts` and `adminApi.ts` use `API_BASE` to make absolute URL requests instead of relative `/api/...` paths. This works in both dev (Vite proxy still intercepts) and production (direct cross-origin requests). + +### 6. `frontend/src/routes/[cohort_name]/docs/[slug]/+page.svelte` + +Fixed markdown import paths — the file moved one level deeper into `[cohort_name]/` so the relative imports needed an extra `../` (e.g. `../../../../../docs/` → `../../../../../../docs/`). + +### 7. `backend/fly.staging.toml` + +Fly.io config for the staging app. Key differences from production (`fly.toml`): +- `app = 'trading-bootcamp-staging'` +- Separate persistent volume mount (`trading_bootcamp_staging`) +- Staging database path + +## Gotchas + +### Stale `.sqlx/` cache breaks Fly builds + +Because the Dockerfile sets `SQLX_OFFLINE=true`, the build relies entirely on the committed `.sqlx/` cache. If queries in `db.rs` change (e.g., new column like `account.color`) without regenerating the cache, Fly builds fail with cryptic errors like: + +``` +error: key must be a string at line 3 column 1 + --> src/db.rs:262:9 +``` + +This is *not* a syntax error in `db.rs` — it's SQLx failing to parse/find a matching cache file. Fix: + +```bash +cd backend +sqlx migrate run # ensure local DB matches migrations +cargo sqlx prepare -- --features dev-mode --tests # regenerate .sqlx/ +git add backend/.sqlx && git commit # commit the regenerated cache +``` + +Always include `--tests` so queries used only in tests are cached too. CLAUDE.md's "Required Checks" section mentions this but it's easy to miss — treat any SQL change in `db.rs` as requiring a cache regen before pushing. + +### Vercel deployment URL vs. production alias + +`vercel --prod` prints a line like: + +``` +Production: https://platform-staging-fgbm5rn8e-trading-bootcamp.vercel.app +``` + +That is the **immutable deployment URL** for this specific build, not a different project. The stable production alias (`https://platform-staging-five-gamma.vercel.app`) is updated to point to this deployment. Verify with `vercel project ls --scope trading-bootcamp` — the `Latest Production URL` column shows the alias. + +### Two `.vercel/` project links in the repo + +Both `/.vercel/` (repo root) and `/frontend/.vercel/` exist and point to projects named `platform-staging`, but with **different org IDs**. Running `vercel` from the repo root uses the root link, which is the correct one (the `trading-bootcamp` team's `platform-staging` project that aliases to `platform-staging-five-gamma.vercel.app`). Do not `cd frontend` before deploying. diff --git a/backend/.gitignore b/backend/.gitignore index 3c74fad6..e82a09d7 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -3,3 +3,6 @@ target/ db.sqlite db.sqlite-shm db.sqlite-wal +*.sqlite +*.sqlite-shm +*.sqlite-wal diff --git a/backend/.sqlx/query-aae2c0449b7f47212ceac7d9fc1231968e23b191d4c9f7e9d414e9b37f86872d.json b/backend/.sqlx/query-411e2737e334e6e70fc1b288e146ae70c4360d7c7f087942f34fc880da0960c7.json similarity index 74% rename from backend/.sqlx/query-aae2c0449b7f47212ceac7d9fc1231968e23b191d4c9f7e9d414e9b37f86872d.json rename to backend/.sqlx/query-411e2737e334e6e70fc1b288e146ae70c4360d7c7f087942f34fc880da0960c7.json index 1030f40c..feb8ce89 100644 --- a/backend/.sqlx/query-aae2c0449b7f47212ceac7d9fc1231968e23b191d4c9f7e9d414e9b37f86872d.json +++ b/backend/.sqlx/query-411e2737e334e6e70fc1b288e146ae70c4360d7c7f087942f34fc880da0960c7.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT id, name, kinde_id IS NOT NULL as \"is_user: bool\", universe_id, color FROM account", + "query": "SELECT id, name, (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) as \"is_user: bool\", universe_id, color FROM account", "describe": { "columns": [ { @@ -37,8 +37,8 @@ false, false, false, - true + false ] }, - "hash": "aae2c0449b7f47212ceac7d9fc1231968e23b191d4c9f7e9d414e9b37f86872d" + "hash": "411e2737e334e6e70fc1b288e146ae70c4360d7c7f087942f34fc880da0960c7" } diff --git a/backend/.sqlx/query-433242ec70fe924887b6fde177112404d773129357a3845da91d0ec752fd143a.json b/backend/.sqlx/query-433242ec70fe924887b6fde177112404d773129357a3845da91d0ec752fd143a.json new file mode 100644 index 00000000..12ce415c --- /dev/null +++ b/backend/.sqlx/query-433242ec70fe924887b6fde177112404d773129357a3845da91d0ec752fd143a.json @@ -0,0 +1,26 @@ +{ + "db_name": "SQLite", + "query": "\n SELECT id AS \"id!\", name\n FROM account\n WHERE global_user_id = ?\n ", + "describe": { + "columns": [ + { + "name": "id!", + "ordinal": 0, + "type_info": "Int64" + }, + { + "name": "name", + "ordinal": 1, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + true, + false + ] + }, + "hash": "433242ec70fe924887b6fde177112404d773129357a3845da91d0ec752fd143a" +} diff --git a/backend/.sqlx/query-943427a7ba64285bf8dc6f35a61f4e025b5c2ec12c2448768425245ffe49290d.json b/backend/.sqlx/query-6026fe9a76ae8dd30ead15b3c809a78a546129d2e9bf4e99993b4c92e07f6ad9.json similarity index 62% rename from backend/.sqlx/query-943427a7ba64285bf8dc6f35a61f4e025b5c2ec12c2448768425245ffe49290d.json rename to backend/.sqlx/query-6026fe9a76ae8dd30ead15b3c809a78a546129d2e9bf4e99993b4c92e07f6ad9.json index 46a73e4b..8e0991c5 100644 --- a/backend/.sqlx/query-943427a7ba64285bf8dc6f35a61f4e025b5c2ec12c2448768425245ffe49290d.json +++ b/backend/.sqlx/query-6026fe9a76ae8dd30ead15b3c809a78a546129d2e9bf4e99993b4c92e07f6ad9.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "SELECT EXISTS(\n SELECT 1 FROM account WHERE id = ? AND kinde_id IS NOT NULL\n ) as \"exists!: bool\"", + "query": "SELECT EXISTS(\n SELECT 1 FROM account WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL)\n ) as \"exists!: bool\"", "describe": { "columns": [ { @@ -16,5 +16,5 @@ null ] }, - "hash": "943427a7ba64285bf8dc6f35a61f4e025b5c2ec12c2448768425245ffe49290d" + "hash": "6026fe9a76ae8dd30ead15b3c809a78a546129d2e9bf4e99993b4c92e07f6ad9" } diff --git a/backend/.sqlx/query-0c7c90c8753c66208ee11c0cc2826e7044881937d5d1aadccb5dc412dec9f55a.json b/backend/.sqlx/query-62de5a5e7713956efa9ee7f382a36c1c1646d5f224b97ba88985192518773f92.json similarity index 64% rename from backend/.sqlx/query-0c7c90c8753c66208ee11c0cc2826e7044881937d5d1aadccb5dc412dec9f55a.json rename to backend/.sqlx/query-62de5a5e7713956efa9ee7f382a36c1c1646d5f224b97ba88985192518773f92.json index e2532a22..b42f459b 100644 --- a/backend/.sqlx/query-0c7c90c8753c66208ee11c0cc2826e7044881937d5d1aadccb5dc412dec9f55a.json +++ b/backend/.sqlx/query-62de5a5e7713956efa9ee7f382a36c1c1646d5f224b97ba88985192518773f92.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "\n SELECT EXISTS(\n SELECT 1\n FROM account\n WHERE id = ? AND kinde_id IS NOT NULL\n ) as \"exists!: bool\"\n ", + "query": "\n SELECT EXISTS(\n SELECT 1\n FROM account\n WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL)\n ) AS \"exists!: bool\"\n ", "describe": { "columns": [ { @@ -16,5 +16,5 @@ null ] }, - "hash": "0c7c90c8753c66208ee11c0cc2826e7044881937d5d1aadccb5dc412dec9f55a" + "hash": "62de5a5e7713956efa9ee7f382a36c1c1646d5f224b97ba88985192518773f92" } diff --git a/backend/.sqlx/query-573c8349f125b65d642b439cacd5e900b4fad9fb3a396c3a625c2adbfdb30a7f.json b/backend/.sqlx/query-98aba038f48d4dc2f929955a9fa90166defea7deb41560c9bcd35a8c60b7bb59.json similarity index 64% rename from backend/.sqlx/query-573c8349f125b65d642b439cacd5e900b4fad9fb3a396c3a625c2adbfdb30a7f.json rename to backend/.sqlx/query-98aba038f48d4dc2f929955a9fa90166defea7deb41560c9bcd35a8c60b7bb59.json index 0c19932e..56de5698 100644 --- a/backend/.sqlx/query-573c8349f125b65d642b439cacd5e900b4fad9fb3a396c3a625c2adbfdb30a7f.json +++ b/backend/.sqlx/query-98aba038f48d4dc2f929955a9fa90166defea7deb41560c9bcd35a8c60b7bb59.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "\n SELECT EXISTS(\n SELECT 1\n FROM account\n WHERE id = ? AND kinde_id IS NOT NULL\n ) AS \"exists!: bool\"\n ", + "query": "\n SELECT EXISTS(\n SELECT 1\n FROM account\n WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL)\n ) as \"exists!: bool\"\n ", "describe": { "columns": [ { @@ -16,5 +16,5 @@ null ] }, - "hash": "573c8349f125b65d642b439cacd5e900b4fad9fb3a396c3a625c2adbfdb30a7f" + "hash": "98aba038f48d4dc2f929955a9fa90166defea7deb41560c9bcd35a8c60b7bb59" } diff --git a/backend/.sqlx/query-78b6fbf503baf041f4ebf76c310e39c63ca56bb86a64d001e0853014cf71f340.json b/backend/.sqlx/query-b20cceac140cfcdb62468aa6a2d19dbf1743f978837d9bb971868ed5dd17f0db.json similarity index 64% rename from backend/.sqlx/query-78b6fbf503baf041f4ebf76c310e39c63ca56bb86a64d001e0853014cf71f340.json rename to backend/.sqlx/query-b20cceac140cfcdb62468aa6a2d19dbf1743f978837d9bb971868ed5dd17f0db.json index 97a1b543..1f3149a7 100644 --- a/backend/.sqlx/query-78b6fbf503baf041f4ebf76c310e39c63ca56bb86a64d001e0853014cf71f340.json +++ b/backend/.sqlx/query-b20cceac140cfcdb62468aa6a2d19dbf1743f978837d9bb971868ed5dd17f0db.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "\n SELECT EXISTS (\n SELECT 1\n FROM account\n WHERE id = ? AND kinde_id IS NOT NULL\n ) AS \"exists!: bool\"\n ", + "query": "\n SELECT EXISTS (\n SELECT 1\n FROM account\n WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL)\n ) AS \"exists!: bool\"\n ", "describe": { "columns": [ { @@ -16,5 +16,5 @@ null ] }, - "hash": "78b6fbf503baf041f4ebf76c310e39c63ca56bb86a64d001e0853014cf71f340" + "hash": "b20cceac140cfcdb62468aa6a2d19dbf1743f978837d9bb971868ed5dd17f0db" } diff --git a/backend/.sqlx/query-c4f5b6e260992d4e2df494d9ba3ad09b7ca596147544f719c4b09d6020e8a01b.json b/backend/.sqlx/query-c4f5b6e260992d4e2df494d9ba3ad09b7ca596147544f719c4b09d6020e8a01b.json new file mode 100644 index 00000000..055da3fd --- /dev/null +++ b/backend/.sqlx/query-c4f5b6e260992d4e2df494d9ba3ad09b7ca596147544f719c4b09d6020e8a01b.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n SELECT id\n FROM account\n WHERE name = ? AND (global_user_id != ? OR global_user_id IS NULL)\n ", + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int64" + } + ], + "parameters": { + "Right": 2 + }, + "nullable": [ + true + ] + }, + "hash": "c4f5b6e260992d4e2df494d9ba3ad09b7ca596147544f719c4b09d6020e8a01b" +} diff --git a/backend/.sqlx/query-258bf770e5ae4ba380572b612c2aaa979ff496e650b43498854eb42e2b72848e.json b/backend/.sqlx/query-c5b871f0027f934fbacb865decf5fdd15baf76d68296c07a3bf00a3f33115f3e.json similarity index 67% rename from backend/.sqlx/query-258bf770e5ae4ba380572b612c2aaa979ff496e650b43498854eb42e2b72848e.json rename to backend/.sqlx/query-c5b871f0027f934fbacb865decf5fdd15baf76d68296c07a3bf00a3f33115f3e.json index fe52ad09..a594a3bd 100644 --- a/backend/.sqlx/query-258bf770e5ae4ba380572b612c2aaa979ff496e650b43498854eb42e2b72848e.json +++ b/backend/.sqlx/query-c5b871f0027f934fbacb865decf5fdd15baf76d68296c07a3bf00a3f33115f3e.json @@ -1,6 +1,6 @@ { "db_name": "SQLite", - "query": "\n SELECT\n id,\n name,\n kinde_id IS NOT NULL AS \"is_user: bool\",\n universe_id,\n color\n FROM account\n WHERE id = ?\n ", + "query": "\n SELECT\n id,\n name,\n (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) AS \"is_user: bool\",\n universe_id,\n color\n FROM account\n WHERE id = ?\n ", "describe": { "columns": [ { @@ -37,8 +37,8 @@ false, false, false, - true + false ] }, - "hash": "258bf770e5ae4ba380572b612c2aaa979ff496e650b43498854eb42e2b72848e" + "hash": "c5b871f0027f934fbacb865decf5fdd15baf76d68296c07a3bf00a3f33115f3e" } diff --git a/backend/.sqlx/query-fdcbfffc320184b5fc9c55e1892de89a23ba7b39ae38ca815a9274572f38ecd5.json b/backend/.sqlx/query-fdcbfffc320184b5fc9c55e1892de89a23ba7b39ae38ca815a9274572f38ecd5.json new file mode 100644 index 00000000..774bdea6 --- /dev/null +++ b/backend/.sqlx/query-fdcbfffc320184b5fc9c55e1892de89a23ba7b39ae38ca815a9274572f38ecd5.json @@ -0,0 +1,20 @@ +{ + "db_name": "SQLite", + "query": "\n INSERT INTO account (global_user_id, name, balance)\n VALUES (?, ?, ?)\n RETURNING id\n ", + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Int64" + } + ], + "parameters": { + "Right": 3 + }, + "nullable": [ + false + ] + }, + "hash": "fdcbfffc320184b5fc9c55e1892de89a23ba7b39ae38ca815a9274572f38ecd5" +} diff --git a/backend/Cargo.lock b/backend/Cargo.lock index e1d0244e..355152b5 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -261,7 +261,6 @@ dependencies = [ "tower-http", "tracing", "tracing-subscriber", - "urlencoding", "uuid", ] diff --git a/backend/Cargo.toml b/backend/Cargo.toml index eea1c80e..ba6cbd66 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -22,7 +22,7 @@ reqwest = { version = "0.12.5", default-features = false, features = [ serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.117" tokio = { version = "1.38.0", features = ["full"] } -tower-http = { version = "0.5.2", features = ["trace", "limit", "set-header"] } +tower-http = { version = "0.5.2", features = ["trace", "limit", "set-header", "cors"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } async-once-cell = "0.5.3" @@ -46,7 +46,6 @@ dashmap = "6.0.1" nonzero_ext = "0.3.0" rand = "0.8.5" tokio-stream = { version = "0.1.17", features = ["sync"] } -urlencoding = "2.1.3" tokio-util = { version = "0.7.10", features = ["io"] } mime = "0.3.17" bytes = "1.5.0" diff --git a/backend/Dockerfile b/backend/Dockerfile index 0341e7a0..d76be0a3 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y protobuf-compiler # Build application COPY ./backend . COPY ./schema /schema +ENV SQLX_OFFLINE=true RUN cargo build --release --bin backend # We do not need the Rust toolchain to run the binary! @@ -21,5 +22,6 @@ RUN apt-get update && apt install -y openssl ca-certificates WORKDIR /app COPY --from=builder /app/target/release/backend /usr/local/bin COPY ./backend/migrations /app/migrations +COPY ./backend/global_migrations /app/global_migrations ENTRYPOINT ["/usr/local/bin/backend"] diff --git a/backend/example.env b/backend/example.env index 5db904d6..a30afac4 100644 --- a/backend/example.env +++ b/backend/example.env @@ -1,4 +1,5 @@ KINDE_ISSUER=https://account.trading.camp KINDE_AUDIENCE=trading-server-api,a9869bb1225848b9ad5bad2a04b72b5f DATABASE_URL=sqlite://db.sqlite +GLOBAL_DATABASE_URL=sqlite://global.sqlite UPLOAD_DIR=./uploads diff --git a/backend/global_migrations/001_initial.sql b/backend/global_migrations/001_initial.sql new file mode 100644 index 00000000..6ec8ab01 --- /dev/null +++ b/backend/global_migrations/001_initial.sql @@ -0,0 +1,31 @@ +CREATE TABLE IF NOT EXISTS "global_user" ( + "id" INTEGER PRIMARY KEY, + "kinde_id" TEXT UNIQUE NOT NULL, + "display_name" TEXT NOT NULL, + "is_admin" BOOLEAN NOT NULL DEFAULT FALSE, + "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS "cohort" ( + "id" INTEGER PRIMARY KEY, + "name" TEXT NOT NULL UNIQUE, + "display_name" TEXT NOT NULL, + "db_path" TEXT NOT NULL UNIQUE, + "is_read_only" BOOLEAN NOT NULL DEFAULT FALSE, + "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL +); + +CREATE TABLE IF NOT EXISTS "cohort_member" ( + "id" INTEGER PRIMARY KEY, + "cohort_id" INTEGER NOT NULL REFERENCES "cohort", + "global_user_id" INTEGER REFERENCES "global_user", + "email" TEXT, + "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL, + UNIQUE("cohort_id", "global_user_id"), + UNIQUE("cohort_id", "email") +); + +CREATE TABLE IF NOT EXISTS "global_config" ( + "key" TEXT PRIMARY KEY, + "value" TEXT NOT NULL +); diff --git a/backend/global_migrations/002_add_initial_balance.sql b/backend/global_migrations/002_add_initial_balance.sql new file mode 100644 index 00000000..a0e431ad --- /dev/null +++ b/backend/global_migrations/002_add_initial_balance.sql @@ -0,0 +1 @@ +ALTER TABLE cohort_member ADD COLUMN initial_balance TEXT; diff --git a/backend/global_migrations/003_add_user_email.sql b/backend/global_migrations/003_add_user_email.sql new file mode 100644 index 00000000..29c04de2 --- /dev/null +++ b/backend/global_migrations/003_add_user_email.sql @@ -0,0 +1 @@ +ALTER TABLE "global_user" ADD COLUMN "email" TEXT; diff --git a/backend/migrations/20260302000000_add_global_user_id.sql b/backend/migrations/20260302000000_add_global_user_id.sql new file mode 100644 index 00000000..26694ee2 --- /dev/null +++ b/backend/migrations/20260302000000_add_global_user_id.sql @@ -0,0 +1,4 @@ +-- Add global_user_id column to account table for multi-cohort support +ALTER TABLE "account" ADD COLUMN "global_user_id" INTEGER DEFAULT NULL; +CREATE UNIQUE INDEX IF NOT EXISTS "idx_account_global_user_id" + ON "account" ("global_user_id") WHERE "global_user_id" IS NOT NULL; diff --git a/backend/migrations/20260305000000_add_account_color.sql b/backend/migrations/20260305000000_add_account_color.sql new file mode 100644 index 00000000..0a9b195b --- /dev/null +++ b/backend/migrations/20260305000000_add_account_color.sql @@ -0,0 +1 @@ +ALTER TABLE account ADD COLUMN color TEXT NOT NULL DEFAULT ''; diff --git a/backend/migrations/20260305000000_arbor_pixie_initial_balance.sql b/backend/migrations/20260305000001_arbor_pixie_initial_balance.sql similarity index 100% rename from backend/migrations/20260305000000_arbor_pixie_initial_balance.sql rename to backend/migrations/20260305000001_arbor_pixie_initial_balance.sql diff --git a/backend/migrations/20260306000001_add_account_color.sql b/backend/migrations/20260306000001_add_account_color.sql deleted file mode 100644 index f1b9170b..00000000 --- a/backend/migrations/20260306000001_add_account_color.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE "account" ADD COLUMN "color" TEXT; diff --git a/backend/src/airtable_users.rs b/backend/src/airtable_users.rs deleted file mode 100644 index c10d99b6..00000000 --- a/backend/src/airtable_users.rs +++ /dev/null @@ -1,530 +0,0 @@ -use std::{ - env, - sync::{LazyLock, RwLock}, - time::{Duration, Instant}, -}; - -use anyhow::Context; -use futures::future::join_all; - -use prost::Message; -use reqwest::{header, Client}; -use rust_decimal_macros::dec; -use serde::{Deserialize, Serialize}; -use urlencoding; - -use crate::{ - db::EnsureUserCreatedSuccess, - websocket_api::{server_message::Message as SM, Account, ServerMessage}, - AppState, -}; - -const TRADEGALA_PRODUCT_ID: &str = "2mR3AnL63Z"; -const TRADEGALA_PRODUCT_ID_ALT: &str = "ld6JAxWNn0"; -const ASH_PRODUCT_ID: &str = "0VNrVWONPg"; -const TRADEGALA_INITIAL_CLIPS: rust_decimal::Decimal = dec!(1000); -const ASH_INITIAL_CLIPS: rust_decimal::Decimal = dec!(2000); -const ERROR_SOURCE: &str = "exchange backend"; - -struct CachedKindeToken { - token: String, - expiry: Instant, -} - -static KINDE_TOKEN_CACHE: LazyLock>> = - LazyLock::new(|| RwLock::new(None)); - -#[derive(Debug, Deserialize)] -struct AirtableResponse { - records: Vec, -} - -#[derive(Debug, Deserialize)] -struct AirtableRecord { - id: String, - fields: AirtableFields, -} - -#[derive(Debug, Deserialize)] -struct AirtableFields { - #[serde(rename = "First Name")] - first_name: String, - #[serde(rename = "Last Name")] - last_name: String, - #[serde(rename = "Email")] - email: String, - #[serde(rename = "Ticket Status")] - #[allow(dead_code)] - ticket_status: Option, - #[serde(rename = "Product ID")] - #[allow(dead_code)] - product_id: Option, - #[serde(rename = "Transferred From ID")] - transferred_from_id: Option, - #[serde(rename = "Transferred From Email")] - #[allow(dead_code)] - transferred_from_email: Option, - #[serde(rename = "Initialized correctly?")] - initialized_correctly: Option, -} - -#[derive(Debug, Deserialize)] -struct KindeTokenResponse { - access_token: String, - #[allow(dead_code)] - token_type: String, - #[allow(dead_code)] - expires_in: u64, -} - -#[derive(Debug, Deserialize)] -struct KindeUsersResponse { - users: Option>, -} - -#[derive(Debug, Deserialize)] -struct KindeUser { - id: String, - #[allow(dead_code)] - email: String, - #[allow(dead_code)] - given_name: Option, - #[allow(dead_code)] - family_name: Option, -} - -#[derive(Debug, Serialize)] -struct KindeCreateUserRequest { - profile: KindeProfile, - identities: Vec, -} - -#[derive(Debug, Serialize)] -struct KindeProfile { - given_name: String, - family_name: String, -} - -#[derive(Debug, Serialize)] -struct KindeIdentity { - #[serde(rename = "type")] - identity_type: String, - is_verified: bool, - details: KindeIdentityDetails, -} - -#[derive(Debug, Serialize)] -struct KindeIdentityDetails { - #[serde(skip_serializing_if = "Option::is_none")] - email: Option, -} - -#[derive(Debug, Serialize)] -struct AirtableCreateRecordsRequest { - records: Vec, -} - -#[derive(Debug, Serialize)] -struct AirtableCreateRecord { - fields: AirtableErrorFields, -} - -#[derive(Debug, Serialize)] -struct AirtableErrorFields { - #[serde(rename = "Error Message")] - error_message: String, - #[serde(rename = "Source")] - source: String, -} - -#[derive(Debug, Serialize)] -struct AirtableUpdateRecordRequest { - fields: AirtableUpdateFields, -} - -#[derive(Debug, Serialize)] -struct AirtableUpdateFields { - #[serde(rename = "Initialized correctly?")] - initialized_correctly: bool, -} - -/// Fetches Airtable users, validates with Kinde, and ensures they exist in the database -/// -/// # Errors -/// Returns an error if API calls fail or environment variables are missing -pub async fn sync_airtable_users_to_kinde_and_db(app_state: AppState) -> anyhow::Result<()> { - let airtable_base_id = - env::var("AIRTABLE_BASE_ID").context("Missing AIRTABLE_BASE_ID environment variable")?; - let airtable_token = - env::var("AIRTABLE_TOKEN").context("Missing AIRTABLE_TOKEN environment variable")?; - - let client = Client::new(); - let airtable_url = - format!("https://api.airtable.com/v0/{airtable_base_id}/Tradegala%20Attendees"); - - let response = client - .get(&airtable_url) - .header("Authorization", format!("Bearer {airtable_token}")) - .send() - .await? - .json::() - .await?; - - let kinde_token = get_kinde_token(&client).await?; - - let futures = response - .records - .iter() - .filter(|record| !record.fields.initialized_correctly.is_some_and(|b| b)) - .map(|record| process_user(app_state.clone(), record, &kinde_token, &client)); - - let results = join_all(futures).await; - - let mut errors = Vec::new(); - for (result, record) in results.into_iter().zip(response.records.iter()) { - match result { - Ok(()) => (), - Err(e) => errors.push((record.fields.email.clone(), e)), - } - } - - if !errors.is_empty() { - return Err(anyhow::anyhow!( - "Errors occurred while processing users: {:?}", - errors - )); - } - - Ok(()) -} - -/// Helper function to process each user -async fn process_user( - app_state: AppState, - record: &AirtableRecord, - kinde_token: &str, - client: &Client, -) -> anyhow::Result<()> { - let email = &record.fields.email; - let first_name = &record.fields.first_name; - let last_name = &record.fields.last_name; - - let kinde_user = get_kinde_user_by_email(email, kinde_token, client).await?; - - let kinde_id = match kinde_user { - Some(user) => user.id, - None => create_kinde_user(email, first_name, last_name, kinde_token, client).await?, - }; - - let name = format!("{first_name} {last_name}"); - let result = app_state - .db - .ensure_user_created(&kinde_id, Some(&name), dec!(0)) - .await?; - - let id = match result.map_err(|e| anyhow::anyhow!("Couldn't create user {name}: {e:?}"))? { - EnsureUserCreatedSuccess { id, name: Some(_) } => { - let msg = ServerMessage { - request_id: String::new(), - message: Some(SM::AccountCreated(Account { - id, - name: name.clone(), - is_user: true, - universe_id: 0, - color: None, - })), - }; - app_state.subscriptions.send_public(msg); - tracing::info!("User {name} created"); - id - } - EnsureUserCreatedSuccess { id, name: None } => id, - }; - - if record - .fields - .transferred_from_id - .as_ref() - .is_some_and(|id| !id.is_empty()) - { - tracing::info!("User {name} has a transfer ticket, skipping pixie transfer"); - return Ok(()); - } - - let initial_clips = match record.fields.product_id.as_deref() { - Some(product_id) if product_id == TRADEGALA_PRODUCT_ID => TRADEGALA_INITIAL_CLIPS, - Some(product_id) if product_id == TRADEGALA_PRODUCT_ID_ALT => TRADEGALA_INITIAL_CLIPS, - Some(product_id) if product_id == ASH_PRODUCT_ID => ASH_INITIAL_CLIPS, - Some("") | None => { - tracing::info!("User {name} has no product ID, skipping pixie transfer"); - return Ok(()); - } - Some(unknown_product_id) => { - anyhow::bail!("Unknown product ID for user {name}: {unknown_product_id}"); - } - }; - - let transfer = app_state - .db - .ensure_arbor_pixie_transfer(id, initial_clips) - .await? - .map_err(|e| anyhow::anyhow!("Couldn't transfer initial clips to user {name}: {e:?}"))?; - - if let Some(transfer) = transfer { - let msg = ServerMessage { - request_id: String::new(), - message: Some(SM::TransferCreated(transfer.into())), - }; - app_state - .subscriptions - .send_private(id, msg.encode_to_vec().into()); - app_state.subscriptions.notify_portfolio(id); - tracing::info!("Pixie transfer created for user {name}"); - - // Update Airtable to indicate successful initialization - if let Err(e) = update_airtable_initialized_status(&record.id, client).await { - tracing::error!("Failed to update Airtable status for {name}: {e}"); - } else { - tracing::info!("Updated Airtable initialization status for {name}"); - } - } - - Ok(()) -} - -/// Updates the "Initialized correctly?" field in Airtable for a specific record -/// -/// # Errors -/// Returns an error if API call fails or environment variables are missing -async fn update_airtable_initialized_status( - record_id: &str, - client: &Client, -) -> anyhow::Result<()> { - let airtable_base_id = - env::var("AIRTABLE_BASE_ID").context("Missing AIRTABLE_BASE_ID environment variable")?; - let airtable_token = - env::var("AIRTABLE_TOKEN").context("Missing AIRTABLE_TOKEN environment variable")?; - - let airtable_url = - format!("https://api.airtable.com/v0/{airtable_base_id}/Tradegala%20Attendees/{record_id}"); - - let update_data = AirtableUpdateRecordRequest { - fields: AirtableUpdateFields { - initialized_correctly: true, - }, - }; - - let response = client - .patch(&airtable_url) - .header("Authorization", format!("Bearer {airtable_token}")) - .header("Content-Type", "application/json") - .json(&update_data) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - anyhow::bail!("Failed to update Airtable record: {}", error_text); - } - - Ok(()) -} - -/// Gets a Kinde access token, using a cached version if available and not expired -/// -/// # Errors -/// Returns an error if API call fails or environment variables are missing -async fn get_kinde_token(client: &Client) -> anyhow::Result { - { - let token_cache = KINDE_TOKEN_CACHE.read().unwrap(); - if let Some(cached) = &*token_cache { - if cached.expiry > Instant::now() { - // Token is still valid, return it - return Ok(cached.token.clone()); - } - } - } - - tracing::info!("Fetching new Kinde token"); - - let kinde_subdomain = - env::var("KINDE_SUBDOMAIN").context("Missing KINDE_SUBDOMAIN environment variable")?; - let kinde_client_id = - env::var("KINDE_CLIENT_ID").context("Missing KINDE_CLIENT_ID environment variable")?; - let kinde_client_secret = env::var("KINDE_CLIENT_SECRET") - .context("Missing KINDE_CLIENT_SECRET environment variable")?; - - let token_url = format!("https://{kinde_subdomain}.kinde.com/oauth2/token"); - - let form_data = [ - ("grant_type", "client_credentials"), - ("client_id", &kinde_client_id), - ("client_secret", &kinde_client_secret), - ( - "audience", - &format!("https://{kinde_subdomain}.kinde.com/api"), - ), - ("scope", "read:users create:users"), - ]; - - let response = client - .post(&token_url) - .form(&form_data) - .send() - .await? - .json::() - .await?; - - let expiry_buffer = Duration::from_secs(60); // 1 minute buffer - let expires_in = - Duration::from_secs(response.expires_in.saturating_sub(expiry_buffer.as_secs())); - let expiry = Instant::now() + expires_in; - - // Store token in cache - let token = response.access_token.clone(); - { - let mut token_cache = KINDE_TOKEN_CACHE.write().unwrap(); - *token_cache = Some(CachedKindeToken { - token: token.clone(), - expiry, - }); - } - - Ok(token) -} - -/// Gets a Kinde user by email -/// -/// # Errors -/// Returns an error if API call fails or environment variables are missing -async fn get_kinde_user_by_email( - email: &str, - token: &str, - client: &Client, -) -> anyhow::Result> { - let kinde_subdomain = - env::var("KINDE_SUBDOMAIN").context("Missing KINDE_SUBDOMAIN environment variable")?; - - let encoded_email = urlencoding::encode(email); - let users_url = - format!("https://{kinde_subdomain}.kinde.com/api/v1/users?email={encoded_email}"); - - let response = client - .get(&users_url) - .header("Authorization", format!("Bearer {token}")) - .send() - .await?; - - if !response.status().is_success() { - return Ok(None); - } - - let users_response = response.json::().await?; - let users = users_response.users.unwrap_or_default(); - - if users.len() > 1 { - tracing::warn!("Multiple Kinde users found for email {email}"); - } - - Ok(users.into_iter().next()) -} - -/// Creates a user in Kinde -/// -/// # Errors -/// Returns an error if API call fails or environment variables are missing -async fn create_kinde_user( - email: &str, - first_name: &str, - last_name: &str, - token: &str, - client: &Client, -) -> anyhow::Result { - let kinde_subdomain = - env::var("KINDE_SUBDOMAIN").context("Missing KINDE_SUBDOMAIN environment variable")?; - - let create_url = format!("https://{kinde_subdomain}.kinde.com/api/v1/user"); - - let user_data = KindeCreateUserRequest { - profile: KindeProfile { - given_name: first_name.to_string(), - family_name: last_name.to_string(), - }, - identities: vec![KindeIdentity { - identity_type: "email".to_string(), - is_verified: true, - details: KindeIdentityDetails { - email: Some(email.to_string()), - }, - }], - }; - - let mut headers = header::HeaderMap::new(); - headers.insert( - header::AUTHORIZATION, - header::HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - let response = client - .post(&create_url) - .headers(headers) - .json(&user_data) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - anyhow::bail!("Failed to create Kinde user: {}", error_text); - } - - let user_data: serde_json::Value = response.json().await?; - let id = user_data["id"] - .as_str() - .context("No ID returned when creating Kinde user")? - .to_string(); - - Ok(id) -} - -/// Logs an error message to the Airtable "Application Errors" table -/// -/// # Errors -/// Returns an error if API call fails or environment variables are missing -pub async fn log_error_to_airtable(error_message: &str) -> anyhow::Result<()> { - let airtable_base_id = - env::var("AIRTABLE_BASE_ID").context("Missing AIRTABLE_BASE_ID environment variable")?; - let airtable_token = - env::var("AIRTABLE_TOKEN").context("Missing AIRTABLE_TOKEN environment variable")?; - - let client = Client::new(); - let airtable_url = - format!("https://api.airtable.com/v0/{airtable_base_id}/Application%20Errors"); - - let request_data = AirtableCreateRecordsRequest { - records: vec![AirtableCreateRecord { - fields: AirtableErrorFields { - error_message: error_message.to_string(), - source: ERROR_SOURCE.to_string(), - }, - }], - }; - - let response = client - .post(&airtable_url) - .header("Authorization", format!("Bearer {airtable_token}")) - .header("Content-Type", "application/json") - .json(&request_data) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - anyhow::bail!("Failed to log error to Airtable: {}", error_text); - } - - Ok(()) -} diff --git a/backend/src/auth.rs b/backend/src/auth.rs index bc96df57..594d5a06 100644 --- a/backend/src/auth.rs +++ b/backend/src/auth.rs @@ -37,12 +37,17 @@ pub struct AccessClaims { pub sub: String, #[serde(default)] pub roles: Vec, + /// Email from the token (populated in dev-mode test tokens; not present in real JWTs) + #[serde(default)] + pub email: Option, } #[derive(Debug, Deserialize)] struct IdClaims { pub name: String, pub sub: String, + #[serde(default)] + pub email: Option, } static AUTH_CONFIG: OnceCell = OnceCell::new(); @@ -59,6 +64,30 @@ impl FromRequestParts for AccessClaims { (StatusCode::UNAUTHORIZED, "Missing Authorization header").into_response() })?; let token = bearer.token(); + + // Dev-mode: support test tokens for REST endpoints + #[cfg(feature = "dev-mode")] + if let Some(rest) = token.strip_prefix("test::") { + let parts: Vec<&str> = rest.split("::").collect(); + if parts.len() >= 3 { + let kinde_id = parts[0]; + let is_admin = parts[2].eq_ignore_ascii_case("true"); + let email = parts + .get(3) + .map(ToString::to_string) + .filter(|e| !e.is_empty()); + let mut roles = vec![Role::Trader]; + if is_admin { + roles.push(Role::Admin); + } + return Ok(AccessClaims { + sub: kinde_id.to_string(), + roles, + email, + }); + } + } + let claims = validate_jwt(token).await.map_err(|e| { tracing::error!("JWT validation failed: {:?}", e); (StatusCode::UNAUTHORIZED, "Bad JWT").into_response() @@ -112,6 +141,7 @@ pub struct ValidatedClient { pub id: String, pub roles: Vec, pub name: Option, + pub email: Option, } /// # Errors @@ -135,10 +165,35 @@ pub async fn validate_access_and_id( Ok(ValidatedClient { id: access_claims.sub, roles: access_claims.roles, + email: id_claims.as_ref().and_then(|c| c.email.clone()), name: id_claims.map(|c| c.name), }) } +/// Validate an ID token and return its email if the subject matches `expected_sub`. +/// +/// # Errors +/// Fails if the token is invalid or the subject does not match. +pub async fn validate_id_token_email_for_sub( + id_token: &str, + expected_sub: &str, +) -> anyhow::Result> { + #[cfg(feature = "dev-mode")] + if id_token.starts_with("test::") { + let test_client = validate_test_token(id_token)?; + if test_client.id != expected_sub { + anyhow::bail!("sub mismatch"); + } + return Ok(test_client.email); + } + + let id_claims: IdClaims = validate_jwt(id_token).await?; + if id_claims.sub != expected_sub { + anyhow::bail!("sub mismatch"); + } + Ok(id_claims.email) +} + /// Test-only function to create a `ValidatedClient` from test credentials. /// Token format: `test::::::` /// Example: `test::user123::Test User::true` @@ -152,24 +207,27 @@ pub fn validate_test_token(token: &str) -> anyhow::Result { } let parts: Vec<&str> = token.split("::").collect(); - if parts.len() != 4 { - anyhow::bail!("Invalid test token format: expected test::::::"); + if parts.len() < 4 { + anyhow::bail!( + "Invalid test token format: expected test::::::[::]" + ); } let kinde_id = parts[1].to_string(); let name = parts[2].to_string(); let is_admin = parts[3].parse::().unwrap_or(false); + let email = parts + .get(4) + .map(ToString::to_string) + .filter(|e| !e.is_empty()); - let roles = if is_admin { - vec![Role::Admin] - } else { - vec![] - }; + let roles = if is_admin { vec![Role::Admin] } else { vec![] }; Ok(ValidatedClient { id: kinde_id, roles, name: Some(name), + email, }) } diff --git a/backend/src/db.rs b/backend/src/db.rs index 9a44d88a..edce5571 100644 --- a/backend/src/db.rs +++ b/backend/src/db.rs @@ -152,6 +152,111 @@ impl DB { } } + /// Initialize a DB from a specific file path (for multi-cohort support). + /// Does NOT run seed data. + #[instrument(err)] + pub async fn init_with_path(db_path: &str, create_if_missing: bool) -> anyhow::Result { + let connection_options = SqliteConnectOptions::new() + .filename(db_path) + .create_if_missing(create_if_missing) + .journal_mode(SqliteJournalMode::Wal) + .synchronous(SqliteSynchronous::Normal) + .busy_timeout(std::time::Duration::from_secs(5)) + .optimize_on_close(true, None) + .pragma("optimize", "0x10002") + .pragma("wal_autocheckpoint", "0"); + + // Create pool first, then run migrations through it. + // This avoids holding a raw SqliteConnection across await points, + // which would make this future !Send. + let (release_tx, release_rx) = tokio::sync::broadcast::channel(1); + + let pool = SqlitePoolOptions::new() + .min_connections(8) + .max_connections(64) + .after_release(move |_, _| { + let release_tx = release_tx.clone(); + Box::pin(async move { + if let Err(e) = release_tx.send(()) { + tracing::error!("release_tx.send failed: {:?}", e); + } + Ok(true) + }) + }) + .connect_with(connection_options.clone()) + .await?; + + let mut migrator = sqlx::migrate::Migrator::new(Path::new("./migrations")).await?; + migrator + .set_ignore_missing(true) + .run(&pool) + .await?; + + let arbor_pixie_account_id: i64 = sqlx::query_scalar( + r"SELECT id FROM account WHERE name = ?", + ) + .bind(ARBOR_PIXIE_ACCOUNT_NAME) + .fetch_one(&pool) + .await?; + + // Checkpoint task: create a dedicated connection inside the spawned task + let checkpoint_options = connection_options; + tokio::spawn(async move { + let mut management_conn = + match SqliteConnection::connect_with(&checkpoint_options).await { + Ok(conn) => conn, + Err(e) => { + tracing::error!("Failed to create checkpoint connection: {e}"); + return; + } + }; + let mut release_rx = release_rx; + let mut released_connections: i64 = 0; + let mut remaining_pages: i64 = 0; + loop { + match release_rx.recv().await { + Ok(()) => { + released_connections += 1; + } + #[allow(clippy::cast_possible_wrap)] + Err(RecvError::Lagged(n)) => { + released_connections += n as i64; + } + Err(RecvError::Closed) => { + break; + } + } + let approx_wal_pages = remaining_pages + released_connections * 8; + if approx_wal_pages < CHECKPOINT_PAGE_LIMIT { + continue; + } + match sqlx::query_as::<_, WalCheckPointRow>("PRAGMA wal_checkpoint(PASSIVE)") + .fetch_one(&mut management_conn) + .await + { + Err(e) => { + tracing::error!("wal_checkpoint failed: {:?}", e); + } + Ok(row) => { + released_connections = 0; + remaining_pages = row.log - row.checkpointed; + tracing::info!( + "wal_checkpoint: busy={} log={} checkpointed={}", + row.busy, + row.log, + row.checkpointed + ); + } + } + } + }); + + Ok(Self { + arbor_pixie_account_id, + pool, + }) + } + #[instrument(err, skip(self))] pub async fn get_account(&self, account_id: i64) -> SqlxResult> { sqlx::query_as!( @@ -160,7 +265,7 @@ impl DB { SELECT id, name, - kinde_id IS NOT NULL AS "is_user: bool", + (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) AS "is_user: bool", universe_id, color FROM account @@ -401,7 +506,7 @@ impl DB { } let balance = initial_balance.to_string(); - let color_for_insert = color.clone(); + let color_for_insert = color.clone().unwrap_or_default(); let result = sqlx::query_scalar!( r#" INSERT INTO account (name, balance, universe_id, color) @@ -437,7 +542,7 @@ impl DB { // Owner must be in universe 0 or in the same universe as the new account let (owner_universe, owner_is_user) = sqlx::query_as::<_, (i64, bool)>( - r#"SELECT universe_id, kinde_id IS NOT NULL as "is_user" FROM account WHERE id = ?"#, + r#"SELECT universe_id, (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) as "is_user" FROM account WHERE id = ?"#, ) .bind(owner_id) .fetch_one(transaction.as_mut()) @@ -595,7 +700,7 @@ impl DB { SELECT EXISTS ( SELECT 1 FROM account - WHERE id = ? AND kinde_id IS NOT NULL + WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) AS "exists!: bool" "#, existing_owner_id @@ -630,7 +735,7 @@ impl DB { SELECT EXISTS( SELECT 1 FROM account - WHERE id = ? AND kinde_id IS NOT NULL + WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) as "exists!: bool" "#, to_account_id @@ -670,7 +775,7 @@ impl DB { SELECT EXISTS( SELECT 1 FROM account - WHERE id = ? AND kinde_id IS NOT NULL + WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) as "exists!: bool" "#, from_account_id @@ -837,6 +942,73 @@ impl DB { })) } + /// Ensure a user exists in this cohort DB by `global_user_id`. + /// Used in multi-cohort mode where the global DB tracks the user identity. + #[instrument(err, skip(self))] + pub async fn ensure_user_created_by_global_id( + &self, + global_user_id: i64, + requested_name: &str, + initial_balance: Decimal, + ) -> SqlxResult> { + let balance = Text(initial_balance); + + // First try to find user by global_user_id + let existing_user = sqlx::query!( + r#" + SELECT id AS "id!", name + FROM account + WHERE global_user_id = ? + "#, + global_user_id + ) + .fetch_optional(&self.pool) + .await?; + + if let Some(user) = existing_user { + return Ok(Ok(EnsureUserCreatedSuccess { + id: user.id, + name: None, + })); + } + + // Check for name conflicts + let conflicting_account = sqlx::query!( + r#" + SELECT id + FROM account + WHERE name = ? AND (global_user_id != ? OR global_user_id IS NULL) + "#, + requested_name, + global_user_id + ) + .fetch_optional(&self.pool) + .await?; + + let final_name = if conflicting_account.is_some() { + format!("{requested_name}-g{global_user_id}") + } else { + requested_name.to_string() + }; + + let id = sqlx::query_scalar!( + r#" + INSERT INTO account (global_user_id, name, balance) + VALUES (?, ?, ?) + RETURNING id + "#, + global_user_id, + final_name, + balance, + ) + .fetch_one(&self.pool) + .await?; + Ok(Ok(EnsureUserCreatedSuccess { + id, + name: Some(final_name), + })) + } + /// # Errors /// Fails is there's a database error pub async fn get_portfolio(&self, account_id: i64) -> SqlxResult> { @@ -889,11 +1061,58 @@ impl DB { pub fn get_all_accounts(&self) -> BoxStream<'_, SqlxResult> { sqlx::query_as!( Account, - r#"SELECT id, name, kinde_id IS NOT NULL as "is_user: bool", universe_id, color FROM account"# + r#"SELECT id, name, (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) as "is_user: bool", universe_id, color FROM account"# ) .fetch(&self.pool) } + /// Get all accounts with `kinde_id` but without `global_user_id` (legacy accounts). + /// Used during migration from single-DB to multi-cohort mode. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_legacy_kinde_users(&self) -> SqlxResult> { + sqlx::query_as::<_, (i64, String, String)>( + r"SELECT id, kinde_id, name FROM account WHERE kinde_id IS NOT NULL AND global_user_id IS NULL", + ) + .fetch_all(&self.pool) + .await + } + + /// Get the balance for an account by `global_user_id`. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_balance_by_global_user_id( + &self, + global_user_id: i64, + ) -> SqlxResult> { + let row = sqlx::query_scalar::<_, String>( + r"SELECT balance FROM account WHERE global_user_id = ?", + ) + .bind(global_user_id) + .fetch_optional(&self.pool) + .await?; + Ok(row.and_then(|b| b.parse::().ok())) + } + + /// Set the `global_user_id` for an account (used during migration). + /// + /// # Errors + /// Returns an error on database failure. + pub async fn set_global_user_id( + &self, + account_id: i64, + global_user_id: i64, + ) -> SqlxResult<()> { + sqlx::query("UPDATE account SET global_user_id = ? WHERE id = ?") + .bind(global_user_id) + .bind(account_id) + .execute(&self.pool) + .await?; + Ok(()) + } + #[instrument(err, skip(self))] pub async fn get_all_markets(&self) -> SqlxResult> { let market_rows: Vec = sqlx::query_as( @@ -1538,7 +1757,7 @@ impl DB { let initiator_is_user = sqlx::query_scalar!( r#"SELECT EXISTS( - SELECT 1 FROM account WHERE id = ? AND kinde_id IS NOT NULL + SELECT 1 FROM account WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) as "exists!: bool""#, initiator_id ) @@ -1703,7 +1922,7 @@ impl DB { let is_user = sqlx::query_scalar!( r#"SELECT EXISTS( - SELECT 1 FROM account WHERE id = ? AND kinde_id IS NOT NULL + SELECT 1 FROM account WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) as "exists!: bool""#, to_account_id ) @@ -3428,7 +3647,7 @@ impl DB { SELECT EXISTS( SELECT 1 FROM account - WHERE id = ? AND kinde_id IS NOT NULL + WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) AS "exists!: bool" "#, user_id @@ -3490,7 +3709,7 @@ impl DB { SELECT EXISTS( SELECT 1 FROM account - WHERE id = ? AND kinde_id IS NOT NULL + WHERE id = ? AND (kinde_id IS NOT NULL OR global_user_id IS NOT NULL) ) AS "exists!: bool" "#, user_id @@ -5589,7 +5808,7 @@ mod tests { .await?; match &status { Ok(_) => println!("Transfer succeeded"), - Err(e) => println!("Transfer failed with error: {:?}", e), + Err(e) => println!("Transfer failed with error: {e:?}"), } assert!( status.is_ok(), @@ -5628,7 +5847,7 @@ mod tests { from_account_id: 3, }) .await?; - println!("{:?}", status); + println!("{status:?}"); assert!(status.is_ok()); Ok(()) diff --git a/backend/src/global_db.rs b/backend/src/global_db.rs new file mode 100644 index 00000000..fab4b731 --- /dev/null +++ b/backend/src/global_db.rs @@ -0,0 +1,543 @@ +use std::{env, path::Path}; + +use serde::Serialize; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteJournalMode, SqliteSynchronous}, + Connection, FromRow, SqliteConnection, SqlitePool, +}; + +#[derive(Clone, Debug)] +pub struct GlobalDB { + pool: SqlitePool, +} + +#[derive(Debug, Clone, FromRow, Serialize)] +pub struct GlobalUser { + pub id: i64, + pub kinde_id: String, + pub display_name: String, + pub is_admin: bool, + pub email: Option, +} + +#[derive(Debug, Clone, FromRow, Serialize)] +pub struct CohortInfo { + pub id: i64, + pub name: String, + pub display_name: String, + pub db_path: String, + pub is_read_only: bool, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CohortMember { + pub id: i64, + pub cohort_id: i64, + pub global_user_id: Option, + pub email: Option, + pub display_name: Option, + pub initial_balance: Option, +} + +impl GlobalDB { + /// Initialize the global database, creating it if necessary and running migrations. + /// + /// # Errors + /// Returns an error if database creation or migration fails. + pub async fn init() -> anyhow::Result { + let db_url = env::var("GLOBAL_DATABASE_URL") + .unwrap_or_else(|_| "sqlite:///data/global.sqlite".to_string()); + let db_path = db_url.trim_start_matches("sqlite://"); + Self::init_with_path(db_path).await + } + + /// Initialize a `GlobalDB` from a specific file path. + /// + /// # Errors + /// Returns an error if database initialization fails. + pub async fn init_with_path(db_path: &str) -> anyhow::Result { + let connection_options = SqliteConnectOptions::new() + .filename(db_path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal) + .synchronous(SqliteSynchronous::Normal) + .busy_timeout(std::time::Duration::from_secs(5)); + + let mut management_conn = SqliteConnection::connect_with(&connection_options).await?; + + // Run global migrations + let mut migrator = sqlx::migrate::Migrator::new(Path::new("./global_migrations")).await?; + migrator + .set_ignore_missing(true) + .run(&mut management_conn) + .await?; + + let pool = SqlitePool::connect_with(connection_options).await?; + + tracing::info!("Global database initialized"); + Ok(Self { pool }) + } + + /// Create or find a global user by `kinde_id`. Updates `display_name` and email if changed. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn ensure_global_user( + &self, + kinde_id: &str, + name: &str, + email: Option<&str>, + ) -> Result { + // Try to find existing user + let existing = sqlx::query_as::<_, GlobalUser>( + r"SELECT id, kinde_id, display_name, is_admin, email FROM global_user WHERE kinde_id = ?", + ) + .bind(kinde_id) + .fetch_optional(&self.pool) + .await?; + + if let Some(mut user) = existing { + // Update display_name and email if changed + if user.display_name != name || user.email.as_deref() != email { + sqlx::query("UPDATE global_user SET display_name = ?, email = COALESCE(?, email) WHERE id = ?") + .bind(name) + .bind(email) + .bind(user.id) + .execute(&self.pool) + .await?; + user.display_name = name.to_string(); + if email.is_some() { + user.email = email.map(String::from); + } + } + return Ok(user); + } + + // Create new user + let id = sqlx::query_scalar::<_, i64>( + r"INSERT INTO global_user (kinde_id, display_name, email) VALUES (?, ?, ?) RETURNING id", + ) + .bind(kinde_id) + .bind(name) + .bind(email) + .fetch_one(&self.pool) + .await?; + + Ok(GlobalUser { + id, + kinde_id: kinde_id.to_string(), + display_name: name.to_string(), + is_admin: false, + email: email.map(String::from), + }) + } + + /// Get a global user by `kinde_id`. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_global_user_by_kinde_id( + &self, + kinde_id: &str, + ) -> Result, sqlx::Error> { + sqlx::query_as::<_, GlobalUser>( + r"SELECT id, kinde_id, display_name, is_admin, email FROM global_user WHERE kinde_id = ?", + ) + .bind(kinde_id) + .fetch_optional(&self.pool) + .await + } + + /// Get all cohorts a user is a member of. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_user_cohorts( + &self, + global_user_id: i64, + ) -> Result, sqlx::Error> { + sqlx::query_as::<_, CohortInfo>( + r" + SELECT c.id, c.name, c.display_name, c.db_path, c.is_read_only + FROM cohort c + INNER JOIN cohort_member cm ON cm.cohort_id = c.id + WHERE cm.global_user_id = ? + ORDER BY c.created_at DESC + ", + ) + .bind(global_user_id) + .fetch_all(&self.pool) + .await + } + + /// Get all cohorts. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_all_cohorts(&self) -> Result, sqlx::Error> { + sqlx::query_as::<_, CohortInfo>( + r"SELECT id, name, display_name, db_path, is_read_only FROM cohort ORDER BY created_at DESC", + ) + .fetch_all(&self.pool) + .await + } + + /// Check if a user is a member of a cohort. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn is_cohort_member( + &self, + global_user_id: i64, + cohort_id: i64, + ) -> Result { + let count = sqlx::query_scalar::<_, i64>( + r"SELECT COUNT(*) FROM cohort_member WHERE global_user_id = ? AND cohort_id = ?", + ) + .bind(global_user_id) + .bind(cohort_id) + .fetch_one(&self.pool) + .await?; + Ok(count > 0) + } + + /// Get a config value by key. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_config(&self, key: &str) -> Result, sqlx::Error> { + sqlx::query_scalar::<_, String>(r"SELECT value FROM global_config WHERE key = ?") + .bind(key) + .fetch_optional(&self.pool) + .await + } + + /// Set a config value. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn set_config(&self, key: &str, value: &str) -> Result<(), sqlx::Error> { + sqlx::query( + r"INSERT INTO global_config (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + ) + .bind(key) + .bind(value) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Create a new cohort. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn create_cohort( + &self, + name: &str, + display_name: &str, + db_path: &str, + ) -> Result { + let id = sqlx::query_scalar::<_, i64>( + r"INSERT INTO cohort (name, display_name, db_path) VALUES (?, ?, ?) RETURNING id", + ) + .bind(name) + .bind(display_name) + .bind(db_path) + .fetch_one(&self.pool) + .await?; + + Ok(CohortInfo { + id, + name: name.to_string(), + display_name: display_name.to_string(), + db_path: db_path.to_string(), + is_read_only: false, + }) + } + + /// Delete a cohort by id. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn delete_cohort(&self, id: i64) -> Result<(), sqlx::Error> { + sqlx::query("DELETE FROM cohort WHERE id = ?") + .bind(id) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Update a cohort's `display_name` and/or read-only status. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn update_cohort( + &self, + id: i64, + display_name: Option<&str>, + is_read_only: Option, + ) -> Result<(), sqlx::Error> { + if let Some(dn) = display_name { + sqlx::query("UPDATE cohort SET display_name = ? WHERE id = ?") + .bind(dn) + .bind(id) + .execute(&self.pool) + .await?; + } + if let Some(ro) = is_read_only { + sqlx::query("UPDATE cohort SET is_read_only = ? WHERE id = ?") + .bind(ro) + .bind(id) + .execute(&self.pool) + .await?; + } + Ok(()) + } + + /// Get a cohort by name. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_cohort_by_name(&self, name: &str) -> Result, sqlx::Error> { + sqlx::query_as::<_, CohortInfo>( + r"SELECT id, name, display_name, db_path, is_read_only FROM cohort WHERE name = ?", + ) + .bind(name) + .fetch_optional(&self.pool) + .await + } + + /// Batch add members by email. Returns count of newly added members. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn batch_add_members( + &self, + cohort_id: i64, + emails: &[String], + initial_balance: Option<&str>, + ) -> Result { + let mut added = 0; + for email in emails { + let email = email.trim().to_lowercase(); + if email.is_empty() { + continue; + } + + // Check if this email matches an existing global user + // (we don't have email in global_user, so just store the email for now) + let result = sqlx::query( + r"INSERT INTO cohort_member (cohort_id, email, initial_balance) VALUES (?, ?, ?) ON CONFLICT DO NOTHING", + ) + .bind(cohort_id) + .bind(&email) + .bind(initial_balance) + .execute(&self.pool) + .await?; + + if result.rows_affected() > 0 { + added += 1; + } + } + Ok(added) + } + + /// Add a member by `global_user_id`. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn add_member_by_user_id( + &self, + cohort_id: i64, + global_user_id: i64, + initial_balance: Option<&str>, + ) -> Result<(), sqlx::Error> { + sqlx::query( + r"INSERT INTO cohort_member (cohort_id, global_user_id, initial_balance) VALUES (?, ?, ?) ON CONFLICT DO NOTHING", + ) + .bind(cohort_id) + .bind(global_user_id) + .bind(initial_balance) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Remove a member from a cohort. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn remove_member(&self, cohort_id: i64, member_id: i64) -> Result<(), sqlx::Error> { + sqlx::query("DELETE FROM cohort_member WHERE id = ? AND cohort_id = ?") + .bind(member_id) + .bind(cohort_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Get all members of a cohort. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_cohort_members( + &self, + cohort_id: i64, + ) -> Result, sqlx::Error> { + let rows = sqlx::query_as::<_, CohortMemberRow>( + r" + SELECT cm.id, cm.cohort_id, cm.global_user_id, COALESCE(cm.email, gu.email) AS email, gu.display_name, cm.initial_balance + FROM cohort_member cm + LEFT JOIN global_user gu ON gu.id = cm.global_user_id + WHERE cm.cohort_id = ? + ORDER BY cm.created_at + ", + ) + .bind(cohort_id) + .fetch_all(&self.pool) + .await?; + + Ok(rows + .into_iter() + .map(|r| CohortMember { + id: r.id, + cohort_id: r.cohort_id, + global_user_id: r.global_user_id, + email: r.email, + display_name: r.display_name, + initial_balance: r.initial_balance, + }) + .collect()) + } + + /// Set a user's admin status. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn set_user_admin( + &self, + global_user_id: i64, + is_admin: bool, + ) -> Result<(), sqlx::Error> { + sqlx::query("UPDATE global_user SET is_admin = ? WHERE id = ?") + .bind(is_admin) + .bind(global_user_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Link a pre-authorized email to a global user. When a user signs up and their email + /// matches a pre-authorized `cohort_member` row, this links them. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn link_email_to_user( + &self, + email: &str, + global_user_id: i64, + ) -> Result<(), sqlx::Error> { + sqlx::query( + r"UPDATE cohort_member SET global_user_id = ? WHERE email = ? AND global_user_id IS NULL", + ) + .bind(global_user_id) + .bind(email.trim().to_lowercase()) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Get a member's initial balance for a specific cohort. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_member_initial_balance( + &self, + cohort_id: i64, + global_user_id: i64, + ) -> Result, sqlx::Error> { + sqlx::query_scalar::<_, String>( + r"SELECT initial_balance FROM cohort_member WHERE cohort_id = ? AND global_user_id = ? AND initial_balance IS NOT NULL", + ) + .bind(cohort_id) + .bind(global_user_id) + .fetch_optional(&self.pool) + .await + } + + /// Get all global users (for admin UI). + /// + /// # Errors + /// Returns an error on database failure. + pub async fn get_all_users(&self) -> Result, sqlx::Error> { + sqlx::query_as::<_, GlobalUser>( + r"SELECT id, kinde_id, display_name, is_admin, email FROM global_user ORDER BY created_at", + ) + .fetch_all(&self.pool) + .await + } + + /// Update a user's display name. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn update_user_display_name( + &self, + global_user_id: i64, + display_name: &str, + ) -> Result<(), sqlx::Error> { + sqlx::query("UPDATE global_user SET display_name = ? WHERE id = ?") + .bind(display_name) + .bind(global_user_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Delete a global user and all their cohort memberships. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn delete_user(&self, global_user_id: i64) -> Result<(), sqlx::Error> { + sqlx::query("DELETE FROM cohort_member WHERE global_user_id = ?") + .bind(global_user_id) + .execute(&self.pool) + .await?; + sqlx::query("DELETE FROM global_user WHERE id = ?") + .bind(global_user_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + /// Update a member's initial balance. + /// + /// # Errors + /// Returns an error on database failure. + pub async fn update_member_initial_balance( + &self, + cohort_id: i64, + member_id: i64, + initial_balance: Option<&str>, + ) -> Result { + let result = sqlx::query( + "UPDATE cohort_member SET initial_balance = ? WHERE id = ? AND cohort_id = ?", + ) + .bind(initial_balance) + .bind(member_id) + .bind(cohort_id) + .execute(&self.pool) + .await?; + Ok(result.rows_affected() > 0) + } +} + +#[derive(Debug, FromRow)] +struct CohortMemberRow { + id: i64, + cohort_id: i64, + global_user_id: Option, + email: Option, + display_name: Option, + initial_balance: Option, +} diff --git a/backend/src/handle_socket.rs b/backend/src/handle_socket.rs index f1f14903..f2e12138 100644 --- a/backend/src/handle_socket.rs +++ b/backend/src/handle_socket.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{ auth::{validate_access_and_id_or_test, Role}, db::{self, EnsureUserCreatedSuccess, DB}, @@ -6,13 +8,12 @@ use crate::{ request_failed::{ErrorDetails, RequestDetails}, server_message::Message as SM, Account, Accounts, ActingAs, Auction, AuctionDeleted, Authenticated, ClientMessage, - GetFullOrderHistory, GetFullTradeHistory, SetSudo, - Market, MarketGroup, MarketGroups, MarketType, MarketTypeDeleted, MarketTypes, Order, - Orders, OwnershipGiven, OwnershipRevoked, Portfolio, Portfolios, RequestFailed, - ServerMessage, SettleAuction, SudoStatus, Trade, Trades, Transfer, Transfers, Universe, - Universes, + GetFullOrderHistory, GetFullTradeHistory, Market, MarketGroup, MarketGroups, MarketType, + MarketTypeDeleted, MarketTypes, Order, Orders, OwnershipGiven, OwnershipRevoked, Portfolio, + Portfolios, RequestFailed, ServerMessage, SetSudo, SettleAuction, SudoStatus, Trade, + Trades, Transfer, Transfers, Universe, Universes, }, - AppState, + AppState, CohortState, }; use anyhow::{anyhow, bail}; use async_stream::stream; @@ -24,8 +25,8 @@ use rust_decimal_macros::dec; use tokio::sync::broadcast::error::RecvError; use tokio_stream::wrappers::errors::BroadcastStreamRecvError; -pub async fn handle_socket(socket: WebSocket, app_state: AppState) { - if let Err(e) = handle_socket_fallible(socket, app_state).await { +pub async fn handle_socket(socket: WebSocket, app_state: AppState, cohort: Arc) { + if let Err(e) = handle_socket_fallible(socket, app_state, cohort).await { tracing::error!("Error handling socket: {e}"); } else { tracing::info!("Client disconnected"); @@ -33,19 +34,25 @@ pub async fn handle_socket(socket: WebSocket, app_state: AppState) { } #[allow(clippy::too_many_lines, unused_assignments)] -async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> anyhow::Result<()> { +async fn handle_socket_fallible( + mut socket: WebSocket, + app_state: AppState, + cohort: Arc, +) -> anyhow::Result<()> { + let is_read_only = &cohort.is_read_only; let AuthenticatedClient { id: mut user_id, is_admin, act_as, mut owned_accounts, - } = authenticate(&app_state, &mut socket).await?; + auction_only, + } = authenticate(&app_state, &cohort, &mut socket).await?; let admin_id = is_admin.then_some(user_id); let mut acting_as = act_as.unwrap_or(user_id); let mut sudo_enabled = false; - let mut subscription_receivers = app_state.subscriptions.subscribe_all(&owned_accounts); - let db = &app_state.db; + let mut subscription_receivers = cohort.subscriptions.subscribe_all(&owned_accounts); + let db = &cohort.db; let mut current_universe_id = db.get_account_universe_id(acting_as).await?.unwrap_or(0); send_initial_private_data(db, &owned_accounts, &mut socket, false).await?; @@ -59,7 +66,7 @@ async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> a .collect(); for &account_id in &added_owned_accounts { owned_accounts.push(account_id); - app_state + cohort .subscriptions .add_owned_subscription(&mut subscription_receivers, account_id); } @@ -71,7 +78,7 @@ async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> a .collect(); owned_accounts.retain(|account_id| !removed_owned_accounts.contains(account_id)); for &account_id in &removed_owned_accounts { - app_state + cohort .subscriptions .remove_owned_subscription(&mut subscription_receivers, account_id); } @@ -100,7 +107,8 @@ async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> a if is_admin { // Since we're not sending it in update_owned_accounts // Pass false because sudo_enabled starts as false - admins must enable sudo to see hidden data - send_initial_public_data(db, false, &owned_accounts, current_universe_id, &mut socket).await?; + send_initial_public_data(db, false, &owned_accounts, current_universe_id, &mut socket) + .await?; } // Important that this is last - it doubles as letting the client know we're done sending initial data @@ -166,10 +174,13 @@ async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> a if let Some(result) = handle_client_message( &mut socket, &app_state, + &cohort, effective_admin_id, user_id, acting_as, &owned_accounts, + is_read_only, + auction_only, msg, ) .await? { @@ -209,7 +220,7 @@ async fn handle_socket_fallible(mut socket: WebSocket, app_state: AppState) -> a if act_as.admin_as_user { user_id = act_as.account_id; owned_accounts = db.get_owned_accounts(user_id).await?; - subscription_receivers = app_state.subscriptions.subscribe_all(&owned_accounts); + subscription_receivers = cohort.subscriptions.subscribe_all(&owned_accounts); // TODO: somehow notify the client to get rid of existing portfolios send_initial_private_data(db, &owned_accounts, &mut socket, false).await?; update_owned_accounts!(); @@ -545,22 +556,31 @@ struct ActAsInfo { enum HandleResult { ActAs(ActAsInfo), - SudoChange { request_id: String, enabled: bool }, - AdminRequired { request_id: String, msg_type: &'static str }, + SudoChange { + request_id: String, + enabled: bool, + }, + AdminRequired { + request_id: String, + msg_type: &'static str, + }, } -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::too_many_arguments)] async fn handle_client_message( socket: &mut WebSocket, app_state: &AppState, + cohort: &CohortState, admin_id: Option, user_id: i64, acting_as: i64, owned_accounts: &[i64], + is_read_only: &std::sync::atomic::AtomicBool, + auction_only: bool, msg: ws::Message, ) -> anyhow::Result> { - let db = &app_state.db; - let subscriptions = &app_state.subscriptions; + let db = &cohort.db; + let subscriptions = &cohort.subscriptions; let ws::Message::Binary(msg) = msg else { let resp = request_failed(String::new(), "Unknown", "Expected Binary message"); @@ -618,6 +638,18 @@ async fn handle_client_message( }; }; } + // Check read-only and auction-only restrictions + macro_rules! check_mutation_allowed { + ($msg_type:expr) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!($msg_type, "Cohort is read-only"); + } + if auction_only { + fail!($msg_type, "Auction access only"); + } + }; + } + match msg { CM::GetFullTradeHistory(GetFullTradeHistory { market_id }) => { check_expensive_rate_limit!("GetFullTradeHistory"); @@ -648,14 +680,17 @@ async fn handle_client_message( socket.send(msg.encode_to_vec().into()).await?; } CM::CreateMarket(create_market) => { + check_mutation_allowed!("CreateMarket"); check_expensive_rate_limit!("CreateMarket"); // Get the universe_id of the acting_as account - let universe_id = db - .get_account_universe_id(acting_as) - .await? - .unwrap_or(0); + let universe_id = db.get_account_universe_id(acting_as).await?.unwrap_or(0); match db - .create_market(user_id, create_market, admin_id.is_some(), universe_id) + .create_market( + admin_id.unwrap_or(user_id), + create_market, + admin_id.is_some(), + universe_id, + ) .await? { Ok(market) => { @@ -668,6 +703,7 @@ async fn handle_client_message( }; } CM::SettleMarket(settle_market) => { + check_mutation_allowed!("SettleMarket"); check_expensive_rate_limit!("SettleMarket"); match db.settle_market(user_id, admin_id, settle_market).await? { Ok(db::MarketSettledWithAffectedAccounts { @@ -687,6 +723,7 @@ async fn handle_client_message( } } CM::CreateOrder(create_order) => { + check_mutation_allowed!("CreateOrder"); check_mutate_rate_limit!("CreateOrder"); match db.create_order(acting_as, create_order).await? { Ok(order_created) => { @@ -703,6 +740,7 @@ async fn handle_client_message( } } CM::CancelOrder(cancel_order) => { + check_mutation_allowed!("CancelOrder"); check_mutate_rate_limit!("CancelOrder"); match db.cancel_order(acting_as, cancel_order).await? { Ok(order_cancelled) => { @@ -717,6 +755,7 @@ async fn handle_client_message( } } CM::MakeTransfer(make_transfer) => { + check_mutation_allowed!("MakeTransfer"); check_mutate_rate_limit!("MakeTransfer"); let from_account_id = make_transfer.from_account_id; let to_account_id = make_transfer.to_account_id; @@ -737,6 +776,7 @@ async fn handle_client_message( } } CM::Out(out) => { + check_mutation_allowed!("Out"); check_mutate_rate_limit!("Out"); match db.out(acting_as, out.clone()).await? { Ok(orders_cancelled_list) => { @@ -744,7 +784,10 @@ async fn handle_client_message( subscriptions.notify_portfolio(acting_as); } for orders_cancelled in orders_cancelled_list { - let msg = server_message(String::new(), SM::OrdersCancelled(orders_cancelled.into())); + let msg = server_message( + String::new(), + SM::OrdersCancelled(orders_cancelled.into()), + ); subscriptions.send_public(msg); } let resp = encode_server_message(request_id, SM::Out(out)); @@ -756,6 +799,7 @@ async fn handle_client_message( } } CM::CreateAccount(create_account) => { + check_mutation_allowed!("CreateAccount"); check_mutate_rate_limit!("CreateAccount"); let owner_id = create_account.owner_id; let status = db.create_account(user_id, create_account).await?; @@ -771,6 +815,7 @@ async fn handle_client_message( } } CM::ShareOwnership(share_ownership) => { + check_mutation_allowed!("ShareOwnership"); check_mutate_rate_limit!("ShareOwnership"); let to_account_id = share_ownership.to_account_id; match db.share_ownership(user_id, share_ownership).await? { @@ -786,6 +831,7 @@ async fn handle_client_message( } } CM::RevokeOwnership(revoke_ownership) => { + check_mutation_allowed!("RevokeOwnership"); check_mutate_rate_limit!("RevokeOwnership"); let from_account_id = revoke_ownership.from_account_id; if admin_id.is_none() { @@ -809,6 +855,7 @@ async fn handle_client_message( } } CM::Redeem(redeem) => { + check_mutation_allowed!("Redeem"); check_mutate_rate_limit!("Redeem"); match db.redeem(acting_as, redeem).await? { Ok(redeemed) => { @@ -858,6 +905,7 @@ async fn handle_client_message( })); } CM::CreateUniverse(create_universe) => { + check_mutation_allowed!("CreateUniverse"); check_expensive_rate_limit!("CreateUniverse"); match db .create_universe(user_id, create_universe.name, create_universe.description) @@ -873,6 +921,7 @@ async fn handle_client_message( } } CM::EditMarket(edit_market) => { + check_mutation_allowed!("EditMarket"); // Check if user is admin or owner of the market let Some((owner_id, status)) = db.get_market_owner_and_status(edit_market.id).await? else { @@ -897,7 +946,10 @@ async fn handle_client_message( } if edit_market.description.is_some() && !is_admin && !is_owner { - fail!("EditMarket", "You can only edit your own market's description"); + fail!( + "EditMarket", + "You can only edit your own market's description" + ); } // Note: admin_id.is_some() already implies sudo is enabled @@ -913,6 +965,9 @@ async fn handle_client_message( }; } CM::CreateAuction(create_auction) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("CreateAuction", "Cohort is read-only"); + } check_expensive_rate_limit!("CreateMarket"); match db .create_auction(user_id, create_auction) @@ -928,38 +983,42 @@ async fn handle_client_message( }; } CM::SettleAuction(settle_auction) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("SettleAuction", "Cohort is read-only"); + } check_expensive_rate_limit!("SettleAuction"); match admin_id { None => { fail!("SettleAuction", "only admins can settle auctions"); } - Some(admin_id) => { - match db.settle_auction(admin_id, settle_auction).await? { - Ok(db::AuctionSettledWithAffectedAccounts { - auction_settled, - affected_accounts, - transfer, - }) => { - let msg = server_message( - request_id, - SM::AuctionSettled(auction_settled.into()), - ); - subscriptions.send_public(msg); - let transfer_msg = - encode_server_message(String::new(), SM::TransferCreated(transfer.into())); - for &account in &affected_accounts { - subscriptions.send_private(account, transfer_msg.clone()); - subscriptions.notify_portfolio(account); - } - } - Err(failure) => { - fail!("SettleAuction", failure.message()); + Some(admin_id) => match db.settle_auction(admin_id, settle_auction).await? { + Ok(db::AuctionSettledWithAffectedAccounts { + auction_settled, + affected_accounts, + transfer, + }) => { + let msg = + server_message(request_id, SM::AuctionSettled(auction_settled.into())); + subscriptions.send_public(msg); + let transfer_msg = encode_server_message( + String::new(), + SM::TransferCreated(transfer.into()), + ); + for &account in &affected_accounts { + subscriptions.send_private(account, transfer_msg.clone()); + subscriptions.notify_portfolio(account); } } - } + Err(failure) => { + fail!("SettleAuction", failure.message()); + } + }, } } CM::BuyAuction(buy_auction) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("BuyAuction", "Cohort is read-only"); + } check_expensive_rate_limit!("SettleAuction"); match db .settle_auction( @@ -993,11 +1052,11 @@ async fn handle_client_message( }; } CM::DeleteAuction(delete_auction) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("DeleteAuction", "Cohort is read-only"); + } check_expensive_rate_limit!("DeleteAuction"); - match db - .delete_auction(user_id, delete_auction, admin_id) - .await? - { + match db.delete_auction(user_id, delete_auction, admin_id).await? { Ok(auction_id) => { let msg = server_message( request_id, @@ -1011,11 +1070,11 @@ async fn handle_client_message( } } CM::EditAuction(edit_auction) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("EditAuction", "Cohort is read-only"); + } check_expensive_rate_limit!("EditAuction"); - match db - .edit_auction(user_id, edit_auction, admin_id) - .await? - { + match db.edit_auction(user_id, edit_auction, admin_id).await? { Ok(auction) => { let msg = server_message(request_id, SM::Auction(auction.into())); subscriptions.send_public(msg); @@ -1026,6 +1085,9 @@ async fn handle_client_message( } } CM::CreateMarketType(create_market_type) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("CreateMarketType", "Cohort is read-only"); + } if admin_id.is_none() { return Ok(Some(HandleResult::AdminRequired { request_id, @@ -1051,6 +1113,9 @@ async fn handle_client_message( }; } CM::DeleteMarketType(delete_market_type) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("DeleteMarketType", "Cohort is read-only"); + } if admin_id.is_none() { return Ok(Some(HandleResult::AdminRequired { request_id, @@ -1075,6 +1140,9 @@ async fn handle_client_message( }; } CM::CreateMarketGroup(create_market_group) => { + if is_read_only.load(std::sync::atomic::Ordering::Relaxed) { + fail!("CreateMarketGroup", "Cohort is read-only"); + } if admin_id.is_none() { return Ok(Some(HandleResult::AdminRequired { request_id, @@ -1128,14 +1196,17 @@ struct AuthenticatedClient { is_admin: bool, act_as: Option, owned_accounts: Vec, + auction_only: bool, } #[allow(clippy::too_many_lines)] async fn authenticate( app_state: &AppState, + cohort: &CohortState, socket: &mut WebSocket, ) -> anyhow::Result { - let db = &app_state.db; + let db = &cohort.db; + let global_db = &app_state.global_db; loop { match socket.recv().await { Some(Ok(ws::Message::Binary(msg))) => { @@ -1151,25 +1222,123 @@ async fn authenticate( }; let id_jwt = (!authenticate.id_jwt.is_empty()).then_some(authenticate.id_jwt); let act_as = (authenticate.act_as != 0).then_some(authenticate.act_as); - let valid_client = - match validate_access_and_id_or_test(&authenticate.jwt, id_jwt.as_deref()).await { - Ok(valid_client) => valid_client, - Err(e) => { - tracing::error!("JWT validation failed: {e}"); - let resp = - request_failed(request_id, "Authenticate", "JWT validation failed"); - socket.send(resp).await?; - continue; - } - }; - let is_admin = valid_client.roles.contains(&Role::Admin); - let initial_balance = if is_admin { dec!(100_000_000) } else { dec!(0) }; - let result = db - .ensure_user_created( + let valid_client = match validate_access_and_id_or_test( + &authenticate.jwt, + id_jwt.as_deref(), + ) + .await + { + Ok(valid_client) => valid_client, + Err(e) => { + tracing::error!("JWT validation failed: {e}"); + let resp = + request_failed(request_id, "Authenticate", "JWT validation failed"); + socket.send(resp).await?; + continue; + } + }; + + // Get or create global user + let display_name = valid_client.name.as_deref().unwrap_or("Unknown"); + let global_user = match global_db + .ensure_global_user( &valid_client.id, - valid_client.name.as_deref(), - initial_balance, + display_name, + valid_client.email.as_deref(), ) + .await + { + Ok(user) => user, + Err(e) => { + tracing::error!("Failed to ensure global user: {e}"); + let resp = request_failed( + request_id, + "Authenticate", + "Failed to create global user", + ); + socket.send(resp).await?; + continue; + } + }; + + // Link email-based pre-authorizations if we have an email + if let Some(email) = &valid_client.email { + if let Err(e) = global_db.link_email_to_user(email, global_user.id).await { + tracing::warn!("Failed to link email to user: {e}"); + } + } + + // Check admin status (Kinde role OR global DB flag) + let is_admin = valid_client.roles.contains(&Role::Admin) || global_user.is_admin; + + // Check cohort access + #[allow(unused_mut)] + let mut is_member = global_db + .is_cohort_member(global_user.id, cohort.info.id) + .await + .unwrap_or(false); + + // In dev-mode, auto-add users as cohort members + #[cfg(feature = "dev-mode")] + if !is_member { + if let Err(e) = global_db + .add_member_by_user_id(cohort.info.id, global_user.id, None) + .await + { + tracing::warn!("Failed to auto-add user as cohort member: {e}"); + } else { + is_member = true; + } + } + + let mut auction_only = false; + + if !is_admin && !is_member { + // Check if this is the active auction cohort with public auction enabled + let public_auction_enabled = global_db + .get_config("public_auction_enabled") + .await + .unwrap_or(None) + .is_some_and(|v| v == "true"); + let active_auction_cohort_id = global_db + .get_config("active_auction_cohort_id") + .await + .unwrap_or(None) + .and_then(|v| v.parse::().ok()); + + if public_auction_enabled && active_auction_cohort_id == Some(cohort.info.id) { + auction_only = true; + } else { + let resp = request_failed( + request_id, + "Authenticate", + "You are not authorized for this cohort", + ); + socket.send(resp).await?; + continue; + } + } + + let initial_balance = match global_db + .get_member_initial_balance(cohort.info.id, global_user.id) + .await + .ok() + .flatten() + .and_then(|s| rust_decimal::Decimal::from_str_exact(&s).ok()) + { + Some(bal) => bal, + None => { + if is_admin { + dec!(100_000_000) + } else { + dec!(0) + } + } + }; + + // Create/find user in cohort DB using global_user_id + let result = db + .ensure_user_created_by_global_id(global_user.id, display_name, initial_balance) .await?; let id = match result { @@ -1187,7 +1356,7 @@ async fn authenticate( color: None, }), ); - app_state.subscriptions.send_public(msg); + cohort.subscriptions.send_public(msg); id } Ok(EnsureUserCreatedSuccess { id, name: None }) => id, @@ -1213,7 +1382,10 @@ async fn authenticate( } let resp = encode_server_message( request_id, - SM::Authenticated(Authenticated { account_id: id }), + SM::Authenticated(Authenticated { + account_id: id, + auction_only, + }), ); socket.send(resp).await?; return Ok(AuthenticatedClient { @@ -1221,6 +1393,7 @@ async fn authenticate( is_admin, act_as, owned_accounts, + auction_only, }); } Some(Ok(ws::Message::Ping(payload))) => { diff --git a/backend/src/lib.rs b/backend/src/lib.rs index 279276f6..4e36e03a 100644 --- a/backend/src/lib.rs +++ b/backend/src/lib.rs @@ -1,6 +1,11 @@ -use std::{path::PathBuf, sync::Arc}; +use std::{ + path::PathBuf, + sync::{atomic::AtomicBool, Arc}, +}; +use dashmap::DashMap; use db::DB; +use global_db::{CohortInfo, GlobalDB}; use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter}; use nonzero_ext::nonzero; use subscriptions::Subscriptions; @@ -10,10 +15,18 @@ pub mod websocket_api { include!(concat!(env!("OUT_DIR"), "/websocket_api.rs")); } -#[derive(Clone)] -pub struct AppState { +/// State for a single cohort (its own DB + pub/sub system). +pub struct CohortState { pub db: DB, pub subscriptions: Subscriptions, + pub info: CohortInfo, + pub is_read_only: AtomicBool, +} + +#[derive(Clone)] +pub struct AppState { + pub global_db: GlobalDB, + pub cohorts: Arc>>, pub expensive_ratelimit: Arc>, pub admin_expensive_ratelimit: Arc>, pub mutate_ratelimit: Arc>, @@ -23,24 +36,88 @@ pub struct AppState { const ADMIN_RATE_LIMIT_MULTIPLIER: u32 = 10; const LARGE_REQUEST_QUOTA: Quota = Quota::per_minute(nonzero!(180u32)); -const ADMIN_LARGE_REQUEST_QUOTA: Quota = Quota::per_minute(nonzero!(180u32 * ADMIN_RATE_LIMIT_MULTIPLIER)); +const ADMIN_LARGE_REQUEST_QUOTA: Quota = + Quota::per_minute(nonzero!(180u32 * ADMIN_RATE_LIMIT_MULTIPLIER)); const MUTATE_QUOTA: Quota = Quota::per_second(nonzero!(100u32)).allow_burst(nonzero!(1000u32)); -const ADMIN_MUTATE_QUOTA: Quota = Quota::per_second(nonzero!(100u32 * ADMIN_RATE_LIMIT_MULTIPLIER)).allow_burst(nonzero!(1000u32 * ADMIN_RATE_LIMIT_MULTIPLIER)); +const ADMIN_MUTATE_QUOTA: Quota = Quota::per_second(nonzero!(100u32 * ADMIN_RATE_LIMIT_MULTIPLIER)) + .allow_burst(nonzero!(1000u32 * ADMIN_RATE_LIMIT_MULTIPLIER)); impl AppState { /// # Errors /// Returns an error if initializing the database failed. pub async fn new() -> anyhow::Result { - let db = DB::init().await?; - let subscriptions = Subscriptions::new(); + let global_db = GlobalDB::init().await?; + + let cohorts = Arc::new(DashMap::new()); + + // Load all cohorts from global DB and initialize their databases + let cohort_list = global_db.get_all_cohorts().await?; + for cohort_info in cohort_list { + match DB::init_with_path(&cohort_info.db_path, true).await { + Ok(db) => { + let cohort_state = Arc::new(CohortState { + db, + subscriptions: Subscriptions::new(), + is_read_only: AtomicBool::new(cohort_info.is_read_only), + info: cohort_info.clone(), + }); + cohorts.insert(cohort_info.name.clone(), cohort_state); + tracing::info!("Loaded cohort: {}", cohort_info.name); + } + Err(e) => { + tracing::error!( + "Failed to load cohort '{}' at '{}': {e}", + cohort_info.name, + cohort_info.db_path + ); + } + } + } + + // If no cohorts exist, check for legacy DATABASE_URL and auto-migrate + if cohorts.is_empty() { + if let Ok(legacy_db_url) = std::env::var("DATABASE_URL") { + tracing::info!("No cohorts found, migrating legacy database"); + let db_path = legacy_db_url.trim_start_matches("sqlite://").to_string(); + let db_path_str = if db_path.starts_with("//") { + db_path.trim_start_matches('/').to_string() + } else { + db_path + }; + + match DB::init_with_path(&db_path_str, true).await { + Ok(db) => { + let cohort_info = global_db + .create_cohort("main", "Main", &db_path_str) + .await?; + + // Migrate existing kinde_id users to global users + Self::migrate_legacy_users(&global_db, &db, &cohort_info).await?; + + let cohort_state = Arc::new(CohortState { + db, + subscriptions: Subscriptions::new(), + is_read_only: AtomicBool::new(cohort_info.is_read_only), + info: cohort_info, + }); + cohorts.insert("main".to_string(), cohort_state); + tracing::info!("Legacy database migrated as 'main' cohort"); + } + Err(e) => { + tracing::error!("Failed to load legacy database: {e}"); + } + } + } + } + let expensive_ratelimit = Arc::new(RateLimiter::keyed(LARGE_REQUEST_QUOTA)); let admin_expensive_ratelimit = Arc::new(RateLimiter::keyed(ADMIN_LARGE_REQUEST_QUOTA)); let mutate_ratelimit = Arc::new(RateLimiter::keyed(MUTATE_QUOTA)); let admin_mutate_ratelimit = Arc::new(RateLimiter::keyed(ADMIN_MUTATE_QUOTA)); let uploads_dir = PathBuf::from("/data/uploads"); // Default value, overridden in main.rs Ok(Self { - db, - subscriptions, + global_db, + cohorts, expensive_ratelimit, admin_expensive_ratelimit, mutate_ratelimit, @@ -48,12 +125,74 @@ impl AppState { uploads_dir, }) } + + /// Migrate legacy users from a cohort DB to the global DB. + async fn migrate_legacy_users( + global_db: &GlobalDB, + cohort_db: &DB, + cohort_info: &CohortInfo, + ) -> anyhow::Result<()> { + let legacy_users = cohort_db.get_legacy_kinde_users().await?; + for (account_id, kinde_id, name) in legacy_users { + let global_user = global_db.ensure_global_user(&kinde_id, &name, None).await?; + cohort_db + .set_global_user_id(account_id, global_user.id) + .await?; + // Also add as cohort member + global_db + .add_member_by_user_id(cohort_info.id, global_user.id, None) + .await?; + } + Ok(()) + } + + /// Add a new cohort at runtime (e.g., from admin API). + /// If the DB is a legacy DB (has `kinde_id` but no `global_user_id`), migrates users. + /// + /// # Errors + /// Returns an error if database initialization fails. + pub async fn add_cohort( + &self, + cohort_info: CohortInfo, + create_if_missing: bool, + ) -> anyhow::Result<()> { + let db = DB::init_with_path(&cohort_info.db_path, create_if_missing).await?; + + // Check if this is a legacy DB that needs migration + let legacy_users = db.get_legacy_kinde_users().await?; + if !legacy_users.is_empty() { + tracing::info!( + "Migrating {} legacy users in cohort '{}'", + legacy_users.len(), + cohort_info.name + ); + for (account_id, kinde_id, name) in legacy_users { + let global_user = self + .global_db + .ensure_global_user(&kinde_id, &name, None) + .await?; + db.set_global_user_id(account_id, global_user.id).await?; + self.global_db + .add_member_by_user_id(cohort_info.id, global_user.id, None) + .await?; + } + } + + let cohort_state = Arc::new(CohortState { + db, + subscriptions: Subscriptions::new(), + is_read_only: AtomicBool::new(cohort_info.is_read_only), + info: cohort_info.clone(), + }); + self.cohorts.insert(cohort_info.name.clone(), cohort_state); + Ok(()) + } } -pub mod airtable_users; pub mod auth; pub mod convert; pub mod db; +pub mod global_db; pub mod handle_socket; pub mod subscriptions; diff --git a/backend/src/main.rs b/backend/src/main.rs index 9ca99618..3c43079e 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,17 +1,18 @@ +use std::sync::Arc; + use axum::{ self, extract::{Multipart, Path as AxumPath, State, WebSocketUpgrade}, - http::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN}, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, - routing::{get, post}, - Router, + routing::{delete, get, post, put}, + Json, Router, }; -use backend::{airtable_users, AppState}; +use backend::{auth::AccessClaims, global_db::CohortInfo, AppState}; +use serde::{Deserialize, Serialize}; use std::{env, path::Path, str::FromStr}; use tokio::{fs::create_dir_all, net::TcpListener}; -use tower_http::{ - limit::RequestBodyLimitLayer, set_header::response::SetResponseHeaderLayer, trace::TraceLayer, -}; +use tower_http::{cors::CorsLayer, limit::RequestBodyLimitLayer, trace::TraceLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use uuid::Uuid; @@ -38,17 +39,48 @@ async fn main() -> anyhow::Result<()> { } let app = Router::new() - .route("/api", get(api)) - .route("/sync-airtable-users", get(sync_airtable_users)) + // Per-cohort WebSocket route + .route("/api/ws/:cohort_name", get(cohort_ws)) + // REST endpoints + .route("/api/cohorts", get(list_cohorts)) + // Admin REST endpoints + .route( + "/api/admin/cohorts", + get(admin_list_cohorts).post(create_cohort), + ) + .route("/api/admin/cohorts/:name", put(update_cohort)) + .route( + "/api/admin/cohorts/:name/members", + get(list_members).post(batch_add_members), + ) + .route( + "/api/admin/cohorts/:name/members/:id", + put(update_member).delete(remove_member), + ) + .route("/api/admin/config", get(get_config).put(update_config)) + .route("/api/admin/users", get(list_users)) + .route("/api/admin/users/details", get(list_users_detailed)) + .route("/api/admin/users/:id/admin", put(toggle_admin)) + .route( + "/api/admin/users/:id/display-name", + put(admin_update_display_name), + ) + .route("/api/admin/users/:id", delete(delete_user_endpoint)) + .route("/api/admin/available-dbs", get(list_available_dbs)) + // Authenticated user endpoints + .route("/api/users/me/display-name", put(update_my_display_name)) + // Utility routes .route("/api/upload-image", post(upload_image)) .route("/api/images/:filename", get(serve_image)) .layer(TraceLayer::new_for_http()) - // Limit file uploads to 10MB .layer(RequestBodyLimitLayer::new(50 * 1024 * 1024)) - .layer(SetResponseHeaderLayer::if_not_present( - ACCESS_CONTROL_ALLOW_ORIGIN, - HeaderValue::from_static("*"), - )) + .layer( + CorsLayer::new() + .allow_origin(tower_http::cors::Any) + .allow_methods(tower_http::cors::Any) + .allow_headers(tower_http::cors::Any) + .allow_private_network(true), + ) .with_state(AppState { uploads_dir: uploads_dir.to_path_buf(), ..state @@ -81,71 +113,772 @@ async fn main() -> anyhow::Result<()> { Ok(axum::serve(listener, app).await?) } +// --- WebSocket Handler --- + +#[axum::debug_handler] +async fn cohort_ws( + ws: WebSocketUpgrade, + AxumPath(cohort_name): AxumPath, + State(state): State, +) -> Response { + let Some(cohort) = state.cohorts.get(&cohort_name).map(|c| Arc::clone(&c)) else { + return (StatusCode::NOT_FOUND, "Cohort not found").into_response(); + }; + ws.on_upgrade(move |socket| backend::handle_socket::handle_socket(socket, state, cohort)) +} + +// --- REST Endpoints --- + +#[derive(Serialize)] +struct CohortsResponse { + cohorts: Vec, + active_auction_cohort: Option, + default_cohort: Option, + public_auction_enabled: bool, +} + +#[axum::debug_handler] +async fn list_cohorts( + claims: AccessClaims, + headers: HeaderMap, + State(state): State, +) -> Result, (StatusCode, String)> { + // Prefer email from validated ID token when provided; fall back to access-token claims. + let id_token_email = headers + .get("x-id-token") + .and_then(|v| v.to_str().ok()) + .filter(|v| !v.is_empty()) + .map(str::to_owned); + + let resolved_email = if let Some(id_token) = id_token_email { + match backend::auth::validate_id_token_email_for_sub(&id_token, &claims.sub).await { + Ok(email) => email, + Err(e) => { + tracing::warn!("Invalid x-id-token for /api/cohorts: {e}"); + claims.email.clone() + } + } + } else { + claims.email.clone() + }; + + // Ensure global user exists (creates if needed, same as WS auth flow) + let display_name = claims.sub.clone(); // Fallback; WS auth will update with real name + let global_user = state + .global_db + .ensure_global_user(&claims.sub, &display_name, resolved_email.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Link email-based pre-authorizations if we have an email + if let Some(email) = &resolved_email { + if let Err(e) = state + .global_db + .link_email_to_user(email, global_user.id) + .await + { + tracing::warn!("Failed to link email to user in list_cohorts: {e}"); + } + } + + let is_admin_by_role = claims.roles.contains(&backend::auth::Role::Admin); + + let is_admin = global_user.is_admin || is_admin_by_role; + + let cohorts = if is_admin { + state + .global_db + .get_all_cohorts() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + } else { + state + .global_db + .get_user_cohorts(global_user.id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + }; + + let public_auction_enabled = state + .global_db + .get_config("public_auction_enabled") + .await + .unwrap_or(None) + .is_some_and(|v| v == "true"); + + let active_auction_cohort = get_active_auction_cohort_name(&state).await; + let default_cohort = get_cohort_name_by_config_key(&state, "default_cohort_id").await; + + Ok(Json(CohortsResponse { + cohorts, + active_auction_cohort, + default_cohort, + public_auction_enabled, + })) +} + +async fn get_cohort_name_by_config_key(state: &AppState, config_key: &str) -> Option { + let cohort_id = state + .global_db + .get_config(config_key) + .await + .ok() + .flatten() + .and_then(|v| v.parse::().ok())?; + + let all_cohorts = state.global_db.get_all_cohorts().await.ok()?; + all_cohorts + .into_iter() + .find(|c| c.id == cohort_id) + .map(|c| c.name) +} + +async fn get_active_auction_cohort_name(state: &AppState) -> Option { + get_cohort_name_by_config_key(state, "active_auction_cohort_id").await +} + +// --- Admin Endpoints --- + +async fn check_admin(state: &AppState, claims: &AccessClaims) -> Result<(), (StatusCode, String)> { + if claims.roles.contains(&backend::auth::Role::Admin) { + return Ok(()); + } + + let is_global_admin = state + .global_db + .get_global_user_by_kinde_id(&claims.sub) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .is_some_and(|user| user.is_admin); + + if is_global_admin { + Ok(()) + } else { + Err((StatusCode::FORBIDDEN, "Admin access required".to_string())) + } +} + +#[axum::debug_handler] +async fn admin_list_cohorts( + claims: AccessClaims, + State(state): State, +) -> Result>, (StatusCode, String)> { + check_admin(&state, &claims).await?; + state + .global_db + .get_all_cohorts() + .await + .map(Json) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())) +} + +#[derive(Deserialize)] +struct CreateCohortRequest { + name: String, + display_name: String, + #[serde(default)] + existing_db: bool, +} + +#[axum::debug_handler] +async fn create_cohort( + claims: AccessClaims, + State(state): State, + Json(body): Json, +) -> Result, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + // Determine data directory from DATABASE_URL or default + let data_dir = std::env::var("DATABASE_URL") + .ok() + .and_then(|url| { + let path = url.trim_start_matches("sqlite://"); + Path::new(path) + .parent() + .map(|p| p.to_string_lossy().into_owned()) + }) + .unwrap_or_else(|| "/data".to_string()); + + let db_path = format!("{}/{}.sqlite", data_dir, body.name); + + // If using existing DB, verify the file exists first + if body.existing_db && !Path::new(&db_path).exists() { + return Err(( + StatusCode::BAD_REQUEST, + format!("Database file not found: {db_path}"), + )); + } + if !body.existing_db && Path::new(&db_path).exists() { + return Err(( + StatusCode::BAD_REQUEST, + format!( + "Database file already exists: {db_path}. Check 'Use existing database' to adopt it." + ), + )); + } + + let cohort_info = state + .global_db + .create_cohort(&body.name, &body.display_name, &db_path) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Initialize and add the cohort at runtime; clean up global row on failure + if let Err(e) = state + .add_cohort(cohort_info.clone(), !body.existing_db) + .await + { + // Roll back the global DB row so retries don't hit UNIQUE constraint + if let Err(del_err) = state.global_db.delete_cohort(cohort_info.id).await { + tracing::error!("Failed to clean up cohort row after init failure: {del_err}"); + } + return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string())); + } + + Ok(Json(cohort_info)) +} + +#[derive(Deserialize)] +struct UpdateCohortRequest { + display_name: Option, + is_read_only: Option, +} + +#[axum::debug_handler] +async fn update_cohort( + claims: AccessClaims, + AxumPath(name): AxumPath, + State(state): State, + Json(body): Json, +) -> Result { + check_admin(&state, &claims).await?; + + let cohort = state + .global_db + .get_cohort_by_name(&name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Cohort not found".to_string()))?; + + state + .global_db + .update_cohort(cohort.id, body.display_name.as_deref(), body.is_read_only) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + // Update in-memory read-only flag (takes effect immediately for all connections) + if let Some(is_read_only) = body.is_read_only { + if let Some(cohort_state) = state.cohorts.get(&name) { + cohort_state + .is_read_only + .store(is_read_only, std::sync::atomic::Ordering::Relaxed); + } + } + + Ok(StatusCode::OK) +} + +#[derive(Serialize)] +struct MemberWithBalance { + #[serde(flatten)] + member: backend::global_db::CohortMember, + balance: Option, +} + +#[axum::debug_handler] +async fn list_members( + claims: AccessClaims, + AxumPath(name): AxumPath, + State(state): State, +) -> Result>, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + let cohort = state + .global_db + .get_cohort_by_name(&name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Cohort not found".to_string()))?; + + let members = state + .global_db + .get_cohort_members(cohort.id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let cohort_state = state.cohorts.get(&name); + + let mut result = Vec::with_capacity(members.len()); + for member in members { + let balance = + if let (Some(global_user_id), Some(cs)) = (member.global_user_id, &cohort_state) { + cs.db + .get_balance_by_global_user_id(global_user_id) + .await + .ok() + .flatten() + } else { + None + }; + result.push(MemberWithBalance { member, balance }); + } + + Ok(Json(result)) +} + +#[derive(Deserialize)] +struct BatchAddMembersRequest { + #[serde(default)] + emails: Vec, + #[serde(default)] + user_ids: Vec, + initial_balance: Option, +} + +#[derive(Serialize)] +struct BatchAddMembersResponse { + added: usize, +} + +#[axum::debug_handler] +async fn batch_add_members( + claims: AccessClaims, + AxumPath(name): AxumPath, + State(state): State, + Json(body): Json, +) -> Result, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + let cohort = state + .global_db + .get_cohort_by_name(&name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Cohort not found".to_string()))?; + + let mut added = 0; + + if !body.emails.is_empty() { + added += state + .global_db + .batch_add_members(cohort.id, &body.emails, body.initial_balance.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + } + + for user_id in &body.user_ids { + state + .global_db + .add_member_by_user_id(cohort.id, *user_id, body.initial_balance.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + added += 1; + } + + Ok(Json(BatchAddMembersResponse { added })) +} + +#[axum::debug_handler] +async fn remove_member( + claims: AccessClaims, + AxumPath((name, member_id)): AxumPath<(String, i64)>, + State(state): State, +) -> Result { + check_admin(&state, &claims).await?; + + let cohort = state + .global_db + .get_cohort_by_name(&name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Cohort not found".to_string()))?; + + state + .global_db + .remove_member(cohort.id, member_id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(StatusCode::OK) +} + +#[derive(Serialize)] +struct GlobalConfig { + active_auction_cohort_id: Option, + default_cohort_id: Option, + public_auction_enabled: bool, +} + +#[axum::debug_handler] +async fn get_config( + claims: AccessClaims, + State(state): State, +) -> Result, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + let active_auction_cohort_id = state + .global_db + .get_config("active_auction_cohort_id") + .await + .unwrap_or(None) + .and_then(|v| v.parse().ok()); + + let default_cohort_id = state + .global_db + .get_config("default_cohort_id") + .await + .unwrap_or(None) + .and_then(|v| v.parse().ok()); + + let public_auction_enabled = state + .global_db + .get_config("public_auction_enabled") + .await + .unwrap_or(None) + .is_some_and(|v| v == "true"); + + Ok(Json(GlobalConfig { + active_auction_cohort_id, + default_cohort_id, + public_auction_enabled, + })) +} + +#[derive(Deserialize)] +#[allow(clippy::option_option)] // Intentional: distinguishes "not provided" from "set to null" +struct UpdateConfigRequest { + active_auction_cohort_id: Option>, + default_cohort_id: Option>, + public_auction_enabled: Option, +} + +#[axum::debug_handler] +async fn update_config( + claims: AccessClaims, + State(state): State, + Json(body): Json, +) -> Result { + check_admin(&state, &claims).await?; + + if let Some(maybe_id) = body.active_auction_cohort_id { + let value = match maybe_id { + Some(id) => id.to_string(), + None => String::new(), + }; + state + .global_db + .set_config("active_auction_cohort_id", &value) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + } + if let Some(maybe_id) = body.default_cohort_id { + let value = match maybe_id { + Some(id) => id.to_string(), + None => String::new(), + }; + state + .global_db + .set_config("default_cohort_id", &value) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + } + if let Some(enabled) = body.public_auction_enabled { + state + .global_db + .set_config("public_auction_enabled", &enabled.to_string()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + } + + Ok(StatusCode::OK) +} + #[axum::debug_handler] -async fn api(ws: WebSocketUpgrade, State(state): State) -> Response { - ws.on_upgrade(move |socket| backend::handle_socket::handle_socket(socket, state)) +async fn list_users( + claims: AccessClaims, + State(state): State, +) -> Result>, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + state + .global_db + .get_all_users() + .await + .map(Json) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string())) +} + +#[derive(Serialize)] +struct UserCohortDetail { + cohort_name: String, + cohort_display_name: String, + balance: Option, +} + +#[derive(Serialize)] +struct UserWithCohorts { + #[serde(flatten)] + user: backend::global_db::GlobalUser, + cohorts: Vec, } #[axum::debug_handler] -async fn sync_airtable_users(State(state): State) -> Response { - match airtable_users::sync_airtable_users_to_kinde_and_db(state).await { - Ok(()) => { - tracing::info!("Successfully synchronized Airtable users"); - (axum::http::StatusCode::OK, "OK").into_response() +async fn list_users_detailed( + claims: AccessClaims, + State(state): State, +) -> Result>, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + let users = state + .global_db + .get_all_users() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let mut result = Vec::with_capacity(users.len()); + for user in users { + let user_cohort_infos = state + .global_db + .get_user_cohorts(user.id) + .await + .unwrap_or_default(); + + let mut cohorts = Vec::new(); + for ci in &user_cohort_infos { + let balance = if let Some(cs) = state.cohorts.get(&ci.name) { + cs.db + .get_balance_by_global_user_id(user.id) + .await + .ok() + .flatten() + } else { + None + }; + cohorts.push(UserCohortDetail { + cohort_name: ci.name.clone(), + cohort_display_name: ci.display_name.clone(), + balance, + }); } - Err(e) => { - tracing::error!("Failed to synchronize Airtable users: {e}"); - if let Err(e) = airtable_users::log_error_to_airtable(&e.to_string()).await { - tracing::error!("Failed to log error to Airtable: {e}"); + + result.push(UserWithCohorts { user, cohorts }); + } + + Ok(Json(result)) +} + +#[derive(Deserialize)] +struct ToggleAdminRequest { + is_admin: bool, +} + +#[axum::debug_handler] +async fn toggle_admin( + claims: AccessClaims, + AxumPath(user_id): AxumPath, + State(state): State, + Json(body): Json, +) -> Result { + check_admin(&state, &claims).await?; + + state + .global_db + .set_user_admin(user_id, body.is_admin) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(StatusCode::OK) +} + +#[derive(Deserialize)] +struct UpdateDisplayNameRequest { + display_name: String, +} + +#[axum::debug_handler] +async fn update_my_display_name( + claims: AccessClaims, + State(state): State, + Json(body): Json, +) -> Result { + let display_name = body.display_name.trim(); + if display_name.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + "Display name cannot be empty".to_string(), + )); + } + let global_user = state + .global_db + .ensure_global_user(&claims.sub, &claims.sub, claims.email.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + state + .global_db + .update_user_display_name(global_user.id, display_name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(StatusCode::OK) +} + +#[axum::debug_handler] +async fn admin_update_display_name( + claims: AccessClaims, + AxumPath(user_id): AxumPath, + State(state): State, + Json(body): Json, +) -> Result { + check_admin(&state, &claims).await?; + let display_name = body.display_name.trim(); + if display_name.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + "Display name cannot be empty".to_string(), + )); + } + state + .global_db + .update_user_display_name(user_id, display_name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + Ok(StatusCode::OK) +} + +#[axum::debug_handler] +async fn delete_user_endpoint( + claims: AccessClaims, + AxumPath(user_id): AxumPath, + State(state): State, +) -> Result { + check_admin(&state, &claims).await?; + state + .global_db + .delete_user(user_id) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + Ok(StatusCode::OK) +} + +#[derive(Deserialize)] +struct UpdateMemberRequest { + initial_balance: Option, +} + +#[axum::debug_handler] +async fn update_member( + claims: AccessClaims, + AxumPath((name, member_id)): AxumPath<(String, i64)>, + State(state): State, + Json(body): Json, +) -> Result { + check_admin(&state, &claims).await?; + + let cohort = state + .global_db + .get_cohort_by_name(&name) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? + .ok_or((StatusCode::NOT_FOUND, "Cohort not found".to_string()))?; + + let updated = state + .global_db + .update_member_initial_balance(cohort.id, member_id, body.initial_balance.as_deref()) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + if !updated { + return Err(( + StatusCode::NOT_FOUND, + "Member not found in cohort".to_string(), + )); + } + + Ok(StatusCode::OK) +} + +#[axum::debug_handler] +async fn list_available_dbs( + claims: AccessClaims, + State(state): State, +) -> Result>, (StatusCode, String)> { + check_admin(&state, &claims).await?; + + let data_dir = std::env::var("DATABASE_URL") + .ok() + .and_then(|url| { + let path = url.trim_start_matches("sqlite://"); + Path::new(path) + .parent() + .map(|p| p.to_string_lossy().into_owned()) + }) + .unwrap_or_else(|| "/data".to_string()); + + // Collect db_paths already used by cohorts + let cohorts = state + .global_db + .get_all_cohorts() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let used: std::collections::HashSet = cohorts.into_iter().map(|c| c.db_path).collect(); + + let mut available = Vec::new(); + let Ok(entries) = std::fs::read_dir(&data_dir) else { + return Ok(Json(available)); + }; + + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().is_some_and(|ext| ext == "sqlite") { + let full_path = path.to_string_lossy().to_string(); + if !used.contains(&full_path) { + if let Some(stem) = path.file_stem() { + available.push(stem.to_string_lossy().to_string()); + } } - ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, - "Failed to synchronize Airtable users", - ) - .into_response() } } + + available.sort(); + Ok(Json(available)) } +// --- Utility Endpoints --- + #[axum::debug_handler] async fn upload_image( State(state): State, mut multipart: Multipart, -) -> Result { +) -> Result { let Some(field) = multipart.next_field().await.map_err(|e| { ( - axum::http::StatusCode::BAD_REQUEST, + StatusCode::BAD_REQUEST, format!("Failed to process form data: {e}"), ) })? else { return Err(( - axum::http::StatusCode::BAD_REQUEST, + StatusCode::BAD_REQUEST, "No file found in request".to_string(), )); }; - let content_type = field.content_type().ok_or(( - axum::http::StatusCode::BAD_REQUEST, - "Missing content type".to_string(), - ))?; + let content_type = field + .content_type() + .ok_or((StatusCode::BAD_REQUEST, "Missing content type".to_string()))?; // Validate content type if !content_type.starts_with("image/") { return Err(( - axum::http::StatusCode::BAD_REQUEST, + StatusCode::BAD_REQUEST, "Invalid file type. Only images are allowed.".to_string(), )); } // Generate a unique filename with the correct extension let extension = mime::Mime::from_str(content_type) - .map_err(|_| { - ( - axum::http::StatusCode::BAD_REQUEST, - "Invalid content type".to_string(), - ) - })? + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid content type".to_string()))? .subtype() .as_str() .to_string(); @@ -156,14 +889,14 @@ async fn upload_image( // Read the file data and write it to disk let data = field.bytes().await.map_err(|e| { ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, + StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read file data: {e}"), ) })?; tokio::fs::write(&filepath, &data).await.map_err(|e| { ( - axum::http::StatusCode::INTERNAL_SERVER_ERROR, + StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to save file: {e}"), ) })?; @@ -179,23 +912,17 @@ async fn upload_image( async fn serve_image( State(state): State, AxumPath(filename): AxumPath, -) -> Result { +) -> Result { let filepath = state.uploads_dir.join(filename); // Validate the path to prevent directory traversal if !filepath.starts_with(&state.uploads_dir) { - return Err(( - axum::http::StatusCode::BAD_REQUEST, - "Invalid filename".to_string(), - )); + return Err((StatusCode::BAD_REQUEST, "Invalid filename".to_string())); } - let data = tokio::fs::read(&filepath).await.map_err(|e| { - ( - axum::http::StatusCode::NOT_FOUND, - format!("Image not found: {e}"), - ) - })?; + let data = tokio::fs::read(&filepath) + .await + .map_err(|e| (StatusCode::NOT_FOUND, format!("Image not found: {e}")))?; // Try to determine the content type from the file extension let content_type = match filepath.extension().and_then(|e| e.to_str()) { diff --git a/backend/src/test_utils.rs b/backend/src/test_utils.rs index a675dc87..de303c10 100644 --- a/backend/src/test_utils.rs +++ b/backend/src/test_utils.rs @@ -3,6 +3,7 @@ use std::path::Path; use std::sync::Arc; +use dashmap::DashMap; use futures::{SinkExt, StreamExt}; use governor::{Quota, RateLimiter}; use nonzero_ext::nonzero; @@ -17,13 +18,14 @@ use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage}; use crate::{ db::DB, + global_db::GlobalDB, subscriptions::Subscriptions, websocket_api::{ client_message::Message as CM, server_message::Message as SM, ActAs, Authenticate, ClientMessage, CreateMarket, CreateOrder, EditMarket, GetFullTradeHistory, MakeTransfer, Redeem, Redeemable, RevokeOwnership, ServerMessage, SettleMarket, SetSudo, Side, }, - AppState, + AppState, CohortState, }; /// Creates a test `AppState` with a temporary `SQLite` database. @@ -60,12 +62,30 @@ pub async fn create_test_app_state() -> anyhow::Result<(AppState, TempDir)> { let db = DB::new_for_tests(arbor_pixie_account_id, pool); + // Create global DB in temp dir + let global_db_path = temp_dir.path().join("global.db"); + let global_db = GlobalDB::init_with_path(&global_db_path.to_string_lossy()).await?; + + // Create a test cohort in the global DB + let cohort_info = global_db + .create_cohort("test", "Test", &db_path.to_string_lossy()) + .await?; + + let cohorts = Arc::new(DashMap::new()); + let cohort_state = Arc::new(CohortState { + db, + subscriptions: Subscriptions::new(), + is_read_only: std::sync::atomic::AtomicBool::new(false), + info: cohort_info, + }); + cohorts.insert("test".to_string(), cohort_state); + // Use permissive rate limits for testing let quota = Quota::per_second(nonzero!(10000u32)); let state = AppState { - db, - subscriptions: Subscriptions::new(), + global_db, + cohorts, expensive_ratelimit: Arc::new(RateLimiter::keyed(quota)), admin_expensive_ratelimit: Arc::new(RateLimiter::keyed(quota)), mutate_ratelimit: Arc::new(RateLimiter::keyed(quota)), @@ -80,17 +100,31 @@ pub async fn create_test_app_state() -> anyhow::Result<(AppState, TempDir)> { /// /// # Errors /// Returns an error if the server fails to start. +/// +/// # Panics +/// Panics if the test cohort is not found in the app state. pub async fn spawn_test_server(app_state: AppState) -> anyhow::Result { - use axum::{extract::State, routing::get, Router}; + use axum::{ + extract::{Path as AxumPath, State}, + routing::get, + Router, + }; use crate::handle_socket::handle_socket; let app = Router::new() .route( - "/api", + "/api/ws/:cohort_name", get( - |ws: axum::extract::WebSocketUpgrade, State(state): State| async move { - ws.on_upgrade(move |socket| handle_socket(socket, state)) + |ws: axum::extract::WebSocketUpgrade, + AxumPath(cohort_name): AxumPath, + State(state): State| async move { + let cohort = state + .cohorts + .get(&cohort_name) + .map(|c| Arc::clone(&c)) + .unwrap(); + ws.on_upgrade(move |socket| handle_socket(socket, state, cohort)) }, ), ) @@ -106,7 +140,7 @@ pub async fn spawn_test_server(app_state: AppState) -> anyhow::Result { // Give the server a moment to start tokio::time::sleep(std::time::Duration::from_millis(10)).await; - Ok(format!("ws://127.0.0.1:{}/api", addr.port())) + Ok(format!("ws://127.0.0.1:{}/api/ws/test", addr.port())) } /// WebSocket test client for integration tests. diff --git a/backend/tests/websocket_redemptions.rs b/backend/tests/websocket_redemptions.rs index d1cf7033..ebc02a29 100644 --- a/backend/tests/websocket_redemptions.rs +++ b/backend/tests/websocket_redemptions.rs @@ -1,4 +1,5 @@ -//! WebSocket integration tests for redemptions in GetFullTradeHistory +#![allow(clippy::similar_names, clippy::too_many_lines, clippy::used_underscore_binding)] +//! WebSocket integration tests for redemptions in `GetFullTradeHistory` //! //! Run with: `cargo test --features dev-mode` @@ -11,7 +12,7 @@ use backend::{ use tempfile::TempDir; /// Helper: create a standard test setup with an admin, two constituent markets, and a fund market. -/// Returns (url, admin_account_id, market_a_id, market_b_id, fund_id, _temp_dir). +/// Returns (url, `admin_account_id`, `market_a_id`, `market_b_id`, `fund_id`, _`temp_dir`). /// The `TempDir` must be kept alive for the duration of the test. async fn setup_fund_market() -> (String, i64, i64, i64, i64, TempDir) { let (app_state, temp) = create_test_app_state().await.unwrap(); @@ -41,7 +42,7 @@ async fn setup_fund_market() -> (String, i64, i64, i64, i64, TempDir) { .message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; let market_b_id = match admin @@ -51,7 +52,7 @@ async fn setup_fund_market() -> (String, i64, i64, i64, i64, TempDir) { .message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Create a fund market redeemable for A (multiplier 1) and B (multiplier 2), with fee of 5 @@ -78,7 +79,7 @@ async fn setup_fund_market() -> (String, i64, i64, i64, i64, TempDir) { .message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; drop(admin); @@ -162,7 +163,7 @@ async fn test_full_trade_history_includes_redemptions() { assert_eq!(r.fund_id, fund_id); assert!((r.amount - 3.0).abs() < f64::EPSILON); } - other => panic!("Expected Redeemed, got {:?}", other), + other => panic!("Expected Redeemed, got {other:?}"), } // Drain broadcasts while admin @@ -189,7 +190,7 @@ async fn test_full_trade_history_includes_redemptions() { assert!(redemption.transaction_id > 0); assert!(redemption.transaction_timestamp.is_some()); } - other => panic!("Expected Trades, got {:?}", other), + other => panic!("Expected Trades, got {other:?}"), } } @@ -215,7 +216,7 @@ async fn test_full_trade_history_no_redemptions_on_non_fund_market() { "Non-fund market should have no redemptions" ); } - other => panic!("Expected Trades, got {:?}", other), + other => panic!("Expected Trades, got {other:?}"), } } @@ -317,7 +318,7 @@ async fn test_multiple_redemptions_in_history() { assert!((trades.redemptions[0].amount - 2.0).abs() < f64::EPSILON); assert!((trades.redemptions[1].amount - 3.0).abs() < f64::EPSILON); } - other => panic!("Expected Trades, got {:?}", other), + other => panic!("Expected Trades, got {other:?}"), } } @@ -349,7 +350,7 @@ async fn test_redemptions_hidden_on_hide_account_ids_market() { .message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Create fund with hide_account_ids=true @@ -370,7 +371,7 @@ async fn test_redemptions_hidden_on_hide_account_ids_market() { .message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Create a user, give them money, have them trade and redeem @@ -454,7 +455,7 @@ async fn test_redemptions_hidden_on_hide_account_ids_market() { redemption.account_id, user_id, admin_id ); } - other => panic!("Expected Trades, got {:?}", other), + other => panic!("Expected Trades, got {other:?}"), } // With sudo enabled, IDs should be visible @@ -474,7 +475,7 @@ async fn test_redemptions_hidden_on_hide_account_ids_market() { "Admin with sudo should see real redemption account_id" ); } - other => panic!("Expected Trades, got {:?}", other), + other => panic!("Expected Trades, got {other:?}"), } } @@ -554,7 +555,7 @@ async fn test_redemption_broadcast_includes_correct_data() { assert!(r.transaction_id > 0); assert!(r.transaction_timestamp.is_some()); } - other => panic!("Expected Redeemed, got {:?}", other), + other => panic!("Expected Redeemed, got {other:?}"), } // The user (as another connected client) should also receive the broadcast diff --git a/backend/tests/websocket_sudo.rs b/backend/tests/websocket_sudo.rs index 7255707c..e253ac05 100644 --- a/backend/tests/websocket_sudo.rs +++ b/backend/tests/websocket_sudo.rs @@ -1,3 +1,4 @@ +#![allow(clippy::too_many_lines)] //! WebSocket integration tests for sudo functionality and admin permissions //! //! Run with: `cargo test --features dev-mode` @@ -9,7 +10,7 @@ use backend::{ websocket_api::{server_message::Message as SM, RequestFailed, Side, SudoStatus}, }; -/// Helper to assert a RequestFailed response with expected kind and message substring +/// Helper to assert a `RequestFailed` response with expected kind and message substring fn assert_request_failed(msg: &SM, expected_kind: &str, expected_message_contains: &str) { match msg { SM::RequestFailed(RequestFailed { @@ -28,21 +29,20 @@ fn assert_request_failed(msg: &SM, expected_kind: &str, expected_message_contain error.message ); } - other => panic!("Expected RequestFailed, got {:?}", other), + other => panic!("Expected RequestFailed, got {other:?}"), } } -/// Helper to assert a SudoStatus response +/// Helper to assert a `SudoStatus` response fn assert_sudo_status(msg: &SM, expected_enabled: bool) { match msg { SM::SudoStatus(SudoStatus { enabled }) => { assert_eq!( *enabled, expected_enabled, - "Expected sudo enabled={}, got {}", - expected_enabled, enabled + "Expected sudo enabled={expected_enabled}, got {enabled}" ); } - other => panic!("Expected SudoStatus, got {:?}", other), + other => panic!("Expected SudoStatus, got {other:?}"), } } @@ -239,12 +239,11 @@ async fn test_act_as_shows_not_owner_for_non_admin() { async fn test_act_as_shows_sudo_required_for_admin() { let (app_state, _temp) = create_test_app_state().await.unwrap(); - // Create a second user that admin will try to act as - let _ = app_state - .db - .ensure_user_created("user2", Some("Second User"), rust_decimal_macros::dec!(100)) - .await - .unwrap(); + // Create a second user through the global DB flow so they match what WS auth creates + let global_user2 = app_state.global_db.ensure_global_user("user2", "Second User", None).await.unwrap(); + let cohort = app_state.cohorts.get("test").unwrap(); + let _ = cohort.db.ensure_user_created_by_global_id(global_user2.id, "Second User", rust_decimal_macros::dec!(100)).await.unwrap(); + drop(cohort); let url = spawn_test_server(app_state).await.unwrap(); @@ -311,16 +310,13 @@ async fn test_hide_account_ids_respects_sudo() { let (app_state, _temp) = create_test_app_state().await.unwrap(); // Pre-create users with initial balance so they can place orders - let _ = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); - let _ = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + // Create users through the global DB flow so they match what the WS auth creates + let global_user1 = app_state.global_db.ensure_global_user("user1", "User One", None).await.unwrap(); + let global_user2 = app_state.global_db.ensure_global_user("user2", "User Two", None).await.unwrap(); + let cohort = app_state.cohorts.get("test").unwrap(); + let _ = cohort.db.ensure_user_created_by_global_id(global_user1.id, "User One", rust_decimal_macros::dec!(1000)).await.unwrap(); + let _ = cohort.db.ensure_user_created_by_global_id(global_user2.id, "User Two", rust_decimal_macros::dec!(1000)).await.unwrap(); + drop(cohort); let url = spawn_test_server(app_state).await.unwrap(); @@ -349,7 +345,7 @@ async fn test_hide_account_ids_respects_sudo() { .unwrap(); let market_id = match market_response.message { Some(SM::Market(market)) => market.id, - other => panic!("Expected Market response, got {:?}", other), + other => panic!("Expected Market response, got {other:?}"), }; // Disable sudo - now admin should have account IDs hidden @@ -370,7 +366,7 @@ async fn test_hide_account_ids_respects_sudo() { .unwrap(); match &order_response.message { Some(SM::OrderCreated(_)) => {} - other => panic!("User1 should have received OrderCreated, got {:?}", other), + other => panic!("User1 should have received OrderCreated, got {other:?}"), } // Small delay to ensure message propagates @@ -416,7 +412,7 @@ async fn test_hide_account_ids_respects_sudo() { Some(SM::OrderCreated(oc)) => { assert!(!oc.trades.is_empty(), "Should have created a trade"); } - other => panic!("User2 should have received OrderCreated with trade, got {:?}", other), + other => panic!("User2 should have received OrderCreated with trade, got {other:?}"), } // Small delay @@ -555,16 +551,13 @@ async fn test_hide_account_ids_in_full_trade_history() { let (app_state, _temp) = create_test_app_state().await.unwrap(); // Pre-create users with initial balance so they can place orders - let _ = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); - let _ = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + // Create users through the global DB flow so they match what the WS auth creates + let global_user1 = app_state.global_db.ensure_global_user("user1", "User One", None).await.unwrap(); + let global_user2 = app_state.global_db.ensure_global_user("user2", "User Two", None).await.unwrap(); + let cohort = app_state.cohorts.get("test").unwrap(); + let _ = cohort.db.ensure_user_created_by_global_id(global_user1.id, "User One", rust_decimal_macros::dec!(1000)).await.unwrap(); + let _ = cohort.db.ensure_user_created_by_global_id(global_user2.id, "User Two", rust_decimal_macros::dec!(1000)).await.unwrap(); + drop(cohort); let url = spawn_test_server(app_state).await.unwrap(); @@ -593,7 +586,7 @@ async fn test_hide_account_ids_in_full_trade_history() { .unwrap(); let market_id = match market_response.message { Some(SM::Market(market)) => market.id, - other => panic!("Expected Market response, got {:?}", other), + other => panic!("Expected Market response, got {other:?}"), }; // Connect user1 and place an order @@ -660,7 +653,7 @@ async fn test_hide_account_ids_in_full_trade_history() { ); } } - other => panic!("Expected Trades response, got {:?}", other), + other => panic!("Expected Trades response, got {other:?}"), } // Enable sudo and verify IDs are now visible @@ -695,6 +688,6 @@ async fn test_hide_account_ids_in_full_trade_history() { "Admin with sudo should see real account_ids in trade history" ); } - other => panic!("Expected Trades response, got {:?}", other), + other => panic!("Expected Trades response, got {other:?}"), } } diff --git a/backend/tests/websocket_universes.rs b/backend/tests/websocket_universes.rs index c73712f1..4fef861e 100644 --- a/backend/tests/websocket_universes.rs +++ b/backend/tests/websocket_universes.rs @@ -1,3 +1,4 @@ +#![allow(clippy::too_many_lines)] //! WebSocket integration tests for universes functionality //! //! Run with: `cargo test --features dev-mode` @@ -12,7 +13,7 @@ use backend::{ }, }; -/// Helper to assert a RequestFailed response with expected kind and message substring +/// Helper to assert a `RequestFailed` response with expected kind and message substring fn assert_request_failed(msg: &SM, expected_kind: &str, expected_message_contains: &str) { match msg { SM::RequestFailed(RequestFailed { @@ -31,11 +32,11 @@ fn assert_request_failed(msg: &SM, expected_kind: &str, expected_message_contain error.message ); } - other => panic!("Expected RequestFailed, got {:?}", other), + other => panic!("Expected RequestFailed, got {other:?}"), } } -/// Helper to receive a response with the matching request_id, draining other messages +/// Helper to receive a response with the matching `request_id`, draining other messages async fn recv_response( client: &mut TestClient, request_id: &str, @@ -52,13 +53,13 @@ async fn recv_response( } } -/// Helper to send CreateUniverse and get the response +/// Helper to send `CreateUniverse` and get the response async fn create_universe( client: &mut TestClient, name: &str, description: &str, ) -> anyhow::Result { - let request_id = format!("create-universe-{}", name); + let request_id = format!("create-universe-{name}"); let msg = ClientMessage { request_id: request_id.clone(), message: Some(CM::CreateUniverse(CreateUniverse { @@ -70,7 +71,7 @@ async fn create_universe( recv_response(client, &request_id).await } -/// Helper to send CreateAccount and get the response +/// Helper to send `CreateAccount` and get the response async fn create_account( client: &mut TestClient, owner_id: i64, @@ -78,7 +79,7 @@ async fn create_account( universe_id: i64, initial_balance: f64, ) -> anyhow::Result { - let request_id = format!("create-account-{}", name); + let request_id = format!("create-account-{name}"); let msg = ClientMessage { request_id: request_id.clone(), message: Some(CM::CreateAccount(CreateAccount { @@ -93,7 +94,7 @@ async fn create_account( recv_response(client, &request_id).await } -/// Helper to send MakeTransfer and get the response +/// Helper to send `MakeTransfer` and get the response async fn make_transfer( client: &mut TestClient, from_account_id: i64, @@ -101,7 +102,7 @@ async fn make_transfer( amount: f64, note: &str, ) -> anyhow::Result { - let request_id = format!("transfer-{}-to-{}", from_account_id, to_account_id); + let request_id = format!("transfer-{from_account_id}-to-{to_account_id}"); let msg = ClientMessage { request_id: request_id.clone(), message: Some(CM::MakeTransfer(MakeTransfer { @@ -115,12 +116,12 @@ async fn make_transfer( recv_response(client, &request_id).await } -/// Helper to send ActAs and get the response +/// Helper to send `ActAs` and get the response async fn act_as( client: &mut TestClient, account_id: i64, ) -> anyhow::Result { - let request_id = format!("act-as-{}", account_id); + let request_id = format!("act-as-{account_id}"); let msg = ClientMessage { request_id: request_id.clone(), message: Some(CM::ActAs(ActAs { account_id })), @@ -163,7 +164,7 @@ async fn test_any_user_can_create_universe() { assert_eq!(universe.description, "A test universe"); assert!(universe.owner_id > 0, "Universe should have an owner"); } - other => panic!("Expected Universe response, got {:?}", other), + other => panic!("Expected Universe response, got {other:?}"), } } @@ -212,7 +213,7 @@ async fn test_create_account_in_universe() { let universe_response = create_universe(&mut client, "My Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account in that universe (with initial balance since we own the universe) @@ -224,7 +225,7 @@ async fn test_create_account_in_universe() { assert_eq!(account.name, "Test Account"); assert_eq!(account.universe_id, universe_id); } - other => panic!("Expected AccountCreated response, got {:?}", other), + other => panic!("Expected AccountCreated response, got {other:?}"), } } @@ -244,7 +245,7 @@ async fn test_only_universe_owner_can_set_initial_balance() { let universe_response = create_universe(&mut user1, "User1 Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // User1 can create account with initial balance (they own the universe) @@ -289,7 +290,7 @@ async fn test_create_account_in_main_universe_without_initial_balance() { Some(SM::AccountCreated(account)) => { assert_eq!(account.universe_id, 0); } - other => panic!("Expected AccountCreated response, got {:?}", other), + other => panic!("Expected AccountCreated response, got {other:?}"), } } @@ -335,14 +336,14 @@ async fn test_only_universe_owner_can_create_market_in_non_main_universe() { let universe_response = create_universe(&mut user1, "Market Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account in that universe let account_response = create_account(&mut user1, user1_id, "User1 Account", universe_id, 1000.0).await.unwrap(); let user1_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // ActAs that account @@ -352,7 +353,7 @@ async fn test_only_universe_owner_can_create_market_in_non_main_universe() { assert_eq!(acting_as.account_id, user1_account_id); assert_eq!(acting_as.universe_id, universe_id); } - other => panic!("Expected ActingAs, got {:?}", other), + other => panic!("Expected ActingAs, got {other:?}"), } drain_messages(&mut user1).await; @@ -362,7 +363,7 @@ async fn test_only_universe_owner_can_create_market_in_non_main_universe() { Some(SM::Market(market)) => { assert_eq!(market.universe_id, universe_id); } - other => panic!("Expected Market response, got {:?}", other), + other => panic!("Expected Market response, got {other:?}"), } // User2 connects - verify they CANNOT create accounts in user1's universe @@ -405,21 +406,21 @@ async fn test_cross_universe_transfer_rejected() { let main_account_response = create_account(&mut admin, admin_id, "Main Alt", 0, 0.0).await.unwrap(); let main_alt_id = match main_account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Create a universe let universe_response = create_universe(&mut admin, "Transfer Test Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account in the new universe let universe_account_response = create_account(&mut admin, admin_id, "Universe Alt", universe_id, 1000.0).await.unwrap(); let universe_alt_id = match universe_account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Try to transfer from main universe (admin account) to universe alt @@ -463,14 +464,14 @@ async fn test_same_universe_transfer_works() { let universe_response = create_universe(&mut user, "Same Universe Test", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create an account in that universe with initial balance let account_response = create_account(&mut user, user_id, "Universe Account", universe_id, 1000.0).await.unwrap(); let universe_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Create a second account in the same universe owned by the first account @@ -478,7 +479,7 @@ async fn test_same_universe_transfer_works() { let sub_account_response = create_account(&mut user, universe_account_id, "Sub Account", universe_id, 0.0).await.unwrap(); let sub_account_id = match sub_account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Act as the universe account @@ -492,7 +493,7 @@ async fn test_same_universe_transfer_works() { Some(SM::TransferCreated(_)) => { // Success! } - other => panic!("Expected TransferCreated response, got {:?}", other), + other => panic!("Expected TransferCreated response, got {other:?}"), } } @@ -520,20 +521,20 @@ async fn test_cross_universe_trade_rejected() { let market_response = admin.create_market("Main Universe Market", 0.0, 100.0, false).await.unwrap(); let main_market_id = match market_response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Create a universe and account in it let universe_response = create_universe(&mut admin, "Trading Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; let account_response = create_account(&mut admin, admin_id, "Universe Account", universe_id, 10000.0).await.unwrap(); let universe_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Act as the universe account @@ -566,14 +567,14 @@ async fn test_same_universe_trading_works() { let universe_response = create_universe(&mut user, "Trading Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account in that universe let account_response = create_account(&mut user, user_id, "Trader Account", universe_id, 10000.0).await.unwrap(); let account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Act as that account @@ -587,7 +588,7 @@ async fn test_same_universe_trading_works() { assert_eq!(m.universe_id, universe_id, "Market should be in the same universe"); m.id } - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Trading in same universe should work @@ -597,7 +598,7 @@ async fn test_same_universe_trading_works() { Some(SM::OrderCreated(_)) => { // Success! } - other => panic!("Expected OrderCreated response, got {:?}", other), + other => panic!("Expected OrderCreated response, got {other:?}"), } } @@ -621,14 +622,14 @@ async fn test_act_as_to_universe_account_shows_universe_id() { let universe_response = create_universe(&mut user, "ActAs Universe Test", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account in that universe let account_response = create_account(&mut user, user_id, "Universe Account", universe_id, 100.0).await.unwrap(); let universe_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Act as that account @@ -640,7 +641,7 @@ async fn test_act_as_to_universe_account_shows_universe_id() { assert_eq!(acting_as.account_id, universe_account_id); assert_eq!(acting_as.universe_id, universe_id, "Should show universe_id in ActingAs"); } - other => panic!("Expected ActingAs response, got {:?}", other), + other => panic!("Expected ActingAs response, got {other:?}"), } } @@ -660,13 +661,13 @@ async fn test_act_as_switches_universe_and_resends_markets() { let universe_response = create_universe(&mut user, "ActAs Test Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; let account_response = create_account(&mut user, user_id, "Universe Account", universe_id, 1000.0).await.unwrap(); let universe_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Act as universe account to create a market there @@ -676,7 +677,7 @@ async fn test_act_as_switches_universe_and_resends_markets() { let market_response = user.create_market("Universe Only Market", 0.0, 100.0, false).await.unwrap(); let universe_market_id = match market_response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Go back to main account @@ -687,7 +688,7 @@ async fn test_act_as_switches_universe_and_resends_markets() { Some(SM::ActingAs(acting_as)) => { assert_eq!(acting_as.universe_id, 0, "Should be in main universe now"); } - other => panic!("Expected ActingAs, got {:?}", other), + other => panic!("Expected ActingAs, got {other:?}"), } // Collect resent markets @@ -735,21 +736,21 @@ async fn test_initial_data_filters_markets_by_universe() { let main_market_response = admin.create_market("Main Market", 0.0, 100.0, false).await.unwrap(); let main_market_id = match main_market_response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Create a universe let universe_response = create_universe(&mut admin, "Filter Test Universe", "").await.unwrap(); let universe_id = match universe_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create account and market in that universe let account_response = create_account(&mut admin, admin_id, "Filter Account", universe_id, 1000.0).await.unwrap(); let universe_account_id = match account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; let _ = act_as(&mut admin, universe_account_id).await.unwrap(); @@ -758,7 +759,7 @@ async fn test_initial_data_filters_markets_by_universe() { let universe_market_response = admin.create_market("Universe Market", 0.0, 100.0, false).await.unwrap(); let universe_market_id = match universe_market_response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // New user connects - should only see main universe market initially @@ -830,20 +831,20 @@ async fn test_owner_must_be_in_same_universe_or_universe_0() { let universe1_response = create_universe(&mut user, "Universe 1", "").await.unwrap(); let universe1_id = match universe1_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; let universe2_response = create_universe(&mut user, "Universe 2", "").await.unwrap(); let universe2_id = match universe2_response.message { Some(SM::Universe(u)) => u.id, - other => panic!("Expected Universe, got {:?}", other), + other => panic!("Expected Universe, got {other:?}"), }; // Create an account in universe 1 let account1_response = create_account(&mut user, user_id, "Account in U1", universe1_id, 100.0).await.unwrap(); let account1_id = match account1_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Try to create an account in universe 2 with owner from universe 1 - should fail @@ -888,7 +889,7 @@ async fn test_owner_must_be_in_same_universe_or_universe_0() { Some(SM::AccountCreated(account)) => { assert_eq!(account.universe_id, universe1_id); } - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), } // Creating account in universe 1 with owner from universe 0 (main user) should work @@ -897,7 +898,7 @@ async fn test_owner_must_be_in_same_universe_or_universe_0() { Some(SM::AccountCreated(account)) => { assert_eq!(account.universe_id, universe1_id); } - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), } // Try to create an account in universe 0 (main) with owner from universe 1 - should fail @@ -926,7 +927,7 @@ async fn test_owner_must_be_in_same_universe_or_universe_0() { let alt_account_response = create_account(&mut user, user_id, "Alt in Main", 0, 0.0).await.unwrap(); let alt_account_id = match alt_account_response.message { Some(SM::AccountCreated(a)) => a.id, - other => panic!("Expected AccountCreated, got {:?}", other), + other => panic!("Expected AccountCreated, got {other:?}"), }; // Try to create account in non-zero universe with alt account (non-user) from main as owner - should fail diff --git a/backend/tests/websocket_visibility.rs b/backend/tests/websocket_visibility.rs index 108a9063..0534bfa8 100644 --- a/backend/tests/websocket_visibility.rs +++ b/backend/tests/websocket_visibility.rs @@ -2,9 +2,9 @@ //! //! Tests two bugs: //! 1. Sudoed admins don't receive real-time updates for visibility-restricted markets -//! (CreateMarket, SettleMarket, EditMarket are sent via send_private to visible_to accounts only) +//! (`CreateMarket`, `SettleMarket`, `EditMarket` are sent via `send_private` to `visible_to` accounts only) //! 2. Non-admin users receive broadcast updates (orders/trades/settlements) for markets -//! they shouldn't see (these go via send_public with no visibility filtering) +//! they shouldn't see (these go via `send_public` with no visibility filtering) //! //! Run with: `cargo test --features dev-mode` @@ -14,11 +14,21 @@ use backend::{ test_utils::{create_test_app_state, spawn_test_server, TestClient}, websocket_api::{server_message::Message as SM, ClientMessage, EditMarket, ServerMessage, Side, client_message::Message as CM, SettleMarket}, + AppState, }; const TIMEOUT: std::time::Duration = std::time::Duration::from_millis(500); const SHORT_DELAY: std::time::Duration = std::time::Duration::from_millis(50); +/// Helper to create a user via the multi-cohort `global_db` + cohort db pattern. +/// Returns the cohort-local account ID. +async fn create_test_user(app_state: &AppState, kinde_id: &str, name: &str, balance: rust_decimal::Decimal) -> i64 { + let global_user = app_state.global_db.ensure_global_user(kinde_id, name, None).await.unwrap(); + let cohort = app_state.cohorts.get("test").unwrap(); + let result = cohort.db.ensure_user_created_by_global_id(global_user.id, name, balance).await.unwrap().unwrap(); + result.id +} + /// Drain all pending messages, returning them async fn drain_messages(client: &mut TestClient) -> Vec { let mut messages = vec![]; @@ -32,7 +42,7 @@ async fn drain_messages(client: &mut TestClient) -> Vec { messages } -/// Send a request and recv messages until we find one with our request_id +/// Send a request and recv messages until we find one with our `request_id` async fn send_and_recv(client: &mut TestClient, request_id: String, cm: CM) -> ServerMessage { client .send_raw(ClientMessage { @@ -59,13 +69,7 @@ async fn test_sudoed_admin_receives_create_market_with_visible_to() { // should receive the Market broadcast. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -103,7 +107,7 @@ async fn test_sudoed_admin_receives_create_market_with_visible_to() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market response, got {:?}", other), + other => panic!("Expected Market response, got {other:?}"), }; // Admin2 (sudoed, NOT in visible_to) should receive the Market message @@ -115,8 +119,7 @@ async fn test_sudoed_admin_receives_create_market_with_visible_to() { assert!( found_market, "Sudoed admin2 should receive Market broadcast for visible_to-restricted market. \ - Got messages: {:#?}", - messages + Got messages: {messages:#?}" ); } @@ -126,13 +129,7 @@ async fn test_sudoed_admin_receives_settle_market_with_visible_to() { // should receive MarketSettled. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -170,7 +167,7 @@ async fn test_sudoed_admin_receives_settle_market_with_visible_to() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Drain market creation broadcast from admin2 drain_messages(&mut admin2).await; @@ -187,7 +184,7 @@ async fn test_sudoed_admin_receives_settle_market_with_visible_to() { .await; match &settle_response.message { Some(SM::MarketSettled(_)) => {} - other => panic!("Expected MarketSettled, got {:?}", other), + other => panic!("Expected MarketSettled, got {other:?}"), } // Admin2 should receive MarketSettled @@ -199,8 +196,7 @@ async fn test_sudoed_admin_receives_settle_market_with_visible_to() { assert!( found_settled, "Sudoed admin2 should receive MarketSettled for visible_to-restricted market. \ - Got messages: {:#?}", - messages + Got messages: {messages:#?}" ); } @@ -210,13 +206,7 @@ async fn test_sudoed_admin_receives_edit_market_with_visible_to() { // should receive the Market update. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -254,7 +244,7 @@ async fn test_sudoed_admin_receives_edit_market_with_visible_to() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Drain market creation broadcast from admin2 drain_messages(&mut admin2).await; @@ -272,7 +262,7 @@ async fn test_sudoed_admin_receives_edit_market_with_visible_to() { .await; match &edit_response.message { Some(SM::Market(_)) => {} - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), } // Admin2 should receive the Market update @@ -284,8 +274,7 @@ async fn test_sudoed_admin_receives_edit_market_with_visible_to() { assert!( found_market, "Sudoed admin2 should receive Market update for visible_to-restricted market. \ - Got messages: {:#?}", - messages + Got messages: {messages:#?}" ); } @@ -298,19 +287,8 @@ async fn test_non_visible_user_does_not_receive_order_updates() { // A user NOT in visible_to should not receive OrderCreated broadcasts. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let _ = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; + let _ = create_test_user(&app_state, "user2", "User Two", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -336,7 +314,7 @@ async fn test_non_visible_user_does_not_receive_order_updates() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Connect user1 and user2 (do this AFTER market creation to avoid broadcast interleaving) @@ -380,27 +358,9 @@ async fn test_non_visible_user_does_not_receive_trade_broadcasts() { // A user NOT in visible_to should not receive trade broadcasts. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let user2_id = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let _ = app_state - .db - .ensure_user_created("user3", Some("User Three"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; + let user2_id = create_test_user(&app_state, "user2", "User Two", rust_decimal_macros::dec!(1000)).await; + let _ = create_test_user(&app_state, "user3", "User Three", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -426,7 +386,7 @@ async fn test_non_visible_user_does_not_receive_trade_broadcasts() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Connect all users AFTER market creation @@ -476,7 +436,7 @@ async fn test_non_visible_user_does_not_receive_trade_broadcasts() { Some(SM::OrderCreated(oc)) => { assert!(!oc.trades.is_empty(), "Should have created a trade"); } - other => panic!("Expected OrderCreated with trades, got {:?}", other), + other => panic!("Expected OrderCreated with trades, got {other:?}"), } // User3 should NOT receive any OrderCreated for this market @@ -496,19 +456,8 @@ async fn test_non_visible_user_does_not_receive_settle_broadcast() { // A user NOT in visible_to should not receive MarketSettled or OrdersCancelled. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let _ = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; + let _ = create_test_user(&app_state, "user2", "User Two", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -534,7 +483,7 @@ async fn test_non_visible_user_does_not_receive_settle_broadcast() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Connect users AFTER market creation @@ -577,7 +526,7 @@ async fn test_non_visible_user_does_not_receive_settle_broadcast() { .await; match &settle_response.message { Some(SM::MarketSettled(_)) => {} - other => panic!("Expected MarketSettled, got {:?}", other), + other => panic!("Expected MarketSettled, got {other:?}"), } // User2 should NOT receive MarketSettled or OrdersCancelled @@ -608,21 +557,8 @@ async fn test_visible_user_receives_order_updates() { // A user IN visible_to should receive OrderCreated broadcasts normally. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let user2_id = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; + let user2_id = create_test_user(&app_state, "user2", "User Two", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -648,7 +584,7 @@ async fn test_visible_user_receives_order_updates() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Connect users AFTER market creation @@ -692,19 +628,8 @@ async fn test_non_visible_user_does_not_see_market_in_initial_data() { // A user NOT in visible_to should not see the market in initial data. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; - - let _ = app_state - .db - .ensure_user_created("user2", Some("User Two"), rust_decimal_macros::dec!(1000)) - .await - .unwrap(); + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; + let _ = create_test_user(&app_state, "user2", "User Two", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -730,7 +655,7 @@ async fn test_non_visible_user_does_not_see_market_in_initial_data() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // User2 connects - should NOT see the market in initial data @@ -763,13 +688,7 @@ async fn test_sudoed_admin_sees_visible_to_market_in_initial_data() { // they should see it in initial data. let (app_state, _temp) = create_test_app_state().await.unwrap(); - let user1_id = app_state - .db - .ensure_user_created("user1", Some("User One"), rust_decimal_macros::dec!(1000)) - .await - .unwrap() - .unwrap() - .id; + let user1_id = create_test_user(&app_state, "user1", "User One", rust_decimal_macros::dec!(1000)).await; let url = spawn_test_server(app_state).await.unwrap(); @@ -795,7 +714,7 @@ async fn test_sudoed_admin_sees_visible_to_market_in_initial_data() { .unwrap(); let market_id = match &response.message { Some(SM::Market(m)) => m.id, - other => panic!("Expected Market, got {:?}", other), + other => panic!("Expected Market, got {other:?}"), }; // Admin2 connects, enables sudo → should see the market in resent data diff --git a/frontend/src/lib/adminApi.ts b/frontend/src/lib/adminApi.ts new file mode 100644 index 00000000..90c05b74 --- /dev/null +++ b/frontend/src/lib/adminApi.ts @@ -0,0 +1,232 @@ +import { kinde } from './auth.svelte'; +import { API_BASE } from './apiBase'; + +export interface CohortInfo { + id: number; + name: string; + display_name: string; + db_path: string; + is_read_only: boolean; +} + +export interface CohortMember { + id: number; + cohort_id: number; + global_user_id: number | null; + email: string | null; + display_name: string | null; + initial_balance: string | null; + balance: number | null; +} + +export interface GlobalConfig { + active_auction_cohort_id: number | null; + default_cohort_id: number | null; + public_auction_enabled: boolean; +} + +export interface GlobalUser { + id: number; + kinde_id: string; + display_name: string; + is_admin: boolean; + email: string | null; +} + +export interface UserCohortDetail { + cohort_name: string; + cohort_display_name: string; + balance: number | null; +} + +export interface UserWithCohorts extends GlobalUser { + cohorts: UserCohortDetail[]; +} + +async function authHeaders(): Promise { + const token = await kinde.getToken(); + return { + Authorization: `Bearer ${token}`, + 'Content-Type': 'application/json' + }; +} + +async function handleResponse(res: Response): Promise { + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } + return res.json(); +} + +export async function fetchAllCohorts(): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts`, { headers: await authHeaders() }); + return handleResponse(res); +} + +export async function createCohort( + name: string, + displayName: string, + existingDb?: boolean +): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts`, { + method: 'POST', + headers: await authHeaders(), + body: JSON.stringify({ name, display_name: displayName, existing_db: existingDb ?? false }) + }); + return handleResponse(res); +} + +export async function updateCohort( + name: string, + displayName?: string, + isReadOnly?: boolean +): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts/${name}`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify({ display_name: displayName, is_read_only: isReadOnly }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function fetchMembers(cohortName: string): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts/${cohortName}/members`, { + headers: await authHeaders() + }); + return handleResponse(res); +} + +export async function fetchGlobalUsers(): Promise { + const res = await fetch(`${API_BASE}/api/admin/users`, { headers: await authHeaders() }); + return handleResponse(res); +} + +export async function fetchUsersDetailed(): Promise { + const res = await fetch(`${API_BASE}/api/admin/users/details`, { headers: await authHeaders() }); + return handleResponse(res); +} + +export async function batchAddMembers( + cohortName: string, + opts: { emails?: string[]; user_ids?: number[]; initial_balance?: string } +): Promise<{ added: number; already_existing: number }> { + const res = await fetch(`${API_BASE}/api/admin/cohorts/${cohortName}/members`, { + method: 'POST', + headers: await authHeaders(), + body: JSON.stringify(opts) + }); + return handleResponse(res); +} + +export async function removeMember(cohortName: string, memberId: number): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts/${cohortName}/members/${memberId}`, { + method: 'DELETE', + headers: await authHeaders() + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function fetchConfig(): Promise { + const res = await fetch(`${API_BASE}/api/admin/config`, { headers: await authHeaders() }); + return handleResponse(res); +} + +export async function checkAdminAccess(): Promise { + try { + const res = await fetch(`${API_BASE}/api/admin/config`, { headers: await authHeaders() }); + return res.ok; + } catch { + return false; + } +} + +export async function updateConfig(config: { + active_auction_cohort_id?: number | null; + default_cohort_id?: number | null; + public_auction_enabled?: boolean; +}): Promise { + const res = await fetch(`${API_BASE}/api/admin/config`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify(config) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function fetchAvailableDbs(): Promise { + const res = await fetch(`${API_BASE}/api/admin/available-dbs`, { headers: await authHeaders() }); + return handleResponse(res); +} + +export async function toggleAdmin(userId: number, isAdmin: boolean): Promise { + const res = await fetch(`${API_BASE}/api/admin/users/${userId}/admin`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify({ is_admin: isAdmin }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function updateDisplayName(userId: number, displayName: string): Promise { + const res = await fetch(`${API_BASE}/api/admin/users/${userId}/display-name`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify({ display_name: displayName }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function deleteUser(userId: number): Promise { + const res = await fetch(`${API_BASE}/api/admin/users/${userId}`, { + method: 'DELETE', + headers: await authHeaders() + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function updateMyDisplayName(displayName: string): Promise { + const res = await fetch(`${API_BASE}/api/users/me/display-name`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify({ display_name: displayName }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} + +export async function updateMemberInitialBalance( + cohortName: string, + memberId: number, + initialBalance: string | null +): Promise { + const res = await fetch(`${API_BASE}/api/admin/cohorts/${cohortName}/members/${memberId}`, { + method: 'PUT', + headers: await authHeaders(), + body: JSON.stringify({ initial_balance: initialBalance }) + }); + if (!res.ok) { + const text = await res.text(); + throw new Error(text || res.statusText); + } +} diff --git a/frontend/src/lib/api.svelte.ts b/frontend/src/lib/api.svelte.ts index 36e0df1f..0dd8e965 100644 --- a/frontend/src/lib/api.svelte.ts +++ b/frontend/src/lib/api.svelte.ts @@ -6,6 +6,7 @@ import { websocket_api } from 'schema-js'; import { toast } from 'svelte-sonner'; import { SvelteMap } from 'svelte/reactivity'; import { kinde } from './auth.svelte'; +import { API_BASE } from './apiBase'; import { notifyUser } from './notifications'; // const originalConsoleLog = console.log; @@ -32,10 +33,8 @@ import { notifyUser } from './notifications'; // // Using new Error().stack is generally more reliable for the original call site. // }; -const socket = new ReconnectingWebSocket(PUBLIC_SERVER_URL); -socket.binaryType = 'arraybuffer'; - -console.log('Connecting to', PUBLIC_SERVER_URL); +let socket: ReconnectingWebSocket | null = null; +let currentCohort: string | null = null; export class MarketData { definition: websocket_api.IMarket = $state({}); @@ -65,7 +64,8 @@ export const serverState = $state({ universes: new SvelteMap(), tradedMarketIds: new SvelteMap>(), lastKnownTransactionId: 0, - arborPixieAccountId: undefined as number | undefined + arborPixieAccountId: undefined as number | undefined, + auctionOnly: false }); export const hasArborPixieTransfer = () => { @@ -101,6 +101,10 @@ let messageQueue: websocket_api.IClientMessage[] = []; let hasAuthenticated = false; export const sendClientMessage = (msg: websocket_api.IClientMessage) => { + if (!socket) { + messageQueue.push(msg); + return; + } if (hasAuthenticated || 'authenticate' in msg) { const msgType = Object.keys(msg).find((key) => msg[key as keyof typeof msg]); console.log(`sending ${msgType} message`, msg[msgType as keyof typeof msg]); @@ -133,6 +137,17 @@ export const accountName = ( return accountId === serverState.userId && me ? me : formattedName; }; +const checkAdminAccess = async (accessToken: string): Promise => { + try { + const res = await fetch(`${API_BASE}/api/admin/config`, { + headers: { Authorization: `Bearer ${accessToken}` } + }); + return res.ok; + } catch { + return false; + } +}; + /** * Returns a Map of accountId -> display name, using short names when unique * and falling back to raw (full) names or raw + ID when there are duplicates. @@ -184,8 +199,7 @@ const authenticate = async () => { startConnectionToast(); const accessToken = await kinde.getToken(); const idToken = await kinde.getIdToken(); - const isAdmin = await kinde.isAdmin(); - serverState.isAdmin = isAdmin; + const isRoleAdmin = await kinde.isAdmin(); if (!accessToken) { console.log('no access token'); @@ -195,7 +209,10 @@ const authenticate = async () => { console.log('no id token'); return; } - const actAs = Number(localStorage.getItem('actAs')); + const hasAdminApiAccess = await checkAdminAccess(accessToken); + serverState.isAdmin = isRoleAdmin || hasAdminApiAccess; + const actAsKey = currentCohort ? `${currentCohort}:actAs` : 'actAs'; + const actAs = Number(localStorage.getItem(actAsKey)); const authenticate = { jwt: accessToken, idJwt: idToken, @@ -205,13 +222,58 @@ const authenticate = async () => { sendClientMessage({ authenticate }); }; -socket.onopen = authenticate; - -socket.onclose = () => { +const resetServerState = () => { serverState.stale = true; + serverState.userId = undefined; + serverState.actingAs = undefined; + serverState.currentUniverseId = 0; + serverState.sudoEnabled = false; + serverState.portfolio = undefined; + serverState.portfolios.clear(); + serverState.transfers = []; + serverState.accounts.clear(); + serverState.markets.clear(); + serverState.marketTypes.clear(); + serverState.marketGroups.clear(); + serverState.auctions.clear(); + serverState.universes.clear(); + serverState.lastKnownTransactionId = 0; + serverState.arborPixieAccountId = undefined; + serverState.auctionOnly = false; + hasAuthenticated = false; + messageQueue = []; +}; + +export const connectToCohort = (cohortName: string) => { + if (currentCohort === cohortName && socket) return; + if (socket) { + socket.close(); + resetServerState(); + } + currentCohort = cohortName; + const wsUrl = `${PUBLIC_SERVER_URL}/ws/${cohortName}`; + console.log('Connecting to', wsUrl); + socket = new ReconnectingWebSocket(wsUrl); + socket.binaryType = 'arraybuffer'; + socket.onopen = authenticate; + socket.onclose = () => { + serverState.stale = true; + }; + socket.onmessage = handleMessage; }; -socket.onmessage = (event: MessageEvent) => { +export const disconnectFromCohort = () => { + if (socket) { + socket.close(); + socket = null; + } + currentCohort = null; + resetServerState(); +}; + +export const getCurrentCohort = () => currentCohort; + +const handleMessage = (event: MessageEvent) => { const data = event.data; const msg = websocket_api.ServerMessage.decode(new Uint8Array(data)); @@ -219,6 +281,7 @@ socket.onmessage = (event: MessageEvent) => { if (msg.authenticated) { serverState.userId = msg.authenticated.accountId; + serverState.auctionOnly = msg.authenticated.auctionOnly ?? false; serverState.sudoEnabled = false; } @@ -229,7 +292,8 @@ socket.onmessage = (event: MessageEvent) => { resolveConnectionToast = undefined; } if (msg.actingAs.accountId) { - localStorage.setItem('actAs', msg.actingAs.accountId.toString()); + const actAsKey = currentCohort ? `${currentCohort}:actAs` : 'actAs'; + localStorage.setItem(actAsKey, msg.actingAs.accountId.toString()); } serverState.actingAs = msg.actingAs.accountId; serverState.effectiveUserId = msg.actingAs.userId || serverState.userId; @@ -238,8 +302,8 @@ socket.onmessage = (event: MessageEvent) => { if (newUniverseId !== serverState.currentUniverseId) { serverState.markets.clear(); // Redirect to /market if on a specific market page (the market may not exist in the new universe) - if (browser && window.location.pathname.match(/^\/market\/\d+/)) { - goto('/market'); + if (browser && window.location.pathname.match(/\/market\/\d+/)) { + goto(currentCohort ? `/${currentCohort}/market` : '/market'); } } serverState.currentUniverseId = newUniverseId; @@ -535,7 +599,8 @@ socket.onmessage = (event: MessageEvent) => { } if (msg.requestFailed && msg.requestFailed.requestDetails?.kind === 'Authenticate') { - localStorage.removeItem('actAs'); + const actAsKey = currentCohort ? `${currentCohort}:actAs` : 'actAs'; + localStorage.removeItem(actAsKey); console.log('Authentication failed'); authenticate(); } @@ -588,5 +653,5 @@ if (browser) { /** Force WebSocket to reconnect and re-authenticate (useful after login state changes) */ export const reconnect = () => { - socket.reconnect(); + socket?.reconnect(); }; diff --git a/frontend/src/lib/apiBase.ts b/frontend/src/lib/apiBase.ts new file mode 100644 index 00000000..bfe192ef --- /dev/null +++ b/frontend/src/lib/apiBase.ts @@ -0,0 +1,17 @@ +import { PUBLIC_SERVER_URL } from '$env/static/public'; + +// Derive HTTP base URL from the WebSocket PUBLIC_SERVER_URL +// e.g. "wss://host.fly.dev/api" → "https://host.fly.dev" +// e.g. "ws://localhost:8080" → "http://localhost:8080" +// e.g. "/api" (dev mode) → "" (relative, Vite proxy handles it) +function deriveApiBase(serverUrl: string): string { + try { + const wsUrl = new URL(serverUrl); + const protocol = wsUrl.protocol === 'wss:' ? 'https:' : 'http:'; + return `${protocol}//${wsUrl.host}`; + } catch { + // Relative URL (e.g. "/api" in dev mode) — use empty base for relative fetch + return ''; + } +} +export const API_BASE = deriveApiBase(PUBLIC_SERVER_URL); diff --git a/frontend/src/lib/cohortApi.ts b/frontend/src/lib/cohortApi.ts new file mode 100644 index 00000000..04227fc5 --- /dev/null +++ b/frontend/src/lib/cohortApi.ts @@ -0,0 +1,31 @@ +import { kinde } from './auth.svelte'; +import { API_BASE } from './apiBase'; + +export interface CohortInfo { + id: number; + name: string; + display_name: string; + is_read_only: boolean; +} + +export interface CohortsResponse { + cohorts: CohortInfo[]; +} + +export async function fetchCohorts(): Promise { + const token = await kinde.getToken(); + const idToken = await kinde.getIdToken(); + const headers: HeadersInit = { + Authorization: `Bearer ${token}` + }; + if (idToken) { + headers['X-ID-Token'] = idToken; + } + const res = await fetch(`${API_BASE}/api/cohorts`, { + headers + }); + if (!res.ok) { + throw new Error(`Failed to fetch cohorts: ${res.statusText}`); + } + return res.json(); +} diff --git a/frontend/src/lib/components/appSideBar.svelte b/frontend/src/lib/components/appSideBar.svelte index 847ecc3b..c9c7a9d6 100644 --- a/frontend/src/lib/components/appSideBar.svelte +++ b/frontend/src/lib/components/appSideBar.svelte @@ -1,6 +1,7 @@ + +{#if loading} +
+
+
+{:else if error} +
+
+

{error}

+ +
+
+{:else if cohorts.length === 0} +
+
+

Your account isn't authorized for any cohort yet.

+

Contact an administrator to get access.

+ +
+
+{:else} +
+
+
+

Select a Cohort

+ +
+
+ {#each cohorts as cohort} + + {/each} +
+
+
+{/if} diff --git a/frontend/src/routes/[cohort_name]/+layout.svelte b/frontend/src/routes/[cohort_name]/+layout.svelte new file mode 100644 index 00000000..c4b2d2b9 --- /dev/null +++ b/frontend/src/routes/[cohort_name]/+layout.svelte @@ -0,0 +1,257 @@ + + + + +
+
+ +
+
+
+
+
+
+ {@render children()} +
+
+
+
diff --git a/frontend/src/routes/[cohort_name]/+page.ts b/frontend/src/routes/[cohort_name]/+page.ts new file mode 100644 index 00000000..83eb0328 --- /dev/null +++ b/frontend/src/routes/[cohort_name]/+page.ts @@ -0,0 +1,5 @@ +import { redirect } from '@sveltejs/kit'; + +export function load({ params }) { + redirect(307, `/${params.cohort_name}/market`); +} diff --git a/frontend/src/routes/accounts/+page.svelte b/frontend/src/routes/[cohort_name]/accounts/+page.svelte similarity index 81% rename from frontend/src/routes/accounts/+page.svelte rename to frontend/src/routes/[cohort_name]/accounts/+page.svelte index 42a03638..475b58ee 100644 --- a/frontend/src/routes/accounts/+page.svelte +++ b/frontend/src/routes/[cohort_name]/accounts/+page.svelte @@ -11,7 +11,8 @@ import ShareOwnership from '$lib/components/forms/shareOwnership.svelte'; import { Button } from '$lib/components/ui/button'; import * as Select from '$lib/components/ui/select'; - import { Copy } from '@lucide/svelte/icons'; + import { updateMyDisplayName } from '$lib/adminApi'; + import { Check, Copy, Pencil, X } from '@lucide/svelte/icons'; import { toast } from 'svelte-sonner'; import { universeMode } from '$lib/universeMode.svelte'; @@ -72,6 +73,28 @@ // Scroll to create account form document.getElementById('create-account-section')?.scrollIntoView({ behavior: 'smooth' }); } + + let editingDisplayName = $state(false); + let newDisplayName = $state(''); + + function startEditingName() { + editingDisplayName = true; + newDisplayName = accountName(serverState.userId) ?? ''; + } + + async function saveDisplayName() { + if (!newDisplayName.trim()) return; + try { + await updateMyDisplayName(newDisplayName.trim()); + // Update the account name in local state + const account = serverState.accounts.get(serverState.userId ?? 0); + if (account) account.name = newDisplayName.trim(); + toast.success('Display name updated'); + editingDisplayName = false; + } catch (e) { + toast.error('Failed to update: ' + (e instanceof Error ? e.message : String(e))); + } + } @@ -85,6 +108,38 @@

Accounts

+
+ Display name: + {#if editingDisplayName} + { + if (e.key === 'Enter') saveDisplayName(); + if (e.key === 'Escape') editingDisplayName = false; + }} + /> + + + {:else} + {accountName(serverState.userId)} + + {/if} +
{#if serverState.actingAs && serverState.accounts.get(serverState.actingAs)}

Currently acting as {accountName(serverState.actingAs)} diff --git a/frontend/src/routes/auction/+page.svelte b/frontend/src/routes/[cohort_name]/auction/+page.svelte similarity index 94% rename from frontend/src/routes/auction/+page.svelte rename to frontend/src/routes/[cohort_name]/auction/+page.svelte index a1247cb4..dd31b045 100644 --- a/frontend/src/routes/auction/+page.svelte +++ b/frontend/src/routes/[cohort_name]/auction/+page.svelte @@ -29,7 +29,9 @@

Auction

- + {#if serverState.isAdmin && serverState.sudoEnabled} + + {/if}
{#each Array.from(serverState.auctions.values()).sort((a, b) => (a.transactionTimestamp?.seconds ?? 0) - (b.transactionTimestamp?.seconds ?? 0)) as auction} diff --git a/frontend/src/routes/docs/+page.svelte b/frontend/src/routes/[cohort_name]/docs/+page.svelte similarity index 93% rename from frontend/src/routes/docs/+page.svelte rename to frontend/src/routes/[cohort_name]/docs/+page.svelte index 056a2e93..721feb1e 100644 --- a/frontend/src/routes/docs/+page.svelte +++ b/frontend/src/routes/[cohort_name]/docs/+page.svelte @@ -1,4 +1,5 @@ + + +{#if confirmModal.open} + +
(confirmModal.open = false)} + onkeydown={(e) => e.key === 'Escape' && (confirmModal.open = false)} + role="dialog" + tabindex="-1" + > + +
e.stopPropagation()} + role="document" + > +

{confirmModal.title}

+

{confirmModal.message}

+
+ + +
+
+
+{/if} + +{#if loading} +
+
+
+{:else if !isAdmin} +
+

Admin access required.

+
+{:else} +
+
+

Admin

+ {#if lastCohortName} + + Back to {lastCohortDisplay} + + {:else} + + Back to cohorts + + {/if} +
+ + +
+

General Config

+
+
+
+ + +
+

+ This only affects the default cohort that the python client and scenarios server uses +

+
+
+
+ + +
+

Auction Config

+
+
+ + +
+ +
+
+ + +
+

Cohorts

+
+

Create New Cohort

+
+ + +
+ + +
+ {#if availableDbs.length > 0} +
+

Available Databases

+
+ {#each availableDbs as db} + + {/each} +
+
+ {/if} + +
+ + +
+
+

All Users

+ +
+ {#if allUsers.length > 0} + +
+ {#each filteredUsers as user (user.id)} +
+
+
+ {#if editingUserId === user.id} + { + if (e.key === 'Enter') saveEditingName(); + if (e.key === 'Escape') editingUserId = null; + }} + /> + + + {:else} + {user.display_name} + + {#if user.email} + {user.email} + {/if} + {/if} +
+
+ + +
+
+ {#if user.cohorts.length > 0} +
+ {#each user.cohorts as uc} + + {uc.cohort_display_name} + {#if uc.balance != null} + ({formatBalance(uc.balance)}) + {/if} + + {/each} +
+ {:else} +

No cohorts

+ {/if} +
+ {/each} +
+

+ {filteredUsers.length} of {allUsers.length} users +

+ {:else if !loadingAllUsers} +

Click "Load Users" to view all users.

+ {/if} +
+
+{/if} diff --git a/frontend/src/routes/admin/cohorts/[name]/+page.svelte b/frontend/src/routes/admin/cohorts/[name]/+page.svelte new file mode 100644 index 00000000..75791f2d --- /dev/null +++ b/frontend/src/routes/admin/cohorts/[name]/+page.svelte @@ -0,0 +1,411 @@ + + +{#if loading} +
+
+
+{:else if !cohort} +
+

Cohort not found.

+
+{:else} +
+ + +
+
+

{cohort.display_name}

+

{cohort.name}

+
+
+ {#if cohort.is_read_only} + + Read-only + + {/if} + +
+
+ + +
+

Add Existing User

+
+ + + {selectedUserId ? selectedUserName : 'Search users...'} + + + + + + + No users found + + {#each availableUsers as user (user.id)} + { + selectedUserId = user.id; + closePopoverAndFocusTrigger(); + }} + > + {user.display_name} + + + {/each} + + + + + + + +
+
+ + +
+

Add by Email

+ +
+ + +
+
+ + +
+

+ Members ({members.length}) +

+ {#if loadingMembers} +

Loading members...

+ {:else if members.length === 0} +

No members yet.

+ {:else} +
+ {#each members as member} +
+
+
+ {#if member.display_name} + {member.display_name} + {/if} + {#if member.email} + {member.email} + {:else if !member.display_name} + User #{member.global_user_id} + {/if} +
+ {#if member.balance != null} + + {formatBalance(member.balance)} + + {/if} + {#if member.balance == null} + + {#if editingMemberId === member.id} +
+ { + if (e.key === 'Enter') saveEditingBalance(); + if (e.key === 'Escape') editingMemberId = null; + }} + /> + + +
+ {:else} + + {/if} + {:else if member.initial_balance} + + initial: {member.initial_balance} + + {/if} +
+ +
+ {/each} +
+ {/if} +
+
+{/if} diff --git a/frontend/src/routes/auction/+page.ts b/frontend/src/routes/auction/+page.ts deleted file mode 100644 index 69833202..00000000 --- a/frontend/src/routes/auction/+page.ts +++ /dev/null @@ -1,3 +0,0 @@ -export function load() { - // redirect(307, '/market'); -} diff --git a/python-client/README.md b/python-client/README.md index 140d8515..7b900708 100644 --- a/python-client/README.md +++ b/python-client/README.md @@ -40,7 +40,9 @@ This following may not work on Windows without WSL. 1. Install the dependencies with `uv sync` 2. Copy `example.env` to `.env` 3. Go to the "Accounts" page on the exchange and copy your JWT into `.env` -4. Make sure you are acting as the account you are going to be trading from, the copy the ACT_AS into `.env` +4. Make sure you are acting as the account you are going to be trading from, then copy the ACT_AS into `.env` +5. Set `API_URL` to the base server URL (e.g. `https://trading-bootcamp.fly.dev`) +6. Optionally set `COHORT` to the cohort name you want to connect to. If omitted, the server's default cohort is used. You can test if it is working by running ``` @@ -50,4 +52,12 @@ This command places orders at the min and max settlement prices, so you shouldn' You can look at other example bots in examples/, like the code for an (older?) version of mark (`market_maker_bot.py`) and bob (`naive.py`). +You can also use `list_cohorts()` to discover available cohorts programmatically: +```python +from metagame import list_cohorts +info = list_cohorts("https://trading-bootcamp.fly.dev", jwt) +print(info["cohorts"]) # Available cohorts +print(info["default_cohort"]) # Server default cohort name +``` + You can figure out the Jupyter Notebook if you wish. diff --git a/python-client/example.env b/python-client/example.env index 71a8d63b..422aa79a 100644 --- a/python-client/example.env +++ b/python-client/example.env @@ -1,3 +1,4 @@ JWT= ACT_AS= -API_URL=wss://trading-bootcamp.fly.dev/api +API_URL=https://trading-bootcamp.fly.dev +COHORT= diff --git a/python-client/examples/market_maker_bot.py b/python-client/examples/market_maker_bot.py index f00b527b..0da5a372 100644 --- a/python-client/examples/market_maker_bot.py +++ b/python-client/examples/market_maker_bot.py @@ -19,13 +19,14 @@ def main( jwt: Annotated[str, typer.Option(envvar="JWT")], api_url: Annotated[str, typer.Option(envvar="API_URL")], act_as: Annotated[int, typer.Option(envvar="ACT_AS")], - market_name: str, + cohort: Annotated[str, typer.Option(envvar="COHORT")] = "", + market_name: str = typer.Argument(), spread: float = 1.0, size: float = 1.0, fade_per_order: float = 1.0, prior: Optional[float] = None, ): - with TradingClient(api_url, jwt, act_as) as client: + with TradingClient(api_url, jwt, act_as, cohort=cohort or None) as client: market_maker_bot( client, market_name=market_name, diff --git a/python-client/examples/min_max_bot.py b/python-client/examples/min_max_bot.py index b4ee35b9..90d269b6 100644 --- a/python-client/examples/min_max_bot.py +++ b/python-client/examples/min_max_bot.py @@ -18,10 +18,11 @@ def main( jwt: Annotated[str, typer.Option(envvar="JWT")], api_url: Annotated[str, typer.Option(envvar="API_URL")], act_as: Annotated[int, typer.Option(envvar="ACT_AS")], - market_name: str, + cohort: Annotated[str, typer.Option(envvar="COHORT")] = "", + market_name: str = typer.Argument(), size: float = 100.0, ): - with TradingClient(api_url, jwt, act_as) as client: + with TradingClient(api_url, jwt, act_as, cohort=cohort or None) as client: min_max_bot( client, market_name=market_name, diff --git a/python-client/examples/naive_bot.py b/python-client/examples/naive_bot.py index 5babc1a4..8c9bab72 100644 --- a/python-client/examples/naive_bot.py +++ b/python-client/examples/naive_bot.py @@ -20,12 +20,13 @@ def main( jwt: Annotated[str, typer.Option(envvar="JWT")], api_url: Annotated[str, typer.Option(envvar="API_URL")], act_as: Annotated[int, typer.Option(envvar="ACT_AS")], - market_name: str, + cohort: Annotated[str, typer.Option(envvar="COHORT")] = "", + market_name: str = typer.Argument(), loss_per_trade: float = 1.0, max_size: float = 1.0, seconds_per_trade: float = 1.0, ): - with TradingClient(api_url, jwt, act_as) as client: + with TradingClient(api_url, jwt, act_as, cohort=cohort or None) as client: naive_bot( client, market_name=market_name, diff --git a/python-client/examples/twap_bot.py b/python-client/examples/twap_bot.py index 05271e80..831bef09 100644 --- a/python-client/examples/twap_bot.py +++ b/python-client/examples/twap_bot.py @@ -27,12 +27,13 @@ def main( jwt: Annotated[str, typer.Option(envvar="JWT")], api_url: Annotated[str, typer.Option(envvar="API_URL")], act_as: Annotated[int, typer.Option(envvar="ACT_AS")], - market_name: str, - desired_position: float, + cohort: Annotated[str, typer.Option(envvar="COHORT")] = "", + market_name: str = typer.Argument(), + desired_position: float = typer.Argument(), seconds_per_trade: float = 5.0, end_time: float = 300.0, ): - with TradingClient(api_url, jwt, act_as) as client: + with TradingClient(api_url, jwt, act_as, cohort=cohort or None) as client: market_id = client.state().market_name_to_id[market_name] state = TWAPState(next_trade_time=None, end_time=end_time, desired_position=desired_position) start_time = time() diff --git a/python-client/pyproject.toml b/python-client/pyproject.toml index 10951fea..48829082 100644 --- a/python-client/pyproject.toml +++ b/python-client/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "metagame" -version = "0.4.16" +version = "0.5.0" description = "MetaGame Trading Bootcamp Python Client" readme = "README.md" authors = [ diff --git a/python-client/src/metagame/__init__.py b/python-client/src/metagame/__init__.py index 5bbf7112..3a2a62bc 100644 --- a/python-client/src/metagame/__init__.py +++ b/python-client/src/metagame/__init__.py @@ -1,4 +1,4 @@ from . import websocket_api -from .trading_client import MarketData, RequestFailed, State, TradingClient +from .trading_client import MarketData, RequestFailed, State, TradingClient, list_cohorts -__all__ = ["MarketData", "RequestFailed", "State", "TradingClient", "websocket_api"] +__all__ = ["MarketData", "RequestFailed", "State", "TradingClient", "list_cohorts", "websocket_api"] diff --git a/python-client/src/metagame/trading_client.py b/python-client/src/metagame/trading_client.py index e63a7074..10c528ef 100644 --- a/python-client/src/metagame/trading_client.py +++ b/python-client/src/metagame/trading_client.py @@ -1,5 +1,8 @@ import bisect +import json import logging +import re +import urllib.request import uuid from dataclasses import dataclass, field from typing import Dict, List, Optional @@ -13,6 +16,25 @@ logger = logging.getLogger(__name__) +def list_cohorts(api_url: str, jwt: str) -> dict: + """Fetch cohorts and config from the server. + + Args: + api_url: Base HTTP(S) URL of the server (e.g. "https://trading-bootcamp.fly.dev") + jwt: JWT token for authentication + + Returns: + dict with 'cohorts', 'default_cohort', 'active_auction_cohort', 'public_auction_enabled' + """ + base = re.sub(r"/+$", "", api_url) + # Normalize ws/wss URLs to http/https + base = re.sub(r"^ws(s?)://", r"http\1://", base) + url = f"{base}/api/cohorts" + req = urllib.request.Request(url, headers={"Authorization": f"Bearer {jwt}"}) + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read()) + + class TradingClient: """ Client for interacting with the exchange server. @@ -21,17 +43,34 @@ class TradingClient: _ws: ClientConnection _state: "State" - def __init__(self, api_url: str, jwt: str, act_as: int, close_timeout: float = 1.0): + def __init__(self, api_url: str, jwt: str, act_as: int, cohort: Optional[str] = None, close_timeout: float = 1.0): """ Connect, Authenticate, then make sure all of the messages holding initial state have been received. Args: - api_url: WebSocket URL of the exchange server + api_url: Base URL of the exchange server (e.g. "https://trading-bootcamp.fly.dev") jwt: JWT token for authentication act_as: Account ID to act as (0 for own account) + cohort: Cohort name to connect to. If None, uses the server's default cohort. close_timeout: Timeout in seconds for WebSocket close handshake (default: 1.0) """ - self._ws = connect(api_url, max_size=2**27, close_timeout=close_timeout) + if cohort is None: + info = list_cohorts(api_url, jwt) + cohort = info.get("default_cohort") + if not cohort: + raise ValueError( + "No cohort specified and no default cohort configured on the server. " + "Pass cohort='' or ask an admin to set a default cohort." + ) + logger.info(f"Using default cohort: {cohort}") + + # Build WebSocket URL: convert http(s) to ws(s) and append path + base = re.sub(r"/+$", "", api_url) + base = re.sub(r"^http(s?)://", r"ws\1://", base) + base = re.sub(r"^ws(s?)://", r"ws\1://", base) + ws_url = f"{base}/api/ws/{cohort}" + + self._ws = connect(ws_url, max_size=2**27, close_timeout=close_timeout) self._state = State() self._outstanding_requests = set() authenticate = websocket_api.Authenticate(jwt=jwt, act_as=act_as) @@ -759,6 +798,7 @@ class State: _initializing: bool = True user_id: int = 0 acting_as: int = 0 + auction_only: bool = False current_universe_id: int = 0 sudo_enabled: bool = False portfolio: websocket_api.Portfolio = field(default_factory=websocket_api.Portfolio) @@ -777,6 +817,7 @@ def _update(self, server_message: websocket_api.ServerMessage): if isinstance(message, websocket_api.Authenticated): self.user_id = message.account_id + self.auction_only = message.auction_only elif isinstance(message, websocket_api.ActingAs): # ActingAs is always the last message in the initialization sequence diff --git a/python-client/src/metagame/websocket_api.py b/python-client/src/metagame/websocket_api.py index 2db9fcc1..4cc9697c 100644 --- a/python-client/src/metagame/websocket_api.py +++ b/python-client/src/metagame/websocket_api.py @@ -1,5 +1,5 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: portfolio.proto, redeemable.proto, market.proto, market-type.proto, market-group.proto, market-settled.proto, orders-cancelled.proto, side.proto, order.proto, trade.proto, order-created.proto, transfer.proto, request-failed.proto, out.proto, account.proto, redeem.proto, orders.proto, trades.proto, auction.proto, auction-settled.proto, market-positions.proto, universe.proto, server-message.proto, make-transfer.proto, create-market.proto, create-universe.proto, create-market-type.proto, delete-market-type.proto, create-market-group.proto, create-auction.proto, settle-market.proto, edit-market.proto, settle-auction.proto, edit-auction.proto, create-order.proto, buy-auction.proto, client-message.proto +# sources: portfolio.proto, redeemable.proto, market.proto, market-type.proto, market-group.proto, market-settled.proto, orders-cancelled.proto, side.proto, order.proto, trade.proto, order-created.proto, transfer.proto, request-failed.proto, out.proto, account.proto, redeem.proto, orders.proto, trades.proto, auction.proto, auction-settled.proto, universe.proto, server-message.proto, make-transfer.proto, create-market.proto, create-universe.proto, create-market-type.proto, delete-market-type.proto, create-market-group.proto, create-auction.proto, settle-market.proto, edit-market.proto, settle-auction.proto, edit-auction.proto, create-order.proto, buy-auction.proto, client-message.proto # plugin: python-betterproto from dataclasses import dataclass from datetime import datetime @@ -277,26 +277,6 @@ class AuctionSettled(betterproto.Message): buyer_id: int = betterproto.int64_field(5) -@dataclass -class GetMarketPositions(betterproto.Message): - market_id: int = betterproto.int64_field(1) - - -@dataclass -class MarketPositions(betterproto.Message): - market_id: int = betterproto.int64_field(1) - positions: List["ParticipantPosition"] = betterproto.message_field(2) - - -@dataclass -class ParticipantPosition(betterproto.Message): - account_id: int = betterproto.int64_field(1) - gross: float = betterproto.double_field(2) - net: float = betterproto.double_field(3) - avg_buy_price: float = betterproto.double_field(4, group="_avg_buy_price") - avg_sell_price: float = betterproto.double_field(5, group="_avg_sell_price") - - @dataclass class Universe(betterproto.Message): id: int = betterproto.int64_field(1) @@ -339,7 +319,6 @@ class ServerMessage(betterproto.Message): ) market_group: "MarketGroup" = betterproto.message_field(29, group="message") market_groups: "MarketGroups" = betterproto.message_field(30, group="message") - market_positions: "MarketPositions" = betterproto.message_field(31, group="message") sudo_status: "SudoStatus" = betterproto.message_field(32, group="message") universe: "Universe" = betterproto.message_field(33, group="message") universes: "Universes" = betterproto.message_field(34, group="message") @@ -353,6 +332,7 @@ class MarketTypeDeleted(betterproto.Message): @dataclass class Authenticated(betterproto.Message): account_id: int = betterproto.int64_field(1) + auction_only: bool = betterproto.bool_field(2) @dataclass @@ -559,9 +539,6 @@ class ClientMessage(betterproto.Message): create_market_group: "CreateMarketGroup" = betterproto.message_field( 23, group="message" ) - get_market_positions: "GetMarketPositions" = betterproto.message_field( - 24, group="message" - ) set_sudo: "SetSudo" = betterproto.message_field(26, group="message") create_universe: "CreateUniverse" = betterproto.message_field(27, group="message") diff --git a/schema-js/index.d.ts b/schema-js/index.d.ts index 104c0579..2a8b6151 100644 --- a/schema-js/index.d.ts +++ b/schema-js/index.d.ts @@ -385,6 +385,9 @@ export namespace websocket_api { /** Authenticated accountId */ accountId?: (number|Long|null); + + /** Authenticated auctionOnly */ + auctionOnly?: (boolean|null); } /** Represents an Authenticated. */ @@ -399,6 +402,9 @@ export namespace websocket_api { /** Authenticated accountId. */ public accountId: (number|Long); + /** Authenticated auctionOnly. */ + public auctionOnly: boolean; + /** * Creates a new Authenticated instance using the specified properties. * @param [properties] Properties to set diff --git a/schema-js/index.js b/schema-js/index.js index 10e7cdd5..0b925795 100644 --- a/schema-js/index.js +++ b/schema-js/index.js @@ -1508,6 +1508,7 @@ $root.websocket_api = (function() { * @memberof websocket_api * @interface IAuthenticated * @property {number|Long|null} [accountId] Authenticated accountId + * @property {boolean|null} [auctionOnly] Authenticated auctionOnly */ /** @@ -1533,6 +1534,14 @@ $root.websocket_api = (function() { */ Authenticated.prototype.accountId = $util.Long ? $util.Long.fromBits(0,0,false) : 0; + /** + * Authenticated auctionOnly. + * @member {boolean} auctionOnly + * @memberof websocket_api.Authenticated + * @instance + */ + Authenticated.prototype.auctionOnly = false; + /** * Creates a new Authenticated instance using the specified properties. * @function create @@ -1559,6 +1568,8 @@ $root.websocket_api = (function() { writer = $Writer.create(); if (message.accountId != null && Object.hasOwnProperty.call(message, "accountId")) writer.uint32(/* id 1, wireType 0 =*/8).int64(message.accountId); + if (message.auctionOnly != null && Object.hasOwnProperty.call(message, "auctionOnly")) + writer.uint32(/* id 2, wireType 0 =*/16).bool(message.auctionOnly); return writer; }; @@ -1597,6 +1608,10 @@ $root.websocket_api = (function() { message.accountId = reader.int64(); break; } + case 2: { + message.auctionOnly = reader.bool(); + break; + } default: reader.skipType(tag & 7); break; @@ -1635,6 +1650,9 @@ $root.websocket_api = (function() { if (message.accountId != null && message.hasOwnProperty("accountId")) if (!$util.isInteger(message.accountId) && !(message.accountId && $util.isInteger(message.accountId.low) && $util.isInteger(message.accountId.high))) return "accountId: integer|Long expected"; + if (message.auctionOnly != null && message.hasOwnProperty("auctionOnly")) + if (typeof message.auctionOnly !== "boolean") + return "auctionOnly: boolean expected"; return null; }; @@ -1659,6 +1677,8 @@ $root.websocket_api = (function() { message.accountId = object.accountId; else if (typeof object.accountId === "object") message.accountId = new $util.LongBits(object.accountId.low >>> 0, object.accountId.high >>> 0).toNumber(); + if (object.auctionOnly != null) + message.auctionOnly = Boolean(object.auctionOnly); return message; }; @@ -1675,17 +1695,21 @@ $root.websocket_api = (function() { if (!options) options = {}; var object = {}; - if (options.defaults) + if (options.defaults) { if ($util.Long) { var long = new $util.Long(0, 0, false); object.accountId = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; } else object.accountId = options.longs === String ? "0" : 0; + object.auctionOnly = false; + } if (message.accountId != null && message.hasOwnProperty("accountId")) if (typeof message.accountId === "number") object.accountId = options.longs === String ? String(message.accountId) : message.accountId; else object.accountId = options.longs === String ? $util.Long.prototype.toString.call(message.accountId) : options.longs === Number ? new $util.LongBits(message.accountId.low >>> 0, message.accountId.high >>> 0).toNumber() : message.accountId; + if (message.auctionOnly != null && message.hasOwnProperty("auctionOnly")) + object.auctionOnly = message.auctionOnly; return object; }; diff --git a/schema/server-message.proto b/schema/server-message.proto index c915eecd..314c812d 100644 --- a/schema/server-message.proto +++ b/schema/server-message.proto @@ -60,6 +60,7 @@ message MarketTypeDeleted { } message Authenticated { int64 account_id = 1; + bool auction_only = 2; } message ActingAs { int64 account_id = 1; diff --git a/vercel.json b/vercel.json new file mode 100644 index 00000000..0d14cfa0 --- /dev/null +++ b/vercel.json @@ -0,0 +1,14 @@ +{ + "rewrites": [{ "source": "/(.*)", "destination": "/" }], + "headers": [ + { + "source": "/(.*)", + "headers": [ + { + "key": "Access-Control-Allow-Origin", + "value": "*" + } + ] + } + ] +}