diff --git a/doc/conf.py b/doc/conf.py index 57197eb..2a0ac4e 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -42,6 +42,9 @@ ("py:class", r"dotbot_utils.*"), ("py:class", r"ASGIApp"), ("py:class", r"DispatchFunction"), + ("py:class", r"dotbot.models.Annotated"), + ("py:class", r"Query"), + ("py:class", r"PydanticUndefined"), ] # -- Options for HTML output ------------------------------------------------- diff --git a/dotbot/controller.py b/dotbot/controller.py index 2478d4e..cbad560 100644 --- a/dotbot/controller.py +++ b/dotbot/controller.py @@ -736,19 +736,60 @@ def get_dotbots(self, query: DotBotQueryModel) -> List[DotBotModel]: """Returns the list of dotbots matching the query.""" dotbots: List[DotBotModel] = [] for dotbot in self.dotbots.values(): + if query.address is not None and dotbot.address != query.address: + continue if ( query.application is not None - and dotbot.application != query.application + and dotbot.application.value != query.application ): continue - if query.mode is not None and dotbot.mode != query.mode: + if query.status is not None and dotbot.status.value != query.status: + continue + if query.max_battery is not None and dotbot.battery is not None: + if dotbot.battery > query.max_battery: + continue + if query.min_battery is not None and dotbot.battery is not None: + if dotbot.battery < query.min_battery: + continue + if ( + any( + [ + query.max_position_x is not None, + query.min_position_x is not None, + query.max_position_y is not None, + query.min_position_y is not None, + ] + ) + and dotbot.lh2_position is None + ): continue - if query.status is not None and dotbot.status != query.status: + if dotbot.lh2_position is None and query.max_positions is not None: continue + if dotbot.lh2_position is not None: + if query.max_position_x is not None: + if query.max_position_x < dotbot.lh2_position.x: + continue + if query.min_position_x is not None: + if query.min_position_x > dotbot.lh2_position.x: + continue + if query.max_position_y is not None: + if query.max_position_y < dotbot.lh2_position.y: + continue + if query.min_position_y is not None: + if query.min_position_y > dotbot.lh2_position.y: + continue _dotbot = DotBotModel(**dotbot.model_dump()) - _dotbot.position_history = _dotbot.position_history[: query.max_positions] + max_positions = ( + MAX_POSITION_HISTORY_SIZE + if query.max_positions is None + else query.max_positions + ) + _dotbot.position_history = _dotbot.position_history[:max_positions] dotbots.append(_dotbot) - return sorted(dotbots, key=lambda dotbot: dotbot.address) + dotbots = sorted(dotbots, key=lambda dotbot: dotbot.address) + if query.limit is not None: + dotbots = dotbots[: query.limit] + return dotbots async def web(self): """Starts the web server application.""" diff --git a/dotbot/examples/charging_station.py b/dotbot/examples/charging_station.py index 0e18dd6..e5dcbae 100644 --- a/dotbot/examples/charging_station.py +++ b/dotbot/examples/charging_station.py @@ -13,7 +13,9 @@ DotBotLH2Position, DotBotModel, DotBotMoveRawCommandModel, + DotBotQueryModel, DotBotRgbLedCommandModel, + DotBotStatus, DotBotWaypoints, WSRgbLed, WSWaypoints, @@ -52,12 +54,18 @@ async def queue_robots( await send_to_goal(client, ws, goals, params) +async def fetch_active_dotbots(client: RestClient) -> List[DotBotModel]: + return await client.fetch_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) + + async def charge_robots( client: RestClient, ws: DotBotWsClient, params: OrcaParams, ) -> None: - dotbots = await client.fetch_active_dotbots() + dotbots = await fetch_active_dotbots(client) remaining = order_bots(dotbots, QUEUE_HEAD_X, QUEUE_HEAD_Y) total_count = len(dotbots) # The head of the remaining should park @@ -66,7 +74,7 @@ async def charge_robots( parked_count = total_count - len(remaining) while remaining or park_dotbot is not None: - dotbots = await client.fetch_active_dotbots() + dotbots = await fetch_active_dotbots(client) dotbots = [b for b in dotbots if b.address in {r.address for r in remaining}] remaining = order_bots(dotbots, QUEUE_HEAD_X, QUEUE_HEAD_Y) @@ -133,7 +141,7 @@ async def send_to_goal( params: OrcaParams, ) -> None: while True: - dotbots = await client.fetch_active_dotbots() + dotbots = await fetch_active_dotbots(client) agents: List[Agent] = [] for bot in dotbots: @@ -309,7 +317,7 @@ async def main() -> None: port = os.getenv("DOTBOT_CONTROLLER_PORT", "8000") use_https = os.getenv("DOTBOT_CONTROLLER_USE_HTTPS", False) async with rest_client(url, port, use_https) as client: - dotbots = await client.fetch_active_dotbots() + dotbots = await fetch_active_dotbots(client) ws = DotBotWsClient(url, port) await ws.connect() diff --git a/dotbot/joystick.py b/dotbot/joystick.py index 5b5a57d..15d3068 100644 --- a/dotbot/joystick.py +++ b/dotbot/joystick.py @@ -20,7 +20,7 @@ pydotbot_version, ) from dotbot.logger import LOGGER, setup_logging -from dotbot.models import DotBotMoveRawCommandModel +from dotbot.models import DotBotMoveRawCommandModel, DotBotQueryModel, DotBotStatus from dotbot.protocol import ApplicationType from dotbot.rest import rest_client @@ -99,12 +99,18 @@ def pos_from_joystick(self): async def fetch_active_dotbots(self): while 1: - self.dotbots = await self.api.fetch_active_dotbots() + self.dotbots = await self.api.fetch_active_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) await asyncio.sleep(1) async def start(self): """Starts to read continuously joystick positions.""" - asyncio.create_task(self.fetch_active_dotbots()) + asyncio.create_task( + self.fetch_active_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) + ) while True: # fetch positions from joystick positions = self.pos_from_joystick() diff --git a/dotbot/keyboard.py b/dotbot/keyboard.py index a788bde..9e5e37a 100644 --- a/dotbot/keyboard.py +++ b/dotbot/keyboard.py @@ -12,6 +12,7 @@ import click +from dotbot.models import DotBotQueryModel, DotBotStatus from dotbot.rest import rest_client try: @@ -236,12 +237,18 @@ async def refresh_speeds(self): async def fetch_active_dotbots(self): while 1: - self.dotbots = await self.api.fetch_active_dotbots() + self.dotbots = await self.api.fetch_active_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) await asyncio.sleep(1) async def start(self): """Starts to continuously listen on keyboard key press/release events.""" - asyncio.create_task(self.fetch_active_dotbots()) + asyncio.create_task( + self.fetch_active_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) + ) asyncio.create_task(self.update_active_keys()) while 1: await self.refresh_speeds() diff --git a/dotbot/models.py b/dotbot/models.py index 331ce9f..a1ae70f 100644 --- a/dotbot/models.py +++ b/dotbot/models.py @@ -59,7 +59,7 @@ class DotBotLH2Position(BaseModel): x: float y: float - z: float + z: float = 0.0 class DotBotControlModeModel(BaseModel): @@ -100,11 +100,17 @@ class DotBotStatus(IntEnum): class DotBotQueryModel(BaseModel): """Model class used to filter DotBots.""" - max_positions: int = MAX_POSITION_HISTORY_SIZE + limit: Optional[int] = None + address: Optional[str] = None application: Optional[ApplicationType] = None - mode: Optional[ControlModeType] = None status: Optional[DotBotStatus] = None - swarm: Optional[str] = None + max_battery: Optional[float] = None + min_battery: Optional[float] = None + max_positions: int = None + max_position_x: Optional[float] = None + min_position_x: Optional[float] = None + max_position_y: Optional[float] = None + min_position_y: Optional[float] = None class DotBotNotificationCommand(IntEnum): @@ -179,7 +185,7 @@ class DotBotModel(BaseModel): waypoints_threshold: int = 100 # in mm position_history: List[Union[DotBotLH2Position, DotBotGPSPosition]] = [] calibrated: int = 0x00 # Bitmask: first lighthouse = 0x01, second lighthouse = 0x02 - battery: float = 0.0 # Voltage in Volts + battery: float = 3.0 # Voltage in Volts class WSBase(BaseModel): diff --git a/dotbot/rest.py b/dotbot/rest.py index 89b6ab1..ae1a70c 100644 --- a/dotbot/rest.py +++ b/dotbot/rest.py @@ -5,13 +5,14 @@ """Module containing client code to interact with the controller REST API.""" +import urllib.parse from contextlib import asynccontextmanager -from typing import List +from typing import List, Optional import httpx from dotbot.logger import LOGGER, setup_logging -from dotbot.models import DotBotModel, DotBotStatus +from dotbot.models import DotBotModel, DotBotQueryModel from dotbot.protocol import ApplicationType @@ -43,11 +44,16 @@ def base_url(self): async def close(self): await self._client.aclose() - async def fetch_active_dotbots(self) -> List[DotBotModel]: - """Fetch active DotBots.""" + async def fetch_dotbots( + self, query: Optional[DotBotQueryModel] = None + ) -> List[DotBotModel]: + """Fetch DotBots matching the query.""" try: + url = f"{self.base_url}/dotbots" + if query is not None: + url += f"?{urllib.parse.urlencode(query.model_dump(exclude_none=True))}" response = await self._client.get( - f"{self.base_url}/dotbots", + url, headers={ "Accept": "application/json", }, @@ -60,11 +66,7 @@ async def fetch_active_dotbots(self) -> List[DotBotModel]: f"Failed to fetch dotbots: {response} {response.text}" ) else: - return [ - DotBotModel(**dotbot) - for dotbot in response.json() - if dotbot["status"] == DotBotStatus.ACTIVE.value - ] + return [DotBotModel(**dotbot) for dotbot in response.json()] return [] async def _send_command(self, address, application, resource, command): diff --git a/dotbot/server.py b/dotbot/server.py index aab28c8..b12e4bf 100644 --- a/dotbot/server.py +++ b/dotbot/server.py @@ -6,10 +6,16 @@ """Module for the web server application.""" import os -from typing import List +from typing import Annotated, List import httpx -from fastapi import Depends, FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from fastapi import ( + FastAPI, + HTTPException, + Query, + WebSocket, + WebSocketDisconnect, +) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from fastapi.staticfiles import StaticFiles @@ -19,6 +25,7 @@ from dotbot import pydotbot_version from dotbot.logger import LOGGER from dotbot.models import ( + MAX_POSITION_HISTORY_SIZE, DotBotMapSizeModel, DotBotModel, DotBotMoveRawCommandModel, @@ -228,12 +235,12 @@ async def dotbot_positions_history_clear(address: str): summary="Return information about a dotbot given its address", tags=["dotbots"], ) -async def dotbot(address: str, query: DotBotQueryModel = Depends()): +async def dotbot(address: str, max_positions: int = MAX_POSITION_HISTORY_SIZE): """Dotbot HTTP GET handler.""" if address not in api.controller.dotbots: raise HTTPException(status_code=404, detail="No matching dotbot found") _dotbot = DotBotModel(**api.controller.dotbots[address].model_dump()) - _dotbot.position_history = _dotbot.position_history[: query.max_positions] + _dotbot.position_history = _dotbot.position_history[:max_positions] return _dotbot @@ -244,7 +251,7 @@ async def dotbot(address: str, query: DotBotQueryModel = Depends()): summary="Return the list of available dotbots", tags=["dotbots"], ) -async def dotbots(query: DotBotQueryModel = Depends()): +async def dotbots(query: Annotated[DotBotQueryModel, Query()]): """Dotbots HTTP GET handler.""" return api.controller.get_dotbots(query) diff --git a/dotbot/tests/test_controller.py b/dotbot/tests/test_controller.py index d22e25b..b8bdeb7 100644 --- a/dotbot/tests/test_controller.py +++ b/dotbot/tests/test_controller.py @@ -2,7 +2,7 @@ import asyncio import time -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from dotbot_utils.hdlc import hdlc_encode @@ -11,16 +11,26 @@ from dotbot.adapter import SerialAdapter from dotbot.controller import Controller, ControllerSettings, gps_distance, lh2_distance -from dotbot.models import DotBotGPSPosition, DotBotLH2Position, DotBotModel -from dotbot.protocol import ControlModeType, PayloadControlMode +from dotbot.models import ( + DotBotGPSPosition, + DotBotLH2Position, + DotBotModel, + DotBotQueryModel, + DotBotStatus, +) +from dotbot.protocol import ApplicationType, ControlModeType, PayloadControlMode -@pytest.mark.asyncio -@patch("dotbot_utils.serial_interface.serial.Serial.write") -@patch("dotbot_utils.serial_interface.serial.Serial.open") -@patch("dotbot_utils.serial_interface.serial.Serial.flush") -async def test_controller(_, __, serial_write, capsys): - """Check controller subclass instanciation and write to serial.""" +@pytest.fixture +def controller(monkeypatch): + """Create a controller instance with mocked serial interface.""" + monkeypatch.setattr( + "dotbot_utils.serial_interface.serial.Serial.write", MagicMock() + ) + monkeypatch.setattr("dotbot_utils.serial_interface.serial.Serial.open", MagicMock()) + monkeypatch.setattr( + "dotbot_utils.serial_interface.serial.Serial.flush", MagicMock() + ) settings = ControllerSettings( port="/dev/null", baudrate=115200, @@ -28,18 +38,62 @@ async def test_controller(_, __, serial_write, capsys): dotbot_address="456", gw_address="78", ) - controller = Controller(settings) - controller.dotbots.update( + _controller = Controller(settings) + _controller.dotbots.update( { "0000000000000000": DotBotModel( - address="0000000000000000", last_seen=time.time() - ) + address="0000000000000000", + last_seen=time.time(), + application=ApplicationType.DotBot, + status=DotBotStatus.ACTIVE, + battery=2.0, + lh2_position=DotBotLH2Position(x=1000, y=1000), + position_history=[ + DotBotLH2Position(x=900, y=900), + DotBotLH2Position(x=800, y=800), + ], + ), + "0000000000000001": DotBotModel( + address="0000000000000001", + last_seen=time.time(), + application=ApplicationType.SailBot, + status=DotBotStatus.ACTIVE, + battery=3.0, + ), + "0000000000000002": DotBotModel( + address="0000000000000002", + last_seen=time.time(), + application=ApplicationType.DotBot, + status=DotBotStatus.INACTIVE, + battery=1.0, + lh2_position=DotBotLH2Position(x=500, y=500), + position_history=[ + DotBotLH2Position(x=400, y=400), + DotBotLH2Position(x=300, y=300), + ], + ), + "0000000000000003": DotBotModel( + address="0000000000000003", + last_seen=time.time(), + application=ApplicationType.DotBot, + status=DotBotStatus.LOST, + battery=1.0, + lh2_position=DotBotLH2Position(x=1000, y=1500), + position_history=[], + ), } ) - controller.adapter = SerialAdapter(settings.port, settings.baudrate) - controller.adapter.serial = SerialInterface( + _controller.adapter = SerialAdapter(settings.port, settings.baudrate) + _controller.adapter.serial = SerialInterface( settings.port, settings.baudrate, lambda: None ) + + yield _controller + + +@pytest.mark.asyncio +async def test_controller(controller): + """Check controller subclass instanciation and write to serial.""" frame = Frame( header=Header( destination=0, @@ -48,34 +102,97 @@ async def test_controller(_, __, serial_write, capsys): packet=Packet().from_payload(PayloadControlMode(mode=ControlModeType.AUTO)), ) controller.send_payload(0, PayloadControlMode(mode=ControlModeType.AUTO)) - assert serial_write.call_count == 1 + + serial_write_mock = controller.adapter.serial.serial.write + assert serial_write_mock.call_count == 1 payload_expected = hdlc_encode(frame.to_bytes()) - assert serial_write.call_args_list[0].args[0] == payload_expected + assert serial_write_mock.call_args_list[0].args[0] == payload_expected @pytest.mark.asyncio -@patch("dotbot_utils.serial_interface.serial.Serial.write") -@patch("dotbot_utils.serial_interface.serial.Serial.open") -@patch("dotbot_utils.serial_interface.serial.Serial.flush") -async def test_controller_dont_send(_, __, serial_write): +async def test_controller_dont_send(controller): """Check controller subclass instanciation and write to serial.""" - settings = ControllerSettings( - port="/dev/null", - baudrate=115200, - network_id="0", - dotbot_address="456", - gw_address="78", - ) - controller = Controller(settings) - dotbot = DotBotModel(address="0000000000000000", last_seen=time.time()) - controller.dotbots.update({dotbot.address: dotbot}) - controller.adapter = SerialAdapter(settings.port, settings.baudrate) - controller.adapter.serial = SerialInterface( - settings.port, settings.baudrate, lambda: None - ) + serial_write_mock = controller.adapter.serial.serial.write # DotBot is not in the controller known dotbot, so the payload won't be sent - controller.send_payload(1, PayloadControlMode(mode=ControlModeType.AUTO)) - assert serial_write.call_count == 0 + controller.send_payload(42, PayloadControlMode(mode=ControlModeType.AUTO)) + assert serial_write_mock.call_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "query,length", + [ + pytest.param( + DotBotQueryModel(address="0000000000000001"), + 1, + id="by address", + ), + pytest.param( + DotBotQueryModel(application=ApplicationType.SailBot), + 1, + id="by application", + ), + pytest.param( + DotBotQueryModel(status=DotBotStatus.ACTIVE), + 2, + id="by status active", + ), + pytest.param( + DotBotQueryModel(status=DotBotStatus.INACTIVE), + 1, + id="by status inactive", + ), + pytest.param( + DotBotQueryModel(status=DotBotStatus.LOST), + 1, + id="by status lost", + ), + pytest.param( + DotBotQueryModel(min_battery=2.5), + 1, + id="by min battery", + ), + pytest.param( + DotBotQueryModel(max_battery=1.5), + 2, + id="by max battery", + ), + pytest.param( + DotBotQueryModel(max_position_x=600), + 1, + id="by max position x", + ), + pytest.param( + DotBotQueryModel(min_position_x=800), + 2, + id="by min position x", + ), + pytest.param( + DotBotQueryModel(max_position_y=600), + 1, + id="by max position y", + ), + pytest.param( + DotBotQueryModel(min_position_y=1000), + 2, + id="by min position y", + ), + pytest.param( + DotBotQueryModel(max_positions=1), + 3, + id="by max positions", + ), + pytest.param( + DotBotQueryModel(limit=2), + 2, + id="by limit", + ), + ], +) +async def test_controller_get_dotbots_query(query, length, controller): + """Check controller get_dotbots query.""" + dotbots = controller.get_dotbots(query=query) + assert len(dotbots) == length @pytest.mark.filterwarnings("ignore::DeprecationWarning") diff --git a/dotbot/tests/test_experiment_charging_station.py b/dotbot/tests/test_experiment_charging_station.py index 191559c..197bc4f 100644 --- a/dotbot/tests/test_experiment_charging_station.py +++ b/dotbot/tests/test_experiment_charging_station.py @@ -51,7 +51,7 @@ def __init__(self, dotbots: List[DotBotModel]): self.move_raw_commands = [] self.rgb_commands = [] - async def fetch_active_dotbots(self) -> List[DotBotModel]: + async def fetch_dotbots(self, query=None) -> List[DotBotModel]: return list(self._dotbots.values()) async def send_waypoint_command( diff --git a/dotbot/tests/test_rest.py b/dotbot/tests/test_rest.py index 1e182e1..ca2fea6 100644 --- a/dotbot/tests/test_rest.py +++ b/dotbot/tests/test_rest.py @@ -7,7 +7,9 @@ DotBotLH2Position, DotBotModel, DotBotMoveRawCommandModel, + DotBotQueryModel, DotBotRgbLedCommandModel, + DotBotStatus, DotBotWaypoints, ) from dotbot.protocol import ApplicationType @@ -25,7 +27,7 @@ httpx.Response( 200, json=[{"address": "test", "status": 1, "last_seen": 0}] ), - [], + [DotBotModel(**{"address": "test", "status": 1, "last_seen": 0})], id="none active", ), pytest.param( @@ -38,17 +40,42 @@ ], ) @mock.patch("httpx.AsyncClient.get") -async def test_fetch_active_dotbots(get, response, expected): +async def test_fetch_dotbots(get, response, expected): if response == httpx.ConnectError: get.side_effect = response("error") else: get.return_value = response async with rest_client("localhost", 1234, False) as client: - result = await client.fetch_active_dotbots() + result = await client.fetch_dotbots() get.assert_called_once() + get.assert_called_with( + "http://localhost:1234/controller/dotbots", + headers={ + "Accept": "application/json", + }, + ) assert result == expected +@pytest.mark.asyncio +@mock.patch("httpx.AsyncClient.get") +async def test_fetch_dotbots_with_query(get): + dotbots = [{"address": "test", "status": 0, "last_seen": 0}] + get.return_value = httpx.Response(200, json=dotbots) + async with rest_client("localhost", 1234, False) as client: + result = await client.fetch_dotbots( + query=DotBotQueryModel(status=DotBotStatus.ACTIVE) + ) + get.assert_called_once() + get.assert_called_with( + "http://localhost:1234/controller/dotbots?status=0", + headers={ + "Accept": "application/json", + }, + ) + assert result == [DotBotModel(**dotbots[0])] + + @pytest.mark.asyncio @pytest.mark.parametrize( "response,application,command",