Skip to content

Commit a2373e0

Browse files
committed
Add initial Websocket
The websocket is used to send immediate notification for newly created sensor data / state
1 parent 17c60a1 commit a2373e0

File tree

7 files changed

+151
-3
lines changed

7 files changed

+151
-3
lines changed

api/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
import asyncio
12
from contextlib import asynccontextmanager
23
from typing import Annotated
34

45
from fastapi import Depends, FastAPI
56
from sqlalchemy.ext.asyncio import AsyncSession
67

7-
from api.routers import authentication, forecast, permissions, sensors, serverstats, users
8+
from api.routers import authentication, forecast, permissions, sensors, serverstats, users, websocket
89
from api.utils.database import dispose_database, get_session
910
from api.utils.security import get_current_user
11+
from api.utils.websocket_connection_handler import get_websocket_handler
1012

1113

1214
@asynccontextmanager
1315
async def lifespan(_app: FastAPI):
16+
asyncio.get_event_loop().create_task(get_websocket_handler().event_loop())
1417
yield
1518
await dispose_database()
1619

20+
1721
app: FastAPI = FastAPI(root_path="/weatherapi", lifespan=lifespan)
1822

1923
app.include_router(authentication.auth_router)
@@ -22,6 +26,7 @@ async def lifespan(_app: FastAPI):
2226
app.include_router(forecast.forecast_router)
2327
app.include_router(permissions.permissions_router)
2428
app.include_router(serverstats.serverstats_router)
29+
app.include_router(websocket.websocket_router)
2530

2631

2732
@app.get("/")

api/routers/sensors.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated
22

3-
from fastapi import APIRouter, Depends, status
3+
from fastapi import APIRouter, BackgroundTasks, Depends, status
44
from sqlalchemy.ext.asyncio.session import AsyncSession
55
from sqlalchemy.sql import func
66
from sqlmodel import NUMERIC, cast, or_, select
@@ -22,6 +22,7 @@
2222
from api.utils.permissions import get_user_read_permissions, get_user_write_permissions
2323
from api.utils.security import get_current_superuser, get_current_user
2424
from api.utils.sensor_utils import get_is_valid_sensor_type, get_sensor_from_db
25+
from api.utils.websocket_connection_handler import WebsocketHandler, get_websocket_handler
2526

2627
sensors_router = APIRouter(tags=["Sensors"], prefix="/sensor")
2728

