Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

from abc import ABC
from collections import defaultdict
from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import dataclass, field, replace
from functools import cached_property, reduce
import operator
import sys
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, get_args, get_origin, get_type_hints

from pydantic import BaseModel, create_model

if TYPE_CHECKING:
from dimos.protocol.service.system_configurator.base import SystemConfigurator

Expand Down Expand Up @@ -130,6 +132,11 @@ def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint":
return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules)

def config(self) -> type[BaseModel]:
configs = {b.module.name: (b.module.default_config | None, None) for b in self.blueprints}
configs["g"] = (GlobalConfig | None, None)
return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return]

def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports}))

Expand Down Expand Up @@ -274,13 +281,16 @@ def _verify_no_name_conflicts(self) -> None:
raise ValueError("\n".join(error_lines))

def _deploy_all_modules(
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
self,
module_coordinator: ModuleCoordinator,
global_config: GlobalConfig,
blueprint_args: Mapping[str, Mapping[str, Any]],
) -> None:
module_specs: list[ModuleSpec] = []
for blueprint in self._active_blueprints:
module_specs.append((blueprint.module, global_config, blueprint.kwargs))
module_specs.append((blueprint.module, global_config, blueprint.kwargs.copy()))

module_coordinator.deploy_parallel(module_specs)
module_coordinator.deploy_parallel(module_specs, blueprint_args)

def _connect_streams(self, module_coordinator: ModuleCoordinator) -> None:
# dict when given (final/remapped) stream name+type, provides a list of modules + original (non-remapped) stream names
Expand Down Expand Up @@ -472,12 +482,13 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:

def build(
self,
cli_config_overrides: Mapping[str, Any] | None = None,
blueprint_args: MutableMapping[str, Any] | None = None,
) -> ModuleCoordinator:
logger.info("Building the blueprint")
global_config.update(**dict(self.global_config_overrides))
if cli_config_overrides:
global_config.update(**dict(cli_config_overrides))
blueprint_args = blueprint_args or {}
if "g" in blueprint_args:
global_config.update(**blueprint_args.pop("g"))

self._run_configurators()
self._check_requirements()
Expand All @@ -488,7 +499,7 @@ def build(
module_coordinator.start()

# all module constructors are called here (each of them setup their own)
self._deploy_all_modules(module_coordinator, global_config)
self._deploy_all_modules(module_coordinator, global_config, blueprint_args)
self._connect_streams(module_coordinator)
self._connect_rpc_methods(module_coordinator)
self._connect_module_refs(module_coordinator)
Expand Down
5 changes: 5 additions & 0 deletions dimos/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def __init__(self, config_args: dict[str, Any]):
except ValueError:
...

@classproperty
def name(self) -> str:
"""Name for this module to be used for blueprint configs."""
return self.__name__.lower() # type: ignore[attr-defined,no-any-return]

@property
def frame_id(self) -> str:
base = self.config.frame_id or self.__class__.__name__
Expand Down
7 changes: 5 additions & 2 deletions dimos/core/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -128,11 +129,13 @@ def deploy(
self._deployed_modules[module_class] = module # type: ignore[assignment]
return module # type: ignore[return-value]

def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
def deploy_parallel(
self, module_specs: list[ModuleSpec], blueprint_args: Mapping[str, Mapping[str, Any]]
) -> list[ModuleProxy]:
if not self._client:
raise ValueError("Not started")

modules = self._client.deploy_parallel(module_specs)
modules = self._client.deploy_parallel(module_specs, blueprint_args)
for (module_class, _, _), module in zip(module_specs, modules, strict=True):
self._deployed_modules[module_class] = module # type: ignore[assignment]
return modules # type: ignore[return-value]
Expand Down
31 changes: 21 additions & 10 deletions dimos/core/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import MappingProxyType
from typing import Protocol

import pytest
Expand All @@ -37,9 +38,11 @@
from dimos.spec.utils import Spec

# Disable Rerun for tests (prevents viewer spawn and gRPC flush errors)
_BUILD_WITHOUT_RERUN = {
"cli_config_overrides": {"viewer": "none"},
}
_BUILD_WITHOUT_RERUN = MappingProxyType(
{
"g": {"viewer": "none"},
}
)


class Scratch:
Expand Down Expand Up @@ -152,6 +155,14 @@ def test_autoconnect() -> None:
)


def test_config() -> None:
blueprint = autoconnect(module_a(), module_b())
config = blueprint.config()
assert config.model_fields.keys() == {"modulea", "moduleb", "g"}
assert config.model_fields["modulea"].annotation == ModuleA.default_config | None
assert config.model_fields["moduleb"].annotation == ModuleB.default_config | None


def test_transports() -> None:
custom_transport = LCMTransport("/custom_topic", Data1)
blueprint_set = autoconnect(module_a(), module_b()).transports(
Expand All @@ -175,7 +186,7 @@ def test_global_config() -> None:
def test_build_happy_path() -> None:
blueprint_set = autoconnect(module_a(), module_b(), module_c())

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
assert isinstance(coordinator, ModuleCoordinator)
Expand Down Expand Up @@ -304,7 +315,7 @@ def test_remapping() -> None:
assert ("color_image", Data1) not in blueprint_set._all_name_types

# Build and verify streams work
coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
source_instance = coordinator.get_instance(SourceModule)
Expand Down Expand Up @@ -354,7 +365,7 @@ def test_future_annotations_autoconnect() -> None:

blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint())

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
out_instance = coordinator.get_instance(FutureModuleOut)
Expand Down Expand Up @@ -446,7 +457,7 @@ def test_module_ref_direct() -> None:
coordinator = autoconnect(
Calculator1.blueprint(),
Mod1.blueprint(),
).build(**_BUILD_WITHOUT_RERUN)
).build(_BUILD_WITHOUT_RERUN.copy())

