-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathbuilders.py
More file actions
29 lines (26 loc) · 1.07 KB
/
builders.py
File metadata and controls
29 lines (26 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import annotations
from config import Config
from model import BaseModelAdapter
from model import OpenPiRTCJaxAdapter
from model import OpenPiRTCTritonAdapter
from optimizer import BaseOptimizer
from optimizer import PassThroughOptimizer
from optimizer import TimeParameterizationMPC
def build_model(cfg: Config) -> BaseModelAdapter:
adapter_map = {
"openpi_rtc_jax": OpenPiRTCJaxAdapter,
"openpi_rtc_triton": OpenPiRTCTritonAdapter,
}
adapter_cls = adapter_map.get(cfg.model.adapter)
if adapter_cls is None:
raise ValueError(f"Unsupported model.adapter={cfg.model.adapter!r}")
return adapter_cls.from_config(cfg.model)
def build_optimizer(cfg: Config) -> BaseOptimizer:
optimizer_map = {
"pass_through": PassThroughOptimizer,
"timeaxis_smooth": TimeParameterizationMPC,
}
optimizer_cls = optimizer_map.get(cfg.inference.optimizer)
if optimizer_cls is None:
raise ValueError(f"Unsupported inference.optimizer={cfg.inference.optimizer!r}")
return optimizer_cls.from_config(cfg.inference)