@@ -117,6 +118,8 @@ async def create_sensor(
117118
@sensors_router.post("/{sensor_id}/data", response_model=SensorData, status_code=status.HTTP_201_CREATED)
118119
async def create_sensor_data(
119120
session: Annotated[AsyncSession, Depends(get_session)],
121+
ws_handler: Annotated[WebsocketHandler, Depends(get_websocket_handler)],
122+
background_tasks: BackgroundTasks,
120123
current_user: Annotated[DBUser, Depends(get_current_user)],
121124
sensor_id: int,
122125
data: SensorDataCreate,
@@ -140,6 +143,9 @@ async def create_sensor_data(
140143
session.add(data)
141144
await session.commit()
142145
await session.refresh(data)
146+
147+
background_tasks.add_task(ws_handler.add_event, data)
148+
143149
return data
144150

145151

@@ -210,6 +216,8 @@ async def get_sensor_data_daily(
210216
@sensors_router.post("/{sensor_id}/state", response_model=SensorState, status_code=status.HTTP_201_CREATED)
211217
async def create_sensor_state(
212218
session: Annotated[AsyncSession, Depends(get_session)],
219+
ws_handler: Annotated[WebsocketHandler, Depends(get_websocket_handler)],
220+
background_tasks: BackgroundTasks,
213221
current_user: Annotated[DBUser, Depends(get_current_user)],
214222
sensor_id: int,
215223
data: SensorStateCreate,
@@ -233,6 +241,9 @@ async def create_sensor_state(
233241
session.add(data)
234242
await session.commit()
235243
await session.refresh(data)
244+
245+
background_tasks.add_task(ws_handler.add_event, data)
246+
236247
return data
237248

238249

api/routers/websocket.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Annotated
2+
3+
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
4+
5+
from api.models.database_models import DBUser
6+
from api.utils.security import get_current_user_ws
7+
from api.utils.websocket_connection_handler import WebsocketHandler, get_websocket_handler
8+
9+
websocket_router = APIRouter()
10+
11+
12+
@websocket_router.websocket("/ws")
13+
async def websocket_endpoint(
14+
websocket: WebSocket,
15+
current_user: Annotated[DBUser, Depends(get_current_user_ws)],
16+
ws_handler: Annotated[WebsocketHandler, Depends(get_websocket_handler)],
17+
):
18+
await websocket.accept()
19+
ws_handler.add(current_user, websocket)
20+
print(f"Accepted connection with {current_user.username}")
21+
try:
22+
while True:
23+
await websocket.receive_text()
24+
except WebSocketDisconnect:
25+
ws_handler.remove(current_user)

api/utils/security.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime, timedelta, timezone
22
from typing import Annotated
33

4-
from fastapi import Depends
4+
from fastapi import Depends, Header, WebSocketException, status
55
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
66
from jose import JWTError, jwt
77
from passlib.context import CryptContext
@@ -92,6 +92,7 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None):
9292
async def get_current_user(
9393
token: Annotated[str, Depends(oauth2_scheme)], session: Annotated[AsyncSession, Depends(get_session)]
9494
) -> DBUser:
95+
# ToDo Handle expired token somehow
9596
try:
9697
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
9798
username: str = payload.get("sub")
@@ -105,6 +106,29 @@ async def get_current_user(
105106
return user
106107

107108

109+
async def get_current_user_ws(
110+
session: Annotated[AsyncSession, Depends(get_session)], authorization: Annotated[str | None, Header()] = None
111+
) -> DBUser:
112+
if authorization is None:
113+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
114+
115+
if not authorization.startswith("Bearer "):
116+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
117+
118+
token = authorization.split(" ")[1]
119+
try:
120+
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
121+
username: str = payload.get("sub")
122+
if username is None:
123+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
124+
except JWTError:
125+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
126+
user = await get_user(username, session)
127+
if user is None:
128+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
129+
return user
130+
131+
108132
async def get_current_superuser(user: Annotated[DBUser, Depends(get_current_user)]):
109133
if not user.superuser:
110134
raise MISSING_PRIVILEGES
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
3+
from fastapi import HTTPException, WebSocket
4+
from pydantic import BaseModel
5+
from sqlalchemy.ext.asyncio import AsyncEngine
6+
7+
from api.models.database_models import DatabaseModelBase, DBUser
8+
from api.utils.database import get_engine, get_session
9+
from api.utils.permissions import get_user_read_permissions
10+
11+
12+
class WebsocketHandler:
13+
def __init__(self):
14+
self._connections: dict[int, tuple[DBUser, WebSocket]] = {}
15+
self._message_queue: asyncio.Queue[DatabaseModelBase] = asyncio.Queue()
16+
17+
def add(self, user: DBUser, websocket: WebSocket) -> bool:
18+
if self._connections.get(user.id):
19+
return False
20+
21+
self._connections[user.id] = (user, websocket)
22+
23+
def remove(self, user: DBUser) -> bool:
24+
if self._connections.pop(user.id, None):
25+
return True
26+
return False
27+
28+
async def event_loop(self):
29+
engine: AsyncEngine = await get_engine()
30+
while True:
31+
event = await self._message_queue.get()
32+
async for session in get_session(engine):
33+
for user, websocket in self._connections.values():
34+
try:
35+
await get_user_read_permissions(session, user, event.id)
36+
await websocket.send_json(event.model_dump_json())
37+
except HTTPException:
38+
pass
39+
self._message_queue.task_done()
40+
41+
async def add_event(self, data: BaseModel):
42+
await self._message_queue.put(data)
43+
44+
45+
_websocket_handler = WebsocketHandler()
46+
47+
48+
def get_websocket_handler() -> WebsocketHandler:
49+
return _websocket_handler

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"sqlalchemy==2.0.25",
1818
"sqlmodel==0.0.14",
1919
"uvicorn==0.27.0",
20+
"websockets>=15.0.1",
2021
]
2122

2223
[tool.black]

uv.lock

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)