try:
mod1 = coordinator.get_instance(Mod1)
Expand All @@ -462,7 +473,7 @@ def test_module_ref_spec() -> None:
coordinator = autoconnect(
Calculator1.blueprint(),
Mod2.blueprint(),
).build(**_BUILD_WITHOUT_RERUN)
).build(_BUILD_WITHOUT_RERUN.copy())

try:
mod2 = coordinator.get_instance(Mod2)
Expand All @@ -477,7 +488,7 @@ def test_module_ref_spec() -> None:
def test_disabled_modules_are_skipped_during_build() -> None:
blueprint_set = autoconnect(module_a(), module_b(), module_c()).disabled_modules(ModuleC)

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
assert coordinator.get_instance(ModuleA) is not None
Expand Down Expand Up @@ -515,7 +526,7 @@ def test_module_ref_remap_ambiguous() -> None:
(Mod2, "calc", Calculator1),
]
)
.build(**_BUILD_WITHOUT_RERUN)
.build(_BUILD_WITHOUT_RERUN.copy())
)

try:
Expand Down
3 changes: 2 additions & 1 deletion dimos/core/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def test_worker_manager_parallel_deployment(create_worker_manager):
(SimpleModule, global_config, {}),
(AnotherModule, global_config, {}),
(ThirdModule, global_config, {}),
]
],
{},
)

assert len(modules) == 3
Expand Down
9 changes: 7 additions & 2 deletions dimos/core/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down Expand Up @@ -61,7 +61,11 @@ def deploy(
actor = worker.deploy_module(module_class, global_config, kwargs=kwargs)
return RPCClient(actor, module_class)

def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]:
def deploy_parallel(
self,
module_specs: Iterable[ModuleSpec],
blueprint_args: Mapping[str, Mapping[str, Any]],
) -> list[RPCClient]:
if self._closed:
raise RuntimeError("WorkerManager is closed")

Expand All @@ -76,6 +80,7 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient]
for module_class, global_config, kwargs in module_specs:
worker = self._select_worker()
worker.reserve_slot()
kwargs.update(blueprint_args.get(module_class.name, {}))
assignments.append((worker, module_class, global_config, kwargs))

def _deploy(
Expand Down
Loading
Loading