|
2 | 2 | from fastapi.responses import StreamingResponse, JSONResponse |
3 | 3 | from sqlmodel import Session, select, text, or_ |
4 | 4 | from backend.db import engine, create_db_and_tables |
5 | | -from backend.models import Exercise, Workout, SetEntry, Profile, BodyweightEntry |
| 5 | +from backend.models import Exercise, Workout, SetEntry, Profile, BodyweightEntry, PushSubscription |
6 | 6 | from backend import schemas |
7 | 7 | from backend.seed_exercises import seed as seed_exercises |
8 | 8 | from prometheus_client import generate_latest, CONTENT_TYPE_LATEST, Counter |
9 | 9 | from typing import List, Optional |
10 | 10 | from datetime import datetime, timedelta |
11 | 11 | import io, csv, re, hashlib, os as _os, time as _time, secrets as _secrets |
| 12 | +import asyncio, json, base64 |
12 | 13 | import uvicorn |
13 | 14 | from backend import auth as _auth |
14 | 15 |
|
|
17 | 18 | _PW_MAX_ATTEMPTS = 5 |
18 | 19 | _PW_LOCKOUT_SECS = 30 |
19 | 20 |
|
| 21 | +# ── Web Push state ── |
| 22 | +_vapid_state: dict = {"private_b64url": None, "public_b64url": None} |
| 23 | +_push_tasks: dict = {} # key -> asyncio.Task |
| 24 | + |
20 | 25 | app = FastAPI(title="lifty API") |
21 | 26 |
|
22 | 27 | REQUEST_COUNTER = Counter("lifty_requests_total", "Total HTTP requests", ["method", "endpoint", "status"]) |
@@ -110,6 +115,39 @@ def on_startup(): |
110 | 115 | # ── Seed built-in global exercises (idempotent) ── |
111 | 116 | seed_exercises() |
112 | 117 |
|
| 118 | + # ── VAPID key pair: generate once, persist in app_config ── |
| 119 | + from cryptography.hazmat.primitives.asymmetric.ec import generate_private_key, SECP256R1 |
| 120 | + from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat |
| 121 | + with Session(engine) as session: |
| 122 | + session.exec(text(""" |
| 123 | + CREATE TABLE IF NOT EXISTS pushsubscription ( |
| 124 | + id INTEGER PRIMARY KEY AUTOINCREMENT, |
| 125 | + profile_id INTEGER NOT NULL REFERENCES profile(id), |
| 126 | + endpoint TEXT NOT NULL, |
| 127 | + p256dh TEXT NOT NULL, |
| 128 | + auth TEXT NOT NULL, |
| 129 | + created_at DATETIME |
| 130 | + ) |
| 131 | + """)) |
| 132 | + session.commit() |
| 133 | + |
| 134 | + row = session.exec(text("SELECT value FROM app_config WHERE key='vapid_private'")).first() |
| 135 | + if row: |
| 136 | + _vapid_state["private_b64url"] = row[0] |
| 137 | + pub_row = session.exec(text("SELECT value FROM app_config WHERE key='vapid_public'")).first() |
| 138 | + _vapid_state["public_b64url"] = pub_row[0] if pub_row else None |
| 139 | + else: |
| 140 | + sk = generate_private_key(SECP256R1()) |
| 141 | + raw_priv = sk.private_numbers().private_value.to_bytes(32, 'big') |
| 142 | + priv_b64 = base64.urlsafe_b64encode(raw_priv).decode().rstrip('=') |
| 143 | + pub_bytes = sk.public_key().public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) |
| 144 | + pub_b64 = base64.urlsafe_b64encode(pub_bytes).decode().rstrip('=') |
| 145 | + session.execute(text("INSERT INTO app_config (key, value) VALUES (:k, :v)"), {"k": "vapid_private", "v": priv_b64}) |
| 146 | + session.execute(text("INSERT INTO app_config (key, value) VALUES (:k, :v)"), {"k": "vapid_public", "v": pub_b64}) |
| 147 | + session.commit() |
| 148 | + _vapid_state["private_b64url"] = priv_b64 |
| 149 | + _vapid_state["public_b64url"] = pub_b64 |
| 150 | + |
113 | 151 |
|
114 | 152 | # ───────────────────────────────────────────── |
115 | 153 | # Utility |
@@ -161,6 +199,100 @@ async def auth_change_password(request: Request): |
161 | 199 | return {"token": _auth.create_access_token(_auth_state["jwt_secret"])} |
162 | 200 |
|
163 | 201 |
|
| 202 | +# ───────────────────────────────────────────── |
| 203 | +# Web Push (VAPID) |
| 204 | +# ───────────────────────────────────────────── |
| 205 | + |
| 206 | +@app.get("/api/push/vapid-public-key") |
| 207 | +def push_vapid_key(): |
| 208 | + """Return the VAPID public key so the client can create a push subscription.""" |
| 209 | + return {"publicKey": _vapid_state.get("public_b64url")} |
| 210 | + |
| 211 | + |
| 212 | +@app.post("/api/push/subscribe") |
| 213 | +async def push_subscribe(request: Request): |
| 214 | + """Store or update a push subscription for a profile.""" |
| 215 | + body = await request.json() |
| 216 | + profile_id = body.get("profileId") |
| 217 | + endpoint = body.get("endpoint") |
| 218 | + p256dh = body.get("p256dh") |
| 219 | + auth = body.get("auth") |
| 220 | + if not all([profile_id, endpoint, p256dh, auth]): |
| 221 | + return JSONResponse({"detail": "Missing fields"}, status_code=422) |
| 222 | + with Session(engine) as session: |
| 223 | + existing = session.exec( |
| 224 | + select(PushSubscription).where(PushSubscription.profile_id == profile_id) |
| 225 | + ).first() |
| 226 | + if existing: |
| 227 | + existing.endpoint = endpoint |
| 228 | + existing.p256dh = p256dh |
| 229 | + existing.auth = auth |
| 230 | + session.add(existing) |
| 231 | + else: |
| 232 | + session.add(PushSubscription( |
| 233 | + profile_id=profile_id, endpoint=endpoint, p256dh=p256dh, auth=auth |
| 234 | + )) |
| 235 | + session.commit() |
| 236 | + return {"ok": True} |
| 237 | + |
| 238 | + |
| 239 | +@app.post("/api/push/schedule") |
| 240 | +async def push_schedule(request: Request): |
| 241 | + """Schedule a Web Push notification after delay_ms milliseconds.""" |
| 242 | + body = await request.json() |
| 243 | + profile_id = body.get("profileId") |
| 244 | + delay_ms = int(body.get("delayMs", 0)) |
| 245 | + title = body.get("title", "lifty") |
| 246 | + msg_body = body.get("body", "Rest done — time to lift!") |
| 247 | + |
| 248 | + key = f"rest-{profile_id}" |
| 249 | + |
| 250 | + # Cancel any previous task for this slot |
| 251 | + if key in _push_tasks and not _push_tasks[key].done(): |
| 252 | + _push_tasks[key].cancel() |
| 253 | + |
| 254 | + async def _send(): |
| 255 | + try: |
| 256 | + await asyncio.sleep(delay_ms / 1000) |
| 257 | + with Session(engine) as session: |
| 258 | + sub = session.exec( |
| 259 | + select(PushSubscription).where(PushSubscription.profile_id == profile_id) |
| 260 | + ).first() |
| 261 | + if not sub: |
| 262 | + return |
| 263 | + from pywebpush import webpush, WebPushException |
| 264 | + webpush( |
| 265 | + subscription_info={ |
| 266 | + "endpoint": sub.endpoint, |
| 267 | + "keys": {"p256dh": sub.p256dh, "auth": sub.auth}, |
| 268 | + }, |
| 269 | + data=json.dumps({"title": title, "body": msg_body}), |
| 270 | + vapid_private_key=_vapid_state["private_b64url"], |
| 271 | + vapid_claims={"sub": "mailto:lifty@lifty.app"}, |
| 272 | + ) |
| 273 | + except asyncio.CancelledError: |
| 274 | + pass |
| 275 | + except Exception as exc: |
| 276 | + print(f"[push] send error: {exc}") |
| 277 | + finally: |
| 278 | + _push_tasks.pop(key, None) |
| 279 | + |
| 280 | + _push_tasks[key] = asyncio.create_task(_send()) |
| 281 | + return {"ok": True} |
| 282 | + |
| 283 | + |
| 284 | +@app.post("/api/push/cancel") |
| 285 | +async def push_cancel(request: Request): |
| 286 | + """Cancel a pending push notification for a profile.""" |
| 287 | + body = await request.json() |
| 288 | + profile_id = body.get("profileId") |
| 289 | + key = f"rest-{profile_id}" |
| 290 | + if key in _push_tasks and not _push_tasks[key].done(): |
| 291 | + _push_tasks[key].cancel() |
| 292 | + _push_tasks.pop(key, None) |
| 293 | + return {"ok": True} |
| 294 | + |
| 295 | + |
164 | 296 | # ───────────────────────────────────────────── |
165 | 297 | # Utility |
166 | 298 | # ───────────────────────────────────────────── |
|
0 commit comments