Skip to content

Commit 2fef900

Browse files
committed
Fix ws conn handler, add ws tests
- Fixes permission issue in connection handler event loop - Adds tests for the websockets router and conn handler
1 parent ea2d4da commit 2fef900

3 files changed

Lines changed: 119 additions & 13 deletions

File tree

api/utils/websocket_connection_handler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@
33
from fastapi import HTTPException, WebSocket
44
from sqlalchemy.ext.asyncio import AsyncEngine
55

6-
from api.models.database_models import DatabaseModelBase, DBUser
6+
from api.models.database_models import DBUser, SensorData, SensorState
77
from api.utils.database import get_engine, get_session
88
from api.utils.permissions import get_user_read_permissions
99

1010

1111
class WebsocketHandler:
1212
def __init__(self):
1313
self._connections: dict[int, tuple[DBUser, WebSocket]] = {}
14-
self._message_queue: asyncio.Queue[DatabaseModelBase] = asyncio.Queue()
14+
self._message_queue: asyncio.Queue[SensorState | SensorData] = asyncio.Queue()
1515

1616
def add(self, user: DBUser, websocket: WebSocket) -> bool:
1717
if not user.id:
18-
return False
18+
return False # pragma: no cover: Just for type safety, can not really happen (Should we remove the check?)
1919
if self._connections.get(user.id):
2020
return False
2121

@@ -24,10 +24,10 @@ def add(self, user: DBUser, websocket: WebSocket) -> bool:
2424

2525
def remove(self, user: DBUser) -> bool:
2626
if not user.id:
27-
return False
27+
return False # pragma: no cover: Just for type safety, can not really happen (Should we remove the check?)
2828
if self._connections.pop(user.id, None):
2929
return True
30-
return False
30+
return False # pragma: no cover: Just for type safety, shouldn't happen
3131

3232
async def event_loop(self):
3333
engine: AsyncEngine = await get_engine()
@@ -36,14 +36,13 @@ async def event_loop(self):
3636
async for session in get_session(engine):
3737
for user, websocket in self._connections.values():
3838
try:
39-
if event.id:
40-
await get_user_read_permissions(session, user, event.id)
41-
await websocket.send_text(event.model_dump_json())
42-
except HTTPException:
39+
await get_user_read_permissions(session, user, event.sensor_id)
40+
await websocket.send_text(event.model_dump_json())
41+
except HTTPException: # pragma: no cover: For now no tests, permission handling is tested elsewhere
4342
pass
4443
self._message_queue.task_done()
4544

46-
async def add_event(self, data: DatabaseModelBase):
45+
async def add_event(self, data: SensorState | SensorData):
4746
await self._message_queue.put(data)
4847

