diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index abfeb29b2..31f0390b2 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -14,7 +14,7 @@ from abc import ABC from collections import defaultdict -from collections.abc import Callable, Mapping +from collections.abc import Callable, Mapping, MutableMapping from dataclasses import dataclass, field, replace from functools import cached_property, reduce import operator @@ -22,6 +22,8 @@ from types import MappingProxyType from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin, get_type_hints +from pydantic import BaseModel, create_model + if TYPE_CHECKING: from dimos.protocol.service.system_configurator.base import SystemConfigurator @@ -130,6 +132,11 @@ def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint": def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint": return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules) + def config(self) -> type[BaseModel]: + configs = {b.module.name: (b.module.default_config | None, None) for b in self.blueprints} + configs["g"] = (GlobalConfig | None, None) + return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return] + def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint": return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports})) @@ -274,13 +281,16 @@ def _verify_no_name_conflicts(self) -> None: raise ValueError("\n".join(error_lines)) def _deploy_all_modules( - self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig + self, + module_coordinator: ModuleCoordinator, + global_config: GlobalConfig, + blueprint_args: Mapping[str, Mapping[str, Any]], ) -> None: module_specs: list[ModuleSpec] = [] for blueprint in self._active_blueprints: - module_specs.append((blueprint.module, global_config, blueprint.kwargs)) + module_specs.append((blueprint.module, global_config, blueprint.kwargs.copy())) - module_coordinator.deploy_parallel(module_specs) + module_coordinator.deploy_parallel(module_specs, blueprint_args) def _connect_streams(self, module_coordinator: ModuleCoordinator) -> None: # dict when given (final/remapped) stream name+type, provides a list of modules + original (non-remapped) stream names @@ -472,12 +482,13 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: def build( self, - cli_config_overrides: Mapping[str, Any] | None = None, + blueprint_args: MutableMapping[str, Any] | None = None, ) -> ModuleCoordinator: logger.info("Building the blueprint") global_config.update(**dict(self.global_config_overrides)) - if cli_config_overrides: - global_config.update(**dict(cli_config_overrides)) + blueprint_args = blueprint_args or {} + if "g" in blueprint_args: + global_config.update(**blueprint_args.pop("g")) self._run_configurators() self._check_requirements() @@ -488,7 +499,7 @@ def build( module_coordinator.start() # all module constructors are called here (each of them setup their own) - self._deploy_all_modules(module_coordinator, global_config) + self._deploy_all_modules(module_coordinator, global_config, blueprint_args) self._connect_streams(module_coordinator) self._connect_rpc_methods(module_coordinator) self._connect_module_refs(module_coordinator) diff --git a/dimos/core/module.py b/dimos/core/module.py index ab21ce17a..b99a47a75 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -117,6 +117,11 @@ def __init__(self, config_args: dict[str, Any]): except ValueError: ... + @classproperty + def name(self) -> str: + """Name for this module to be used for blueprint configs.""" + return self.__name__.lower() # type: ignore[attr-defined,no-any-return] + @property def frame_id(self) -> str: base = self.config.frame_id or self.__class__.__name__ diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 10227eae9..7a4f10cee 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -14,6 +14,7 @@ from __future__ import annotations +from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor import threading from typing import TYPE_CHECKING, Any @@ -128,11 +129,13 @@ def deploy( self._deployed_modules[module_class] = module # type: ignore[assignment] return module # type: ignore[return-value] - def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: + def deploy_parallel( + self, module_specs: list[ModuleSpec], blueprint_args: Mapping[str, Mapping[str, Any]] + ) -> list[ModuleProxy]: if not self._client: raise ValueError("Not started") - modules = self._client.deploy_parallel(module_specs) + modules = self._client.deploy_parallel(module_specs, blueprint_args) for (module_class, _, _), module in zip(module_specs, modules, strict=True): self._deployed_modules[module_class] = module # type: ignore[assignment] return modules # type: ignore[return-value] diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 19dbf62c7..1d4a310e5 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import MappingProxyType from typing import Protocol import pytest @@ -37,9 +38,11 @@ from dimos.spec.utils import Spec # Disable Rerun for tests (prevents viewer spawn and gRPC flush errors) -_BUILD_WITHOUT_RERUN = { - "cli_config_overrides": {"viewer": "none"}, -} +_BUILD_WITHOUT_RERUN = MappingProxyType( + { + "g": {"viewer": "none"}, + } +) class Scratch: @@ -152,6 +155,14 @@ def test_autoconnect() -> None: ) +def test_config() -> None: + blueprint = autoconnect(module_a(), module_b()) + config = blueprint.config() + assert config.model_fields.keys() == {"modulea", "moduleb", "g"} + assert config.model_fields["modulea"].annotation == ModuleA.default_config | None + assert config.model_fields["moduleb"].annotation == ModuleB.default_config | None + + def test_transports() -> None: custom_transport = LCMTransport("/custom_topic", Data1) blueprint_set = autoconnect(module_a(), module_b()).transports( @@ -175,7 +186,7 @@ def test_global_config() -> None: def test_build_happy_path() -> None: blueprint_set = autoconnect(module_a(), module_b(), module_c()) - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy()) try: assert isinstance(coordinator, ModuleCoordinator) @@ -304,7 +315,7 @@ def test_remapping() -> None: assert ("color_image", Data1) not in blueprint_set._all_name_types # Build and verify streams work - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy()) try: source_instance = coordinator.get_instance(SourceModule) @@ -354,7 +365,7 @@ def test_future_annotations_autoconnect() -> None: blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint()) - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy()) try: out_instance = coordinator.get_instance(FutureModuleOut) @@ -446,7 +457,7 @@ def test_module_ref_direct() -> None: coordinator = autoconnect( Calculator1.blueprint(), Mod1.blueprint(), - ).build(**_BUILD_WITHOUT_RERUN) + ).build(_BUILD_WITHOUT_RERUN.copy()) try: mod1 = coordinator.get_instance(Mod1) @@ -462,7 +473,7 @@ def test_module_ref_spec() -> None: coordinator = autoconnect( Calculator1.blueprint(), Mod2.blueprint(), - ).build(**_BUILD_WITHOUT_RERUN) + ).build(_BUILD_WITHOUT_RERUN.copy()) try: mod2 = coordinator.get_instance(Mod2) @@ -477,7 +488,7 @@ def test_module_ref_spec() -> None: def test_disabled_modules_are_skipped_during_build() -> None: blueprint_set = autoconnect(module_a(), module_b(), module_c()).disabled_modules(ModuleC) - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy()) try: assert coordinator.get_instance(ModuleA) is not None @@ -515,7 +526,7 @@ def test_module_ref_remap_ambiguous() -> None: (Mod2, "calc", Calculator1), ] ) - .build(**_BUILD_WITHOUT_RERUN) + .build(_BUILD_WITHOUT_RERUN.copy()) ) try: diff --git a/dimos/core/test_worker.py b/dimos/core/test_worker.py index 306b3fdb3..c05b4047d 100644 --- a/dimos/core/test_worker.py +++ b/dimos/core/test_worker.py @@ -145,7 +145,8 @@ def test_worker_manager_parallel_deployment(create_worker_manager): (SimpleModule, global_config, {}), (AnotherModule, global_config, {}), (ThirdModule, global_config, {}), - ] + ], + {}, ) assert len(modules) == 3 diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 4cd5eec8d..96b9e227f 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -61,7 +61,11 @@ def deploy( actor = worker.deploy_module(module_class, global_config, kwargs=kwargs) return RPCClient(actor, module_class) - def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]: + def deploy_parallel( + self, + module_specs: Iterable[ModuleSpec], + blueprint_args: Mapping[str, Mapping[str, Any]], + ) -> list[RPCClient]: if self._closed: raise RuntimeError("WorkerManager is closed") @@ -76,6 +80,7 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient] for module_class, global_config, kwargs in module_specs: worker = self._select_worker() worker.reserve_slot() + kwargs.update(blueprint_args.get(module_class.name, {})) assignments.append((worker, module_class, global_config, kwargs)) def _deploy( diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 1137a612f..0c768562e 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -14,24 +14,39 @@ from __future__ import annotations +from collections.abc import Iterable +from contextlib import suppress from datetime import datetime, timezone import inspect import json import os +from pathlib import Path import sys import time -from typing import Any, get_args, get_origin +import types +from typing import Any, Union, get_args, get_origin import click from dotenv import load_dotenv +from pydantic import BaseModel +from pydantic_core import PydanticUndefined import requests import typer from dimos.agents.mcp.mcp_adapter import McpAdapter, McpError +from dimos.core.blueprints import Blueprint, _BlueprintAtom from dimos.core.global_config import GlobalConfig, global_config from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry from dimos.utils.logging_config import setup_logger +try: + # Not a dependency, just the best way to get config path if available. + from gi.repository import GLib # type: ignore[import-untyped,import-not-found] +except ImportError: + CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config")) +else: + CONFIG_DIR = Path(GLib.get_user_config_dir()) + logger = setup_logger() main = typer.Typer( @@ -108,12 +123,84 @@ def callback(**kwargs) -> None: # type: ignore[no-untyped-def] main.callback()(create_dynamic_callback()) # type: ignore[no-untyped-call] +def arghelp( + config: type[BaseModel], + blueprint: Blueprint, + indent: str = " ", + module: str = "", + _atom: _BlueprintAtom | None = None, +) -> str: + output = "" + for k, info in config.model_fields.items(): + if k == "g": + continue + t = info.annotation + if isinstance(t, types.GenericAlias): + # Can't be specified on CLI + continue + + # TODO(PY314): if isinstance(t, Union): + if get_origin(t) in {Union, types.UnionType}: + with suppress(StopIteration): + t = next(u for u in get_args(t) if issubclass(u, BaseModel)) + + if inspect.isclass(t) and issubclass(t, BaseModel): + output += f"{indent}{module}{k}:\n" + # Find blueprint atom + bp = next(bp for bp in blueprint.blueprints if bp.module.name == k) + output += arghelp(t, blueprint, indent=indent + " ", module=module + k + ".", _atom=bp) + else: + assert _atom is not None + # Use __name__ to avoid "" style output on basic types. + display_type = t.__name__ if isinstance(t, type) else t + required = "[Required] " if info.is_required() and k not in _atom.kwargs else "" + d = _atom.kwargs.get(k, info.default) + default = f" (default: {d})" if d is not PydanticUndefined else "" + output += f"{indent}* {required}{module}{k}: {display_type}{default}\n" + return output + + +def load_config_args(config: type[BaseModel], args: Iterable[str], path: Path) -> dict[str, Any]: + try: + kwargs = json.loads(path.read_text()) + except (OSError, json.JSONDecodeError): + kwargs = {} + + for k, v in os.environ.items(): + parts = k.lower().split("__") + if parts[0] not in config.model_fields: + continue + d = kwargs + for p in parts[:-1]: + d = d.setdefault(p, {}) + d[parts[-1]] = v + + for arg in args: + k, _, v = arg.partition("=") + parts = k.split(".") + d = kwargs + for p in parts[:-1]: + d = d.setdefault(p, {}) + d[parts[-1]] = v + + # We don't need this config, but this atleast validates the user input first. + # This will help catch misspellings and similar mistakes. + config(**kwargs) + + return kwargs # type: ignore[no-any-return] + + @main.command() def run( ctx: typer.Context, robot_types: list[str] = typer.Argument(..., help="Blueprints or modules to run"), daemon: bool = typer.Option(False, "--daemon", "-d", help="Run in background"), disable: list[str] = typer.Option([], "--disable", help="Module names to disable"), + blueprint_args: list[str] = typer.Option((), "--option", "-o"), + config_path: Path = typer.Option( + CONFIG_DIR / "dimos", "--config", "-c", help="Path to config file" + ), + show_help: bool = typer.Option(False, "--help"), ) -> None: """Start a robot blueprint""" logger.info("Starting DimOS") @@ -132,7 +219,6 @@ def run( setup_exception_handler() cli_config_overrides: dict[str, Any] = ctx.obj - global_config.update(**cli_config_overrides) # Clean stale registry entries stale = cleanup_stale() @@ -163,7 +249,17 @@ def run( disabled_classes = tuple(get_module_by_name(name).blueprints[0].module for name in disable) blueprint = blueprint.disabled_modules(*disabled_classes) - coordinator = blueprint.build(cli_config_overrides=cli_config_overrides) + if show_help: + print("Blueprint arguments:") + print(arghelp(blueprint.config(), blueprint)) + return + + blueprint_config = blueprint.config() + kwargs = load_config_args(blueprint_config, blueprint_args, config_path) + if cli_config_overrides: + kwargs["g"] = cli_config_overrides + + coordinator = blueprint.build(kwargs) if daemon: from dimos.core.daemon import ( diff --git a/dimos/robot/cli/test_dimos.py b/dimos/robot/cli/test_dimos.py new file mode 100644 index 000000000..706e6c787 --- /dev/null +++ b/dimos/robot/cli/test_dimos.py @@ -0,0 +1,90 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core.blueprints import autoconnect +from dimos.core.module import Module, ModuleConfig +from dimos.robot.cli.dimos import arghelp +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.visualization.rerun.bridge import RerunBridgeModule, _default_blueprint + + +def test_blueprint_arghelp(): + blueprint = autoconnect(RerunBridgeModule.blueprint(), GO2Connection.blueprint()) + output = arghelp(blueprint.config(), blueprint) + # List output produces better diff in pytest error output. + assert output.split("\n") == [ + " rerunbridgemodule:", + " * rerunbridgemodule.frame_id_prefix: str | None (default: None)", + " * rerunbridgemodule.frame_id: str | None (default: None)", + " * rerunbridgemodule.min_interval_sec: float (default: 0.1)", + " * rerunbridgemodule.entity_prefix: str (default: world)", + " * rerunbridgemodule.topic_to_entity: collections.abc.Callable[[typing.Any], str] | None (default: None)", + " * rerunbridgemodule.viewer_mode: typing.Literal['native', 'web', 'connect', 'none']", + " * rerunbridgemodule.connect_url: str (default: rerun+http://127.0.0.1:9877/proxy)", + " * rerunbridgemodule.memory_limit: str (default: 25%)", + f" * rerunbridgemodule.blueprint: collections.abc.Callable[rerun.blueprint.api.Blueprint] | None (default: {_default_blueprint})", + " go2connection:", + " * go2connection.frame_id_prefix: str | None (default: None)", + " * go2connection.frame_id: str | None (default: None)", + " * go2connection.ip: str", + "", + ] + + +def test_blueprint_arghelp_extra_args(): + """Test defaults passed to .blueprint() override.""" + + bridge = RerunBridgeModule.blueprint(frame_id_prefix="foo", viewer_mode="web") + blueprint = autoconnect(bridge, GO2Connection.blueprint(ip="1.1.1.1")) + output = arghelp(blueprint.config(), blueprint) + # List output produces better diff in pytest error output. + assert output.split("\n") == [ + " rerunbridgemodule:", + " * rerunbridgemodule.frame_id_prefix: str | None (default: foo)", + " * rerunbridgemodule.frame_id: str | None (default: None)", + " * rerunbridgemodule.min_interval_sec: float (default: 0.1)", + " * rerunbridgemodule.entity_prefix: str (default: world)", + " * rerunbridgemodule.topic_to_entity: collections.abc.Callable[[typing.Any], str] | None (default: None)", + " * rerunbridgemodule.viewer_mode: typing.Literal['native', 'web', 'connect', 'none'] (default: web)", + " * rerunbridgemodule.connect_url: str (default: rerun+http://127.0.0.1:9877/proxy)", + " * rerunbridgemodule.memory_limit: str (default: 25%)", + f" * rerunbridgemodule.blueprint: collections.abc.Callable[rerun.blueprint.api.Blueprint] | None (default: {_default_blueprint})", + " go2connection:", + " * go2connection.frame_id_prefix: str | None (default: None)", + " * go2connection.frame_id: str | None (default: None)", + " * go2connection.ip: str (default: 1.1.1.1)", + "", + ] + + +def test_blueprint_arghelp_required(): + """Test required arguments.""" + + class Config(ModuleConfig): + foo: int + spam: str = "eggs" + + class TestModule(Module[Config]): + default_config = Config + + blueprint = TestModule.blueprint() + output = arghelp(blueprint.config(), blueprint) + assert output.split("\n") == [ + " testmodule:", + " * testmodule.frame_id_prefix: str | None (default: None)", + " * testmodule.frame_id: str | None (default: None)", + " * [Required] testmodule.foo: int", + " * testmodule.spam: str (default: eggs)", + "", + ] diff --git a/docs/usage/blueprints.md b/docs/usage/blueprints.md index 80a6b24b1..04976c879 100644 --- a/docs/usage/blueprints.md +++ b/docs/usage/blueprints.md @@ -230,6 +230,45 @@ The config is normally taken from .env or from environment variables. But you ca blueprint = ModuleA.blueprint().global_config(n_workers=8) ``` +## Providing blueprint configuration to users + +`Blueprint.config()` can be used to get a `pydantic.BaseModel` that can be used to +inspect or test configuration settings that can be passed to `Blueprint.build()`: + +```python session=blueprint-ex1 +# Validate config input +blueprint_args = { + "module1": {"arg1": 5} +} +config = base_blueprint.config() +config(**blueprint_args) # raises pydantic.ValidationError if args are incorrect +``` + +`dimos.robot.cli.dimos.arghelp()` is a helper function that will return a string +containing all details of these arguments (this is how the output is produced when +running `dimos run unitree-go2 --help`, for example): + +```python session=blueprint-ex1 +from dimos.robot.cli.dimos import arghelp +print(arghelp(base_blueprint.config(), base_blueprint)) +``` + +Another function is `dimos.robot.cli.dimos.load_config_args()` which can create the +argument dict for users from a config file, environment variables and CLI arguments: + + +```python session=blueprint-ex1 +from dimos.robot.cli.dimos import load_config_args + +config_path = Path.home() / "base-blueprint-config.json" +cli_args = ["arg1=5"] +blueprint_args = load_config_args(base_blueprint.config(), cli_args, config_path) +# Test user input is valid +config(**blueprint_args) +# Then we can build the blueprint +base_blueprint.build(blueprint_args) +``` + ## Calling the methods of other modules Imagine you have this code: