Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions dimos/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from dimos.core.stream import In, Out
from dimos.protocol.rpc import RPCSpec
from dimos.spec.utils import Spec
from dimos.utils.logging_config import setup_logger

logger = setup_logger()

if TYPE_CHECKING:
from langchain_core.language_models import BaseChatModel
Expand All @@ -54,6 +57,7 @@ class Agent(Module[AgentConfig]):
_lock: RLock
_state_graph: CompiledStateGraph[Any, Any, Any, Any] | None
_message_queue: Queue[BaseMessage]
_skill_registry: dict[str, SkillInfo]
_history: list[BaseMessage]
_thread: Thread
_stop_event: Event
Expand All @@ -64,6 +68,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._state_graph = None
self._message_queue = Queue()
self._history = []
self._skill_registry = {}
self._thread = Thread(
target=self._thread_loop,
name=f"{self.__class__.__name__}-thread",
Expand Down Expand Up @@ -102,13 +107,16 @@ def on_system_modules(self, modules: list[RPCClient]) -> None:

model = MockModel(json_path=self.config.model_fixture)

skills = [skill for module in modules for skill in (module.get_skills() or [])]
self._skill_registry = {skill.func_name: skill for skill in skills}

with self._lock:
# Here to prevent unwanted imports in the file.
from langchain.agents import create_agent

self._state_graph = create_agent(
model=model,
tools=_get_tools_from_modules(self, modules, self.rpc),
tools=[_skill_to_tool(self, skill, self.rpc) for skill in skills],
system_prompt=self.config.system_prompt,
)
self._thread.start()
Expand All @@ -117,6 +125,64 @@ def on_system_modules(self, modules: list[RPCClient]) -> None:
def add_message(self, message: BaseMessage) -> None:
self._message_queue.put(message)

@rpc
def dispatch_continuation(
self, continuation: dict[str, Any], continuation_context: dict[str, Any]
) -> None:
"""Execute a tool continuation with detection data, bypassing the LLM.

Called by trigger tools (e.g. look_out_for) to immediately invoke a
follow-up tool when a detection fires, without waiting for the LLM to
reason about the next action.

Args:
continuation: ``{"tool": "<name>", "args": {…}}`` — the tool to
call and its arguments. Argument values that are strings
starting with ``$`` are treated as template variables and
resolved against *continuation_context* (e.g. ``"$bbox"``).
continuation_context: runtime detection data, e.g.
``{"bbox": [x1, y1, x2, y2], "label": "person"}``.
"""
tool_name = continuation.get("tool")
if not tool_name:
self._message_queue.put(
HumanMessage(f"Continuation failed: missing 'tool' key in {continuation}")
)
return

skill_info = self._skill_registry.get(tool_name)
if skill_info is None:
self._message_queue.put(
HumanMessage(f"Continuation failed: tool '{tool_name}' not found")
)
return

tool_args: dict[str, Any] = dict(continuation.get("args", {}))

# Substitute $-prefixed template variables from continuation_context
for key, value in tool_args.items():
if isinstance(value, str) and value.startswith("$"):
context_key = value[1:]
if context_key in continuation_context:
tool_args[key] = continuation_context[context_key]

rpc_call = RpcCall(None, self.rpc, skill_info.func_name, skill_info.class_name, [])
try:
result = rpc_call(**tool_args)
except Exception as e:
self._message_queue.put(
HumanMessage(f"Continuation '{tool_name}' failed with error: {e}")
)
return

label = continuation_context.get("label", "unknown")
self._message_queue.put(
HumanMessage(
f"Automatically executed '{tool_name}' as a continuation of lookout "
f"detection (detected: {label}). Result: {result or 'started'}"
)
)

def _thread_loop(self) -> None:
while not self._stop_event.is_set():
try:
Expand Down Expand Up @@ -150,13 +216,9 @@ def _process_message(

class AgentSpec(Spec, Protocol):
def add_message(self, message: BaseMessage) -> None: ...


def _get_tools_from_modules(
agent: Agent, modules: list[RPCClient], rpc: RPCSpec
) -> list[StructuredTool]:
skills = [skill for module in modules for skill in (module.get_skills() or [])]
return [_skill_to_tool(agent, skill, rpc) for skill in skills]
def dispatch_continuation(
self, continuation: dict[str, Any], continuation_context: dict[str, Any]
) -> None: ...


def _skill_to_tool(agent: Agent, skill: SkillInfo, rpc: RPCSpec) -> StructuredTool:
Expand Down
65 changes: 64 additions & 1 deletion dimos/agents/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class McpClient(Module[McpClientConfig]):
_lock: RLock
_state_graph: CompiledStateGraph[Any, Any, Any, Any] | None
_message_queue: Queue[BaseMessage]
_tool_registry: dict[str, dict[str, Any]]
_history: list[BaseMessage]
_thread: Thread
_stop_event: Event
Expand All @@ -67,6 +68,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._lock = RLock()
self._state_graph = None
self._message_queue = Queue()
self._tool_registry = {}
self._history = []
self._thread = Thread(
target=self._thread_loop,
Expand Down Expand Up @@ -106,7 +108,9 @@ def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[Str
f"Failed to fetch tools from MCP server {self.config.mcp_server_url}"
)

tools = [self._mcp_tool_to_langchain(t) for t in result.get("tools", [])]
raw_tools = result.get("tools", [])
self._tool_registry = {t["name"]: t for t in raw_tools}
tools = [self._mcp_tool_to_langchain(t) for t in raw_tools]

if not tools:
logger.warning("No tools found from MCP server.")
Expand Down Expand Up @@ -198,6 +202,65 @@ def stop(self) -> None:
def add_message(self, message: BaseMessage) -> None:
self._message_queue.put(message)

@rpc
def dispatch_continuation(
self, continuation: dict[str, Any], continuation_context: dict[str, Any]
) -> None:
"""Execute a tool continuation with detection data, bypassing the LLM.

Called by trigger tools (e.g. look_out_for) to immediately invoke a
follow-up tool when a detection fires, without waiting for the LLM to
reason about the next action.

Args:
continuation: ``{"tool": "<name>", "args": {…}}`` — the tool to
call and its arguments. Argument values that are strings
starting with ``$`` are treated as template variables and
resolved against *continuation_context* (e.g. ``"$bbox"``).
continuation_context: runtime detection data, e.g.
``{"bbox": [x1, y1, x2, y2], "label": "person"}``.
"""
tool_name = continuation.get("tool")
if not tool_name:
self._message_queue.put(
HumanMessage(f"Continuation failed: missing 'tool' key in {continuation}")
)
return

if tool_name not in self._tool_registry:
self._message_queue.put(
HumanMessage(f"Continuation failed: tool '{tool_name}' not found")
)
return

tool_args: dict[str, Any] = dict(continuation.get("args", {}))

# Substitute $-prefixed template variables from continuation_context
for key, value in tool_args.items():
if isinstance(value, str) and value.startswith("$"):
context_key = value[1:]
if context_key in continuation_context:
tool_args[key] = continuation_context[context_key]

try:
result = self._mcp_request("tools/call", {"name": tool_name, "arguments": tool_args})
content = result.get("content", [])
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
text = "\n".join(parts)
except Exception as e:
self._message_queue.put(
HumanMessage(f"Continuation '{tool_name}' failed with error: {e}")
)
return

label = continuation_context.get("label", "unknown")
self._message_queue.put(
HumanMessage(
f"Automatically executed '{tool_name}' as a continuation of lookout "
f"detection (detected: {label}). Result: {text or 'started'}"
)
)

def _thread_loop(self) -> None:
while not self._stop_event.is_set():
try:
Expand Down
76 changes: 61 additions & 15 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
from threading import Event, RLock, Thread
import time
from typing import TYPE_CHECKING

from langchain_core.messages import HumanMessage
import numpy as np
from reactivex.disposable import Disposable
from turbojpeg import TurboJPEG

from dimos.agents.agent import AgentSpec
from dimos.agents.annotation import skill
Expand All @@ -29,7 +31,8 @@
from dimos.models.qwen.bbox import BBox
from dimos.models.vl.create import create
from dimos.msgs.geometry_msgs import Twist
from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
from dimos.msgs.sensor_msgs import CameraInfo, Image, ImageFormat, PointCloud2
from dimos.navigation.patrolling.patrolling_module_spec import PatrollingModuleSpec
from dimos.navigation.visual.query import get_object_bbox_from_image
from dimos.navigation.visual_servoing.detection_navigation import DetectionNavigation
from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D
Expand Down Expand Up @@ -59,6 +62,7 @@ class PersonFollowSkillContainer(Module):
_agent_spec: AgentSpec
_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost
_patrolling_module_spec: PatrollingModuleSpec

def __init__(
self,
Expand Down Expand Up @@ -107,14 +111,25 @@ def stop(self) -> None:
super().stop()

@skill
def follow_person(self, query: str) -> str:
def follow_person(
self,
query: str,
initial_bbox: list[float] | None = None,
initial_image: str | None = None,
) -> str:
"""Follow a person matching the given description using visual servoing.

The robot will continuously track and follow the person, while keeping
them centered in the camera view.

Args:
query: Description of the person to follow (e.g., "man with blue shirt")
initial_bbox: Optional pre-computed bounding box [x1, y1, x2, y2].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is exactly what Detection2d/bbox model is, should we use those for detections to follow?

I can imagine visual following making sense in many cases, would be good to support standard detection types?

If provided, skips the initial VL model detection step. This is
used by the continuation system to pass detection data directly
from look_out_for, avoiding a redundant detection.
initial_image: Optional base64-encoded JPEG of the frame on which
initial_bbox was detected.

Returns:
Status message indicating the result of the following action.
Expand All @@ -134,16 +149,27 @@ def follow_person(self, query: str) -> str:
if latest_image is None:
return "No image available to detect person."

initial_bbox = get_object_bbox_from_image(
self._vl_model,
latest_image,
query,
)

if initial_bbox is None:
return f"Could not find '{query}' in the current view."

return self._follow_person(query, initial_bbox)
detection_image: Image | None = None
if initial_bbox is not None:
bbox: BBox = (
initial_bbox[0],
initial_bbox[1],
initial_bbox[2],
initial_bbox[3],
)
if initial_image is not None:
detection_image = _decode_base64_image(initial_image)
else:
detected = get_object_bbox_from_image(
self._vl_model,
latest_image,
query,
)
if detected is None:
return f"Could not find '{query}' in the current view."
bbox = detected

return self._follow_person(query, bbox, detection_image)

@skill
def stop_following(self) -> str:
Expand All @@ -170,7 +196,9 @@ def _on_pointcloud(self, pointcloud: PointCloud2) -> None:
with self._lock:
self._latest_pointcloud = pointcloud

def _follow_person(self, query: str, initial_bbox: BBox) -> str:
def _follow_person(
self, query: str, initial_bbox: BBox, detection_image: Image | None = None
) -> str:
x1, y1, x2, y2 = initial_bbox
box = np.array([x1, y1, x2, y2], dtype=np.float32)

Expand All @@ -185,8 +213,11 @@ def _follow_person(self, query: str, initial_bbox: BBox) -> str:
if latest_image is None:
return "No image available to start tracking."

# Use the detection frame for tracker init when available, so the bbox
# matches the image it was computed on.
init_image = detection_image if detection_image is not None else latest_image
initial_detections = tracker.init_track(
image=latest_image,
image=init_image,
box=box,
obj_id=1,
)
Expand All @@ -200,11 +231,21 @@ def _follow_person(self, query: str, initial_bbox: BBox) -> str:
self._thread = Thread(target=self._follow_loop, args=(tracker, query), daemon=True)
self._thread.start()

return (
message = (
"Found the person. Starting to follow. You can stop following by calling "
"the 'stop_following' tool."
)

if self._patrolling_module_spec.is_patrolling():
message += (
" Note: since the robot was patrolling, this has been stopped automatically "
"(the equivalent of calling the `stop_patrol` tool call) so you don't have "
"to do it. "
)
self._patrolling_module_spec.stop_patrol()

return message

def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None:
lost_count = 0
period = 1.0 / self._frequency
Expand Down Expand Up @@ -268,6 +309,11 @@ def _send_stop_reason(self, query: str, reason: str) -> None:
logger.info("Person follow stopped", query=query, reason=reason)


def _decode_base64_image(b64: str) -> Image:
bgr_array = TurboJPEG().decode(base64.b64decode(b64))
return Image(data=bgr_array, format=ImageFormat.BGR)


person_follow_skill = PersonFollowSkillContainer.blueprint

__all__ = ["PersonFollowSkillContainer", "person_follow_skill"]
3 changes: 1 addition & 2 deletions dimos/core/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from pydantic_settings import BaseSettings, SettingsConfigDict

from dimos.mapping.occupancy.path_map import NavigationStrategy
from dimos.models.vl.create import VlModelName

ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"]
Expand Down Expand Up @@ -47,7 +46,7 @@ class GlobalConfig(BaseSettings):
robot_model: str | None = None
robot_width: float = 0.3
robot_rotation_diameter: float = 0.6
planner_strategy: NavigationStrategy = "simple"
nerf_speed: float = 1.0
planner_robot_speed: float | None = None
mcp_port: int = 9990
mcp_host: str = "0.0.0.0"
Expand Down
Loading
Loading