4948

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
from typing import Any
3+
4+
import pytest
5+
from fastapi import WebSocketDisconnect
6+
from fastapi.testclient import TestClient
7+
8+
from api.main import app
9+
from api.models.database_models import Sensor, SensorData
10+
from api.utils.database import get_engine
11+
from api.utils.websocket_connection_handler import get_websocket_handler
12+
from tests.utils.fake_db import override_get_engine
13+
from tests.utils.fixtures import superuser_token, token
14+
from tests.utils.sensor_utils import create_sensor
15+
16+
app.dependency_overrides[get_engine] = override_get_engine
17+
client: TestClient = TestClient(app)
18+
19+
20+
def test_ws_connect_no_authentication():
21+
try:
22+
with client.websocket_connect("/ws") as _:
23+
return
24+
except WebSocketDisconnect as exception:
25+
assert exception.code == 1008
26+
27+
28+
def test_ws_connect_no_bearer():
29+
try:
30+
with client.websocket_connect("/ws", headers={"Authorization": "1234"}) as _:
31+
return
32+
except WebSocketDisconnect as exception:
33+
assert exception.code == 1008
34+
35+
36+
def test_ws_connect_wrong_token(token: str):
37+
try:
38+
with client.websocket_connect("/ws", headers={"Authorization": f"Bearer {token}1234"}) as _:
39+
return
40+
except WebSocketDisconnect as exception:
41+
assert exception.code == 1008
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_ws_connect(superuser_token: str):
46+
ws_handler = get_websocket_handler()
47+
task = asyncio.get_event_loop().create_task(get_websocket_handler().event_loop())
48+
49+
sensor: Sensor = await create_sensor(name="TestSensorWS")
50+
51+
with client.websocket_connect("/ws", headers={"Authorization": f"Bearer {superuser_token}"}) as websocket:
52+
data = websocket.receive_json()
53+
assert data == {"message": "Hello World!"}
54+
55+
data: dict[str, Any] = {
56+
"id": -1,
57+
"temperature": 32.1,
58+
"humidity": 56.78,
59+
"pressure": 123.45,
60+
"voltage": 3.45,
61+
"sensor_id": sensor.id,
62+
}
63+
sensor_data: SensorData = SensorData(**data)
64+
await ws_handler.add_event(sensor_data)
65+
66+
# As the websocket.receive_json() is synchronous and TestClient does not really work with async ws,
67+
# the delay is need, so that the ws_handler can run its loop.
68+
await asyncio.sleep(0.1)
69+
70+
received_data = websocket.receive_json()
71+
assert data.get("id") == received_data.get("id")
72+
assert data.get("temperature") == received_data.get("temperature")
73+
task.cancel()
74+
75+
76+
# ToDo Investigate performance issue
77+
def test_ws_connect_double_connection(token: str):
78+
with client.websocket_connect("/ws", headers={"Authorization": f"Bearer {token}"}) as websocket:
79+
data = websocket.receive_json()
80+
assert data == {"message": "Hello World!"}
81+
82+
try:
83+
with client.websocket_connect("/ws", headers={"Authorization": f"Bearer {token}"}) as _:
84+
pass
85+
except WebSocketDisconnect as exception:
86+
assert exception.code == 1008

tests/utils_tests/test_security.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import datetime
22

33
import pytest
4-
from fastapi import HTTPException
4+
from fastapi import HTTPException, WebSocketException
5+
from pytest_mock import MockerFixture
56
from sqlmodel import delete
67

78
from api.models.database_models import DBUser
89
from api.utils.http_exceptions import INVALID_CREDENTIALS
9-
from api.utils.security import create_access_token, get_current_user
10+
from api.utils.security import create_access_token, get_current_user, get_current_user_ws
1011
from tests.utils.fake_db import async_fake_session_maker, initialize_fake_database
1112

1213

1314
@pytest.mark.parametrize("token_data", [{}, {"sub": "fake_username"}])
1415
@pytest.mark.asyncio
15-
async def test_get_current_user_invalid_token(token_data: str):
16+
async def test_get_current_user_invalid_token(token_data: dict):
1617
await initialize_fake_database()
1718
async with async_fake_session_maker() as session:
1819
await session.execute(delete(DBUser))
@@ -24,3 +25,23 @@ async def test_get_current_user_invalid_token(token_data: str):
2425
await get_current_user(access_token, session)
2526

2627
assert exception.value == INVALID_CREDENTIALS
28+
29+
30+
@pytest.mark.parametrize("token_data", [{}, {"sub": "fake_username"}])
31+
@pytest.mark.asyncio
32+
async def test_get_current_user_ws_invalid_token(token_data: dict, mocker: MockerFixture):
33+
await initialize_fake_database()
34+
async with async_fake_session_maker() as session:
35+
await session.execute(delete(DBUser))
36+
await session.commit()
37+
38+
access_token = create_access_token(data=token_data, expires_delta=datetime.timedelta(minutes=1))
39+
40+
jwt_decode_mock = mocker.patch("api.utils.security.jwt.decode")
41+
jwt_decode_mock.return_value = token_data
42+
43+
with pytest.raises(WebSocketException) as exception:
44+
await get_current_user_ws(session, f"Bearer {access_token}")
45+
46+
# assert jwt_decode_mock.assert_called_once()
47+
assert exception.value.code == 1008

0 commit comments

Comments
 (0)