diff --git a/dimos/__init__.py b/dimos/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 37e1a4757c..6e24cee870 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import json from queue import Empty, Queue from threading import Event, RLock, Thread @@ -31,14 +30,13 @@ from dimos.core.module import Module, ModuleConfig, SkillInfo from dimos.core.rpc_client import RpcCall, RPCClient from dimos.core.stream import In, Out -from dimos.protocol.rpc import RPCSpec +from dimos.protocol.rpc.spec import RPCSpec from dimos.spec.utils import Spec if TYPE_CHECKING: from langchain_core.language_models import BaseChatModel -@dataclass class AgentConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -58,8 +56,8 @@ class Agent(Module[AgentConfig]): _thread: Thread _stop_event: Event - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/agent_test_runner.py b/dimos/agents/agent_test_runner.py index 7d7fbab03d..7a4ba2a94e 100644 --- a/dimos/agents/agent_test_runner.py +++ b/dimos/agents/agent_test_runner.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from threading import Event, Thread +from typing import Any from langchain_core.messages import AIMessage from langchain_core.messages.base import BaseMessage @@ -20,21 +22,26 @@ from dimos.agents.agent import AgentSpec from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out -class AgentTestRunner(Module): +class Config(ModuleConfig): + messages: Iterable[BaseMessage] + + +class AgentTestRunner(Module[Config]): + default_config = Config + agent_spec: AgentSpec agent: In[BaseMessage] agent_idle: In[bool] finished: Out[bool] added: Out[bool] - def __init__(self, messages: list[BaseMessage]) -> None: - super().__init__() - self._messages = messages + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._idle_event = Event() self._subscription_ready = Event() self._thread = Thread(target=self._thread_loop, daemon=True) @@ -71,7 +78,7 @@ def _thread_loop(self) -> None: if not self._subscription_ready.wait(5): raise TimeoutError("Timed out waiting for subscription to be ready.") - for message in self._messages: + for message in self.config.messages: self._idle_event.clear() self.agent_spec.add_message(message) if not self._idle_event.wait(60): diff --git a/dimos/agents/demo_agent.py b/dimos/agents/demo_agent.py index bd69fc6cae..b839b0809c 100644 --- a/dimos/agents/demo_agent.py +++ b/dimos/agents/demo_agent.py @@ -14,9 +14,9 @@ from dimos.agents.agent import Agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera import zed from dimos.hardware.sensors.camera.module import camera_module from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.hardware.sensors.camera.zed import compat as zed demo_agent = autoconnect(Agent.blueprint()) diff --git a/dimos/agents/mcp/__init__.py b/dimos/agents/mcp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents/mcp/mcp_adapter.py b/dimos/agents/mcp/mcp_adapter.py index 9b8cc5c4b9..213bf71e23 100644 --- a/dimos/agents/mcp/mcp_adapter.py +++ b/dimos/agents/mcp/mcp_adapter.py @@ -63,10 +63,6 @@ def __init__(self, url: str | None = None, timeout: int = DEFAULT_TIMEOUT) -> No self.url = url self.timeout = timeout - # ------------------------------------------------------------------ - # Low-level JSON-RPC - # ------------------------------------------------------------------ - def call(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: """Send a JSON-RPC request and return the parsed response. @@ -87,10 +83,6 @@ def call(self, method: str, params: dict[str, Any] | None = None) -> dict[str, A raise McpError(f"HTTP {resp.status_code}: {e}") from e return resp.json() # type: ignore[no-any-return] - # ------------------------------------------------------------------ - # MCP standard methods - # ------------------------------------------------------------------ - def initialize(self) -> dict[str, Any]: """Send ``initialize`` and return server info.""" return self.call("initialize") @@ -112,10 +104,6 @@ def call_tool_text(self, name: str, arguments: dict[str, Any] | None = None) -> return "" return content[0].get("text", str(content[0])) # type: ignore[no-any-return] - # ------------------------------------------------------------------ - # Readiness probes - # ------------------------------------------------------------------ - def wait_for_ready(self, timeout: float = 10.0, interval: float = 0.5) -> bool: """Poll until the MCP server responds, or return False on timeout.""" deadline = time.monotonic() + timeout @@ -148,10 +136,6 @@ def wait_for_down(self, timeout: float = 10.0, interval: float = 0.5) -> bool: time.sleep(interval) return False - # ------------------------------------------------------------------ - # Class methods for discovery - # ------------------------------------------------------------------ - @classmethod def from_run_entry(cls, entry: Any | None = None, timeout: int = DEFAULT_TIMEOUT) -> McpAdapter: """Create an adapter from a RunEntry, or discover the latest one. @@ -173,10 +157,6 @@ def from_run_entry(cls, entry: Any | None = None, timeout: int = DEFAULT_TIMEOUT url = f"http://localhost:{global_config.mcp_port}/mcp" return cls(url=url, timeout=timeout) - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - @staticmethod def _unwrap(response: dict[str, Any]) -> dict[str, Any]: """Extract the ``result`` from a JSON-RPC response, raising on error.""" diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 7c5eda5302..a2ee872e16 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from queue import Empty, Queue from threading import Event, RLock, Thread import time @@ -39,7 +38,6 @@ logger = setup_logger() -@dataclass class McpClientConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -62,8 +60,8 @@ class McpClient(Module[McpClientConfig]): _http_client: httpx.Client _seq_ids: SequentialIds - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._lock = RLock() self._state_graph = None self._message_queue = Queue() diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index bfd45bc58a..9149de06ec 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import json import os import time @@ -22,7 +23,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from starlette.requests import Request # noqa: TC002 +from starlette.requests import Request from starlette.responses import Response import uvicorn @@ -32,14 +33,11 @@ from dimos.core.rpc_client import RpcCall, RPCClient from dimos.utils.logging_config import setup_logger -logger = setup_logger() - - if TYPE_CHECKING: - import concurrent.futures - from dimos.core.module import SkillInfo +logger = setup_logger() + app = FastAPI() app.add_middleware( @@ -52,11 +50,6 @@ app.state.rpc_calls = {} -# --------------------------------------------------------------------------- -# JSON-RPC helpers -# --------------------------------------------------------------------------- - - def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]: return {"jsonrpc": "2.0", "id": req_id, "result": result} @@ -69,11 +62,6 @@ def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]: return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} -# --------------------------------------------------------------------------- -# JSON-RPC handlers (standard MCP protocol only) -# --------------------------------------------------------------------------- - - def _handle_initialize(req_id: Any) -> dict[str, Any]: return _jsonrpc_result( req_id, @@ -179,16 +167,9 @@ async def mcp_endpoint(request: Request) -> Response: return JSONResponse(result) -# --------------------------------------------------------------------------- -# McpServer Module -# --------------------------------------------------------------------------- - - class McpServer(Module): - def __init__(self) -> None: - super().__init__() - self._uvicorn_server: uvicorn.Server | None = None - self._serve_future: concurrent.futures.Future[None] | None = None + _uvicorn_server: uvicorn.Server | None = None + _serve_future: concurrent.futures.Future[None] | None = None @rpc def start(self) -> None: @@ -219,10 +200,6 @@ def on_system_modules(self, modules: list[RPCClient]) -> None: for skill_info in app.state.skills } - # ------------------------------------------------------------------ - # Introspection skills (exposed as MCP tools via tools/list) - # ------------------------------------------------------------------ - @skill def server_status(self) -> str: """Get MCP server status: main process PID, deployed modules, and skill count.""" diff --git a/dimos/agents/mcp/test_mcp_client.py b/dimos/agents/mcp/test_mcp_client.py index 16427103e4..c903e5f11c 100644 --- a/dimos/agents/mcp/test_mcp_client.py +++ b/dimos/agents/mcp/test_mcp_client.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from langchain_core.messages import HumanMessage import pytest from dimos.agents.annotation import skill from dimos.core.module import Module -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data @@ -40,10 +41,8 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - def __init__(self): - super().__init__() - self._first_call = True - self._use_upper = False + _first_call = True + _use_upper = False @skill def register_user(self, name: str) -> str: @@ -79,8 +78,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/agents/skills/demo_robot.py b/dimos/agents/skills/demo_robot.py index aa4e81e2cc..789e26d7e1 100644 --- a/dimos/agents/skills/demo_robot.py +++ b/dimos/agents/skills/demo_robot.py @@ -17,7 +17,7 @@ from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon class DemoRobot(Module): diff --git a/dimos/agents/skills/google_maps_skill_container.py b/dimos/agents/skills/google_maps_skill_container.py index 7e402e32d7..e218601696 100644 --- a/dimos/agents/skills/google_maps_skill_container.py +++ b/dimos/agents/skills/google_maps_skill_container.py @@ -20,7 +20,7 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.mapping.google_maps.google_maps import GoogleMaps -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -32,8 +32,8 @@ class GoogleMapsSkillContainer(Module): gps_location: In[LatLon] - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) try: self._client = GoogleMaps() except ValueError: diff --git a/dimos/agents/skills/gps_nav_skill.py b/dimos/agents/skills/gps_nav_skill.py index 721119f6e6..1464665131 100644 --- a/dimos/agents/skills/gps_nav_skill.py +++ b/dimos/agents/skills/gps_nav_skill.py @@ -19,7 +19,7 @@ from dimos.core.module import Module from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.utils.logging_config import setup_logger @@ -34,9 +34,6 @@ class GpsNavSkillContainer(Module): gps_location: In[LatLon] gps_goal: Out[LatLon] - def __init__(self) -> None: - super().__init__() - @rpc def start(self) -> None: super().start() diff --git a/dimos/agents/skills/navigation.py b/dimos/agents/skills/navigation.py index b02ff3a446..47ae21c799 100644 --- a/dimos/agents/skills/navigation.py +++ b/dimos/agents/skills/navigation.py @@ -22,9 +22,10 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.models.qwen.bbox import BBox -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.geometry_msgs.Vector3 import make_vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3, make_vector3 +from dimos.msgs.sensor_msgs.Image import Image from dimos.navigation.base import NavigationState from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.types.robot_location import RobotLocation @@ -55,8 +56,8 @@ class NavigationSkillContainer(Module): color_image: In[Image] odom: In[PoseStamped] - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._skill_started = False # Here to prevent unwanted imports in the file. diff --git a/dimos/agents/skills/osm.py b/dimos/agents/skills/osm.py index 613bc0806e..d0281fb808 100644 --- a/dimos/agents/skills/osm.py +++ b/dimos/agents/skills/osm.py @@ -16,8 +16,8 @@ from dimos.agents.annotation import skill from dimos.core.module import Module from dimos.core.stream import In +from dimos.mapping.models import LatLon from dimos.mapping.osm.current_location_map import CurrentLocationMap -from dimos.mapping.types import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.models.vl.qwen import QwenVlModel from dimos.utils.logging_config import setup_logger diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index e59ddb3b2a..f1cafed6cd 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -14,7 +14,7 @@ from threading import Event, RLock, Thread import time -from typing import TYPE_CHECKING +from typing import Any from langchain_core.messages import HumanMessage import numpy as np @@ -23,26 +23,30 @@ from dimos.agents.agent import AgentSpec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.models.qwen.bbox import BBox +from dimos.models.segmentation.edge_tam import EdgeTAMProcessor +from dimos.models.vl.base import VlModel from dimos.models.vl.create import create -from dimos.msgs.geometry_msgs import Twist -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.models.segmentation.edge_tam import EdgeTAMProcessor - from dimos.models.vl.base import VlModel - logger = setup_logger() -class PersonFollowSkillContainer(Module): +class Config(ModuleConfig): + camera_info: CameraInfo + use_3d_navigation: bool = False + + +class PersonFollowSkillContainer(Module[Config]): """Skill container for following a person. This skill uses: @@ -52,6 +56,8 @@ class PersonFollowSkillContainer(Module): - Does not do obstacle avoidance; assumes a clear path. """ + default_config = Config + color_image: In[Image] global_map: In[PointCloud2] cmd_vel: Out[Twist] @@ -60,38 +66,31 @@ class PersonFollowSkillContainer(Module): _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost - def __init__( - self, - camera_info: CameraInfo, - cfg: GlobalConfig, - use_3d_navigation: bool = False, - ) -> None: - super().__init__() - self._global_config: GlobalConfig = cfg - self._use_3d_navigation: bool = use_3d_navigation + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._latest_image: Image | None = None self._latest_pointcloud: PointCloud2 | None = None - self._vl_model: VlModel = create("qwen") + self._vl_model: VlModel[Any] = create("qwen") self._tracker: EdgeTAMProcessor | None = None self._thread: Thread | None = None self._should_stop: Event = Event() self._lock = RLock() # Use MuJoCo camera intrinsics in simulation mode - if self._global_config.simulation: + camera_info = self.config.camera_info + if self.config.g.simulation: from dimos.robot.unitree.mujoco_connection import MujocoConnection camera_info = MujocoConnection.camera_info_static - self._camera_info = camera_info - self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation) + self._visual_servo = VisualServoing2D(camera_info, self.config.g.simulation) self._detection_navigation = DetectionNavigation(self.tf, camera_info) @rpc def start(self) -> None: super().start() self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image))) - if self._use_3d_navigation: + if self.config.use_3d_navigation: self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud))) @rpc @@ -230,7 +229,7 @@ def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None: lost_count = 0 best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume()) - if self._use_3d_navigation: + if self.config.use_3d_navigation: with self._lock: pointcloud = self._latest_pointcloud if pointcloud is None: diff --git a/dimos/agents/skills/test_google_maps_skill_container.py b/dimos/agents/skills/test_google_maps_skill_container.py index 1d8e4549b0..376d6d306e 100644 --- a/dimos/agents/skills/test_google_maps_skill_container.py +++ b/dimos/agents/skills/test_google_maps_skill_container.py @@ -13,6 +13,7 @@ # limitations under the License. import re +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -20,8 +21,8 @@ from dimos.agents.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position -from dimos.mapping.types import LatLon +from dimos.mapping.google_maps.models import Coordinates, LocationContext, Position +from dimos.mapping.models import LatLon class FakeGPS(Module): @@ -39,8 +40,8 @@ def get_location_context(self, location, radius=200): class MockedWhereAmISkill(GoogleMapsSkillContainer): - def __init__(self): - Module.__init__(self) # Skip GoogleMapsSkillContainer's __init__. + def __init__(self, **kwargs: Any): + Module.__init__(self, **kwargs) # Skip GoogleMapsSkillContainer's __init__. self._client = FakeLocationClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True @@ -62,8 +63,8 @@ def get_position(self, query, location): class MockedPositionSkill(GoogleMapsSkillContainer): - def __init__(self): - Module.__init__(self) + def __init__(self, **kwargs: Any): + Module.__init__(self, **kwargs) self._client = FakePositionClient() self._latest_location = LatLon(lat=37.782654, lon=-122.413273) self._started = True diff --git a/dimos/agents/skills/test_gps_nav_skills.py b/dimos/agents/skills/test_gps_nav_skills.py index d701d469ca..c1e380ccd1 100644 --- a/dimos/agents/skills/test_gps_nav_skills.py +++ b/dimos/agents/skills/test_gps_nav_skills.py @@ -18,7 +18,7 @@ from dimos.agents.skills.gps_nav_skill import GpsNavSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon class FakeGPS(Module): @@ -28,11 +28,9 @@ class FakeGPS(Module): class MockedGpsNavSkill(GpsNavSkillContainer): - def __init__(self): - Module.__init__(self) - self._latest_location = LatLon(lat=37.782654, lon=-122.413273) - self._started = True - self._max_valid_distance = 50000 + _latest_location = LatLon(lat=37.782654, lon=-122.413273) + _started = True + _max_valid_distance = 50000 @pytest.mark.slow diff --git a/dimos/agents/skills/test_navigation.py b/dimos/agents/skills/test_navigation.py index a7505b23c7..e4a60db081 100644 --- a/dimos/agents/skills/test_navigation.py +++ b/dimos/agents/skills/test_navigation.py @@ -18,8 +18,8 @@ from dimos.agents.skills.navigation import NavigationSkillContainer from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image class FakeCamera(Module): @@ -31,23 +31,17 @@ class FakeOdom(Module): class MockedStopNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _cancel_goal_and_stop(self): pass class MockedExploreNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _start_exploration(self, timeout): return "Exploration completed successfuly" @@ -56,12 +50,9 @@ def _cancel_goal_and_stop(self): class MockedSemanticNavSkill(NavigationSkillContainer): + _skill_started = True rpc_calls: list[str] = [] - def __init__(self): - Module.__init__(self) - self._skill_started = True - def _navigate_by_tagged_location(self, query): return None diff --git a/dimos/agents/skills/test_unitree_skill_container.py b/dimos/agents/skills/test_unitree_skill_container.py index dde7239bbd..92b006dce5 100644 --- a/dimos/agents/skills/test_unitree_skill_container.py +++ b/dimos/agents/skills/test_unitree_skill_container.py @@ -13,6 +13,7 @@ # limitations under the License. import difflib +from typing import Any from langchain_core.messages import HumanMessage import pytest @@ -23,8 +24,8 @@ class MockedUnitreeSkill(UnitreeSkillContainer): rpc_calls: list[str] = [] - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) # Provide a fake RPC so the real execute_sport_command runs end-to-end. self._bound_rpc_calls["GO2Connection.publish_request"] = lambda *args, **kwargs: None diff --git a/dimos/agents/test_agent.py b/dimos/agents/test_agent.py index 2464e622ca..e925e52a4a 100644 --- a/dimos/agents/test_agent.py +++ b/dimos/agents/test_agent.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from langchain_core.messages import HumanMessage import pytest from dimos.agents.annotation import skill from dimos.core.module import Module -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data @@ -40,10 +41,8 @@ def test_can_call_tool(agent_setup): class UserRegistration(Module): - def __init__(self): - super().__init__() - self._first_call = True - self._use_upper = False + _first_call = True + _use_upper = False @skill def register_user(self, name: str) -> str: @@ -81,8 +80,8 @@ def test_can_call_again_on_error(agent_setup): class MultipleTools(Module): - def __init__(self): - super().__init__() + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) self._people = {"Ben": "office", "Bob": "garage"} @skill diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index ec0aec1442..81bad79ae5 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import TYPE_CHECKING, Any from langchain.chat_models import init_chat_model @@ -22,7 +21,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -31,24 +30,22 @@ logger = setup_logger() -@dataclass class VLMAgentConfig(ModuleConfig): model: str = "gpt-4o" system_prompt: str | None = SYSTEM_PROMPT -class VLMAgent(Module): +class VLMAgent(Module[VLMAgentConfig]): """Stream-first agent for vision queries with optional RPC access.""" - default_config: type[VLMAgentConfig] = VLMAgentConfig - config: VLMAgentConfig + default_config = VLMAgentConfig color_image: In[Image] query_stream: In[HumanMessage] answer_stream: Out[AIMessage] - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) if self.config.model.startswith("ollama:"): from dimos.agents.ollama_agent import ensure_ollama_model diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py index 4126c6b3a0..5f2165dc8d 100644 --- a/dimos/agents/vlm_stream_tester.py +++ b/dimos/agents/vlm_stream_tester.py @@ -20,7 +20,7 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/agents_deprecated/__init__.py b/dimos/agents_deprecated/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/agent.py b/dimos/agents_deprecated/agent.py index 0443b2cc94..1d48ce2fa4 100644 --- a/dimos/agents_deprecated/agent.py +++ b/dimos/agents_deprecated/agent.py @@ -68,9 +68,6 @@ _MAX_SAVED_FRAMES = 100 # Maximum number of frames to save -# ----------------------------------------------------------------------------- -# region Agent Base Class -# ----------------------------------------------------------------------------- class Agent: """Base agent that manages memory and subscriptions.""" @@ -105,12 +102,6 @@ def dispose_all(self) -> None: logger.info("No disposables to dispose.") -# endregion Agent Base Class - - -# ----------------------------------------------------------------------------- -# region LLMAgent Base Class (Generic LLM Agent) -# ----------------------------------------------------------------------------- class LLMAgent(Agent): """Generic LLM agent containing common logic for LLM-based agents. @@ -689,12 +680,6 @@ def dispose_all(self) -> None: self.response_subject.on_completed() -# endregion LLMAgent Base Class (Generic LLM Agent) - - -# ----------------------------------------------------------------------------- -# region OpenAIAgent Subclass (OpenAI-Specific Implementation) -# ----------------------------------------------------------------------------- class OpenAIAgent(LLMAgent): """OpenAI agent implementation that uses OpenAI's API for processing. @@ -914,4 +899,3 @@ def stream_query(self, query_text: str) -> Observable: # type: ignore[type-arg] ) -# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation) diff --git a/dimos/agents_deprecated/memory/__init__.py b/dimos/agents_deprecated/memory/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/modules/__init__.py b/dimos/agents_deprecated/modules/__init__.py deleted file mode 100644 index 99163d55d0..0000000000 --- a/dimos/agents_deprecated/modules/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agent modules for DimOS.""" diff --git a/dimos/agents_deprecated/modules/base.py b/dimos/agents_deprecated/modules/base.py index 891edbe4bd..0927e184fc 100644 --- a/dimos/agents_deprecated/modules/base.py +++ b/dimos/agents_deprecated/modules/base.py @@ -29,9 +29,9 @@ from dimos.utils.logging_config import setup_logger try: - from .gateway import UnifiedGatewayClient + from dimos.agents_deprecated.modules.gateway.client import UnifiedGatewayClient except ImportError: - from dimos.agents_deprecated.modules.gateway import UnifiedGatewayClient + from dimos.agents_deprecated.modules.gateway.client import UnifiedGatewayClient logger = setup_logger() diff --git a/dimos/agents_deprecated/modules/base_agent.py b/dimos/agents_deprecated/modules/base_agent.py index 18ac15b317..d524861f77 100644 --- a/dimos/agents_deprecated/modules/base_agent.py +++ b/dimos/agents_deprecated/modules/base_agent.py @@ -21,7 +21,7 @@ from dimos.agents_deprecated.agent_types import AgentResponse from dimos.agents_deprecated.memory.base import AbstractAgentSemanticMemory from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.utils.logging_config import setup_logger @@ -34,32 +34,34 @@ logger = setup_logger() -class BaseAgentModule(BaseAgent, Module): # type: ignore[misc] +class BaseAgentConfig(ModuleConfig): + model: str = "openai::gpt-4o-mini" + system_prompt: str | None = None + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None + memory: AbstractAgentSemanticMemory | None = None + temperature: float = 0.0 + max_tokens: int = 4096 + max_input_tokens: int = 128000 + max_history: int = 20 + rag_n: int = 4 + rag_threshold: float = 0.45 + process_all_inputs: bool = False + + +class BaseAgentModule(BaseAgent, Module[BaseAgentConfig]): # type: ignore[misc] """Agent module that inherits from BaseAgent and adds DimOS module interface. This provides a thin wrapper around BaseAgent functionality, exposing it through the DimOS module system with RPC methods and stream I/O. """ + default_config = BaseAgentConfig + # Module I/O - AgentMessage based communication message_in: In[AgentMessage] # Primary input for AgentMessage response_out: Out[AgentResponse] # Output AgentResponse objects - def __init__( # type: ignore[no-untyped-def] - self, - model: str = "openai::gpt-4o-mini", - system_prompt: str | None = None, - skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, - memory: AbstractAgentSemanticMemory | None = None, - temperature: float = 0.0, - max_tokens: int = 4096, - max_input_tokens: int = 128000, - max_history: int = 20, - rag_n: int = 4, - rag_threshold: float = 0.45, - process_all_inputs: bool = False, - **kwargs, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the agent module. Args: @@ -82,17 +84,17 @@ def __init__( # type: ignore[no-untyped-def] # Initialize BaseAgent with all functionality BaseAgent.__init__( self, - model=model, - system_prompt=system_prompt, - skills=skills, - memory=memory, - temperature=temperature, - max_tokens=max_tokens, - max_input_tokens=max_input_tokens, - max_history=max_history, - rag_n=rag_n, - rag_threshold=rag_threshold, - process_all_inputs=process_all_inputs, + model=self.config.model, + system_prompt=self.config.system_prompt, + skills=self.config.skills, + memory=self.config.memory, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + max_input_tokens=self.config.max_input_tokens, + max_history=self.config.max_history, + rag_n=self.config.rag_n, + rag_threshold=self.config.rag_threshold, + process_all_inputs=self.config.process_all_inputs, # Don't pass streams - we'll connect them in start() input_query_stream=None, input_data_stream=None, diff --git a/dimos/agents_deprecated/modules/gateway/__init__.py b/dimos/agents_deprecated/modules/gateway/__init__.py deleted file mode 100644 index 58ed40cd95..0000000000 --- a/dimos/agents_deprecated/modules/gateway/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Gateway module for unified LLM access.""" - -from .client import UnifiedGatewayClient -from .utils import convert_tools_to_standard_format, parse_streaming_response - -__all__ = ["UnifiedGatewayClient", "convert_tools_to_standard_format", "parse_streaming_response"] diff --git a/dimos/agents_deprecated/modules/gateway/utils.py b/dimos/agents_deprecated/modules/gateway/utils.py deleted file mode 100644 index 526d3b9724..0000000000 --- a/dimos/agents_deprecated/modules/gateway/utils.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utility functions for gateway operations.""" - -import logging -from typing import Any - -logger = logging.getLogger(__name__) - - -def convert_tools_to_standard_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Convert DimOS tool format to standard format accepted by gateways. - - DimOS tools come from pydantic_function_tool and have this format: - { - "type": "function", - "function": { - "name": "tool_name", - "description": "tool description", - "parameters": { - "type": "object", - "properties": {...}, - "required": [...] - } - } - } - - We keep this format as it's already standard JSON Schema format. - """ - if not tools: - return [] - - # Tools are already in the correct format from pydantic_function_tool - return tools - - -def parse_streaming_response(chunk: dict[str, Any]) -> dict[str, Any]: - """Parse a streaming response chunk into a standard format. - - Args: - chunk: Raw chunk from the gateway - - Returns: - Parsed chunk with standard fields: - - type: "content" | "tool_call" | "error" | "done" - - content: The actual content (text for content type, tool info for tool_call) - - metadata: Additional information - """ - # Handle TensorZero streaming format - if "choices" in chunk: - # OpenAI-style format from TensorZero - choice = chunk["choices"][0] if chunk["choices"] else {} - delta = choice.get("delta", {}) - - if "content" in delta: - return { - "type": "content", - "content": delta["content"], - "metadata": {"index": choice.get("index", 0)}, - } - elif "tool_calls" in delta: - tool_calls = delta["tool_calls"] - if tool_calls: - tool_call = tool_calls[0] - return { - "type": "tool_call", - "content": { - "id": tool_call.get("id"), - "name": tool_call.get("function", {}).get("name"), - "arguments": tool_call.get("function", {}).get("arguments", ""), - }, - "metadata": {"index": tool_call.get("index", 0)}, - } - elif choice.get("finish_reason"): - return { - "type": "done", - "content": None, - "metadata": {"finish_reason": choice["finish_reason"]}, - } - - # Handle direct content chunks - if isinstance(chunk, str): - return {"type": "content", "content": chunk, "metadata": {}} - - # Handle error responses - if "error" in chunk: - return {"type": "error", "content": chunk["error"], "metadata": chunk} - - # Default fallback - return {"type": "unknown", "content": chunk, "metadata": {}} - - -def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> dict[str, Any]: - """Create a properly formatted tool response. - - Args: - tool_id: The ID of the tool call - result: The result from executing the tool - is_error: Whether this is an error response - - Returns: - Formatted tool response message - """ - content = str(result) if not isinstance(result, str) else result - - return { - "role": "tool", - "tool_call_id": tool_id, - "content": content, - "name": None, # Will be filled by the calling code - } - - -def extract_image_from_message(message: dict[str, Any]) -> dict[str, Any] | None: - """Extract image data from a message if present. - - Args: - message: Message dict that may contain image data - - Returns: - Dict with image data and metadata, or None if no image - """ - content = message.get("content", []) - - # Handle list content (multimodal) - if isinstance(content, list): - for item in content: - if isinstance(item, dict): - # OpenAI format - if item.get("type") == "image_url": - return { - "format": "openai", - "data": item["image_url"]["url"], - "detail": item["image_url"].get("detail", "auto"), - } - # Anthropic format - elif item.get("type") == "image": - return { - "format": "anthropic", - "data": item["source"]["data"], - "media_type": item["source"].get("media_type", "image/jpeg"), - } - - return None diff --git a/dimos/agents_deprecated/prompt_builder/__init__.py b/dimos/agents_deprecated/prompt_builder/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/agents_deprecated/prompt_builder/impl.py b/dimos/agents_deprecated/prompt_builder/impl.py index 35c864062a..354057464f 100644 --- a/dimos/agents_deprecated/prompt_builder/impl.py +++ b/dimos/agents_deprecated/prompt_builder/impl.py @@ -148,7 +148,6 @@ def build( # type: ignore[no-untyped-def] # print("system_prompt: ", system_prompt) # print("rag_context: ", rag_context) - # region Token Counts if not override_token_limit: rag_token_cnt = self.tokenizer.token_count(rag_context) system_prompt_token_cnt = self.tokenizer.token_count(system_prompt) @@ -163,7 +162,6 @@ def build( # type: ignore[no-untyped-def] system_prompt_token_cnt = 0 user_query_token_cnt = 0 image_token_cnt = 0 - # endregion Token Counts # Create a component dictionary for dynamic allocation components = { diff --git a/dimos/agents_deprecated/tokenizer/__init__.py b/dimos/agents_deprecated/tokenizer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/control/__init__.py b/dimos/control/__init__.py deleted file mode 100644 index 639f0ba38a..0000000000 --- a/dimos/control/__init__.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ControlCoordinator - Centralized control for multi-arm coordination. - -This module provides a centralized control coordinator that replaces -per-driver/per-controller loops with a single deterministic tick-based system. - -Features: -- Single tick loop (read -> compute -> arbitrate -> route -> write) -- Per-joint arbitration (highest priority wins) -- Mode conflict detection -- Partial command support (hold last value) -- Aggregated preemption notifications - -Example: - >>> from dimos.control import ControlCoordinator - >>> from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig - >>> from dimos.hardware.manipulators.xarm import XArmAdapter - >>> - >>> # Create coordinator - >>> coord = ControlCoordinator(tick_rate=100.0) - >>> - >>> # Add hardware - >>> adapter = XArmAdapter(ip="192.168.1.185", dof=7) - >>> adapter.connect() - >>> coord.add_hardware("left_arm", adapter) - >>> - >>> # Add task - >>> joints = [f"left_arm_joint{i+1}" for i in range(7)] - >>> task = JointTrajectoryTask( - ... "traj_left", - ... JointTrajectoryTaskConfig(joint_names=joints, priority=10), - ... ) - >>> coord.add_task(task) - >>> - >>> # Start - >>> coord.start() -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "components": [ - "HardwareComponent", - "HardwareId", - "HardwareType", - "JointName", - "JointState", - "make_gripper_joints", - "make_joints", - ], - "coordinator": [ - "ControlCoordinator", - "ControlCoordinatorConfig", - "TaskConfig", - "control_coordinator", - ], - "hardware_interface": ["ConnectedHardware"], - "task": [ - "ControlMode", - "ControlTask", - "CoordinatorState", - "JointCommandOutput", - "JointStateSnapshot", - "ResourceClaim", - ], - "tick_loop": ["TickLoop"], - }, -) diff --git a/dimos/control/blueprints.py b/dimos/control/blueprints.py index 0384c69160..fff2083322 100644 --- a/dimos/control/blueprints.py +++ b/dimos/control/blueprints.py @@ -39,8 +39,9 @@ ) from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import Buttons from dimos.utils.data import LfsPath @@ -49,10 +50,6 @@ _XARM7_MODEL_PATH = LfsPath("xarm_description/urdf/xarm7/xarm7.urdf") -# ============================================================================= -# Single Arm Blueprints -# ============================================================================= - # Mock 7-DOF arm (for testing) coordinator_mock = control_coordinator( tick_rate=100.0, @@ -168,10 +165,6 @@ ) -# ============================================================================= -# Dual Arm Blueprints -# ============================================================================= - # Dual mock arms (7-DOF left, 6-DOF right) coordinator_dual_mock = control_coordinator( tick_rate=100.0, @@ -298,10 +291,6 @@ ) -# ============================================================================= -# Streaming Control Blueprints -# ============================================================================= - # XArm6 teleop - streaming position control coordinator_teleop_xarm6 = control_coordinator( tick_rate=100.0, @@ -399,11 +388,6 @@ ) -# ============================================================================= -# Cartesian IK Blueprints (internal Pinocchio IK solver) -# ============================================================================= - - # Mock 6-DOF arm with CartesianIK coordinator_cartesian_ik_mock = control_coordinator( tick_rate=100.0, @@ -471,10 +455,6 @@ ) -# ============================================================================= -# Teleop IK Blueprints (VR teleoperation with internal Pinocchio IK) -# ============================================================================= - # Single XArm7 with TeleopIK coordinator_teleop_xarm7 = control_coordinator( tick_rate=100.0, @@ -605,10 +585,6 @@ ) -# ============================================================================= -# Twist Base Blueprints (velocity-commanded platforms) -# ============================================================================= - # Mock holonomic twist base (3-DOF: vx, vy, wz) _base_joints = make_twist_base_joints("base") coordinator_mock_twist_base = control_coordinator( @@ -636,10 +612,6 @@ ) -# ============================================================================= -# Mobile Manipulation Blueprints (arm + twist base) -# ============================================================================= - # Mock arm (7-DOF) + mock holonomic base (3-DOF) _mm_base_joints = make_twist_base_joints("base") coordinator_mobile_manip_mock = control_coordinator( @@ -679,10 +651,6 @@ ) -# ============================================================================= -# Raw Blueprints (for programmatic setup) -# ============================================================================= - coordinator_basic = control_coordinator( tick_rate=100.0, publish_joint_state=True, @@ -694,10 +662,6 @@ ) -# ============================================================================= -# Exports -# ============================================================================= - __all__ = [ # Raw "coordinator_basic", diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index 21d4c9d06c..0757f27705 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -49,13 +49,9 @@ TwistBaseAdapter, ) from dimos.hardware.manipulators.spec import ManipulatorAdapter -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Twist, -) -from dimos.msgs.sensor_msgs import ( - JointState, -) +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.quest.quest_types import ( Buttons, ) @@ -68,11 +64,6 @@ logger = setup_logger() -# ============================================================================= -# Configuration -# ============================================================================= - - @dataclass class TaskConfig: """Configuration for a control task. @@ -104,7 +95,6 @@ class TaskConfig: gripper_closed_pos: float = 0.0 -@dataclass class ControlCoordinatorConfig(ModuleConfig): """Configuration for the ControlCoordinator. @@ -125,11 +115,6 @@ class ControlCoordinatorConfig(ModuleConfig): tasks: list[TaskConfig] = field(default_factory=lambda: []) -# ============================================================================= -# ControlCoordinator Module -# ============================================================================= - - class ControlCoordinator(Module[ControlCoordinatorConfig]): """Centralized control coordinator with per-joint arbitration. @@ -202,10 +187,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: logger.info(f"ControlCoordinator initialized at {self.config.tick_rate}Hz") - # ========================================================================= - # Config-based Setup - # ========================================================================= - def _setup_from_config(self) -> None: """Create hardware and tasks from config (called on start).""" hardware_added: list[str] = [] @@ -273,7 +254,10 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: task_type = cfg.type.lower() if task_type == "trajectory": - from dimos.control.tasks import JointTrajectoryTask, JointTrajectoryTaskConfig + from dimos.control.tasks.trajectory_task import ( + JointTrajectoryTask, + JointTrajectoryTaskConfig, + ) return JointTrajectoryTask( cfg.name, @@ -284,7 +268,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "servo": - from dimos.control.tasks import JointServoTask, JointServoTaskConfig + from dimos.control.tasks.servo_task import JointServoTask, JointServoTaskConfig return JointServoTask( cfg.name, @@ -295,7 +279,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "velocity": - from dimos.control.tasks import JointVelocityTask, JointVelocityTaskConfig + from dimos.control.tasks.velocity_task import JointVelocityTask, JointVelocityTaskConfig return JointVelocityTask( cfg.name, @@ -306,7 +290,7 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: ) elif task_type == "cartesian_ik": - from dimos.control.tasks import CartesianIKTask, CartesianIKTaskConfig + from dimos.control.tasks.cartesian_ik_task import CartesianIKTask, CartesianIKTaskConfig if cfg.model_path is None: raise ValueError(f"CartesianIKTask '{cfg.name}' requires model_path in TaskConfig") @@ -344,10 +328,6 @@ def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: else: raise ValueError(f"Unknown task type: {task_type}") - # ========================================================================= - # Hardware Management (RPC) - # ========================================================================= - @rpc def add_hardware( self, @@ -447,10 +427,6 @@ def get_joint_positions(self) -> dict[str, float]: positions[joint_name] = joint_state.position return positions - # ========================================================================= - # Task Management (RPC) - # ========================================================================= - @rpc def add_task(self, task: ControlTask) -> bool: """Register a task with the coordinator.""" @@ -493,10 +469,6 @@ def get_active_tasks(self) -> list[str]: with self._task_lock: return [name for name, task in self._tasks.items() if task.is_active()] - # ========================================================================= - # Streaming Control - # ========================================================================= - def _on_joint_command(self, msg: JointState) -> None: """Route incoming JointState to streaming tasks by joint name. @@ -604,10 +576,6 @@ def task_invoke( return getattr(task, method)(**kwargs) - # ========================================================================= - # Gripper - # ========================================================================= - @rpc def set_gripper_position(self, hardware_id: str, position: float) -> bool: """Set gripper position on a specific hardware device. @@ -641,10 +609,6 @@ def get_gripper_position(self, hardware_id: str) -> float | None: return None return hw.adapter.read_gripper_position() - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def start(self) -> None: """Start the coordinator control loop.""" diff --git a/dimos/control/examples/cartesian_ik_jogger.py b/dimos/control/examples/cartesian_ik_jogger.py index d2a2f4d119..bf3b36a972 100644 --- a/dimos/control/examples/cartesian_ik_jogger.py +++ b/dimos/control/examples/cartesian_ik_jogger.py @@ -116,7 +116,7 @@ def to_pose_stamped(self, task_name: str) -> Any: Args: task_name: Task name to use as frame_id for routing """ - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -168,7 +168,7 @@ def run_jogger_ui(model_path: str | None = None, ee_joint_id: int = 6) -> None: ee_joint_id: End-effector joint ID in the model """ from dimos.core.transport import LCMTransport - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped # Use Piper model if not specified if model_path is None: diff --git a/dimos/control/task.py b/dimos/control/task.py index ecdf9ab7f4..afad70bb05 100644 --- a/dimos/control/task.py +++ b/dimos/control/task.py @@ -34,13 +34,10 @@ from dimos.hardware.manipulators.spec import ControlMode if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_types import Buttons -# ============================================================================= -# Data Types -# ============================================================================= - @dataclass(frozen=True) class ResourceClaim: @@ -168,11 +165,6 @@ def get_values(self) -> list[float] | None: return None -# ============================================================================= -# ControlTask Protocol -# ============================================================================= - - @runtime_checkable class ControlTask(Protocol): """Protocol for passive tasks that run within the coordinator. diff --git a/dimos/control/tasks/__init__.py b/dimos/control/tasks/__init__.py deleted file mode 100644 index 5b869b01f9..0000000000 --- a/dimos/control/tasks/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Task implementations for the ControlCoordinator.""" - -from dimos.control.tasks.cartesian_ik_task import ( - CartesianIKTask, - CartesianIKTaskConfig, -) -from dimos.control.tasks.servo_task import ( - JointServoTask, - JointServoTaskConfig, -) -from dimos.control.tasks.teleop_task import ( - TeleopIKTask, - TeleopIKTaskConfig, -) -from dimos.control.tasks.trajectory_task import ( - JointTrajectoryTask, - JointTrajectoryTaskConfig, -) -from dimos.control.tasks.velocity_task import ( - JointVelocityTask, - JointVelocityTaskConfig, -) - -__all__ = [ - "CartesianIKTask", - "CartesianIKTaskConfig", - "JointServoTask", - "JointServoTaskConfig", - "JointTrajectoryTask", - "JointTrajectoryTaskConfig", - "JointVelocityTask", - "JointVelocityTaskConfig", - "TeleopIKTask", - "TeleopIKTaskConfig", -] diff --git a/dimos/control/tasks/cartesian_ik_task.py b/dimos/control/tasks/cartesian_ik_task.py index 6ea5ddc55b..2525db69e6 100644 --- a/dimos/control/tasks/cartesian_ik_task.py +++ b/dimos/control/tasks/cartesian_ik_task.py @@ -50,7 +50,8 @@ from numpy.typing import NDArray import pinocchio # type: ignore[import-untyped] - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() @@ -255,10 +256,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: f"CartesianIKTask {self._name} preempted by {by_task} on joints {joints}" ) - # ========================================================================= - # Task-specific methods - # ========================================================================= - def on_cartesian_command(self, pose: Pose | PoseStamped, t_now: float) -> bool: """Handle incoming cartesian command (target EE pose). diff --git a/dimos/control/tasks/servo_task.py b/dimos/control/tasks/servo_task.py index b69b4dd099..50805bfa2c 100644 --- a/dimos/control/tasks/servo_task.py +++ b/dimos/control/tasks/servo_task.py @@ -159,10 +159,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: logger.warning(f"JointServoTask {self._name} preempted by {by_task} on joints {joints}") - # ========================================================================= - # Task-specific methods - # ========================================================================= - def set_target(self, positions: list[float], t_now: float) -> bool: """Set target joint positions. diff --git a/dimos/control/tasks/teleop_task.py b/dimos/control/tasks/teleop_task.py index ce63dc4006..3f20502759 100644 --- a/dimos/control/tasks/teleop_task.py +++ b/dimos/control/tasks/teleop_task.py @@ -51,7 +51,8 @@ from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_types import Buttons logger = setup_logger() @@ -295,10 +296,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: logger.warning(f"TeleopIKTask {self._name} preempted by {by_task} on joints {joints}") - # ========================================================================= - # Task-specific methods - # ========================================================================= - def on_buttons(self, msg: Buttons) -> bool: """Press-and-hold engage: hold primary button to track, release to stop.""" is_left = self._config.hand == "left" diff --git a/dimos/control/tasks/trajectory_task.py b/dimos/control/tasks/trajectory_task.py index 4d2eaa188b..fd0a9fda6e 100644 --- a/dimos/control/tasks/trajectory_task.py +++ b/dimos/control/tasks/trajectory_task.py @@ -32,7 +32,8 @@ JointCommandOutput, ResourceClaim, ) -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -171,10 +172,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: if joints & self._joint_names: self._state = TrajectoryState.ABORTED - # ========================================================================= - # Task-specific methods - # ========================================================================= - def execute(self, trajectory: JointTrajectory) -> bool: """Start executing a trajectory. diff --git a/dimos/control/tasks/velocity_task.py b/dimos/control/tasks/velocity_task.py index 163bc09827..5da475114d 100644 --- a/dimos/control/tasks/velocity_task.py +++ b/dimos/control/tasks/velocity_task.py @@ -191,10 +191,6 @@ def on_preempted(self, by_task: str, joints: frozenset[str]) -> None: f"JointVelocityTask {self._name} preempted by {by_task} on joints {joints}" ) - # ========================================================================= - # Task-specific methods - # ========================================================================= - def set_velocities(self, velocities: list[float], t_now: float) -> bool: """Set target joint velocities. diff --git a/dimos/control/test_control.py b/dimos/control/test_control.py index 656678d167..3de7865ae3 100644 --- a/dimos/control/test_control.py +++ b/dimos/control/test_control.py @@ -38,11 +38,8 @@ ) from dimos.control.tick_loop import TickLoop from dimos.hardware.manipulators.spec import ManipulatorAdapter -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint - -# ============================================================================= -# Fixtures -# ============================================================================= +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint @pytest.fixture @@ -112,11 +109,6 @@ def coordinator_state(): return CoordinatorState(joints=joints, t_now=time.perf_counter(), dt=0.01) -# ============================================================================= -# Test JointCommandOutput -# ============================================================================= - - class TestJointCommandOutput: def test_position_output(self): output = JointCommandOutput( @@ -153,11 +145,6 @@ def test_no_values_returns_none(self): assert output.get_values() is None -# ============================================================================= -# Test JointStateSnapshot -# ============================================================================= - - class TestJointStateSnapshot: def test_get_position(self): snapshot = JointStateSnapshot( @@ -171,11 +158,6 @@ def test_get_position(self): assert snapshot.get_position("nonexistent") is None -# ============================================================================= -# Test ConnectedHardware -# ============================================================================= - - class TestConnectedHardware: def test_joint_names_prefixed(self, connected_hardware): names = connected_hardware.joint_names @@ -206,11 +188,6 @@ def test_write_command(self, connected_hardware, mock_adapter): mock_adapter.write_joint_positions.assert_called() -# ============================================================================= -# Test JointTrajectoryTask -# ============================================================================= - - class TestJointTrajectoryTask: def test_initial_state(self, trajectory_task): assert trajectory_task.name == "test_traj" @@ -314,11 +291,6 @@ def test_progress(self, trajectory_task, simple_trajectory, coordinator_state): assert trajectory_task.get_progress(t_start + 1.0) == pytest.approx(1.0, abs=0.01) -# ============================================================================= -# Test Arbitration Logic -# ============================================================================= - - class TestArbitration: def test_single_task_wins(self): outputs = [ @@ -422,11 +394,6 @@ def test_non_overlapping_joints(self): assert winners["j4"][3] == "task2" -# ============================================================================= -# Test TickLoop -# ============================================================================= - - class TestTickLoop: def test_tick_loop_starts_and_stops(self, mock_adapter): component = HardwareComponent( @@ -498,11 +465,6 @@ def test_tick_loop_calls_compute(self, mock_adapter): assert mock_task.compute.call_count > 0 -# ============================================================================= -# Integration Test -# ============================================================================= - - class TestIntegration: def test_full_trajectory_execution(self, mock_adapter): component = HardwareComponent( diff --git a/dimos/control/tick_loop.py b/dimos/control/tick_loop.py index e0020a34da..dc1ed32dbb 100644 --- a/dimos/control/tick_loop.py +++ b/dimos/control/tick_loop.py @@ -38,7 +38,7 @@ JointStateSnapshot, ResourceClaim, ) -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -172,26 +172,19 @@ def _tick(self) -> None: self._last_tick_time = t_now self._tick_count += 1 - # === PHASE 1: READ ALL HARDWARE === joint_states = self._read_all_hardware() state = CoordinatorState(joints=joint_states, t_now=t_now, dt=dt) - # === PHASE 2: COMPUTE ALL ACTIVE TASKS === commands = self._compute_all_tasks(state) - # === PHASE 3: ARBITRATE (with mode validation) === joint_commands, preemptions = self._arbitrate(commands) - # === PHASE 4: NOTIFY PREEMPTIONS (once per task) === self._notify_preemptions(preemptions) - # === PHASE 5: ROUTE TO HARDWARE === hw_commands = self._route_to_hardware(joint_commands) - # === PHASE 6: WRITE TO HARDWARE === self._write_all_hardware(hw_commands) - # === PHASE 7: PUBLISH AGGREGATED STATE === if self._publish_callback: self._publish_joint_state(joint_states) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 287697f6c0..cac8507881 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field, replace from functools import cached_property, reduce -import inspect import operator import sys from types import MappingProxyType @@ -27,7 +26,7 @@ from dimos.protocol.service.system_configurator.base import SystemConfigurator from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module, is_module_type +from dimos.core.module import Module, ModuleBase, ModuleSpec, is_module_type from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport @@ -35,6 +34,11 @@ from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + logger = setup_logger() @@ -48,21 +52,18 @@ class StreamRef: @dataclass(frozen=True) class ModuleRef: name: str - spec: type[Spec] | type[Module] + spec: type[Spec] | type[ModuleBase] @dataclass(frozen=True) class _BlueprintAtom: - module: type[Module] + kwargs: dict[str, Any] + module: type[ModuleBase[Any]] streams: tuple[StreamRef, ...] module_refs: tuple[ModuleRef, ...] - args: tuple[Any, ...] - kwargs: dict[str, Any] @classmethod - def create( - cls, module: type[Module], args: tuple[Any, ...], kwargs: dict[str, Any] - ) -> "_BlueprintAtom": + def create(cls, module: type[ModuleBase[Any]], kwargs: dict[str, Any]) -> Self: streams: list[StreamRef] = [] module_refs: list[ModuleRef] = [] @@ -103,7 +104,6 @@ def create( module=module, streams=tuple(streams), module_refs=tuple(module_refs), - args=args, kwargs=kwargs, ) @@ -111,23 +111,23 @@ def create( @dataclass(frozen=True) class Blueprint: blueprints: tuple[_BlueprintAtom, ...] - disabled_modules_tuple: tuple[type[Module], ...] = field(default_factory=tuple) + disabled_modules_tuple: tuple[type[ModuleBase], ...] = field(default_factory=tuple) transport_map: Mapping[tuple[str, type], PubSubTransport[Any]] = field( default_factory=lambda: MappingProxyType({}) ) global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({})) - remapping_map: Mapping[tuple[type[Module], str], str | type[Module] | type[Spec]] = field( - default_factory=lambda: MappingProxyType({}) + remapping_map: Mapping[tuple[type[ModuleBase], str], str | type[ModuleBase] | type[Spec]] = ( + field(default_factory=lambda: MappingProxyType({})) ) requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) configurator_checks: "tuple[SystemConfigurator, ...]" = field(default_factory=tuple) @classmethod - def create(cls, module: type[Module], *args: Any, **kwargs: Any) -> "Blueprint": - blueprint = _BlueprintAtom.create(module, args, kwargs) + def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint": + blueprint = _BlueprintAtom.create(module, kwargs) return cls(blueprints=(blueprint,)) - def disabled_modules(self, *modules: type[Module]) -> "Blueprint": + def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint": return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules) def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint": @@ -140,7 +140,10 @@ def global_config(self, **kwargs: Any) -> "Blueprint": ) def remappings( - self, remappings: list[tuple[type[Module], str, str | type[Module] | type[Spec]]] + self, + remappings: list[ + tuple[type[ModuleBase[Any]], str, str | type[ModuleBase[Any]] | type[Spec]] + ], ) -> "Blueprint": remappings_dict = dict(self.remapping_map) for module, old, new in remappings: @@ -163,8 +166,8 @@ def _active_blueprints(self) -> tuple[_BlueprintAtom, ...]: def _check_ambiguity( self, requested_method_name: str, - interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]], - requesting_module: type[Module], + interface_methods: Mapping[str, list[tuple[type[ModuleBase], Callable[..., Any]]]], + requesting_module: type[ModuleBase], ) -> None: if ( requested_method_name in interface_methods @@ -206,7 +209,8 @@ def _is_name_unique(self, name: str) -> bool: return sum(1 for n, _ in self._all_name_types if n == name) == 1 def _run_configurators(self) -> None: - from dimos.protocol.service.system_configurator import configure_system, lcm_configurators + from dimos.protocol.service.system_configurator.base import configure_system + from dimos.protocol.service.system_configurator.lcm_config import lcm_configurators configurators = [*lcm_configurators(), *self.configurator_checks] @@ -273,13 +277,9 @@ def _verify_no_name_conflicts(self) -> None: def _deploy_all_modules( self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig ) -> None: - module_specs: list[tuple[type[Module], tuple[Any, ...], dict[str, Any]]] = [] + module_specs: list[ModuleSpec] = [] for blueprint in self._active_blueprints: - kwargs = {**blueprint.kwargs} - sig = inspect.signature(blueprint.module.__init__) - if "cfg" in sig.parameters: - kwargs["cfg"] = global_config - module_specs.append((blueprint.module, blueprint.args, kwargs)) + module_specs.append((blueprint.module, global_config, blueprint.kwargs)) module_coordinator.deploy_parallel(module_specs) @@ -399,12 +399,12 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: rpc_methods_dot = {} # Track interface methods to detect ambiguity. - interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + interface_methods: defaultdict[str, list[tuple[type[ModuleBase], Callable[..., Any]]]] = ( defaultdict(list) ) # interface_name_method -> [(module_class, method)] - interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( - defaultdict(list) - ) # interface_name.method -> [(module_class, method)] + interface_methods_dot: defaultdict[ + str, list[tuple[type[ModuleBase], Callable[..., Any]]] + ] = defaultdict(list) # interface_name.method -> [(module_class, method)] for blueprint in self._active_blueprints: for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] diff --git a/dimos/core/daemon.py b/dimos/core/daemon.py index f4a19c9403..61060b2a73 100644 --- a/dimos/core/daemon.py +++ b/dimos/core/daemon.py @@ -31,10 +31,6 @@ logger = setup_logger() -# --------------------------------------------------------------------------- -# Health check (delegates to ModuleCoordinator.health_check) -# --------------------------------------------------------------------------- - def health_check(coordinator: ModuleCoordinator) -> bool: """Verify all coordinator workers are alive after build. @@ -45,11 +41,6 @@ def health_check(coordinator: ModuleCoordinator) -> bool: return coordinator.health_check() -# --------------------------------------------------------------------------- -# Daemonize (double-fork) -# --------------------------------------------------------------------------- - - def daemonize(log_dir: Path) -> None: """Double-fork daemonize the current process. @@ -83,11 +74,6 @@ def daemonize(log_dir: Path) -> None: devnull.close() -# --------------------------------------------------------------------------- -# Signal handler for clean shutdown -# --------------------------------------------------------------------------- - - def install_signal_handlers(entry: RunEntry, coordinator: ModuleCoordinator) -> None: """Install SIGTERM/SIGINT handlers that stop the coordinator and clean the registry.""" diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index ee56163ca6..dcb75fbdee 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -15,7 +15,7 @@ import argparse from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import field import importlib import json import os @@ -28,7 +28,7 @@ from dimos.core.docker_build import build_image, image_exists from dimos.core.module import Module, ModuleConfig from dimos.core.rpc_client import RpcCall -from dimos.protocol.rpc import LCMRPC +from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT @@ -46,7 +46,6 @@ LOG_TAIL_LINES = 200 # Number of log lines to include in error messages -@dataclass(kw_only=True) class DockerModuleConfig(ModuleConfig): """ Configuration for running a DimOS module inside Docker. diff --git a/dimos/core/introspection/__init__.py b/dimos/core/introspection/__init__.py deleted file mode 100644 index c40c3d49e6..0000000000 --- a/dimos/core/introspection/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module and blueprint introspection utilities.""" - -from dimos.core.introspection.module import INTERNAL_RPCS, render_module_io -from dimos.core.introspection.svg import to_svg - -__all__ = ["INTERNAL_RPCS", "render_module_io", "to_svg"] diff --git a/dimos/core/introspection/blueprint/__init__.py b/dimos/core/introspection/blueprint/__init__.py deleted file mode 100644 index 6545b39dfa..0000000000 --- a/dimos/core/introspection/blueprint/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Blueprint introspection and rendering. - -Renderers: - - dot: Graphviz DOT format (hub-style with type nodes as intermediate hubs) -""" - -from dimos.core.introspection.blueprint import dot -from dimos.core.introspection.blueprint.dot import LayoutAlgo, render_svg - -__all__ = ["LayoutAlgo", "dot", "render_svg"] diff --git a/dimos/core/introspection/blueprint/dot.py b/dimos/core/introspection/blueprint/dot.py index ea66401033..74ee9406a9 100644 --- a/dimos/core/introspection/blueprint/dot.py +++ b/dimos/core/introspection/blueprint/dot.py @@ -31,7 +31,7 @@ color_for_string, sanitize_id, ) -from dimos.core.module import Module +from dimos.core.module import ModuleBase from dimos.utils.cli import theme @@ -82,11 +82,11 @@ def render( ignored_modules = DEFAULT_IGNORED_MODULES # Collect all outputs: (name, type) -> list of producer modules - producers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + producers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Collect all inputs: (name, type) -> list of consumer modules - consumers: dict[tuple[str, type], list[type[Module]]] = defaultdict(list) + consumers: dict[tuple[str, type], list[type[ModuleBase]]] = defaultdict(list) # Module name -> module class (for getting package info) - module_classes: dict[str, type[Module]] = {} + module_classes: dict[str, type[ModuleBase]] = {} for bp in blueprint_set.blueprints: module_classes[bp.module.__name__] = bp.module @@ -117,7 +117,7 @@ def render( active_channels[key] = color_for_string(TYPE_COLORS, label) # Group modules by package - def get_group(mod_class: type[Module]) -> str: + def get_group(mod_class: type[ModuleBase]) -> str: module_path = mod_class.__module__ parts = module_path.split(".") if len(parts) >= 2 and parts[0] == "dimos": diff --git a/dimos/core/introspection/module/__init__.py b/dimos/core/introspection/module/__init__.py deleted file mode 100644 index 444d0e24f3..0000000000 --- a/dimos/core/introspection/module/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module introspection and rendering. - -Renderers: - - ansi: ANSI terminal output (default) - - dot: Graphviz DOT format -""" - -from dimos.core.introspection.module import ansi, dot -from dimos.core.introspection.module.info import ( - INTERNAL_RPCS, - ModuleInfo, - ParamInfo, - RpcInfo, - SkillInfo, - StreamInfo, - extract_module_info, -) -from dimos.core.introspection.module.render import render_module_io - -__all__ = [ - "INTERNAL_RPCS", - "ModuleInfo", - "ParamInfo", - "RpcInfo", - "SkillInfo", - "StreamInfo", - "ansi", - "dot", - "extract_module_info", - "render_module_io", -] diff --git a/dimos/core/module.py b/dimos/core/module.py index 48a99a79a3..1c5b311883 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -17,38 +17,45 @@ from functools import partial import inspect import json +import sys import threading from typing import ( TYPE_CHECKING, Any, + Protocol, get_args, get_origin, get_type_hints, overload, ) -from typing_extensions import TypeVar as TypeVarExtension - -if TYPE_CHECKING: - from dimos.core.introspection.module import ModuleInfo - from dimos.core.rpc_client import RPCClient - -from typing import TypeVar - from langchain_core.tools import tool from reactivex.disposable import CompositeDisposable from dimos.core.core import T, rpc -from dimos.core.introspection.module import extract_module_info, render_module_io +from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.introspection.module.info import extract_module_info +from dimos.core.introspection.module.render import render_module_io from dimos.core.resource import Resource from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport -from dimos.protocol.rpc import LCMRPC, RPCSpec -from dimos.protocol.service import Configurable # type: ignore[attr-defined] -from dimos.protocol.tf import LCMTF, TFSpec +from dimos.protocol.rpc.pubsubrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCSpec +from dimos.protocol.service.spec import BaseConfig, Configurable +from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +if TYPE_CHECKING: + from dimos.core.blueprints import Blueprint + from dimos.core.introspection.module.info import ModuleInfo + from dimos.core.rpc_client import RPCClient + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + @dataclass(frozen=True) class SkillInfo: @@ -70,20 +77,27 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: return loop, thr -@dataclass -class ModuleConfig: +class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC - tf_transport: type[TFSpec] = LCMTF + tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] frame_id_prefix: str | None = None frame_id: str | None = None + g: GlobalConfig = global_config + +ModuleConfigT = TypeVar("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) -ModuleConfigT = TypeVarExtension("ModuleConfigT", bound=ModuleConfig, default=ModuleConfig) + +class _BlueprintPartial(Protocol): + def __call__(self, **kwargs: Any) -> "Blueprint": ... class ModuleBase(Configurable[ModuleConfigT], Resource): + # This won't type check against the TypeVar, but we need it as the default. + default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] + _rpc: RPCSpec | None = None - _tf: TFSpec | None = None + _tf: TFSpec[Any] | None = None _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None _disposables: CompositeDisposable @@ -93,10 +107,8 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): rpc_calls: list[str] = [] - default_config: type[ModuleConfigT] = ModuleConfig # type: ignore[assignment] - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) + def __init__(self, config_args: dict[str, Any]): + super().__init__(**config_args) self._module_closed_lock = threading.Lock() self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() @@ -338,7 +350,7 @@ def __get__( module_info = _module_info_descriptor() @classproperty - def blueprint(self): # type: ignore[no-untyped-def] + def blueprint(self) -> _BlueprintPartial: # Here to prevent circular imports. from dimos.core.blueprints import Blueprint @@ -409,7 +421,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not hasattr(cls, name) or getattr(cls, name) is None: setattr(cls, name, None) - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any): self.ref = None # type: ignore[assignment] try: @@ -427,7 +439,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) # type: ignore[assignment] setattr(self, name, stream) - super().__init__(*args, **kwargs) + super().__init__(config_args=kwargs) def __str__(self) -> str: return f"{self.__class__.__name__}" @@ -465,7 +477,7 @@ def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: input_stream.connection = remote_stream -ModuleT = TypeVar("ModuleT", bound="Module[Any]") +ModuleSpec = tuple[type[ModuleBase], GlobalConfig, dict[str, Any]] def is_module_type(value: Any) -> bool: diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 3a7961fcea..10227eae93 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -19,12 +19,12 @@ from typing import TYPE_CHECKING, Any from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.core.module import Module, ModuleT from dimos.core.resource_monitor.monitor import StatsMonitor from dimos.core.rpc_client import ModuleProxy from dimos.core.worker import Worker @@ -37,7 +37,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[Module], ModuleProxy] + _deployed_modules: dict[type[ModuleBase], ModuleProxy] _stats_monitor: StatsMonitor | None = None def __init__( @@ -115,17 +115,20 @@ def stop(self) -> None: self._client.close_all() # type: ignore[union-attr] - def deploy(self, module_class: type[ModuleT], *args, **kwargs) -> ModuleProxy: # type: ignore[no-untyped-def] + def deploy( + self, + module_class: type[ModuleBase[Any]], + global_config: GlobalConfig = global_config, + **kwargs: Any, + ) -> ModuleProxy: if not self._client: raise ValueError("Trying to dimos.deploy before the client has started") - module: ModuleProxy = self._client.deploy(module_class, *args, **kwargs) # type: ignore[union-attr, attr-defined, assignment] - self._deployed_modules[module_class] = module - return module + module = self._client.deploy(module_class, global_config, kwargs) + self._deployed_modules[module_class] = module # type: ignore[assignment] + return module # type: ignore[return-value] - def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[str, Any]]] - ) -> list[ModuleProxy]: + def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: if not self._client: raise ValueError("Not started") @@ -148,7 +151,7 @@ def start_all_modules(self) -> None: if hasattr(module, "on_system_modules"): module.on_system_modules(module_list) - def get_instance(self, module: type[ModuleT]) -> ModuleProxy: + def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] def loop(self) -> None: diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 6a93e6453a..f4a674cb5d 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -40,7 +40,6 @@ class MyCppModule(NativeModule): from __future__ import annotations -from dataclasses import dataclass, field, fields import enum import inspect import json @@ -48,13 +47,21 @@ class MyCppModule(NativeModule): from pathlib import Path import signal import subprocess +import sys import threading from typing import IO, Any +from pydantic import Field + from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() @@ -63,15 +70,14 @@ class LogFormat(enum.Enum): JSON = "json" -@dataclass(kw_only=True) class NativeModuleConfig(ModuleConfig): """Configuration for a native (C/C++) subprocess module.""" executable: str build_command: str | None = None cwd: str | None = None - extra_args: list[str] = field(default_factory=list) - extra_env: dict[str, str] = field(default_factory=dict) + extra_args: list[str] = Field(default_factory=list) + extra_env: dict[str, str] = Field(default_factory=dict) shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT @@ -85,26 +91,29 @@ def to_cli_args(self) -> list[str]: or its parents) and converts them to ``["--name", str(value)]`` pairs. Skips fields whose values are ``None`` and fields in ``cli_exclude``. """ - ignore_fields = {f.name for f in fields(NativeModuleConfig)} + ignore_fields = {f for f in NativeModuleConfig.model_fields} args: list[str] = [] - for f in fields(self): - if f.name in ignore_fields: + for f in self.__class__.model_fields: + if f in ignore_fields: continue - if f.name in self.cli_exclude: + if f in self.cli_exclude: continue - val = getattr(self, f.name) + val = getattr(self, f) if val is None: continue if isinstance(val, bool): - args.extend([f"--{f.name}", str(val).lower()]) + args.extend([f"--{f}", str(val).lower()]) elif isinstance(val, list): - args.extend([f"--{f.name}", ",".join(str(v) for v in val)]) + args.extend([f"--{f}", ",".join(str(v) for v in val)]) else: - args.extend([f"--{f.name}", str(val)]) + args.extend([f"--{f}", str(val)]) return args -class NativeModule(Module[NativeModuleConfig]): +_NativeConfig = TypeVar("_NativeConfig", bound=NativeModuleConfig, default=NativeModuleConfig) + + +class NativeModule(Module[_NativeConfig]): """Module that wraps a native executable as a managed subprocess. Subclass this, declare In/Out ports, and set ``default_config`` to a @@ -118,13 +127,13 @@ class NativeModule(Module[NativeModuleConfig]): LCM topics directly. On ``stop()``, the process receives SIGTERM. """ - default_config: type[NativeModuleConfig] = NativeModuleConfig + default_config: type[_NativeConfig] = NativeModuleConfig # type: ignore[assignment] _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._resolve_paths() @rpc diff --git a/dimos/core/resource_monitor/__init__.py b/dimos/core/resource_monitor/__init__.py deleted file mode 100644 index 217941a2ec..0000000000 --- a/dimos/core/resource_monitor/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from dimos.core.resource_monitor.logger import ( - LCMResourceLogger, - ResourceLogger, - StructlogResourceLogger, -) -from dimos.core.resource_monitor.monitor import StatsMonitor -from dimos.core.resource_monitor.stats import ProcessStats, WorkerStats, collect_process_stats - -__all__ = [ - "LCMResourceLogger", - "ProcessStats", - "ResourceLogger", - "StatsMonitor", - "StructlogResourceLogger", - "WorkerStats", - "collect_process_stats", -] diff --git a/dimos/core/resource_monitor/stats.py b/dimos/core/resource_monitor/stats.py index 485132db46..6264d5c7f9 100644 --- a/dimos/core/resource_monitor/stats.py +++ b/dimos/core/resource_monitor/stats.py @@ -19,7 +19,7 @@ import psutil -from dimos.utils.decorators import ttl_cache +from dimos.utils.decorators.decorators import ttl_cache # Cache Process objects so cpu_percent(interval=None) has a previous sample. _proc_cache: dict[int, psutil.Process] = {} diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index e46124469c..84de18d671 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -17,7 +17,8 @@ from dimos.core.stream import RemoteStream from dimos.core.worker import MethodCallProxy -from dimos.protocol.rpc import LCMRPC, RPCSpec +from dimos.protocol.rpc.pubsubrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCSpec from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 30677bd1f7..5f7bf33b8b 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -33,7 +33,7 @@ from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.spec.utils import Spec # Disable Rerun for tests (prevents viewer spawn and gRPC flush errors) @@ -113,14 +113,13 @@ class ModuleC(Module): def test_get_connection_set() -> None: - assert _BlueprintAtom.create(CatModule, args=("arg1",), kwargs={"k": "v"}) == _BlueprintAtom( + assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom( module=CatModule, streams=( StreamRef(name="pet_cat", type=Petting, direction="in"), StreamRef(name="scratches", type=Scratch, direction="out"), ), module_refs=(), - args=("arg1",), kwargs={"k": "v"}, ) @@ -137,7 +136,6 @@ def test_autoconnect() -> None: StreamRef(name="data2", type=Data2, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), _BlueprintAtom( @@ -148,7 +146,6 @@ def test_autoconnect() -> None: StreamRef(name="data3", type=Data3, direction="out"), ), module_refs=(), - args=(), kwargs={}, ), ) @@ -342,11 +339,11 @@ def test_future_annotations_support() -> None: """ # Test that streams are properly extracted from modules with future annotations - out_blueprint = _BlueprintAtom.create(FutureModuleOut, args=(), kwargs={}) + out_blueprint = _BlueprintAtom.create(FutureModuleOut, kwargs={}) assert len(out_blueprint.streams) == 1 assert out_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="out") - in_blueprint = _BlueprintAtom.create(FutureModuleIn, args=(), kwargs={}) + in_blueprint = _BlueprintAtom.create(FutureModuleIn, kwargs={}) assert len(in_blueprint.streams) == 1 assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in") diff --git a/dimos/core/test_cli_stop_status.py b/dimos/core/test_cli_stop_status.py index c04d8d2499..5c628f6d92 100644 --- a/dimos/core/test_cli_stop_status.py +++ b/dimos/core/test_cli_stop_status.py @@ -72,11 +72,6 @@ def _entry(run_id: str, pid: int, blueprint: str = "test", **kwargs) -> RunEntry return e -# --------------------------------------------------------------------------- -# STATUS -# --------------------------------------------------------------------------- - - class TestStatusCLI: """Tests for `dimos status` command.""" @@ -132,11 +127,6 @@ def test_status_filters_dead_pids(self): assert "No running" in result.output -# --------------------------------------------------------------------------- -# STOP -# --------------------------------------------------------------------------- - - class TestStopCLI: """Tests for `dimos stop` command.""" diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 197539ef67..f9a89829d5 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -22,8 +22,8 @@ from dimos.core.stream import In, Out from dimos.core.testing import MockRobotClient from dimos.core.transport import LCMTransport, pLCMTransport -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry @@ -39,9 +39,6 @@ class Navigation(Module): @rpc def navigate_to(self, target: Vector3) -> bool: ... - def __init__(self) -> None: - super().__init__() - @rpc def start(self) -> None: def _odom(msg) -> None: diff --git a/dimos/core/test_daemon.py b/dimos/core/test_daemon.py index bd7c6b9ad8..f6dae51433 100644 --- a/dimos/core/test_daemon.py +++ b/dimos/core/test_daemon.py @@ -24,9 +24,6 @@ import pytest -# --------------------------------------------------------------------------- -# Registry tests -# --------------------------------------------------------------------------- from dimos.core import run_registry from dimos.core.run_registry import ( RunEntry, @@ -158,10 +155,6 @@ def test_port_conflict_no_false_positive(self, tmp_registry: Path): assert conflict is None -# --------------------------------------------------------------------------- -# Health check tests -# --------------------------------------------------------------------------- - from dimos.core.module_coordinator import ModuleCoordinator @@ -212,10 +205,6 @@ def test_partial_death(self): assert coord.health_check() is False -# --------------------------------------------------------------------------- -# Daemon tests -# --------------------------------------------------------------------------- - from dimos.core.daemon import daemonize, install_signal_handlers @@ -275,11 +264,6 @@ def test_signal_handler_tolerates_stop_error(self, tmp_registry: Path): assert not entry.registry_path.exists() -# --------------------------------------------------------------------------- -# dimos status tests -# --------------------------------------------------------------------------- - - class TestStatusCommand: """Tests for `dimos status` CLI command.""" @@ -327,11 +311,6 @@ def test_status_filters_dead(self, tmp_path, monkeypatch): assert len(entries) == 0 -# --------------------------------------------------------------------------- -# dimos stop tests -# --------------------------------------------------------------------------- - - class TestStopCommand: """Tests for `dimos stop` CLI command.""" diff --git a/dimos/core/test_e2e_daemon.py b/dimos/core/test_e2e_daemon.py index 7043d0384e..d8ac016faa 100644 --- a/dimos/core/test_e2e_daemon.py +++ b/dimos/core/test_e2e_daemon.py @@ -35,10 +35,6 @@ from dimos.core.stream import Out from dimos.robot.cli.dimos import main -# --------------------------------------------------------------------------- -# Lightweight test modules -# --------------------------------------------------------------------------- - class PingModule(Module): data: Out[str] @@ -54,11 +50,6 @@ def start(self): super().start() -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(autouse=True) def _ci_env(monkeypatch): """Set CI=1 to skip sysctl interactive prompt — scoped per test, not module.""" @@ -114,11 +105,6 @@ def registry_entry(): entry.remove() -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestDaemonE2E: """End-to-end daemon lifecycle with real workers.""" @@ -216,11 +202,6 @@ def test_stale_cleanup(self, coordinator, registry_entry): assert remaining[0].run_id == registry_entry.run_id -# --------------------------------------------------------------------------- -# E2E: CLI status + stop against real running blueprint -# --------------------------------------------------------------------------- - - @pytest.fixture() def live_blueprint(): """Build PingPong and register. Yields (coord, entry). Cleans up on teardown.""" diff --git a/dimos/core/test_mcp_integration.py b/dimos/core/test_mcp_integration.py index 543b9a7fbd..d7527e31f8 100644 --- a/dimos/core/test_mcp_integration.py +++ b/dimos/core/test_mcp_integration.py @@ -55,11 +55,6 @@ MCP_URL = f"http://localhost:{global_config.mcp_port}/mcp" -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - @pytest.fixture(autouse=True) def _ci_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("CI", "1") @@ -121,11 +116,6 @@ def _adapter() -> McpAdapter: return McpAdapter() -# --------------------------------------------------------------------------- -# Tests -- read-only against a shared MCP server -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestMCPLifecycle: """MCP server lifecycle: start -> respond -> stop -> dead.""" @@ -323,11 +313,6 @@ def test_agent_send_cli(self, mcp_shared: ModuleCoordinator) -> None: assert "hello from CLI" in result.output -# --------------------------------------------------------------------------- -# Tests -- lifecycle management (own setup/teardown per test) -# --------------------------------------------------------------------------- - - @pytest.mark.slow class TestDaemonMCPRecovery: """Test MCP recovery after daemon crashes and restarts.""" diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index d17775130e..e77b8f9a53 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -18,7 +18,6 @@ The echo script writes received CLI args to a temp file for assertions. """ -from dataclasses import dataclass import json from pathlib import Path import time @@ -59,7 +58,6 @@ def read_json_file(path: str) -> dict[str, str]: return result -@dataclass(kw_only=True) class StubNativeConfig(NativeModuleConfig): executable: str = _ECHO log_format: LogFormat = LogFormat.TEXT diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index a7c949b33a..fdea17d2a3 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -15,6 +15,7 @@ from collections.abc import Callable import threading import time +from typing import Any import pytest @@ -23,20 +24,20 @@ from dimos.core.stream import In from dimos.core.testing import MockRobotClient from dimos.core.transport import LCMTransport, pLCMTransport -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry class SubscriberBase(Module): - sub1_msgs: list[Odometry] = None - sub2_msgs: list[Odometry] = None + sub1_msgs: list[Odometry] + sub2_msgs: list[Odometry] - def __init__(self) -> None: + def __init__(self, **kwargs: Any) -> None: self.sub1_msgs = [] self.sub2_msgs = [] self._sub1_received = threading.Event() self._sub2_received = threading.Event() - super().__init__() + super().__init__(**kwargs) def _sub1_callback(self, msg) -> None: self.sub1_msgs.append(msg) diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index a5217f2dd6..021b2e21c4 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -17,10 +17,11 @@ import pytest from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.core.worker_manager import WorkerManager -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 if TYPE_CHECKING: from dimos.core.resource_monitor.stats import WorkerStats @@ -99,7 +100,7 @@ def _create(n_workers): @pytest.mark.slow def test_worker_manager_basic(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module = worker_manager.deploy(SimpleModule) + module = worker_manager.deploy(SimpleModule, global_config, {}) module.start() result = module.increment() @@ -117,8 +118,8 @@ def test_worker_manager_basic(create_worker_manager): @pytest.mark.slow def test_worker_manager_multiple_different_modules(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) - module1 = worker_manager.deploy(SimpleModule) - module2 = worker_manager.deploy(AnotherModule) + module1 = worker_manager.deploy(SimpleModule, global_config, {}) + module2 = worker_manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() @@ -141,9 +142,9 @@ def test_worker_manager_parallel_deployment(create_worker_manager): worker_manager = create_worker_manager(n_workers=2) modules = worker_manager.deploy_parallel( [ - (SimpleModule, (), {}), - (AnotherModule, (), {}), - (ThirdModule, (), {}), + (SimpleModule, global_config, {}), + (AnotherModule, global_config, {}), + (ThirdModule, global_config, {}), ] ) @@ -175,8 +176,8 @@ def test_collect_stats(create_worker_manager): from dimos.core.resource_monitor.monitor import StatsMonitor manager = create_worker_manager(n_workers=2) - module1 = manager.deploy(SimpleModule) - module2 = manager.deploy(AnotherModule) + module1 = manager.deploy(SimpleModule, global_config, {}) + module2 = manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() @@ -219,8 +220,8 @@ def log_stats(self, coordinator, workers): @pytest.mark.slow def test_worker_pool_modules_share_workers(create_worker_manager): manager = create_worker_manager(n_workers=1) - module1 = manager.deploy(SimpleModule) - module2 = manager.deploy(AnotherModule) + module1 = manager.deploy(SimpleModule, global_config, {}) + module2 = manager.deploy(AnotherModule, global_config, {}) module1.start() module2.start() diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 6431c09dbd..a128fc4767 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -14,15 +14,16 @@ from threading import Event, Thread import time +from typing import Any from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay class MockRobotClient(Module): @@ -32,8 +33,8 @@ class MockRobotClient(Module): mov_msg_count = 0 - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = Event() self._thread = None diff --git a/dimos/core/tests/demo_devex.py b/dimos/core/tests/demo_devex.py index b9ac1393d7..243c870fab 100644 --- a/dimos/core/tests/demo_devex.py +++ b/dimos/core/tests/demo_devex.py @@ -98,9 +98,6 @@ def main() -> None: print(" Simulating: OpenClaw agent using DimOS") print("=" * 60) - # --------------------------------------------------------------- - # Step 1: dimos run stress-test --daemon - # --------------------------------------------------------------- section("Step 1: dimos run stress-test --daemon") result = run_dimos("run", "stress-test", "--daemon", timeout=60) print(f" stdout: {result.stdout.strip()[:200]}") @@ -131,9 +128,6 @@ def main() -> None: print(" Cannot continue without MCP. Exiting.") sys.exit(1) - # --------------------------------------------------------------- - # Step 2: dimos status - # --------------------------------------------------------------- section("Step 2: dimos status") result = run_dimos("status") print(f" output: {result.stdout.strip()[:300]}") @@ -145,9 +139,6 @@ def main() -> None: p(f"Status unclear (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 3: dimos mcp list-tools - # --------------------------------------------------------------- section("Step 3: dimos mcp list-tools") result = run_dimos("mcp", "list-tools") if result.returncode == 0: @@ -167,9 +158,6 @@ def main() -> None: p(f"list-tools failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 4: dimos mcp call echo --arg message=hello - # --------------------------------------------------------------- section("Step 4: dimos mcp call echo --arg message=hello") result = run_dimos("mcp", "call", "echo", "--arg", "message=hello-from-devex-test") if result.returncode == 0 and "hello-from-devex-test" in result.stdout: @@ -178,9 +166,6 @@ def main() -> None: p(f"echo call failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 5: dimos mcp status - # --------------------------------------------------------------- section("Step 5: dimos mcp status") result = run_dimos("mcp", "status") if result.returncode == 0: @@ -196,9 +181,6 @@ def main() -> None: p(f"mcp status failed (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 6: dimos mcp modules - # --------------------------------------------------------------- section("Step 6: dimos mcp modules") result = run_dimos("mcp", "modules") if result.returncode == 0: @@ -213,9 +195,6 @@ def main() -> None: p(f"mcp modules failed (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 7: dimos agent-send "hello" - # --------------------------------------------------------------- section("Step 7: dimos agent-send 'what tools do you have?'") result = run_dimos("agent-send", "what tools do you have?") if result.returncode == 0: @@ -224,9 +203,6 @@ def main() -> None: p(f"agent-send failed (exit={result.returncode}): {result.stdout[:100]}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 8: Check logs - # --------------------------------------------------------------- section("Step 8: Check per-run logs") log_base = os.path.expanduser("~/.local/state/dimos/logs") if os.path.isdir(log_base): @@ -257,9 +233,6 @@ def main() -> None: p(f"Log base dir not found: {log_base}", ok=False) failures += 1 - # --------------------------------------------------------------- - # Step 9: dimos stop - # --------------------------------------------------------------- section("Step 9: dimos stop") result = run_dimos("stop") print(f" output: {result.stdout.strip()[:200]}") @@ -272,9 +245,6 @@ def main() -> None: # Wait for shutdown time.sleep(2) - # --------------------------------------------------------------- - # Step 10: dimos status (verify stopped) - # --------------------------------------------------------------- section("Step 10: dimos status (verify stopped)") result = run_dimos("status") print(f" output: {result.stdout.strip()[:200]}") @@ -288,9 +258,6 @@ def main() -> None: p(f"Unexpected status after stop (exit={result.returncode})", ok=False) failures += 1 - # --------------------------------------------------------------- - # Summary - # --------------------------------------------------------------- print("\n" + "=" * 60) if failures == 0: print(" \u2705 FULL DEVELOPER EXPERIENCE TEST PASSED") diff --git a/dimos/core/worker.py b/dimos/core/worker.py index 3a98e6b7ba..dca561f16c 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -15,19 +15,19 @@ import logging import multiprocessing +from multiprocessing.connection import Connection import os import sys import threading import traceback from typing import TYPE_CHECKING, Any +from dimos.core.global_config import GlobalConfig, global_config from dimos.utils.logging_config import setup_logger from dimos.utils.sequential_ids import SequentialIds if TYPE_CHECKING: - from multiprocessing.connection import Connection - - from dimos.core.module import ModuleT + from dimos.core.module import ModuleBase logger = setup_logger() @@ -75,7 +75,7 @@ class Actor: def __init__( self, conn: Connection | None, - module_class: type[ModuleT], + module_class: type[ModuleBase], worker_id: int, module_id: int = 0, lock: threading.Lock | None = None, @@ -143,8 +143,6 @@ def reset_forkserver_context() -> None: class Worker: - """Generic worker process that can host multiple modules.""" - def __init__(self) -> None: self._lock = threading.Lock() self._modules: dict[int, Actor] = {} @@ -198,14 +196,15 @@ def start_process(self) -> None: def deploy_module( self, - module_class: type[ModuleT], - args: tuple[Any, ...] = (), - kwargs: dict[Any, Any] | None = None, + module_class: type[ModuleBase], + global_config: GlobalConfig = global_config, + kwargs: dict[str, Any] | None = None, ) -> Actor: if self._conn is None: raise RuntimeError("Worker process not started") kwargs = kwargs or {} + kwargs["g"] = global_config module_id = _module_ids.next() # Send deploy_module request to the worker process @@ -213,7 +212,6 @@ def deploy_module( "type": "deploy_module", "module_id": module_id, "module_class": module_class, - "args": args, "kwargs": kwargs, } with self._lock: @@ -293,10 +291,7 @@ def _suppress_console_output() -> None: ] -def _worker_entrypoint( - conn: Connection, - worker_id: int, -) -> None: +def _worker_entrypoint(conn: Connection, worker_id: int) -> None: instances: dict[int, Any] = {} try: @@ -346,10 +341,9 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) -> if req_type == "deploy_module": module_class = request["module_class"] - args = request.get("args", ()) - kwargs = request.get("kwargs", {}) + kwargs = request["kwargs"] module_id = request["module_id"] - instance = module_class(*args, **kwargs) + instance = module_class(**kwargs) instances[module_id] = instance response["result"] = module_id diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 2b41f634e8..4cd5eec8d7 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -14,16 +14,16 @@ from __future__ import annotations +from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +from typing import Any +from dimos.core.global_config import GlobalConfig +from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.core.module import ModuleT - logger = setup_logger() @@ -47,7 +47,9 @@ def start(self) -> None: def _select_worker(self) -> Worker: return min(self._workers, key=lambda w: w.module_count) - def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCClient: + def deploy( + self, module_class: type[ModuleBase], global_config: GlobalConfig, kwargs: dict[str, Any] + ) -> RPCClient: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -56,12 +58,10 @@ def deploy(self, module_class: type[ModuleT], *args: Any, **kwargs: Any) -> RPCC self.start() worker = self._select_worker() - actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) + actor = worker.deploy_module(module_class, global_config, kwargs=kwargs) return RPCClient(actor, module_class) - def deploy_parallel( - self, module_specs: list[tuple[type[ModuleT], tuple[Any, ...], dict[Any, Any]]] - ) -> list[RPCClient]: + def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -72,17 +72,17 @@ def deploy_parallel( # Pre-assign workers sequentially (so least-loaded accounting is # correct), then deploy concurrently via threads. The per-worker lock # serializes deploys that land on the same worker process. - assignments: list[tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]]] = [] - for module_class, args, kwargs in module_specs: + assignments: list[tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]]] = [] + for module_class, global_config, kwargs in module_specs: worker = self._select_worker() worker.reserve_slot() - assignments.append((worker, module_class, args, kwargs)) + assignments.append((worker, module_class, global_config, kwargs)) def _deploy( - item: tuple[Worker, type[ModuleT], tuple[Any, ...], dict[Any, Any]], + item: tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]], ) -> RPCClient: - worker, module_class, args, kwargs = item - actor = worker.deploy_module(module_class, args=args, kwargs=kwargs) + worker, module_class, global_config, kwargs = item + actor = worker.deploy_module(module_class, global_config=global_config, kwargs=kwargs) return RPCClient(actor, module_class) with ThreadPoolExecutor(max_workers=len(assignments)) as pool: diff --git a/dimos/e2e_tests/conftest.py b/dimos/e2e_tests/conftest.py index 51ab7c2c18..12f4a674a6 100644 --- a/dimos/e2e_tests/conftest.py +++ b/dimos/e2e_tests/conftest.py @@ -22,7 +22,8 @@ from dimos.e2e_tests.conf_types import StartPersonTrack from dimos.e2e_tests.dimos_cli_call import DimosCliCall from dimos.e2e_tests.lcm_spy import LcmSpy -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.std_msgs.Bool import Bool from dimos.simulation.mujoco.person_on_track import PersonTrackPublisher diff --git a/dimos/e2e_tests/lcm_spy.py b/dimos/e2e_tests/lcm_spy.py index 9efed09d5e..030591f52e 100644 --- a/dimos/e2e_tests/lcm_spy.py +++ b/dimos/e2e_tests/lcm_spy.py @@ -22,8 +22,8 @@ import lcm -from dimos.msgs import DimosMsg -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.protocol import DimosMsg from dimos.protocol.service.lcmservice import LCMService diff --git a/dimos/e2e_tests/test_control_coordinator.py b/dimos/e2e_tests/test_control_coordinator.py index 5bb7a096f7..80b63c529f 100644 --- a/dimos/e2e_tests/test_control_coordinator.py +++ b/dimos/e2e_tests/test_control_coordinator.py @@ -24,8 +24,10 @@ from dimos.control.coordinator import ControlCoordinator from dimos.core.rpc_client import RPCClient -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint, TrajectoryState +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState @pytest.mark.skipif_in_ci diff --git a/dimos/e2e_tests/test_simulation_module.py b/dimos/e2e_tests/test_simulation_module.py index b5902ad7e2..e08183fc24 100644 --- a/dimos/e2e_tests/test_simulation_module.py +++ b/dimos/e2e_tests/test_simulation_module.py @@ -16,7 +16,9 @@ import pytest -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState def _positions_within_tolerance( diff --git a/dimos/exceptions/__init__.py b/dimos/exceptions/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/__init__.py b/dimos/hardware/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/drive_trains/__init__.py b/dimos/hardware/drive_trains/__init__.py deleted file mode 100644 index c6e843feea..0000000000 --- a/dimos/hardware/drive_trains/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Drive train hardware adapters for velocity-commanded platforms.""" diff --git a/dimos/hardware/drive_trains/flowbase/__init__.py b/dimos/hardware/drive_trains/flowbase/__init__.py deleted file mode 100644 index 25f95e399c..0000000000 --- a/dimos/hardware/drive_trains/flowbase/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""FlowBase twist base adapter for holonomic base control via Portal RPC.""" diff --git a/dimos/hardware/drive_trains/flowbase/adapter.py b/dimos/hardware/drive_trains/flowbase/adapter.py index 5b5563792d..ec96365c78 100644 --- a/dimos/hardware/drive_trains/flowbase/adapter.py +++ b/dimos/hardware/drive_trains/flowbase/adapter.py @@ -62,10 +62,6 @@ def __init__(self, dof: int = 3, address: str | None = None, **_: object) -> Non # Last commanded velocities (in standard frame, before negation) self._last_velocities = [0.0, 0.0, 0.0] - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to FlowBase controller via Portal RPC.""" try: @@ -98,18 +94,10 @@ def is_connected(self) -> bool: """Check if connected to FlowBase.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_dof(self) -> int: """FlowBase is always 3 DOF (vx, vy, wz).""" return 3 - # ========================================================================= - # State Reading - # ========================================================================= - def read_velocities(self) -> list[float]: """Return last commanded velocities (FlowBase doesn't report actual).""" with self._lock: @@ -134,10 +122,6 @@ def read_odometry(self) -> list[float] | None: logger.error(f"Error reading FlowBase odometry: {e}") return None - # ========================================================================= - # Control - # ========================================================================= - def write_velocities(self, velocities: list[float]) -> bool: """Send velocity command to FlowBase. @@ -165,10 +149,6 @@ def write_stop(self) -> bool: return False return self._send_velocity(0.0, 0.0, 0.0) - # ========================================================================= - # Enable/Disable - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable the platform (FlowBase is always enabled when connected).""" self._enabled = enable @@ -178,10 +158,6 @@ def read_enabled(self) -> bool: """Check if platform is enabled.""" return self._enabled - # ========================================================================= - # Internal - # ========================================================================= - def _send_velocity(self, vx: float, vy: float, wz: float) -> bool: """Send raw velocity to FlowBase via Portal RPC.""" try: diff --git a/dimos/hardware/drive_trains/mock/__init__.py b/dimos/hardware/drive_trains/mock/__init__.py deleted file mode 100644 index 9b6f630040..0000000000 --- a/dimos/hardware/drive_trains/mock/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mock twist base adapter for testing without hardware. - -Usage: - >>> from dimos.hardware.drive_trains.mock import MockTwistBaseAdapter - >>> adapter = MockTwistBaseAdapter(dof=3) - >>> adapter.connect() - True - >>> adapter.write_velocities([0.5, 0.0, 0.1]) - True - >>> adapter.read_velocities() - [0.5, 0.0, 0.1] -""" - -from dimos.hardware.drive_trains.mock.adapter import MockTwistBaseAdapter - -__all__ = ["MockTwistBaseAdapter"] diff --git a/dimos/hardware/drive_trains/mock/adapter.py b/dimos/hardware/drive_trains/mock/adapter.py index 2091ec59d0..d6131305e6 100644 --- a/dimos/hardware/drive_trains/mock/adapter.py +++ b/dimos/hardware/drive_trains/mock/adapter.py @@ -48,10 +48,6 @@ def __init__(self, dof: int = 3, **_: object) -> None: self._enabled = False self._connected = False - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Simulate connection.""" self._connected = True @@ -65,18 +61,10 @@ def is_connected(self) -> bool: """Check mock connection status.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_dof(self) -> int: """Return DOF.""" return self._dof - # ========================================================================= - # State Reading - # ========================================================================= - def read_velocities(self) -> list[float]: """Return mock velocities.""" return self._velocities.copy() @@ -87,10 +75,6 @@ def read_odometry(self) -> list[float] | None: return None return self._odometry.copy() - # ========================================================================= - # Control - # ========================================================================= - def write_velocities(self, velocities: list[float]) -> bool: """Set mock velocities.""" if len(velocities) != self._dof: @@ -103,10 +87,6 @@ def write_stop(self) -> bool: self._velocities = [0.0] * self._dof return True - # ========================================================================= - # Enable/Disable - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable mock platform.""" self._enabled = enable @@ -116,10 +96,6 @@ def read_enabled(self) -> bool: """Check mock enable state.""" return self._enabled - # ========================================================================= - # Test Helpers (not part of Protocol) - # ========================================================================= - def set_odometry(self, odometry: list[float] | None) -> None: """Set odometry directly for testing.""" self._odometry = list(odometry) if odometry is not None else None diff --git a/dimos/hardware/drive_trains/spec.py b/dimos/hardware/drive_trains/spec.py index 0b288edfd4..1380ef1fa9 100644 --- a/dimos/hardware/drive_trains/spec.py +++ b/dimos/hardware/drive_trains/spec.py @@ -35,8 +35,6 @@ class TwistBaseAdapter(Protocol): - Angle: radians """ - # --- Connection --- - def connect(self) -> bool: """Connect to hardware. Returns True on success.""" ... @@ -49,14 +47,10 @@ def is_connected(self) -> bool: """Check if connected.""" ... - # --- Info --- - def get_dof(self) -> int: """Get number of velocity DOFs (e.g., 3 for holonomic, 2 for differential).""" ... - # --- State Reading --- - def read_velocities(self) -> list[float]: """Read current velocities in virtual joint order (m/s or rad/s).""" ... @@ -69,8 +63,6 @@ def read_odometry(self) -> list[float] | None: """ ... - # --- Control --- - def write_velocities(self, velocities: list[float]) -> bool: """Command velocities in virtual joint order. Returns success.""" ... @@ -79,8 +71,6 @@ def write_stop(self) -> bool: """Stop all motion immediately (zero velocities).""" ... - # --- Enable/Disable --- - def write_enable(self, enable: bool) -> bool: """Enable or disable the platform. Returns success.""" ... diff --git a/dimos/hardware/end_effectors/__init__.py b/dimos/hardware/end_effectors/__init__.py deleted file mode 100644 index 9a7aa9759a..0000000000 --- a/dimos/hardware/end_effectors/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .end_effector import EndEffector - -__all__ = ["EndEffector"] diff --git a/dimos/hardware/manipulators/__init__.py b/dimos/hardware/manipulators/__init__.py deleted file mode 100644 index 58986c9211..0000000000 --- a/dimos/hardware/manipulators/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulator drivers for robotic arms. - -Architecture: Protocol-based adapters for different manipulator hardware. -- spec.py: ManipulatorAdapter Protocol and shared types -- xarm/: XArm adapter -- piper/: Piper adapter -- mock/: Mock adapter for testing - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> arm = XArm(ip="192.168.1.185") - >>> arm.start() - >>> arm.enable_servos() - >>> arm.move_joint([0, 0, 0, 0, 0, 0]) - -Testing: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> from dimos.hardware.manipulators.mock import MockAdapter - >>> arm = XArm(adapter=MockAdapter()) - >>> arm.start() # No hardware needed! -""" - -from dimos.hardware.manipulators.spec import ( - ControlMode, - DriverStatus, - JointLimits, - ManipulatorAdapter, - ManipulatorInfo, -) - -__all__ = [ - "ControlMode", - "DriverStatus", - "JointLimits", - "ManipulatorAdapter", - "ManipulatorInfo", -] diff --git a/dimos/hardware/manipulators/mock/__init__.py b/dimos/hardware/manipulators/mock/__init__.py deleted file mode 100644 index 63be6f7e98..0000000000 --- a/dimos/hardware/manipulators/mock/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Mock adapter for testing manipulator drivers without hardware. - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArm - >>> from dimos.hardware.manipulators.mock import MockAdapter - >>> arm = XArm(adapter=MockAdapter()) - >>> arm.start() # No hardware needed! - >>> arm.move_joint([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) - >>> assert arm.adapter.read_joint_positions() == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] -""" - -from dimos.hardware.manipulators.mock.adapter import MockAdapter - -__all__ = ["MockAdapter"] diff --git a/dimos/hardware/manipulators/mock/adapter.py b/dimos/hardware/manipulators/mock/adapter.py index ff299669f7..53c53c722d 100644 --- a/dimos/hardware/manipulators/mock/adapter.py +++ b/dimos/hardware/manipulators/mock/adapter.py @@ -66,10 +66,6 @@ def __init__(self, dof: int = 6, **_: object) -> None: self._error_code: int = 0 self._error_message: str = "" - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Simulate connection.""" self._connected = True @@ -83,10 +79,6 @@ def is_connected(self) -> bool: """Check mock connection status.""" return self._connected - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Return mock info.""" return ManipulatorInfo( @@ -109,10 +101,6 @@ def get_limits(self) -> JointLimits: velocity_max=[1.0] * self._dof, ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set mock control mode.""" self._control_mode = mode @@ -122,10 +110,6 @@ def get_control_mode(self) -> ControlMode: """Get mock control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Return mock joint positions.""" return self._positions.copy() @@ -151,10 +135,6 @@ def read_error(self) -> tuple[int, str]: """Return mock error.""" return self._error_code, self._error_message - # ========================================================================= - # Motion Control - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -178,10 +158,6 @@ def write_stop(self) -> bool: self._velocities = [0.0] * self._dof return True - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable/disable mock servos.""" self._enabled = enable @@ -197,10 +173,6 @@ def write_clear_errors(self) -> bool: self._error_message = "" return True - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Return mock cartesian position.""" return self._cartesian_position.copy() @@ -214,10 +186,6 @@ def write_cartesian_position( self._cartesian_position.update(pose) return True - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Return mock gripper position.""" return self._gripper_position @@ -227,18 +195,10 @@ def write_gripper_position(self, position: float) -> bool: self._gripper_position = position return True - # ========================================================================= - # Force/Torque (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Return mock F/T sensor data (not supported in mock).""" return None - # ========================================================================= - # Test Helpers (not part of Protocol) - # ========================================================================= - def set_error(self, code: int, message: str) -> None: """Inject an error for testing error handling.""" self._error_code = code diff --git a/dimos/hardware/manipulators/piper/__init__.py b/dimos/hardware/manipulators/piper/__init__.py deleted file mode 100644 index bfeb89b1c0..0000000000 --- a/dimos/hardware/manipulators/piper/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Piper manipulator hardware adapter. - -Usage: - >>> from dimos.hardware.manipulators.piper import PiperAdapter - >>> adapter = PiperAdapter(can_port="can0") - >>> adapter.connect() - >>> positions = adapter.read_joint_positions() -""" - -from dimos.hardware.manipulators.piper.adapter import PiperAdapter - -__all__ = ["PiperAdapter"] diff --git a/dimos/hardware/manipulators/piper/adapter.py b/dimos/hardware/manipulators/piper/adapter.py index 68b5769a95..49ed68bcf9 100644 --- a/dimos/hardware/manipulators/piper/adapter.py +++ b/dimos/hardware/manipulators/piper/adapter.py @@ -75,10 +75,6 @@ def __init__( self._enabled: bool = False self._control_mode: ControlMode = ControlMode.POSITION - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to Piper via CAN bus.""" try: @@ -139,10 +135,6 @@ def is_connected(self) -> bool: except Exception: return False - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Get Piper information.""" firmware_version = None @@ -176,10 +168,6 @@ def get_limits(self) -> JointLimits: velocity_max=max_vel, ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set Piper control mode via MotionCtrl_2.""" if not self._sdk: @@ -207,10 +195,6 @@ def get_control_mode(self) -> ControlMode: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Read joint positions (Piper units -> radians).""" if not self._sdk: @@ -295,10 +279,6 @@ def read_error(self) -> tuple[int, str]: return 0, "" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -366,10 +346,6 @@ def write_stop(self) -> bool: # Fallback: disable arm return self.write_enable(False) - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" if not self._sdk: @@ -427,10 +403,6 @@ def write_clear_errors(self) -> bool: time.sleep(0.1) return self.write_enable(True) - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose. @@ -470,10 +442,6 @@ def write_cartesian_position( # Cartesian control not commonly supported in Piper SDK return False - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Read gripper position (percentage -> meters).""" if not self._sdk: @@ -508,10 +476,6 @@ def write_gripper_position(self, position: float) -> bool: return False - # ========================================================================= - # Force/Torque Sensor (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Read F/T sensor data. diff --git a/dimos/hardware/manipulators/registry.py b/dimos/hardware/manipulators/registry.py index 65dbe74b50..9e63fa349b 100644 --- a/dimos/hardware/manipulators/registry.py +++ b/dimos/hardware/manipulators/registry.py @@ -33,7 +33,6 @@ import importlib import logging -import pkgutil from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -78,19 +77,25 @@ def available(self) -> list[str]: def discover(self) -> None: """Discover and register adapters from subpackages. + Scans for subdirectories containing an adapter.py module. Can be called multiple times to pick up newly added adapters. """ - import dimos.hardware.manipulators as pkg + from pathlib import Path - for _, name, ispkg in pkgutil.iter_modules(pkg.__path__): - if not ispkg: + pkg_dir = Path(__file__).parent + for child in sorted(pkg_dir.iterdir()): + if not child.is_dir() or child.name.startswith(("_", ".")): + continue + if not (child / "adapter.py").exists(): continue try: - module = importlib.import_module(f"dimos.hardware.manipulators.{name}.adapter") + module = importlib.import_module( + f"dimos.hardware.manipulators.{child.name}.adapter" + ) if hasattr(module, "register"): module.register(self) except ImportError as e: - logger.debug(f"Skipping adapter {name}: {e}") + logger.debug(f"Skipping adapter {child.name}: {e}") adapter_registry = AdapterRegistry() diff --git a/dimos/hardware/manipulators/spec.py b/dimos/hardware/manipulators/spec.py index ff4d38c54f..868b714bfa 100644 --- a/dimos/hardware/manipulators/spec.py +++ b/dimos/hardware/manipulators/spec.py @@ -26,11 +26,9 @@ from enum import Enum from typing import Protocol, runtime_checkable -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 - -# ============================================================================ -# SHARED TYPES -# ============================================================================ +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class DriverStatus(Enum): @@ -83,11 +81,6 @@ def default_base_transform() -> Transform: ) -# ============================================================================ -# ADAPTER PROTOCOL -# ============================================================================ - - @runtime_checkable class ManipulatorAdapter(Protocol): """Protocol for hardware-specific IO. @@ -100,8 +93,6 @@ class ManipulatorAdapter(Protocol): - Force: Newtons """ - # --- Connection --- - def connect(self) -> bool: """Connect to hardware. Returns True on success.""" ... @@ -114,8 +105,6 @@ def is_connected(self) -> bool: """Check if connected.""" ... - # --- Info --- - def get_info(self) -> ManipulatorInfo: """Get manipulator info (vendor, model, DOF).""" ... @@ -128,8 +117,6 @@ def get_limits(self) -> JointLimits: """Get joint limits.""" ... - # --- Control Mode --- - def set_control_mode(self, mode: ControlMode) -> bool: """Set control mode (position, velocity, torque, cartesian, etc). @@ -152,8 +139,6 @@ def get_control_mode(self) -> ControlMode: """ ... - # --- State Reading --- - def read_joint_positions(self) -> list[float]: """Read current joint positions (radians).""" ... @@ -174,8 +159,6 @@ def read_error(self) -> tuple[int, str]: """Read error code and message. (0, '') means no error.""" ... - # --- Motion Control (Joint Space) --- - def write_joint_positions( self, positions: list[float], @@ -192,8 +175,6 @@ def write_stop(self) -> bool: """Stop all motion immediately.""" ... - # --- Servo Control --- - def write_enable(self, enable: bool) -> bool: """Enable or disable servos. Returns success.""" ... @@ -206,7 +187,6 @@ def write_clear_errors(self) -> bool: """Clear error state. Returns success.""" ... - # --- Optional: Cartesian Control --- # Return None/False if not supported def read_cartesian_position(self) -> dict[str, float] | None: @@ -234,8 +214,6 @@ def write_cartesian_position( """ ... - # --- Optional: Gripper --- - def read_gripper_position(self) -> float | None: """Read gripper position (meters). None if no gripper.""" ... @@ -244,8 +222,6 @@ def write_gripper_position(self, position: float) -> bool: """Command gripper position. False if no gripper.""" ... - # --- Optional: Force/Torque Sensor --- - def read_force_torque(self) -> list[float] | None: """Read F/T sensor [fx, fy, fz, tx, ty, tz]. None if no sensor.""" ... diff --git a/dimos/hardware/manipulators/xarm/__init__.py b/dimos/hardware/manipulators/xarm/__init__.py deleted file mode 100644 index 8bcab667c1..0000000000 --- a/dimos/hardware/manipulators/xarm/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""XArm manipulator hardware adapter. - -Usage: - >>> from dimos.hardware.manipulators.xarm import XArmAdapter - >>> adapter = XArmAdapter(ip="192.168.1.185", dof=6) - >>> adapter.connect() - >>> positions = adapter.read_joint_positions() -""" - -from dimos.hardware.manipulators.xarm.adapter import XArmAdapter - -__all__ = ["XArmAdapter"] diff --git a/dimos/hardware/manipulators/xarm/adapter.py b/dimos/hardware/manipulators/xarm/adapter.py index 80cc8edb38..3e24c530d1 100644 --- a/dimos/hardware/manipulators/xarm/adapter.py +++ b/dimos/hardware/manipulators/xarm/adapter.py @@ -64,10 +64,6 @@ def __init__(self, address: str, dof: int = 6, **_: object) -> None: self._control_mode: ControlMode = ControlMode.POSITION self._gripper_enabled: bool = False - # ========================================================================= - # Connection - # ========================================================================= - def connect(self) -> bool: """Connect to XArm via TCP/IP.""" try: @@ -98,10 +94,6 @@ def is_connected(self) -> bool: """Check if connected to XArm.""" return self._arm is not None and self._arm.connected - # ========================================================================= - # Info - # ========================================================================= - def get_info(self) -> ManipulatorInfo: """Get XArm information.""" return ManipulatorInfo( @@ -124,10 +116,6 @@ def get_limits(self) -> JointLimits: velocity_max=[math.pi] * self._dof, # ~180 deg/s ) - # ========================================================================= - # Control Mode - # ========================================================================= - def set_control_mode(self, mode: ControlMode) -> bool: """Set XArm control mode. @@ -161,10 +149,6 @@ def get_control_mode(self) -> ControlMode: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= - def read_joint_positions(self) -> list[float]: """Read joint positions (degrees -> radians).""" if not self._arm: @@ -214,10 +198,6 @@ def read_error(self) -> tuple[int, str]: return 0, "" return code, f"XArm error {code}" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= - def write_joint_positions( self, positions: list[float], @@ -263,10 +243,6 @@ def write_stop(self) -> bool: code: int = self._arm.emergency_stop() return code == 0 - # ========================================================================= - # Servo Control - # ========================================================================= - def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" if not self._arm: @@ -289,10 +265,6 @@ def write_clear_errors(self) -> bool: code: int = self._arm.clean_error() return code == 0 - # ========================================================================= - # Cartesian Control (Optional) - # ========================================================================= - def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose (mm -> meters, degrees -> radians).""" if not self._arm: @@ -331,10 +303,6 @@ def write_cartesian_position( ) return code == 0 - # ========================================================================= - # Gripper (Optional) - # ========================================================================= - def read_gripper_position(self) -> float | None: """Read gripper position (mm -> meters).""" if not self._arm: @@ -359,10 +327,6 @@ def write_gripper_position(self, position: float) -> bool: code: int = self._arm.set_gripper_position(pos_mm, wait=False) return code == 0 - # ========================================================================= - # Force/Torque Sensor (Optional) - # ========================================================================= - def read_force_torque(self) -> list[float] | None: """Read F/T sensor data if available.""" if not self._arm: diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py index 9161185d50..c723cab130 100644 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py @@ -14,18 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import logging import sys import threading import time +from typing import Any import numpy as np from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.logging_config import setup_logger # Add system path for gi module if needed @@ -43,28 +43,22 @@ Gst.init(None) -@dataclass class Config(ModuleConfig): frame_id: str = "camera" + host: str = "localhost" + port: int = 5000 + timestamp_offset: float = 0.0 + reconnect_interval: float = 5.0 -class GstreamerCameraModule(Module): +class GstreamerCameraModule(Module[Config]): """Module that captures frames from a remote camera using GStreamer TCP with absolute timestamps.""" default_config = Config - config: Config video: Out[Image] - def __init__( # type: ignore[no-untyped-def] - self, - host: str = "localhost", - port: int = 5000, - timestamp_offset: float = 0.0, - reconnect_interval: float = 5.0, - *args, - **kwargs, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the GStreamer TCP camera module. Args: @@ -74,10 +68,10 @@ def __init__( # type: ignore[no-untyped-def] timestamp_offset: Offset to add to timestamps (useful for clock synchronization) reconnect_interval: Seconds to wait before attempting reconnection """ - self.host = host - self.port = port - self.timestamp_offset = timestamp_offset - self.reconnect_interval = reconnect_interval + super().__init__(**kwargs) + self.host = self.config.host + self.port = self.config.port + self.reconnect_interval = self.config.reconnect_interval self.pipeline = None self.appsink = None @@ -88,7 +82,6 @@ def __init__( # type: ignore[no-untyped-def] self.frame_count = 0 self.last_log_time = time.time() self.reconnect_timer_id = None - super().__init__(**kwargs) @rpc def start(self) -> None: @@ -257,7 +250,7 @@ def _on_new_sample(self, appsink): # type: ignore[no-untyped-def] if buffer.pts != Gst.CLOCK_TIME_NONE: # Convert nanoseconds to seconds and add offset # This is the absolute time from when the frame was captured - timestamp = (buffer.pts / 1e9) + self.timestamp_offset + timestamp = (buffer.pts / 1e9) + self.config.timestamp_offset # Skip frames with invalid timestamps (before year 2000) # This filters out initial gray frames with relative timestamps diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py index 8785a9260b..a18d52fbb0 100755 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera_test_script.py @@ -21,8 +21,8 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.transport import LCMTransport from dimos.hardware.sensors.camera.gstreamer.gstreamer_camera import GstreamerCameraModule -from dimos.msgs.sensor_msgs import Image -from dimos.protocol import pubsub +from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.impl import lcmpubsub as _lcm logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def main() -> None: logging.getLogger().setLevel(logging.DEBUG) # Initialize LCM - pubsub.lcm.autoconf() # type: ignore[attr-defined] + _lcm.autoconf() # type: ignore[attr-defined] # Start dimos dimos = ModuleCoordinator() diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 11821d4724..e0d0b3407e 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -13,21 +13,22 @@ # limitations under the License. from collections.abc import Callable -from dataclasses import dataclass, field import time from typing import Any +from pydantic import Field import reactivex as rx from dimos.agents.annotation import skill from dimos.core.blueprints import autoconnect from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware from dimos.hardware.sensors.camera.webcam import Webcam -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception @@ -43,10 +44,9 @@ def default_transform() -> Transform: ) -@dataclass class CameraModuleConfig(ModuleConfig): frame_id: str = "camera_link" - transform: Transform | None = field(default_factory=default_transform) + transform: Transform | None = Field(default_factory=default_transform) hardware: Callable[[], CameraHardware[Any]] | CameraHardware[Any] = Webcam frequency: float = 0.0 # Hz, 0 means no limit @@ -55,16 +55,9 @@ class CameraModule(Module[CameraModuleConfig], perception.Camera): color_image: Out[Image] camera_info: Out[CameraInfo] - hardware: CameraHardware[Any] - - config: CameraModuleConfig default_config = CameraModuleConfig - _global_config: GlobalConfig - - def __init__(self, *args: Any, cfg: GlobalConfig = global_config, **kwargs: Any) -> None: - self._global_config = cfg - self._latest_image: Image | None = None - super().__init__(*args, **kwargs) + hardware: CameraHardware[Any] + _latest_image: Image | None = None @rpc def start(self) -> None: diff --git a/dimos/hardware/sensors/camera/realsense/__init__.py b/dimos/hardware/sensors/camera/realsense/__init__.py deleted file mode 100644 index 58f519a12e..0000000000 --- a/dimos/hardware/sensors/camera/realsense/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dimos.hardware.sensors.camera.realsense.camera import ( - RealSenseCamera, - RealSenseCameraConfig, - realsense_camera, - ) - -__all__ = ["RealSenseCamera", "RealSenseCameraConfig", "realsense_camera"] - - -def __getattr__(name: str) -> object: - if name in __all__: - from dimos.hardware.sensors.camera.realsense.camera import ( - RealSenseCamera, - RealSenseCameraConfig, - realsense_camera, - ) - - globals().update( - RealSenseCamera=RealSenseCamera, - RealSenseCameraConfig=RealSenseCameraConfig, - realsense_camera=realsense_camera, - ) - return globals()[name] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/hardware/sensors/camera/realsense/camera.py b/dimos/hardware/sensors/camera/realsense/camera.py index f34b9a2881..23bc19cdad 100644 --- a/dimos/hardware/sensors/camera/realsense/camera.py +++ b/dimos/hardware/sensors/camera/realsense/camera.py @@ -15,13 +15,13 @@ from __future__ import annotations import atexit -from dataclasses import dataclass, field import threading import time from typing import TYPE_CHECKING import cv2 import numpy as np +from pydantic import Field import reactivex as rx from scipy.spatial.transform import Rotation # type: ignore[import-untyped] @@ -35,8 +35,10 @@ DepthCameraConfig, DepthCameraHardware, ) -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.foxglove_bridge import FoxgloveBridge @@ -55,14 +57,13 @@ def default_base_transform() -> Transform: ) -@dataclass class RealSenseCameraConfig(ModuleConfig, DepthCameraConfig): width: int = 848 height: int = 480 fps: int = 15 camera_name: str = "camera" base_frame_id: str = "base_link" - base_transform: Transform | None = field(default_factory=default_base_transform) + base_transform: Transform | None = Field(default_factory=default_base_transform) align_depth_to_color: bool = True enable_depth: bool = True enable_pointcloud: bool = False @@ -71,14 +72,13 @@ class RealSenseCameraConfig(ModuleConfig, DepthCameraConfig): serial_number: str | None = None -class RealSenseCamera(DepthCameraHardware, Module, perception.DepthCamera): +class RealSenseCamera(DepthCameraHardware, Module[RealSenseCameraConfig], perception.DepthCamera): color_image: Out[Image] depth_image: Out[Image] pointcloud: Out[PointCloud2] camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] - config: RealSenseCameraConfig default_config = RealSenseCameraConfig @property diff --git a/dimos/hardware/sensors/camera/spec.py b/dimos/hardware/sensors/camera/spec.py index 23fd1a076e..dcb0196ff2 100644 --- a/dimos/hardware/sensors/camera/spec.py +++ b/dimos/hardware/sensors/camera/spec.py @@ -13,19 +13,20 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Generic, Protocol, TypeVar +from typing import TypeVar from reactivex.observable import Observable -from dimos.msgs.geometry_msgs import Quaternion, Transform -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Configurable OPTICAL_ROTATION = Quaternion(-0.5, 0.5, -0.5, 0.5) -class CameraConfig(Protocol): +class CameraConfig(BaseConfig): frame_id_prefix: str | None width: int height: int @@ -35,7 +36,7 @@ class CameraConfig(Protocol): CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) -class CameraHardware(ABC, Configurable[CameraConfigT], Generic[CameraConfigT]): +class CameraHardware(ABC, Configurable[CameraConfigT]): @abstractmethod def image_stream(self) -> Observable[Image]: pass @@ -62,8 +63,6 @@ class DepthCameraConfig(CameraConfig): class DepthCameraHardware(ABC): """Abstract class for depth camera modules (RealSense, ZED, etc.).""" - config: DepthCameraConfig - @abstractmethod def get_color_camera_info(self) -> CameraInfo | None: """Get color camera intrinsics.""" diff --git a/dimos/hardware/sensors/camera/webcam.py b/dimos/hardware/sensors/camera/webcam.py index 51199624fe..cfd1a080a0 100644 --- a/dimos/hardware/sensors/camera/webcam.py +++ b/dimos/hardware/sensors/camera/webcam.py @@ -23,8 +23,8 @@ from reactivex.observable import Observable from dimos.hardware.sensors.camera.spec import CameraConfig, CameraHardware -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.reactive import backpressure diff --git a/dimos/hardware/sensors/camera/zed/camera.py b/dimos/hardware/sensors/camera/zed/camera.py index 6ce2fc86b2..214b1f73e3 100644 --- a/dimos/hardware/sensors/camera/zed/camera.py +++ b/dimos/hardware/sensors/camera/zed/camera.py @@ -15,11 +15,11 @@ from __future__ import annotations import atexit -from dataclasses import dataclass, field import threading import time import cv2 +from pydantic import Field import pyzed.sl as sl import reactivex as rx @@ -33,8 +33,10 @@ DepthCameraConfig, DepthCameraHardware, ) -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.foxglove_bridge import FoxgloveBridge @@ -50,14 +52,13 @@ def default_base_transform() -> Transform: ) -@dataclass class ZEDCameraConfig(ModuleConfig, DepthCameraConfig): width: int = 1280 height: int = 720 fps: int = 15 camera_name: str = "camera" base_frame_id: str = "base_link" - base_transform: Transform | None = field(default_factory=default_base_transform) + base_transform: Transform | None = Field(default_factory=default_base_transform) align_depth_to_color: bool = True enable_depth: bool = True enable_pointcloud: bool = False @@ -76,14 +77,13 @@ class ZEDCameraConfig(ModuleConfig, DepthCameraConfig): world_frame: str = "world" -class ZEDCamera(DepthCameraHardware, Module, perception.DepthCamera): +class ZEDCamera(DepthCameraHardware, Module[ZEDCameraConfig], perception.DepthCamera): color_image: Out[Image] depth_image: Out[Image] pointcloud: Out[PointCloud2] camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] - config: ZEDCameraConfig default_config = ZEDCameraConfig @property diff --git a/dimos/hardware/sensors/camera/zed/__init__.py b/dimos/hardware/sensors/camera/zed/compat.py similarity index 89% rename from dimos/hardware/sensors/camera/zed/__init__.py rename to dimos/hardware/sensors/camera/zed/compat.py index f8e73273bf..3cec8d9566 100644 --- a/dimos/hardware/sensors/camera/zed/__init__.py +++ b/dimos/hardware/sensors/camera/zed/compat.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ZED camera hardware interfaces.""" +"""ZED camera compatibility layer and SDK detection.""" from pathlib import Path from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider -# Check if ZED SDK is available try: - import pyzed.sl as sl # noqa: F401 + import pyzed.sl # noqa: F401 + # This awkwardness is needed as pytest implicitly imports this to collect + # the test in this directory. HAS_ZED_SDK = True except ImportError: HAS_ZED_SDK = False -# Only import ZED classes if SDK is available if HAS_ZED_SDK: from dimos.hardware.sensors.camera.zed.camera import ZEDCamera, ZEDModule, zed_camera else: @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." ) - def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[no-redef] + def zed_camera(*args: object, **kwargs: object) -> None: # type: ignore[misc,no-redef] raise ModuleNotFoundError( "ZED SDK not installed. Please install pyzed package to use ZED camera functionality.", name="pyzed", diff --git a/dimos/hardware/sensors/camera/zed/test_zed.py b/dimos/hardware/sensors/camera/zed/test_zed.py index 2d912553c6..a98055a355 100644 --- a/dimos/hardware/sensors/camera/zed/test_zed.py +++ b/dimos/hardware/sensors/camera/zed/test_zed.py @@ -13,14 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + +from dimos.hardware.sensors.camera.zed import compat as zed from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +@pytest.mark.skipif(not zed.HAS_ZED_SDK, reason="ZED SDK not installed") def test_zed_import_and_calibration_access() -> None: """Test that zed module can be imported and calibrations accessed.""" - # Import zed module from camera - from dimos.hardware.sensors.camera import zed - # Test that CameraInfo is accessible assert hasattr(zed, "CameraInfo") diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index ec5613077d..16e85aa93c 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -17,9 +17,9 @@ FakeZEDModule - Replays recorded ZED data for testing without hardware. """ -from dataclasses import dataclass import functools import logging +from typing import Any from dimos_lcm.sensor_msgs import CameraInfo import numpy as np @@ -27,18 +27,18 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.protocol.tf.tf import TF from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger(level=logging.INFO) -@dataclass class FakeZEDModuleConfig(ModuleConfig): + recording_path: str frame_id: str = "zed_camera" @@ -54,9 +54,8 @@ class FakeZEDModule(Module[FakeZEDModuleConfig]): pose: Out[PoseStamped] default_config = FakeZEDModuleConfig - config: FakeZEDModuleConfig - def __init__(self, recording_path: str, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize FakeZEDModule with recording path. @@ -65,7 +64,7 @@ def __init__(self, recording_path: str, **kwargs: object) -> None: """ super().__init__(**kwargs) - self.recording_path = recording_path + self.recording_path = self.config.recording_path self._running = False # Initialize TF publisher @@ -279,7 +278,9 @@ def _publish_pose(self, msg) -> None: # type: ignore[no-untyped-def] # Publish TF transform from world to camera import time - from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.msgs.geometry_msgs.Vector3 import Vector3 transform = Transform( translation=Vector3(*msg.position), diff --git a/dimos/hardware/sensors/lidar/__init__.py b/dimos/hardware/sensors/lidar/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/sensors/lidar/fastlio2/__init__.py b/dimos/hardware/sensors/lidar/fastlio2/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index fb894ddce5..c1a96a525b 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -30,12 +30,13 @@ from __future__ import annotations -from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated + +from pydantic.experimental.pipeline import validate_as from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out # noqa: TC001 +from dimos.core.stream import Out from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -48,14 +49,13 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.nav_msgs.Odometry import Odometry # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import mapping, perception _CONFIG_DIR = Path(__file__).parent / "config" -@dataclass(kw_only=True) class FastLio2Config(NativeModuleConfig): """Config for the FAST-LIO2 + Livox Mid-360 native module.""" @@ -92,7 +92,9 @@ class FastLio2Config(NativeModuleConfig): # FAST-LIO YAML config (relative to config/ dir, or absolute path) # C++ binary reads YAML directly via yaml-cpp - config: str = "mid360.yaml" + config: Annotated[ + Path, validate_as(...).transform(lambda p: p if p.is_absolute() else _CONFIG_DIR / p) + ] = Path("mid360.yaml") # SDK port configuration (see livox/ports.py for defaults) cmd_data_port: int = SDK_CMD_DATA_PORT @@ -112,15 +114,10 @@ class FastLio2Config(NativeModuleConfig): # config is not a CLI arg (config_path is) cli_exclude: frozenset[str] = frozenset({"config"}) - def __post_init__(self) -> None: - if self.config_path is None: - path = Path(self.config) - if not path.is_absolute(): - path = _CONFIG_DIR / path - self.config_path = str(path.resolve()) - -class FastLio2(NativeModule, perception.Lidar, perception.Odometry, mapping.GlobalPointcloud): +class FastLio2( + NativeModule[FastLio2Config], perception.Lidar, perception.Odometry, mapping.GlobalPointcloud +): """FAST-LIO2 SLAM module with integrated Livox Mid-360 driver. Ports: @@ -129,7 +126,7 @@ class FastLio2(NativeModule, perception.Lidar, perception.Odometry, mapping.Glob global_map (Out[PointCloud2]): Global voxel map (optional, enable via map_freq > 0). """ - default_config: type[FastLio2Config] = FastLio2Config # type: ignore[assignment] + default_config = FastLio2Config lidar: Out[PointCloud2] odometry: Out[Odometry] global_map: Out[PointCloud2] diff --git a/dimos/hardware/sensors/lidar/livox/__init__.py b/dimos/hardware/sensors/lidar/livox/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/hardware/sensors/lidar/livox/module.py b/dimos/hardware/sensors/lidar/livox/module.py index 2e470b21ef..999cdd9aa1 100644 --- a/dimos/hardware/sensors/lidar/livox/module.py +++ b/dimos/hardware/sensors/lidar/livox/module.py @@ -26,11 +26,10 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING from dimos.core.native_module import NativeModule, NativeModuleConfig -from dimos.core.stream import Out # noqa: TC001 +from dimos.core.stream import Out from dimos.hardware.sensors.lidar.livox.ports import ( SDK_CMD_DATA_PORT, SDK_HOST_CMD_DATA_PORT, @@ -43,12 +42,11 @@ SDK_POINT_DATA_PORT, SDK_PUSH_MSG_PORT, ) -from dimos.msgs.sensor_msgs.Imu import Imu # noqa: TC001 -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # noqa: TC001 +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.spec import perception -@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): """Config for the C++ Mid-360 native module.""" @@ -76,7 +74,7 @@ class Mid360Config(NativeModuleConfig): host_log_data_port: int = SDK_HOST_LOG_DATA_PORT -class Mid360(NativeModule, perception.Lidar, perception.IMU): +class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): """Livox Mid-360 LiDAR module backed by a native C++ binary. Ports: diff --git a/dimos/manipulation/__init__.py b/dimos/manipulation/__init__.py deleted file mode 100644 index d2a511d146..0000000000 --- a/dimos/manipulation/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulation module for robot arm motion planning and control.""" - -from dimos.manipulation.manipulation_module import ( - ManipulationModule, - ManipulationModuleConfig, - ManipulationState, - manipulation_module, -) -from dimos.manipulation.pick_and_place_module import ( - PickAndPlaceModule, - PickAndPlaceModuleConfig, - pick_and_place_module, -) - -__all__ = [ - "ManipulationModule", - "ManipulationModuleConfig", - "ManipulationState", - "PickAndPlaceModule", - "PickAndPlaceModuleConfig", - "manipulation_module", - "pick_and_place_module", -] diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 97657b9cae..8ef2c03279 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -35,20 +35,19 @@ from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera.realsense import realsense_camera +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera from dimos.manipulation.manipulation_module import manipulation_module from dimos.manipulation.pick_and_place_module import pick_and_place_module -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import object_scene_registration_module from dimos.robot.foxglove_bridge import foxglove_bridge # TODO: migrate to rerun from dimos.utils.data import get_data -# ============================================================================= -# Pose Helpers -# ============================================================================= - def _make_base_pose( x: float = 0.0, @@ -70,11 +69,6 @@ def _make_base_pose( ) -# ============================================================================= -# URDF Helpers -# ============================================================================= - - def _get_xarm_urdf_path() -> Path: """Get path to xarm URDF.""" return get_data("xarm_description") / "urdf/xarm_device.urdf.xacro" @@ -133,11 +127,6 @@ def _get_piper_package_paths() -> dict[str, Path]: ] -# ============================================================================= -# Robot Configs -# ============================================================================= - - def _make_xarm6_config( name: str = "arm", y_offset: float = 0.0, @@ -283,11 +272,6 @@ def _make_piper_config( ) -# ============================================================================= -# Blueprints -# ============================================================================= - - # Single XArm6 planner (standalone, no coordinator) xarm6_planner_only = manipulation_module( robots=[_make_xarm6_config()], diff --git a/dimos/manipulation/control/__init__.py b/dimos/manipulation/control/__init__.py deleted file mode 100644 index ec85660eb3..0000000000 --- a/dimos/manipulation/control/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Control Modules - -Hardware-agnostic controllers for robotic manipulation tasks. - -Submodules: -- servo_control: Real-time servo-level controllers (Cartesian motion control) -- trajectory_controller: Trajectory planning and execution -""" - -# Re-export from servo_control for backwards compatibility -from dimos.manipulation.control.servo_control import ( - CartesianMotionController, - CartesianMotionControllerConfig, - cartesian_motion_controller, -) - -# Re-export from trajectory_controller -from dimos.manipulation.control.trajectory_controller import ( - JointTrajectoryController, - JointTrajectoryControllerConfig, - joint_trajectory_controller, -) - -__all__ = [ - # Servo control - "CartesianMotionController", - "CartesianMotionControllerConfig", - # Trajectory control - "JointTrajectoryController", - "JointTrajectoryControllerConfig", - "cartesian_motion_controller", - "joint_trajectory_controller", -] diff --git a/dimos/manipulation/control/coordinator_client.py b/dimos/manipulation/control/coordinator_client.py index 4e277fae97..dfa99371a6 100644 --- a/dimos/manipulation/control/coordinator_client.py +++ b/dimos/manipulation/control/coordinator_client.py @@ -54,7 +54,7 @@ ) if TYPE_CHECKING: - from dimos.msgs.trajectory_msgs import JointTrajectory + from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class CoordinatorClient: @@ -98,10 +98,6 @@ def stop(self) -> None: """Stop the RPC client.""" self._rpc.stop_rpc_client() - # ========================================================================= - # Query methods (RPC calls) - # ========================================================================= - def list_hardware(self) -> list[str]: """List all hardware IDs.""" return self._rpc.list_hardware() or [] @@ -129,10 +125,6 @@ def get_trajectory_status(self, task_name: str) -> dict[str, Any]: return {"state": int(result), "task": task_name} return {} - # ========================================================================= - # Trajectory execution (via task_invoke) - # ========================================================================= - def execute_trajectory(self, task_name: str, trajectory: JointTrajectory) -> bool: """Execute a trajectory on a task via task_invoke.""" result = self._rpc.task_invoke(task_name, "execute", {"trajectory": trajectory}) @@ -143,10 +135,6 @@ def cancel_trajectory(self, task_name: str) -> bool: result = self._rpc.task_invoke(task_name, "cancel", {}) return bool(result) - # ========================================================================= - # Task selection and setup - # ========================================================================= - def select_task(self, task_name: str) -> bool: """ Select a task and setup its trajectory generator. @@ -248,11 +236,6 @@ def set_acceleration_limit(self, acceleration: float, task_name: str | None = No gen.set_limits(gen.max_velocity, acceleration) -# ============================================================================= -# Interactive CLI -# ============================================================================= - - def parse_joint_input(line: str, num_joints: int) -> list[float] | None: """Parse joint positions from user input (degrees by default, 'r' suffix for radians).""" parts = line.strip().split() diff --git a/dimos/manipulation/control/dual_trajectory_setter.py b/dimos/manipulation/control/dual_trajectory_setter.py index 05793eeb76..3fdccea400 100644 --- a/dimos/manipulation/control/dual_trajectory_setter.py +++ b/dimos/manipulation/control/dual_trajectory_setter.py @@ -37,8 +37,8 @@ from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, ) -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory @dataclass diff --git a/dimos/manipulation/control/servo_control/__init__.py b/dimos/manipulation/control/servo_control/__init__.py deleted file mode 100644 index 5418a7e24b..0000000000 --- a/dimos/manipulation/control/servo_control/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Servo Control Modules - -Real-time servo-level controllers for robotic manipulation. -Includes Cartesian motion control with PID-based tracking. -""" - -from dimos.manipulation.control.servo_control.cartesian_motion_controller import ( - CartesianMotionController, - CartesianMotionControllerConfig, - cartesian_motion_controller, -) - -__all__ = [ - "CartesianMotionController", - "CartesianMotionControllerConfig", - "cartesian_motion_controller", -] diff --git a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py index 2c11b0cc10..0cbd41e218 100644 --- a/dimos/manipulation/control/servo_control/cartesian_motion_controller.py +++ b/dimos/manipulation/control/servo_control/cartesian_motion_controller.py @@ -26,7 +26,6 @@ - Supports velocity-based and position-based control modes """ -from dataclasses import dataclass import math import threading import time @@ -35,18 +34,25 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Twist, Vector3 -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState from dimos.utils.logging_config import setup_logger from dimos.utils.simple_controller import PIDController logger = setup_logger() -@dataclass class CartesianMotionControllerConfig(ModuleConfig): """Configuration for Cartesian motion controller.""" + arm_driver: Any = None + # Control loop parameters control_frequency: float = 20.0 # Hz - Cartesian control loop rate command_timeout: float = 30.0 # seconds - timeout for stale targets (RPC mode needs longer) @@ -78,7 +84,7 @@ class CartesianMotionControllerConfig(ModuleConfig): control_frame: str = "world" # Frame for target poses (world, base_link, etc.) -class CartesianMotionController(Module): +class CartesianMotionController(Module[CartesianMotionControllerConfig]): """ Hardware-agnostic Cartesian motion controller. @@ -94,7 +100,6 @@ class CartesianMotionController(Module): """ default_config = CartesianMotionControllerConfig - config: CartesianMotionControllerConfig # Type hint for proper attribute access # RPC methods to request from other modules (resolved at blueprint build time) rpc_calls = [ @@ -112,7 +117,7 @@ class CartesianMotionController(Module): cartesian_velocity: Out[Twist] = None # type: ignore[assignment] current_pose: Out[PoseStamped] = None # type: ignore[assignment] - def __init__(self, arm_driver: Any = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the Cartesian motion controller. @@ -120,10 +125,10 @@ def __init__(self, arm_driver: Any = None, *args: Any, **kwargs: Any) -> None: arm_driver: (Optional) Hardware driver reference (legacy mode). When using blueprints, this is resolved automatically via rpc_calls. """ - super().__init__(*args, **kwargs) + super().__init__(**kwargs) # Hardware driver reference - set via arm_driver param (legacy) or RPC wiring (blueprint) - self._arm_driver_legacy = arm_driver + self._arm_driver_legacy = self.config.arm_driver # State tracking self._latest_joint_state: JointState | None = None @@ -273,10 +278,6 @@ def stop(self) -> None: super().stop() logger.info("CartesianMotionController stopped") - # ========================================================================= - # RPC Methods - High-level control - # ========================================================================= - @rpc def set_target_pose( self, position: list[float], orientation: list[float], frame_id: str = "world" @@ -350,10 +351,6 @@ def is_converged(self) -> bool: and ori_error < self.config.orientation_tolerance ) - # ========================================================================= - # Private Methods - Callbacks - # ========================================================================= - def _on_joint_state(self, msg: JointState) -> None: """Callback when new joint state is received.""" logger.debug(f"Received joint_state: {len(msg.position)} joints") @@ -373,10 +370,6 @@ def _on_target_pose(self, msg: PoseStamped) -> None: self._is_tracking = True logger.debug(f"New target received: {msg}") - # ========================================================================= - # Private Methods - Control Loop - # ========================================================================= - def _control_loop(self) -> None: """ Main control loop running at control_frequency Hz. diff --git a/dimos/manipulation/control/target_setter.py b/dimos/manipulation/control/target_setter.py index f54a6af2f0..a0228c6a24 100644 --- a/dimos/manipulation/control/target_setter.py +++ b/dimos/manipulation/control/target_setter.py @@ -25,7 +25,9 @@ import time from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class TargetSetter: diff --git a/dimos/manipulation/control/trajectory_controller/__init__.py b/dimos/manipulation/control/trajectory_controller/__init__.py deleted file mode 100644 index fb4360d4cc..0000000000 --- a/dimos/manipulation/control/trajectory_controller/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory Controller Module - -Joint-space trajectory execution for robotic manipulators. -""" - -from dimos.manipulation.control.trajectory_controller.joint_trajectory_controller import ( - JointTrajectoryController, - JointTrajectoryControllerConfig, - joint_trajectory_controller, -) - -__all__ = [ - "JointTrajectoryController", - "JointTrajectoryControllerConfig", - "joint_trajectory_controller", -] diff --git a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py index 1ce3149dd2..465df7afea 100644 --- a/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py +++ b/dimos/manipulation/control/trajectory_controller/joint_trajectory_controller.py @@ -29,7 +29,6 @@ - reset(): Required to recover from FAULT state """ -from dataclasses import dataclass import threading import time from typing import Any @@ -37,21 +36,23 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryState, TrajectoryStatus +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState, TrajectoryStatus from dimos.utils.logging_config import setup_logger logger = setup_logger() -@dataclass class JointTrajectoryControllerConfig(ModuleConfig): """Configuration for joint trajectory controller.""" control_frequency: float = 100.0 # Hz - trajectory execution rate -class JointTrajectoryController(Module): +class JointTrajectoryController(Module[JointTrajectoryControllerConfig]): """ Joint-space trajectory executor. @@ -72,7 +73,6 @@ class JointTrajectoryController(Module): """ default_config = JointTrajectoryControllerConfig - config: JointTrajectoryControllerConfig # Type hint for proper attribute access # Input topics joint_state: In[JointState] = None # type: ignore[assignment] # Feedback from arm driver @@ -82,8 +82,8 @@ class JointTrajectoryController(Module): # Output topics joint_position_command: Out[JointCommand] = None # type: ignore[assignment] # To arm driver - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # State machine self._state = TrajectoryState.IDLE @@ -156,10 +156,6 @@ def stop(self) -> None: super().stop() logger.info("JointTrajectoryController stopped") - # ========================================================================= - # RPC Methods - Action-server-like interface - # ========================================================================= - @rpc def execute_trajectory(self, trajectory: JointTrajectory) -> bool: """ @@ -273,10 +269,6 @@ def get_status(self) -> TrajectoryStatus: error=self._error_message, ) - # ========================================================================= - # Callbacks - # ========================================================================= - def _on_joint_state(self, msg: JointState) -> None: """Callback for joint state feedback.""" self._latest_joint_state = msg @@ -292,10 +284,6 @@ def _on_trajectory(self, msg: JointTrajectory) -> None: ) self.execute_trajectory(msg) - # ========================================================================= - # Execution Loop - # ========================================================================= - def _execution_loop(self) -> None: """ Main execution loop running at control_frequency Hz. diff --git a/dimos/manipulation/control/trajectory_controller/spec.py b/dimos/manipulation/control/trajectory_controller/spec.py index e11da91847..b696f2dc6a 100644 --- a/dimos/manipulation/control/trajectory_controller/spec.py +++ b/dimos/manipulation/control/trajectory_controller/spec.py @@ -30,8 +30,11 @@ if TYPE_CHECKING: from dimos.core.stream import In, Out - from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState - from dimos.msgs.trajectory_msgs import JointTrajectory as JointTrajectoryMsg, TrajectoryState + from dimos.msgs.sensor_msgs.JointCommand import JointCommand + from dimos.msgs.sensor_msgs.JointState import JointState + from dimos.msgs.sensor_msgs.RobotState import RobotState + from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory as JointTrajectoryMsg + from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState # Input topics joint_state: In[JointState] | None = None # Feedback from arm driver diff --git a/dimos/manipulation/control/trajectory_setter.py b/dimos/manipulation/control/trajectory_setter.py index a5baa512b5..25f9db2a3f 100644 --- a/dimos/manipulation/control/trajectory_setter.py +++ b/dimos/manipulation/control/trajectory_setter.py @@ -36,8 +36,8 @@ from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, ) -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class TrajectorySetter: diff --git a/dimos/manipulation/grasping/__init__.py b/dimos/manipulation/grasping/__init__.py deleted file mode 100644 index 41779f55e7..0000000000 --- a/dimos/manipulation/grasping/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dimos.manipulation.grasping.graspgen_module import ( - GraspGenConfig, - GraspGenModule, - graspgen, -) -from dimos.manipulation.grasping.grasping import ( - GraspingModule, - grasping_module, -) - -__all__ = [ - "GraspGenConfig", - "GraspGenModule", - "GraspingModule", - "graspgen", - "grasping_module", -] diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 01e34f905f..43a6c9a20a 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -16,8 +16,8 @@ from dimos.agents.agent import agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense import realsense_camera -from dimos.manipulation.grasping import graspgen +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera +from dimos.manipulation.grasping.graspgen_module import graspgen from dimos.manipulation.grasping.grasping import grasping_module from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import object_scene_registration_module diff --git a/dimos/manipulation/grasping/graspgen_module.py b/dimos/manipulation/grasping/graspgen_module.py index c988d3df51..c883126840 100644 --- a/dimos/manipulation/grasping/graspgen_module.py +++ b/dimos/manipulation/grasping/graspgen_module.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -from dataclasses import dataclass import os from pathlib import Path import sys @@ -26,13 +25,13 @@ from dimos.core.docker_runner import DockerModuleConfig from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseArray -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.PoseArray import PoseArray +from dimos.msgs.std_msgs.Header import Header from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import matrix_to_pose if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() @@ -42,7 +41,6 @@ COLLISION_FILTER_THRESHOLD = 0.02 -@dataclass class GraspGenConfig(DockerModuleConfig): """Configuration for GraspGen module.""" @@ -68,11 +66,9 @@ class GraspGenModule(Module[GraspGenConfig]): default_config = GraspGenConfig grasps: Out[PoseArray] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._sampler = self._gripper_info = None - self._initialized = False + _sampler = None + _gripper_info = None + _initialized = False @rpc def start(self) -> None: @@ -212,7 +208,7 @@ def _run_inference( return grasps_np, scores_np pc_mean = object_pc_filtered.mean(axis=0) - T_center = tra.translation_matrix(-pc_mean) + T_center = tra.translation_matrix(-pc_mean) # type: ignore[no-untyped-call] grasps_centered = np.array([T_center @ g for g in grasps_np]) scene_pc_centered = tra.transform_points(scene_pc, T_center) diff --git a/dimos/manipulation/grasping/grasping.py b/dimos/manipulation/grasping/grasping.py index 433a07d846..ef05dc29e2 100644 --- a/dimos/manipulation/grasping/grasping.py +++ b/dimos/manipulation/grasping/grasping.py @@ -25,12 +25,12 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseArray +from dimos.msgs.geometry_msgs.PoseArray import PoseArray from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import quaternion_to_euler if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py index 524562520d..c60cbfd9c6 100644 --- a/dimos/manipulation/manipulation_interface.py +++ b/dimos/manipulation/manipulation_interface.py @@ -157,8 +157,6 @@ def update_task_result(self, task_id: str, result: dict[str, Any]) -> Manipulati return task return None - # === Perception stream methods === - def _setup_perception_subscription(self) -> None: """ Set up subscription to perception stream if available. @@ -239,8 +237,6 @@ def cleanup_perception_subscription(self) -> None: self.stream_subscription.dispose() self.stream_subscription = None - # === Utility methods === - def clear(self) -> None: """ Clear all manipulation tasks and agent constraints. diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index 40dd6734c5..fe5561c705 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -24,7 +24,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Iterable from enum import Enum import threading import time @@ -34,23 +34,20 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In -from dimos.manipulation.planning import ( - JointPath, +from dimos.manipulation.planning.factory import create_kinematics, create_planner +from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import JointPath, Obstacle, RobotName, WorldRobotID +from dimos.manipulation.planning.spec.protocols import KinematicsSpec, PlannerSpec +from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( JointTrajectoryGenerator, - KinematicsSpec, - Obstacle, - ObstacleType, - PlannerSpec, - RobotModelConfig, - RobotName, - WorldRobotID, - create_kinematics, - create_planner, ) -from dimos.manipulation.planning.monitor import WorldMonitor -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -82,18 +79,17 @@ class ManipulationState(Enum): FAULT = 4 -@dataclass class ManipulationModuleConfig(ModuleConfig): """Configuration for ManipulationModule.""" - robots: list[RobotModelConfig] = field(default_factory=list) + robots: Iterable[RobotModelConfig] = () planning_timeout: float = 10.0 enable_viz: bool = False planner_name: str = "rrt_connect" # "rrt_connect" kinematics_name: str = "jacobian" # "jacobian" or "drake_optimization" -class ManipulationModule(Module): +class ManipulationModule(Module[ManipulationModuleConfig]): """Base motion planning module with ControlCoordinator execution. - @rpc: Low-level building blocks (plan, execute, gripper) @@ -104,14 +100,11 @@ class ManipulationModule(Module): default_config = ManipulationModuleConfig - # Type annotation for the config attribute (mypy uses this) - config: ManipulationModuleConfig - # Input: Joint state from coordinator (for world sync) joint_state: In[JointState] - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # State machine self._state = ManipulationState.IDLE @@ -251,7 +244,7 @@ def _on_joint_state(self, msg: JointState) -> None: def _tf_publish_loop(self) -> None: """Publish TF transforms at 10Hz for EE and extra links.""" - from dimos.msgs.geometry_msgs import Transform + from dimos.msgs.geometry_msgs.Transform import Transform period = 0.1 # 10Hz while not self._tf_stop_event.is_set(): @@ -282,10 +275,6 @@ def _tf_publish_loop(self) -> None: self._tf_stop_event.wait(period) - # ========================================================================= - # RPC Methods - # ========================================================================= - @rpc def get_state(self) -> str: """Get current manipulation state name.""" @@ -360,10 +349,6 @@ def is_collision_free(self, joints: list[float], robot_name: RobotName | None = return self._world_monitor.is_state_valid(robot_id, joint_state) return False - # ========================================================================= - # Plan/Preview/Execute Workflow RPC Methods - # ========================================================================= - def _begin_planning( self, robot_name: RobotName | None = None ) -> tuple[RobotName, WorldRobotID] | None: @@ -418,7 +403,7 @@ def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool: return self._fail("No joint state") # Convert Pose to PoseStamped for the IK solver - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped target_pose = PoseStamped( frame_id="world", @@ -634,10 +619,6 @@ def set_init_joints_to_current(self, robot_name: RobotName | None = None) -> boo ) return True - # ========================================================================= - # Coordinator Integration RPC Methods - # ========================================================================= - def _get_coordinator_client(self) -> RPCClient | None: """Get or create coordinator RPC client (lazy init).""" if not any( @@ -766,7 +747,7 @@ def add_obstacle( return "" # Import PoseStamped here to avoid circular imports - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped obstacle = Obstacle( name=name, @@ -784,10 +765,6 @@ def remove_obstacle(self, obstacle_id: str) -> bool: return False return self._world_monitor.remove_obstacle(obstacle_id) - # ========================================================================= - # Gripper Methods - # ========================================================================= - def _get_gripper_hardware_id(self, robot_name: RobotName | None = None) -> str | None: """Get gripper hardware ID for a robot.""" robot = self._get_robot(robot_name) @@ -860,10 +837,6 @@ def close_gripper(self, robot_name: str | None = None) -> str: return "Gripper closed" return "Error: Failed to close gripper" - # ========================================================================= - # Skill Helpers (internal) - # ========================================================================= - def _wait_for_trajectory_completion( self, robot_name: RobotName | None = None, timeout: float = 60.0, poll_interval: float = 0.2 ) -> bool: @@ -948,10 +921,6 @@ def _preview_execute_wait( return None - # ========================================================================= - # Short-Horizon Skills — Single-step actions - # ========================================================================= - @skill def get_robot_state(self, robot_name: str | None = None) -> str: """Get current robot state: joint positions, end-effector pose, and gripper. @@ -1136,10 +1105,6 @@ def go_init(self, robot_name: str | None = None) -> str: return "Reached init position" - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def stop(self) -> None: """Stop the manipulation module.""" diff --git a/dimos/manipulation/pick_and_place_module.py b/dimos/manipulation/pick_and_place_module.py index 84ede61793..b433df6801 100644 --- a/dimos/manipulation/pick_and_place_module.py +++ b/dimos/manipulation/pick_and_place_module.py @@ -22,7 +22,6 @@ from __future__ import annotations -from dataclasses import dataclass, field import math from pathlib import Path import time @@ -32,22 +31,24 @@ from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc from dimos.core.docker_runner import DockerModule as DockerRunner -from dimos.core.stream import In # noqa: TC001 +from dimos.core.stream import In from dimos.manipulation.grasping.graspgen_module import GraspGenModule from dimos.manipulation.manipulation_module import ( ManipulationModule, ManipulationModuleConfig, ) -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.perception.detection.type.detection3d.object import ( - Object as DetObject, # noqa: TC001 + Object as DetObject, ) from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import PoseArray - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.geometry_msgs.PoseArray import PoseArray + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 logger = setup_logger() @@ -56,7 +57,6 @@ _GRASPGEN_VIZ_CONTAINER_PATH = f"{_GRASPGEN_VIZ_CONTAINER_DIR}/visualization.json" -@dataclass class PickAndPlaceModuleConfig(ManipulationModuleConfig): """Configuration for PickAndPlaceModule (adds GraspGen settings).""" @@ -68,8 +68,8 @@ class PickAndPlaceModuleConfig(ManipulationModuleConfig): graspgen_grasp_threshold: float = -1.0 graspgen_filter_collisions: bool = False graspgen_save_visualization_data: bool = False - graspgen_visualization_output_path: Path = field( - default_factory=lambda: Path.home() / ".dimos" / "graspgen" / "visualization.json" + graspgen_visualization_output_path: Path = ( + Path.home() / ".dimos" / "graspgen" / "visualization.json" ) @@ -90,8 +90,8 @@ class PickAndPlaceModule(ManipulationModule): # Input: Objects from perception (for obstacle integration) objects: In[list[DetObject]] - def __init__(self, *args: object, **kwargs: object) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # GraspGen Docker runner (lazy initialized on first generate_grasps call) self._graspgen: DockerRunner | None = None @@ -104,10 +104,6 @@ def __init__(self, *args: object, **kwargs: object) -> None: # so pick/place use this stable snapshot instead. self._detection_snapshot: list[DetObject] = [] - # ========================================================================= - # Lifecycle (perception integration) - # ========================================================================= - @rpc def start(self) -> None: """Start the pick-and-place module (adds perception subscriptions).""" @@ -132,10 +128,6 @@ def _on_objects(self, objects: list[DetObject]) -> None: except Exception as e: logger.error(f"Exception in _on_objects: {e}") - # ========================================================================= - # Perception RPC Methods - # ========================================================================= - @rpc def refresh_obstacles(self, min_duration: float = 0.0) -> list[dict[str, Any]]: """Refresh perception obstacles. Returns the list of obstacles added. @@ -184,10 +176,6 @@ def list_added_obstacles(self) -> list[dict[str, Any]]: return [] return self._world_monitor.list_added_obstacles() - # ========================================================================= - # GraspGen - # ========================================================================= - def _get_graspgen(self) -> DockerRunner: """Get or create GraspGen Docker module (lazy init, thread-safe).""" # Fast path: already initialized (no lock needed for read) @@ -252,10 +240,6 @@ def generate_grasps( logger.error(f"Grasp generation failed: {e}") return None - # ========================================================================= - # Pick/Place Helpers - # ========================================================================= - def _compute_pre_grasp_pose(self, grasp_pose: Pose, offset: float = 0.10) -> Pose: """Compute a pre-grasp pose offset along the approach direction (local -Z). @@ -323,10 +307,6 @@ def _generate_grasps_for_pick( logger.info(f"Heuristic grasp for '{object_name}' at ({c.x:.3f}, {c.y:.3f}, {c.z:.3f})") return [grasp_pose] - # ========================================================================= - # Perception Skills - # ========================================================================= - @skill def get_scene_info(self, robot_name: str | None = None) -> str: """Get current robot state, detected objects, and scene information. @@ -412,10 +392,6 @@ def scan_objects(self, min_duration: float = 1.0, robot_name: str | None = None) return "\n".join(lines) - # ========================================================================= - # Long-Horizon Skills — Pick and Place - # ========================================================================= - @skill def pick( self, @@ -604,10 +580,6 @@ def pick_and_place( # Place phase return self.place(place_x, place_y, place_z, robot_name) - # ========================================================================= - # Lifecycle - # ========================================================================= - @rpc def stop(self) -> None: """Stop the pick-and-place module (cleanup GraspGen + delegate to base).""" diff --git a/dimos/manipulation/planning/__init__.py b/dimos/manipulation/planning/__init__.py deleted file mode 100644 index 8aaf0caa25..0000000000 --- a/dimos/manipulation/planning/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Planning Module - -Motion planning stack for robotic manipulators using Protocol-based architecture. - -## Architecture - -- WorldSpec: Core backend owning physics/collision (DrakeWorld, future: MuJoCoWorld) -- KinematicsSpec: IK solvers - - JacobianIK: Backend-agnostic iterative/differential IK - - DrakeOptimizationIK: Drake-specific nonlinear optimization IK -- PlannerSpec: Backend-agnostic joint-space path planning - - RRTConnectPlanner: Bi-directional RRT-Connect - - RRTStarPlanner: RRT* (asymptotically optimal) - -## Factory Functions - -Use factory functions to create components: - -```python -from dimos.manipulation.planning.factory import ( - create_world, - create_kinematics, - create_planner, -) - -world = create_world(backend="drake", enable_viz=True) -kinematics = create_kinematics(name="jacobian") # or "drake_optimization" -planner = create_planner(name="rrt_connect") # backend-agnostic -``` - -## Monitors - -Use WorldMonitor for reactive state synchronization: - -```python -from dimos.manipulation.planning.monitor import WorldMonitor - -monitor = WorldMonitor(enable_viz=True) -robot_id = monitor.add_robot(config) -monitor.finalize() -monitor.start_state_monitor(robot_id) -``` -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "factory": ["create_kinematics", "create_planner", "create_planning_stack", "create_world"], - "spec": [ - "CollisionObjectMessage", - "IKResult", - "IKStatus", - "JointPath", - "KinematicsSpec", - "Obstacle", - "ObstacleType", - "PlannerSpec", - "PlanningResult", - "PlanningStatus", - "RobotModelConfig", - "RobotName", - "WorldRobotID", - "WorldSpec", - ], - "trajectory_generator.joint_trajectory_generator": ["JointTrajectoryGenerator"], - }, -) diff --git a/dimos/manipulation/planning/examples/__init__.py b/dimos/manipulation/planning/examples/__init__.py deleted file mode 100644 index 7971835dab..0000000000 --- a/dimos/manipulation/planning/examples/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation planning examples. -""" diff --git a/dimos/manipulation/planning/examples/manipulation_client.py b/dimos/manipulation/planning/examples/manipulation_client.py index ac098ac52a..4dcd2fe9e8 100644 --- a/dimos/manipulation/planning/examples/manipulation_client.py +++ b/dimos/manipulation/planning/examples/manipulation_client.py @@ -49,7 +49,9 @@ from dimos.core.rpc_client import RPCClient from dimos.manipulation.manipulation_module import ManipulationModule -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 _client = RPCClient(None, ManipulationModule) @@ -71,7 +73,7 @@ def state() -> str: def plan(target_joints: list[float], robot_name: str | None = None) -> bool: """Plan to joint configuration. e.g. plan([0.1]*7)""" - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState js = JointState(position=target_joints) return _client.plan_to_joints(js, robot_name) @@ -106,7 +108,7 @@ def execute(robot_name: str | None = None) -> bool: def home(robot_name: str | None = None) -> bool: """Plan and execute move to home position.""" - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState home_joints = _client.get_robot_info(robot_name).get("home_joints", [0.0] * 7) success = _client.plan_to_joints(JointState(position=home_joints), robot_name) diff --git a/dimos/manipulation/planning/factory.py b/dimos/manipulation/planning/factory.py index d392bac563..65173dfd18 100644 --- a/dimos/manipulation/planning/factory.py +++ b/dimos/manipulation/planning/factory.py @@ -19,11 +19,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from dimos.manipulation.planning.spec import ( - KinematicsSpec, - PlannerSpec, - WorldSpec, - ) + from dimos.manipulation.planning.spec.protocols import KinematicsSpec, PlannerSpec, WorldSpec def create_world( diff --git a/dimos/manipulation/planning/kinematics/__init__.py b/dimos/manipulation/planning/kinematics/__init__.py deleted file mode 100644 index dacd2007cb..0000000000 --- a/dimos/manipulation/planning/kinematics/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Kinematics Module - -Contains IK solver implementations that use WorldSpec. - -## Implementations - -- JacobianIK: Backend-agnostic iterative/differential IK (works with any WorldSpec) -- DrakeOptimizationIK: Drake-specific nonlinear optimization IK (requires DrakeWorld) - -## Usage - -Use factory functions to create IK solvers: - -```python -from dimos.manipulation.planning.factory import create_kinematics - -# Backend-agnostic (works with any WorldSpec) -kinematics = create_kinematics(name="jacobian") - -# Drake-specific (requires DrakeWorld, more accurate) -kinematics = create_kinematics(name="drake_optimization") - -result = kinematics.solve(world, robot_id, target_pose) -``` -""" - -from dimos.manipulation.planning.kinematics.drake_optimization_ik import ( - DrakeOptimizationIK, -) -from dimos.manipulation.planning.kinematics.jacobian_ik import JacobianIK -from dimos.manipulation.planning.kinematics.pinocchio_ik import ( - PinocchioIK, - PinocchioIKConfig, -) - -__all__ = ["DrakeOptimizationIK", "JacobianIK", "PinocchioIK", "PinocchioIKConfig"] diff --git a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py index 1e6b1962a5..b13aa8947a 100644 --- a/dimos/manipulation/planning/kinematics/drake_optimization_ik.py +++ b/dimos/manipulation/planning/kinematics/drake_optimization_ik.py @@ -20,10 +20,13 @@ import numpy as np -from dimos.manipulation.planning.spec import IKResult, IKStatus, WorldRobotID, WorldSpec +from dimos.manipulation.planning.spec.enums import IKStatus +from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import compute_pose_error -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import pose_to_matrix diff --git a/dimos/manipulation/planning/kinematics/jacobian_ik.py b/dimos/manipulation/planning/kinematics/jacobian_ik.py index 5f80642058..fb493e2d5f 100644 --- a/dimos/manipulation/planning/kinematics/jacobian_ik.py +++ b/dimos/manipulation/planning/kinematics/jacobian_ik.py @@ -28,7 +28,9 @@ import numpy as np -from dimos.manipulation.planning.spec import IKResult, IKStatus, WorldRobotID, WorldSpec +from dimos.manipulation.planning.spec.enums import IKStatus +from dimos.manipulation.planning.spec.models import IKResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.kinematics_utils import ( check_singularity, compute_error_twist, @@ -41,8 +43,11 @@ if TYPE_CHECKING: from numpy.typing import NDArray -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState logger = setup_logger() @@ -395,7 +400,7 @@ def solve_differential_position_only( return JointState(name=joint_names, velocity=q_dot.tolist()) -# ============= Result Helpers ============= +# Result Helpers def _create_success_result( diff --git a/dimos/manipulation/planning/kinematics/pinocchio_ik.py b/dimos/manipulation/planning/kinematics/pinocchio_ik.py index 4224dda556..cb6ee91608 100644 --- a/dimos/manipulation/planning/kinematics/pinocchio_ik.py +++ b/dimos/manipulation/planning/kinematics/pinocchio_ik.py @@ -44,16 +44,12 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import Pose, PoseStamped + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() -# ============================================================================= -# Configuration -# ============================================================================= - - @dataclass class PinocchioIKConfig: """Configuration for the Pinocchio IK solver. @@ -73,11 +69,6 @@ class PinocchioIKConfig: max_velocity: float = 10.0 -# ============================================================================= -# PinocchioIK Solver -# ============================================================================= - - class PinocchioIK: """Pinocchio-based damped least-squares IK solver. @@ -162,10 +153,6 @@ def ee_joint_id(self) -> int: """End-effector joint ID.""" return self._ee_joint_id - # ========================================================================= - # Core IK - # ========================================================================= - def solve( self, target_pose: pinocchio.SE3, @@ -208,10 +195,6 @@ def solve( return q, False, final_err - # ========================================================================= - # Forward Kinematics - # ========================================================================= - def forward_kinematics(self, joint_positions: NDArray[np.floating[Any]]) -> pinocchio.SE3: """Compute end-effector pose from joint positions. @@ -225,11 +208,6 @@ def forward_kinematics(self, joint_positions: NDArray[np.floating[Any]]) -> pino return self._data.oMi[self._ee_joint_id].copy() -# ============================================================================= -# Pose Conversion Helpers -# ============================================================================= - - def pose_to_se3(pose: Pose | PoseStamped) -> pinocchio.SE3: """Convert Pose or PoseStamped to pinocchio SE3""" @@ -239,11 +217,6 @@ def pose_to_se3(pose: Pose | PoseStamped) -> pinocchio.SE3: return pinocchio.SE3(rotation, position) -# ============================================================================= -# Safety Utilities -# ============================================================================= - - def check_joint_delta( q_new: NDArray[np.floating[Any]], q_current: NDArray[np.floating[Any]], diff --git a/dimos/manipulation/planning/monitor/__init__.py b/dimos/manipulation/planning/monitor/__init__.py deleted file mode 100644 index c280bd4d56..0000000000 --- a/dimos/manipulation/planning/monitor/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -World Monitor Module - -Provides reactive monitoring for keeping WorldSpec synchronized with the real world. - -## Components - -- WorldMonitor: Top-level monitor using WorldSpec Protocol -- WorldStateMonitor: Syncs joint state to WorldSpec -- WorldObstacleMonitor: Syncs obstacles to WorldSpec - -All monitors use the factory pattern and Protocol types. - -## Example - -```python -from dimos.manipulation.planning.monitor import WorldMonitor - -monitor = WorldMonitor(enable_viz=True) -robot_id = monitor.add_robot(config) -monitor.finalize() - -# Start monitoring -monitor.start_state_monitor(robot_id) -monitor.start_obstacle_monitor() - -# Handle joint state messages -monitor.on_joint_state(msg, robot_id) - -# Thread-safe collision checking -is_valid = monitor.is_state_valid(robot_id, q_test) -``` -""" - -from dimos.manipulation.planning.monitor.world_monitor import WorldMonitor -from dimos.manipulation.planning.monitor.world_obstacle_monitor import ( - WorldObstacleMonitor, -) -from dimos.manipulation.planning.monitor.world_state_monitor import WorldStateMonitor - -# Re-export message types from spec for convenience -from dimos.manipulation.planning.spec import CollisionObjectMessage - -__all__ = [ - "CollisionObjectMessage", - "WorldMonitor", - "WorldObstacleMonitor", - "WorldStateMonitor", -] diff --git a/dimos/manipulation/planning/monitor/world_monitor.py b/dimos/manipulation/planning/monitor/world_monitor.py index 33017957dc..32f519dfd4 100644 --- a/dimos/manipulation/planning/monitor/world_monitor.py +++ b/dimos/manipulation/planning/monitor/world_monitor.py @@ -23,8 +23,8 @@ from dimos.manipulation.planning.factory import create_world from dimos.manipulation.planning.monitor.world_obstacle_monitor import WorldObstacleMonitor from dimos.manipulation.planning.monitor.world_state_monitor import WorldStateMonitor -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -33,15 +33,15 @@ import numpy as np from numpy.typing import NDArray - from dimos.manipulation.planning.spec import ( + from dimos.manipulation.planning.spec.config import RobotModelConfig + from dimos.manipulation.planning.spec.models import ( CollisionObjectMessage, JointPath, Obstacle, - RobotModelConfig, WorldRobotID, - WorldSpec, ) - from dimos.msgs.vision_msgs import Detection3D + from dimos.manipulation.planning.spec.protocols import WorldSpec + from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() @@ -66,7 +66,7 @@ def __init__( self._viz_stop_event = threading.Event() self._viz_rate_hz: float = 10.0 - # ============= Robot Management ============= + # Robot Management def add_robot(self, config: RobotModelConfig) -> WorldRobotID: """Add a robot. Returns robot_id.""" @@ -93,7 +93,7 @@ def get_joint_limits( with self._lock: return self._world.get_joint_limits(robot_id) - # ============= Obstacle Management ============= + # Obstacle Management def add_obstacle(self, obstacle: Obstacle) -> str: """Add an obstacle. Returns obstacle_id.""" @@ -110,7 +110,7 @@ def clear_obstacles(self) -> None: with self._lock: self._world.clear_obstacles() - # ============= Monitor Control ============= + # Monitor Control def start_state_monitor( self, @@ -181,7 +181,7 @@ def stop_all_monitors(self) -> None: self._world.close() - # ============= Message Handlers ============= + # Message Handlers def on_joint_state(self, msg: JointState, robot_id: WorldRobotID | None = None) -> None: """Handle joint state message. Broadcasts to all monitors if robot_id is None.""" @@ -252,7 +252,7 @@ def list_added_obstacles(self) -> list[dict[str, Any]]: return self._obstacle_monitor.list_added_obstacles() return [] - # ============= State Access ============= + # State Access def get_current_joint_state(self, robot_id: WorldRobotID) -> JointState | None: """Get current joint state. Returns None if not yet received.""" @@ -294,7 +294,7 @@ def is_state_stale(self, robot_id: WorldRobotID, max_age: float = 1.0) -> bool: return self._state_monitors[robot_id].is_state_stale(max_age) return True - # ============= Context Management ============= + # Context Management @contextmanager def scratch_context(self) -> Generator[Any, None, None]: @@ -306,7 +306,7 @@ def get_live_context(self) -> Any: """Get live context. Prefer scratch_context() for planning.""" return self._world.get_live_context() - # ============= Collision Checking ============= + # Collision Checking def is_state_valid(self, robot_id: WorldRobotID, joint_state: JointState) -> bool: """Check if configuration is collision-free.""" @@ -340,7 +340,7 @@ def get_min_distance(self, robot_id: WorldRobotID) -> float: with self._world.scratch_context() as ctx: return self._world.get_min_distance(ctx, robot_id) - # ============= Kinematics ============= + # Kinematics def get_ee_pose( self, robot_id: WorldRobotID, joint_state: JointState | None = None @@ -366,7 +366,7 @@ def get_link_pose( link_name: Name of the link in the URDF joint_state: Joint state to use (uses current if None) """ - from dimos.msgs.geometry_msgs import Quaternion + from dimos.msgs.geometry_msgs.Quaternion import Quaternion with self._world.scratch_context() as ctx: if joint_state is None: @@ -394,7 +394,7 @@ def get_jacobian(self, robot_id: WorldRobotID, joint_state: JointState) -> NDArr self._world.set_joint_state(ctx, robot_id, joint_state) return self._world.get_jacobian(ctx, robot_id) - # ============= Lifecycle ============= + # Lifecycle def finalize(self) -> None: """Finalize world. Must be called before collision checking.""" @@ -407,7 +407,7 @@ def is_finalized(self) -> bool: """Check if world is finalized.""" return self._world.is_finalized - # ============= Visualization ============= + # Visualization def get_visualization_url(self) -> str | None: """Get visualization URL or None if not enabled.""" @@ -466,7 +466,7 @@ def _visualization_loop(self) -> None: logger.debug(f"Visualization publish failed: {e}") time.sleep(period) - # ============= Direct World Access ============= + # Direct World Access @property def world(self) -> WorldSpec: diff --git a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py index a96d3efaf6..a21ee68726 100644 --- a/dimos/manipulation/planning/monitor/world_obstacle_monitor.py +++ b/dimos/manipulation/planning/monitor/world_obstacle_monitor.py @@ -29,20 +29,17 @@ import time from typing import TYPE_CHECKING, Any -from dimos.manipulation.planning.spec import ( - CollisionObjectMessage, - Obstacle, - ObstacleType, -) -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import CollisionObjectMessage, Obstacle +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from collections.abc import Callable import threading - from dimos.manipulation.planning.spec import WorldSpec - from dimos.msgs.vision_msgs import Detection3D + from dimos.manipulation.planning.spec.protocols import WorldSpec + from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() @@ -406,7 +403,7 @@ def remove_obstacle_callback( if callback in self._obstacle_callbacks: self._obstacle_callbacks.remove(callback) - # ============= Object-Based Perception (from ObjectDB) ============= + # Object-Based Perception (from ObjectDB) def on_objects(self, objects: list[object]) -> None: """Cache objects from ObjectDB (preserves stable object_id). diff --git a/dimos/manipulation/planning/monitor/world_state_monitor.py b/dimos/manipulation/planning/monitor/world_state_monitor.py index 87d61bb66f..8548251c73 100644 --- a/dimos/manipulation/planning/monitor/world_state_monitor.py +++ b/dimos/manipulation/planning/monitor/world_state_monitor.py @@ -31,7 +31,7 @@ import numpy as np -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -40,7 +40,7 @@ from numpy.typing import NDArray - from dimos.manipulation.planning.spec import WorldSpec + from dimos.manipulation.planning.spec.protocols import WorldSpec logger = setup_logger() diff --git a/dimos/manipulation/planning/planners/__init__.py b/dimos/manipulation/planning/planners/__init__.py deleted file mode 100644 index 8fb8ae042b..0000000000 --- a/dimos/manipulation/planning/planners/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Motion Planners Module - -Contains motion planning implementations that use WorldSpec. - -All planners are backend-agnostic - they only use WorldSpec methods and -work with any physics backend (Drake, MuJoCo, PyBullet, etc.). - -## Implementations - -- RRTConnectPlanner: Bi-directional RRT-Connect planner (fast, reliable) - -## Usage - -Use factory functions to create planners: - -```python -from dimos.manipulation.planning.factory import create_planner - -planner = create_planner(name="rrt_connect") # Returns PlannerSpec -result = planner.plan_joint_path(world, robot_id, q_start, q_goal) -``` -""" - -from dimos.manipulation.planning.planners.rrt_planner import RRTConnectPlanner - -__all__ = ["RRTConnectPlanner"] diff --git a/dimos/manipulation/planning/planners/rrt_planner.py b/dimos/manipulation/planning/planners/rrt_planner.py index f2be8736d5..7f308dce0c 100644 --- a/dimos/manipulation/planning/planners/rrt_planner.py +++ b/dimos/manipulation/planning/planners/rrt_planner.py @@ -26,15 +26,11 @@ import numpy as np -from dimos.manipulation.planning.spec import ( - JointPath, - PlanningResult, - PlanningStatus, - WorldRobotID, - WorldSpec, -) +from dimos.manipulation.planning.spec.enums import PlanningStatus +from dimos.manipulation.planning.spec.models import JointPath, PlanningResult, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.path_utils import compute_path_length -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -315,7 +311,7 @@ def _simplify_path( return simplified -# ============= Result Helpers ============= +# Result Helpers def _create_success_result( diff --git a/dimos/manipulation/planning/spec/__init__.py b/dimos/manipulation/planning/spec/__init__.py deleted file mode 100644 index a78fb6e5fd..0000000000 --- a/dimos/manipulation/planning/spec/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Manipulation Planning Specifications.""" - -from dimos.manipulation.planning.spec.config import RobotModelConfig -from dimos.manipulation.planning.spec.enums import IKStatus, ObstacleType, PlanningStatus -from dimos.manipulation.planning.spec.protocols import ( - KinematicsSpec, - PlannerSpec, - WorldSpec, -) -from dimos.manipulation.planning.spec.types import ( - CollisionObjectMessage, - IKResult, - Jacobian, - JointPath, - Obstacle, - PlanningResult, - RobotName, - WorldRobotID, -) - -__all__ = [ - "CollisionObjectMessage", - "IKResult", - "IKStatus", - "Jacobian", - "JointPath", - "KinematicsSpec", - "Obstacle", - "ObstacleType", - "PlannerSpec", - "PlanningResult", - "PlanningStatus", - "RobotModelConfig", - "RobotName", - "WorldRobotID", - "WorldSpec", -] diff --git a/dimos/manipulation/planning/spec/config.py b/dimos/manipulation/planning/spec/config.py index dc302689ea..80cf248f08 100644 --- a/dimos/manipulation/planning/spec/config.py +++ b/dimos/manipulation/planning/spec/config.py @@ -16,17 +16,16 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from collections.abc import Iterable, Sequence +from pathlib import Path -if TYPE_CHECKING: - from pathlib import Path +from pydantic import Field - from dimos.msgs.geometry_msgs import PoseStamped +from dimos.core.module import ModuleConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -@dataclass -class RobotModelConfig: +class RobotModelConfig(ModuleConfig): """Configuration for adding a robot to the world. Attributes: @@ -60,24 +59,24 @@ class RobotModelConfig: joint_names: list[str] end_effector_link: str base_link: str = "base_link" - package_paths: dict[str, Path] = field(default_factory=dict) + package_paths: dict[str, Path] = Field(default_factory=dict) joint_limits_lower: list[float] | None = None joint_limits_upper: list[float] | None = None velocity_limits: list[float] | None = None auto_convert_meshes: bool = False - xacro_args: dict[str, str] = field(default_factory=dict) - collision_exclusion_pairs: list[tuple[str, str]] = field(default_factory=list) + xacro_args: dict[str, str] = Field(default_factory=dict) + collision_exclusion_pairs: Iterable[tuple[str, str]] = () # Motion constraints for trajectory generation max_velocity: float = 1.0 max_acceleration: float = 2.0 # Coordinator integration - joint_name_mapping: dict[str, str] = field(default_factory=dict) + joint_name_mapping: dict[str, str] = Field(default_factory=dict) coordinator_task_name: str | None = None gripper_hardware_id: str | None = None # TF publishing for extra links (e.g., camera mount) - tf_extra_links: list[str] = field(default_factory=list) + tf_extra_links: Sequence[str] = () # Home/observe joint configuration for go_home skill - home_joints: list[float] | None = None + home_joints: Iterable[float] | None = None # Pre-grasp offset distance in meters (along approach direction) pre_grasp_offset: float = 0.10 diff --git a/dimos/manipulation/planning/spec/types.py b/dimos/manipulation/planning/spec/models.py similarity index 87% rename from dimos/manipulation/planning/spec/types.py rename to dimos/manipulation/planning/spec/models.py index a38cc0da26..37daa331e4 100644 --- a/dimos/manipulation/planning/spec/types.py +++ b/dimos/manipulation/planning/spec/models.py @@ -29,12 +29,9 @@ import numpy as np from numpy.typing import NDArray - from dimos.msgs.geometry_msgs import PoseStamped - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.sensor_msgs.JointState import JointState -# ============================================================================= -# Semantic ID Types (documentation only, not enforced at runtime) -# ============================================================================= RobotName: TypeAlias = str """User-facing robot name (e.g., 'left_arm', 'right_arm')""" @@ -45,19 +42,11 @@ JointPath: TypeAlias = "list[JointState]" """List of joint states forming a path (each waypoint has names + positions)""" -# ============================================================================= -# Numeric Array Types -# ============================================================================= Jacobian: TypeAlias = "NDArray[np.float64]" """6 x n Jacobian matrix (rows: [vx, vy, vz, wx, wy, wz])""" -# ============================================================================= -# Data Classes -# ============================================================================= - - @dataclass class Obstacle: """Obstacle specification for collision avoidance. diff --git a/dimos/manipulation/planning/spec/protocols.py b/dimos/manipulation/planning/spec/protocols.py index dea4718abb..76ecd1780b 100644 --- a/dimos/manipulation/planning/spec/protocols.py +++ b/dimos/manipulation/planning/spec/protocols.py @@ -29,15 +29,15 @@ from numpy.typing import NDArray from dimos.manipulation.planning.spec.config import RobotModelConfig - from dimos.manipulation.planning.spec.types import ( + from dimos.manipulation.planning.spec.models import ( IKResult, JointPath, Obstacle, PlanningResult, WorldRobotID, ) - from dimos.msgs.geometry_msgs import PoseStamped - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.sensor_msgs.JointState import JointState @runtime_checkable diff --git a/dimos/manipulation/planning/trajectory_generator/__init__.py b/dimos/manipulation/planning/trajectory_generator/__init__.py deleted file mode 100644 index a7449cf45f..0000000000 --- a/dimos/manipulation/planning/trajectory_generator/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory Generator Module - -Generates time-parameterized trajectories from waypoints. -""" - -from dimos.manipulation.planning.trajectory_generator.joint_trajectory_generator import ( - JointTrajectoryGenerator, -) - -__all__ = ["JointTrajectoryGenerator"] diff --git a/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py b/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py index 6b732d133c..1ac6b74351 100644 --- a/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py +++ b/dimos/manipulation/planning/trajectory_generator/joint_trajectory_generator.py @@ -32,7 +32,8 @@ import math -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint class JointTrajectoryGenerator: diff --git a/dimos/manipulation/planning/trajectory_generator/spec.py b/dimos/manipulation/planning/trajectory_generator/spec.py index 5357679f28..0814f5dc0b 100644 --- a/dimos/manipulation/planning/trajectory_generator/spec.py +++ b/dimos/manipulation/planning/trajectory_generator/spec.py @@ -35,7 +35,7 @@ from typing import Protocol -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory class JointTrajectoryGeneratorSpec(Protocol): diff --git a/dimos/manipulation/planning/utils/__init__.py b/dimos/manipulation/planning/utils/__init__.py deleted file mode 100644 index 04ec1806b5..0000000000 --- a/dimos/manipulation/planning/utils/__init__.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Manipulation Planning Utilities - -Standalone utility functions for kinematics and path operations. -These are extracted from the old ABC base classes to enable composition over inheritance. - -## Modules - -- kinematics_utils: Jacobian operations, singularity detection, pose error computation -- path_utils: Path interpolation, simplification, length computation -""" - -from dimos.manipulation.planning.utils.kinematics_utils import ( - check_singularity, - compute_error_twist, - compute_pose_error, - damped_pseudoinverse, - get_manipulability, -) -from dimos.manipulation.planning.utils.path_utils import ( - compute_path_length, - interpolate_path, - interpolate_segment, -) - -__all__ = [ - # Kinematics utilities - "check_singularity", - "compute_error_twist", - # Path utilities - "compute_path_length", - "compute_pose_error", - "damped_pseudoinverse", - "get_manipulability", - "interpolate_path", - "interpolate_segment", -] diff --git a/dimos/manipulation/planning/utils/kinematics_utils.py b/dimos/manipulation/planning/utils/kinematics_utils.py index c9f3f95a3d..02e885f1ae 100644 --- a/dimos/manipulation/planning/utils/kinematics_utils.py +++ b/dimos/manipulation/planning/utils/kinematics_utils.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.manipulation.planning.spec import Jacobian + from dimos.manipulation.planning.spec.models import Jacobian def damped_pseudoinverse( diff --git a/dimos/manipulation/planning/utils/path_utils.py b/dimos/manipulation/planning/utils/path_utils.py index fbf8af4032..dd5de1a0a4 100644 --- a/dimos/manipulation/planning/utils/path_utils.py +++ b/dimos/manipulation/planning/utils/path_utils.py @@ -32,12 +32,13 @@ import numpy as np -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState if TYPE_CHECKING: from numpy.typing import NDArray - from dimos.manipulation.planning.spec import JointPath, WorldRobotID, WorldSpec + from dimos.manipulation.planning.spec.models import JointPath, WorldRobotID + from dimos.manipulation.planning.spec.protocols import WorldSpec def interpolate_path( diff --git a/dimos/manipulation/planning/world/__init__.py b/dimos/manipulation/planning/world/__init__.py deleted file mode 100644 index 8ddef7fdff..0000000000 --- a/dimos/manipulation/planning/world/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -World Module - -Contains world implementations that own the physics/collision backend. - -## Implementations - -- DrakeWorld: Uses Drake MultibodyPlant + SceneGraph -""" - -from dimos.manipulation.planning.world.drake_world import DrakeWorld - -__all__ = ["DrakeWorld"] diff --git a/dimos/manipulation/planning/world/drake_world.py b/dimos/manipulation/planning/world/drake_world.py index 2ab996f410..ce155253ca 100644 --- a/dimos/manipulation/planning/world/drake_world.py +++ b/dimos/manipulation/planning/world/drake_world.py @@ -25,14 +25,10 @@ import numpy as np -from dimos.manipulation.planning.spec import ( - JointPath, - Obstacle, - ObstacleType, - RobotModelConfig, - WorldRobotID, - WorldSpec, -) +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.manipulation.planning.spec.enums import ObstacleType +from dimos.manipulation.planning.spec.models import JointPath, Obstacle, WorldRobotID +from dimos.manipulation.planning.spec.protocols import WorldSpec from dimos.manipulation.planning.utils.mesh_utils import prepare_urdf_for_drake from dimos.utils.logging_config import setup_logger @@ -41,8 +37,9 @@ from numpy.typing import NDArray -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.JointState import JointState try: from pydrake.geometry import ( # type: ignore[import-not-found] @@ -124,8 +121,6 @@ def _call(self, fn: Any, *args: Any, **kwargs: Any) -> Any: return fn(*args, **kwargs) return self._executor.submit(fn, *args, **kwargs).result() - # --- Meshcat proxies --- - def SetObject(self, *args: Any, **kwargs: Any) -> Any: return self._call(self._inner.SetObject, *args, **kwargs) @@ -327,7 +322,7 @@ def get_joint_limits( np.full(n_joints, np.pi), ) - # ============= Obstacle Management ============= + # Obstacle Management def add_obstacle(self, obstacle: Obstacle) -> str: """Add an obstacle to the world.""" @@ -536,7 +531,7 @@ def clear_obstacles(self) -> None: for obs_id in obstacle_ids: self.remove_obstacle(obs_id) - # ============= Preview Robot Setup ============= + # Preview Robot Setup def _set_preview_colors(self) -> None: """Set all preview robot visual geometries to yellow/semi-transparent.""" @@ -565,7 +560,7 @@ def _remove_preview_collision_roles(self) -> None: for geom_id in self._plant.GetCollisionGeometriesForBody(body): self._scene_graph.RemoveRole(source_id, geom_id, Role.kProximity) - # ============= Lifecycle ============= + # Lifecycle def finalize(self) -> None: """Finalize world - locks robot topology, enables collision checking.""" @@ -683,7 +678,7 @@ def _exclude_body_pair(self, body1: Any, body2: Any) -> None: ) ) - # ============= Context Management ============= + # Context Management def get_live_context(self) -> Context: """Get the live context (mirrors current robot state). @@ -736,7 +731,7 @@ def sync_from_joint_state(self, robot_id: WorldRobotID, joint_state: JointState) # Calling ForcedPublish from the LCM callback thread blocks message processing. # Visualization can be updated via publish_to_meshcat() from non-callback contexts. - # ============= State Operations (context-based) ============= + # State Operations (context-based) def set_joint_state( self, ctx: Context, robot_id: WorldRobotID, joint_state: JointState @@ -782,7 +777,7 @@ def get_joint_state(self, ctx: Context, robot_id: WorldRobotID) -> JointState: positions = [float(full_positions[idx]) for idx in robot_data.joint_indices] return JointState(name=robot_data.config.joint_names, position=positions) - # ============= Collision Checking (context-based) ============= + # Collision Checking (context-based) def is_collision_free(self, ctx: Context, robot_id: WorldRobotID) -> bool: """Check if current configuration in context is collision-free.""" @@ -812,7 +807,7 @@ def get_min_distance(self, ctx: Context, robot_id: WorldRobotID) -> float: return float(min(pair.distance for pair in signed_distance_pairs)) - # ============= Collision Checking (context-free, for planning) ============= + # Collision Checking (context-free, for planning) def check_config_collision_free(self, robot_id: WorldRobotID, joint_state: JointState) -> bool: """Check if a joint state is collision-free (manages context internally). @@ -859,7 +854,7 @@ def check_edge_collision_free( return True - # ============= Forward Kinematics (context-based) ============= + # Forward Kinematics (context-based) def get_ee_pose(self, ctx: Context, robot_id: WorldRobotID) -> PoseStamped: """Get end-effector pose.""" @@ -944,7 +939,7 @@ def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float return J_reordered - # ============= Visualization ============= + # Visualization def get_visualization_url(self) -> str | None: """Get visualization URL if enabled.""" @@ -1029,7 +1024,7 @@ def close(self) -> None: if self._meshcat is not None: self._meshcat.close() - # ============= Direct Access (use with caution) ============= + # Direct Access (use with caution) @property def plant(self) -> MultibodyPlant: diff --git a/dimos/manipulation/test_manipulation_module.py b/dimos/manipulation/test_manipulation_module.py index c30ba9b55c..46a196e28c 100644 --- a/dimos/manipulation/test_manipulation_module.py +++ b/dimos/manipulation/test_manipulation_module.py @@ -30,9 +30,12 @@ ManipulationModule, ManipulationState, ) -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.utils.data import get_data diff --git a/dimos/manipulation/test_manipulation_unit.py b/dimos/manipulation/test_manipulation_unit.py index 4aa232c74f..67ca8332b4 100644 --- a/dimos/manipulation/test_manipulation_unit.py +++ b/dimos/manipulation/test_manipulation_unit.py @@ -26,13 +26,12 @@ ManipulationModule, ManipulationState, ) -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.trajectory_msgs import JointTrajectory, TrajectoryPoint - -# ============================================================================= -# Fixtures -# ============================================================================= +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory +from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint @pytest.fixture @@ -103,11 +102,6 @@ def _make_module(): return module -# ============================================================================= -# Test State Machine -# ============================================================================= - - class TestStateMachine: """Test state transitions.""" @@ -167,11 +161,6 @@ def test_begin_planning_state_checks(self, robot_config): assert module._begin_planning() is None -# ============================================================================= -# Test Robot Selection -# ============================================================================= - - class TestRobotSelection: """Test robot selection logic.""" @@ -201,11 +190,6 @@ def test_multiple_robots_require_name(self, robot_config): assert result[0] == "left" -# ============================================================================= -# Test Joint Name Translation (for coordinator integration) -# ============================================================================= - - class TestJointNameTranslation: """Test trajectory joint name translation for coordinator.""" @@ -227,11 +211,6 @@ def test_mapping_translates_names(self, robot_config_with_mapping, simple_trajec assert len(result.points) == 2 # Points preserved -# ============================================================================= -# Test Execute Method -# ============================================================================= - - class TestExecute: """Test coordinator execution.""" @@ -288,11 +267,6 @@ def test_execute_rejected(self, robot_config, simple_trajectory): assert module._state == ManipulationState.FAULT -# ============================================================================= -# Test RobotModelConfig Mapping Helpers -# ============================================================================= - - class TestRobotModelConfigMapping: """Test RobotModelConfig joint name mapping helpers.""" diff --git a/dimos/mapping/__init__.py b/dimos/mapping/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/mapping/costmapper.py b/dimos/mapping/costmapper.py index fa0ce826f2..06bf493564 100644 --- a/dimos/mapping/costmapper.py +++ b/dimos/mapping/costmapper.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass, field +from dataclasses import asdict import time +from pydantic import Field from reactivex import operators as ops from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.pointclouds.occupancy import ( @@ -26,30 +26,24 @@ HeightCostConfig, OccupancyConfig, ) -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.logging_config import setup_logger logger = setup_logger() -@dataclass class Config(ModuleConfig): algo: str = "height_cost" - config: OccupancyConfig = field(default_factory=HeightCostConfig) + config: OccupancyConfig = Field(default_factory=HeightCostConfig) -class CostMapper(Module): +class CostMapper(Module[Config]): default_config = Config - config: Config global_map: In[PointCloud2] global_costmap: Out[OccupancyGrid] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: - super().__init__(**kwargs) - self._global_config = cfg - @rpc def start(self) -> None: super().start() diff --git a/dimos/mapping/google_maps/google_maps.py b/dimos/mapping/google_maps/google_maps.py index 7f5ce32e99..18a1e25e2b 100644 --- a/dimos/mapping/google_maps/google_maps.py +++ b/dimos/mapping/google_maps/google_maps.py @@ -16,14 +16,14 @@ import googlemaps # type: ignore[import-untyped] -from dimos.mapping.google_maps.types import ( +from dimos.mapping.google_maps.models import ( Coordinates, LocationContext, NearbyPlace, PlacePosition, Position, ) -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon from dimos.mapping.utils.distance import distance_in_meters from dimos.utils.logging_config import setup_logger diff --git a/dimos/mapping/google_maps/types.py b/dimos/mapping/google_maps/models.py similarity index 100% rename from dimos/mapping/google_maps/types.py rename to dimos/mapping/google_maps/models.py diff --git a/dimos/mapping/google_maps/test_google_maps.py b/dimos/mapping/google_maps/test_google_maps.py index 13f7fa8eaa..2805f5589c 100644 --- a/dimos/mapping/google_maps/test_google_maps.py +++ b/dimos/mapping/google_maps/test_google_maps.py @@ -13,7 +13,7 @@ # limitations under the License. -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon def test_get_position(maps_client, maps_fixture) -> None: diff --git a/dimos/mapping/types.py b/dimos/mapping/models.py similarity index 100% rename from dimos/mapping/types.py rename to dimos/mapping/models.py diff --git a/dimos/mapping/occupancy/path_mask.py b/dimos/mapping/occupancy/path_mask.py index 5ad3010111..7744ab95ba 100644 --- a/dimos/mapping/occupancy/path_mask.py +++ b/dimos/mapping/occupancy/path_mask.py @@ -16,8 +16,8 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path def make_path_mask( diff --git a/dimos/mapping/occupancy/path_resampling.py b/dimos/mapping/occupancy/path_resampling.py index 2090bf8f04..4d957a1aad 100644 --- a/dimos/mapping/occupancy/path_resampling.py +++ b/dimos/mapping/occupancy/path_resampling.py @@ -18,8 +18,11 @@ import numpy as np from scipy.ndimage import uniform_filter1d # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Vector3 -from dimos.msgs.nav_msgs import Path +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion diff --git a/dimos/mapping/occupancy/test_path_mask.py b/dimos/mapping/occupancy/test_path_mask.py index dede997946..f566af2a23 100644 --- a/dimos/mapping/occupancy/test_path_mask.py +++ b/dimos/mapping/occupancy/test_path_mask.py @@ -19,9 +19,9 @@ from dimos.mapping.occupancy.path_mask import make_path_mask from dimos.mapping.occupancy.path_resampling import smooth_resample_path from dimos.mapping.occupancy.visualizations import visualize_occupancy_grid -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.navigation.replanning_a_star.min_cost_astar import min_cost_astar from dimos.utils.data import get_data diff --git a/dimos/mapping/occupancy/test_path_resampling.py b/dimos/mapping/occupancy/test_path_resampling.py index c23f71cf89..aeda7d11ad 100644 --- a/dimos/mapping/occupancy/test_path_resampling.py +++ b/dimos/mapping/occupancy/test_path_resampling.py @@ -18,7 +18,7 @@ from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.path_resampling import simple_resample_path, smooth_resample_path from dimos.mapping.occupancy.visualize_path import visualize_path -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid from dimos.msgs.sensor_msgs.Image import Image diff --git a/dimos/mapping/occupancy/visualizations.py b/dimos/mapping/occupancy/visualizations.py index 2ed0364257..36321896be 100644 --- a/dimos/mapping/occupancy/visualizations.py +++ b/dimos/mapping/occupancy/visualizations.py @@ -19,8 +19,8 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.sensor_msgs.Image import Image, ImageFormat Palette: TypeAlias = Literal["rainbow", "turbo"] diff --git a/dimos/mapping/occupancy/visualize_path.py b/dimos/mapping/occupancy/visualize_path.py index 0662582f72..89dcf83067 100644 --- a/dimos/mapping/occupancy/visualize_path.py +++ b/dimos/mapping/occupancy/visualize_path.py @@ -16,8 +16,8 @@ import numpy as np from dimos.mapping.occupancy.visualizations import visualize_occupancy_grid -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.sensor_msgs.Image import Image, ImageFormat diff --git a/dimos/mapping/osm/__init__.py b/dimos/mapping/osm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index ef0a832cd6..4cfeddc9b8 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from PIL import Image as PILImage, ImageDraw +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import MapImage, get_osm_map from dimos.mapping.osm.query import query_for_one_position, query_for_one_position_and_context -from dimos.mapping.types import LatLon from dimos.models.vl.base import VlModel from dimos.utils.logging_config import setup_logger @@ -24,11 +26,11 @@ class CurrentLocationMap: - _vl_model: VlModel + _vl_model: VlModel[Any] _position: LatLon | None _map_image: MapImage | None - def __init__(self, vl_model: VlModel) -> None: + def __init__(self, vl_model: VlModel[Any]) -> None: self._vl_model = vl_model self._position = None self._map_image = None diff --git a/dimos/mapping/osm/osm.py b/dimos/mapping/osm/osm.py index 31fb044087..f9b7eaafda 100644 --- a/dimos/mapping/osm/osm.py +++ b/dimos/mapping/osm/osm.py @@ -21,8 +21,8 @@ from PIL import Image as PILImage import requests # type: ignore[import-untyped] -from dimos.mapping.types import ImageCoord, LatLon -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.mapping.models import ImageCoord, LatLon +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat @dataclass(frozen=True) diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index 410f879c20..7a3c3b0154 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -13,9 +13,10 @@ # limitations under the License. import re +from typing import Any +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import MapImage -from dimos.mapping.types import LatLon from dimos.models.vl.base import VlModel from dimos.utils.generic import extract_json_from_llm_response from dimos.utils.logging_config import setup_logger @@ -25,7 +26,9 @@ logger = setup_logger() -def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: +def query_for_one_position( + vl_model: VlModel[Any], map_image: MapImage, query: str +) -> LatLon | None: full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." response = vl_model.query(map_image.image, full_query) coords = tuple(map(int, re.findall(r"\d+", response))) @@ -35,7 +38,7 @@ def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) - def query_for_one_position_and_context( - vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon + vl_model: VlModel[Any], map_image: MapImage, query: str, robot_position: LatLon ) -> tuple[LatLon, str] | None: example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' x, y = map_image.latlon_to_pixel(robot_position) diff --git a/dimos/mapping/osm/test_osm.py b/dimos/mapping/osm/test_osm.py index 475e2b40fc..64fbb72b02 100644 --- a/dimos/mapping/osm/test_osm.py +++ b/dimos/mapping/osm/test_osm.py @@ -21,8 +21,8 @@ from requests import Request import requests_mock +from dimos.mapping.models import LatLon from dimos.mapping.osm.osm import get_osm_map -from dimos.mapping.types import LatLon from dimos.utils.data import get_data _fixture_dir = get_data("osm_map_test") diff --git a/dimos/mapping/pointclouds/demo.py b/dimos/mapping/pointclouds/demo.py index 5251fc3406..2812aaae42 100644 --- a/dimos/mapping/pointclouds/demo.py +++ b/dimos/mapping/pointclouds/demo.py @@ -25,8 +25,8 @@ read_pointcloud, visualize, ) -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data app = typer.Typer() diff --git a/dimos/mapping/pointclouds/occupancy.py b/dimos/mapping/pointclouds/occupancy.py index 0f6ad8c0de..c9cd7e7af3 100644 --- a/dimos/mapping/pointclouds/occupancy.py +++ b/dimos/mapping/pointclouds/occupancy.py @@ -21,7 +21,7 @@ import numpy as np from scipy import ndimage # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid if TYPE_CHECKING: @@ -99,7 +99,7 @@ def _simple_occupancy_kernel( if TYPE_CHECKING: from collections.abc import Callable - from dimos.msgs.sensor_msgs import PointCloud2 + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 @dataclass(frozen=True) diff --git a/dimos/mapping/pointclouds/test_occupancy.py b/dimos/mapping/pointclouds/test_occupancy.py index d265800f24..93b5793dc8 100644 --- a/dimos/mapping/pointclouds/test_occupancy.py +++ b/dimos/mapping/pointclouds/test_occupancy.py @@ -26,8 +26,8 @@ ) from dimos.mapping.pointclouds.util import read_pointcloud from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment from dimos.utils.testing.test_moment import Go2Moment diff --git a/dimos/mapping/pointclouds/test_occupancy_speed.py b/dimos/mapping/pointclouds/test_occupancy_speed.py index 2def839dd5..ac4085e971 100644 --- a/dimos/mapping/pointclouds/test_occupancy_speed.py +++ b/dimos/mapping/pointclouds/test_occupancy_speed.py @@ -21,7 +21,7 @@ from dimos.mapping.voxels import VoxelGridMapper from dimos.utils.cli.plot import bar from dimos.utils.data import get_data, get_data_dir -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.mark.tool diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index 95e70e1d6d..bb5f4ed764 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -20,7 +20,7 @@ from dimos.core.transport import LCMTransport from dimos.mapping.voxels import VoxelGridMapper -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment from dimos.utils.testing.replay import TimedSensorReplay diff --git a/dimos/mapping/utils/distance.py b/dimos/mapping/utils/distance.py index 6e8c48c205..42b8a9be04 100644 --- a/dimos/mapping/utils/distance.py +++ b/dimos/mapping/utils/distance.py @@ -14,7 +14,7 @@ import math -from dimos.mapping.types import LatLon +from dimos.mapping.models import LatLon def distance_in_meters(location1: LatLon, location2: LatLon) -> float: diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index 124073cf49..e4e03dfc01 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import time +from typing import Any import numpy as np import open3d as o3d # type: ignore[import-untyped] @@ -23,18 +23,16 @@ from reactivex.subject import Subject from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.utils.decorators import simple_mcache +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger from dimos.utils.reactive import backpressure logger = setup_logger() -@dataclass class Config(ModuleConfig): frame_id: str = "world" # -1 never publishes, 0 publishes on every frame, >0 publishes at interval in seconds @@ -45,16 +43,14 @@ class Config(ModuleConfig): carve_columns: bool = True -class VoxelGridMapper(Module): +class VoxelGridMapper(Module[Config]): default_config = Config - config: Config lidar: In[PointCloud2] global_map: Out[PointCloud2] - def __init__(self, cfg: GlobalConfig = global_config, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._global_config = cfg dev = ( o3c.Device(self.config.device) diff --git a/dimos/memory/embedding.py b/dimos/memory/embedding.py index 4627ecfc35..be73d01ac1 100644 --- a/dimos/memory/embedding.py +++ b/dimos/memory/embedding.py @@ -13,9 +13,10 @@ # limitations under the License. from collections.abc import Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import cast +from pydantic import Field import reactivex as rx from reactivex import operators as ops from reactivex.observable import Observable @@ -25,16 +26,14 @@ from dimos.core.stream import In from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.models.embedding.clip import CLIPModel -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.utils.reactive import getter_hot -@dataclass class Config(ModuleConfig): - embedding_model: EmbeddingModel = field(default_factory=CLIPModel) + embedding_model: EmbeddingModel = Field(default_factory=CLIPModel) @dataclass @@ -50,7 +49,6 @@ class SpatialEmbedding(SpatialEntry): class EmbeddingMemory(Module[Config]): default_config = Config - config: Config color_image: In[Image] global_costmap: In[OccupancyGrid] diff --git a/dimos/memory/test_embedding.py b/dimos/memory/test_embedding.py index b7e7fbb294..9a59ed51e1 100644 --- a/dimos/memory/test_embedding.py +++ b/dimos/memory/test_embedding.py @@ -15,9 +15,9 @@ import pytest from dimos.memory.embedding import EmbeddingMemory, SpatialEntry -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay dir_name = "unitree_go2_bigoffice" diff --git a/dimos/memory/timeseries/__init__.py b/dimos/memory/timeseries/__init__.py deleted file mode 100644 index debc14ab3a..0000000000 --- a/dimos/memory/timeseries/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Time series storage and replay.""" - -from dimos.memory.timeseries.base import TimeSeriesStore -from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.memory.timeseries.pickledir import PickleDirStore -from dimos.memory.timeseries.sqlite import SqliteStore - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - if name == "PostgresStore": - from dimos.memory.timeseries.postgres import PostgresStore - - return PostgresStore - if name == "reset_db": - from dimos.memory.timeseries.postgres import reset_db - - return reset_db - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -__all__ = [ - "InMemoryStore", - "PickleDirStore", - "PostgresStore", - "SqliteStore", - "TimeSeriesStore", - "reset_db", -] diff --git a/dimos/memory/timeseries/base.py b/dimos/memory/timeseries/base.py index 0d88355b5b..2831836020 100644 --- a/dimos/memory/timeseries/base.py +++ b/dimos/memory/timeseries/base.py @@ -92,8 +92,6 @@ def _find_after(self, timestamp: float) -> tuple[float, T] | None: """Find the first (ts, data) strictly after the given timestamp.""" ... - # --- Collection API (built on abstract methods) --- - def __len__(self) -> int: return self._count() diff --git a/dimos/memory/timeseries/legacy.py b/dimos/memory/timeseries/legacy.py index 15a4ff90fa..a98b0baddf 100644 --- a/dimos/memory/timeseries/legacy.py +++ b/dimos/memory/timeseries/legacy.py @@ -232,8 +232,6 @@ def _find_after(self, timestamp: float) -> tuple[float, T] | None: return (ts, data) return None - # === Backward-compatible API (TimedSensorReplay/SensorReplay) === - @property def files(self) -> list[Path]: """Return list of pickle files (backward compatibility with SensorReplay).""" diff --git a/dimos/models/__init__.py b/dimos/models/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/models/base.py b/dimos/models/base.py index 2269a6d0b8..fd18d8ba93 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -16,21 +16,19 @@ from __future__ import annotations -from dataclasses import dataclass from functools import cached_property from typing import Annotated, Any import torch from dimos.core.resource import Resource -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Configurable # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] -@dataclass -class LocalModelConfig: +class LocalModelConfig(BaseConfig): device: DeviceType = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.float32 warmup: bool = False @@ -127,7 +125,6 @@ def _ensure_cuda_initialized(self) -> None: pass -@dataclass class HuggingFaceModelConfig(LocalModelConfig): model_name: str = "" trust_remote_code: bool = True diff --git a/dimos/models/embedding/__init__.py b/dimos/models/embedding/__init__.py deleted file mode 100644 index 050d35467e..0000000000 --- a/dimos/models/embedding/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from dimos.models.embedding.base import Embedding, EmbeddingModel - -__all__ = [ - "Embedding", - "EmbeddingModel", -] - -# Optional: CLIP support -try: - from dimos.models.embedding.clip import CLIPModel - - __all__.append("CLIPModel") -except ImportError: - pass - -# Optional: MobileCLIP support -try: - from dimos.models.embedding.mobileclip import MobileCLIPModel - - __all__.append("MobileCLIPModel") -except ImportError: - pass - -# Optional: TorchReID support -try: - from dimos.models.embedding.treid import TorchReIDModel - - __all__.append("TorchReIDModel") -except ImportError: - pass diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index c6b78fcf2c..0c80cafc0a 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -15,7 +15,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass import time from typing import TYPE_CHECKING @@ -26,17 +25,15 @@ from dimos.types.timestamped import Timestamped if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image -@dataclass class EmbeddingModelConfig(LocalModelConfig): """Base config for embedding models.""" normalize: bool = True -@dataclass class HuggingFaceEmbeddingModelConfig(HuggingFaceModelConfig): """Base config for HuggingFace-based embedding models.""" diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 1b8d3e68bb..6fb42b7ccf 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from functools import cached_property from PIL import Image as PILImage @@ -22,10 +21,9 @@ from dimos.models.base import HuggingFaceModel from dimos.models.embedding.base import Embedding, EmbeddingModel, HuggingFaceEmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image -@dataclass class CLIPModelConfig(HuggingFaceEmbeddingModelConfig): model_name: str = "openai/clip-vit-base-patch32" dtype: torch.dtype = torch.float32 diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index c02361b367..84bba74829 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from functools import cached_property from typing import Any @@ -23,11 +22,10 @@ from dimos.models.base import LocalModel from dimos.models.embedding.base import Embedding, EmbeddingModel, EmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data -@dataclass class MobileCLIPModelConfig(EmbeddingModelConfig): model_name: str = "MobileCLIP2-S4" diff --git a/dimos/models/embedding/test_embedding.py b/dimos/models/embedding/test_embedding.py index 466c974b32..20aac83dbb 100644 --- a/dimos/models/embedding/test_embedding.py +++ b/dimos/models/embedding/test_embedding.py @@ -7,7 +7,7 @@ from dimos.models.embedding.clip import CLIPModel from dimos.models.embedding.mobileclip import MobileCLIPModel from dimos.models.embedding.treid import TorchReIDModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 85e32cd39b..21a4527781 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -16,7 +16,6 @@ warnings.filterwarnings("ignore", message="Cython evaluation.*unavailable", category=UserWarning) -from dataclasses import dataclass from functools import cached_property import torch @@ -25,14 +24,13 @@ from dimos.models.base import LocalModel from dimos.models.embedding.base import Embedding, EmbeddingModel, EmbeddingModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data # osnet models downloaded from https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html # into dimos/data/models_torchreid/ # feel free to add more -@dataclass class TorchReIDModelConfig(EmbeddingModelConfig): model_name: str = "osnet_x1_0" diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index 54158b2b92..e9744f6d81 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -28,9 +28,9 @@ from PIL import Image as PILImage import torch -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection2d.seg import Detection2DSeg from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py deleted file mode 100644 index 482a907cbd..0000000000 --- a/dimos/models/vl/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "base": ["Captioner", "VlModel"], - "florence": ["Florence2Model"], - "moondream": ["MoondreamVlModel"], - "moondream_hosted": ["MoondreamHostedVlModel"], - "openai": ["OpenAIVlModel"], - "qwen": ["QwenVlModel"], - }, -) diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 41b240eaf9..08b83fc503 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,21 +1,26 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass import json import logging -from typing import TYPE_CHECKING, Any +import sys +from typing import Any import warnings from dimos.core.resource import Resource -from dimos.msgs.sensor_msgs import Image -from dimos.protocol.service import Configurable # type: ignore[attr-defined] +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint +from dimos.protocol.service.spec import BaseConfig, Configurable from dimos.utils.data import get_data -from dimos.utils.decorators import retry +from dimos.utils.decorators.decorators import retry from dimos.utils.llm_utils import extract_json -if TYPE_CHECKING: - from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar logger = logging.getLogger(__name__) @@ -70,7 +75,7 @@ def vlm_detection_to_detection2d( Detection2DBBox instance or None if invalid """ # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import Detection2DBBox + from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox # Validate list/tuple structure if not isinstance(vlm_detection, (list, tuple)): @@ -127,7 +132,7 @@ def vlm_point_to_detection2d_point( Returns: Detection2DPoint instance or None if invalid """ - from dimos.perception.detection.type import Detection2DPoint + from dimos.perception.detection.type.detection2d.point import Detection2DPoint # Validate list/tuple structure if not isinstance(vlm_point, (list, tuple)): @@ -159,15 +164,17 @@ def vlm_point_to_detection2d_point( ) -@dataclass -class VlModelConfig: +class VlModelConfig(BaseConfig): """Configuration for VlModel.""" auto_resize: tuple[int, int] | None = None """Optional (width, height) tuple. If set, images are resized to fit.""" -class VlModel(Captioner, Resource, Configurable[VlModelConfig]): +_VlConfig = TypeVar("_VlConfig", bound=VlModelConfig) + + +class VlModel(Captioner, Resource, Configurable[_VlConfig]): """Vision-language model that can answer questions about images. Inherits from Captioner, providing a default caption() implementation @@ -176,8 +183,7 @@ class VlModel(Captioner, Resource, Configurable[VlModelConfig]): Implements Resource interface for lifecycle management. """ - default_config = VlModelConfig - config: VlModelConfig + default_config: type[_VlConfig] = VlModelConfig # type: ignore[assignment] def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. @@ -256,7 +262,7 @@ def query_detections( self, image: Image, query: str, **kwargs: Any ) -> ImageDetections2D[Detection2DBBox]: # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import ImageDetections2D + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D full_query = f"""show me bounding boxes in pixels for this query: `{query}` @@ -317,7 +323,7 @@ def query_points( ImageDetections2D containing Detection2DPoint instances """ # Here to prevent unwanted imports in the file. - from dimos.perception.detection.type import ImageDetections2D + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D full_query = f"""Show me point coordinates in pixels for this query: `{query}` diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 1f8819c8db..6c778d4104 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,11 +1,11 @@ -from typing import Literal +from typing import Any, Literal from dimos.models.vl.base import VlModel VlModelName = Literal["qwen", "moondream"] -def create(name: VlModelName) -> VlModel: +def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. match name: case "qwen": diff --git a/dimos/models/vl/florence.py b/dimos/models/vl/florence.py index 2e6cf822a8..b68441328a 100644 --- a/dimos/models/vl/florence.py +++ b/dimos/models/vl/florence.py @@ -20,7 +20,7 @@ from dimos.models.base import HuggingFaceModel from dimos.models.vl.base import Captioner -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image class Florence2Model(HuggingFaceModel, Captioner): diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index f31611e867..0f5e501ef6 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property from typing import Any import warnings @@ -9,16 +8,17 @@ from transformers import AutoModelForCausalLM # type: ignore[import-untyped] from dimos.models.base import HuggingFaceModel, HuggingFaceModelConfig -from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint # Moondream works well with 512x512 max MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) -@dataclass -class MoondreamConfig(HuggingFaceModelConfig): +class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): """Configuration for MoondreamVlModel.""" model_name: str = "vikhyatk/moondream2" @@ -26,10 +26,9 @@ class MoondreamConfig(HuggingFaceModelConfig): auto_resize: tuple[int, int] | None = MOONDREAM_DEFAULT_AUTO_RESIZE -class MoondreamVlModel(HuggingFaceModel, VlModel): +class MoondreamVlModel(HuggingFaceModel, VlModel[MoondreamConfig]): _model_class = AutoModelForCausalLM default_config = MoondreamConfig # type: ignore[assignment] - config: MoondreamConfig # type: ignore[assignment] @cached_property def _model(self) -> AutoModelForCausalLM: diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index fc1f8b7a17..aad9fe514c 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -6,20 +6,23 @@ import numpy as np from PIL import Image as PILImage -from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.point import Detection2DPoint -class MoondreamHostedVlModel(VlModel): - _api_key: str | None +class Config(VlModelConfig): + api_key: str | None = None - def __init__(self, api_key: str | None = None) -> None: - self._api_key = api_key + +class MoondreamHostedVlModel(VlModel[Config]): + default_config = Config @cached_property def _client(self) -> md.vl: - api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") + api_key = self.config.api_key or os.getenv("MOONDREAM_API_KEY") if not api_key: raise ValueError( "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index f596f1ee1e..0486bbdb30 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -7,21 +6,19 @@ from openai import OpenAI from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() -@dataclass class OpenAIVlModelConfig(VlModelConfig): model_name: str = "gpt-4o-mini" api_key: str | None = None -class OpenAIVlModel(VlModel): +class OpenAIVlModel(VlModel[OpenAIVlModelConfig]): default_config = OpenAIVlModelConfig - config: OpenAIVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 93b31bf74c..202ce6759e 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from functools import cached_property import os from typing import Any @@ -7,10 +6,9 @@ from openai import OpenAI from dimos.models.vl.base import VlModel, VlModelConfig -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image -@dataclass class QwenVlModelConfig(VlModelConfig): """Configuration for Qwen VL model.""" @@ -18,9 +16,8 @@ class QwenVlModelConfig(VlModelConfig): api_key: str | None = None -class QwenVlModel(VlModel): +class QwenVlModel(VlModel[QwenVlModelConfig]): default_config = QwenVlModelConfig - config: QwenVlModelConfig @cached_property def _client(self) -> OpenAI: diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py index 0cc5c90d0e..b0b03e70fa 100644 --- a/dimos/models/vl/test_base.py +++ b/dimos/models/vl/test_base.py @@ -6,8 +6,8 @@ from dimos.core.transport import LCMTransport from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data # Captured actual response from Qwen API for cafe.jpg with query "humans" diff --git a/dimos/models/vl/test_captioner.py b/dimos/models/vl/test_captioner.py index c7ebb8fc63..734c83290e 100644 --- a/dimos/models/vl/test_captioner.py +++ b/dimos/models/vl/test_captioner.py @@ -6,7 +6,7 @@ from dimos.models.vl.florence import Florence2Model from dimos.models.vl.moondream import MoondreamVlModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index 43dad0ef94..f0fd3b8d5a 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -11,8 +11,8 @@ from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.cli.plot import bar from dimos.utils.data import get_data @@ -228,7 +228,7 @@ def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: @pytest.mark.slow def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: """Test query_batch optimization - multiple images, same query.""" - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay # Load 5 frames at 1-second intervals using TimedSensorReplay replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") @@ -285,7 +285,7 @@ def test_vlm_resize( sizes: list[tuple[int, int] | None], ) -> None: """Test VLM auto_resize effect on performance.""" - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") image = replay.find_closest_seek(0).to_rgb() diff --git a/dimos/msgs/__init__.py b/dimos/msgs/__init__.py deleted file mode 100644 index 4395dbcc51..0000000000 --- a/dimos/msgs/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from dimos.msgs.helpers import resolve_msg_type -from dimos.msgs.protocol import DimosMsg - -__all__ = ["DimosMsg", "resolve_msg_type"] diff --git a/dimos/msgs/foxglove_msgs/__init__.py b/dimos/msgs/foxglove_msgs/__init__.py deleted file mode 100644 index 945ebf94c9..0000000000 --- a/dimos/msgs/foxglove_msgs/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations - -__all__ = ["ImageAnnotations"] diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 5f50f9b9d1..9b08c8dadd 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -29,7 +29,7 @@ from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.std_msgs import Header +from dimos.msgs.std_msgs.Header import Header from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py deleted file mode 100644 index 01069d765c..0000000000 --- a/dimos/msgs/geometry_msgs/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -from dimos.msgs.geometry_msgs.Point import Point -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike, to_pose -from dimos.msgs.geometry_msgs.PoseArray import PoseArray -from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance -from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped -from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance -from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike -from dimos.msgs.geometry_msgs.Wrench import Wrench -from dimos.msgs.geometry_msgs.WrenchStamped import WrenchStamped - -__all__ = [ - "Point", - "PointStamped", - "Pose", - "PoseArray", - "PoseLike", - "PoseStamped", - "PoseWithCovariance", - "PoseWithCovarianceStamped", - "Quaternion", - "Transform", - "Twist", - "TwistStamped", - "TwistWithCovariance", - "TwistWithCovarianceStamped", - "Vector3", - "VectorLike", - "Wrench", - "WrenchStamped", - "to_pose", -] diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py index 82250a9113..a486f33303 100644 --- a/dimos/msgs/geometry_msgs/test_PoseStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -15,7 +15,7 @@ import pickle import time -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped def test_lcm_encode_decode() -> None: diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index 0c15610b05..056238719a 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -18,7 +18,11 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def test_transform_initialization() -> None: diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py index df4bd8b6a2..a4dc93f3cc 100644 --- a/dimos/msgs/geometry_msgs/test_Twist.py +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -15,7 +15,9 @@ from dimos_lcm.geometry_msgs import Twist as LCMTwist import numpy as np -from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def test_twist_initialization() -> None: diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py index b3d2324af0..01c5cf7842 100644 --- a/dimos/msgs/geometry_msgs/test_publish.py +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -17,7 +17,7 @@ import lcm import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 @pytest.mark.tool diff --git a/dimos/msgs/helpers.py b/dimos/msgs/helpers.py index 8464ec4ab1..91466f7fdd 100644 --- a/dimos/msgs/helpers.py +++ b/dimos/msgs/helpers.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from dimos.msgs import DimosMsg + from dimos.msgs.protocol import DimosMsg @lru_cache(maxsize=256) @@ -38,7 +38,10 @@ def resolve_msg_type(type_name: str) -> type[DimosMsg] | None: return None # Try different import paths + # First try the direct submodule path (e.g., dimos.msgs.geometry_msgs.Quaternion) + # then fall back to parent package (for dimos_lcm or other packages) import_paths = [ + f"dimos.msgs.{module_name}.{class_name}", f"dimos.msgs.{module_name}", f"dimos_lcm.{module_name}", ] diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py index d45e1b6232..4760884620 100644 --- a/dimos/msgs/nav_msgs/OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -28,7 +28,8 @@ import numpy as np from PIL import Image -from dimos.msgs.geometry_msgs import Pose, Vector3, VectorLike +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py deleted file mode 100644 index 9d099068ad..0000000000 --- a/dimos/msgs/nav_msgs/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from dimos.msgs.nav_msgs.OccupancyGrid import ( # type: ignore[attr-defined] - CostValues, - MapMetaData, - OccupancyGrid, -) -from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.nav_msgs.Path import Path - -__all__ = ["CostValues", "MapMetaData", "OccupancyGrid", "Odometry", "Path"] diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py index d1ec8938b4..7aae8abfac 100644 --- a/dimos/msgs/nav_msgs/test_OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -23,9 +23,9 @@ from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.inflation import simple_inflate from dimos.mapping.pointclouds.occupancy import general_occupancy -from dimos.msgs.geometry_msgs import Pose -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data diff --git a/dimos/msgs/sensor_msgs/Imu.py b/dimos/msgs/sensor_msgs/Imu.py index 7fe03ce03f..f3461975ff 100644 --- a/dimos/msgs/sensor_msgs/Imu.py +++ b/dimos/msgs/sensor_msgs/Imu.py @@ -18,7 +18,8 @@ from dimos_lcm.sensor_msgs.Imu import Imu as LCMImu -from dimos.msgs.geometry_msgs import Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.types.timestamped import Timestamped diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index 22fe731a70..67af1c5ac3 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -28,7 +28,8 @@ import open3d as o3d # type: ignore[import-untyped] import open3d.core as o3c # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Transform, Vector3 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.types.timestamped import Timestamped if TYPE_CHECKING: diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py deleted file mode 100644 index 7fec2d2793..0000000000 --- a/dimos/msgs/sensor_msgs/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.msgs.sensor_msgs.Imu import Imu -from dimos.msgs.sensor_msgs.JointCommand import JointCommand -from dimos.msgs.sensor_msgs.JointState import JointState -from dimos.msgs.sensor_msgs.Joy import Joy -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.msgs.sensor_msgs.RobotState import RobotState - -__all__ = [ - "CameraInfo", - "Image", - "ImageFormat", - "Imu", - "JointCommand", - "JointState", - "Joy", - "PointCloud2", - "RobotState", -] diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py index f48802ab7a..70e6e35aec 100644 --- a/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -16,9 +16,9 @@ import numpy as np -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay def test_lcm_encode_decode() -> None: diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 24375139b3..cc2fc9f096 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -18,7 +18,7 @@ from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.fixture diff --git a/dimos/msgs/std_msgs/__init__.py b/dimos/msgs/std_msgs/__init__.py deleted file mode 100644 index ae8e3dd8f6..0000000000 --- a/dimos/msgs/std_msgs/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .Bool import Bool -from .Header import Header -from .Int8 import Int8 -from .Int32 import Int32 -from .UInt32 import UInt32 - -__all__ = ["Bool", "Header", "Int8", "Int32", "UInt32"] diff --git a/dimos/msgs/std_msgs/test_header.py b/dimos/msgs/std_msgs/test_header.py index 93f20da283..29f4ee2c0e 100644 --- a/dimos/msgs/std_msgs/test_header.py +++ b/dimos/msgs/std_msgs/test_header.py @@ -15,7 +15,7 @@ from datetime import datetime import time -from dimos.msgs.std_msgs import Header +from dimos.msgs.std_msgs.Header import Header def test_header_initialization_methods() -> None: diff --git a/dimos/msgs/tf2_msgs/__init__.py b/dimos/msgs/tf2_msgs/__init__.py deleted file mode 100644 index 69d4e0137e..0000000000 --- a/dimos/msgs/tf2_msgs/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.msgs.tf2_msgs.TFMessage import TFMessage - -__all__ = ["TFMessage"] diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py index 8567de9988..c379481f1d 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -14,8 +14,10 @@ from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage def test_tfmessage_initialization() -> None: diff --git a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py index 8b58a61a44..2a03b7ee71 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py @@ -16,8 +16,10 @@ import pytest -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic diff --git a/dimos/msgs/trajectory_msgs/__init__.py b/dimos/msgs/trajectory_msgs/__init__.py deleted file mode 100644 index 44039e594e..0000000000 --- a/dimos/msgs/trajectory_msgs/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Trajectory message types. - -Similar to ROS trajectory_msgs package. -""" - -from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory -from dimos.msgs.trajectory_msgs.TrajectoryPoint import TrajectoryPoint -from dimos.msgs.trajectory_msgs.TrajectoryStatus import TrajectoryState, TrajectoryStatus - -__all__ = [ - "JointTrajectory", - "TrajectoryPoint", - "TrajectoryState", - "TrajectoryStatus", -] diff --git a/dimos/msgs/vision_msgs/__init__.py b/dimos/msgs/vision_msgs/__init__.py deleted file mode 100644 index 0f1c9c8dc1..0000000000 --- a/dimos/msgs/vision_msgs/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .BoundingBox2DArray import BoundingBox2DArray -from .BoundingBox3DArray import BoundingBox3DArray -from .Detection2D import Detection2D -from .Detection2DArray import Detection2DArray -from .Detection3D import Detection3D -from .Detection3DArray import Detection3DArray - -__all__ = [ - "BoundingBox2DArray", - "BoundingBox3DArray", - "Detection2D", - "Detection2DArray", - "Detection3D", - "Detection3DArray", -] diff --git a/dimos/navigation/base.py b/dimos/navigation/base.py index 347c4ad124..1530308711 100644 --- a/dimos/navigation/base.py +++ b/dimos/navigation/base.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from enum import Enum -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped class NavigationState(Enum): diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index e0752dfd00..c96ba9efad 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -18,26 +18,30 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.DEBUG) -class BBoxNavigationModule(Module): +class Config(ModuleConfig): + goal_distance: float = 1.0 + + +class BBoxNavigationModule(Module[Config]): """Minimal module that converts 2D bbox center to navigation goals.""" + default_config = Config + detection2d: In[Detection2DArray] camera_info: In[CameraInfo] goal_request: Out[PoseStamped] - - def __init__(self, goal_distance: float = 1.0) -> None: - super().__init__() - self.goal_distance = goal_distance - self.camera_intrinsics = None + camera_intrinsics = None @rpc def start(self) -> None: @@ -62,9 +66,9 @@ def _on_detection(self, det: Detection2DArray) -> None: det.detections[0].bbox.center.position.y, ) x, y, z = ( - (center_x - cx) / fx * self.goal_distance, - (center_y - cy) / fy * self.goal_distance, - self.goal_distance, + (center_x - cx) / fx * self.config.goal_distance, + (center_y - cy) / fy * self.config.goal_distance, + self.config.goal_distance, ) goal = PoseStamped( position=Vector3(z, -x, -y), diff --git a/dimos/navigation/demo_ros_navigation.py b/dimos/navigation/demo_ros_navigation.py index 4d57867d59..0efa04cd44 100644 --- a/dimos/navigation/demo_ros_navigation.py +++ b/dimos/navigation/demo_ros_navigation.py @@ -15,7 +15,9 @@ import time from dimos.core.module_coordinator import ModuleCoordinator -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.navigation import rosnav from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger diff --git a/dimos/navigation/frontier_exploration/__init__.py b/dimos/navigation/frontier_exploration/__init__.py deleted file mode 100644 index 24ce957ccf..0000000000 --- a/dimos/navigation/frontier_exploration/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .wavefront_frontier_goal_selector import WavefrontFrontierExplorer, wavefront_frontier_explorer - -__all__ = ["WavefrontFrontierExplorer", "wavefront_frontier_explorer"] diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 1c8082b414..834897d396 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -17,8 +17,8 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) @@ -56,7 +56,7 @@ def quick_costmap(): # One obstacle grid[9:10, 9:10] = CostValues.OCCUPIED - from dimos.msgs.geometry_msgs import Pose + from dimos.msgs.geometry_msgs.Pose import Pose origin = Pose() origin.position.x = -1.0 @@ -97,7 +97,7 @@ def create_test_costmap(width: int = 40, height: int = 40, resolution: float = 0 grid[13:14, 18:22] = CostValues.OCCUPIED # Top corridor obstacle # Create origin at bottom-left, adjusted for map size - from dimos.msgs.geometry_msgs import Pose + from dimos.msgs.geometry_msgs.Pose import Pose origin = Pose() # Center the map around (0, 0) in world coordinates @@ -262,7 +262,7 @@ def test_frontier_ranking(explorer) -> None: # Note: Goals might be closer than safe_distance if that's the best available frontier # The safe_distance is used for scoring, not as a hard constraint print( - f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.safe_distance}m)" + f"Distance to obstacles: {obstacle_dist:.2f}m (safe distance: {explorer.config.safe_distance}m)" ) print(f"Frontier ranking test passed - selected goal at ({goal1.x:.2f}, {goal1.y:.2f})") diff --git a/dimos/navigation/frontier_exploration/utils.py b/dimos/navigation/frontier_exploration/utils.py index 28644cdd41..d5ed7df61c 100644 --- a/dimos/navigation/frontier_exploration/utils.py +++ b/dimos/navigation/frontier_exploration/utils.py @@ -19,8 +19,8 @@ import numpy as np from PIL import Image, ImageDraw -from dimos.msgs.geometry_msgs import Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image.Image: diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 6e598e8316..20fab41b35 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -23,6 +23,7 @@ from dataclasses import dataclass from enum import IntFlag import threading +from typing import Any from dimos_lcm.std_msgs import Bool import numpy as np @@ -30,11 +31,12 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.mapping.occupancy.inflation import simple_inflate -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import get_distance @@ -78,7 +80,18 @@ def clear(self) -> None: self.points.clear() -class WavefrontFrontierExplorer(Module): +class WavefrontConfig(ModuleConfig): + min_frontier_perimeter: float = 0.5 + occupancy_threshold: int = 99 + safe_distance: float = 3.0 + lookahead_distance: float = 5.0 + max_explored_distance: float = 10.0 + info_gain_threshold: float = 0.03 + num_no_gain_attempts: int = 2 + goal_timeout: float = 15.0 + + +class WavefrontFrontierExplorer(Module[WavefrontConfig]): """ Wavefront frontier exploration algorithm implementation. @@ -93,6 +106,8 @@ class WavefrontFrontierExplorer(Module): - goal_request: Exploration goals sent to the navigator """ + default_config = WavefrontConfig + # LCM inputs global_costmap: In[OccupancyGrid] odom: In[PoseStamped] @@ -103,17 +118,7 @@ class WavefrontFrontierExplorer(Module): # LCM outputs goal_request: Out[PoseStamped] - def __init__( - self, - min_frontier_perimeter: float = 0.5, - occupancy_threshold: int = 99, - safe_distance: float = 3.0, - lookahead_distance: float = 5.0, - max_explored_distance: float = 10.0, - info_gain_threshold: float = 0.03, - num_no_gain_attempts: int = 2, - goal_timeout: float = 15.0, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the frontier explorer. @@ -124,20 +129,12 @@ def __init__( info_gain_threshold: Minimum percentage increase in costmap information required to continue exploration (0.05 = 5%) num_no_gain_attempts: Maximum number of consecutive attempts with no information gain """ - super().__init__() - self.min_frontier_perimeter = min_frontier_perimeter - self.occupancy_threshold = occupancy_threshold - self.safe_distance = safe_distance - self.max_explored_distance = max_explored_distance - self.lookahead_distance = lookahead_distance - self.info_gain_threshold = info_gain_threshold - self.num_no_gain_attempts = num_no_gain_attempts + super().__init__(**kwargs) self._cache = FrontierCache() self.explored_goals = [] # type: ignore[var-annotated] # list of explored goals self.exploration_direction = Vector3(0.0, 0.0, 0.0) # current exploration direction self.last_costmap = None # store last costmap for information comparison self.no_gain_counter = 0 # track consecutive no-gain attempts - self.goal_timeout = goal_timeout # Latest data self.latest_costmap: OccupancyGrid | None = None @@ -214,7 +211,7 @@ def _count_costmap_information(self, costmap: OccupancyGrid) -> int: Number of cells that are free space or obstacles (not unknown) """ free_count = np.sum(costmap.grid == CostValues.FREE) - obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) + obstacle_count = np.sum(costmap.grid >= self.config.occupancy_threshold) return int(free_count + obstacle_count) def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: @@ -252,7 +249,7 @@ def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: neighbor_cost = costmap.grid[neighbor.y, neighbor.x] # If adjacent to occupied space, not a frontier - if neighbor_cost > self.occupancy_threshold: + if neighbor_cost > self.config.occupancy_threshold: return False # Check if adjacent to free space @@ -376,7 +373,7 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[ # Check if we found a large enough frontier # Convert minimum perimeter to minimum number of cells based on resolution - min_cells = int(self.min_frontier_perimeter / costmap.resolution) + min_cells = int(self.config.min_frontier_perimeter / costmap.resolution) if len(new_frontier) >= min_cells: world_points = [] for point in new_frontier: @@ -489,7 +486,7 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr min_distance = float("inf") search_radius = ( - int(self.safe_distance / costmap.resolution) + 5 + int(self.config.safe_distance / costmap.resolution) + 5 ) # Search a bit beyond minimum # Search in a square around the frontier point @@ -508,14 +505,14 @@ def _compute_distance_to_obstacles(self, frontier: Vector3, costmap: OccupancyGr continue # Check if this cell is an obstacle - if costmap.grid[check_y, check_x] >= self.occupancy_threshold: + if costmap.grid[check_y, check_x] >= self.config.occupancy_threshold: # Calculate distance in meters distance = np.sqrt(dx**2 + dy**2) * costmap.resolution min_distance = min(min_distance, distance) # If no obstacles found within search radius, return the safe distance # This indicates the frontier is safely away from obstacles - return min_distance if min_distance != float("inf") else self.safe_distance + return min_distance if min_distance != float("inf") else self.config.safe_distance def _compute_comprehensive_frontier_score( self, frontier: Vector3, frontier_size: int, robot_pose: Vector3, costmap: OccupancyGrid @@ -527,25 +524,25 @@ def _compute_comprehensive_frontier_score( # Distance score: prefer moderate distances (not too close, not too far) # Normalized to 0-1 range - distance_score = 1.0 / (1.0 + abs(robot_distance - self.lookahead_distance)) + distance_score = 1.0 / (1.0 + abs(robot_distance - self.config.lookahead_distance)) # 2. Information gain (frontier size) # Normalize by a reasonable max frontier size - max_expected_frontier_size = self.min_frontier_perimeter / costmap.resolution * 10 + max_expected_frontier_size = self.config.min_frontier_perimeter / costmap.resolution * 10 info_gain_score = min(frontier_size / max_expected_frontier_size, 1.0) # 3. Distance to explored goals (bonus for being far from explored areas) # Normalize by a reasonable max distance (e.g., 10 meters) explored_goals_distance = self._compute_distance_to_explored_goals(frontier) - explored_goals_score = min(explored_goals_distance / self.max_explored_distance, 1.0) + explored_goals_score = min(explored_goals_distance / self.config.max_explored_distance, 1.0) # 4. Distance to obstacles (score based on safety) # 0 = too close to obstacles, 1 = at or beyond safe distance obstacles_distance = self._compute_distance_to_obstacles(frontier, costmap) - if obstacles_distance >= self.safe_distance: + if obstacles_distance >= self.config.safe_distance: obstacles_score = 1.0 # Fully safe else: - obstacles_score = obstacles_distance / self.safe_distance # Linear penalty + obstacles_score = obstacles_distance / self.config.safe_distance # Linear penalty # 5. Direction momentum (already in 0-1 range from dot product) momentum_score = self._compute_direction_momentum_score(frontier, robot_pose) @@ -628,15 +625,15 @@ def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> V # Check if information increase meets minimum percentage threshold if last_info > 0: # Avoid division by zero info_increase_percent = (current_info - last_info) / last_info - if info_increase_percent < self.info_gain_threshold: + if info_increase_percent < self.config.info_gain_threshold: logger.info( - f"Information increase ({info_increase_percent:.2f}) below threshold ({self.info_gain_threshold:.2f})" + f"Information increase ({info_increase_percent:.2f}) below threshold ({self.config.info_gain_threshold:.2f})" ) logger.info( f"Current information: {current_info}, Last information: {last_info}" ) self.no_gain_counter += 1 - if self.no_gain_counter >= self.num_no_gain_attempts: + if self.no_gain_counter >= self.config.num_no_gain_attempts: logger.info( f"No information gain for {self.no_gain_counter} consecutive attempts" ) @@ -797,7 +794,7 @@ def _exploration_loop(self) -> None: # Wait for goal to be reached or timeout logger.info("Waiting for goal to be reached...") - goal_reached = self.goal_reached_event.wait(timeout=self.goal_timeout) + goal_reached = self.goal_reached_event.wait(timeout=self.config.goal_timeout) if goal_reached: logger.info("Goal reached, finding next frontier") diff --git a/dimos/navigation/replanning_a_star/controllers.py b/dimos/navigation/replanning_a_star/controllers.py index 865aafb8be..07ba8c7119 100644 --- a/dimos/navigation/replanning_a_star/controllers.py +++ b/dimos/navigation/replanning_a_star/controllers.py @@ -19,8 +19,9 @@ from numpy.typing import NDArray from dimos.core.global_config import GlobalConfig -from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.trigonometry import angle_diff diff --git a/dimos/navigation/replanning_a_star/global_planner.py b/dimos/navigation/replanning_a_star/global_planner.py index df2680a4a7..4c4e79cb7b 100644 --- a/dimos/navigation/replanning_a_star/global_planner.py +++ b/dimos/navigation/replanning_a_star/global_planner.py @@ -23,8 +23,8 @@ from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource from dimos.mapping.occupancy.path_resampling import smooth_resample_path -from dimos.msgs.geometry_msgs import Twist from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.msgs.nav_msgs.Path import Path diff --git a/dimos/navigation/replanning_a_star/goal_validator.py b/dimos/navigation/replanning_a_star/goal_validator.py index 5cd093e955..b717c76295 100644 --- a/dimos/navigation/replanning_a_star/goal_validator.py +++ b/dimos/navigation/replanning_a_star/goal_validator.py @@ -16,8 +16,8 @@ import numpy as np -from dimos.msgs.geometry_msgs import Vector3, VectorLike -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid def find_safe_goal( diff --git a/dimos/navigation/replanning_a_star/local_planner.py b/dimos/navigation/replanning_a_star/local_planner.py index a5f8d9e457..d50d0def84 100644 --- a/dimos/navigation/replanning_a_star/local_planner.py +++ b/dimos/navigation/replanning_a_star/local_planner.py @@ -23,9 +23,10 @@ from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Twist from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationState from dimos.navigation.replanning_a_star.controllers import Controller, PController from dimos.navigation.replanning_a_star.navigation_map import NavigationMap diff --git a/dimos/navigation/replanning_a_star/min_cost_astar.py b/dimos/navigation/replanning_a_star/min_cost_astar.py index c3430e64d9..55f502680c 100644 --- a/dimos/navigation/replanning_a_star/min_cost_astar.py +++ b/dimos/navigation/replanning_a_star/min_cost_astar.py @@ -14,8 +14,11 @@ import heapq -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike -from dimos.msgs.nav_msgs import CostValues, OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import VectorLike +from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger # Try to import C++ extension for faster pathfinding diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 4dad9a2843..796390f06c 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -13,16 +13,19 @@ # limitations under the License. import os +from typing import Any from dimos_lcm.std_msgs import Bool, String from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PointStamped, PoseStamped, Twist -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner @@ -41,12 +44,10 @@ class ReplanningAStarPlanner(Module, NavigationInterface): navigation_costmap: Out[OccupancyGrid] _planner: GlobalPlanner - _global_config: GlobalConfig - def __init__(self, cfg: GlobalConfig = global_config) -> None: - super().__init__() - self._global_config = cfg - self._planner = GlobalPlanner(self._global_config) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._planner = GlobalPlanner(self.config.g) @rpc def start(self) -> None: diff --git a/dimos/navigation/replanning_a_star/path_clearance.py b/dimos/navigation/replanning_a_star/path_clearance.py index e99fba26c3..7dc08d49e0 100644 --- a/dimos/navigation/replanning_a_star/path_clearance.py +++ b/dimos/navigation/replanning_a_star/path_clearance.py @@ -19,8 +19,8 @@ from dimos.core.global_config import GlobalConfig from dimos.mapping.occupancy.path_mask import make_path_mask -from dimos.msgs.nav_msgs import Path from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path class PathClearance: diff --git a/dimos/navigation/replanning_a_star/path_distancer.py b/dimos/navigation/replanning_a_star/path_distancer.py index 04d844267f..c50583ca33 100644 --- a/dimos/navigation/replanning_a_star/path_distancer.py +++ b/dimos/navigation/replanning_a_star/path_distancer.py @@ -17,7 +17,7 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.nav_msgs import Path +from dimos.msgs.nav_msgs.Path import Path class PathDistancer: diff --git a/dimos/navigation/replanning_a_star/test_goal_validator.py b/dimos/navigation/replanning_a_star/test_goal_validator.py index 4cda9de863..69c7147696 100644 --- a/dimos/navigation/replanning_a_star/test_goal_validator.py +++ b/dimos/navigation/replanning_a_star/test_goal_validator.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, OccupancyGrid from dimos.navigation.replanning_a_star.goal_validator import find_safe_goal from dimos.utils.data import get_data diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 89b299ae5b..38c8e32847 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -18,53 +18,60 @@ Encapsulates ROS transport and topic remapping for Unitree robots. """ -from dataclasses import dataclass, field import logging import threading import time +from typing import Any +from pydantic import Field from reactivex import operators as ops from reactivex.subject import Subject -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, ROSTransport -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - TwistStamped, - Vector3, -) -from dimos.msgs.nav_msgs import Path -from dimos.msgs.sensor_msgs import Joy, PointCloud2 -from dimos.msgs.std_msgs import Bool, Int8 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Bool import Bool +from dimos.msgs.std_msgs.Int8 import Int8 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.navigation.base import NavigationInterface, NavigationState +from dimos.spec.control import LocalPlanner +from dimos.spec.mapping import GlobalPointcloud +from dimos.spec.nav import Nav +from dimos.spec.perception import Pointcloud from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion logger = setup_logger(level=logging.INFO) -@dataclass class Config(ModuleConfig): local_pointcloud_freq: float = 2.0 global_map_freq: float = 1.0 - sensor_to_base_link_transform: Transform = field( + sensor_to_base_link_transform: Transform = Field( default_factory=lambda: Transform(frame_id="sensor", child_frame_id="base_link") ) class ROSNav( - Module, NavigationInterface, spec.Nav, spec.GlobalPointcloud, spec.Pointcloud, spec.LocalPlanner + Module[Config], + NavigationInterface, + Nav, + GlobalPointcloud, + Pointcloud, + LocalPlanner, ): - config: Config default_config = Config # Existing ports (default LCM/pSHM transport) @@ -106,8 +113,8 @@ class ROSNav( _current_goal: PoseStamped | None = None _goal_reached: bool = False - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # Initialize RxPY Subjects for streaming data self._local_pointcloud_subject = Subject() diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 37b743506a..0693ca5dd1 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from dimos.models.qwen.bbox import BBox from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.generic import extract_json_from_llm_response def get_object_bbox_from_image( - vl_model: VlModel, image: Image, object_description: str + vl_model: VlModel[Any], image: Image, object_description: str ) -> BBox | None: prompt = ( f"Look at this image and find the '{object_description}'. " diff --git a/dimos/navigation/visual_servoing/detection_navigation.py b/dimos/navigation/visual_servoing/detection_navigation.py index 5f89bd1faa..351883e8ac 100644 --- a/dimos/navigation/visual_servoing/detection_navigation.py +++ b/dimos/navigation/visual_servoing/detection_navigation.py @@ -15,11 +15,15 @@ from dimos_lcm.sensor_msgs import CameraInfo as DimosLcmCameraInfo import numpy as np -from dimos.msgs.geometry_msgs import Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox -from dimos.perception.detection.type.detection3d import Detection3DPC -from dimos.protocol.tf import LCMTF +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.protocol.tf.tf import LCMTF from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/navigation/visual_servoing/visual_servoing_2d.py b/dimos/navigation/visual_servoing/visual_servoing_2d.py index 032b5f3370..f424b21466 100644 --- a/dimos/navigation/visual_servoing/visual_servoing_2d.py +++ b/dimos/navigation/visual_servoing/visual_servoing_2d.py @@ -14,8 +14,9 @@ import numpy as np -from dimos.msgs.geometry_msgs import Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo class VisualServoing2D: diff --git a/dimos/perception/__init__.py b/dimos/perception/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/perception/common/__init__.py b/dimos/perception/common/__init__.py deleted file mode 100644 index 5902f54bb8..0000000000 --- a/dimos/perception/common/__init__.py +++ /dev/null @@ -1,81 +0,0 @@ -from .utils import ( - BoundingBox2D, - CameraInfo, - Detection2D, - Detection3D, - Header, - Image, - ObjectData, - Pose, - Quaternion, - Union, - Vector, - Vector3, - bbox2d_to_corners, - colorize_depth, - combine_object_data, - cp, - cv2, - detection_results_to_object_data, - draw_bounding_box, - draw_object_detection_visualization, - draw_segmentation_mask, - extract_pose_from_detection3d, - find_clicked_detection, - load_camera_info, - load_camera_info_opencv, - logger, - np, - point_in_bbox, - project_2d_points_to_3d, - project_2d_points_to_3d_cpu, - project_2d_points_to_3d_cuda, - project_3d_points_to_2d, - project_3d_points_to_2d_cpu, - project_3d_points_to_2d_cuda, - rectify_image, - setup_logger, - torch, - yaml, -) - -__all__ = [ - "BoundingBox2D", - "CameraInfo", - "Detection2D", - "Detection3D", - "Header", - "Image", - "ObjectData", - "Pose", - "Quaternion", - "Union", - "Vector", - "Vector3", - "bbox2d_to_corners", - "colorize_depth", - "combine_object_data", - "cp", - "cv2", - "detection_results_to_object_data", - "draw_bounding_box", - "draw_object_detection_visualization", - "draw_segmentation_mask", - "extract_pose_from_detection3d", - "find_clicked_detection", - "load_camera_info", - "load_camera_info_opencv", - "logger", - "np", - "point_in_bbox", - "project_2d_points_to_3d", - "project_2d_points_to_3d_cpu", - "project_2d_points_to_3d_cuda", - "project_3d_points_to_2d", - "project_3d_points_to_2d_cpu", - "project_3d_points_to_2d_cuda", - "rectify_image", - "setup_logger", - "torch", - "yaml", -] diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index c5f550ade3..1670d31998 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -25,9 +25,11 @@ import torch import yaml # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.std_msgs.Header import Header from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index ad98d0474a..cdb09d359e 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -15,8 +15,8 @@ from dimos.agents.agent import agent from dimos.core.blueprints import autoconnect -from dimos.hardware.sensors.camera.realsense import realsense_camera -from dimos.hardware.sensors.camera.zed import zed_camera +from dimos.hardware.sensors.camera.realsense.camera import realsense_camera +from dimos.hardware.sensors.camera.zed.compat import zed_camera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import object_scene_registration_module from dimos.robot.foxglove_bridge import foxglove_bridge diff --git a/dimos/perception/detection/__init__.py b/dimos/perception/detection/__init__.py deleted file mode 100644 index ae9f8cb14d..0000000000 --- a/dimos/perception/detection/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "detectors": ["Detector", "Yolo2DDetector"], - "module2D": ["Detection2DModule"], - "module3D": ["Detection3DModule"], - }, -) diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index e81ab2ab4a..5f8f1bc4b9 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -15,6 +15,7 @@ from collections.abc import Callable, Generator import functools from typing import TypedDict +from unittest import mock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate @@ -22,23 +23,23 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.module3D import Detection3DModule from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.perception.detection.type import ( - Detection2D, - Detection3DPC, - ImageDetections2D, - ImageDetections3DPC, -) -from dimos.protocol.tf import TF +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.protocol.tf.tf import TF from dimos.robot.unitree.go2 import connection from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay class Moment(TypedDict, total=False): @@ -202,9 +203,10 @@ def detection3dpc(detections3dpc) -> Detection3DPC: @pytest.fixture(scope="session") def get_moment_2d(get_moment) -> Generator[Callable[[], Moment2D], None, None]: - from dimos.perception.detection.detectors import Yolo2DDetector + from dimos.perception.detection.detectors.yolo import Yolo2DDetector - module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) @functools.lru_cache(maxsize=1) def moment_provider(**kwargs) -> Moment2D: @@ -260,9 +262,10 @@ def moment_provider(**kwargs) -> Moment3D: @pytest.fixture(scope="session") def object_db_module(get_moment): """Create and populate an ObjectDBModule with detections from multiple frames.""" - from dimos.perception.detection.detectors import Yolo2DDetector + from dimos.perception.detection.detectors.yolo import Yolo2DDetector - module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) + c = mock.create_autospec(CameraInfo, spec_set=True, instance=True) + module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu"), camera_info=c) module3d = Detection3DModule(camera_info=connection._camera_info_static()) moduleDB = ObjectDBModule(camera_info=connection._camera_info_static()) diff --git a/dimos/perception/detection/detectors/__init__.py b/dimos/perception/detection/detectors/__init__.py deleted file mode 100644 index 2f151fe3ef..0000000000 --- a/dimos/perception/detection/detectors/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# from dimos.perception.detection.detectors.detic import Detic2DDetector -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.detectors.yolo import Yolo2DDetector - -__all__ = [ - "Detector", - "Yolo2DDetector", -] diff --git a/dimos/perception/detection/detectors/types.py b/dimos/perception/detection/detectors/base.py similarity index 84% rename from dimos/perception/detection/detectors/types.py rename to dimos/perception/detection/detectors/base.py index e85c5ae18e..40aa82e5bd 100644 --- a/dimos/perception/detection/detectors/types.py +++ b/dimos/perception/detection/detectors/base.py @@ -14,8 +14,8 @@ from abc import ABC, abstractmethod -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D class Detector(ABC): diff --git a/dimos/perception/detection/detectors/conftest.py b/dimos/perception/detection/detectors/conftest.py index 6a2c041a8b..bb9a47e0eb 100644 --- a/dimos/perception/detection/detectors/conftest.py +++ b/dimos/perception/detection/detectors/conftest.py @@ -14,7 +14,7 @@ import pytest -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.detectors.yolo import Yolo2DDetector from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode diff --git a/dimos/perception/detection/detectors/person/test_person_detectors.py b/dimos/perception/detection/detectors/person/test_person_detectors.py index 2ed7cdc7dc..6130e5888a 100644 --- a/dimos/perception/detection/detectors/person/test_person_detectors.py +++ b/dimos/perception/detection/detectors/person/test_person_detectors.py @@ -14,7 +14,8 @@ import pytest -from dimos.perception.detection.type import Detection2DPerson, ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.perception.detection.type.detection2d.person import Detection2DPerson @pytest.fixture(scope="session") diff --git a/dimos/perception/detection/detectors/person/yolo.py b/dimos/perception/detection/detectors/person/yolo.py index 519f45f2f6..26d68a4510 100644 --- a/dimos/perception/detection/detectors/person/yolo.py +++ b/dimos/perception/detection/detectors/person/yolo.py @@ -14,9 +14,9 @@ from ultralytics import YOLO # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/detection/detectors/test_bbox_detectors.py b/dimos/perception/detection/detectors/test_bbox_detectors.py index 2e69016eb5..c8112e9aab 100644 --- a/dimos/perception/detection/detectors/test_bbox_detectors.py +++ b/dimos/perception/detection/detectors/test_bbox_detectors.py @@ -17,8 +17,9 @@ from reactivex.disposable import CompositeDisposable from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2D, ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D @pytest.fixture(params=["bbox_detector", "person_detector", "yoloe_detector"], scope="session") diff --git a/dimos/perception/detection/detectors/yolo.py b/dimos/perception/detection/detectors/yolo.py index c9a65a120e..64565cce7a 100644 --- a/dimos/perception/detection/detectors/yolo.py +++ b/dimos/perception/detection/detectors/yolo.py @@ -14,9 +14,9 @@ from ultralytics import YOLO # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger diff --git a/dimos/perception/detection/detectors/yoloe.py b/dimos/perception/detection/detectors/yoloe.py index 9c9881209c..536dd9f497 100644 --- a/dimos/perception/detection/detectors/yoloe.py +++ b/dimos/perception/detection/detectors/yoloe.py @@ -20,9 +20,9 @@ from numpy.typing import NDArray from ultralytics import YOLOE # type: ignore[attr-defined, import-not-found] -from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.types import Detector -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.detection.detectors.base import Detector +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index f86794a1f7..b6d0c9358c 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -11,51 +11,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any +from collections.abc import Callable, Sequence +from typing import Annotated, Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) +from pydantic.experimental.pipeline import validate_as from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject -from dimos import spec from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.detectors import Detector # type: ignore[attr-defined] +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.perception.detection.detectors.base import Detector from dimos.perception.detection.detectors.yolo import Yolo2DDetector -from dimos.perception.detection.type import Filter2D, ImageDetections2D +from dimos.perception.detection.type.detection2d.base import Filter2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D +from dimos.spec.perception import Camera from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure -@dataclass class Config(ModuleConfig): max_freq: float = 10 detector: Callable[[Any], Detector] | None = Yolo2DDetector publish_detection_images: bool = True - camera_info: CameraInfo = None # type: ignore[assignment] - filter: list[Filter2D] | Filter2D | None = None + camera_info: CameraInfo + filter: Annotated[ + Sequence[Filter2D], + validate_as(Sequence[Filter2D] | Filter2D).transform( + lambda f: f if isinstance(f, Sequence) else (f,) + ), + ] = () - def __post_init__(self) -> None: - if self.filter is None: - self.filter = [] - elif not isinstance(self.filter, list): - self.filter = [self.filter] - -class Detection2DModule(Module): +class Detection2DModule(Module[Config]): default_config = Config - config: Config detector: Detector color_image: In[Image] @@ -161,7 +160,7 @@ def stop(self) -> None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - camera: spec.Camera, + camera: Camera, prefix: str = "/detector2d", **kwargs, ) -> Detection2DModule: diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 96ae4e8297..fa392dc799 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -22,19 +22,23 @@ from reactivex import operators as ops from reactivex.observable import Observable -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D -from dimos.perception.detection.type.detection3d import Detection3DPC from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC +from dimos.spec.perception import Camera, Pointcloud from dimos.types.timestamped import align_timestamped from dimos.utils.reactive import backpressure @@ -177,7 +181,7 @@ def detection2d_to_3d(args): # type: ignore[no-untyped-def] transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) return self.process_frame(detections, pc, transform) - self.detection_stream_3d = align_timestamped( + self.detection_stream_3d = align_timestamped( # type: ignore[type-var] backpressure(self.detection_stream_2d()), self.pointcloud.observable(), # type: ignore[no-untyped-call] match_tolerance=0.25, @@ -203,8 +207,8 @@ def _publish_detections(self, detections: ImageDetections3DPC) -> None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - lidar: spec.Pointcloud, - camera: spec.Camera, + lidar: Pointcloud, + camera: Camera, prefix: str = "/detector3d", **kwargs, ) -> "ModuleProxy": diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index bc0a346a59..5672786b94 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -25,12 +25,16 @@ from dimos.core.core import rpc from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.type.detection3d import Detection3DPC from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC from dimos.perception.detection.type.utils import TableStr diff --git a/dimos/perception/detection/objectDB.py b/dimos/perception/detection/objectDB.py index 9af8058c55..5b73e97742 100644 --- a/dimos/perception/detection/objectDB.py +++ b/dimos/perception/detection/objectDB.py @@ -20,11 +20,11 @@ import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.perception.detection.type.detection3d.object import Object logger = setup_logger() diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py index 913043f312..9dbba210a2 100644 --- a/dimos/perception/detection/person_tracker.py +++ b/dimos/perception/detection/person_tracker.py @@ -21,10 +21,13 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.type import ImageDetections2D +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.types.timestamped import align_timestamped from dimos.utils.reactive import backpressure diff --git a/dimos/perception/detection/reid/__init__.py b/dimos/perception/detection/reid/__init__.py deleted file mode 100644 index 31d50a894b..0000000000 --- a/dimos/perception/detection/reid/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem -from dimos.perception.detection.reid.module import Config, ReidModule -from dimos.perception.detection.reid.type import IDSystem, PassthroughIDSystem - -__all__ = [ - "Config", - "EmbeddingIDSystem", - # ID Systems - "IDSystem", - "PassthroughIDSystem", - # Module - "ReidModule", -] diff --git a/dimos/perception/detection/reid/embedding_id_system.py b/dimos/perception/detection/reid/embedding_id_system.py index 15bb491f5c..faf322de07 100644 --- a/dimos/perception/detection/reid/embedding_id_system.py +++ b/dimos/perception/detection/reid/embedding_id_system.py @@ -19,7 +19,7 @@ from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.perception.detection.reid.type import IDSystem -from dimos.perception.detection.type import Detection2DBBox +from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox class EmbeddingIDSystem(IDSystem): diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py index 0a359746d3..2bb0ecfbb2 100644 --- a/dimos/perception/detection/reid/module.py +++ b/dimos/perception/detection/reid/module.py @@ -24,8 +24,8 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.perception.detection.reid.type import IDSystem from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D @@ -48,7 +48,7 @@ def __init__(self, idsystem: IDSystem | None = None, **kwargs) -> None: # type: super().__init__(**kwargs) if idsystem is None: try: - from dimos.models.embedding import TorchReIDModel + from dimos.models.embedding.treid import TorchReIDModel idsystem = EmbeddingIDSystem(model=TorchReIDModel, padding=0) # type: ignore[arg-type] except Exception as e: diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py index cc8632627f..2916c9040d 100644 --- a/dimos/perception/detection/reid/test_embedding_id_system.py +++ b/dimos/perception/detection/reid/test_embedding_id_system.py @@ -15,7 +15,7 @@ import numpy as np import pytest -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.utils.data import get_data diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index f5672c1f67..aac6ba11d1 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -15,7 +15,7 @@ import pytest from dimos.core.transport import LCMTransport -from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem from dimos.perception.detection.reid.module import ReidModule @@ -23,7 +23,7 @@ @pytest.mark.tool def test_reid_ingress(imageDetections2d) -> None: try: - from dimos.models.embedding import TorchReIDModel + from dimos.models.embedding.treid import TorchReIDModel except Exception: pytest.skip("TorchReIDModel not available") diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py deleted file mode 100644 index b14464d4fa..0000000000 --- a/dimos/perception/detection/type/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "detection2d.base": [ - "Detection2D", - "Filter2D", - ], - "detection2d.bbox": [ - "Detection2DBBox", - ], - "detection2d.person": [ - "Detection2DPerson", - ], - "detection2d.point": [ - "Detection2DPoint", - ], - "detection2d.imageDetections2D": [ - "ImageDetections2D", - ], - "detection3d": [ - "Detection3D", - "Detection3DBBox", - "Detection3DPC", - "ImageDetections3DPC", - "PointCloudFilter", - "height_filter", - "radius_outlier", - "raycast", - "statistical", - ], - "imageDetections": ["ImageDetections"], - "utils": ["TableStr"], - }, -) diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py index ee9374af8c..ef05813118 100644 --- a/dimos/perception/detection/type/detection2d/base.py +++ b/dimos/perception/detection/type/detection2d/base.py @@ -17,8 +17,8 @@ from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import Timestamped diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 45dc848e9d..9ce3f11b96 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -22,7 +22,7 @@ from typing_extensions import Self from ultralytics.engine.results import Results # type: ignore[import-not-found] - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos_lcm.foxglove_msgs.ImageAnnotations import ( PointsAnnotation, @@ -40,9 +40,9 @@ from rich.console import Console from rich.text import Text -from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.types.timestamped import to_ros_stamp, to_timestamp from dimos.utils.decorators.decorators import simple_mcache diff --git a/dimos/perception/detection/type/detection2d/imageDetections2D.py b/dimos/perception/detection/type/detection2d/imageDetections2D.py index 34033a9c50..507125c333 100644 --- a/dimos/perception/detection/type/detection2d/imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/imageDetections2D.py @@ -27,8 +27,8 @@ if TYPE_CHECKING: from ultralytics.engine.results import Results - from dimos.msgs.sensor_msgs import Image - from dimos.msgs.vision_msgs import Detection2DArray + from dimos.msgs.sensor_msgs.Image import Image + from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray T2D = TypeVar("T2D", bound=Detection2D, default=Detection2DBBox) diff --git a/dimos/perception/detection/type/detection2d/person.py b/dimos/perception/detection/type/detection2d/person.py index efb12ebdbc..e85229719a 100644 --- a/dimos/perception/detection/type/detection2d/person.py +++ b/dimos/perception/detection/type/detection2d/person.py @@ -25,7 +25,7 @@ import numpy as np from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.type.detection2d.bbox import Bbox, Detection2DBBox from dimos.types.timestamped import to_ros_stamp from dimos.utils.decorators.decorators import simple_mcache diff --git a/dimos/perception/detection/type/detection2d/point.py b/dimos/perception/detection/type/detection2d/point.py index 216ec57b82..0155bcb9cd 100644 --- a/dimos/perception/detection/type/detection2d/point.py +++ b/dimos/perception/detection/type/detection2d/point.py @@ -31,14 +31,14 @@ Pose2D, ) -from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.types.timestamped import to_ros_stamp if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass diff --git a/dimos/perception/detection/type/detection2d/seg.py b/dimos/perception/detection/type/detection2d/seg.py index 5d4d55d0c3..aca1e34b7e 100644 --- a/dimos/perception/detection/type/detection2d/seg.py +++ b/dimos/perception/detection/type/detection2d/seg.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from ultralytics.engine.results import Results - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass diff --git a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py index 83487d2c25..4897d8d034 100644 --- a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from dimos.perception.detection.type import ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D def test_from_ros_detection2d_array(get_moment_2d) -> None: diff --git a/dimos/perception/detection/type/detection2d/test_person.py b/dimos/perception/detection/type/detection2d/test_person.py index 06c5883ae2..988222e120 100644 --- a/dimos/perception/detection/type/detection2d/test_person.py +++ b/dimos/perception/detection/type/detection2d/test_person.py @@ -17,7 +17,7 @@ def test_person_ros_confidence() -> None: """Test that Detection2DPerson preserves confidence when converting to ROS format.""" - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.type.detection2d.person import Detection2DPerson from dimos.utils.data import get_data diff --git a/dimos/perception/detection/type/detection3d/__init__.py b/dimos/perception/detection/type/detection3d/__init__.py deleted file mode 100644 index 53ab73259e..0000000000 --- a/dimos/perception/detection/type/detection3d/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.perception.detection.type.detection3d.base import Detection3D -from dimos.perception.detection.type.detection3d.bbox import Detection3DBBox -from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC -from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC -from dimos.perception.detection.type.detection3d.pointcloud_filters import ( - PointCloudFilter, - height_filter, - radius_outlier, - raycast, - statistical, -) - -__all__ = [ - "Detection3D", - "Detection3DBBox", - "Detection3DPC", - "ImageDetections3DPC", - "PointCloudFilter", - "height_filter", - "radius_outlier", - "raycast", - "statistical", -] diff --git a/dimos/perception/detection/type/detection3d/base.py b/dimos/perception/detection/type/detection3d/base.py index a5dbb742b8..afe37aac6e 100644 --- a/dimos/perception/detection/type/detection3d/base.py +++ b/dimos/perception/detection/type/detection3d/base.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox if TYPE_CHECKING: diff --git a/dimos/perception/detection/type/detection3d/bbox.py b/dimos/perception/detection/type/detection3d/bbox.py index bdf2d27a7c..a3ae68a766 100644 --- a/dimos/perception/detection/type/detection3d/bbox.py +++ b/dimos/perception/detection/type/detection3d/bbox.py @@ -20,9 +20,13 @@ from dimos_lcm.vision_msgs import ObjectHypothesis, ObjectHypothesisWithPose -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection3D +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection3D import Detection3D from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox diff --git a/dimos/perception/detection/type/detection3d/object.py b/dimos/perception/detection/type/detection3d/object.py index ec160c4a68..639ea73ae5 100644 --- a/dimos/perception/detection/type/detection3d/object.py +++ b/dimos/perception/detection/type/detection3d/object.py @@ -24,10 +24,15 @@ import numpy as np import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection3D as ROSDetection3D, Detection3DArray +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection3D import Detection3D as ROSDetection3D +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.detection.type.detection2d.seg import Detection2DSeg from dimos.perception.detection.type.detection3d.base import Detection3D diff --git a/dimos/perception/detection/type/detection3d/pointcloud.py b/dimos/perception/detection/type/detection3d/pointcloud.py index 741b9c7498..5ddec06fd5 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud.py +++ b/dimos/perception/detection/type/detection3d/pointcloud.py @@ -33,8 +33,10 @@ import numpy as np from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection3d.base import Detection3D from dimos.perception.detection.type.detection3d.pointcloud_filters import ( PointCloudFilter, diff --git a/dimos/perception/detection/type/detection3d/pointcloud_filters.py b/dimos/perception/detection/type/detection3d/pointcloud_filters.py index 59ad6200d9..fdb2afeebb 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud_filters.py +++ b/dimos/perception/detection/type/detection3d/pointcloud_filters.py @@ -18,8 +18,8 @@ from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox # Filters take Detection2DBBox, PointCloud2, CameraInfo, Transform and return filtered PointCloud2 or None diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 12a1f4efb9..25cd45545a 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -20,14 +20,14 @@ from dimos_lcm.vision_msgs import Detection2DArray -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.std_msgs import Header +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.std_msgs.Header import Header from dimos.perception.detection.type.utils import TableStr if TYPE_CHECKING: from collections.abc import Callable, Iterator - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.type.detection2d.base import Detection2D T = TypeVar("T", bound=Detection2D) diff --git a/dimos/perception/detection/type/test_object3d.py b/dimos/perception/detection/type/test_object3d.py index 7057fbb9cb..ff8931e353 100644 --- a/dimos/perception/detection/type/test_object3d.py +++ b/dimos/perception/detection/type/test_object3d.py @@ -15,7 +15,7 @@ import pytest from dimos.perception.detection.moduleDB import Object3D -from dimos.perception.detection.type.detection3d import ImageDetections3DPC +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ImageDetections3DPC def test_first_object(first_object) -> None: diff --git a/dimos/perception/experimental/__init__.py b/dimos/perception/experimental/__init__.py deleted file mode 100644 index 39ef33521d..0000000000 --- a/dimos/perception/experimental/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Experimental perception modules.""" diff --git a/dimos/perception/experimental/temporal_memory/clip_filter.py b/dimos/perception/experimental/temporal_memory/clip_filter.py index d747899452..9bea000712 100644 --- a/dimos/perception/experimental/temporal_memory/clip_filter.py +++ b/dimos/perception/experimental/temporal_memory/clip_filter.py @@ -18,7 +18,7 @@ import numpy as np -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py index 0d5531bada..bdc7137ce7 100644 --- a/dimos/perception/experimental/temporal_memory/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -30,9 +30,12 @@ from dimos.utils.logging_config import setup_logger +from .temporal_utils.parsers import parse_batch_distance_response +from .temporal_utils.prompts import build_batch_distance_estimation_prompt + if TYPE_CHECKING: from dimos.models.vl.base import VlModel - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image logger = setup_logger() @@ -122,7 +125,7 @@ def _init_schema(self) -> None: conn.commit() - # ==================== Entity Operations ==================== + # Entity Operations def upsert_entity( self, @@ -216,7 +219,7 @@ def get_entities_by_time( for row in cursor.fetchall() ] - # ==================== Relation Operations ==================== + # Relation Operations def add_relation( self, @@ -290,7 +293,7 @@ def get_recent_relations(self, limit: int = 50) -> list[dict[str, Any]]: for row in cursor.fetchall() ] - # ==================== Distance Operations ==================== + # Distance Operations def add_distance( self, @@ -424,7 +427,7 @@ def get_nearby_entities( for row in cursor.fetchall() ] - # ==================== Neighborhood Query ==================== + # Neighborhood Query def get_entity_neighborhood( self, @@ -471,7 +474,7 @@ def get_entity_neighborhood( "num_hops": max_hops, } - # ==================== Stats / Summary ==================== + # Stats / Summary def get_stats(self) -> dict[str, Any]: conn = self._get_connection() @@ -491,7 +494,7 @@ def get_summary(self, recent_relations_limit: int = 5) -> dict[str, Any]: "recent_relations": self.get_recent_relations(limit=recent_relations_limit), } - # ==================== Bulk Save ==================== + # Bulk Save def save_window_data( self, @@ -557,14 +560,13 @@ def estimate_and_save_distances( self, parsed: dict[str, Any], frame_image: Image, - vlm: VlModel, + vlm: VlModel[Any], timestamp_s: float, max_distance_pairs: int = 5, ) -> None: """Estimate distances between entities using VLM and save to database.""" if not frame_image: return - from . import temporal_utils as tu enriched_entities: list[dict[str, Any]] = [] for entity in parsed.get("new_entities", []): @@ -593,8 +595,8 @@ def estimate_and_save_distances( if not pairs: return try: - response = vlm.query(frame_image, tu.build_batch_distance_estimation_prompt(pairs)) - for r in tu.parse_batch_distance_response(response, pairs): + response = vlm.query(frame_image, build_batch_distance_estimation_prompt(pairs)) + for r in parse_batch_distance_response(response, pairs): if r["category"] in ("near", "medium", "far"): self.add_distance( entity_a_id=r["entity_a_id"], @@ -608,7 +610,7 @@ def estimate_and_save_distances( except Exception as e: logger.warning(f"Failed to estimate distances: {e}", exc_info=True) - # ==================== Lifecycle ==================== + # Lifecycle def commit(self) -> None: if hasattr(self._local, "conn"): diff --git a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py index 7af13ad9c2..4c910a1b88 100644 --- a/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py +++ b/dimos/perception/experimental/temporal_memory/frame_window_accumulator.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image @dataclass @@ -72,10 +72,6 @@ def __init__( self.stride_s = stride_s self.fps = fps - # ------------------------------------------------------------------ - # Ingest - # ------------------------------------------------------------------ - def set_start_time(self, wall_time: float) -> None: with self._lock: if self._video_start_wall_time is None: @@ -103,10 +99,6 @@ def add_frame(self, image: Image, wall_time: float) -> None: self._buffer.append(frame) self._frame_count += 1 - # ------------------------------------------------------------------ - # Window extraction - # ------------------------------------------------------------------ - def try_extract_window(self) -> list[Frame] | None: """Try to extract a window of frames. @@ -131,10 +123,6 @@ def mark_analysis_time(self, t: float) -> None: with self._lock: self._last_analysis_time = t - # ------------------------------------------------------------------ - # Accessors - # ------------------------------------------------------------------ - @property def frame_count(self) -> int: with self._lock: diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py index b651d3e0af..d4e343872b 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -22,13 +22,12 @@ from __future__ import annotations from collections import deque -from dataclasses import dataclass import json import os from pathlib import Path import threading import time -from typing import TYPE_CHECKING, Any +from typing import Any from reactivex import Subject, interval from reactivex.disposable import Disposable @@ -37,22 +36,20 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out +from dimos.models.vl.base import VlModel from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.sensor_msgs.Image import sharpness_barrier +from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.msgs.visualization_msgs.EntityMarkers import EntityMarkers, Marker from dimos.utils.logging_config import get_run_log_dir, setup_logger -from . import temporal_utils as tu from .clip_filter import CLIP_AVAILABLE, adaptive_keyframes from .entity_graph_db import EntityGraphDB from .frame_window_accumulator import Frame, FrameWindowAccumulator from .temporal_state import TemporalState +from .temporal_utils.graph_utils import build_graph_context, extract_time_window +from .temporal_utils.helpers import is_scene_stale from .window_analyzer import WindowAnalyzer -if TYPE_CHECKING: - from dimos.models.vl.base import VlModel - try: from .clip_filter import CLIPFrameFilter except ImportError: @@ -63,7 +60,6 @@ MAX_RECENT_WINDOWS = 50 -@dataclass class TemporalMemoryConfig(ModuleConfig): """Configuration for the temporal memory module. @@ -71,6 +67,8 @@ class TemporalMemoryConfig(ModuleConfig): tune cost / latency / accuracy without touching code. """ + vlm: VlModel[Any] | None = None + # Frame processing fps: float = 1.0 window_s: float = 5.0 @@ -106,38 +104,35 @@ class TemporalMemoryConfig(ModuleConfig): nearby_distance_meters: float = 5.0 -class TemporalMemory(Module): +class TemporalMemory(Module[TemporalMemoryConfig]): """Thin orchestrator that wires frames → window accumulator → VLM → state + DB. Uses RxPY reactive streams for the frame pipeline and ``interval`` for periodic window analysis. """ + default_config = TemporalMemoryConfig + color_image: In[Image] odom: In[PoseStamped] entity_markers: Out[EntityMarkers] - def __init__( - self, - vlm: VlModel | None = None, - config: TemporalMemoryConfig | None = None, - ) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) - self._vlm_raw = vlm - self._config: TemporalMemoryConfig = config or TemporalMemoryConfig() + self._vlm_raw = self.config.vlm # new_memory is set via TemporalMemoryConfig by the blueprint factory # (which runs in the main process where GlobalConfig is available). # Components self._accumulator = FrameWindowAccumulator( - max_buffer_frames=self._config.max_buffer_frames, - window_s=self._config.window_s, - stride_s=self._config.stride_s, - fps=self._config.fps, + max_buffer_frames=self.config.max_buffer_frames, + window_s=self.config.window_s, + stride_s=self.config.stride_s, + fps=self.config.fps, ) - self._state = TemporalState(next_summary_at_s=self._config.summary_interval_s) + self._state = TemporalState(next_summary_at_s=self.config.summary_interval_s) self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) self._stopped = False @@ -150,10 +145,10 @@ def __init__( # CLIP filter self._clip_filter: CLIPFrameFilter | None = None - self._use_clip_filtering = self._config.use_clip_filtering + self._use_clip_filtering = self.config.use_clip_filtering if self._use_clip_filtering and CLIP_AVAILABLE: try: - self._clip_filter = CLIPFrameFilter(model_name=self._config.clip_model) + self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model) logger.info("clip filtering enabled") except Exception as e: logger.warning(f"clip init failed: {e}") @@ -163,8 +158,8 @@ def __init__( self._use_clip_filtering = False # Persistent DB — stored in XDG state dir (same root as per-run logs) - if self._config.db_dir: - db_dir = Path(self._config.db_dir) + if self.config.db_dir: + db_dir = Path(self.config.db_dir) else: # Default: ~/.local/state/dimos/temporal_memory/ # XDG state dir — predictable, works for pip install and git clone. @@ -173,7 +168,7 @@ def __init__( db_dir = state_root / "dimos" / "temporal_memory" db_dir.mkdir(parents=True, exist_ok=True) db_path = db_dir / "entity_graph.db" - if self._config.new_memory and db_path.exists(): + if self.config.new_memory and db_path.exists(): db_path.unlink() logger.info("Deleted existing DB (new_memory=True)") self._graph_db = EntityGraphDB(db_path=db_path) @@ -181,7 +176,7 @@ def __init__( # Persistent JSONL — accumulates across runs (raw VLM output + parsed) self._persistent_jsonl_path: Path = db_dir / "temporal_memory.jsonl" - if self._config.new_memory and self._persistent_jsonl_path.exists(): + if self.config.new_memory and self._persistent_jsonl_path.exists(): self._persistent_jsonl_path.unlink() logger.info("Deleted existing persistent JSONL (new_memory=True)") logger.info(f"persistent JSONL: {self._persistent_jsonl_path}") @@ -204,16 +199,12 @@ def __init__( logger.warning("no run log dir found — JSONL logging disabled") logger.info( - f"TemporalMemory init: fps={self._config.fps}, " - f"window={self._config.window_s}s, stride={self._config.stride_s}s" + f"TemporalMemory init: fps={self.config.fps}, " + f"window={self.config.window_s}s, stride={self.config.stride_s}s" ) - # ------------------------------------------------------------------ - # VLM access (lazy) - # ------------------------------------------------------------------ - @property - def vlm(self) -> VlModel: + def vlm(self) -> VlModel[Any]: if self._vlm_raw is None: from dimos.models.vl.openai import OpenAIVlModel @@ -230,15 +221,11 @@ def _analyzer(self) -> WindowAnalyzer: if not hasattr(self, "__analyzer"): self.__analyzer = WindowAnalyzer( self.vlm, - max_tokens=self._config.max_tokens, - temperature=self._config.temperature, + max_tokens=self.config.max_tokens, + temperature=self.config.temperature, ) return self.__analyzer - # ------------------------------------------------------------------ - # JSONL logging - # ------------------------------------------------------------------ - def _log_jsonl(self, record: dict[str, Any]) -> None: line = json.dumps(record, ensure_ascii=False) + "\n" # Write to per-run JSONL @@ -255,13 +242,9 @@ def _log_jsonl(self, record: dict[str, Any]) -> None: except Exception as e: logger.warning(f"persistent jsonl log failed: {e}") - # ------------------------------------------------------------------ - # Rerun visualization - # ------------------------------------------------------------------ - def _publish_entity_markers(self) -> None: """Publish entity positions as 3D markers for Rerun overlay on the map.""" - if not self._config.visualize: + if not self.config.visualize: return try: all_entities = self._graph_db.get_all_entities() @@ -293,10 +276,6 @@ def _publish_entity_markers(self) -> None: except Exception as e: logger.debug(f"entity marker publish error: {e}") - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - @rpc def start(self) -> None: super().start() @@ -319,7 +298,7 @@ def _on_frame(img: Image) -> None: ) self._disposables.add( - frame_subject.pipe(sharpness_barrier(self._config.fps)).subscribe(_on_frame) + frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(_on_frame) ) unsub_image = self.color_image.subscribe(frame_subject.on_next) self._disposables.add(Disposable(unsub_image)) @@ -342,7 +321,7 @@ def _on_odom(msg: PoseStamped) -> None: # Periodic window analysis self._disposables.add( - interval(self._config.stride_s).subscribe(lambda _: self._analyze_window()) + interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) ) logger.info("TemporalMemory started") @@ -366,7 +345,7 @@ def stop(self) -> None: self._accumulator.clear() self._recent_windows.clear() - self._state.clear(self._config.summary_interval_s) + self._state.clear(self.config.summary_interval_s) super().stop() @@ -379,10 +358,6 @@ def stop(self) -> None: logger.info("TemporalMemory stopped") - # ------------------------------------------------------------------ - # Core loop - # ------------------------------------------------------------------ - def _analyze_window(self) -> None: if self._stopped: return @@ -401,13 +376,13 @@ def _analyze_window(self) -> None: w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s # Skip stale scenes (frames too close together / camera not moving) - if tu.is_scene_stale(window_frames, self._config.stale_scene_threshold): + if is_scene_stale(window_frames, self.config.stale_scene_threshold): logger.info(f"[temporal-memory] skipping stale window [{w_start:.1f}-{w_end:.1f}s]") return # Select diverse keyframes window_frames = adaptive_keyframes( - window_frames, max_frames=self._config.max_frames_per_window + window_frames, max_frames=self.config.max_frames_per_window ) logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") @@ -458,7 +433,7 @@ def _analyze_window(self) -> None: ) # VLM Call #2: distance estimation (background thread) - if self._graph_db and self._config.enable_distance_estimation and window_frames: + if self._graph_db and self.config.enable_distance_estimation and window_frames: mid_frame = window_frames[len(window_frames) // 2] if mid_frame.image: thread = threading.Thread( @@ -468,7 +443,7 @@ def _analyze_window(self) -> None: mid_frame.image, self.vlm, w_end, - self._config.max_distance_pairs, + self.config.max_distance_pairs, ), daemon=True, ) @@ -478,7 +453,7 @@ def _analyze_window(self) -> None: # Update state needs_summary = self._state.update_from_window( - parsed, w_end, self._config.summary_interval_s + parsed, w_end, self.config.summary_interval_s ) self._recent_windows.append(parsed) @@ -512,7 +487,7 @@ def _update_rolling_summary(self, w_end: float) -> None: sr = self._analyzer.update_summary(latest.image, snap.rolling_summary, snap.chunk_buffer) if sr is not None: - self._state.apply_summary(sr.summary_text, w_end, self._config.summary_interval_s) + self._state.apply_summary(sr.summary_text, w_end, self.config.summary_interval_s) self._log_jsonl( { "ts": time.time(), @@ -523,10 +498,6 @@ def _update_rolling_summary(self, w_end: float) -> None: ) logger.info(f"[temporal-memory] SUMMARY: {sr.summary_text[:300]}") - # ------------------------------------------------------------------ - # Query (agent skill) - # ------------------------------------------------------------------ - @skill def query(self, question: str) -> str: """Answer a question about the video stream using temporal memory and graph knowledge. @@ -582,18 +553,18 @@ def query(self, question: str) -> str: # Graph context if self._graph_db: - time_window_s = tu.extract_time_window(question) + time_window_s = extract_time_window(question) all_entity_ids = [ e["id"] for e in snap.entity_roster if isinstance(e, dict) and "id" in e ] if all_entity_ids: logger.info(f"query: building graph context for {len(all_entity_ids)} entities") - graph_context = tu.build_graph_context( + graph_context = build_graph_context( graph_db=self._graph_db, entity_ids=all_entity_ids, time_window_s=time_window_s, - max_relations_per_entity=self._config.max_relations_per_entity, - nearby_distance_meters=self._config.nearby_distance_meters, + max_relations_per_entity=self.config.max_relations_per_entity, + nearby_distance_meters=self.config.nearby_distance_meters, current_video_time_s=current_video_time_s, ) context["graph_knowledge"] = graph_context @@ -616,14 +587,10 @@ def query(self, question: str) -> str: ) return qr.answer - # ------------------------------------------------------------------ - # RPC accessors (backward compat) - # ------------------------------------------------------------------ - @rpc def clear_history(self) -> bool: try: - self._state.clear(self._config.summary_interval_s) + self._state.clear(self.config.summary_interval_s) self._recent_windows.clear() logger.info("cleared history") return True diff --git a/dimos/perception/experimental/temporal_memory/temporal_state.py b/dimos/perception/experimental/temporal_memory/temporal_state.py index 64914761b1..dfc440872d 100644 --- a/dimos/perception/experimental/temporal_memory/temporal_state.py +++ b/dimos/perception/experimental/temporal_memory/temporal_state.py @@ -39,10 +39,6 @@ class TemporalState: _lock: threading.Lock = field(default_factory=threading.Lock, repr=False, compare=False) - # ------------------------------------------------------------------ - # Snapshot - # ------------------------------------------------------------------ - def snapshot(self) -> TemporalState: """Return a deep-copy snapshot (safe to read outside the lock).""" with self._lock: @@ -65,10 +61,6 @@ def to_dict(self) -> dict[str, Any]: "last_present": copy.deepcopy(self.last_present), } - # ------------------------------------------------------------------ - # Mutators - # ------------------------------------------------------------------ - def update_from_window( self, parsed: dict[str, Any], diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py deleted file mode 100644 index d8119a5159..0000000000 --- a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Temporal memory utilities.""" - -from .graph_utils import build_graph_context, extract_time_window -from .helpers import clamp_text, format_timestamp, is_scene_stale, next_entity_id_hint -from .parsers import parse_batch_distance_response, parse_window_response -from .prompts import ( - WINDOW_RESPONSE_SCHEMA, - build_batch_distance_estimation_prompt, - build_distance_estimation_prompt, - build_query_prompt, - build_summary_prompt, - build_window_prompt, - get_structured_output_format, -) - -__all__ = [ - "WINDOW_RESPONSE_SCHEMA", - "build_batch_distance_estimation_prompt", - "build_distance_estimation_prompt", - "build_graph_context", - "build_query_prompt", - "build_summary_prompt", - "build_window_prompt", - "clamp_text", - "extract_time_window", - "format_timestamp", - "get_structured_output_format", - "is_scene_stale", - "next_entity_id_hint", - "parse_batch_distance_response", - "parse_window_response", -] diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py index d98074bd5d..81df107ecf 100644 --- a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -21,7 +21,7 @@ import threading import time from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch from dotenv import load_dotenv import numpy as np @@ -33,15 +33,18 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import Out from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image -from dimos.perception.experimental.temporal_memory import ( +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.perception.experimental.temporal_memory.entity_graph_db import EntityGraphDB +from dimos.perception.experimental.temporal_memory.frame_window_accumulator import ( Frame, FrameWindowAccumulator, +) +from dimos.perception.experimental.temporal_memory.temporal_memory import ( TemporalMemory, TemporalMemoryConfig, - TemporalState, ) -from dimos.perception.experimental.temporal_memory.entity_graph_db import EntityGraphDB +from dimos.perception.experimental.temporal_memory.temporal_state import TemporalState from dimos.perception.experimental.temporal_memory.temporal_utils.graph_utils import ( extract_time_window, ) @@ -63,11 +66,6 @@ def _make_image(value: int = 128, shape: tuple[int, ...] = (64, 64, 3)) -> Image return Image.from_numpy(data) -# ====================================================================== -# 1. FrameWindowAccumulator tests -# ====================================================================== - - class TestFrameWindowAccumulator: def test_bounded_buffer(self) -> None: acc = FrameWindowAccumulator(max_buffer_frames=5, window_s=1.0, stride_s=1.0, fps=1.0) @@ -124,11 +122,6 @@ def test_clear(self) -> None: assert acc.buffer_size == 0 -# ====================================================================== -# 2. TemporalState tests -# ====================================================================== - - class TestTemporalState: def test_update_and_snapshot(self) -> None: state = TemporalState(next_summary_at_s=10.0) @@ -225,11 +218,6 @@ def test_auto_add_referenced(self) -> None: assert "E2" in ids -# ====================================================================== -# 3. extract_time_window (regex-only) tests -# ====================================================================== - - class TestExtractTimeWindow: def test_keyword_patterns(self) -> None: assert extract_time_window("just now") == 60 @@ -247,11 +235,6 @@ def test_no_time_reference(self) -> None: assert extract_time_window("is there a person?") is None -# ====================================================================== -# 4. EntityGraphDB tests -# ====================================================================== - - class TestEntityGraphDB: @pytest.fixture def db(self, tmp_path: Path) -> EntityGraphDB: @@ -314,11 +297,6 @@ def test_stats(self, db: EntityGraphDB) -> None: assert "semantic_relations" not in stats -# ====================================================================== -# 5. Persistence test (new_memory flag) -# ====================================================================== - - class TestPersistence: def test_new_memory_clears_db(self, tmp_path: Path) -> None: db_dir = tmp_path / "memory" / "temporal" @@ -337,8 +315,9 @@ def test_new_memory_clears_db(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), new_memory=True), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + new_memory=True, ) # DB should be empty since we cleared it stats = tm._graph_db.get_stats() @@ -361,19 +340,15 @@ def test_persistent_memory_survives(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), new_memory=False), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + new_memory=False, ) stats = tm._graph_db.get_stats() assert stats["entities"] == 1 tm.stop() -# ====================================================================== -# 6. Per-run JSONL logging test -# ====================================================================== - - class TestJSONLLogging: def test_log_entries(self, tmp_path: Path) -> None: db_dir = tmp_path / "db" @@ -386,8 +361,8 @@ def test_log_entries(self, tmp_path: Path) -> None: return_value=log_dir, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir)), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), ) jsonl_path = log_dir / "temporal_memory" / "temporal_memory.jsonl" @@ -412,11 +387,6 @@ def test_log_entries(self, tmp_path: Path) -> None: tm.stop() -# ====================================================================== -# 7. Rerun visualization test -# ====================================================================== - - class TestEntityMarkers: def test_publish_entity_markers(self, tmp_path: Path) -> None: db_dir = tmp_path / "db" @@ -427,8 +397,9 @@ def test_publish_entity_markers(self, tmp_path: Path) -> None: return_value=None, ): tm = TemporalMemory( - vlm=MagicMock(), - config=TemporalMemoryConfig(db_dir=str(db_dir), visualize=True), + vlm=create_autospec(VlModel, spec_set=True, instance=True), + db_dir=str(db_dir), + visualize=True, ) # Populate DB with world positions @@ -478,16 +449,11 @@ def test_markers_to_rerun(self) -> None: assert isinstance(archetype, rr.Points3D) -# ====================================================================== -# 8. WindowAnalyzer mock tests -# ====================================================================== - - class TestWindowAnalyzer: def test_analyze_window_calls_vlm(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = json.dumps( { "window": {"start_s": 0.0, "end_s": 2.0}, @@ -513,7 +479,7 @@ def test_analyze_window_calls_vlm(self) -> None: def test_analyze_window_vlm_error(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.side_effect = RuntimeError("VLM error") analyzer = WindowAnalyzer(mock_vlm) @@ -527,7 +493,7 @@ def test_analyze_window_vlm_error(self) -> None: def test_update_summary(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = "Updated summary text" analyzer = WindowAnalyzer(mock_vlm) @@ -540,7 +506,7 @@ def test_update_summary(self) -> None: def test_answer_query(self) -> None: from dimos.perception.experimental.temporal_memory.window_analyzer import WindowAnalyzer - mock_vlm = MagicMock() + mock_vlm = create_autospec(VlModel, spec_set=True, instance=True) mock_vlm.query.return_value = "The answer is 42" analyzer = WindowAnalyzer(mock_vlm) @@ -551,11 +517,6 @@ def test_answer_query(self) -> None: assert result.answer == "The answer is 42" -# ====================================================================== -# 9. Integration test with ModuleCoordinator -# ====================================================================== - - class VideoReplayModule(Module): """Module that replays synthetic video data for tests.""" diff --git a/dimos/perception/experimental/temporal_memory/window_analyzer.py b/dimos/perception/experimental/temporal_memory/window_analyzer.py index a8b1899258..3c233f8e5b 100644 --- a/dimos/perception/experimental/temporal_memory/window_analyzer.py +++ b/dimos/perception/experimental/temporal_memory/window_analyzer.py @@ -25,11 +25,17 @@ from dimos.utils.logging_config import setup_logger -from . import temporal_utils as tu +from .temporal_utils.parsers import parse_window_response +from .temporal_utils.prompts import ( + build_query_prompt, + build_summary_prompt, + build_window_prompt, + get_structured_output_format, +) if TYPE_CHECKING: from dimos.models.vl.base import VlModel - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from .frame_window_accumulator import Frame @@ -68,19 +74,17 @@ class WindowAnalyzer: Stateless — caller provides frames, state snapshots, and config. """ - def __init__(self, vlm: VlModel, *, max_tokens: int = 900, temperature: float = 0.2) -> None: + def __init__( + self, vlm: VlModel[Any], *, max_tokens: int = 900, temperature: float = 0.2 + ) -> None: self._vlm = vlm self.max_tokens = max_tokens self.temperature = temperature @property - def vlm(self) -> VlModel: + def vlm(self) -> VlModel[Any]: return self._vlm - # ------------------------------------------------------------------ - # VLM Call #1: Window analysis - # ------------------------------------------------------------------ - def analyze_window( self, frames: list[Frame], @@ -89,14 +93,14 @@ def analyze_window( w_end: float, ) -> AnalysisResult | None: """Run VLM window analysis. Returns None on failure.""" - query = tu.build_window_prompt( + query = build_window_prompt( w_start=w_start, w_end=w_end, frame_count=len(frames), state=state_dict, ) try: - fmt = tu.get_structured_output_format() + fmt = get_structured_output_format() if len(frames) > 1: responses = self._vlm.query_batch( [f.image for f in frames], query, response_format=fmt @@ -111,19 +115,11 @@ def analyze_window( if raw is None: return None - parsed = tu.parse_window_response(raw, w_start, w_end, len(frames)) + parsed = parse_window_response(raw, w_start, w_end, len(frames)) return AnalysisResult(parsed=parsed, raw_vlm_response=raw, w_start=w_start, w_end=w_end) - # ------------------------------------------------------------------ - # VLM Call #2: Distance estimation (delegated to EntityGraphDB) - # ------------------------------------------------------------------ - # Distance estimation is handled by EntityGraphDB.estimate_and_save_distances. # It's called from the orchestrator, not here. - # ------------------------------------------------------------------ - # VLM Call #3: Rolling summary - # ------------------------------------------------------------------ - def update_summary( self, latest_frame: Image, @@ -134,7 +130,7 @@ def update_summary( if not chunk_buffer or not latest_frame: return None - prompt = tu.build_summary_prompt( + prompt = build_summary_prompt( rolling_summary=rolling_summary, chunk_windows=chunk_buffer, ) @@ -146,10 +142,6 @@ def update_summary( logger.error(f"summary update failed: {e}", exc_info=True) return None - # ------------------------------------------------------------------ - # VLM Call #5: Query answer - # ------------------------------------------------------------------ - def answer_query( self, question: str, @@ -157,7 +149,7 @@ def answer_query( latest_frame: Image, ) -> QueryResult | None: """Answer a user query. Returns None on failure.""" - prompt = tu.build_query_prompt(question=question, context=context) + prompt = build_query_prompt(question=question, context=context) try: raw = self._vlm.query(latest_frame, prompt) return QueryResult(answer=raw.strip(), raw_vlm_response=raw) diff --git a/dimos/perception/object_scene_registration.py b/dimos/perception/object_scene_registration.py index ee7b87b534..5fb1748032 100644 --- a/dimos/perception/object_scene_registration.py +++ b/dimos/perception/object_scene_registration.py @@ -24,14 +24,16 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.foxglove_msgs.ImageAnnotations import ImageAnnotations +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.detection.detectors.yoloe import Yoloe2DDetector, YoloePromptMode from dimos.perception.detection.objectDB import ObjectDB -from dimos.perception.detection.type import ImageDetections2D +from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection3d.object import ( Object, Object as DetObject, diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index da415ac32a..6afc5e0814 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import threading import time +from typing import Any import cv2 @@ -31,15 +31,16 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import ( - CameraInfo, - Image, - ImageFormat, -) -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray +from dimos.protocol.tf.tf import TF from dimos.types.timestamped import align_timestamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( @@ -51,9 +52,10 @@ logger = setup_logger() -@dataclass class ObjectTrackingConfig(ModuleConfig): frame_id: str = "camera_link" + reid_threshold: int = 10 + reid_fail_tolerance: int = 5 class ObjectTracking(Module[ObjectTrackingConfig]): @@ -70,11 +72,8 @@ class ObjectTracking(Module[ObjectTrackingConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTrackingConfig - config: ObjectTrackingConfig - def __init__( - self, reid_threshold: int = 10, reid_fail_tolerance: int = 5, **kwargs: object - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. @@ -89,8 +88,6 @@ def __init__( super().__init__(**kwargs) self.camera_intrinsics = None - self.reid_threshold = reid_threshold - self.reid_fail_tolerance = reid_fail_tolerance self.tracker = None self.tracking_bbox = None # Stores (x, y, w, h) for tracker initialization @@ -276,7 +273,7 @@ def reid(self, frame, current_bbox) -> bool: # type: ignore[no-untyped-def] good_matches += 1 self.last_good_matches = good_matches_list # Store good matches for visualization - return good_matches >= self.reid_threshold + return good_matches >= self.config.reid_threshold def _start_tracking_thread(self) -> None: """Start the tracking thread.""" @@ -389,7 +386,7 @@ def _process_tracking(self) -> None: # Determine final success if tracker_succeeded: - if self.reid_fail_count >= self.reid_fail_tolerance: + if self.reid_fail_count >= self.config.reid_fail_tolerance: logger.warning( f"Re-ID failed consecutively {self.reid_fail_count} times. Target lost." ) @@ -589,11 +586,11 @@ def _draw_reid_matches(self, image: NDArray[np.uint8]) -> NDArray[np.uint8]: # f"REID: WARMING UP ({self.tracking_frame_count}/{self.reid_warmup_frames})" ) status_color = (255, 255, 0) # Yellow - elif len(self.last_good_matches) >= self.reid_threshold: + elif len(self.last_good_matches) >= self.config.reid_threshold: status_text = "REID: CONFIRMED" status_color = (0, 255, 0) # Green else: - status_text = f"REID: WEAK ({self.reid_fail_count}/{self.reid_fail_tolerance})" + status_text = f"REID: WEAK ({self.reid_fail_count}/{self.config.reid_fail_tolerance})" status_color = (0, 165, 255) # Orange cv2.putText( diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 1264b0e92b..a53d331aef 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass import logging import threading import time +from typing import Any import cv2 @@ -35,15 +35,14 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.INFO) -@dataclass class ObjectTracker2DConfig(ModuleConfig): frame_id: str = "camera_link" @@ -57,9 +56,8 @@ class ObjectTracker2D(Module[ObjectTracker2DConfig]): tracked_overlay: Out[Image] # Visualization output default_config = ObjectTracker2DConfig - config: ObjectTracker2DConfig - def __init__(self, **kwargs: object) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize 2D object tracking module using OpenCV's CSRT tracker.""" super().__init__(**kwargs) diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index da35577d0d..317a58dba0 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -24,12 +24,16 @@ from dimos.core.core import rpc from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.std_msgs.Header import Header +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray +from dimos.msgs.vision_msgs.Detection3DArray import Detection3DArray from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.protocol.tf import TF +from dimos.protocol.tf.tf import TF from dimos.types.timestamped import align_timestamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( diff --git a/dimos/perception/perceive_loop_skill.py b/dimos/perception/perceive_loop_skill.py index 53362977f5..4532e61c2e 100644 --- a/dimos/perception/perceive_loop_skill.py +++ b/dimos/perception/perceive_loop_skill.py @@ -16,7 +16,7 @@ import json from threading import RLock -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langchain_core.messages import HumanMessage @@ -26,16 +26,13 @@ from dimos.core.module import Module from dimos.core.stream import In from dimos.models.vl.create import create -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.sensor_msgs.Image import sharpness_window +from dimos.msgs.sensor_msgs.Image import Image, sharpness_window from dimos.utils.logging_config import setup_logger from dimos.utils.reactive import backpressure if TYPE_CHECKING: from reactivex.abc import DisposableBase - from dimos.core.global_config import GlobalConfig - from dimos.models.vl.base import VlModel logger = setup_logger() @@ -46,13 +43,9 @@ class PerceiveLoopSkill(Module): _agent_spec: AgentSpec _period: float = 0.5 # seconds - how often to run the perceive loop - def __init__( - self, - cfg: GlobalConfig, - ) -> None: - super().__init__() - self._global_config: GlobalConfig = cfg - self._vl_model: VlModel = create(cfg.detection_model) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._vl_model = create(self.config.g.detection_model) self._active_lookout: tuple[str, ...] = () self._lookout_subscription: DisposableBase | None = None self._model_started: bool = False diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index bf62d50bcf..fe6d7d50e0 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -19,7 +19,7 @@ from datetime import datetime import os import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import uuid import cv2 @@ -27,21 +27,21 @@ from reactivex import Observable, interval, operators as ops from reactivex.disposable import Disposable -from dimos import spec from dimos.agents_deprecated.memory.image_embedding import ImageEmbeddingProvider from dimos.agents_deprecated.memory.spatial_vector_db import SpatialVectorDB from dimos.agents_deprecated.memory.visual_memory import VisualMemory from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image +from dimos.spec.perception import Camera from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 _OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" _MEMORY_DIR = _OUTPUT_DIR / "memory" @@ -53,7 +53,23 @@ logger = setup_logger() -class SpatialMemory(Module): +class SpatialConfig(ModuleConfig): + collection_name: str = "spatial_memory" + embedding_model: str = "clip" + embedding_dimensions: int = 512 + min_distance_threshold: float = 0.01 # Min distance in meters to store a new frame + min_time_threshold: float = 1.0 # Min time in seconds to record a new frame + db_path: str | None = str(_DB_PATH) # Path for ChromaDB persistence + visual_memory_path: str | None = str( + _VISUAL_MEMORY_PATH + ) # Path for saving/loading visual memory + new_memory: bool = True # Whether to create a new memory from scratch + output_dir: str | None = str(_SPATIAL_MEMORY_DIR) # Directory for storing visual memory data + chroma_client: Any = None # Optional ChromaDB client for persistence + visual_memory: VisualMemory | None = None # Optional VisualMemory instance for storing images + + +class SpatialMemory(Module[SpatialConfig]): """ A Dimos module for building and querying Robot spatial memory. @@ -63,29 +79,12 @@ class SpatialMemory(Module): robot locations that can be queried by name. """ + default_config = SpatialConfig + # LCM inputs color_image: In[Image] - def __init__( - self, - collection_name: str = "spatial_memory", - embedding_model: str = "clip", - embedding_dimensions: int = 512, - min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame - min_time_threshold: float = 1.0, # Min time in seconds to record a new frame - db_path: str | None = str(_DB_PATH), # Path for ChromaDB persistence - visual_memory_path: str | None = str( - _VISUAL_MEMORY_PATH - ), # Path for saving/loading visual memory - new_memory: bool = True, # Whether to create a new memory from scratch - output_dir: str | None = str( - _SPATIAL_MEMORY_DIR - ), # Directory for storing visual memory data - chroma_client: Any = None, # Optional ChromaDB client for persistence - visual_memory: Optional[ - "VisualMemory" - ] = None, # Optional VisualMemory instance for storing images - ) -> None: + def __init__(self, **kwargs: Any) -> None: """ Initialize the spatial perception system. @@ -99,39 +98,36 @@ def __init__( visual_memory: Optional VisualMemory instance for storing images output_dir: Directory for storing visual memory data if visual_memory is not provided """ - self.collection_name = collection_name - self.embedding_model = embedding_model - self.embedding_dimensions = embedding_dimensions - self.min_distance_threshold = min_distance_threshold - self.min_time_threshold = min_time_threshold - - # Set up paths for persistence - # Call parent Module init - super().__init__() + super().__init__(**kwargs) - self.db_path = db_path - self.visual_memory_path = visual_memory_path + self.collection_name = self.config.collection_name + self.embedding_model = self.config.embedding_model + self.embedding_dimensions = self.config.embedding_dimensions + self.min_distance_threshold = self.config.min_distance_threshold + self.min_time_threshold = self.config.min_time_threshold + self.db_path = self.config.db_path + self.visual_memory_path = self.config.visual_memory_path # Setup ChromaDB client if not provided - self._chroma_client = chroma_client - if chroma_client is None and db_path is not None: + self._chroma_client = self.config.chroma_client + if self._chroma_client is None and self.db_path is not None: # Create db directory if needed - os.makedirs(db_path, exist_ok=True) + os.makedirs(self.db_path, exist_ok=True) # Clean up existing DB if creating new memory - if new_memory and os.path.exists(db_path): + if self.config.new_memory and os.path.exists(self.db_path): try: logger.info("Creating new ChromaDB database (new_memory=True)") # Try to delete any existing database files import shutil - for item in os.listdir(db_path): - item_path = os.path.join(db_path, item) + for item in os.listdir(self.db_path): + item_path = os.path.join(self.db_path, item) if os.path.isfile(item_path): os.unlink(item_path) elif os.path.isdir(item_path): shutil.rmtree(item_path) - logger.info(f"Removed existing ChromaDB files from {db_path}") + logger.info(f"Removed existing ChromaDB files from {self.db_path}") except Exception as e: logger.error(f"Error clearing ChromaDB directory: {e}") @@ -139,33 +135,33 @@ def __init__( from chromadb.config import Settings self._chroma_client = chromadb.PersistentClient( - path=db_path, settings=Settings(anonymized_telemetry=False) + path=self.db_path, settings=Settings(anonymized_telemetry=False) ) # Initialize or load visual memory - self._visual_memory = visual_memory - if visual_memory is None: - if new_memory or not os.path.exists(visual_memory_path or ""): + self._visual_memory = self.config.visual_memory + if self._visual_memory is None: + if self.config.new_memory or not os.path.exists(self.visual_memory_path or ""): logger.info("Creating new visual memory") - self._visual_memory = VisualMemory(output_dir=output_dir) + self._visual_memory = VisualMemory(output_dir=self.config.output_dir) else: try: - logger.info(f"Loading existing visual memory from {visual_memory_path}...") + logger.info(f"Loading existing visual memory from {self.visual_memory_path}...") self._visual_memory = VisualMemory.load( - visual_memory_path, # type: ignore[arg-type] - output_dir=output_dir, + self.visual_memory_path, # type: ignore[arg-type] + output_dir=self.config.output_dir, ) logger.info(f"Loaded {self._visual_memory.count()} images from previous runs") except Exception as e: logger.error(f"Error loading visual memory: {e}") - self._visual_memory = VisualMemory(output_dir=output_dir) + self._visual_memory = VisualMemory(output_dir=self.config.output_dir) self.embedding_provider: ImageEmbeddingProvider = ImageEmbeddingProvider( - model_name=embedding_model, dimensions=embedding_dimensions + model_name=self.embedding_model, dimensions=self.embedding_dimensions ) self.vector_db: SpatialVectorDB = SpatialVectorDB( - collection_name=collection_name, + collection_name=self.collection_name, chroma_client=self._chroma_client, visual_memory=self._visual_memory, embedding_provider=self.embedding_provider, @@ -184,7 +180,7 @@ def __init__( self._latest_video_frame: np.ndarray | None = None # type: ignore[type-arg] self._process_interval = 1 - logger.info(f"SpatialMemory initialized with model {embedding_model}") + logger.info(f"SpatialMemory initialized with model {self.embedding_model}") @rpc def start(self) -> None: @@ -581,7 +577,7 @@ def query_tagged_location(self, query: str) -> RobotLocation | None: def deploy( # type: ignore[no-untyped-def] dimos: ModuleCoordinator, - camera: spec.Camera, + camera: Camera, ): spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") # type: ignore[attr-defined] spatial_memory.color_image.connect(camera.color_image) diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index 433896aefe..322513d459 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -22,7 +22,7 @@ from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose from dimos.perception.spatial_perception import SpatialMemory from dimos.stream.video_provider import VideoProvider diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index ac9b132a69..d8567036bf 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -20,36 +20,37 @@ from reactivex import operators as ops from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import Out from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger() -class VideoReplayModule(Module): +class VideoReplayConfig(ModuleConfig): + video_path: str + + +class VideoReplayModule(Module[VideoReplayConfig]): """Module that replays video data from TimedSensorReplay.""" + default_config = VideoReplayConfig video_out: Out[Image] - - def __init__(self, video_path: str) -> None: - super().__init__() - self.video_path = video_path - self._subscription = None + _subscription = None @rpc def start(self) -> None: """Start replaying video data.""" # Use TimedSensorReplay to replay video frames - video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) + video_replay = TimedSensorReplay(self.config.video_path, autocast=Image.from_numpy) # Subscribe to the replay stream and publish to LCM self._subscription = ( diff --git a/dimos/protocol/__init__.py b/dimos/protocol/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/protocol/encode/__init__.py b/dimos/protocol/encode/encoder.py similarity index 82% rename from dimos/protocol/encode/__init__.py rename to dimos/protocol/encode/encoder.py index 87386a09e5..b6e00e4b1c 100644 --- a/dimos/protocol/encode/__init__.py +++ b/dimos/protocol/encode/encoder.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import ABC, abstractmethod import json from typing import Generic, Protocol, TypeVar diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py deleted file mode 100644 index 94a58b60de..0000000000 --- a/dimos/protocol/pubsub/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import dimos.protocol.pubsub.impl.lcmpubsub as lcm -from dimos.protocol.pubsub.impl.memory import Memory -from dimos.protocol.pubsub.spec import PubSub - -__all__ = [ - "Memory", - "PubSub", - "lcm", -] diff --git a/dimos/protocol/pubsub/bridge.py b/dimos/protocol/pubsub/bridge.py index f312caed7b..72cbe155d9 100644 --- a/dimos/protocol/pubsub/bridge.py +++ b/dimos/protocol/pubsub/bridge.py @@ -16,10 +16,9 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, Protocol, TypeVar -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service if TYPE_CHECKING: from collections.abc import Callable @@ -66,8 +65,7 @@ def pass_msg(msg: MsgFrom, topic: TopicFrom) -> None: return pubsub1.subscribe_all(pass_msg) -@dataclass -class BridgeConfig(Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): +class BridgeConfig(BaseConfig, Generic[TopicFrom, TopicTo, MsgFrom, MsgTo]): """Configuration for a one-way bridge.""" source: AllPubSub[TopicFrom, MsgFrom] diff --git a/dimos/protocol/pubsub/encoders.py b/dimos/protocol/pubsub/encoders.py index 6b2056fa8b..69aa328765 100644 --- a/dimos/protocol/pubsub/encoders.py +++ b/dimos/protocol/pubsub/encoders.py @@ -20,8 +20,8 @@ import pickle from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast -from dimos.msgs import DimosMsg -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.protocol import DimosMsg +from dimos.msgs.sensor_msgs.Image import Image if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/protocol/pubsub/impl/__init__.py b/dimos/protocol/pubsub/impl/__init__.py deleted file mode 100644 index 63a5bfa6d6..0000000000 --- a/dimos/protocol/pubsub/impl/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from dimos.protocol.pubsub.impl.lcmpubsub import ( - LCM as LCM, - LCMPubSubBase as LCMPubSubBase, - PickleLCM as PickleLCM, -) -from dimos.protocol.pubsub.impl.memory import Memory as Memory diff --git a/dimos/protocol/pubsub/impl/lcmpubsub.py b/dimos/protocol/pubsub/impl/lcmpubsub.py index bf6bbd0dec..50c7c49f2f 100644 --- a/dimos/protocol/pubsub/impl/lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/lcmpubsub.py @@ -14,10 +14,13 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass import re -from typing import TYPE_CHECKING, Any +import threading +from typing import Any +from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.encoders import ( JpegEncoderMixin, LCMEncoderMixin, @@ -25,15 +28,9 @@ ) from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf +from dimos.protocol.service.lcmservice import LCMService, autoconf from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from collections.abc import Callable - import threading - - from dimos.msgs import DimosMsg - logger = setup_logger() @@ -66,7 +63,7 @@ def from_channel_str(channel: str, default_lcm_type: type[DimosMsg] | None = Non Channel format: /topic#module.ClassName Falls back to default_lcm_type if type cannot be parsed. """ - from dimos.msgs import resolve_msg_type + from dimos.msgs.helpers import resolve_msg_type if "#" not in channel: return Topic(topic=channel, lcm_type=default_lcm_type) @@ -83,7 +80,6 @@ class LCMPubSubBase(LCMService, AllPubSub[Topic, Any]): RegexSubscribable directly without needing discovery-based fallback. """ - default_config = LCMConfig _stop_event: threading.Event _thread: threading.Thread | None diff --git a/dimos/protocol/pubsub/impl/memory.py b/dimos/protocol/pubsub/impl/memory.py index 3425a5ee3d..25e10efe32 100644 --- a/dimos/protocol/pubsub/impl/memory.py +++ b/dimos/protocol/pubsub/impl/memory.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any -from dimos.protocol import encode +from dimos.protocol.encode import encoder as encode from dimos.protocol.pubsub.encoders import PubSubEncoderMixin from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/impl/redispubsub.py b/dimos/protocol/pubsub/impl/redispubsub.py index 6cc089e953..b299d6b883 100644 --- a/dimos/protocol/pubsub/impl/redispubsub.py +++ b/dimos/protocol/pubsub/impl/redispubsub.py @@ -14,25 +14,24 @@ from collections import defaultdict from collections.abc import Callable -from dataclasses import dataclass, field import json import threading import time from types import TracebackType from typing import Any +from pydantic import Field import redis # type: ignore[import-not-found] from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service -@dataclass -class RedisConfig: +class RedisConfig(BaseConfig): host: str = "localhost" port: int = 6379 db: int = 0 - kwargs: dict[str, Any] = field(default_factory=dict) + kwargs: dict[str, Any] = Field(default_factory=dict) class Redis(PubSub[str, Any], Service[RedisConfig]): diff --git a/dimos/protocol/pubsub/impl/rospubsub.py b/dimos/protocol/pubsub/impl/rospubsub.py index 1a3c989a4d..1e18b3759a 100644 --- a/dimos/protocol/pubsub/impl/rospubsub.py +++ b/dimos/protocol/pubsub/impl/rospubsub.py @@ -37,7 +37,7 @@ import uuid -from dimos.msgs import DimosMsg +from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.impl.rospubsub_conversion import ( derive_ros_type, dimos_to_ros, diff --git a/dimos/protocol/pubsub/impl/rospubsub_conversion.py b/dimos/protocol/pubsub/impl/rospubsub_conversion.py index 275033a5ac..150c3eeb8f 100644 --- a/dimos/protocol/pubsub/impl/rospubsub_conversion.py +++ b/dimos/protocol/pubsub/impl/rospubsub_conversion.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: - from dimos.msgs import DimosMsg + from dimos.msgs.protocol import DimosMsg from dimos.protocol.pubsub.impl.rospubsub import ROSMessage diff --git a/dimos/protocol/pubsub/impl/shmpubsub.py b/dimos/protocol/pubsub/impl/shmpubsub.py index db0a91e579..883afcdcc0 100644 --- a/dimos/protocol/pubsub/impl/shmpubsub.py +++ b/dimos/protocol/pubsub/impl/shmpubsub.py @@ -13,9 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# --------------------------------------------------------------------------- -# SharedMemory Pub/Sub over unified IPC channels (CPU/CUDA) -# --------------------------------------------------------------------------- from __future__ import annotations @@ -101,7 +98,7 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u # Lock for thread-safe publish buffer access self.publish_lock = threading.Lock() - # ----- init / lifecycle ------------------------------------------------- + # init / lifecycle def __init__( self, @@ -146,7 +143,7 @@ def stop(self) -> None: self._topics.clear() logger.debug("SharedMemory PubSub stopped.") - # ----- PubSub API (bytes on the wire) ---------------------------------- + # PubSub API (bytes on the wire) def publish(self, topic: str, message: bytes) -> None: if not isinstance(message, bytes | bytearray | memoryview): @@ -212,7 +209,7 @@ def _unsub() -> None: return _unsub - # ----- Capacity mgmt ---------------------------------------------------- + # Capacity mgmt def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[type-arg] """Change payload capacity (bytes) for a topic; returns new descriptor.""" @@ -229,7 +226,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[typ st.publish_buffer = np.zeros(new_shape, dtype=np.uint8) return desc # type: ignore[no-any-return] - # ----- Internals -------------------------------------------------------- + # Internals def _ensure_topic(self, topic: str) -> _TopicState: with self._lock: diff --git a/dimos/protocol/pubsub/impl/test_lcmpubsub.py b/dimos/protocol/pubsub/impl/test_lcmpubsub.py index ea80b4c445..ba29c70958 100644 --- a/dimos/protocol/pubsub/impl/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/impl/test_lcmpubsub.py @@ -18,7 +18,9 @@ import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import ( LCM, LCMPubSubBase, diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index 5f574065ba..ef9df74227 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -28,7 +28,7 @@ # Add msg_name to LCM PointStamped for testing nested message conversion PointStamped.msg_name = "geometry_msgs.PointStamped" from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay def ros_node(): diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index fbf98d379e..29ed682f8d 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -54,11 +54,6 @@ def _open_shm_with_retry(name: str) -> SharedMemory: raise FileNotFoundError(f"SHM not found after {tries} retries: {name}") from last -# --------------------------- -# 1) Abstract interface -# --------------------------- - - class FrameChannel(ABC): """Single-slot 'freshest frame' IPC channel with a tiny control block. - Double-buffered to avoid torn reads. @@ -125,11 +120,6 @@ def _safe_unlink(name: str) -> None: pass -# --------------------------- -# 2) CPU shared-memory backend -# --------------------------- - - class CpuShmChannel(FrameChannel): def __init__( # type: ignore[no-untyped-def] self, @@ -300,11 +290,6 @@ def close(self) -> None: pass -# --------------------------- -# 3) Factories -# --------------------------- - - class CPU_IPC_Factory: """Creates/attaches CPU shared-memory channels.""" @@ -318,11 +303,6 @@ def attach(desc: dict) -> CpuShmChannel: # type: ignore[type-arg] return CpuShmChannel.attach(desc) # type: ignore[arg-type, no-any-return] -# --------------------------- -# 4) Runtime selector -# --------------------------- - - def make_frame_channel( # type: ignore[no-untyped-def] shape, dtype=np.uint8, prefer: str = "auto", device: int = 0 ) -> FrameChannel: diff --git a/dimos/protocol/pubsub/test_pattern_sub.py b/dimos/protocol/pubsub/test_pattern_sub.py index cdbce5d5a6..ac94ba1b3b 100644 --- a/dimos/protocol/pubsub/test_pattern_sub.py +++ b/dimos/protocol/pubsub/test_pattern_sub.py @@ -24,7 +24,9 @@ import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, LCMPubSubBase, Topic from dimos.protocol.pubsub.patterns import Glob from dimos.protocol.pubsub.spec import AllPubSub, PubSub diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index a240319fdf..e36741bbfd 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -23,7 +23,7 @@ import pytest -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.impl.memory import Memory diff --git a/dimos/protocol/rpc/__init__.py b/dimos/protocol/rpc/__init__.py deleted file mode 100644 index 1eb892d956..0000000000 --- a/dimos/protocol/rpc/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.protocol.rpc.pubsubrpc import LCMRPC, ShmRPC -from dimos.protocol.rpc.spec import RPCClient, RPCServer, RPCSpec - -__all__ = ["LCMRPC", "RPCClient", "RPCServer", "RPCSpec", "ShmRPC"] diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py index f31d20cf19..5baa5ac40c 100644 --- a/dimos/protocol/rpc/test_lcmrpc.py +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -17,7 +17,7 @@ import pytest from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH -from dimos.protocol.rpc import LCMRPC +from dimos.protocol.rpc.pubsubrpc import LCMRPC @pytest.fixture diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py deleted file mode 100644 index fb9df08ca9..0000000000 --- a/dimos/protocol/service/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import Configurable as Configurable, Service as Service - -__all__ = [ - "Configurable", - "LCMService", - "Service", -] diff --git a/dimos/protocol/service/ddsservice.py b/dimos/protocol/service/ddsservice.py index 6ed04c07ad..b5562defff 100644 --- a/dimos/protocol/service/ddsservice.py +++ b/dimos/protocol/service/ddsservice.py @@ -14,9 +14,8 @@ from __future__ import annotations -from dataclasses import dataclass import threading -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING try: from cyclonedds.domain import DomainParticipant @@ -26,7 +25,7 @@ DDS_AVAILABLE = False DomainParticipant = None # type: ignore[assignment, misc] -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import BaseConfig, Service from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -38,8 +37,7 @@ _participants_lock = threading.Lock() -@dataclass -class DDSConfig: +class DDSConfig(BaseConfig): """Configuration for DDS service.""" domain_id: int = 0 @@ -49,9 +47,6 @@ class DDSConfig: class DDSService(Service[DDSConfig]): default_config = DDSConfig - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - def start(self) -> None: """Start the DDS service.""" domain_id = self.config.domain_id diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 5cd4563fd1..0211b34129 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -15,18 +15,25 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass import os import platform +import sys import threading import traceback +from typing import Any -import lcm +import lcm as lcm_mod -from dimos.protocol.service.spec import Service -from dimos.protocol.service.system_configurator import configure_system, lcm_configurators +from dimos.protocol.service.spec import BaseConfig, Service +from dimos.protocol.service.system_configurator.base import configure_system +from dimos.protocol.service.system_configurator.lcm_config import lcm_configurators from dimos.utils.logging_config import setup_logger +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = setup_logger() _DEFAULT_LCM_HOST = "239.255.76.67" @@ -45,40 +52,37 @@ def autoconf(check_only: bool = False) -> None: configure_system(checks, check_only=check_only) -@dataclass -class LCMConfig: +class LCMConfig(BaseConfig): ttl: int = 0 - url: str | None = None - lcm: lcm.LCM | None = None - - def __post_init__(self) -> None: - if self.url is None: - self.url = _DEFAULT_LCM_URL + url: str = _DEFAULT_LCM_URL + lcm: lcm_mod.LCM | None = None +_Config = TypeVar("_Config", bound=LCMConfig, default=LCMConfig) _LCM_LOOP_TIMEOUT = 50 # this class just sets up cpp LCM instance # and runs its handle loop in a thread # higher order stuff is done by pubsub/impl/lcmpubsub.py -class LCMService(Service[LCMConfig]): - default_config = LCMConfig - l: lcm.LCM | None +class LCMService(Service[_Config]): + default_config = LCMConfig # type: ignore[assignment] + + l: lcm_mod.LCM | None _stop_event: threading.Event _l_lock: threading.Lock _thread: threading.Thread | None _call_thread_pool: ThreadPoolExecutor | None = None _call_thread_pool_lock: threading.RLock = threading.RLock() - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) # we support passing an existing LCM instance if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() self._l_lock = threading.Lock() self._stop_event = threading.Event() @@ -113,7 +117,7 @@ def start(self) -> None: if self.config.lcm: self.l = self.config.lcm else: - self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self.l = lcm_mod.LCM(self.config.url) if self.config.url else lcm_mod.LCM() self._stop_event.clear() self._thread = threading.Thread(target=self._lcm_loop) diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index c4e6758614..c9796cf2b5 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -13,17 +13,24 @@ # limitations under the License. from abc import ABC -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar + +from pydantic import BaseModel + + +class BaseConfig(BaseModel): + model_config = {"arbitrary_types_allowed": True, "extra": "forbid"} + # Generic type for service configuration -ConfigT = TypeVar("ConfigT") +ConfigT = TypeVar("ConfigT", bound=BaseConfig) class Configurable(Generic[ConfigT]): default_config: type[ConfigT] - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - self.config: ConfigT = self.default_config(**kwargs) + def __init__(self, **kwargs: Any) -> None: + self.config = self.default_config(**kwargs) class Service(Configurable[ConfigT], ABC): diff --git a/dimos/protocol/service/system_configurator/base.py b/dimos/protocol/service/system_configurator/base.py index c221af890f..e5f65bdc18 100644 --- a/dimos/protocol/service/system_configurator/base.py +++ b/dimos/protocol/service/system_configurator/base.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -# ----------------------------- sudo helpers ----------------------------- +# sudo helpers @cache @@ -66,7 +66,7 @@ def _write_sysctl_int(name: str, value: int) -> None: sudo_run("sysctl", "-w", f"{name}={value}", check=True, text=True, capture_output=False) -# -------------------------- base class for system config checks/requirements -------------------------- +# base class for system config checks/requirements class SystemConfigurator(ABC): @@ -91,7 +91,7 @@ def fix(self) -> None: raise NotImplementedError -# ----------------------------- generic enforcement of system configs ----------------------------- +# generic enforcement of system configs def configure_system(checks: list[SystemConfigurator], check_only: bool = False) -> None: diff --git a/dimos/protocol/service/system_configurator/lcm.py b/dimos/protocol/service/system_configurator/lcm.py index 6599f97407..9e1b3e5c61 100644 --- a/dimos/protocol/service/system_configurator/lcm.py +++ b/dimos/protocol/service/system_configurator/lcm.py @@ -25,7 +25,7 @@ sudo_run, ) -# ------------------------------ specific checks: multicast ------------------------------ +# specific checks: multicast class MulticastConfiguratorLinux(SystemConfigurator): @@ -182,7 +182,7 @@ def fix(self) -> None: sudo_run(*self.add_route_cmd, check=True, text=True, capture_output=True) -# ------------------------------ specific checks: buffers ------------------------------ +# specific checks: buffers IDEAL_RMEM_SIZE = 67_108_864 # 64MB @@ -254,7 +254,7 @@ def fix(self) -> None: _write_sysctl_int(key, target) -# ------------------------------ specific checks: ulimit ------------------------------ +# specific checks: ulimit class MaxFileConfiguratorMacOS(SystemConfigurator): diff --git a/dimos/protocol/service/system_configurator/__init__.py b/dimos/protocol/service/system_configurator/lcm_config.py similarity index 54% rename from dimos/protocol/service/system_configurator/__init__.py rename to dimos/protocol/service/system_configurator/lcm_config.py index 31b5af4d8c..72f1e5d774 100644 --- a/dimos/protocol/service/system_configurator/__init__.py +++ b/dimos/protocol/service/system_configurator/lcm_config.py @@ -12,18 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""System configurator package — re-exports for backward compatibility.""" +"""Platform-appropriate LCM system configurators.""" import platform -from dimos.protocol.service.system_configurator.base import ( - SystemConfigurator, - configure_system, - sudo_run, -) -from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.base import SystemConfigurator from dimos.protocol.service.system_configurator.lcm import ( - IDEAL_RMEM_SIZE, BufferConfiguratorLinux, BufferConfiguratorMacOS, MaxFileConfiguratorMacOS, @@ -33,17 +27,6 @@ from dimos.protocol.service.system_configurator.libpython import LibPythonConfiguratorMacOS -# TODO: This is a configurator API issue and inserted here temporarily -# -# We need to use different configurators based on the underlying OS -# -# We should have separation of concerns, nothing but configurators themselves care about the OS in this context -# -# So configurators with multi-os behavior should be responsible for the right per-OS behaviour, and -# not external systems -# -# We might want to have some sort of recursive configurators -# def lcm_configurators() -> list[SystemConfigurator]: """Return the platform-appropriate LCM system configurators.""" system = platform.system() @@ -56,23 +39,7 @@ def lcm_configurators() -> list[SystemConfigurator]: return [ MulticastConfiguratorMacOS(loopback_interface="lo0"), BufferConfiguratorMacOS(), - MaxFileConfiguratorMacOS(), # TODO: this is not LCM related and shouldn't be here at all + MaxFileConfiguratorMacOS(), LibPythonConfiguratorMacOS(), ] return [] - - -__all__ = [ - "IDEAL_RMEM_SIZE", - "BufferConfiguratorLinux", - "BufferConfiguratorMacOS", - "ClockSyncConfigurator", - "LibPythonConfiguratorMacOS", - "MaxFileConfiguratorMacOS", - "MulticastConfiguratorLinux", - "MulticastConfiguratorMacOS", - "SystemConfigurator", - "configure_system", - "lcm_configurators", - "sudo_run", -] diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index 857bc305a2..cbab6ff3ab 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -14,7 +14,9 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, create_autospec, patch + +from lcm import LCM from dimos.protocol.pubsub.impl.lcmpubsub import Topic from dimos.protocol.service.lcmservice import ( @@ -23,22 +25,23 @@ LCMService, autoconf, ) -from dimos.protocol.service.system_configurator import ( +from dimos.protocol.service.system_configurator.lcm import ( BufferConfiguratorLinux, BufferConfiguratorMacOS, - LibPythonConfiguratorMacOS, MaxFileConfiguratorMacOS, MulticastConfiguratorLinux, MulticastConfiguratorMacOS, ) +from dimos.protocol.service.system_configurator.libpython import LibPythonConfiguratorMacOS -# ----------------------------- autoconf tests ----------------------------- +# autoconf tests class TestConfigureSystemForLcm: def test_creates_linux_checks_on_linux(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Linux" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Linux", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf() @@ -51,7 +54,8 @@ def test_creates_linux_checks_on_linux(self) -> None: def test_creates_macos_checks_on_darwin(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Darwin" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Darwin", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf() @@ -66,7 +70,8 @@ def test_creates_macos_checks_on_darwin(self) -> None: def test_passes_check_only_flag(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Linux" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Linux", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: autoconf(check_only=True) @@ -75,7 +80,8 @@ def test_passes_check_only_flag(self) -> None: def test_logs_error_on_unsupported_system(self) -> None: with patch( - "dimos.protocol.service.system_configurator.platform.system", return_value="Windows" + "dimos.protocol.service.system_configurator.lcm_config.platform.system", + return_value="Windows", ): with patch("dimos.protocol.service.lcmservice.configure_system") as mock_configure: with patch("dimos.protocol.service.lcmservice.logger") as mock_logger: @@ -85,7 +91,7 @@ def test_logs_error_on_unsupported_system(self) -> None: assert "Windows" in mock_logger.error.call_args[0][0] -# ----------------------------- LCMConfig tests ----------------------------- +# LCMConfig tests class TestLCMConfig: @@ -100,12 +106,8 @@ def test_custom_url(self) -> None: config = LCMConfig(url=custom_url) assert config.url == custom_url - def test_post_init_sets_default_url_when_none(self) -> None: - config = LCMConfig(url=None) - assert config.url == _DEFAULT_LCM_URL - -# ----------------------------- Topic tests ----------------------------- +# Topic tests class TestTopic: @@ -120,13 +122,13 @@ def test_str_with_lcm_type(self) -> None: assert str(topic) == "my_topic#TestMessage" -# ----------------------------- LCMService tests ----------------------------- +# LCMService tests class TestLCMService: def test_init_with_default_config(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -136,8 +138,8 @@ def test_init_with_default_config(self) -> None: def test_init_with_custom_url(self) -> None: custom_url = "udpm://192.168.1.1:7777?ttl=1" - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance # Pass url as kwarg, not config= @@ -145,17 +147,17 @@ def test_init_with_custom_url(self) -> None: mock_lcm_class.assert_called_once_with(custom_url) def test_init_with_existing_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) mock_lcm_class.assert_not_called() assert service.l == mock_lcm_instance def test_start_and_stop(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -172,8 +174,8 @@ def test_start_and_stop(self) -> None: assert not service._thread.is_alive() def test_getstate_excludes_unpicklable_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -187,8 +189,8 @@ def test_getstate_excludes_unpicklable_attrs(self) -> None: assert "_call_thread_pool_lock" not in state def test_setstate_reinitializes_runtime_attrs(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -207,8 +209,8 @@ def test_setstate_reinitializes_runtime_attrs(self) -> None: assert hasattr(new_service._l_lock, "release") def test_start_reinitializes_lcm_after_unpickling(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -227,8 +229,8 @@ def test_start_reinitializes_lcm_after_unpickling(self) -> None: new_service.stop() def test_stop_cleans_up_lcm_instance(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -239,7 +241,7 @@ def test_stop_cleans_up_lcm_instance(self) -> None: assert service.l is None def test_stop_preserves_external_lcm_instance(self) -> None: - mock_lcm_instance = MagicMock() + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) # Pass lcm as kwarg service = LCMService(lcm=mock_lcm_instance) @@ -250,8 +252,8 @@ def test_stop_preserves_external_lcm_instance(self) -> None: assert service.l == mock_lcm_instance def test_get_call_thread_pool_creates_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() @@ -269,8 +271,8 @@ def test_get_call_thread_pool_creates_pool(self) -> None: pool.shutdown(wait=False) def test_stop_shuts_down_thread_pool(self) -> None: - with patch("dimos.protocol.service.lcmservice.lcm.LCM") as mock_lcm_class: - mock_lcm_instance = MagicMock() + with patch("dimos.protocol.service.lcmservice.lcm_mod.LCM") as mock_lcm_class: + mock_lcm_instance = create_autospec(LCM, spec_set=True, instance=True) mock_lcm_class.return_value = mock_lcm_instance service = LCMService() diff --git a/dimos/protocol/service/test_system_configurator.py b/dimos/protocol/service/test_system_configurator.py index 62de2a61ea..715d9eede7 100644 --- a/dimos/protocol/service/test_system_configurator.py +++ b/dimos/protocol/service/test_system_configurator.py @@ -19,25 +19,25 @@ import pytest -from dimos.protocol.service.system_configurator import ( +from dimos.protocol.service.system_configurator.base import ( + SystemConfigurator, + _is_root_user, + _read_sysctl_int, + _write_sysctl_int, + configure_system, + sudo_run, +) +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.lcm import ( IDEAL_RMEM_SIZE, BufferConfiguratorLinux, BufferConfiguratorMacOS, - ClockSyncConfigurator, MaxFileConfiguratorMacOS, MulticastConfiguratorLinux, MulticastConfiguratorMacOS, - SystemConfigurator, - configure_system, - sudo_run, -) -from dimos.protocol.service.system_configurator.base import ( - _is_root_user, - _read_sysctl_int, - _write_sysctl_int, ) -# ----------------------------- Helper function tests ----------------------------- +# Helper function tests class TestIsRootUser: @@ -122,7 +122,7 @@ def test_calls_sudo_run_with_correct_args(self) -> None: ) -# ----------------------------- configure_system tests ----------------------------- +# configure_system tests class MockConfigurator(SystemConfigurator): @@ -186,7 +186,7 @@ def test_exits_on_no_with_critical_check(self, mocker) -> None: assert exc_info.value.code == 1 -# ----------------------------- MulticastConfiguratorLinux tests ----------------------------- +# MulticastConfiguratorLinux tests class TestMulticastConfiguratorLinux: @@ -259,7 +259,7 @@ def test_fix_runs_needed_commands(self) -> None: assert mock_run.call_count == 2 -# ----------------------------- MulticastConfiguratorMacOS tests ----------------------------- +# MulticastConfiguratorMacOS tests class TestMulticastConfiguratorMacOS: @@ -311,7 +311,7 @@ def test_fix_runs_route_command(self) -> None: assert "224.0.0.0/4" in add_args -# ----------------------------- BufferConfiguratorLinux tests ----------------------------- +# BufferConfiguratorLinux tests class TestBufferConfiguratorLinux: @@ -354,7 +354,7 @@ def test_fix_writes_needed_values(self) -> None: mock_write.assert_called_once_with("net.core.rmem_max", IDEAL_RMEM_SIZE) -# ----------------------------- BufferConfiguratorMacOS tests ----------------------------- +# BufferConfiguratorMacOS tests class TestBufferConfiguratorMacOS: @@ -398,7 +398,7 @@ def test_fix_writes_needed_values(self) -> None: ) -# ----------------------------- MaxFileConfiguratorMacOS tests ----------------------------- +# MaxFileConfiguratorMacOS tests class TestMaxFileConfiguratorMacOS: @@ -489,7 +489,7 @@ def test_fix_raises_on_setrlimit_error(self) -> None: configurator.fix() -# ----------------------------- ClockSyncConfigurator tests ----------------------------- +# ClockSyncConfigurator tests class TestClockSyncConfigurator: diff --git a/dimos/protocol/tf/__init__.py b/dimos/protocol/tf/__init__.py deleted file mode 100644 index cb00dbde3c..0000000000 --- a/dimos/protocol/tf/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.protocol.tf.tf import LCMTF, TF, MultiTBuffer, PubSubTF, TBuffer, TFConfig, TFSpec - -__all__ = ["LCMTF", "TF", "MultiTBuffer", "PubSubTF", "TBuffer", "TFConfig", "TFSpec"] diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index c1f0b13fa2..b0843bfccd 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -19,8 +19,11 @@ import pytest -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.protocol.tf import TF, MultiTBuffer, TBuffer +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer # from https://foxglove.dev/blog/understanding-ros-transforms diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 825e89fc8c..97b2132bbb 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -16,32 +16,32 @@ from abc import abstractmethod from collections import deque -from dataclasses import dataclass, field +from dataclasses import field from functools import reduce from typing import TypeVar from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.tf2_msgs import TFMessage +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol.pubsub.impl.lcmpubsub import LCM, Topic from dimos.protocol.pubsub.spec import PubSub -from dimos.protocol.service.lcmservice import Service # type: ignore[attr-defined] +from dimos.protocol.service.spec import BaseConfig, Service CONFIG = TypeVar("CONFIG") # generic configuration for transform service -@dataclass -class TFConfig: +class TFConfig(BaseConfig): buffer_size: float = 10.0 # seconds rate_limit: float = 10.0 # Hz -# generic specification for transform service -class TFSpec(Service[TFConfig]): - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(**kwargs) +_TFConfig = TypeVar("_TFConfig", bound=TFConfig) + +# generic specification for transform service +class TFSpec(Service[_TFConfig]): @abstractmethod def publish(self, *args: Transform) -> None: ... @@ -244,15 +244,17 @@ def __str__(self) -> str: return "\n".join(lines) -@dataclass class PubSubTFConfig(TFConfig): topic: Topic | None = None # Required field but needs default for dataclass inheritance pubsub: type[PubSub] | PubSub | None = None # type: ignore[type-arg] autostart: bool = True -class PubSubTF(MultiTBuffer, TFSpec): - default_config: type[PubSubTFConfig] = PubSubTFConfig +_PubSubConfig = TypeVar("_PubSubConfig", bound=PubSubTFConfig) + + +class PubSubTF(MultiTBuffer, TFSpec[_PubSubConfig]): + default_config: type[_PubSubConfig] = PubSubTFConfig # type: ignore[assignment] def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] TFSpec.__init__(self, **kwargs) @@ -330,15 +332,14 @@ def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) -@dataclass class LCMPubsubConfig(PubSubTFConfig): topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) pubsub: type[PubSub] | PubSub | None = LCM # type: ignore[type-arg] autostart: bool = True -class LCMTF(PubSubTF): - default_config: type[LCMPubsubConfig] = LCMPubsubConfig +class LCMTF(PubSubTF[LCMPubsubConfig]): + default_config = LCMPubsubConfig TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index 158a68d3d8..aec1f947ce 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -13,15 +13,18 @@ # limitations under the License. from datetime import datetime -from typing import Union -from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.protocol.tf.tf import TFConfig, TFSpec +class Config(TFConfig, LCMConfig): + """Combined config""" + + # this doesn't work due to tf_lcm_py package -class TFLCM(TFSpec, LCMService): +class TFLCM(TFSpec[Config], LCMService[Config]): """A service for managing and broadcasting transforms using LCM. This is not a separete module, You can include this in your module if you need to access transforms. @@ -34,7 +37,7 @@ class TFLCM(TFSpec, LCMService): for each module. """ - default_config = Union[TFConfig, LCMConfig] + default_config = Config def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) diff --git a/dimos/robot/__init__.py b/dimos/robot/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/drone/__init__.py b/dimos/robot/drone/__init__.py deleted file mode 100644 index 828059e99d..0000000000 --- a/dimos/robot/drone/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Generic drone module for MAVLink-based drones.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "camera_module": ["DroneCameraModule"], - "connection_module": ["DroneConnectionModule"], - "mavlink_connection": ["MavlinkConnection"], - }, -) diff --git a/dimos/robot/drone/blueprints/__init__.py b/dimos/robot/drone/blueprints/__init__.py deleted file mode 100644 index d011c6e4fb..0000000000 --- a/dimos/robot/drone/blueprints/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DimOS Drone blueprints.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "basic.drone_basic": ["drone_basic"], - "agentic.drone_agentic": ["drone_agentic"], - }, -) diff --git a/dimos/robot/drone/blueprints/agentic/__init__.py b/dimos/robot/drone/blueprints/agentic/__init__.py deleted file mode 100644 index a7386b8f45..0000000000 --- a/dimos/robot/drone/blueprints/agentic/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Agentic drone blueprint.""" - -from dimos.robot.drone.blueprints.agentic.drone_agentic import drone_agentic - -__all__ = ["drone_agentic"] diff --git a/dimos/robot/drone/blueprints/basic/__init__.py b/dimos/robot/drone/blueprints/basic/__init__.py deleted file mode 100644 index 3bf4ec60ff..0000000000 --- a/dimos/robot/drone/blueprints/basic/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Basic drone blueprint.""" - -from dimos.robot.drone.blueprints.basic.drone_basic import drone_basic - -__all__ = ["drone_basic"] diff --git a/dimos/robot/drone/camera_module.py b/dimos/robot/drone/camera_module.py index 63389aa358..5343549c66 100644 --- a/dimos/robot/drone/camera_module.py +++ b/dimos/robot/drone/camera_module.py @@ -26,9 +26,9 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.std_msgs.Header import Header from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/robot/drone/connection_module.py b/dimos/robot/drone/connection_module.py index 7b44cea607..863f719bad 100644 --- a/dimos/robot/drone/connection_module.py +++ b/dimos/robot/drone/connection_module.py @@ -26,11 +26,15 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.mapping.models import LatLon +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.drone.dji_video_stream import DJIDroneVideoStream from dimos.robot.drone.mavlink_connection import MavlinkConnection from dimos.utils.logging_config import setup_logger @@ -45,9 +49,17 @@ def _add_disposable(composite: CompositeDisposable, item: Disposable | Any) -> N composite.add(Disposable(item)) -class DroneConnectionModule(Module): +class Config(ModuleConfig): + connection_string: str = "udp:0.0.0.0:14550" + video_port: int = 5600 + outdoor: bool = False + + +class DroneConnectionModule(Module[Config]): """Module that handles drone sensor data and movement commands.""" + default_config = Config + # Inputs movecmd: In[Vector3] movecmd_twist: In[Twist] # Twist commands from tracking/navigation @@ -62,9 +74,6 @@ class DroneConnectionModule(Module): video: Out[Image] follow_object_cmd: Out[Any] - # Parameters - connection_string: str - # Internal state _odom: PoseStamped | None = None _status: dict[str, Any] = {} @@ -73,14 +82,7 @@ class DroneConnectionModule(Module): _latest_status: dict[str, Any] | None = None _latest_status_lock: threading.RLock - def __init__( - self, - connection_string: str = "udp:0.0.0.0:14550", - video_port: int = 5600, - outdoor: bool = False, - *args: Any, - **kwargs: Any, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize drone connection module. Args: @@ -88,9 +90,7 @@ def __init__( video_port: UDP port for video stream outdoor: Use GPS only mode (no velocity integration) """ - self.connection_string = connection_string - self.video_port = video_port - self.outdoor = outdoor + super().__init__(**kwargs) self.connection: MavlinkConnection | None = None self.video_stream: DJIDroneVideoStream | None = None self._latest_video_frame = None @@ -99,23 +99,24 @@ def __init__( self._latest_status_lock = threading.RLock() self._running = False self._telemetry_thread: threading.Thread | None = None - Module.__init__(self, *args, **kwargs) @rpc def start(self) -> None: """Start the connection and subscribe to sensor streams.""" # Check for replay mode - if self.connection_string == "replay": + if self.config.connection_string == "replay": from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream from dimos.robot.drone.mavlink_connection import FakeMavlinkConnection self.connection = FakeMavlinkConnection("replay") - self.video_stream = FakeDJIVideoStream(port=self.video_port) + self.video_stream = FakeDJIVideoStream(port=self.config.video_port) else: - self.connection = MavlinkConnection(self.connection_string, outdoor=self.outdoor) + self.connection = MavlinkConnection( + self.config.connection_string, outdoor=self.config.outdoor + ) self.connection.connect() - self.video_stream = DJIDroneVideoStream(port=self.video_port) + self.video_stream = DJIDroneVideoStream(port=self.config.video_port) if not self.connection.connected: logger.error("Failed to connect to drone") diff --git a/dimos/robot/drone/dji_video_stream.py b/dimos/robot/drone/dji_video_stream.py index 1810fd4212..60618ae712 100644 --- a/dimos/robot/drone/dji_video_stream.py +++ b/dimos/robot/drone/dji_video_stream.py @@ -26,7 +26,7 @@ import numpy as np from reactivex import Observable, Subject -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -214,7 +214,7 @@ def get_stream(self) -> Observable[Image]: # type: ignore[override] """ from reactivex import operators as ops - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay def _fix_format(img: Image) -> Image: if img.format == ImageFormat.BGR: diff --git a/dimos/robot/drone/drone_tracking_module.py b/dimos/robot/drone/drone_tracking_module.py index 276b636633..5798db374b 100644 --- a/dimos/robot/drone/drone_tracking_module.py +++ b/dimos/robot/drone/drone_tracking_module.py @@ -29,8 +29,9 @@ from dimos.core.module import Module from dimos.core.stream import In, Out from dimos.models.qwen.video_query import get_bbox_from_qwen_frame -from dimos.msgs.geometry_msgs import Twist, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.robot.drone.drone_visual_servoing_controller import ( DroneVisualServoingController, PIDParams, diff --git a/dimos/robot/drone/mavlink_connection.py b/dimos/robot/drone/mavlink_connection.py index d8a7c97c4a..076d9cd369 100644 --- a/dimos/robot/drone/mavlink_connection.py +++ b/dimos/robot/drone/mavlink_connection.py @@ -23,7 +23,10 @@ from pymavlink import mavutil # type: ignore[import-not-found, import-untyped] from reactivex import Subject -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.logging_config import setup_logger logger = setup_logger(level=logging.INFO) @@ -1028,7 +1031,7 @@ def __init__(self, connection_string: str) -> None: class FakeMavlink: def __init__(self) -> None: from dimos.utils.data import get_data - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay get_data("drone") diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index 88c45c9aa8..0b30c22c35 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -25,8 +25,10 @@ import numpy as np -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.robot.drone.connection_module import DroneConnectionModule from dimos.robot.drone.dji_video_stream import FakeDJIVideoStream @@ -192,7 +194,7 @@ class TestReplayMode(unittest.TestCase): def test_fake_mavlink_connection(self) -> None: """Test FakeMavlinkConnection replays messages correctly.""" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: # Mock the replay stream MagicMock() mock_messages = [ @@ -218,7 +220,7 @@ def test_fake_mavlink_connection(self) -> None: def test_fake_video_stream_no_throttling(self) -> None: """Test FakeDJIVideoStream returns replay stream with format fix.""" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: mock_stream = MagicMock() mock_replay.return_value.stream.return_value = mock_stream @@ -280,7 +282,7 @@ def test_connection_module_replay_with_messages(self) -> None: os.environ["DRONE_CONNECTION"] = "replay" - with patch("dimos.utils.testing.TimedSensorReplay") as mock_replay: + with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: # Set up MAVLink replay stream mavlink_messages = [ {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, @@ -433,7 +435,7 @@ def tearDown(self) -> None: self.foxglove_patch.stop() @patch("dimos.robot.drone.drone.ModuleCoordinator") - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") def test_full_system_with_replay(self, mock_replay, mock_coordinator_class) -> None: """Test full drone system initialization and operation with replay mode.""" # Set up mock replay data @@ -567,7 +569,7 @@ def deploy_side_effect(module_class, **kwargs): class TestDroneControlCommands(unittest.TestCase): """Test drone control commands with FakeMavlinkConnection.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: """Test arm and disarm commands work with fake connection.""" @@ -586,7 +588,7 @@ def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: result = conn.disarm() self.assertIsInstance(result, bool) # Should return bool without crashing - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: """Test takeoff and land commands with fake connection.""" @@ -605,7 +607,7 @@ def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: result = conn.land() self.assertIsNotNone(result) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_set_mode_command(self, mock_get_data, mock_replay) -> None: """Test flight mode setting with fake connection.""" @@ -626,7 +628,7 @@ def test_set_mode_command(self, mock_get_data, mock_replay) -> None: class TestDronePerception(unittest.TestCase): """Test drone perception capabilities.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_video_stream_replay(self, mock_get_data, mock_replay) -> None: """Test video stream works with replay data.""" @@ -696,7 +698,7 @@ def piped_subscribe(callback): # type: ignore[no-untyped-def] class TestDroneMovementAndOdometry(unittest.TestCase): """Test drone movement commands and odometry.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: """Test movement commands are properly converted from ROS to NED.""" @@ -716,7 +718,7 @@ def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: # Movement should be converted to NED internally # The fake connection doesn't actually send commands, but it should not crash - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_odometry_from_replay(self, mock_get_data, mock_replay) -> None: """Test odometry is properly generated from replay messages.""" @@ -763,7 +765,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIsNotNone(odom.orientation) self.assertEqual(odom.frame_id, "world") - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_position_integration_indoor(self, mock_get_data, mock_replay) -> None: """Test position integration for indoor flight without GPS.""" @@ -808,7 +810,7 @@ def replay_stream_subscribe(callback) -> None: class TestDroneStatusAndTelemetry(unittest.TestCase): """Test drone status and telemetry reporting.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_status_extraction(self, mock_get_data, mock_replay) -> None: """Test status is properly extracted from MAVLink messages.""" @@ -853,7 +855,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIn("altitude", status) self.assertIn("heading", status) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_telemetry_json_publishing(self, mock_get_data, mock_replay) -> None: """Test full telemetry is published as JSON.""" @@ -907,7 +909,7 @@ def replay_stream_subscribe(callback) -> None: class TestFlyToErrorHandling(unittest.TestCase): """Test fly_to() error handling paths.""" - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: """flying_to_target=True rejects concurrent fly_to() calls.""" @@ -921,7 +923,7 @@ def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: result = conn.fly_to(37.0, -122.0, 10.0) self.assertIn("Already flying to target", result) - @patch("dimos.utils.testing.TimedSensorReplay") + @patch("dimos.utils.testing.replay.TimedSensorReplay") @patch("dimos.utils.data.get_data") def test_error_when_not_connected(self, mock_get_data, mock_replay) -> None: """connected=False returns error immediately.""" diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 78fbdaf168..9f0fc938e5 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -13,21 +13,21 @@ # limitations under the License. import asyncio +from collections.abc import Sequence import logging import threading -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from dimos_lcm.foxglove_bridge import ( FoxgloveBridge as LCMFoxgloveBridge, ) from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.core.global_config import GlobalConfig from dimos.core.rpc_client import ModuleProxy logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) @@ -36,31 +36,23 @@ logger = setup_logger() -class FoxgloveBridge(Module): +class FoxgloveConfig(ModuleConfig): + shm_channels: Sequence[str] = () + jpeg_shm_channels: Sequence[str] = () + + +class FoxgloveBridge(Module[FoxgloveConfig]): _thread: threading.Thread _loop: asyncio.AbstractEventLoop - _global_config: "GlobalConfig | None" = None - - def __init__( - self, - *args: Any, - shm_channels: list[str] | None = None, - jpeg_shm_channels: list[str] | None = None, - global_config: "GlobalConfig | None" = None, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - self.shm_channels = shm_channels or [] - self.jpeg_shm_channels = jpeg_shm_channels or [] - self._global_config = global_config + default_config = FoxgloveConfig @rpc def start(self) -> None: super().start() # Skip if Rerun is the selected viewer - if self._global_config and self._global_config.viewer.startswith("rerun"): - logger.info("Foxglove bridge skipped", viewer=self._global_config.viewer) + if self.config.g.viewer.startswith("rerun"): + logger.info("Foxglove bridge skipped", viewer=self.config.g.viewer) return def run_bridge() -> None: @@ -78,8 +70,8 @@ def run_bridge() -> None: port=8765, debug=False, num_threads=4, - shm_channels=self.shm_channels, - jpeg_shm_channels=self.jpeg_shm_channels, + shm_channels=self.config.shm_channels, + jpeg_shm_channels=self.config.jpeg_shm_channels, ) self._loop.run_until_complete(bridge.run()) except Exception as e: diff --git a/dimos/robot/manipulators/__init__.py b/dimos/robot/manipulators/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/piper/__init__.py b/dimos/robot/manipulators/piper/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/piper/blueprints.py b/dimos/robot/manipulators/piper/blueprints.py index 68e02fc994..ead27fd54b 100644 --- a/dimos/robot/manipulators/piper/blueprints.py +++ b/dimos/robot/manipulators/piper/blueprints.py @@ -27,9 +27,11 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport from dimos.manipulation.manipulation_module import manipulation_module -from dimos.manipulation.planning.spec import RobotModelConfig -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 -from dimos.msgs.sensor_msgs import JointState +from dimos.manipulation.planning.spec.config import RobotModelConfig +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module from dimos.utils.data import LfsPath, get_data diff --git a/dimos/robot/manipulators/xarm/__init__.py b/dimos/robot/manipulators/xarm/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/manipulators/xarm/blueprints.py b/dimos/robot/manipulators/xarm/blueprints.py index 9a1732217b..e699057b44 100644 --- a/dimos/robot/manipulators/xarm/blueprints.py +++ b/dimos/robot/manipulators/xarm/blueprints.py @@ -32,8 +32,8 @@ _make_xarm7_config, ) from dimos.manipulation.manipulation_module import manipulation_module -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState from dimos.teleop.keyboard.keyboard_teleop_module import keyboard_teleop_module from dimos.utils.data import LfsPath diff --git a/dimos/robot/unitree/__init__.py b/dimos/robot/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/b1/__init__.py b/dimos/robot/unitree/b1/__init__.py deleted file mode 100644 index db85984070..0000000000 --- a/dimos/robot/unitree/b1/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. - -"""Unitree B1 robot module.""" - -from .unitree_b1 import UnitreeB1 - -__all__ = ["UnitreeB1"] diff --git a/dimos/robot/unitree/b1/connection.py b/dimos/robot/unitree/b1/connection.py index 4279f78399..11af31b296 100644 --- a/dimos/robot/unitree/b1/connection.py +++ b/dimos/robot/unitree/b1/connection.py @@ -21,15 +21,18 @@ import socket import threading import time +from typing import Any from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.std_msgs.Int32 import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.utils.logging_config import setup_logger @@ -48,13 +51,21 @@ class RobotMode: RECOVERY = 6 -class B1ConnectionModule(Module): +class B1ConnectionConfig(ModuleConfig): + ip: str = "192.168.12.1" + port: int = 9090 + test_mode: bool = False + + +class B1ConnectionModule(Module[B1ConnectionConfig]): """UDP connection module for B1 robot with standard Twist interface. Accepts standard ROS Twist messages on /cmd_vel and mode changes on /b1/mode, internally converts to B1Command format, and sends UDP packets at 50Hz. """ + default_config = B1ConnectionConfig + # LCM ports (inter-module communication) cmd_vel: In[TwistStamped] mode_cmd: In[Int32] @@ -67,9 +78,7 @@ class B1ConnectionModule(Module): ros_odom_in: In[Odometry] ros_tf: In[TFMessage] - def __init__( # type: ignore[no-untyped-def] - self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize B1 connection module. Args: @@ -77,11 +86,11 @@ def __init__( # type: ignore[no-untyped-def] port: UDP port for joystick server test_mode: If True, print commands instead of sending UDP """ - Module.__init__(self, *args, **kwargs) + super().__init__(**kwargs) - self.ip = ip - self.port = port - self.test_mode = test_mode + self.ip = self.config.ip + self.port = self.config.port + self.test_mode = self.config.test_mode self.current_mode = RobotMode.IDLE # Start in IDLE mode self._current_cmd = B1Command(mode=RobotMode.IDLE) self.cmd_lock = threading.Lock() # Thread lock for _current_cmd access @@ -383,9 +392,10 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" - def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: # type: ignore[no-untyped-def] """Initialize test connection without creating socket.""" - super().__init__(ip, port, test_mode=True, *args, **kwargs) # type: ignore[misc] + kwargs["test_mode"] = True + super().__init__(**kwargs) def _send_loop(self) -> None: """Override to provide better test output with timeout detection.""" diff --git a/dimos/robot/unitree/b1/joystick_module.py b/dimos/robot/unitree/b1/joystick_module.py index 0a72f81617..234ff129c9 100644 --- a/dimos/robot/unitree/b1/joystick_module.py +++ b/dimos/robot/unitree/b1/joystick_module.py @@ -28,8 +28,10 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.std_msgs.Int32 import Int32 class JoystickModule(Module): @@ -41,12 +43,9 @@ class JoystickModule(Module): twist_out: Out[TwistStamped] # Timestamped velocity commands mode_out: Out[Int32] # Mode changes - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - Module.__init__(self, *args, **kwargs) - self.pygame_ready = False - self.running = False - self.current_mode = 0 # Start in IDLE mode for safety + pygame_ready = False + running = False + current_mode = 0 # Start in IDLE mode for safety @rpc def start(self) -> None: diff --git a/dimos/robot/unitree/b1/test_connection.py b/dimos/robot/unitree/b1/test_connection.py index e43a3124dc..f1ff5ad861 100644 --- a/dimos/robot/unitree/b1/test_connection.py +++ b/dimos/robot/unitree/b1/test_connection.py @@ -25,7 +25,8 @@ import threading import time -from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs.Int32 import Int32 from .connection import MockB1ConnectionModule diff --git a/dimos/robot/unitree/b1/unitree_b1.py b/dimos/robot/unitree/b1/unitree_b1.py index 2c0c918942..9a6d04a7ff 100644 --- a/dimos/robot/unitree/b1/unitree_b1.py +++ b/dimos/robot/unitree/b1/unitree_b1.py @@ -26,9 +26,10 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource from dimos.core.transport import LCMTransport, ROSTransport -from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.std_msgs import Int32 +from dimos.msgs.std_msgs.Int32 import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.robot.robot import Robot from dimos.robot.unitree.b1.connection import ( @@ -92,9 +93,9 @@ def start(self) -> None: logger.info("Deploying connection module...") if self.test_mode: - self.connection = self._dimos.deploy(MockB1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(MockB1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] else: - self.connection = self._dimos.deploy(B1ConnectionModule, self.ip, self.port) # type: ignore[assignment] + self.connection = self._dimos.deploy(B1ConnectionModule, ip=self.ip, port=self.port) # type: ignore[assignment] # Configure LCM transports for connection (matching G1 pattern) self.connection.cmd_vel.transport = LCMTransport("/cmd_vel", TwistStamped) # type: ignore[attr-defined] diff --git a/dimos/robot/unitree/connection.py b/dimos/robot/unitree/connection.py index ff73d922ee..7e60080f01 100644 --- a/dimos/robot/unitree/connection.py +++ b/dimos/robot/unitree/connection.py @@ -35,9 +35,11 @@ ) from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.lowstate import LowStateMsg from dimos.robot.unitree.type.odometry import Odometry diff --git a/dimos/robot/unitree/g1/blueprints/__init__.py b/dimos/robot/unitree/g1/blueprints/__init__.py deleted file mode 100644 index ebc18da8d3..0000000000 --- a/dimos/robot/unitree/g1/blueprints/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cascaded G1 blueprints split into focused modules.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "agentic._agentic_skills": ["_agentic_skills"], - "agentic.unitree_g1_agentic": ["unitree_g1_agentic"], - "agentic.unitree_g1_agentic_sim": ["unitree_g1_agentic_sim"], - "agentic.unitree_g1_full": ["unitree_g1_full"], - "basic.unitree_g1_basic": ["unitree_g1_basic"], - "basic.unitree_g1_basic_sim": ["unitree_g1_basic_sim"], - "basic.unitree_g1_joystick": ["unitree_g1_joystick"], - "perceptive._perception_and_memory": ["_perception_and_memory"], - "perceptive.unitree_g1": ["unitree_g1"], - "perceptive.unitree_g1_detection": ["unitree_g1_detection"], - "perceptive.unitree_g1_shm": ["unitree_g1_shm"], - "perceptive.unitree_g1_sim": ["unitree_g1_sim"], - "primitive.uintree_g1_primitive_no_nav": ["uintree_g1_primitive_no_nav", "basic_no_nav"], - }, -) diff --git a/dimos/robot/unitree/g1/blueprints/agentic/__init__.py b/dimos/robot/unitree/g1/blueprints/agentic/__init__.py deleted file mode 100644 index 5e6db90d91..0000000000 --- a/dimos/robot/unitree/g1/blueprints/agentic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agentic blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/basic/__init__.py b/dimos/robot/unitree/g1/blueprints/basic/__init__.py deleted file mode 100644 index 87e6586f56..0000000000 --- a/dimos/robot/unitree/g1/blueprints/basic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Basic blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py b/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py deleted file mode 100644 index 9bd838e8b8..0000000000 --- a/dimos/robot/unitree/g1/blueprints/perceptive/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Perceptive blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py index 25bff97c73..18884bd7af 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_detection.py @@ -22,10 +22,11 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera import zed -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.hardware.sensors.camera.zed import compat as zed +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.perception.detection.moduleDB import ObjectDBModule, detection_db_module diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 5ee4d4c9d1..be67194b62 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -18,7 +18,7 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import pSHMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.foxglove_bridge import foxglove_bridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 diff --git a/dimos/robot/unitree/g1/blueprints/primitive/__init__.py b/dimos/robot/unitree/g1/blueprints/primitive/__init__.py deleted file mode 100644 index 833f767728..0000000000 --- a/dimos/robot/unitree/g1/blueprints/primitive/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Primitive blueprints for Unitree G1.""" diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index c47fdc377b..242fcaf38f 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -22,16 +22,24 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.transport import LCMTransport -from dimos.hardware.sensors.camera import zed from dimos.hardware.sensors.camera.module import camera_module # type: ignore[attr-defined] from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.hardware.sensors.camera.zed import compat as zed from dimos.mapping.costmapper import cost_mapper from dimos.mapping.voxels import voxel_mapper -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.nav_msgs import Odometry, Path -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Bool -from dimos.navigation.frontier_exploration import wavefront_frontier_explorer +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Bool import Bool +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( + wavefront_frontier_explorer, +) from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/g1/blueprints/unitree_g1_blueprints.py b/dimos/robot/unitree/g1/blueprints/unitree_g1_blueprints.py new file mode 100644 index 0000000000..4dcbedcca7 --- /dev/null +++ b/dimos/robot/unitree/g1/blueprints/unitree_g1_blueprints.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Blueprint configurations for Unitree G1 humanoid robot. + +This module provides pre-configured blueprints for various G1 robot setups, +from basic teleoperation to full autonomous agent configurations. +""" + +from dimos_lcm.foxglove_msgs import SceneUpdate +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, +) +from dimos_lcm.sensor_msgs import CameraInfo + +from dimos.agents.agent import llm_agent +from dimos.agents.cli.human import human_input +from dimos.agents.skills.navigation import navigation_skill +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core.blueprints import autoconnect +from dimos.core.transport import LCMTransport, pSHMTransport +from dimos.hardware.sensors.camera import zed +from dimos.hardware.sensors.camera.module import camera_module # type: ignore[attr-defined] +from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.mapping.costmapper import cost_mapper +from dimos.mapping.voxels import voxel_mapper +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.msgs.nav_msgs import Odometry, Path +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.std_msgs import Bool +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.navigation.frontier_exploration import wavefront_frontier_explorer +from dimos.navigation.replanning_a_star.module import replanning_a_star_planner +from dimos.navigation.rosnav import ros_nav +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module +from dimos.perception.detection.person_tracker import PersonTracker, person_tracker_module +from dimos.perception.object_tracker import object_tracking +from dimos.perception.spatial_perception import spatial_memory +from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree.connection.g1 import g1_connection +from dimos.robot.unitree.connection.g1sim import g1_sim_connection +from dimos.robot.unitree_webrtc.keyboard_teleop import keyboard_teleop +from dimos.robot.unitree_webrtc.keyboard_pose_teleop import keyboard_pose_teleop +from dimos.robot.unitree_webrtc.unitree_g1_skill_container import g1_skills +from dimos.robot.doom_teleop import doom_teleop +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis + +_basic_no_nav = ( + autoconnect( + camera_module( + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + fps=15, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ), + voxel_mapper(voxel_size=0.1), + cost_mapper(), + wavefront_frontier_explorer(), + # Visualization + websocket_vis(), + foxglove_bridge(), + ) + .global_config(n_dask_workers=4, robot_model="unitree_g1") + .transports( + { + # G1 uses Twist for movement commands + ("cmd_vel", Twist): LCMTransport("/cmd_vel", Twist), + # State estimation from ROS + ("state_estimation", Odometry): LCMTransport("/state_estimation", Odometry), + # Odometry output from ROSNavigationModule + ("odom", PoseStamped): LCMTransport("/odom", PoseStamped), + # Navigation module topics from nav_bot + ("goal_req", PoseStamped): LCMTransport("/goal_req", PoseStamped), + ("goal_active", PoseStamped): LCMTransport("/goal_active", PoseStamped), + ("path_active", Path): LCMTransport("/path_active", Path), + ("pointcloud", PointCloud2): LCMTransport("/lidar", PointCloud2), + ("global_pointcloud", PointCloud2): LCMTransport("/map", PointCloud2), + # Original navigation topics for backwards compatibility + ("goal_pose", PoseStamped): LCMTransport("/goal_pose", PoseStamped), + ("goal_reached", Bool): LCMTransport("/goal_reached", Bool), + ("cancel_goal", Bool): LCMTransport("/cancel_goal", Bool), + # Camera topics (if camera module is added) + ("color_image", Image): LCMTransport("/g1/color_image", Image), + ("camera_info", CameraInfo): LCMTransport("/g1/camera_info", CameraInfo), + } + ) +) + +basic_ros = autoconnect( + _basic_no_nav, + g1_connection(), + ros_nav(), +) + +basic_sim = autoconnect( + _basic_no_nav, + g1_sim_connection(), + replanning_a_star_planner(), +) + +_perception_and_memory = autoconnect( + spatial_memory(), + object_tracking(frame_id="camera_link"), + utilization(), +) + +standard = autoconnect( + basic_ros, + _perception_and_memory, +).global_config(n_dask_workers=8) + +standard_sim = autoconnect( + basic_sim, + _perception_and_memory, +).global_config(n_dask_workers=8) + +# Optimized configuration using shared memory for images +standard_with_shm = autoconnect( + standard.transports( + { + ("color_image", Image): pSHMTransport( + "/g1/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + } + ), + foxglove_bridge( + shm_channels=[ + "/g1/color_image#sensor_msgs.Image", + ] + ), +) + +_agentic_skills = autoconnect( + llm_agent(), + human_input(), + navigation_skill(), + g1_skills(), +) + +# Full agentic configuration with LLM and skills +agentic = autoconnect( + standard, + _agentic_skills, +) + +agentic_sim = autoconnect( + standard_sim, + _agentic_skills, +) + +# Configuration with joystick/DOOM-style control for teleoperation +with_joystick = autoconnect( + basic_ros, + doom_teleop(), # Doom-style keyboard + mouse teleop (cmd_vel / goal_pose / cancel_goal) +) + +# Detection configuration with person tracking and 3D detection +detection = ( + autoconnect( + basic_ros, + # Person detection modules with YOLO + detection3d_module( + camera_info=zed.CameraInfo.SingleWebcam, + detector=YoloPersonDetector, + ), + detectionDB_module( + camera_info=zed.CameraInfo.SingleWebcam, + filter=lambda det: det.class_id == 0, # Filter for person class only + ), + person_tracker_module( + cameraInfo=zed.CameraInfo.SingleWebcam, + ), + ) + .global_config(n_dask_workers=8) + .remappings( + [ + # Connect detection modules to camera and lidar + (Detection3DModule, "image", "color_image"), + (Detection3DModule, "pointcloud", "pointcloud"), + (ObjectDBModule, "image", "color_image"), + (ObjectDBModule, "pointcloud", "pointcloud"), + (PersonTracker, "image", "color_image"), + (PersonTracker, "detections", "detections_2d"), + ] + ) + .transports( + { + # Detection 3D module outputs + ("detections", Detection3DModule): LCMTransport( + "/detector3d/detections", Detection2DArray + ), + ("annotations", Detection3DModule): LCMTransport( + "/detector3d/annotations", ImageAnnotations + ), + ("scene_update", Detection3DModule): LCMTransport( + "/detector3d/scene_update", SceneUpdate + ), + ("detected_pointcloud_0", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/0", PointCloud2 + ), + ("detected_pointcloud_1", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/1", PointCloud2 + ), + ("detected_pointcloud_2", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/2", PointCloud2 + ), + ("detected_image_0", Detection3DModule): LCMTransport("/detector3d/image/0", Image), + ("detected_image_1", Detection3DModule): LCMTransport("/detector3d/image/1", Image), + ("detected_image_2", Detection3DModule): LCMTransport("/detector3d/image/2", Image), + # Detection DB module outputs + ("detections", ObjectDBModule): LCMTransport( + "/detectorDB/detections", Detection2DArray + ), + ("annotations", ObjectDBModule): LCMTransport( + "/detectorDB/annotations", ImageAnnotations + ), + ("scene_update", ObjectDBModule): LCMTransport("/detectorDB/scene_update", SceneUpdate), + ("detected_pointcloud_0", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/0", PointCloud2 + ), + ("detected_pointcloud_1", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/1", PointCloud2 + ), + ("detected_pointcloud_2", ObjectDBModule): LCMTransport( + "/detectorDB/pointcloud/2", PointCloud2 + ), + ("detected_image_0", ObjectDBModule): LCMTransport("/detectorDB/image/0", Image), + ("detected_image_1", ObjectDBModule): LCMTransport("/detectorDB/image/1", Image), + ("detected_image_2", ObjectDBModule): LCMTransport("/detectorDB/image/2", Image), + # Person tracker outputs + ("target", PersonTracker): LCMTransport("/person_tracker/target", PoseStamped), + } + ) +) + +# Full featured configuration with everything +full_featured = autoconnect( + standard_with_shm, + _agentic_skills, + keyboard_teleop(), +) diff --git a/dimos/robot/unitree/g1/connection.py b/dimos/robot/unitree/g1/connection.py index c2dbc6ab2d..1f3788de98 100644 --- a/dimos/robot/unitree/g1/connection.py +++ b/dimos/robot/unitree/g1/connection.py @@ -14,27 +14,33 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar +from pydantic import Field from reactivex.disposable import Disposable -from dimos import spec from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs.Twist import Twist from dimos.robot.unitree.connection import UnitreeWebRTCConnection +from dimos.spec.control import LocalPlanner from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from dimos.core.rpc_client import ModuleProxy logger = setup_logger() +_Config = TypeVar("_Config", bound=ModuleConfig) -class G1ConnectionBase(Module, ABC): +class G1Config(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + connection_type: str = Field(default_factory=lambda m: m["g"].unitree_connection_type) + + +class G1ConnectionBase(Module[_Config], ABC): """Abstract base for G1 connections (real hardware and simulation). Modules that depend on G1 connection RPC methods should reference this @@ -61,36 +67,19 @@ def move(self, twist: Twist, duration: float = 0.0) -> None: ... def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: ... -class G1Connection(G1ConnectionBase): +class G1Connection(G1ConnectionBase[G1Config]): + default_config = G1Config + cmd_vel: In[Twist] - ip: str | None - connection_type: str | None = None - _global_config: GlobalConfig - - connection: UnitreeWebRTCConnection | None - - def __init__( - self, - ip: str | None = None, - connection_type: str | None = None, - cfg: GlobalConfig = global_config, - *args: Any, - **kwargs: Any, - ) -> None: - self._global_config = cfg - self.ip = ip if ip is not None else self._global_config.robot_ip - self.connection_type = connection_type or self._global_config.unitree_connection_type - self.connection = None - super().__init__(*args, **kwargs) + connection: UnitreeWebRTCConnection | None = None @rpc def start(self) -> None: super().start() - match self.connection_type: + match self.config.connection_type: case "webrtc": - assert self.ip is not None, "IP address must be provided" - self.connection = UnitreeWebRTCConnection(self.ip) + self.connection = UnitreeWebRTCConnection(self.config.ip) case "replay": raise ValueError("Replay connection not implemented for G1 robot") case "mujoco": @@ -98,7 +87,7 @@ def start(self) -> None: "This module does not support simulation, use G1SimConnection instead" ) case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") + raise ValueError(f"Unknown connection type: {self.config.connection_type}") assert self.connection is not None self.connection.start() @@ -126,8 +115,8 @@ def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: g1_connection = G1Connection.blueprint -def deploy(dimos: ModuleCoordinator, ip: str, local_planner: spec.LocalPlanner) -> "ModuleProxy": - connection = dimos.deploy(G1Connection, ip) # type: ignore[attr-defined] +def deploy(dimos: ModuleCoordinator, ip: str, local_planner: LocalPlanner) -> "ModuleProxy": + connection = dimos.deploy(G1Connection, ip=ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection diff --git a/dimos/robot/unitree/g1/sim.py b/dimos/robot/unitree/g1/sim.py index 06950c6f0d..206a689284 100644 --- a/dimos/robot/unitree/g1/sim.py +++ b/dimos/robot/unitree/g1/sim.py @@ -16,53 +16,48 @@ import threading from threading import Thread import time -from typing import TYPE_CHECKING, Any +from typing import Any +from pydantic import Field from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config +from dimos.core.module import ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - Vector3, -) -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.g1.connection import G1ConnectionBase +from dimos.robot.unitree.mujoco_connection import MujocoConnection from dimos.robot.unitree.type.odometry import Odometry as SimOdometry from dimos.utils.logging_config import setup_logger -if TYPE_CHECKING: - from dimos.robot.unitree.mujoco_connection import MujocoConnection - logger = setup_logger() -class G1SimConnection(G1ConnectionBase): +class G1SimConfig(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + + +class G1SimConnection(G1ConnectionBase[G1SimConfig]): + default_config = G1SimConfig + cmd_vel: In[Twist] lidar: Out[PointCloud2] odom: Out[PoseStamped] color_image: Out[Image] camera_info: Out[CameraInfo] - ip: str | None - _global_config: GlobalConfig + connection: MujocoConnection | None = None _camera_info_thread: Thread | None = None - def __init__( - self, - ip: str | None = None, - cfg: GlobalConfig = global_config, - *args: Any, - **kwargs: Any, - ) -> None: - self._global_config = cfg - self.ip = ip if ip is not None else self._global_config.robot_ip - self.connection: MujocoConnection | None = None + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() - super().__init__(*args, **kwargs) @rpc def start(self) -> None: @@ -70,7 +65,7 @@ def start(self) -> None: from dimos.robot.unitree.mujoco_connection import MujocoConnection - self.connection = MujocoConnection(self._global_config) + self.connection = MujocoConnection(self.config.g) assert self.connection is not None self.connection.start() diff --git a/dimos/robot/unitree/g1/skill_container.py b/dimos/robot/unitree/g1/skill_container.py index 2bd5bcdb49..b1342ca96d 100644 --- a/dimos/robot/unitree/g1/skill_container.py +++ b/dimos/robot/unitree/g1/skill_container.py @@ -22,7 +22,8 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module -from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils.logging_config import setup_logger logger = setup_logger() diff --git a/dimos/robot/unitree/go2/blueprints/__init__.py b/dimos/robot/unitree/go2/blueprints/__init__.py deleted file mode 100644 index cbc49694f3..0000000000 --- a/dimos/robot/unitree/go2/blueprints/__init__.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Cascaded GO2 blueprints split into focused modules.""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "agentic._common_agentic": ["_common_agentic"], - "agentic.unitree_go2_agentic": ["unitree_go2_agentic"], - "agentic.unitree_go2_agentic_huggingface": ["unitree_go2_agentic_huggingface"], - "agentic.unitree_go2_agentic_mcp": ["unitree_go2_agentic_mcp"], - "agentic.unitree_go2_agentic_ollama": ["unitree_go2_agentic_ollama"], - "agentic.unitree_go2_temporal_memory": ["unitree_go2_temporal_memory"], - "basic.unitree_go2_basic": ["_linux", "_mac", "unitree_go2_basic"], - "smart._with_jpeg": ["_with_jpeglcm"], - "smart.unitree_go2": ["unitree_go2"], - "smart.unitree_go2_detection": ["unitree_go2_detection"], - "smart.unitree_go2_ros": ["unitree_go2_ros"], - "smart.unitree_go2_spatial": ["unitree_go2_spatial"], - "smart.unitree_go2_vlm_stream_test": ["unitree_go2_vlm_stream_test"], - }, -) diff --git a/dimos/robot/unitree/go2/blueprints/agentic/__init__.py b/dimos/robot/unitree/go2/blueprints/agentic/__init__.py deleted file mode 100644 index 84d1b41b23..0000000000 --- a/dimos/robot/unitree/go2/blueprints/agentic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Agentic blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py index 13a1eec1ff..24ab47ad3b 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_temporal_memory.py @@ -15,7 +15,10 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig, temporal_memory +from dimos.perception.experimental.temporal_memory.temporal_memory import ( + TemporalMemoryConfig, + temporal_memory, +) from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic # This module is imported lazily by `get_by_name()` in the CLI run command, diff --git a/dimos/robot/unitree/go2/blueprints/basic/__init__.py b/dimos/robot/unitree/go2/blueprints/basic/__init__.py deleted file mode 100644 index 79964b0297..0000000000 --- a/dimos/robot/unitree/go2/blueprints/basic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Basic blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index ce8aef2222..3325290bf7 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -21,9 +21,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.protocol.service.system_configurator import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import go2_connection from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index 015cfcdba4..908444b2fd 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -21,7 +21,7 @@ """ from dimos.core.blueprints import autoconnect -from dimos.protocol.service.system_configurator import ClockSyncConfigurator +from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis from dimos.robot.unitree.go2.fleet_connection import go2_fleet_connection from dimos.web.websocket_vis.websocket_vis_module import websocket_vis diff --git a/dimos/robot/unitree/go2/blueprints/smart/__init__.py b/dimos/robot/unitree/go2/blueprints/smart/__init__.py deleted file mode 100644 index 7d5bdbc3ab..0000000000 --- a/dimos/robot/unitree/go2/blueprints/smart/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Smart blueprints for Unitree GO2.""" diff --git a/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py b/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py index 9c77d599cf..a759b1ca50 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py +++ b/dimos/robot/unitree/go2/blueprints/smart/_with_jpeg.py @@ -14,7 +14,7 @@ # limitations under the License. from dimos.core.transport import JpegLcmTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 _with_jpeglcm = unitree_go2.transports( diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 22743ac135..80e6ec701a 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -16,7 +16,9 @@ from dimos.core.blueprints import autoconnect from dimos.mapping.costmapper import cost_mapper from dimos.mapping.voxels import voxel_mapper -from dimos.navigation.frontier_exploration import wavefront_frontier_explorer +from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( + wavefront_frontier_explorer, +) from dimos.navigation.replanning_a_star.module import replanning_a_star_planner from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py index f2edf2cb3b..a9bb7729ae 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_detection.py @@ -20,8 +20,9 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection2DArray import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 from dimos.robot.unitree.go2.connection import GO2Connection diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py index a335b1e9af..b63b8f5f6c 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2_ros.py @@ -14,8 +14,9 @@ # limitations under the License. from dimos.core.transport import ROSTransport -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.go2.blueprints.smart.unitree_go2 import unitree_go2 unitree_go2_ros = unitree_go2.transports( diff --git a/dimos/robot/unitree/go2/blueprints/unitree_go2_blueprints.py b/dimos/robot/unitree/go2/blueprints/unitree_go2_blueprints.py new file mode 100644 index 0000000000..40c960a6d5 --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/unitree_go2_blueprints.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import platform + +from dimos_lcm.foxglove_msgs.ImageAnnotations import ( + ImageAnnotations, # type: ignore[import-untyped] +) +from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate # type: ignore[import-untyped] + +from dimos.agents.agent import llm_agent +from dimos.agents.cli.human import human_input +from dimos.agents.cli.web import web_input +from dimos.agents.ollama_agent import ollama_installed +from dimos.agents.skills.navigation import navigation_skill +from dimos.agents.skills.person_follow import person_follow_skill +from dimos.agents.skills.speak_skill import speak_skill +from dimos.agents.spec import Provider +from dimos.agents.vlm_agent import vlm_agent +from dimos.agents.vlm_stream_tester import vlm_stream_tester +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core.blueprints import autoconnect +from dimos.core.transport import ( + JpegLcmTransport, + JpegShmTransport, + LCMTransport, + ROSTransport, + pSHMTransport, +) +from dimos.dashboard.tf_rerun_module import tf_rerun +from dimos.mapping.costmapper import cost_mapper +from dimos.mapping.voxels import voxel_mapper +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.navigation.frontier_exploration import ( + wavefront_frontier_explorer, +) +from dimos.navigation.replanning_a_star.module import ( + replanning_a_star_planner, +) +from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.experimental.temporal_memory import temporal_memory +from dimos.perception.spatial_perception import spatial_memory +from dimos.protocol.mcp.mcp import MCPModule +from dimos.robot.foxglove_bridge import foxglove_bridge +import dimos.robot.unitree.connection.go2 as _go2_mod +from dimos.robot.unitree.connection.go2 import GO2Connection, go2_connection +from dimos.robot.unitree_webrtc.unitree_skill_container import unitree_skills +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis +from dimos.robot.doom_teleop import doom_teleop + +_GO2_URDF = Path(_go2_mod.__file__).parent.parent / "go2" / "go2.urdf" + +# Mac has some issue with high bandwidth UDP +# +# so we use pSHMTransport for color_image +# (Could we adress this on the system config layer? Is this fixable on mac?) +mac = autoconnect( + foxglove_bridge( + shm_channels=[ + "/color_image#sensor_msgs.Image", + ] + ), +).transports( + { + ("color_image", Image): pSHMTransport( + "color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ), + } +) + + +linux = autoconnect(foxglove_bridge()) + +basic = autoconnect( + go2_connection(), + linux if platform.system() == "Linux" else mac, + websocket_vis(), + tf_rerun( + urdf_path=str(_GO2_URDF), + cameras=[ + ("world/robot/camera", "camera_optical", GO2Connection.camera_info_static), + ], + ), +).global_config(n_dask_workers=4, robot_model="unitree_go2") + +go2_with_doom = autoconnect( + basic, + doom_teleop(), # Doom-style keyboard + mouse teleop (cmd_vel / goal_pose / cancel_goal) +) + +nav = autoconnect( + basic, + voxel_mapper(voxel_size=0.1), + cost_mapper(), + replanning_a_star_planner(), + wavefront_frontier_explorer(), +).global_config(n_dask_workers=6, robot_model="unitree_go2") + +ros = nav.transports( + { + ("lidar", PointCloud2): ROSTransport("lidar", PointCloud2), + ("global_map", PointCloud2): ROSTransport("global_map", PointCloud2), + ("odom", PoseStamped): ROSTransport("odom", PoseStamped), + ("color_image", Image): ROSTransport("color_image", Image), + } +) + +detection = ( + autoconnect( + nav, + detection3d_module( + camera_info=GO2Connection.camera_info_static, + ), + ) + .remappings( + [ + (Detection3DModule, "pointcloud", "global_map"), + ] + ) + .transports( + { + # Detection 3D module outputs + ("detections", Detection3DModule): LCMTransport( + "/detector3d/detections", Detection2DArray + ), + ("annotations", Detection3DModule): LCMTransport( + "/detector3d/annotations", ImageAnnotations + ), + ("scene_update", Detection3DModule): LCMTransport( + "/detector3d/scene_update", SceneUpdate + ), + ("detected_pointcloud_0", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/0", PointCloud2 + ), + ("detected_pointcloud_1", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/1", PointCloud2 + ), + ("detected_pointcloud_2", Detection3DModule): LCMTransport( + "/detector3d/pointcloud/2", PointCloud2 + ), + ("detected_image_0", Detection3DModule): LCMTransport("/detector3d/image/0", Image), + ("detected_image_1", Detection3DModule): LCMTransport("/detector3d/image/1", Image), + ("detected_image_2", Detection3DModule): LCMTransport("/detector3d/image/2", Image), + } + ) +) + + +spatial = autoconnect( + nav, + spatial_memory(), + utilization(), +).global_config(n_dask_workers=8) + +with_jpeglcm = nav.transports( + { + ("color_image", Image): JpegLcmTransport("/color_image", Image), + } +) + +with_jpegshm = autoconnect( + nav.transports( + { + ("color_image", Image): JpegShmTransport("/color_image", quality=75), + } + ), + foxglove_bridge( + jpeg_shm_channels=[ + "/color_image#sensor_msgs.Image", + ] + ), +) + +_common_agentic = autoconnect( + human_input(), + navigation_skill(), + person_follow_skill(camera_info=GO2Connection.camera_info_static), + unitree_skills(), + web_input(), + speak_skill(), +) + +agentic = autoconnect( + spatial, + llm_agent(), + _common_agentic, +) + +agentic_mcp = autoconnect( + agentic, + MCPModule.blueprint(), +) + +agentic_ollama = autoconnect( + spatial, + llm_agent( + model="qwen3:8b", + provider=Provider.OLLAMA, # type: ignore[attr-defined] + ), + _common_agentic, +).requirements( + ollama_installed, +) + +agentic_huggingface = autoconnect( + spatial, + llm_agent( + model="Qwen/Qwen2.5-1.5B-Instruct", + provider=Provider.HUGGINGFACE, # type: ignore[attr-defined] + ), + _common_agentic, +) + +vlm_stream_test = autoconnect( + basic, + vlm_agent(), + vlm_stream_tester(), +) + +temporal_memory = autoconnect( + agentic, + temporal_memory(), +) diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index afd5c25ed6..38da7fb439 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -13,42 +13,52 @@ # limitations under the License. import logging +import sys from threading import Thread import time from typing import TYPE_CHECKING, Any, Protocol +from pydantic import Field from reactivex.disposable import Disposable from reactivex.observable import Observable import rerun.blueprint as rrb -from dimos import spec from dimos.agents.annotation import skill from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.global_config import GlobalConfig +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, pSHMTransport +from dimos.spec.perception import Camera, Pointcloud if TYPE_CHECKING: from dimos.core.rpc_client import ModuleProxy -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - Vector3, -) -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.connection import UnitreeWebRTCConnection from dimos.utils.data import get_data from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.testing.replay import TimedSensorReplay, TimedSensorStorage +if sys.version_info < (3, 13): + from typing_extensions import TypeVar +else: + from typing import TypeVar + logger = logging.getLogger(__name__) +class ConnectionConfig(ModuleConfig): + ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + + class Go2ConnectionProtocol(Protocol): """Protocol defining the interface for Go2 robot connections.""" @@ -170,7 +180,12 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de return {"status": "ok", "message": "Fake publish"} -class GO2Connection(Module, spec.Camera, spec.Pointcloud): +_Config = TypeVar("_Config", bound=ConnectionConfig, default=ConnectionConfig) + + +class GO2Connection(Module[_Config], Camera, Pointcloud): + default_config = ConnectionConfig # type: ignore[assignment] + cmd_vel: In[Twist] pointcloud: Out[PointCloud2] odom: Out[PoseStamped] @@ -180,7 +195,6 @@ class GO2Connection(Module, spec.Camera, spec.Pointcloud): connection: Go2ConnectionProtocol camera_info_static: CameraInfo = _camera_info_static() - _global_config: GlobalConfig _camera_info_thread: Thread | None = None _latest_video_frame: Image | None = None @@ -194,23 +208,13 @@ def rerun_views(cls): # type: ignore[no-untyped-def] ), ] - def __init__( # type: ignore[no-untyped-def] - self, - ip: str | None = None, - cfg: GlobalConfig = global_config, - *args, - **kwargs, - ) -> None: - self._global_config = cfg - - ip = ip if ip is not None else self._global_config.robot_ip - self.connection = make_connection(ip, self._global_config) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.connection = make_connection(self.config.ip, self.config.g) if hasattr(self.connection, "camera_info_static"): self.camera_info_static = self.connection.camera_info_static - Module.__init__(self, *args, **kwargs) - @rpc def record(self, recording_name: str) -> None: lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") # type: ignore[type-arg] @@ -246,7 +250,7 @@ def onimage(image: Image) -> None: self.standup() time.sleep(3) self.connection.balance_stand() - self.connection.set_obstacle_avoidance(self._global_config.obstacle_avoidance) + self.connection.set_obstacle_avoidance(self.config.g.obstacle_avoidance) # self.record("go2_bigoffice") @@ -339,7 +343,7 @@ def observe(self) -> Image | None: def deploy(dimos: ModuleCoordinator, ip: str, prefix: str = "") -> "ModuleProxy": from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE - connection = dimos.deploy(GO2Connection, ip) # type: ignore[attr-defined] + connection = dimos.deploy(GO2Connection, ip=ip) connection.pointcloud.transport = pSHMTransport( f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE diff --git a/dimos/robot/unitree/go2/fleet_connection.py b/dimos/robot/unitree/go2/fleet_connection.py index 4dd2be2984..f0e904648a 100644 --- a/dimos/robot/unitree/go2/fleet_connection.py +++ b/dimos/robot/unitree/go2/fleet_connection.py @@ -16,52 +16,62 @@ from __future__ import annotations +from collections.abc import Sequence +import sys from typing import TYPE_CHECKING, Any +from pydantic import Field, model_validator + from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config from dimos.robot.unitree.go2.connection import ( + ConnectionConfig, GO2Connection, Go2ConnectionProtocol, make_connection, ) from dimos.utils.logging_config import setup_logger +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing import Any as Self + if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import Twist + from dimos.msgs.geometry_msgs.Twist import Twist logger = setup_logger() -class Go2FleetConnection(GO2Connection): +class FleetConnectionConfig(ConnectionConfig): + ips: Sequence[str] = Field( + default_factory=lambda m: [ip.strip() for ip in m["g"].robot_ips.split(",")] + ) + + @model_validator(mode="after") + def set_ip_after_validation(self) -> Self: + if self.ip is None: + self.ip = self.ips[0] + return self + + +class Go2FleetConnection(GO2Connection[FleetConnectionConfig]): """Inherits all single-robot behaviour from GO2Connection for the primary (first) robot. Additional robots only receive broadcast commands (move, standup, liedown, publish_request). """ - def __init__( - self, - ips: list[str] | None = None, - cfg: GlobalConfig = global_config, - *args: object, - **kwargs: object, - ) -> None: - if not ips: - raw = cfg.robot_ips - if not raw: - raise ValueError( - "No IPs provided. Pass ips= or set ROBOT_IPS (e.g. ROBOT_IPS=10.0.0.102,10.0.0.209)" - ) - ips = [ip.strip() for ip in raw.split(",") if ip.strip()] - self._extra_ips = ips[1:] + default_config = FleetConnectionConfig + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._extra_ips = self.config.ips[1:] self._extra_connections: list[Go2ConnectionProtocol] = [] - super().__init__(ips[0], cfg, *args, **kwargs) @rpc def start(self) -> None: self._extra_connections.clear() for ip in self._extra_ips: - conn = make_connection(ip, self._global_config) + conn = make_connection(ip, self.config.g) conn.start() self._extra_connections.append(conn) @@ -69,7 +79,7 @@ def start(self) -> None: super().start() for conn in self._extra_connections: conn.balance_stand() - conn.set_obstacle_avoidance(self._global_config.obstacle_avoidance) + conn.set_obstacle_avoidance(self.config.g.obstacle_avoidance) @rpc def stop(self) -> None: diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 14be8432e5..86885bc446 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -15,13 +15,15 @@ import os import threading +from typing import Any import pygame from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" @@ -42,8 +44,8 @@ class KeyboardTeleop(Module): _clock: pygame.time.Clock | None = None _font: pygame.font.Font | None = None - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() @rpc diff --git a/dimos/robot/unitree/modular/detect.py b/dimos/robot/unitree/modular/detect.py index 99faddc946..d6ed78d101 100644 --- a/dimos/robot/unitree/modular/detect.py +++ b/dimos/robot/unitree/modular/detect.py @@ -16,8 +16,9 @@ from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.std_msgs import Header +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.std_msgs.Header import Header from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry @@ -71,8 +72,10 @@ def camera_info() -> CameraInfo: def transform_chain(odom_frame: Odometry) -> list: # type: ignore[type-arg] - from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 - from dimos.protocol.tf import TF + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + from dimos.protocol.tf.tf import TF camera_link = Transform( translation=Vector3(0.3, 0.0, 0.0), @@ -113,7 +116,7 @@ def broadcast( # type: ignore[no-untyped-def] ) from dimos.core.transport import LCMTransport - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped lidar_transport = LCMTransport("/lidar", PointCloud2) # type: ignore[var-annotated] odom_transport = LCMTransport("/odom", PoseStamped) # type: ignore[var-annotated] @@ -136,14 +139,14 @@ def broadcast( # type: ignore[no-untyped-def] def process_data(): # type: ignore[no-untyped-def] - from dimos.msgs.sensor_msgs import Image + from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.module2D import ( # type: ignore[attr-defined] Detection2DModule, build_imageannotations, ) from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data - from dimos.utils.testing import TimedSensorReplay + from dimos.utils.testing.replay import TimedSensorReplay get_data("unitree_office_walk") target = 1751591272.9654856 diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 36673ecb3e..d7c98cffd3 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -35,8 +35,12 @@ from reactivex.disposable import Disposable from dimos.core.global_config import GlobalConfig -from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, ImageFormat, PointCloud2 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.odometry import Odometry from dimos.simulation.mujoco.constants import ( LAUNCHER_PATH, @@ -126,6 +130,7 @@ def start(self) -> None: self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], + stderr=subprocess.PIPE, ) except Exception as e: diff --git a/dimos/robot/unitree/rosnav.py b/dimos/robot/unitree/rosnav.py index adc97eb4a2..b2fe42fde5 100644 --- a/dimos/robot/unitree/rosnav.py +++ b/dimos/robot/unitree/rosnav.py @@ -19,8 +19,8 @@ from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Joy +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Joy import Joy from dimos.msgs.std_msgs.Bool import Bool from dimos.utils.logging_config import setup_logger @@ -33,11 +33,7 @@ class NavigationModule(Module): goal_reached: In[Bool] cancel_goal: Out[Bool] joy: Out[Joy] - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - """Initialize NavigationModule.""" - Module.__init__(self, *args, **kwargs) - self.goal_reach = None + goal_reach = None @rpc def start(self) -> None: diff --git a/dimos/robot/unitree/testing/__init__.py b/dimos/robot/unitree/testing/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/testing/mock.py b/dimos/robot/unitree/testing/mock.py index 26e6a90018..4c5e52e4b0 100644 --- a/dimos/robot/unitree/testing/mock.py +++ b/dimos/robot/unitree/testing/mock.py @@ -21,7 +21,7 @@ from reactivex import from_iterable, interval, operators as ops from reactivex.observable import Observable -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar diff --git a/dimos/robot/unitree/testing/test_actors.py b/dimos/robot/unitree/testing/test_actors.py index ed0b05d664..77c3d7c56f 100644 --- a/dimos/robot/unitree/testing/test_actors.py +++ b/dimos/robot/unitree/testing/test_actors.py @@ -20,7 +20,7 @@ from dimos.core.module import Module from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.map import Map as Mapper diff --git a/dimos/robot/unitree/testing/test_tooling.py b/dimos/robot/unitree/testing/test_tooling.py index d1f2eeb169..40db01feee 100644 --- a/dimos/robot/unitree/testing/test_tooling.py +++ b/dimos/robot/unitree/testing/test_tooling.py @@ -19,7 +19,7 @@ from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.reactive import backpressure -from dimos.utils.testing import TimedSensorReplay +from dimos.utils.testing.replay import TimedSensorReplay @pytest.mark.tool diff --git a/dimos/robot/unitree/type/__init__.py b/dimos/robot/unitree/type/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/type/lidar.py b/dimos/robot/unitree/type/lidar.py index df2909dc38..f58268d442 100644 --- a/dimos/robot/unitree/type/lidar.py +++ b/dimos/robot/unitree/type/lidar.py @@ -20,7 +20,7 @@ import numpy as np import open3d as o3d # type: ignore[import-untyped] -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 # Backwards compatibility alias for pickled data LidarMessage = PointCloud2 diff --git a/dimos/robot/unitree/type/map.py b/dimos/robot/unitree/type/map.py index 95b2bf6f6b..da45c003f7 100644 --- a/dimos/robot/unitree/type/map.py +++ b/dimos/robot/unitree/type/map.py @@ -21,52 +21,47 @@ from reactivex.disposable import Disposable from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport from dimos.mapping.pointclouds.accumulators.general import GeneralPointCloudAccumulator from dimos.mapping.pointclouds.accumulators.protocol import PointCloudAccumulator from dimos.mapping.pointclouds.occupancy import general_occupancy -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.go2.connection import Go2ConnectionProtocol -class Map(Module): +class MapConfig(ModuleConfig): + voxel_size: float = 0.05 + cost_resolution: float = 0.05 + global_publish_interval: float | None = None + min_height: float = 0.10 + max_height: float = 0.5 + + +class Map(Module[MapConfig]): + default_config = MapConfig + lidar: In[PointCloud2] global_map: Out[PointCloud2] global_costmap: Out[OccupancyGrid] _point_cloud_accumulator: PointCloudAccumulator - _global_config: GlobalConfig _preloaded_occupancy: OccupancyGrid | None = None - def __init__( # type: ignore[no-untyped-def] - self, - voxel_size: float = 0.05, - cost_resolution: float = 0.05, - global_publish_interval: float | None = None, - min_height: float = 0.10, - max_height: float = 0.5, - cfg: GlobalConfig = global_config, - **kwargs, - ) -> None: - self.voxel_size = voxel_size - self.cost_resolution = cost_resolution - self.global_publish_interval = global_publish_interval - self.min_height = min_height - self.max_height = max_height - self._global_config = cfg - self._point_cloud_accumulator = GeneralPointCloudAccumulator( - self.voxel_size, self._global_config - ) - - if self._global_config.simulation: - self.min_height = 0.3 - + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + self.voxel_size = self.config.voxel_size + self.cost_resolution = self.config.cost_resolution + self.global_publish_interval = self.config.global_publish_interval + self.min_height = self.config.min_height + self.max_height = self.config.max_height + self._point_cloud_accumulator = GeneralPointCloudAccumulator(self.voxel_size, self.config.g) + + if self.config.g.simulation: + self.min_height = 0.3 @rpc def start(self) -> None: @@ -108,9 +103,9 @@ def _publish(self, _: Any) -> None: ) # When debugging occupancy navigation, load a predefined occupancy grid. - if self._global_config.mujoco_global_costmap_from_occupancy: + if self.config.g.mujoco_global_costmap_from_occupancy: if self._preloaded_occupancy is None: - path = Path(self._global_config.mujoco_global_costmap_from_occupancy) + path = Path(self.config.g.mujoco_global_costmap_from_occupancy) self._preloaded_occupancy = OccupancyGrid.from_path(path) occupancygrid = self._preloaded_occupancy diff --git a/dimos/robot/unitree/type/odometry.py b/dimos/robot/unitree/type/odometry.py index aa664b32ef..fabf800b6c 100644 --- a/dimos/robot/unitree/type/odometry.py +++ b/dimos/robot/unitree/type/odometry.py @@ -13,7 +13,9 @@ # limitations under the License. from typing import Literal, TypedDict -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.robot.unitree.type.timeseries import ( Timestamped, ) diff --git a/dimos/robot/unitree/type/test_lidar.py b/dimos/robot/unitree/type/test_lidar.py index 719088d77a..9a743d65b5 100644 --- a/dimos/robot/unitree/type/test_lidar.py +++ b/dimos/robot/unitree/type/test_lidar.py @@ -16,9 +16,9 @@ import itertools from typing import cast -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import RawLidarMsg, pointcloud2_from_webrtc_lidar -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay def test_init() -> None: diff --git a/dimos/robot/unitree/type/test_odometry.py b/dimos/robot/unitree/type/test_odometry.py index d0fe2b290e..8020684fb7 100644 --- a/dimos/robot/unitree/type/test_odometry.py +++ b/dimos/robot/unitree/type/test_odometry.py @@ -17,7 +17,7 @@ import pytest from dimos.robot.unitree.type.odometry import Odometry -from dimos.utils.testing import SensorReplay +from dimos.utils.testing.replay import SensorReplay _EXPECTED_TOTAL_RAD = -4.05212 diff --git a/dimos/robot/unitree/unitree_skill_container.py b/dimos/robot/unitree/unitree_skill_container.py index d2f15b9efe..a79c061567 100644 --- a/dimos/robot/unitree/unitree_skill_container.py +++ b/dimos/robot/unitree/unitree_skill_container.py @@ -24,7 +24,9 @@ from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.navigation.base import NavigationState from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree_webrtc/type/__init__.py b/dimos/robot/unitree_webrtc/type/__init__.py deleted file mode 100644 index 03ff4f4563..0000000000 --- a/dimos/robot/unitree_webrtc/type/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Compatibility re-exports for legacy dimos.robot.unitree_webrtc.type.* imports.""" - -import importlib - -__all__ = [] - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - module = importlib.import_module("dimos.robot.unitree.type") - try: - return getattr(module, name) - except AttributeError as exc: - raise AttributeError(f"No {__name__} attribute {name}") from exc - - -def __dir__() -> list[str]: - module = importlib.import_module("dimos.robot.unitree.type") - return [name for name in dir(module) if not name.startswith("_")] diff --git a/dimos/rxpy_backpressure/__init__.py b/dimos/rxpy_backpressure/__init__.py deleted file mode 100644 index ff3b1f37c0..0000000000 --- a/dimos/rxpy_backpressure/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dimos.rxpy_backpressure.backpressure import BackPressure - -__all__ = [BackPressure] diff --git a/dimos/simulation/__init__.py b/dimos/simulation/__init__.py deleted file mode 100644 index 1a68191a36..0000000000 --- a/dimos/simulation/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Try to import Isaac Sim components -try: - from .isaac import IsaacSimulator, IsaacStream -except ImportError: - IsaacSimulator = None # type: ignore[assignment, misc] - IsaacStream = None # type: ignore[assignment, misc] - -# Try to import Genesis components -try: - from .genesis import GenesisSimulator, GenesisStream -except ImportError: - GenesisSimulator = None # type: ignore[assignment, misc] - GenesisStream = None # type: ignore[assignment, misc] - -__all__ = ["GenesisSimulator", "GenesisStream", "IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/base/__init__.py b/dimos/simulation/base/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/simulation/engines/__init__.py b/dimos/simulation/engines/__init__.py deleted file mode 100644 index d437f9a7cd..0000000000 --- a/dimos/simulation/engines/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Simulation engines for manipulator backends.""" - -from __future__ import annotations - -from typing import Literal - -from dimos.simulation.engines.base import SimulationEngine -from dimos.simulation.engines.mujoco_engine import MujocoEngine - -EngineType = Literal["mujoco"] - -_ENGINES: dict[EngineType, type[SimulationEngine]] = { - "mujoco": MujocoEngine, -} - - -def get_engine(engine_name: EngineType) -> type[SimulationEngine]: - return _ENGINES[engine_name] - - -__all__ = [ - "EngineType", - "SimulationEngine", - "get_engine", -] diff --git a/dimos/simulation/engines/base.py b/dimos/simulation/engines/base.py index d450614c62..58e76ecba6 100644 --- a/dimos/simulation/engines/base.py +++ b/dimos/simulation/engines/base.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from pathlib import Path - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState class SimulationEngine(ABC): diff --git a/dimos/simulation/engines/mujoco_engine.py b/dimos/simulation/engines/mujoco_engine.py index ddaaa25ad3..2d1cdf92ac 100644 --- a/dimos/simulation/engines/mujoco_engine.py +++ b/dimos/simulation/engines/mujoco_engine.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: from pathlib import Path - from dimos.msgs.sensor_msgs import JointState + from dimos.msgs.sensor_msgs.JointState import JointState logger = setup_logger() diff --git a/dimos/msgs/visualization_msgs/__init__.py b/dimos/simulation/engines/registry.py similarity index 56% rename from dimos/msgs/visualization_msgs/__init__.py rename to dimos/simulation/engines/registry.py index 0df5006c76..deadf3a404 100644 --- a/dimos/msgs/visualization_msgs/__init__.py +++ b/dimos/simulation/engines/registry.py @@ -12,8 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Visualization message types.""" +"""Simulation engine registry.""" -from dimos.msgs.visualization_msgs.EntityMarkers import EntityMarkers +from __future__ import annotations -__all__ = ["EntityMarkers"] +from typing import Literal + +from dimos.simulation.engines.base import SimulationEngine +from dimos.simulation.engines.mujoco_engine import MujocoEngine + +EngineType = Literal["mujoco"] + +_ENGINES: dict[EngineType, type[SimulationEngine]] = { + "mujoco": MujocoEngine, +} + + +def get_engine(engine_name: EngineType) -> type[SimulationEngine]: + return _ENGINES[engine_name] diff --git a/dimos/simulation/genesis/__init__.py b/dimos/simulation/genesis/__init__.py deleted file mode 100644 index 5657d9167b..0000000000 --- a/dimos/simulation/genesis/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import GenesisSimulator -from .stream import GenesisStream - -__all__ = ["GenesisSimulator", "GenesisStream"] diff --git a/dimos/simulation/isaac/__init__.py b/dimos/simulation/isaac/__init__.py deleted file mode 100644 index 2b9bdc082d..0000000000 --- a/dimos/simulation/isaac/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .simulator import IsaacSimulator -from .stream import IsaacStream - -__all__ = ["IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/manipulators/__init__.py b/dimos/simulation/manipulators/__init__.py deleted file mode 100644 index 816de0a18d..0000000000 --- a/dimos/simulation/manipulators/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Simulation manipulator utilities.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface - from dimos.simulation.manipulators.sim_module import ( - SimulationModule, - SimulationModuleConfig, - simulation, - ) - -__all__ = [ - "SimManipInterface", - "SimulationModule", - "SimulationModuleConfig", - "simulation", -] - - -def __getattr__(name: str): # type: ignore[no-untyped-def] - if name == "SimManipInterface": - from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface - - return SimManipInterface - if name in {"SimulationModule", "SimulationModuleConfig", "simulation"}: - from dimos.simulation.manipulators.sim_module import ( - SimulationModule, - SimulationModuleConfig, - simulation, - ) - - return { - "SimulationModule": SimulationModule, - "SimulationModuleConfig": SimulationModuleConfig, - "simulation": simulation, - }[name] - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/dimos/simulation/manipulators/sim_manip_interface.py b/dimos/simulation/manipulators/sim_manip_interface.py index c829f0c864..6de570ae15 100644 --- a/dimos/simulation/manipulators/sim_manip_interface.py +++ b/dimos/simulation/manipulators/sim_manip_interface.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from dimos.hardware.manipulators.spec import ControlMode, JointLimits, ManipulatorInfo -from dimos.msgs.sensor_msgs import JointState +from dimos.msgs.sensor_msgs.JointState import JointState if TYPE_CHECKING: from dimos.simulation.engines.base import SimulationEngine diff --git a/dimos/simulation/manipulators/sim_module.py b/dimos/simulation/manipulators/sim_module.py index 831ea6ee34..5e873ba634 100644 --- a/dimos/simulation/manipulators/sim_module.py +++ b/dimos/simulation/manipulators/sim_module.py @@ -15,7 +15,6 @@ """Simulator-agnostic manipulator simulation module.""" from collections.abc import Callable -from dataclasses import dataclass from pathlib import Path import threading import time @@ -26,12 +25,13 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.sensor_msgs import JointCommand, JointState, RobotState -from dimos.simulation.engines import EngineType, get_engine +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.simulation.engines.registry import EngineType, get_engine from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface -@dataclass(kw_only=True) class SimulationModuleConfig(ModuleConfig): engine: EngineType config_path: Path | Callable[[], Path] @@ -42,7 +42,6 @@ class SimulationModule(Module[SimulationModuleConfig]): """Module wrapper for manipulator simulation across engines.""" default_config = SimulationModuleConfig - config: SimulationModuleConfig joint_state: Out[JointState] robot_state: Out[RobotState] @@ -51,8 +50,8 @@ class SimulationModule(Module[SimulationModuleConfig]): MIN_CONTROL_RATE = 1.0 - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._backend: SimManipInterface | None = None self._control_rate = 100.0 self._monitor_rate = 100.0 diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 334e2ce85f..951d4790e3 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -17,10 +17,11 @@ import pytest +from dimos.protocol.rpc.spec import RPCSpec from dimos.simulation.manipulators.sim_module import SimulationModule -class _DummyRPC: +class _DummyRPC(RPCSpec): def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] return None diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py index 21baec473f..2644dddd36 100755 --- a/dimos/simulation/mujoco/mujoco_process.py +++ b/dimos/simulation/mujoco/mujoco_process.py @@ -29,7 +29,7 @@ import open3d as o3d # type: ignore[import-untyped] from dimos.core.global_config import GlobalConfig -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.simulation.mujoco.constants import ( DEPTH_CAMERA_FOV, LIDAR_FPS, diff --git a/dimos/simulation/mujoco/person_on_track.py b/dimos/simulation/mujoco/person_on_track.py index a816b5f3ee..f19b49e4c6 100644 --- a/dimos/simulation/mujoco/person_on_track.py +++ b/dimos/simulation/mujoco/person_on_track.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import Pose +from dimos.msgs.geometry_msgs.Pose import Pose class PersonPositionController: diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py index 6dad60b4b4..f677863edf 100644 --- a/dimos/simulation/mujoco/shared_memory.py +++ b/dimos/simulation/mujoco/shared_memory.py @@ -21,7 +21,7 @@ import numpy as np from numpy.typing import NDArray -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.simulation.mujoco.constants import VIDEO_HEIGHT, VIDEO_WIDTH from dimos.utils.logging_config import setup_logger diff --git a/dimos/simulation/sim_blueprints.py b/dimos/simulation/sim_blueprints.py index 8b91ff817a..494b97ccbf 100644 --- a/dimos/simulation/sim_blueprints.py +++ b/dimos/simulation/sim_blueprints.py @@ -14,12 +14,10 @@ from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import ( # type: ignore[attr-defined] - JointCommand, - JointState, - RobotState, -) -from dimos.msgs.trajectory_msgs import JointTrajectory +from dimos.msgs.sensor_msgs.JointCommand import JointCommand +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.msgs.sensor_msgs.RobotState import RobotState +from dimos.msgs.trajectory_msgs.JointTrajectory import JointTrajectory from dimos.simulation.manipulators.sim_module import simulation from dimos.utils.data import LfsPath diff --git a/dimos/skills/__init__.py b/dimos/skills/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/skills/rest/__init__.py b/dimos/skills/rest/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py index 94f8b3726f..1fbf6266ef 100644 --- a/dimos/skills/skills.py +++ b/dimos/skills/skills.py @@ -30,12 +30,8 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# region SkillLibrary - class SkillLibrary: - # ==== Flat Skill Library ==== - def __init__(self) -> None: self.registered_skills: list[AbstractSkill] = [] self.class_skills: list[AbstractSkill] = [] @@ -111,8 +107,6 @@ def __contains__(self, skill: AbstractSkill) -> bool: def __getitem__(self, index): # type: ignore[no-untyped-def] return self.registered_skills[index] - # ==== Calling a Function ==== - _instances: dict[str, dict] = {} # type: ignore[type-arg] def create_instance(self, name: str, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -154,8 +148,6 @@ def call(self, name: str, **args): # type: ignore[no-untyped-def] logger.error(error_msg) return error_msg - # ==== Tools ==== - def get_tools(self) -> Any: tools_json = self.get_list_of_skills_as_json(list_of_skills=self.registered_skills) # print(f"{Colors.YELLOW_PRINT_COLOR}Tools JSON: {tools_json}{Colors.RESET_COLOR}") @@ -250,11 +242,6 @@ def terminate_skill(self, name: str): # type: ignore[no-untyped-def] return f"No running skill found with name: {name}" -# endregion SkillLibrary - -# region AbstractSkill - - class AbstractSkill(BaseModel): def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] print("Initializing AbstractSkill Class") @@ -289,7 +276,6 @@ def unregister_as_running(self, name: str, skill_library: SkillLibrary) -> None: """ skill_library.unregister_running_skill(name) - # ==== Tools ==== def get_tools(self) -> Any: tools_json = self.get_list_of_skills_as_json(list_of_skills=self._list_of_skills) # print(f"Tools JSON: {tools_json}") @@ -299,10 +285,6 @@ def get_list_of_skills_as_json(self, list_of_skills: list[AbstractSkill]) -> lis return list(map(pydantic_function_tool, list_of_skills)) # type: ignore[arg-type] -# endregion AbstractSkill - -# region Abstract Robot Skill - if TYPE_CHECKING: from dimos.robot.robot import Robot else: @@ -338,6 +320,3 @@ def __call__(self): # type: ignore[no-untyped-def] print( f"{Colors.BLUE_PRINT_COLOR}Robot Instance provided to Robot Skill: {self.__class__.__name__}{Colors.RESET_COLOR}" ) - - -# endregion Abstract Robot Skill diff --git a/dimos/skills/unitree/__init__.py b/dimos/skills/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py deleted file mode 100644 index 1423bec9a1..0000000000 --- a/dimos/spec/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from dimos.spec.control import LocalPlanner -from dimos.spec.mapping import GlobalCostmap, GlobalPointcloud -from dimos.spec.nav import Nav -from dimos.spec.perception import Camera, Image, Pointcloud - -__all__ = [ - "Camera", - "GlobalCostmap", - "GlobalPointcloud", - "Image", - "LocalPlanner", - "Nav", - "Pointcloud", -] diff --git a/dimos/spec/control.py b/dimos/spec/control.py index 48d58a926a..b597b4faaf 100644 --- a/dimos/spec/control.py +++ b/dimos/spec/control.py @@ -15,7 +15,7 @@ from typing import Protocol from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs.Twist import Twist class LocalPlanner(Protocol): diff --git a/dimos/spec/mapping.py b/dimos/spec/mapping.py index 0ba88cfaa9..f35778f40b 100644 --- a/dimos/spec/mapping.py +++ b/dimos/spec/mapping.py @@ -15,8 +15,8 @@ from typing import Protocol from dimos.core.stream import Out -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 class GlobalPointcloud(Protocol): diff --git a/dimos/spec/nav.py b/dimos/spec/nav.py index 08f6f42b35..ae971e7b5c 100644 --- a/dimos/spec/nav.py +++ b/dimos/spec/nav.py @@ -15,8 +15,9 @@ from typing import Protocol from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.nav_msgs import Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.nav_msgs.Path import Path class Nav(Protocol): diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 1cfe352390..4fac65ad02 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -16,7 +16,10 @@ from dimos.core.stream import Out from dimos.msgs.nav_msgs.Odometry import Odometry as OdometryMsg -from dimos.msgs.sensor_msgs import CameraInfo, Image as ImageMsg, Imu, PointCloud2 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image as ImageMsg +from dimos.msgs.sensor_msgs.Imu import Imu +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 class Image(Protocol): diff --git a/dimos/stream/__init__.py b/dimos/stream/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/stream/audio/__init__.py b/dimos/stream/audio/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/stream/frame_processor.py b/dimos/stream/frame_processor.py index ab18400c88..c2db47dc23 100644 --- a/dimos/stream/frame_processor.py +++ b/dimos/stream/frame_processor.py @@ -154,8 +154,6 @@ def visualize_flow(self, flow): # type: ignore[no-untyped-def] rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) return rgb - # ============================== - def process_stream_edge_detection(self, frame_stream): # type: ignore[no-untyped-def] return frame_stream.pipe( ops.map(self.edge_detection), diff --git a/dimos/stream/video_providers/__init__.py b/dimos/stream/video_providers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/teleop/__init__.py b/dimos/teleop/__init__.py deleted file mode 100644 index 8324113111..0000000000 --- a/dimos/teleop/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Teleoperation modules for DimOS.""" diff --git a/dimos/teleop/keyboard/__init__.py b/dimos/teleop/keyboard/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/teleop/keyboard/doom_teleop.py b/dimos/teleop/keyboard/doom_teleop.py new file mode 100644 index 0000000000..47d85dcc76 --- /dev/null +++ b/dimos/teleop/keyboard/doom_teleop.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +import os +import threading +from typing import Optional, Set + +import pygame + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.std_msgs.Bool import Bool + +# Match existing teleop modules: force X11 to avoid OpenGL threading issues. +os.environ.setdefault("SDL_VIDEODRIVER", "x11") + + +class DoomTeleop(Module): + """Keyboard + mouse teleoperation in a DOOM/FPS style. + + - Keyboard: W/S for forward/back, A/D for turn left/right, Space for e-stop. + - Mouse: + - MOUSEMOTION controls yaw based on horizontal motion. + - Optional: right-click sends a small forward goal pose, middle-click + rotates in place using a discrete goal (if goal topics are wired). + - Outputs: + - cmd_vel: continuous Twist on the standard /cmd_vel interface. + - goal_pose + cancel_goal: optional discrete pose goals on existing + navigation topics (e.g. /goal_pose, /cancel_goal). + + The module is robot-agnostic; it only publishes Twist / PoseStamped / + Bool messages and relies on existing transports in the blueprint. + """ + + # Continuous velocity interface + cmd_vel: Out[Twist] + + # Optional discrete navigation goal interface + goal_pose: Out[PoseStamped] + cancel_goal: Out[Bool] + odom: In[PoseStamped] + + _stop_event: threading.Event + _thread: threading.Thread | None = None + _screen: pygame.Surface | None = None + _clock: pygame.time.Clock | None = None + _font: pygame.font.Font | None = None + + _keys_held: Set[int] | None = None + _mouse_buttons_held: Set[int] | None = None + _has_focus: bool = True + + _current_pose: Optional[PoseStamped] = None + + # Tunable parameters + _base_linear_speed: float = 0.6 # m/s + _base_angular_speed: float = 1.2 # rad/s + _mouse_yaw_sensitivity: float = 0.003 # rad per pixel + _goal_step_forward: float = 0.5 # m + _goal_step_degrees: float = 20.0 # deg + + def __init__(self) -> None: + super().__init__() + self._stop_event = threading.Event() + + @rpc + def start(self) -> bool: + super().start() + + self._stop_event.clear() + self._keys_held = set() + self._mouse_buttons_held = set() + + # Subscribe to odom if wired, to enable discrete pose goals + if self.odom: + self.odom.subscribe(self._on_odom) + + self._thread = threading.Thread(target=self._pygame_loop, daemon=True) + self._thread.start() + + return True + + @rpc + def stop(self) -> None: + # Publish a final stop twist + stop_twist = Twist() + stop_twist.linear = Vector3(0.0, 0.0, 0.0) + stop_twist.angular = Vector3(0.0, 0.0, 0.0) + self.cmd_vel.publish(stop_twist) + + # Optionally cancel any active goal + if self.cancel_goal: + cancel_msg = Bool(data=True) + self.cancel_goal.publish(cancel_msg) + + self._stop_event.set() + + if self._thread is None: + raise RuntimeError("Cannot stop: thread was never started") + self._thread.join(2) + + super().stop() + + def _on_odom(self, pose: PoseStamped) -> None: + self._current_pose = pose + + def _clear_motion(self) -> None: + """Clear key/mouse state and send a hard stop.""" + if self._keys_held is not None: + self._keys_held.clear() + if self._mouse_buttons_held is not None: + self._mouse_buttons_held.clear() + + stop_twist = Twist() + stop_twist.linear = Vector3(0.0, 0.0, 0.0) + stop_twist.angular = Vector3(0.0, 0.0, 0.0) + self.cmd_vel.publish(stop_twist) + + def _pygame_loop(self) -> None: + if self._keys_held is None or self._mouse_buttons_held is None: + raise RuntimeError("Internal state not initialized") + + pygame.init() + self._screen = pygame.display.set_mode((640, 320), pygame.SWSURFACE) + pygame.display.set_caption("Doom Teleop (WSAD + Mouse)") + self._clock = pygame.time.Clock() + self._font = pygame.font.Font(None, 24) + + # Center the mouse and start with relative motion + pygame.mouse.set_visible(True) + pygame.mouse.get_rel() # reset relative movement + + while not self._stop_event.is_set(): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self._stop_event.set() + + elif event.type == pygame.KEYDOWN: + self._handle_keydown(event.key) + + elif event.type == pygame.KEYUP: + self._handle_keyup(event.key) + + elif event.type == pygame.MOUSEBUTTONDOWN: + self._mouse_buttons_held.add(event.button) + self._handle_mouse_button_down(event.button) + + elif event.type == pygame.MOUSEBUTTONUP: + self._mouse_buttons_held.discard(event.button) + + elif event.type == pygame.ACTIVEEVENT: + # Lose focus → immediately stop and ignore motion until focus returns. + if getattr(event, "gain", 0) == 0: + self._has_focus = False + self._clear_motion() + else: + self._has_focus = True + + # Compute continuous Twist command + twist = Twist() + twist.linear = Vector3(0.0, 0.0, 0.0) + twist.angular = Vector3(0.0, 0.0, 0.0) + + if self._has_focus: + # Keyboard WSAD mapping (DOOM-style) + if pygame.K_w in self._keys_held: + twist.linear.x += self._base_linear_speed + if pygame.K_s in self._keys_held: + twist.linear.x -= self._base_linear_speed + + # A/D = turn left/right + if pygame.K_a in self._keys_held: + twist.angular.z += self._base_angular_speed + if pygame.K_d in self._keys_held: + twist.angular.z -= self._base_angular_speed + + # Mouse horizontal motion → yaw + dx, _dy = pygame.mouse.get_rel() + twist.angular.z += float(-dx) * self._mouse_yaw_sensitivity + + # Left mouse button acts as a "drive" enable: if held with no WS, + # move forward slowly; if released, rely on keys only. + if 1 in self._mouse_buttons_held and twist.linear.x == 0.0: + twist.linear.x = 0.3 + + # Always publish at a fixed rate, even when zero, so downstream + # modules see that control has stopped. + self.cmd_vel.publish(twist) + + self._update_display(twist) + + if self._clock is None: + raise RuntimeError("_clock not initialized") + self._clock.tick(50) + + pygame.quit() + + def _handle_keydown(self, key: int) -> None: + if self._keys_held is None: + raise RuntimeError("_keys_held not initialized") + + self._keys_held.add(key) + + if key == pygame.K_SPACE: + # Emergency stop: clear all motion and cancel any goal. + self._clear_motion() + if self.cancel_goal: + cancel_msg = Bool(data=True) + self.cancel_goal.publish(cancel_msg) + print("EMERGENCY STOP!") + elif key == pygame.K_ESCAPE: + # ESC quits the teleop module. + self._stop_event.set() + + def _handle_keyup(self, key: int) -> None: + if self._keys_held is None: + raise RuntimeError("_keys_held not initialized") + self._keys_held.discard(key) + + def _handle_mouse_button_down(self, button: int) -> None: + """Map mouse button clicks to optional discrete goals.""" + if self._current_pose is None: + return + + # Right click → small forward step goal + if button == 3 and self.goal_pose: + goal = self._relative_goal( + self._current_pose, + forward=self._goal_step_forward, + yaw_degrees=0.0, + ) + self.goal_pose.publish(goal) + print("Published forward step goal from right click.") + # Middle click → in-place rotation goal + elif button == 2 and self.goal_pose: + goal = self._relative_goal( + self._current_pose, + forward=0.0, + yaw_degrees=self._goal_step_degrees, + ) + self.goal_pose.publish(goal) + print("Published rotate-in-place goal from middle click.") + + @staticmethod + def _relative_goal( + current_pose: PoseStamped, + forward: float, + yaw_degrees: float, + ) -> PoseStamped: + """Generate a new PoseStamped goal in the global frame. + + - forward is measured in the robot's local x direction. + - yaw_degrees is the desired change in yaw at the goal. + """ + local_offset = Vector3(forward, 0.0, 0.0) + global_offset = current_pose.orientation.rotate_vector(local_offset) + goal_position = current_pose.position + global_offset + + current_euler = current_pose.orientation.to_euler() + goal_yaw = current_euler.yaw + math.radians(yaw_degrees) + goal_euler = Vector3(current_euler.roll, current_euler.pitch, goal_yaw) + goal_orientation = Quaternion.from_euler(goal_euler) + + return PoseStamped( + position=goal_position, + orientation=goal_orientation, + frame_id=current_pose.frame_id, + ) + + def _update_display(self, twist: Twist) -> None: + if self._screen is None or self._font is None or self._keys_held is None: + raise RuntimeError("Display not initialized correctly") + + self._screen.fill((20, 20, 20)) + + y = 20 + focus_text = "FOCUSED" if self._has_focus else "OUT OF FOCUS (stopped)" + lines = [ + f"Doom Teleop - {focus_text}", + "", + f"Linear X: {twist.linear.x:+.2f} m/s", + f"Angular Z: {twist.angular.z:+.2f} rad/s", + "", + "Keyboard: W/S = forward/back, A/D = turn", + "Mouse: move = look/turn, LMB = slow forward drive", + "Mouse: RMB = step goal, MMB = rotate goal", + "Space: E-stop (also cancels goal), ESC: quit", + ] + + for text in lines: + color = (0, 255, 255) if text.startswith("Doom Teleop") else (230, 230, 230) + surf = self._font.render(text, True, color) + self._screen.blit(surf, (20, y)) + y += 26 + + # Simple status LED + moving = ( + abs(twist.linear.x) > 1e-3 + or abs(twist.linear.y) > 1e-3 + or abs(twist.angular.z) > 1e-3 + ) + color = (255, 0, 0) if moving else (0, 200, 0) + pygame.draw.circle(self._screen, color, (600, 30), 12) + + pygame.display.flip() + + +doom_teleop = DoomTeleop.blueprint + +__all__ = ["DoomTeleop", "doom_teleop"] + diff --git a/dimos/teleop/keyboard/keyboard_teleop_module.py b/dimos/teleop/keyboard/keyboard_teleop_module.py index cc3c301804..a90dc3cf44 100644 --- a/dimos/teleop/keyboard/keyboard_teleop_module.py +++ b/dimos/teleop/keyboard/keyboard_teleop_module.py @@ -28,7 +28,6 @@ ESC: Quit """ -from dataclasses import dataclass import os import threading import time @@ -45,7 +44,7 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" @@ -64,7 +63,6 @@ def _clamp(value: float, min_val: float, max_val: float) -> float: return max(min_val, min(max_val, value)) -@dataclass class KeyboardTeleopConfig(ModuleConfig): model_path: str = "" ee_joint_id: int = 6 @@ -84,8 +82,8 @@ class KeyboardTeleopModule(Module[KeyboardTeleopConfig]): _stop_event: threading.Event _thread: threading.Thread | None = None - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._stop_event = threading.Event() @rpc diff --git a/dimos/teleop/phone/__init__.py b/dimos/teleop/phone/__init__.py deleted file mode 100644 index 552032a47b..0000000000 --- a/dimos/teleop/phone/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Phone teleoperation module for DimOS.""" - -from dimos.teleop.phone.phone_extensions import ( - SimplePhoneTeleop, - simple_phone_teleop_module, -) -from dimos.teleop.phone.phone_teleop_module import ( - PhoneTeleopConfig, - PhoneTeleopModule, - phone_teleop_module, -) - -__all__ = [ - "PhoneTeleopConfig", - "PhoneTeleopModule", - "SimplePhoneTeleop", - "phone_teleop_module", - "simple_phone_teleop_module", -] diff --git a/dimos/teleop/phone/phone_extensions.py b/dimos/teleop/phone/phone_extensions.py index 0f52fce2e0..c5cdc1fc80 100644 --- a/dimos/teleop/phone/phone_extensions.py +++ b/dimos/teleop/phone/phone_extensions.py @@ -20,7 +20,9 @@ """ from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.teleop.phone.phone_teleop_module import PhoneTeleopModule diff --git a/dimos/teleop/phone/phone_teleop_module.py b/dimos/teleop/phone/phone_teleop_module.py index 4d40b995f3..3f32063cce 100644 --- a/dimos/teleop/phone/phone_teleop_module.py +++ b/dimos/teleop/phone/phone_teleop_module.py @@ -22,7 +22,6 @@ velocity commands via configurable gains, and publishes. """ -from dataclasses import dataclass from pathlib import Path import threading import time @@ -37,7 +36,9 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs.Bool import Bool from dimos.utils.logging_config import setup_logger from dimos.utils.path_utils import get_project_root @@ -48,7 +49,6 @@ STATIC_DIR = Path(__file__).parent / "web" / "static" -@dataclass class PhoneTeleopConfig(ModuleConfig): control_loop_hz: float = 50.0 linear_gain: float = 1.0 / 30.0 # Gain: maps degrees of tilt to m/s. 30 deg -> 1.0 m/s @@ -71,12 +71,8 @@ class PhoneTeleopModule(Module[PhoneTeleopConfig]): # Output: velocity command to robot twist_output: Out[TwistStamped] - # ------------------------------------------------------------------------- - # Initialization - # ------------------------------------------------------------------------- - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._is_engaged: bool = False self._teleop_button: bool = False @@ -100,10 +96,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._setup_routes() - # ------------------------------------------------------------------------- - # Web Server Routes - # ------------------------------------------------------------------------- - def _setup_routes(self) -> None: """Register teleop routes on the embedded web server.""" @@ -135,10 +127,6 @@ async def websocket_endpoint(ws: WebSocket) -> None: except Exception: logger.exception("WebSocket error") - # ------------------------------------------------------------------------- - # Lifecycle - # ------------------------------------------------------------------------- - @rpc def start(self) -> None: super().start() @@ -151,10 +139,6 @@ def stop(self) -> None: self._stop_server() super().stop() - # ------------------------------------------------------------------------- - # Internal engage / disengage (assumes lock is held) - # ------------------------------------------------------------------------- - def _engage(self) -> bool: """Engage: capture current sensors as initial""" if self._current_sensors is None: @@ -171,10 +155,6 @@ def _disengage(self) -> None: self._initial_sensors = None logger.info("Phone teleop disengaged") - # ------------------------------------------------------------------------- - # WebSocket Message Decoders - # ------------------------------------------------------------------------- - def _on_sensors_bytes(self, data: bytes) -> None: """Decode raw LCM bytes into TwistStamped and update sensor state.""" msg = TwistStamped.lcm_decode(data) @@ -187,10 +167,6 @@ def _on_button_bytes(self, data: bytes) -> None: with self._lock: self._teleop_button = bool(msg.data) - # ------------------------------------------------------------------------- - # Embedded Web Server - # ------------------------------------------------------------------------- - def _start_server(self) -> None: """Start the embedded FastAPI server with HTTPS in a daemon thread.""" if self._web_server_thread is not None and self._web_server_thread.is_alive(): @@ -214,10 +190,6 @@ def _stop_server(self) -> None: self._web_server_thread = None logger.info("Phone teleop web server stopped") - # ------------------------------------------------------------------------- - # Control Loop - # ------------------------------------------------------------------------- - def _start_control_loop(self) -> None: if self._control_loop_thread is not None and self._control_loop_thread.is_alive(): return @@ -256,10 +228,6 @@ def _control_loop(self) -> None: if sleep_time > 0: self._stop_event.wait(sleep_time) - # ------------------------------------------------------------------------- - # Control Loop Internal Methods - # ------------------------------------------------------------------------- - def _handle_engage(self) -> None: """ Override to customize engagement logic. diff --git a/dimos/teleop/quest/__init__.py b/dimos/teleop/quest/__init__.py deleted file mode 100644 index 83daf4347b..0000000000 --- a/dimos/teleop/quest/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Quest teleoperation module.""" - -from dimos.teleop.quest.quest_extensions import ( - ArmTeleopModule, - TwistTeleopModule, - VisualizingTeleopModule, - arm_teleop_module, - twist_teleop_module, - visualizing_teleop_module, -) -from dimos.teleop.quest.quest_teleop_module import ( - Hand, - QuestTeleopConfig, - QuestTeleopModule, - QuestTeleopStatus, - quest_teleop_module, -) -from dimos.teleop.quest.quest_types import ( - Buttons, - QuestControllerState, - ThumbstickState, -) - -__all__ = [ - "ArmTeleopModule", - "Buttons", - "Hand", - "QuestControllerState", - "QuestTeleopConfig", - "QuestTeleopModule", - "QuestTeleopStatus", - "ThumbstickState", - "TwistTeleopModule", - "VisualizingTeleopModule", - # Blueprints - "arm_teleop_module", - "quest_teleop_module", - "twist_teleop_module", - "visualizing_teleop_module", -] diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 5672a2bea0..a3aa54ee08 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -22,14 +22,10 @@ ) from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import arm_teleop_module, visualizing_teleop_module from dimos.teleop.quest.quest_types import Buttons -# ----------------------------------------------------------------------------- -# Quest Teleop Blueprints -# ----------------------------------------------------------------------------- - # Arm teleop with press-and-hold engage arm_teleop = autoconnect( arm_teleop_module(), @@ -53,10 +49,6 @@ ) -# ----------------------------------------------------------------------------- -# Teleop wired to Coordinator (TeleopIK) -# ----------------------------------------------------------------------------- - # Single XArm7 teleop: right controller -> xarm7 # Usage: dimos run arm-teleop-xarm7 diff --git a/dimos/teleop/quest/quest_extensions.py b/dimos/teleop/quest/quest_extensions.py index 68ec279efb..46e868837d 100644 --- a/dimos/teleop/quest/quest_extensions.py +++ b/dimos/teleop/quest/quest_extensions.py @@ -20,11 +20,13 @@ - VisualizingTeleopModule: Adds Rerun visualization (inherits press-and-hold engage) """ -from dataclasses import dataclass, field from typing import Any +from pydantic import Field + from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.teleop.quest.quest_teleop_module import Hand, QuestTeleopConfig, QuestTeleopModule from dimos.teleop.quest.quest_types import Buttons, QuestControllerState from dimos.teleop.utils.teleop_visualization import ( @@ -33,7 +35,6 @@ ) -@dataclass class TwistTeleopConfig(QuestTeleopConfig): """Configuration for TwistTeleopModule.""" @@ -42,7 +43,7 @@ class TwistTeleopConfig(QuestTeleopConfig): # Example implementation to show how to extend QuestTeleopModule for different teleop behaviors and outputs. -class TwistTeleopModule(QuestTeleopModule): +class TwistTeleopModule(QuestTeleopModule[TwistTeleopConfig]): """Quest teleop that outputs TwistStamped instead of PoseStamped. Config: @@ -56,7 +57,6 @@ class TwistTeleopModule(QuestTeleopModule): """ default_config = TwistTeleopConfig - config: TwistTeleopConfig left_twist: Out[TwistStamped] right_twist: Out[TwistStamped] @@ -75,7 +75,6 @@ def _publish_msg(self, hand: Hand, output_msg: PoseStamped) -> None: self.right_twist.publish(twist) -@dataclass class ArmTeleopConfig(QuestTeleopConfig): """Configuration for ArmTeleopModule. @@ -85,10 +84,10 @@ class ArmTeleopConfig(QuestTeleopConfig): hand's commands to the correct TeleopIKTask. """ - task_names: dict[str, str] = field(default_factory=dict) + task_names: dict[str, str] = Field(default_factory=dict) -class ArmTeleopModule(QuestTeleopModule): +class ArmTeleopModule(QuestTeleopModule[ArmTeleopConfig]): """Quest teleop with per-hand press-and-hold engage and task name routing. Each controller's primary button (X for left, A for right) @@ -105,10 +104,9 @@ class ArmTeleopModule(QuestTeleopModule): """ default_config = ArmTeleopConfig - config: ArmTeleopConfig - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self._task_names: dict[Hand, str] = { Hand[k.upper()]: v for k, v in self.config.task_names.items() diff --git a/dimos/teleop/quest/quest_teleop_module.py b/dimos/teleop/quest/quest_teleop_module.py index f862558424..5868aab620 100644 --- a/dimos/teleop/quest/quest_teleop_module.py +++ b/dimos/teleop/quest/quest_teleop_module.py @@ -26,7 +26,7 @@ from pathlib import Path import threading import time -from typing import Any +from typing import Any, TypeVar from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped from dimos_lcm.sensor_msgs import Joy as LCMJoy @@ -37,8 +37,8 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Joy +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Joy import Joy from dimos.teleop.quest.quest_types import Buttons, QuestControllerState from dimos.teleop.utils.teleop_transforms import webxr_to_robot from dimos.utils.logging_config import setup_logger @@ -68,7 +68,6 @@ class QuestTeleopStatus: buttons: Buttons -@dataclass class QuestTeleopConfig(ModuleConfig): """Configuration for Quest Teleoperation Module.""" @@ -76,7 +75,10 @@ class QuestTeleopConfig(ModuleConfig): server_port: int = 8443 -class QuestTeleopModule(Module[QuestTeleopConfig]): +_Config = TypeVar("_Config", bound=QuestTeleopConfig) + + +class QuestTeleopModule(Module[_Config]): """Quest Teleoperation Module for Meta Quest controllers. Receives controller data from the Quest web app via an embedded WebSocket @@ -89,19 +91,15 @@ class QuestTeleopModule(Module[QuestTeleopConfig]): - buttons: Buttons (button states for both controllers) """ - default_config = QuestTeleopConfig + default_config = QuestTeleopConfig # type: ignore[assignment] # Outputs: delta poses for each controller left_controller_output: Out[PoseStamped] right_controller_output: Out[PoseStamped] buttons: Out[Buttons] - # ------------------------------------------------------------------------- - # Initialization - # ------------------------------------------------------------------------- - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) # Engage state (per-hand) self._is_engaged: dict[Hand, bool] = {Hand.LEFT: False, Hand.RIGHT: False} @@ -129,10 +127,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._setup_routes() - # ------------------------------------------------------------------------- - # Web Server Routes - # ------------------------------------------------------------------------- - def _setup_routes(self) -> None: """Register teleop routes on the embedded web server.""" @@ -164,10 +158,6 @@ async def websocket_endpoint(ws: WebSocket) -> None: except Exception: logger.exception("WebSocket error") - # ------------------------------------------------------------------------- - # Lifecycle - # ------------------------------------------------------------------------- - @rpc def start(self) -> None: super().start() @@ -181,10 +171,6 @@ def stop(self) -> None: self._stop_server() super().stop() - # ------------------------------------------------------------------------- - # Internal engage/disengage (assumes lock is held) - # ------------------------------------------------------------------------- - def _engage(self, hand: Hand | None = None) -> bool: """Engage a hand. Assumes self._lock is held.""" hands = [hand] if hand is not None else list(Hand) @@ -217,10 +203,6 @@ def get_status(self) -> QuestTeleopStatus: buttons=Buttons.from_controllers(left, right), ) - # ------------------------------------------------------------------------- - # WebSocket Message Decoders - # ------------------------------------------------------------------------- - @staticmethod def _resolve_hand(frame_id: str) -> Hand: if frame_id == "left": @@ -251,10 +233,6 @@ def _on_joy_bytes(self, data: bytes) -> None: with self._lock: self._controllers[hand] = controller - # ------------------------------------------------------------------------- - # Embedded Web Server - # ------------------------------------------------------------------------- - def _start_server(self) -> None: """Start the embedded FastAPI server with HTTPS in a daemon thread.""" if self._web_server_thread is not None and self._web_server_thread.is_alive(): @@ -333,10 +311,6 @@ def _control_loop(self) -> None: if sleep_time > 0: self._stop_event.wait(sleep_time) - # ------------------------------------------------------------------------- - # Control Loop Internals - # ------------------------------------------------------------------------- - def _handle_engage(self) -> None: """Check for engage button press and update per-hand engage state. diff --git a/dimos/teleop/quest/quest_types.py b/dimos/teleop/quest/quest_types.py index 7fd991a76c..7e7cfc7620 100644 --- a/dimos/teleop/quest/quest_types.py +++ b/dimos/teleop/quest/quest_types.py @@ -18,8 +18,8 @@ from dataclasses import dataclass, field from typing import ClassVar -from dimos.msgs.sensor_msgs import Joy -from dimos.msgs.std_msgs import UInt32 +from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.std_msgs.UInt32 import UInt32 @dataclass diff --git a/dimos/teleop/utils/__init__.py b/dimos/teleop/utils/__init__.py deleted file mode 100644 index ae8c375e8f..0000000000 --- a/dimos/teleop/utils/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Teleoperation utilities.""" diff --git a/dimos/teleop/utils/teleop_transforms.py b/dimos/teleop/utils/teleop_transforms.py index 15fd3be120..f1e9e9381d 100644 --- a/dimos/teleop/utils/teleop_transforms.py +++ b/dimos/teleop/utils/teleop_transforms.py @@ -22,7 +22,7 @@ import numpy as np from scipy.spatial.transform import Rotation as R # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.utils.transform_utils import matrix_to_pose, pose_to_matrix if TYPE_CHECKING: diff --git a/dimos/teleop/utils/teleop_visualization.py b/dimos/teleop/utils/teleop_visualization.py index a59b0666ef..5a7acd06e9 100644 --- a/dimos/teleop/utils/teleop_visualization.py +++ b/dimos/teleop/utils/teleop_visualization.py @@ -24,7 +24,7 @@ from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: - from dimos.msgs.geometry_msgs import PoseStamped + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped logger = setup_logger() diff --git a/dimos/perception/experimental/temporal_memory/__init__.py b/dimos/test_no_init_files.py similarity index 50% rename from dimos/perception/experimental/temporal_memory/__init__.py rename to dimos/test_no_init_files.py index 1056e82e8b..39efb7ad24 100644 --- a/dimos/perception/experimental/temporal_memory/__init__.py +++ b/dimos/test_no_init_files.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Temporal memory package.""" +from dimos.constants import DIMOS_PROJECT_ROOT -from .frame_window_accumulator import Frame, FrameWindowAccumulator -from .temporal_memory import TemporalMemory, TemporalMemoryConfig, temporal_memory -from .temporal_state import TemporalState -from .window_analyzer import WindowAnalyzer -__all__ = [ - "Frame", - "FrameWindowAccumulator", - "TemporalMemory", - "TemporalMemoryConfig", - "TemporalState", - "WindowAnalyzer", - "temporal_memory", -] +def test_no_init_files(): + dimos_dir = DIMOS_PROJECT_ROOT / "dimos" + init_files = sorted(dimos_dir.rglob("__init__.py")) + if init_files: + listing = "\n".join(f" - {f.relative_to(dimos_dir)}" for f in init_files) + raise AssertionError( + f"Found __init__.py files in dimos/:\n{listing}\n\n" + "__init__.py files are not allowed because they lead to unnecessary " + "extraneous imports. Everything should be imported straight from the " + "source module." + ) diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py new file mode 100644 index 0000000000..9523c0aae2 --- /dev/null +++ b/dimos/test_no_sections.py @@ -0,0 +1,143 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +from dimos.constants import DIMOS_PROJECT_ROOT + +REPO_ROOT = str(DIMOS_PROJECT_ROOT) + +# Matches lines that are purely separator characters (=== or ---) with optional +# whitespace, e.g.: # ============= or # --------------- +SEPARATOR_LINE = re.compile(r"^\s*#\s*[-=]{10,}\s*$") + +# Matches section headers wrapped in separators, e.g.: +# # === My Section === or # ===== My Section ===== +INLINE_SECTION = re.compile(r"^\s*#\s*[-=]{3,}.+[-=]{3,}\s*$") + +# VS Code-style region markers +REGION_MARKER = re.compile(r"^\s*#\s*(region|endregion)\b") + +SCANNED_EXTENSIONS = { + ".py", + ".yml", + ".yaml", +} + +SCANNED_PREFIXES = { + "Dockerfile", +} + +IGNORED_DIRS = { + ".venv", + "venv", + "__pycache__", + "node_modules", + ".git", + "dist", + "build", + ".egg-info", + ".tox", + # third-party vendored code + "gtsam", +} + +# Lines that match section patterns but are actually programmatic / intentional. +# Each entry is (relative_path, line_substring) — if both match, the line is skipped. +WHITELIST = [ + # Sentinel marker used at runtime to detect already-converted Dockerfiles + ("dimos/core/docker_build.py", "DIMOS_SENTINEL"), +] + + +def _should_scan(path: str) -> bool: + basename = os.path.basename(path) + _, ext = os.path.splitext(basename) + if ext in SCANNED_EXTENSIONS: + return True + for prefix in SCANNED_PREFIXES: + if basename.startswith(prefix): + return True + return False + + +def _is_ignored_dir(dirpath: str) -> bool: + parts = dirpath.split(os.sep) + return bool(IGNORED_DIRS.intersection(parts)) + + +def _is_whitelisted(rel_path: str, line: str) -> bool: + for allowed_path, allowed_substr in WHITELIST: + if rel_path == allowed_path and allowed_substr in line: + return True + return False + + +def find_section_markers() -> list[tuple[str, int, str]]: + """Return a list of (file, line_number, line_text) for every section marker.""" + violations: list[tuple[str, int, str]] = [] + + for dirpath, dirnames, filenames in os.walk(REPO_ROOT): + # Prune ignored directories in-place + dirnames[:] = [d for d in dirnames if d not in IGNORED_DIRS] + + if _is_ignored_dir(dirpath): + continue + + rel_dir = os.path.relpath(dirpath, REPO_ROOT) + + for fname in filenames: + full_path = os.path.join(dirpath, fname) + rel_path = os.path.join(rel_dir, fname) + + if not _should_scan(full_path): + continue + + try: + with open(full_path, encoding="utf-8", errors="replace") as f: + for lineno, line in enumerate(f, start=1): + stripped = line.rstrip("\n") + if _is_whitelisted(rel_path, stripped): + continue + if ( + SEPARATOR_LINE.match(stripped) + or INLINE_SECTION.match(stripped) + or REGION_MARKER.match(stripped) + ): + violations.append((rel_path, lineno, stripped)) + except (OSError, UnicodeDecodeError): + continue + + return violations + + +def test_no_section_markers(): + """ + Fail if any file contains section-style comment markers. + + If a file is too complicated to be understood without sections, then the + sections should be files. We don't need "subfiles". + """ + violations = find_section_markers() + if violations: + report_lines = [ + f"Found {len(violations)} section marker(s). " + "If a file is too complicated to be understood without sections, " + 'then the sections should be files. We don\'t need "subfiles".', + "", + ] + for path, lineno, text in violations: + report_lines.append(f" {path}:{lineno}: {text.strip()}") + raise AssertionError("\n".join(report_lines)) diff --git a/dimos/types/ros_polyfill.py b/dimos/types/ros_polyfill.py index 4bad99740d..70140336b8 100644 --- a/dimos/types/ros_polyfill.py +++ b/dimos/types/ros_polyfill.py @@ -15,7 +15,7 @@ try: from geometry_msgs.msg import Vector3 # type: ignore[attr-defined] except ImportError: - from dimos.msgs.geometry_msgs import Vector3 + from dimos.msgs.geometry_msgs.Vector3 import Vector3 try: from geometry_msgs.msg import ( # type: ignore[attr-defined] diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 7de82e8f9a..e62b275dfc 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -20,7 +20,7 @@ from reactivex.scheduler import ThreadPoolScheduler from dimos.memory.timeseries.inmemory import InMemoryStore -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import ( Timestamped, TimestampedBufferCollection, @@ -28,9 +28,9 @@ to_datetime, to_ros_stamp, ) -from dimos.utils import testing from dimos.utils.data import get_data from dimos.utils.reactive import backpressure +from dimos.utils.testing.replay import TimedSensorReplay def test_timestamped_dt_method() -> None: @@ -296,7 +296,7 @@ def spy(image): # sensor reply of raw video frames video_raw = ( - testing.TimedSensorReplay( + TimedSensorReplay( "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() ) .stream(speed) diff --git a/dimos/utils/cli/__init__.py b/dimos/utils/cli/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/utils/cli/agentspy/demo_agentspy.py b/dimos/utils/cli/agentspy/demo_agentspy.py index 5229295038..851229131b 100755 --- a/dimos/utils/cli/agentspy/demo_agentspy.py +++ b/dimos/utils/cli/agentspy/demo_agentspy.py @@ -24,7 +24,7 @@ ToolMessage, ) -from dimos.protocol.pubsub import lcm # type: ignore[attr-defined] +import dimos.protocol.pubsub.impl.lcmpubsub as lcm from dimos.protocol.pubsub.impl.lcmpubsub import PickleLCM diff --git a/dimos/utils/cli/dtop.py b/dimos/utils/cli/dtop.py index fa463c15d6..64529a6bc3 100644 --- a/dimos/utils/cli/dtop.py +++ b/dimos/utils/cli/dtop.py @@ -40,10 +40,6 @@ if TYPE_CHECKING: from collections.abc import Callable -# --------------------------------------------------------------------------- -# Color helpers -# --------------------------------------------------------------------------- - def _heat(ratio: float) -> str: """Map 0..1 ratio to a cyan → yellow → red gradient.""" @@ -96,11 +92,6 @@ def _rel_style(value: float, lo: float, hi: float) -> str: return _heat(min((value - lo) / (hi - lo), 1.0)) -# --------------------------------------------------------------------------- -# Metric formatters (plain strings — color applied separately via _rel_style) -# --------------------------------------------------------------------------- - - def _fmt_pct(v: float) -> str: return f"{v:3.0f}%" @@ -128,11 +119,6 @@ def _fmt_io(v: float) -> str: return f"{v / 1048576:.0f} MB" -# --------------------------------------------------------------------------- -# Metric definitions — add a tuple here to add a new field -# (label, dict_key, format_fn) -# --------------------------------------------------------------------------- - _LINE1: list[tuple[str, str, Callable[[float], str]]] = [ ("CPU", "cpu_percent", _fmt_pct), ("PSS", "pss", _fmt_mem), @@ -162,11 +148,6 @@ def _compute_ranges(data_dicts: list[dict[str, Any]]) -> dict[str, tuple[float, return ranges -# --------------------------------------------------------------------------- -# App -# --------------------------------------------------------------------------- - - class ResourceSpyApp(App[None]): CSS_PATH = "dimos.tcss" @@ -367,10 +348,6 @@ def _make_lines( return [line1, line2] -# --------------------------------------------------------------------------- -# Preview -# --------------------------------------------------------------------------- - _PREVIEW_DATA: dict[str, Any] = { "coordinator": { "cpu_percent": 12.3, diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py index 651e8d551b..5b2d0be4ef 100755 --- a/dimos/utils/cli/lcmspy/lcmspy.py +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -13,9 +13,9 @@ # limitations under the License. from collections import deque -from dataclasses import dataclass import threading import time +from typing import Any from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.utils.human import human_bytes @@ -98,20 +98,19 @@ def __str__(self) -> str: return f"topic({self.name})" -@dataclass class LCMSpyConfig(LCMConfig): topic_history_window: float = 60.0 -class LCMSpy(LCMService, Topic): +class LCMSpy(LCMService[LCMSpyConfig], Topic): default_config = LCMSpyConfig topic = dict[str, Topic] graph_log_window: float = 1.0 topic_class: type[Topic] = Topic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - Topic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + Topic.__init__(self, name="total", history_window=self.config.topic_history_window) self.topic = {} # type: ignore[assignment] self._topic_lock = threading.Lock() @@ -150,7 +149,6 @@ def update_graphs(self, step_window: float = 1.0) -> None: self.bandwidth_history.append(kbps) -@dataclass class GraphLCMSpyConfig(LCMSpyConfig): graph_log_window: float = 1.0 @@ -162,9 +160,9 @@ class GraphLCMSpy(LCMSpy, GraphTopic): graph_log_stop_event: threading.Event = threading.Event() topic_class: type[Topic] = GraphTopic - def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) # type: ignore[attr-defined] + GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) def start(self) -> None: super().start() diff --git a/dimos/utils/decorators/__init__.py b/dimos/utils/decorators/__init__.py deleted file mode 100644 index d0f91a4939..0000000000 --- a/dimos/utils/decorators/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Decorators and accumulators for rate limiting and other utilities.""" - -from .accumulators import Accumulator, LatestAccumulator, RollingAverageAccumulator -from .decorators import CachedMethod, limit, retry, simple_mcache, ttl_cache - -__all__ = [ - "Accumulator", - "CachedMethod", - "LatestAccumulator", - "RollingAverageAccumulator", - "limit", - "retry", - "simple_mcache", - "ttl_cache", -] diff --git a/dimos/utils/decorators/test_decorators.py b/dimos/utils/decorators/test_decorators.py index 98545a2e37..8923151667 100644 --- a/dimos/utils/decorators/test_decorators.py +++ b/dimos/utils/decorators/test_decorators.py @@ -16,7 +16,8 @@ import pytest -from dimos.utils.decorators import RollingAverageAccumulator, limit, retry, simple_mcache, ttl_cache +from dimos.utils.decorators.accumulators import RollingAverageAccumulator +from dimos.utils.decorators.decorators import limit, retry, simple_mcache, ttl_cache def test_limit() -> None: diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index 42374029f2..84b91acf79 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -34,7 +34,7 @@ from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import JpegLcmTransport, LCMTransport -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import Image from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.utils.fast_image_generator import random_image diff --git a/dimos/utils/docs/test_doclinks.py b/dimos/utils/docs/test_doclinks.py index 7da6a6281b..a5a50b03e5 100644 --- a/dimos/utils/docs/test_doclinks.py +++ b/dimos/utils/docs/test_doclinks.py @@ -16,7 +16,9 @@ from pathlib import Path -from doclinks import ( +import pytest + +from dimos.utils.docs.doclinks import ( build_doc_index, build_file_index, extract_other_backticks, @@ -27,7 +29,6 @@ score_path_similarity, split_by_ignore_regions, ) -import pytest # Use the actual repo root REPO_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py index 4397e0171e..623556d6b7 100644 --- a/dimos/utils/reactive.py +++ b/dimos/utils/reactive.py @@ -24,7 +24,7 @@ from reactivex.observable import Observable from reactivex.scheduler import ThreadPoolScheduler -from dimos.rxpy_backpressure import BackPressure +from dimos.rxpy_backpressure.backpressure import BackPressure from dimos.utils.threadpool import get_scheduler T = TypeVar("T") diff --git a/dimos/utils/simple_controller.py b/dimos/utils/simple_controller.py index f95350552c..c8a6ade19d 100644 --- a/dimos/utils/simple_controller.py +++ b/dimos/utils/simple_controller.py @@ -20,9 +20,6 @@ def normalize_angle(angle: float): # type: ignore[no-untyped-def] return math.atan2(math.sin(angle), math.cos(angle)) -# ---------------------------- -# PID Controller Class -# ---------------------------- class PIDController: def __init__( # type: ignore[no-untyped-def] self, @@ -120,9 +117,6 @@ def _apply_deadband_compensation(self, error): # type: ignore[no-untyped-def] return error -# ---------------------------- -# Visual Servoing Controller Class -# ---------------------------- class VisualServoingController: def __init__(self, distance_pid_params, angle_pid_params) -> None: # type: ignore[no-untyped-def] """ diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index e55c8b20f3..9970fc5912 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -132,11 +132,6 @@ def test_pull_dir() -> None: assert sha256 == expected_hash -# ============================================================================ -# LfsPath Tests -# ============================================================================ - - def test_lfs_path_lazy_creation() -> None: """Test that creating LfsPath doesn't trigger download.""" lfs_path = LfsPath("test_data_file") diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py index 7923124c9f..77852a7bb2 100644 --- a/dimos/utils/test_transform_utils.py +++ b/dimos/utils/test_transform_utils.py @@ -16,7 +16,10 @@ import pytest from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.utils import transform_utils diff --git a/dimos/utils/testing/__init__.py b/dimos/utils/testing/__init__.py deleted file mode 100644 index 568cd3604f..0000000000 --- a/dimos/utils/testing/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "moment": ["Moment", "OutputMoment", "SensorMoment"], - "replay": ["SensorReplay", "TimedSensorReplay", "TimedSensorStorage"], - }, -) diff --git a/dimos/utils/testing/test_moment.py b/dimos/utils/testing/test_moment.py index 75f11d2657..dcca3d7d01 100644 --- a/dimos/utils/testing/test_moment.py +++ b/dimos/utils/testing/test_moment.py @@ -14,9 +14,12 @@ import time from dimos.core.transport import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Transform -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 -from dimos.protocol.tf import TF +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.protocol.tf.tf import TF from dimos.robot.unitree.go2 import connection from dimos.utils.data import get_data from dimos.utils.testing.moment import Moment, SensorMoment diff --git a/dimos/utils/testing/test_replay.py b/dimos/utils/testing/test_replay.py index e3020777b4..10ace353f7 100644 --- a/dimos/utils/testing/test_replay.py +++ b/dimos/utils/testing/test_replay.py @@ -16,7 +16,7 @@ from reactivex import operators as ops -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index ed82f6116f..bfd38ce14f 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -16,7 +16,10 @@ import numpy as np from scipy.spatial.transform import Rotation as R # type: ignore[import-untyped] -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def normalize_angle(angle: float) -> float: diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index 9bba9dd82f..12f998d96d 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -16,11 +16,11 @@ from __future__ import annotations -from dataclasses import dataclass, field +from collections.abc import Callable +from dataclasses import field from functools import lru_cache import time from typing import ( - TYPE_CHECKING, Any, Literal, Protocol, @@ -31,14 +31,18 @@ ) from reactivex.disposable import Disposable +from rerun._baseclasses import Archetype +from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] import typer from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches +from dimos.protocol.pubsub.spec import SubscribeAllCapable from dimos.utils.logging_config import setup_logger # Message types with large payloads that need rate-limiting. @@ -96,15 +100,7 @@ logger = setup_logger() -if TYPE_CHECKING: - from collections.abc import Callable - - from rerun._baseclasses import Archetype - from rerun.blueprint import Blueprint - - from dimos.protocol.pubsub.spec import SubscribeAllCapable - -BlueprintFactory: TypeAlias = "Callable[[], Blueprint]" +BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] # to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" @@ -113,8 +109,6 @@ def is_rerun_multi(data: Any) -> TypeGuard[RerunMulti]: """Check if data is a list of (entity_path, archetype) tuples.""" - from rerun._baseclasses import Archetype - return ( isinstance(data, list) and bool(data) @@ -167,7 +161,6 @@ def _resolve_viewer_mode() -> ViewerMode: return _BACKEND_TO_MODE.get(global_config.viewer, "native") -@dataclass class Config(ModuleConfig): """Configuration for RerunBridgeModule.""" @@ -190,7 +183,7 @@ class Config(ModuleConfig): blueprint: BlueprintFactory | None = _default_blueprint -class RerunBridgeModule(Module): +class RerunBridgeModule(Module[Config]): """Bridge that logs messages from pubsubs to Rerun. Spawns its own Rerun viewer and subscribes to all topics on each provided @@ -207,7 +200,6 @@ class RerunBridgeModule(Module): """ default_config = Config - config: Config @lru_cache(maxsize=256) def _visual_override_for_entity_path( @@ -218,8 +210,6 @@ def _visual_override_for_entity_path( Chains matching overrides from config, ending with final_convert which handles .to_rerun() or passes through Archetypes. """ - from rerun._baseclasses import Archetype - # find all matching converters for this entity path matches = [ fn diff --git a/dimos/web/__init__.py b/dimos/web/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/web/dimos_interface/__init__.py b/dimos/web/dimos_interface/__init__.py deleted file mode 100644 index 3bdc622cee..0000000000 --- a/dimos/web/dimos_interface/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Dimensional Interface package -""" - -import lazy_loader as lazy - -__getattr__, __dir__, __all__ = lazy.attach( - __name__, - submod_attrs={ - "api.server": ["FastAPIServer"], - }, -) diff --git a/dimos/web/dimos_interface/api/__init__.py b/dimos/web/dimos_interface/api/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/web/websocket_vis/costmap_viz.py b/dimos/web/websocket_vis/costmap_viz.py index 21309c94bc..f24628e6c7 100644 --- a/dimos/web/websocket_vis/costmap_viz.py +++ b/dimos/web/websocket_vis/costmap_viz.py @@ -19,7 +19,7 @@ import numpy as np -from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid class CostmapViz: diff --git a/dimos/web/websocket_vis/path_history.py b/dimos/web/websocket_vis/path_history.py index 39b6be08a3..c69e7e9508 100644 --- a/dimos/web/websocket_vis/path_history.py +++ b/dimos/web/websocket_vis/path_history.py @@ -17,7 +17,7 @@ This is a minimal implementation to support websocket visualization. """ -from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.geometry_msgs.Vector3 import Vector3 class PathHistory: diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 0304e3b77b..5514144570 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -46,14 +46,17 @@ ) from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig, global_config -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out +from dimos.mapping.models import LatLon from dimos.mapping.occupancy.gradient import gradient from dimos.mapping.occupancy.inflation import simple_inflate -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path from dimos.utils.logging_config import setup_logger from .optimized_costmap import OptimizedCostmapEncoder @@ -64,7 +67,11 @@ _browser_opened = False -class WebsocketVisModule(Module): +class WebsocketConfig(ModuleConfig): + port: int = 7779 + + +class WebsocketVisModule(Module[WebsocketConfig]): """ WebSocket-based visualization module for real-time navigation data. @@ -83,6 +90,8 @@ class WebsocketVisModule(Module): - click_goal: Goal position from user clicks """ + default_config = WebsocketConfig + # LCM inputs odom: In[PoseStamped] gps_location: In[LatLon] @@ -97,12 +106,7 @@ class WebsocketVisModule(Module): cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] - def __init__( - self, - port: int = 7779, - cfg: GlobalConfig = global_config, - **kwargs: Any, - ) -> None: + def __init__(self, **kwargs: Any) -> None: """Initialize the WebSocket visualization module. Args: @@ -110,9 +114,6 @@ def __init__( cfg: Optional global config for viewer settings """ super().__init__(**kwargs) - self._global_config = cfg - - self.port = port self._uvicorn_server_thread: threading.Thread | None = None self.sio: socketio.AsyncServer | None = None self.app = None @@ -127,7 +128,7 @@ def __init__( # Track GPS goal points for visualization self.gps_goal_points: list[dict[str, float]] = [] logger.info( - f"WebSocket visualization module initialized on port {port}, GPS goal tracking enabled" + f"WebSocket visualization module initialized on port {self.config.port}, GPS goal tracking enabled" ) def _start_broadcast_loop(self) -> None: @@ -157,8 +158,8 @@ def start(self) -> None: # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) # For rerun and foxglove, users access the command center manually if needed - if self._global_config.viewer == "rerun-web": - url = f"http://localhost:{self.port}/" + if self.config.g.viewer == "rerun-web": + url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") global _browser_opened @@ -234,7 +235,7 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" # If running native Rerun, redirect to standalone command center - if self._global_config.viewer != "rerun-web": + if self.config.g.viewer != "rerun-web": return RedirectResponse(url="/command-center") # Otherwise serve full dashboard with Rerun iframe @@ -355,7 +356,7 @@ def _run_uvicorn_server(self) -> None: config = uvicorn.Config( self.app, # type: ignore[arg-type] host="0.0.0.0", - port=self.port, + port=self.config.port, log_level="error", # Reduce verbosity ) self._uvicorn_server = uvicorn.Server(config) diff --git a/docker/navigation/.env.hardware b/docker/navigation/.env.hardware index 234e58545c..fc0e34581e 100644 --- a/docker/navigation/.env.hardware +++ b/docker/navigation/.env.hardware @@ -1,16 +1,8 @@ # Hardware Configuration Environment Variables # Copy this file to .env and customize for your hardware setup -# ============================================ -# NVIDIA GPU Support -# ============================================ -# Set the Docker runtime to nvidia for GPU support (it's runc by default) #DOCKER_RUNTIME=nvidia -# ============================================ -# ROS Configuration -# ============================================ -# ROS domain ID for multi-robot setups ROS_DOMAIN_ID=42 # Robot configuration ('mechanum_drive', 'unitree/unitree_g1', 'unitree/unitree_g1', etc) @@ -21,10 +13,6 @@ ROBOT_CONFIG_PATH=mechanum_drive # This can be found in the unitree app under Device settings or via network scan ROBOT_IP= -# ============================================ -# Mid-360 Lidar Configuration -# ============================================ -# Network interface connected to the lidar (e.g., eth0, enp0s3) # Find with: ip addr show LIDAR_INTERFACE=eth0 @@ -43,24 +31,12 @@ LIDAR_GATEWAY=192.168.1.1 # LIDAR_IP=192.168.123.120 # FOR UNITREE G1 EDU LIDAR_IP=192.168.1.116 -# ============================================ -# Motor Controller Configuration -# ============================================ -# Serial device for motor controller # Check with: ls /dev/ttyACM* or ls /dev/ttyUSB* MOTOR_SERIAL_DEVICE=/dev/ttyACM0 -# ============================================ -# Network Communication (for base station) -# ============================================ -# Enable WiFi buffer optimization for data transmission # Set to true if using wireless base station ENABLE_WIFI_BUFFER=false -# ============================================ -# Unitree Robot Configuration -# ============================================ -# Enable Unitree WebRTC control (for Go2, G1) #USE_UNITREE=true # Unitree robot IP address @@ -69,10 +45,6 @@ UNITREE_IP=192.168.12.1 # Unitree connection method (LocalAP or Ethernet) UNITREE_CONN=LocalAP -# ============================================ -# Navigation Options -# ============================================ -# Enable route planner (FAR planner for goal navigation) USE_ROUTE_PLANNER=false # Enable RViz visualization @@ -83,10 +55,6 @@ USE_RVIZ=false # The system will load: MAP_PATH.pcd for SLAM, MAP_PATH_tomogram.pickle for PCT planner MAP_PATH= -# ============================================ -# Device Group IDs -# ============================================ -# Group ID for /dev/input devices (joystick) # Find with: getent group input | cut -d: -f3 INPUT_GID=995 @@ -94,8 +62,4 @@ INPUT_GID=995 # Find with: getent group dialout | cut -d: -f3 DIALOUT_GID=20 -# ============================================ -# Display Configuration -# ============================================ -# X11 display (usually auto-detected) # DISPLAY=:0 diff --git a/docker/navigation/Dockerfile b/docker/navigation/Dockerfile index fa51fd621c..dc2ce54f39 100644 --- a/docker/navigation/Dockerfile +++ b/docker/navigation/Dockerfile @@ -1,39 +1,23 @@ -# ============================================================================= -# DimOS Navigation Docker Image -# ============================================================================= -# # Multi-stage build for ROS 2 navigation with SLAM support. # Includes both arise_slam and FASTLIO2 - select at runtime via LOCALIZATION_METHOD. -# # Supported configurations: # - ROS distributions: humble, jazzy # - SLAM methods: arise_slam (default), fastlio (set LOCALIZATION_METHOD=fastlio) -# # Build: # ./build.sh --humble # Build for ROS 2 Humble # ./build.sh --jazzy # Build for ROS 2 Jazzy -# # Run: # ./start.sh --hardware --route-planner # Uses arise_slam # LOCALIZATION_METHOD=fastlio ./start.sh --hardware --route-planner # Uses FASTLIO2 -# -# ============================================================================= # Build argument for ROS distribution (default: humble) ARG ROS_DISTRO=humble ARG TARGETARCH -# ----------------------------------------------------------------------------- -# Platform-specific base images # - amd64: Use osrf/ros desktop-full (includes Gazebo, full GUI) -# - arm64: Use ros-base (desktop-full not available for ARM) -# ----------------------------------------------------------------------------- FROM osrf/ros:${ROS_DISTRO}-desktop-full AS base-amd64 FROM ros:${ROS_DISTRO}-ros-base AS base-arm64 -# ----------------------------------------------------------------------------- -# STAGE 1: Build Stage - compile all C++ dependencies -# ----------------------------------------------------------------------------- FROM base-${TARGETARCH} AS builder ARG ROS_DISTRO @@ -200,9 +184,6 @@ RUN /bin/bash -c "source /opt/ros/${ROS_DISTRO}/setup.bash && \ echo 'Building with both arise_slam and FASTLIO2' && \ colcon build --cmake-args -DCMAKE_BUILD_TYPE=Release" -# ----------------------------------------------------------------------------- -# STAGE 2: Runtime Stage - minimal image for running -# ----------------------------------------------------------------------------- ARG ROS_DISTRO ARG TARGETARCH FROM base-${TARGETARCH} AS runtime diff --git a/docker/navigation/docker-compose.dev.yml b/docker/navigation/docker-compose.dev.yml index defbdae846..537e00581d 100644 --- a/docker/navigation/docker-compose.dev.yml +++ b/docker/navigation/docker-compose.dev.yml @@ -1,13 +1,6 @@ -# ============================================================================= -# DEVELOPMENT OVERRIDES - Mount source for live editing -# ============================================================================= -# # Usage: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -# # This file adds development-specific volume mounts for editing ROS configs # without rebuilding the image. -# -# ============================================================================= services: dimos_simulation: diff --git a/docs/capabilities/manipulation/adding_a_custom_arm.md b/docs/capabilities/manipulation/adding_a_custom_arm.md index 2b435a50fe..3e931a7f73 100644 --- a/docs/capabilities/manipulation/adding_a_custom_arm.md +++ b/docs/capabilities/manipulation/adding_a_custom_arm.md @@ -116,9 +116,6 @@ class YourArmAdapter: self._sdk: YourArmSDK | None = None self._control_mode: ControlMode = ControlMode.POSITION - # ========================================================================= - # Connection - # ========================================================================= def connect(self) -> bool: """Connect to hardware. Returns True on success.""" @@ -144,9 +141,6 @@ class YourArmAdapter: """Check if connected.""" return self._sdk is not None and self._sdk.is_alive() - # ========================================================================= - # Info - # ========================================================================= def get_info(self) -> ManipulatorInfo: """Get manipulator info (vendor, model, DOF).""" @@ -173,9 +167,6 @@ class YourArmAdapter: velocity_max=[math.pi] * self._dof, # rad/s ) - # ========================================================================= - # Control Mode - # ========================================================================= def set_control_mode(self, mode: ControlMode) -> bool: """Set control mode. @@ -206,9 +197,6 @@ class YourArmAdapter: """Get current control mode.""" return self._control_mode - # ========================================================================= - # State Reading - # ========================================================================= def read_joint_positions(self) -> list[float]: """Read current joint positions in radians. @@ -262,9 +250,6 @@ class YourArmAdapter: return 0, "" return code, f"YourArm error {code}" - # ========================================================================= - # Motion Control (Joint Space) - # ========================================================================= def write_joint_positions( self, @@ -300,9 +285,6 @@ class YourArmAdapter: return False return self._sdk.emergency_stop() - # ========================================================================= - # Servo Control - # ========================================================================= def write_enable(self, enable: bool) -> bool: """Enable or disable servos.""" @@ -322,10 +304,6 @@ class YourArmAdapter: return False return self._sdk.clear_errors() - # ========================================================================= - # Optional: Cartesian Control - # Return None/False if not supported by your arm. - # ========================================================================= def read_cartesian_position(self) -> dict[str, float] | None: """Read end-effector pose. @@ -343,9 +321,6 @@ class YourArmAdapter: """Command end-effector pose. Return False if not supported.""" return False - # ========================================================================= - # Optional: Gripper - # ========================================================================= def read_gripper_position(self) -> float | None: """Read gripper position in meters. Return None if no gripper.""" @@ -355,9 +330,6 @@ class YourArmAdapter: """Command gripper position in meters. Return False if no gripper.""" return False - # ========================================================================= - # Optional: Force/Torque Sensor - # ========================================================================= def read_force_torque(self) -> list[float] | None: """Read F/T sensor data [fx, fy, fz, tx, ty, tz]. None if no sensor.""" @@ -470,9 +442,6 @@ from dimos.control.coordinator import TaskConfig, control_coordinator from dimos.core.transport import LCMTransport from dimos.msgs.sensor_msgs import JointState -# ============================================================================= -# Coordinator Blueprints -# ============================================================================= # YourArm (6-DOF) — real hardware coordinator_yourarm = control_coordinator( @@ -589,9 +558,6 @@ def _make_yourarm_config( Add this to your `dimos/robot/yourarm/blueprints.py` alongside the coordinator blueprint: ```python -# ============================================================================= -# Planner Blueprints (requires URDF) -# ============================================================================= yourarm_planner = manipulation_module( robots=[_make_yourarm_config("arm", joint_prefix="arm_", coordinator_task="traj_arm")], diff --git a/docs/capabilities/navigation/native/index.md b/docs/capabilities/navigation/native/index.md index a750d3bfba..6a8c5224e9 100644 --- a/docs/capabilities/navigation/native/index.md +++ b/docs/capabilities/navigation/native/index.md @@ -118,7 +118,7 @@ All visualization layers shown together ## Blueprint Composition -The navigation stack is composed in the [`unitree_go2`](/dimos/robot/unitree/go2/blueprints/__init__.py) blueprint: +The navigation stack is composed in the [`unitree_go2`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py) blueprint: ```python fold output=assets/go2_blueprint.svg from dimos.core.blueprints import autoconnect diff --git a/docs/usage/blueprints.md b/docs/usage/blueprints.md index ed48670cb4..80a6b24b19 100644 --- a/docs/usage/blueprints.md +++ b/docs/usage/blueprints.md @@ -9,13 +9,16 @@ You create a `Blueprint` from a single module (say `ConnectionModule`) with: ```python session=blueprint-ex1 from dimos.core.blueprints import Blueprint from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig -class ConnectionModule(Module): - def __init__(self, arg1, arg2, kwarg='value') -> None: - super().__init__() +class ConnectionConfig(ModuleConfig): + arg1: int + arg2: str = "value" + +class ConnectionModule(Module[ConnectionConfig]): + default_config = ConnectionConfig -blueprint = Blueprint.create(ConnectionModule, 'arg1', 'arg2', kwarg='value') +blueprint = Blueprint.create(ConnectionModule, arg1=5, arg2="foo") ``` But the same thing can be accomplished more succinctly as: @@ -37,9 +40,11 @@ You can link multiple blueprints together with `autoconnect`: ```python session=blueprint-ex1 from dimos.core.blueprints import autoconnect -class Module1(Module): - def __init__(self, arg1) -> None: - super().__init__() +class Config(ModuleConfig): + arg1: int = 42 + +class Module1(Module[Config]): + default_config = Config class Module2(Module): ... @@ -206,7 +211,7 @@ blueprint.remappings([ ## Overriding global configuration. -Each module can optionally take global config as a `cfg` option in `__init__`. E.g.: +Each module includes the global config available as `self.config.g`. E.g.: ```python session=blueprint-ex3 from dimos.core.core import rpc @@ -214,9 +219,8 @@ from dimos.core.module import Module from dimos.core.global_config import GlobalConfig class ModuleA(Module): - - def __init__(self, cfg: GlobalConfig | None = None): - self._global_config: GlobalConfig = cfg + def some_method(self): + print(self.config.g.viewer) ... ``` diff --git a/docs/usage/configuration.md b/docs/usage/configuration.md index fe6e0029f0..384ef5240e 100644 --- a/docs/usage/configuration.md +++ b/docs/usage/configuration.md @@ -2,23 +2,19 @@ Dimos provides a `Configurable` base class. See [`service/spec.py`](/dimos/protocol/service/spec.py#L22). -This allows using dataclasses to specify configuration structure and default values per module. +This allows using pydantic models to specify configuration structure and default values per module. ```python from dimos.protocol.service import Configurable +from dimos.protocol.service.spec import BaseConfig from rich import print -from dataclasses import dataclass -@dataclass -class Config(): +class Config(BaseConfig): x: int = 3 hello: str = "world" -class MyClass(Configurable): +class MyClass(Configurable[Config]): default_config = Config - config: Config - def __init__(self, **kwargs) -> None: - super().__init__(**kwargs) myclass1 = MyClass() print(myclass1.config) @@ -48,22 +44,19 @@ Error: Config.__init__() got an unexpected keyword argument 'something' [Modules](/docs/usage/modules.md) inherit from `Configurable`, so all of the above applies. Module configs should inherit from `ModuleConfig` ([`core/module.py`](/dimos/core/module.py#L40)), which includes shared configuration for all modules like transport protocols, frame IDs, etc. ```python -from dataclasses import dataclass from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out from rich import print -@dataclass class Config(ModuleConfig): frame_id: str = "world" publish_interval: float = 0 voxel_size: float = 0.05 device: str = "CUDA:0" -class MyModule(Module): +class MyModule(Module[Config]): default_config = Config - config: Config def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/docs/usage/native_modules.md b/docs/usage/native_modules.md index 929ac18424..de12417b4a 100644 --- a/docs/usage/native_modules.md +++ b/docs/usage/native_modules.md @@ -17,7 +17,6 @@ Python side native module is just a definition of a **config** dataclass and **m Both the config dataclass and pubsub topics get converted to CLI args passed down to your executable once the module is started. ```python no-result session=nativemodule -from dataclasses import dataclass from dimos.core.stream import Out from dimos.core.transport import LCMTransport from dimos.core.native_module import NativeModule, NativeModuleConfig @@ -25,13 +24,12 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.sensor_msgs.Imu import Imu import time -@dataclass(kw_only=True) class MyLidarConfig(NativeModuleConfig): executable: str = "./build/my_lidar" host_ip: str = "192.168.1.5" frequency: float = 10.0 -class MyLidar(NativeModule): +class MyLidar(NativeModule[MyLidarConfig]): default_config = MyLidarConfig pointcloud: Out[PointCloud2] imu: Out[Imu] @@ -98,18 +96,18 @@ When `stop()` is called, the process receives SIGTERM. If it doesn't exit within Any field you add to your config subclass automatically becomes a `--name value` CLI arg. Fields from `NativeModuleConfig` itself (like `executable`, `extra_args`, `cwd`) are **not** passed — they're for Python-side orchestration only. ```python skip +from pydantic import Field class LogFormat(enum.Enum): TEXT = "text" JSON = "json" -@dataclass(kw_only=True) class MyConfig(NativeModuleConfig): executable: str = "./build/my_module" # relative or absolute path to your executable host_ip: str = "192.168.1.5" # becomes --host_ip 192.168.1.5 frequency: float = 10.0 # becomes --frequency 10.0 enable_imu: bool = True # becomes --enable_imu true - filters: list[str] = field(default_factory=lambda: ["a", "b"]) # becomes --filters a,b + filters: list[str] = Field(default_factory=lambda: ["a", "b"]) # becomes --filters a,b ``` - `None` values are skipped. @@ -121,16 +119,11 @@ class MyConfig(NativeModuleConfig): If a config field shouldn't be a CLI arg, add it to `cli_exclude`: ```python skip -@dataclass(kw_only=True) class FastLio2Config(NativeModuleConfig): executable: str = "./build/fastlio2" config: str = "mid360.yaml" # human-friendly name - config_path: str | None = None # resolved absolute path + config_path: str = Field(default_factory=lambda m: str(Path(m["config"]).resolve())) cli_exclude: frozenset[str] = frozenset({"config"}) # only config_path is passed - - def __post_init__(self) -> None: - if self.config_path is None: - self.config_path = str(Path(self.config).resolve()) ``` ## Using with blueprints @@ -173,7 +166,6 @@ NativeModule pipes subprocess stdout and stderr through structlog: If your native binary outputs structured JSON lines, set `log_format=LogFormat.JSON`: ```python skip -@dataclass(kw_only=True) class MyConfig(NativeModuleConfig): executable: str = "./build/my_module" log_format: LogFormat = LogFormat.JSON @@ -236,7 +228,6 @@ from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.sensor_msgs.Imu import Imu from dimos.spec import perception -@dataclass(kw_only=True) class Mid360Config(NativeModuleConfig): cwd: str | None = "cpp" executable: str = "result/bin/mid360_native" @@ -248,7 +239,7 @@ class Mid360Config(NativeModuleConfig): frame_id: str = "lidar_link" # ... SDK port configuration -class Mid360(NativeModule, perception.Lidar, perception.IMU): +class Mid360(NativeModule[Mid360Config], perception.Lidar, perception.IMU): default_config = Mid360Config lidar: Out[PointCloud2] imu: Out[Imu] @@ -271,7 +262,6 @@ If `build_command` is set in the module config, and the executable doesn't exist Build output is piped through structlog (stdout at `info`, stderr at `warning`). ```python skip -@dataclass(kw_only=True) class MyLidarConfig(NativeModuleConfig): cwd: str | None = "cpp" executable: str = "result/bin/my_lidar" diff --git a/docs/usage/transforms.md b/docs/usage/transforms.md index 8b98e4e81d..8a3f708cd2 100644 --- a/docs/usage/transforms.md +++ b/docs/usage/transforms.md @@ -173,9 +173,7 @@ Modules in DimOS automatically get a `frame_id` property. This is controlled by ```python from dimos.core.module import Module, ModuleConfig -from dataclasses import dataclass -@dataclass class MyModuleConfig(ModuleConfig): frame_id: str = "sensor_link" frame_id_prefix: str | None = None @@ -228,8 +226,6 @@ from dimos.core.module_coordinator import ModuleCoordinator class RobotBaseModule(Module): """Publishes the robot's position in the world frame at 10Hz.""" - def __init__(self, **kwargs: object) -> None: - super().__init__(**kwargs) @rpc def start(self) -> None: diff --git a/docs/usage/transports/index.md b/docs/usage/transports/index.md index b930671906..db931872bd 100644 --- a/docs/usage/transports/index.md +++ b/docs/usage/transports/index.md @@ -81,7 +81,7 @@ We’ll go through these layers top-down. See [Blueprints](/docs/usage/blueprints.md) for the blueprint API. -From [`unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py). +From [`unitree/go2/blueprints/smart/unitree_go2.py`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py). Example: rebind a few streams from the default `LCMTransport` to `ROSTransport` (defined at [`transport.py`](/dimos/core/transport.py#L226)) so you can visualize in **rviz2**. diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 809f7881e4..57ad460354 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -96,7 +96,7 @@ This happens on lower-end hardware (NUC, older laptops) with large maps. ### Increase Voxel Size -Edit [`dimos/robot/unitree/go2/blueprints/__init__.py`](/dimos/robot/unitree/go2/blueprints/__init__.py) line 82: +Edit [`dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py`](/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py): ```python # Before (high detail, slower on large maps) diff --git a/examples/simplerobot/simplerobot.py b/examples/simplerobot/simplerobot.py index 010b3bf2eb..517684d7cd 100644 --- a/examples/simplerobot/simplerobot.py +++ b/examples/simplerobot/simplerobot.py @@ -22,17 +22,19 @@ Subscribes to Twist commands and publishes PoseStamped. """ -from dataclasses import dataclass import math import time -from typing import Any import reactivex as rx from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Twist, Vector3 +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 def apply_twist(pose: Pose, twist: Twist, dt: float) -> Pose: @@ -48,7 +50,6 @@ def apply_twist(pose: Pose, twist: Twist, dt: float) -> Pose: ) -@dataclass class SimpleRobotConfig(ModuleConfig): frame_id: str = "world" update_rate: float = 30.0 @@ -61,12 +62,9 @@ class SimpleRobot(Module[SimpleRobotConfig]): cmd_vel: In[Twist] pose: Out[PoseStamped] default_config = SimpleRobotConfig - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self._pose = Pose() - self._vel = Twist() - self._vel_time = 0.0 + _pose = Pose() + _vel = Twist() + _vel_time = 0.0 @rpc def start(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 017562a78a..722e3b0485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -354,8 +354,12 @@ exclude = [ [tool.ruff.lint] extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] -# TODO: All of these should be fixed, but it's easier commit autofixes first -ignore = ["A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007"] +ignore = [ + # TODO: All of these should be fixed, but it's easier commit autofixes first + "A001", "A002", "B008", "B017", "B019", "B024", "B026", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N817", "N999", "RUF003", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "UP007", + # This breaks runtime type checking (both for us, and users introspecting our APIs) + "TC001", "TC002", "TC003" +] [tool.ruff.lint.per-file-ignores] "dimos/models/Detic/*" = ["ALL"] @@ -373,6 +377,7 @@ python_version = "3.12" incremental = true strict = true warn_unused_ignores = false +explicit_package_bases = true exclude = "^dimos/models/Detic(/|$)|^dimos/rxpy_backpressure(/|$)|.*/test_.|.*/conftest.py*" [[tool.mypy.overrides]] @@ -425,7 +430,7 @@ env = [ "GOOGLE_MAPS_API_KEY=AIzafake_google_key", "PYTHONWARNINGS=ignore:cupyx.jit.rawkernel is experimental:FutureWarning", ] -addopts = "-v -r a -p no:warnings --color=yes -m 'not (tool or slow or mujoco)'" +addopts = "-v -r a -p no:warnings -p no:launch_testing -p no:launch_ros --import-mode=importlib --color=yes -m 'not (tool or slow or mujoco)'" asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function"