diff --git a/configs/LeRobot/lerobot_datasource_base.toml b/configs/LeRobot/lerobot_datasource_base.toml new file mode 100644 index 000000000..7228f4848 --- /dev/null +++ b/configs/LeRobot/lerobot_datasource_base.toml @@ -0,0 +1,9 @@ + +# The training and testing dataset +datasource = "LeRobot" + +# IID or non-IID? LeRobot handles deterministic episode partitioning internally. +sampler = "iid" + +# The random seed used for deterministic train/test split and client partitioning. +random_seed = 1 diff --git a/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml b/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml new file mode 100644 index 000000000..45db563b9 --- /dev/null +++ b/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml @@ -0,0 +1,70 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8001 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_fedavg_two_client_smoke" +model_path = "models/lerobot/smolvla_fedavg_two_client_smoke" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 1 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 1 +batch_size = 2 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" +precision = "fp32" +device = "cpu" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.0 diff --git a/configs/LeRobot/smolvla_full_finetune.toml b/configs/LeRobot/smolvla_full_finetune.toml new file mode 100644 index 000000000..89c83b6f6 --- /dev/null +++ b/configs/LeRobot/smolvla_full_finetune.toml @@ -0,0 +1,77 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = true + +[server] +address = "127.0.0.1" +port = 8002 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_full_finetune" +model_path = "models/lerobot/smolvla_full_finetune" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 10 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 5 + +# SmolVLA upstream fine-tuning guidance uses batch size 64. +batch_size = 64 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "full" +precision = "bf16" +device = "cuda" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +split_seed = 7 +train_split = 0.9 +task_aware_split = true +task_aware_partition = true +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +center_crop = true +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.01 diff --git a/configs/LeRobot/smolvla_single_client_smoke.toml b/configs/LeRobot/smolvla_single_client_smoke.toml new file mode 100644 index 000000000..0c48f00f5 --- /dev/null +++ b/configs/LeRobot/smolvla_single_client_smoke.toml @@ -0,0 +1,70 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 1 + +# The number of clients selected in each round +per_round = 1 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_single_client_smoke" +model_path = "models/lerobot/smolvla_single_client_smoke" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 1 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 1 +batch_size = 2 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" +precision = "fp32" +device = "cpu" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.0 diff --git a/docs/docs/configurations/parameters.md b/docs/docs/configurations/parameters.md index 68cd6eddf..8a0119def 100644 --- a/docs/docs/configurations/parameters.md +++ b/docs/docs/configurations/parameters.md @@ -18,3 +18,65 @@ !!! example "loss_criterion" All the parameter settings that need to be passed as keyword parameters when initializing the loss criterion, such as `size_average`. The set of parameters permitted or needed depends on the loss criterion. + +## SmolVLA + LeRobot parameter contract + +`Config()` keeps nested keys under `[parameters]` as dot-accessible nodes. For +the SmolVLA + LeRobot integration, define the following sections. + +| Config key | Purpose | Consumption path | +| --- | --- | --- | +| `data.datasource = "LeRobot"` | Selects the robotics datasource family. | `plato.datasources.registry.get()` chooses the datasource module. | +| `trainer.type = "lerobot"` | Selects the robotics trainer backend. | `plato.trainers.registry.get()` chooses the trainer class. | +| `trainer.model_type = "smolvla"` | Selects the model family. | `plato.models.registry.get()` resolves the model factory. | +| `trainer.model_name = "smolvla"` | Selects the concrete model entry point. | `plato.models.registry.get()` resolves the model name. | +| `parameters.policy.type` | Policy family identifier (`smolvla` in v1). | Consumed by `plato/models/smolvla.py` and `plato/trainers/lerobot.py` via `Config().parameters.policy`. | +| `parameters.policy.path` | Pretrained policy source (Hub/local path). | Consumed by `plato/models/smolvla.py` via `Config().parameters.policy.path`. | +| `parameters.policy.finetune_mode` | Full fine-tune vs adapter mode switch. | Consumed by `plato/trainers/lerobot.py` to decide trainable params. | +| `parameters.policy.precision` | Runtime precision flag (`fp32`/`fp16`/`bf16`). | Consumed by `plato/trainers/lerobot.py` for dtype/autocast setup. | +| `parameters.policy.device` | Runtime device flag (`cpu`/`cuda`/`mps`). | Consumed by `plato/trainers/lerobot.py` for device placement. | +| `parameters.dataset.repo_id` | LeRobot dataset identifier. | Consumed by `plato/datasources/lerobot.py` dataset loader. | +| `parameters.dataset.delta_timestamps` | Temporal window selection per modality key. | Consumed by `plato/datasources/lerobot.py` sampling/windowing logic. | +| `parameters.transforms.*` | Image transform controls (`image_size`, `normalize`, interpolation/crop options). | Consumed by `plato/datasources/lerobot.py` preprocessing pipeline. | + +### Constructor-ready dictionaries + +SmolVLA/LeRobot components should convert section nodes to plain dictionaries +when passing keyword arguments into constructors: + +```python +from plato.config import Config + +cfg = Config() +policy_kwargs = cfg.parameters.policy._asdict() +dataset_kwargs = cfg.parameters.dataset._asdict() +transform_kwargs = cfg.parameters.transforms._asdict() +``` + +### Example + +```toml +[data] +datasource = "LeRobot" + +[trainer] +type = "lerobot" +model_type = "smolvla" +model_name = "smolvla" + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "full" +precision = "bf16" +device = "cuda" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" +``` diff --git a/docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md b/docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md new file mode 100644 index 000000000..ba99ac107 --- /dev/null +++ b/docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md @@ -0,0 +1,247 @@ +# SmolVLA Trainer with LeRobot + +This runbook is for operators running SmolVLA training in Plato with LeRobot datasets. It complements the parameter contract in [Configuration Parameters](../../configurations/parameters.md). + +## 1) Setup + +1. Install core dependencies: + +```bash +uv sync +``` + +2. Install robotics extras: + +```bash +uv sync --extra robotics +``` + +3. Authenticate to Hugging Face when using private repos: + +```bash +huggingface-cli login +``` + +4. Verify optional stack import: + +```bash +uv run python -c "import lerobot; print(lerobot.__version__)" +``` + +5. If your dataset is video-backed, ensure `ffmpeg` is installed on the host. + +## 2) Config Profiles to Start From + +Use the configs added under `configs/LeRobot/`: + +- `configs/LeRobot/lerobot_datasource_base.toml`: shared LeRobot datasource include. +- `configs/LeRobot/smolvla_single_client_smoke.toml`: minimal single-client smoke run. +- `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml`: 2-client FedAvg smoke run. +- `configs/LeRobot/smolvla_full_finetune.toml`: longer full fine-tune profile. + +## 3) Run Commands + +Single-client smoke: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml +``` + +Two-client federated smoke: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml +``` + +Full-finetune profile: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_full_finetune.toml +``` + +## 4) Required Plato TOML Fields + +Minimum contract for this integration: + +```toml +[data] +datasource = "LeRobot" + +[trainer] +type = "lerobot" +model_type = "smolvla" +model_name = "smolvla" + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" # or "full" +precision = "fp32" +device = "cpu" # or "cuda" / "mps" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" +``` + +## 5) Plato ↔ `lerobot-train` Mapping + +| Plato config field(s) | `lerobot-train` equivalent | Type | +| --- | --- | --- | +| `parameters.policy.path` | `--policy.path` | Direct | +| `parameters.dataset.repo_id` | `--dataset.repo_id` | Direct | +| `trainer.batch_size` | `--batch_size` | Direct | +| `parameters.policy.device` | `--policy.device` | Direct | +| `trainer.rounds` + `trainer.epochs` | `--steps` | Conceptual scheduling mapping | +| `server.checkpoint_path` / `server.model_path` | `--output_dir` | Conceptual output-location mapping | +| `parameters.dataset.delta_timestamps` | LeRobot dataset `delta_timestamps` usage during training | Conceptual data-window mapping | +| `parameters.policy.finetune_mode` (`full`/`adapter`) | Trainable-parameter strategy during policy training | Conceptual finetune-mode mapping | + +Notes: + +- Upstream LeRobot examples for SmolVLA commonly use `--steps`; Plato uses round/epoch scheduling. +- Adapter-mode behavior in Plato is implemented via `parameters.policy.finetune_mode` and + adapter parameter selection in the SmolVLA model wrapper. + +## 6) Troubleshooting + +### Missing optional robotics dependencies + +Symptom: + +- `ImportError: ... Install them with "uv sync --extra robotics" ...` + +Action: + +```bash +uv sync --extra robotics +``` + +### Dataset repo not configured + +Symptom: + +- `LeRobot datasource requires "parameters.dataset.repo_id" to be set.` + +Action: + +- Set `parameters.dataset.repo_id` in your TOML. + +### Private dataset/model access failure + +Symptoms: + +- SmolVLA load failure from `parameters.policy.path`. +- Dataset access/auth failures from the Hub. + +Actions: + +```bash +huggingface-cli login +# optionally for non-interactive runs +export HF_TOKEN= +``` + +### No episodes found + +Symptom: + +- `No episodes found for LeRobot dataset ""` + +Actions: + +- Verify `parameters.dataset.repo_id` exists and contains episodes. +- Confirm access permissions for the dataset repository. + +### Invalid `delta_timestamps` shape + +Symptom: + +- `"parameters.dataset.delta_timestamps" must be a mapping of key -> list[float].` + +Action: + +- Use mapping syntax, for example: + +```toml +[parameters.dataset] +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } +``` + +### Device/precision mismatch or OOM + +Symptoms: + +- CUDA/MPS initialization failures. +- Out-of-memory during training. + +Actions: + +- Start from `configs/LeRobot/smolvla_single_client_smoke.toml` (CPU, tiny batch). +- Reduce `trainer.batch_size`. +- Use `parameters.policy.device = "cpu"` for smoke checks. +- Move to `cuda` + higher batch sizes only after smoke passes. + +### FFmpeg / build issues in robotics stack + +Symptom: + +- Build/runtime errors mentioning FFmpeg or PyAV dependencies. + +Actions: + +- Install host FFmpeg libraries and build toolchain (`cmake`, `build-essential`, FFmpeg libs), then reinstall robotics extras. + + + + +## SmolVLA + LeRobot Optional Setup + +This setup path is optional. Core Plato federated workloads continue to use the default dependency set from `uv sync`. + +### Install the robotics extra + +From the repository root: + +```bash +uv sync --extra robotics +``` + +This installs `lerobot[smolvla]` and the associated training stack only when the +`robotics` extra is requested. + +### Environment gating + +When adding LeRobot-backed modules, keep imports guarded so non-robotics +environments fail with a clear action instead of a hard crash at import time. + +```python +try: + import lerobot +except ImportError as exc: + raise ImportError( + "LeRobot support is optional. Install with: uv sync --extra robotics" + ) from exc +``` + +### Runtime notes for SmolVLA/LeRobot + +- CUDA-capable GPUs are recommended for practical SmolVLA fine-tuning; CPU is + mainly suitable for smoke checks. +- Install `ffmpeg` on hosts that read video-backed LeRobot datasets. +- Authenticate with Hugging Face (`huggingface-cli login`) when accessing + private dataset repositories. +- LeRobot currently constrains the Torch stack used by this optional path; + if you need different Torch constraints for non-robotics research, keep a + separate virtual environment. + +### Quick verification + +```bash +uv run python -c "import lerobot; print(lerobot.__version__)" +``` diff --git a/docs/docs/install.md b/docs/docs/install.md index 342426e35..f62f4289e 100644 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -65,6 +65,20 @@ uv sync --extra mlx See the [Quick Start guide](quickstart.md#using-mlx-as-a-backend) for configuration details. +### Optional: SmolVLA + LeRobot Robotics Stack + +LeRobot and SmolVLA dependencies are available behind Plato's optional +`robotics` extra so default federated-learning installs remain lightweight: + +```bash +uv sync --extra robotics +``` + +See [SmolVLA + LeRobot Optional Setup](examples/case-studies/smolvla_lerobot_setup.md) for runtime +requirements and guarded-import guidance, and +[SmolVLA + LeRobot Runbook](examples/case-studies/smolvla_lerobot_runbook.md) for setup-to-run +operational steps and troubleshooting. + ### Building the `plato-learn` PyPi Package The `plato-learn` PyPi package will be automatically built and published by a GitHub action workflow every time a release is created on GitHub. To build the package manually, follow these steps: diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c5e428749..64f5f5ca9 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -69,6 +69,8 @@ nav: - Case Studies: - Federated LoRA Fine-Tuning: examples/case-studies/1. LoRA.md - Composable Trainer API: examples/case-studies/2. Composable Trainer.md + - SmolVLA + LeRobot Optional Setup: examples/case-studies/smolvla_lerobot_setup.md + - SmolVLA + LeRobot Runbook: examples/case-studies/smolvla_lerobot_runbook.md - Configuration Settings: - Overview: configurations/overview.md - General: configurations/general.md diff --git a/examples/server_aggregation/fedadp/fedadp_algorithm.py b/examples/server_aggregation/fedadp/fedadp_algorithm.py index 8f841ab23..0956ddbfb 100644 --- a/examples/server_aggregation/fedadp/fedadp_algorithm.py +++ b/examples/server_aggregation/fedadp/fedadp_algorithm.py @@ -174,7 +174,9 @@ def _to_float_tensor(tensor: torch.Tensor) -> torch.Tensor: @staticmethod def _cast_tensor_like( - tensor: torch.Tensor, reference: torch.Tensor + tensor: torch.Tensor, + reference: torch.Tensor, + tensor_name: str = "tensor", ) -> torch.Tensor: """Cast a tensor to match the dtype of a reference tensor.""" if tensor.dtype == reference.dtype: diff --git a/examples/server_aggregation/moon/moon_algorithm.py b/examples/server_aggregation/moon/moon_algorithm.py index cd6ded9b7..87658644a 100644 --- a/examples/server_aggregation/moon/moon_algorithm.py +++ b/examples/server_aggregation/moon/moon_algorithm.py @@ -21,7 +21,9 @@ class Algorithm(fedavg.Algorithm): @staticmethod def _cast_tensor_like( - tensor: torch.Tensor, reference: torch.Tensor + tensor: torch.Tensor, + reference: torch.Tensor, + tensor_name: str = "tensor", ) -> torch.Tensor: """Cast a tensor to match a reference dtype (handles bool/int safely).""" if tensor.dtype == reference.dtype: diff --git a/plato/algorithms/fedavg.py b/plato/algorithms/fedavg.py index 605e0b945..e1e0b3d97 100644 --- a/plato/algorithms/fedavg.py +++ b/plato/algorithms/fedavg.py @@ -2,10 +2,14 @@ The federated averaging algorithm for PyTorch. """ +from __future__ import annotations + +import os from collections import OrderedDict -from collections.abc import MutableMapping -from typing import Any, Optional +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any +import torch from torch.nn import Module from plato.algorithms import base @@ -14,21 +18,231 @@ class Algorithm(base.Algorithm): """PyTorch-based federated averaging algorithm, used by both the client and the server.""" + @staticmethod + def _as_state_mapping(weights: Any, context: str) -> Mapping[str, torch.Tensor]: + """Validate and cast a state-dict-like payload.""" + if not isinstance(weights, Mapping): + raise TypeError(f"{context} must be a mapping of parameter names to tensors.") + return weights + + @staticmethod + def _to_transport_tensor( + tensor: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """ + Convert a tensor to a wire-safe representation for payload transport. + + Safetensor serialization in the current runtime path does not support + `torch.bfloat16` conversion through numpy. Cast bf16 payload tensors to + fp32 for transport, then cast back in `load_weights`. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Payload tensor '{tensor_name}' must be a torch.Tensor, " + f"received {type(tensor).__name__}." + ) + + prepared = tensor.detach().cpu().clone() + if prepared.dtype == torch.bfloat16: + return prepared.to(torch.float32) + return prepared + + @staticmethod + def _cast_tensor_like( + tensor: torch.Tensor, reference: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """Cast an incoming tensor to match the dtype expected by `reference`.""" + if tensor.shape != reference.shape: + raise ValueError( + f"Tensor shape mismatch for '{tensor_name}': " + f"received {tuple(tensor.shape)}, expected {tuple(reference.shape)}." + ) + + if tensor.dtype == reference.dtype: + return tensor.detach() + + if reference.dtype == torch.bool: + if torch.is_floating_point(tensor): + return (tensor >= 0.5).detach() + return tensor.ne(0).detach() + + if torch.is_floating_point(reference) or torch.is_complex(reference): + return tensor.to(reference.dtype).detach() + + if torch.is_floating_point(tensor): + return torch.round(tensor).to(reference.dtype).detach() + + return tensor.to(reference.dtype).detach() + + @staticmethod + def _compute_tensor_delta( + current_weight: torch.Tensor, + baseline_weight: torch.Tensor, + tensor_name: str, + ) -> torch.Tensor: + """Compute a dtype-safe delta tensor for a parameter.""" + current_casted = Algorithm._cast_tensor_like( + current_weight, baseline_weight, tensor_name + ) + + if baseline_weight.dtype == torch.bool: + return current_casted.to(torch.int8) - baseline_weight.to(torch.int8) + + if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + return current_casted.to(baseline_weight.dtype) - baseline_weight + + return current_casted.to(torch.int64) - baseline_weight.to(torch.int64) + + @staticmethod + def _apply_tensor_delta( + baseline_weight: torch.Tensor, delta: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """Apply a delta tensor to the baseline tensor with dtype safeguards.""" + if delta.shape != baseline_weight.shape: + raise ValueError( + f"Delta shape mismatch for '{tensor_name}': " + f"received {tuple(delta.shape)}, expected {tuple(baseline_weight.shape)}." + ) + + if baseline_weight.dtype == torch.bool: + if torch.is_floating_point(delta): + delta_integral = torch.round(delta).to(torch.int8) + else: + delta_integral = delta.to(torch.int8) + return (baseline_weight.to(torch.int8) + delta_integral).ne(0) + + if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + return baseline_weight + delta.to(baseline_weight.dtype) + + if torch.is_floating_point(delta): + delta_casted = torch.round(delta).to(baseline_weight.dtype) + else: + delta_casted = delta.to(baseline_weight.dtype) + return baseline_weight + delta_casted + + def _resolve_payload_limit_mb(self) -> float | None: + """Resolve an optional payload-size limit from model attrs or env vars.""" + model = self.require_model() + configured_limit = getattr(model, "plato_max_payload_size_mb", None) + if configured_limit is None: + configured_limit = os.getenv("PLATO_FEDAVG_MAX_PAYLOAD_MB") + + if configured_limit is None: + return None + + try: + limit_mb = float(configured_limit) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Invalid payload size limit: {configured_limit!r}. " + "Expected a positive numeric value in megabytes." + ) from exc + + if limit_mb <= 0: + raise ValueError("Payload size limit must be greater than 0 MB.") + + return limit_mb + + @staticmethod + def _estimate_payload_size_bytes(weights: Mapping[str, torch.Tensor]) -> int: + """Estimate payload size by summing tensor storage bytes.""" + size_bytes = 0 + for name, tensor in weights.items(): + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Payload tensor '{name}' must be a torch.Tensor, " + f"received {type(tensor).__name__}." + ) + size_bytes += tensor.numel() * tensor.element_size() + return size_bytes + + def _assert_payload_size(self, weights: Mapping[str, torch.Tensor], source: str) -> None: + """Enforce an optional payload-size safeguard.""" + limit_mb = self._resolve_payload_limit_mb() + if limit_mb is None: + return + + payload_size_mb = self._estimate_payload_size_bytes(weights) / 1024**2 + if payload_size_mb > limit_mb: + raise ValueError( + f"{source} payload size {payload_size_mb:.2f} MB exceeds " + f"configured limit {limit_mb:.2f} MB." + ) + + @staticmethod + def _resolve_adapter_parameter_names( + target_model: Module, state_dict: Mapping[str, torch.Tensor] + ) -> list[str] | None: + """Resolve parameter names to exchange for adapter-only finetuning.""" + finetune_mode = getattr(target_model, "plato_finetune_mode", None) + if not isinstance(finetune_mode, str) or finetune_mode.strip().lower() != "adapter": + return None + + trainable_names_attr = getattr(target_model, "plato_trainable_parameter_names", None) + names_from_attr = ( + [ + name + for name in trainable_names_attr + if isinstance(name, str) and name in state_dict + ] + if isinstance(trainable_names_attr, Iterable) + else [] + ) + + if names_from_attr: + return names_from_attr + + names_from_requires_grad = [ + name + for name, parameter in target_model.named_parameters() + if parameter.requires_grad and name in state_dict + ] + if names_from_requires_grad: + return names_from_requires_grad + + raise ValueError( + "Adapter finetune mode is enabled, but no trainable parameters " + "were resolved for federated payload exchange." + ) + def compute_weight_deltas( self, baseline_weights: MutableMapping[str, Any], weights_received, ): """Compute the deltas between baseline weights and weights received.""" + baseline_mapping = self._as_state_mapping( + baseline_weights, context="baseline_weights" + ) + # Calculate updates from the received weights deltas = [] for weight in weights_received: + weight_mapping = self._as_state_mapping(weight, context="received weights") + self._assert_payload_size(weight_mapping, source="Received") + + unknown_keys = set(weight_mapping).difference(baseline_mapping) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Received weights include unexpected parameter(s): {unknown}.") + delta = OrderedDict() - for name, current_weight in weight.items(): - baseline = baseline_weights[name] + for name, current_weight in weight_mapping.items(): + if not isinstance(current_weight, torch.Tensor): + raise TypeError( + f"Received tensor '{name}' must be a torch.Tensor, " + f"received {type(current_weight).__name__}." + ) + + baseline = baseline_mapping[name] + if not isinstance(baseline, torch.Tensor): + raise TypeError( + f"Baseline tensor '{name}' must be a torch.Tensor, " + f"received {type(baseline).__name__}." + ) # Calculate update - _delta = current_weight - baseline + _delta = self._compute_tensor_delta(current_weight, baseline, name) delta[name] = _delta deltas.append(delta) @@ -37,10 +251,27 @@ def compute_weight_deltas( def update_weights(self, deltas): """Updates the existing model weights from the provided deltas.""" baseline_weights = self.extract_weights() + delta_mapping = self._as_state_mapping(deltas, context="deltas") updated_weights = OrderedDict() for name, weight in baseline_weights.items(): - updated_weights[name] = weight + deltas[name] + updated_weights[name] = weight + + unknown_keys = set(delta_mapping).difference(baseline_weights) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Delta includes unexpected parameter(s): {unknown}.") + + for name, delta in delta_mapping.items(): + baseline = baseline_weights[name] + if not isinstance(delta, torch.Tensor): + raise TypeError( + f"Delta tensor '{name}' must be a torch.Tensor, " + f"received {type(delta).__name__}." + ) + updated_weights[name] = self._apply_tensor_delta(baseline, delta, name) + + self._assert_payload_size(updated_weights, source="Updated") return updated_weights @@ -51,9 +282,43 @@ def extract_weights(self, model: Module | None = None): target_model = self.require_model() else: target_model = model - return target_model.cpu().state_dict() + + state_dict = target_model.state_dict() + adapter_names = self._resolve_adapter_parameter_names(target_model, state_dict) + keys_to_exchange = adapter_names or list(state_dict.keys()) + + payload = OrderedDict( + ( + name, + self._to_transport_tensor(state_dict[name], name), + ) + for name in keys_to_exchange + ) + self._assert_payload_size(payload, source="Extracted") + return payload def load_weights(self, weights): """Loads the model weights passed in as a parameter.""" + weights_mapping = self._as_state_mapping(weights, context="weights") + self._assert_payload_size(weights_mapping, source="Inbound") + model: Module = self.require_model() - model.load_state_dict(weights, strict=True) + current_state = model.state_dict() + + unknown_keys = set(weights_mapping).difference(current_state) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Inbound weights include unexpected parameter(s): {unknown}.") + + merged_state = OrderedDict(current_state.items()) + for name, incoming_tensor in weights_mapping.items(): + if not isinstance(incoming_tensor, torch.Tensor): + raise TypeError( + f"Inbound tensor '{name}' must be a torch.Tensor, " + f"received {type(incoming_tensor).__name__}." + ) + merged_state[name] = self._cast_tensor_like( + incoming_tensor, current_state[name], name + ) + + model.load_state_dict(merged_state, strict=True) diff --git a/plato/datasources/lerobot.py b/plato/datasources/lerobot.py new file mode 100644 index 000000000..f6d632ff0 --- /dev/null +++ b/plato/datasources/lerobot.py @@ -0,0 +1,771 @@ +"""LeRobot datasource with deterministic, episode-aware client partitioning.""" + +from __future__ import annotations + +import hashlib +import inspect +import logging +import random +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import Any + +from plato.config import Config +from plato.datasources import base + +_DEFAULT_TRAIN_SPLIT = 0.8 +_DEFAULT_SPLIT_SEED = 1 +_DEFAULT_NORMALIZE_MEAN = (0.485, 0.456, 0.406) +_DEFAULT_NORMALIZE_STD = (0.229, 0.224, 0.225) + +_EPISODE_INDEX_KEYS = ("episode_index", "episode_id", "index") +_TASK_KEYS = ( + "task", + "task_name", + "language_instruction", + "language_instruction_2", + "language_instruction_3", + "task_id", +) + + +class _EmptyDataset: + """A minimal dataset object for empty episode partitions.""" + + targets: list[Any] = [] + classes: list[str] = [] + + def __len__(self) -> int: + return 0 + + def __getitem__(self, index: int): + raise IndexError(f"Empty dataset does not contain index {index}.") + + +class _MappedLeRobotDataset: + """Wrap a LeRobot dataset and attach Plato-friendly canonical keys.""" + + def __init__(self, dataset: Any): + self._dataset = dataset + self.targets = getattr(dataset, "targets", None) + self.classes = getattr(dataset, "classes", None) + + def __len__(self) -> int: + return len(self._dataset) + + def __getitem__(self, index: int) -> dict[str, Any]: + sample = self._dataset[index] + + if isinstance(sample, Mapping): + raw_sample = dict(sample) + else: + return { + "plato_inputs": sample, + "plato_targets": None, + "plato_metadata": {}, + } + + inputs: dict[str, Any] = {} + targets: dict[str, Any] = {} + metadata: dict[str, Any] = {} + + for key, value in raw_sample.items(): + if key.startswith("observation"): + inputs[key] = value + elif key == "action" or key.startswith("action."): + targets[key] = value + else: + metadata[key] = value + + mapped = dict(raw_sample) + mapped["plato_inputs"] = inputs + mapped["plato_targets"] = raw_sample.get("action", targets or None) + mapped["plato_metadata"] = metadata + return mapped + + def __getattr__(self, name: str) -> Any: + return getattr(self._dataset, name) + + +def _import_lerobot() -> tuple[Any, Any]: + try: + from lerobot.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, + ) + except ImportError as exc: # pragma: no cover - exercised without robotics extra. + raise ImportError( + "LeRobot datasource requires optional robotics dependencies. " + "Install them with \"uv sync --extra robotics\" before using " + '"data.datasource = \"LeRobot\"". ' + ) from exc + + return LeRobotDataset, LeRobotDatasetMetadata + + +def _to_plain(value: Any) -> Any: + if value is None: + return None + if isinstance(value, Mapping): + return {str(key): _to_plain(val) for key, val in value.items()} + if hasattr(value, "_asdict"): + return {str(key): _to_plain(val) for key, val in value._asdict().items()} + if isinstance(value, (list, tuple, set)): + return [_to_plain(item) for item in value] + return value + + +def _to_plain_dict(value: Any) -> dict[str, Any]: + plain = _to_plain(value) + if plain is None: + return {} + if isinstance(plain, Mapping): + return {str(key): val for key, val in plain.items()} + raise TypeError(f"Expected mapping-like configuration, got {type(value)}.") + + +def _as_int(value: Any) -> int | None: + if value is None: + return None + if isinstance(value, bool): + return int(value) + if isinstance(value, int): + return value + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _stable_seed(seed: int, key: str) -> int: + digest = hashlib.sha256(f"{seed}:{key}".encode("utf-8")).digest() + return int.from_bytes(digest[:8], "big") + + +def _parse_size(value: Any) -> tuple[int, int] | int | None: + if value is None: + return None + if isinstance(value, int): + return value + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + parts = [int(item) for item in value] + if not parts: + return None + if len(parts) == 1: + return parts[0] + return (parts[0], parts[1]) + raise TypeError("Expected an int or sequence for transform size.") + + +def _parse_float_sequence(value: Any, default: Sequence[float]) -> tuple[float, ...]: + if value is None: + return tuple(float(item) for item in default) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + parsed = [float(item) for item in value] + if parsed: + return tuple(parsed) + raise TypeError("Expected a numeric sequence for normalization values.") + + +def _resolve_interpolation(value: Any) -> Any: + if value is None: + return None + + try: + from torchvision.transforms import InterpolationMode + except ImportError: + return None + + interpolation_map = { + "nearest": InterpolationMode.NEAREST, + "bilinear": InterpolationMode.BILINEAR, + "bicubic": InterpolationMode.BICUBIC, + "lanczos": InterpolationMode.LANCZOS, + } + return interpolation_map.get(str(value).strip().lower(), None) + + +def _build_image_transforms(transform_cfg: Mapping[str, Any]) -> Any | None: + if not transform_cfg: + return None + + enable = transform_cfg.get("enable", True) + if isinstance(enable, bool) and not enable: + return None + + try: + import torch + from torchvision import transforms as tv_transforms + except ImportError as exc: + raise ImportError( + "LeRobot image transforms require torchvision and torch." + ) from exc + + transform_steps: list[Any] = [] + + image_size = _parse_size(transform_cfg.get("image_size")) + interpolation = _resolve_interpolation(transform_cfg.get("interpolation")) + + if image_size is not None: + resize_kwargs: dict[str, Any] = {} + if interpolation is not None: + resize_kwargs["interpolation"] = interpolation + transform_steps.append(tv_transforms.Resize(image_size, **resize_kwargs)) + + center_crop_cfg = transform_cfg.get("center_crop", None) + crop_size = None + if isinstance(center_crop_cfg, bool): + if center_crop_cfg and image_size is not None: + crop_size = image_size + elif center_crop_cfg is not None: + crop_size = _parse_size(center_crop_cfg) + elif transform_cfg.get("crop_size") is not None: + crop_size = _parse_size(transform_cfg.get("crop_size")) + + if crop_size is not None: + transform_steps.append(tv_transforms.CenterCrop(crop_size)) + + normalize = bool(transform_cfg.get("normalize", False)) + if normalize: + convert_dtype = getattr(tv_transforms, "ConvertImageDtype", None) + if callable(convert_dtype): + transform_steps.append(convert_dtype(torch.float32)) + + mean = _parse_float_sequence( + transform_cfg.get("mean"), + _DEFAULT_NORMALIZE_MEAN, + ) + std = _parse_float_sequence(transform_cfg.get("std"), _DEFAULT_NORMALIZE_STD) + transform_steps.append(tv_transforms.Normalize(mean=mean, std=std)) + + if not transform_steps: + return None + + compose = getattr(tv_transforms, "Compose", None) + if not callable(compose): + return None + + return compose(transform_steps) + + +def _normalize_delta_timestamps(value: Any) -> dict[str, list[float]] | None: + if value is None: + return None + + parsed = _to_plain(value) + if not isinstance(parsed, Mapping): + raise TypeError( + '"parameters.dataset.delta_timestamps" must be a mapping of ' + "key -> list[float]." + ) + + normalized: dict[str, list[float]] = {} + for key, offsets in parsed.items(): + if not isinstance(offsets, Sequence) or isinstance(offsets, (str, bytes)): + raise TypeError( + f"Delta timestamps for '{key}' must be a list-like sequence." + ) + normalized[str(key)] = [float(offset) for offset in offsets] + + return normalized + + +def _columnar_to_rows(columns: Mapping[str, Any]) -> list[dict[str, Any]]: + normalized: dict[str, list[Any]] = {} + max_len = 0 + + for key, value in columns.items(): + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + values = list(value) + elif hasattr(value, "tolist"): + values = list(value.tolist()) + else: + values = [value] + + normalized[str(key)] = values + max_len = max(max_len, len(values)) + + rows: list[dict[str, Any]] = [] + for row_idx in range(max_len): + row: dict[str, Any] = {} + for key, values in normalized.items(): + if row_idx < len(values): + row[key] = values[row_idx] + rows.append(row) + + return rows + + +def _episode_rows(episodes: Any) -> list[dict[str, Any]]: + if episodes is None: + return [] + + if isinstance(episodes, Mapping): + return _columnar_to_rows(episodes) + + to_pylist = getattr(episodes, "to_pylist", None) + if callable(to_pylist): + pylist = to_pylist() + if isinstance(pylist, list): + return [ + dict(item) if isinstance(item, Mapping) else {"value": item} + for item in pylist + ] + + to_dict = getattr(episodes, "to_dict", None) + if callable(to_dict): + try: + as_dict = to_dict(orient="list") + except TypeError: + as_dict = to_dict() + if isinstance(as_dict, Mapping): + return _columnar_to_rows(as_dict) + + if isinstance(episodes, Sequence) and not isinstance(episodes, (str, bytes)): + rows: list[dict[str, Any]] = [] + for entry in episodes: + if isinstance(entry, Mapping): + rows.append(dict(entry)) + elif hasattr(entry, "_asdict"): + rows.append(_to_plain_dict(entry)) + else: + rows.append({"value": entry}) + return rows + + return [] + + +def _resolve_episode_indices(metadata: Any) -> list[int]: + total_episodes = _as_int(getattr(metadata, "total_episodes", None)) + + if total_episodes is None: + total_episodes = len(_episode_rows(getattr(metadata, "episodes", None))) + + if total_episodes is None or total_episodes < 0: + return [] + + return list(range(total_episodes)) + + +def _resolve_task_name(row: Mapping[str, Any], tasks_lookup: Any) -> str | None: + task_index = _as_int(row.get("task_index")) + if task_index is not None and isinstance(tasks_lookup, Sequence): + if not isinstance(tasks_lookup, (str, bytes)) and 0 <= task_index < len( + tasks_lookup + ): + return str(tasks_lookup[task_index]) + + for key in _TASK_KEYS: + if key in row and row[key] is not None: + return str(row[key]) + + return None + + +def _resolve_episode_tasks(metadata: Any, episodes: Sequence[int]) -> dict[int, str | None]: + episode_tasks = {episode: None for episode in episodes} + episode_rows = _episode_rows(getattr(metadata, "episodes", None)) + tasks_lookup = _to_plain(getattr(metadata, "tasks", None)) + + for position, row in enumerate(episode_rows): + episode_index = None + for key in _EPISODE_INDEX_KEYS: + if key in row: + episode_index = _as_int(row.get(key)) + if episode_index is not None: + break + + if episode_index is None: + episode_index = position + + if episode_index not in episode_tasks: + continue + + task_name = _resolve_task_name(row, tasks_lookup) + if task_name is not None: + episode_tasks[episode_index] = task_name + + return episode_tasks + + +def _split_count(total: int, train_ratio: float) -> int: + if total <= 1: + return total + + if train_ratio <= 0: + return 0 + if train_ratio >= 1: + return total + + split_count = int(round(total * train_ratio)) + split_count = max(1, min(total - 1, split_count)) + return split_count + + +def _split_episodes( + episodes: Sequence[int], + episode_tasks: Mapping[int, str | None], + train_ratio: float, + seed: int, + task_aware: bool, +) -> tuple[list[int], list[int]]: + if not episodes: + return [], [] + + all_episodes = [int(episode) for episode in episodes] + has_task_info = task_aware and any(episode_tasks.get(ep) for ep in all_episodes) + + if has_task_info: + grouped: dict[str, list[int]] = defaultdict(list) + for episode in all_episodes: + grouped[str(episode_tasks.get(episode) or "__unknown_task__")].append( + episode + ) + + train_episodes: list[int] = [] + test_episodes: list[int] = [] + + for task_name in sorted(grouped): + task_episodes = sorted(grouped[task_name]) + rng = random.Random(_stable_seed(seed, f"split:{task_name}")) + rng.shuffle(task_episodes) + + split_idx = _split_count(len(task_episodes), train_ratio) + train_episodes.extend(task_episodes[:split_idx]) + test_episodes.extend(task_episodes[split_idx:]) + else: + shuffled = sorted(all_episodes) + random.Random(seed).shuffle(shuffled) + split_idx = _split_count(len(shuffled), train_ratio) + train_episodes = shuffled[:split_idx] + test_episodes = shuffled[split_idx:] + + if not train_episodes and test_episodes: + train_episodes.append(test_episodes.pop(0)) + + return sorted(train_episodes), sorted(test_episodes) + + +def _normalize_episode_list(value: Any) -> list[int] | None: + if value is None: + return None + + parsed = _to_plain(value) + + if isinstance(parsed, int): + return [int(parsed)] + + if isinstance(parsed, Sequence) and not isinstance(parsed, (str, bytes)): + return [int(item) for item in parsed] + + raise TypeError("Episode lists must be an int or a list of ints.") + + +def _resolve_episode_split( + all_episodes: Sequence[int], + explicit_train: list[int] | None, + explicit_test: list[int] | None, + episode_tasks: Mapping[int, str | None], + train_ratio: float, + seed: int, + task_aware: bool, +) -> tuple[list[int], list[int]]: + episode_set = set(int(episode) for episode in all_episodes) + + if explicit_train is None and explicit_test is None: + return _split_episodes(all_episodes, episode_tasks, train_ratio, seed, task_aware) + + train_episodes = [ + int(episode) for episode in (explicit_train or []) if int(episode) in episode_set + ] + test_episodes = [ + int(episode) + for episode in (explicit_test or []) + if int(episode) in episode_set and int(episode) not in train_episodes + ] + + if not train_episodes: + train_episodes = [ + episode for episode in all_episodes if episode not in set(test_episodes) + ] + + if not test_episodes: + test_episodes = [ + episode for episode in all_episodes if episode not in set(train_episodes) + ] + + if not train_episodes and test_episodes: + train_episodes.append(test_episodes.pop(0)) + + return sorted(train_episodes), sorted(test_episodes) + + +def _partition_episodes( + episodes: Sequence[int], + episode_tasks: Mapping[int, str | None], + total_clients: int, + client_id: int, + seed: int, + task_aware: bool, +) -> list[int]: + if not episodes: + return [] + + if total_clients <= 1 or client_id <= 0: + return sorted(int(episode) for episode in episodes) + + client_slot = (int(client_id) - 1) % int(total_clients) + all_episodes = [int(episode) for episode in episodes] + + has_task_info = task_aware and any(episode_tasks.get(ep) for ep in all_episodes) + + selected: list[int] = [] + + if has_task_info: + grouped: dict[str, list[int]] = defaultdict(list) + for episode in all_episodes: + grouped[str(episode_tasks.get(episode) or "__unknown_task__")].append( + episode + ) + + for task_name in sorted(grouped): + task_episodes = sorted(grouped[task_name]) + rng = random.Random(_stable_seed(seed, f"partition:{task_name}")) + rng.shuffle(task_episodes) + + for idx, episode in enumerate(task_episodes): + if idx % total_clients == client_slot: + selected.append(episode) + else: + shuffled = sorted(all_episodes) + random.Random(seed).shuffle(shuffled) + selected = [ + episode + for idx, episode in enumerate(shuffled) + if idx % total_clients == client_slot + ] + + if not selected: + fallback = sorted(all_episodes) + random.Random(_stable_seed(seed, "fallback")).shuffle(fallback) + selected = [fallback[client_slot % len(fallback)]] + + return sorted(selected) + + +def _resolve_default_seed() -> int: + data_cfg = getattr(Config(), "data", None) + configured_seed = _as_int(getattr(data_cfg, "random_seed", None)) + if configured_seed is not None: + return configured_seed + return _DEFAULT_SPLIT_SEED + + +def _resolve_total_clients(config: Any) -> int: + clients_cfg = getattr(config, "clients", None) + total_clients = _as_int(getattr(clients_cfg, "total_clients", 1)) + if total_clients is None or total_clients <= 0: + return 1 + return total_clients + + +def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: + try: + signature = inspect.signature(dataset_cls.__init__) + except (TypeError, ValueError): + return dict(kwargs) + + accepts_var_kwargs = any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + if accepts_var_kwargs: + return dict(kwargs) + + valid_parameters = { + name for name in signature.parameters.keys() if name != "self" + } + filtered = {key: value for key, value in kwargs.items() if key in valid_parameters} + + dropped = sorted(set(kwargs.keys()) - set(filtered.keys())) + if dropped: + logging.warning( + "LeRobot datasource ignored unsupported dataset kwargs: %s", + ", ".join(dropped), + ) + + return filtered + + +class DataSource(base.DataSource): + """LeRobot datasource with deterministic train/test and per-client episode splits.""" + + def __init__(self, client_id: int = 0, **kwargs): + super().__init__() + + LeRobotDataset, LeRobotDatasetMetadata = _import_lerobot() + + config = Config() + params_cfg = getattr(config, "parameters", None) + + dataset_cfg = _to_plain_dict(getattr(params_cfg, "dataset", None)) + transform_cfg = _to_plain_dict(getattr(params_cfg, "transforms", None)) + + dataset_cfg.update(_to_plain_dict(kwargs.pop("dataset_kwargs", None))) + transform_cfg.update(_to_plain_dict(kwargs.pop("transform_kwargs", None))) + + for key in ( + "repo_id", + "delta_timestamps", + "split_seed", + "train_split", + "test_split", + "train_episodes", + "test_episodes", + "task_aware_split", + "task_aware_partition", + ): + if key in kwargs: + dataset_cfg[key] = kwargs.pop(key) + + for key in ( + "image_size", + "interpolation", + "center_crop", + "crop_size", + "normalize", + "mean", + "std", + "enable", + ): + if key in kwargs: + transform_cfg[key] = kwargs.pop(key) + + dataset_cfg.update(_to_plain_dict(kwargs)) + + repo_id = str(dataset_cfg.pop("repo_id", "")).strip() + if not repo_id: + raise ValueError( + "LeRobot datasource requires " + '"parameters.dataset.repo_id" to be set.' + ) + + train_split_raw = dataset_cfg.pop("train_split", _DEFAULT_TRAIN_SPLIT) + test_split_raw = dataset_cfg.pop("test_split", None) + + train_split = float(train_split_raw) + if test_split_raw is not None: + train_split = 1.0 - float(test_split_raw) + train_split = max(0.0, min(1.0, train_split)) + + split_seed = _as_int(dataset_cfg.pop("split_seed", None)) + if split_seed is None: + split_seed = _resolve_default_seed() + + task_aware_split = bool(dataset_cfg.pop("task_aware_split", True)) + task_aware_partition = bool(dataset_cfg.pop("task_aware_partition", True)) + + delta_timestamps = _normalize_delta_timestamps( + dataset_cfg.pop("delta_timestamps", None) + ) + + explicit_train_episodes = _normalize_episode_list( + dataset_cfg.pop("train_episodes", None) + ) + explicit_test_episodes = _normalize_episode_list( + dataset_cfg.pop("test_episodes", None) + ) + + metadata = LeRobotDatasetMetadata(repo_id) + all_episodes = _resolve_episode_indices(metadata) + + if not all_episodes: + raise ValueError(f"No episodes found for LeRobot dataset '{repo_id}'.") + + episode_tasks = _resolve_episode_tasks(metadata, all_episodes) + train_episodes, test_episodes = _resolve_episode_split( + all_episodes, + explicit_train_episodes, + explicit_test_episodes, + episode_tasks, + train_split, + split_seed, + task_aware_split, + ) + + total_clients = _resolve_total_clients(config) + resolved_client_id = int(client_id) + + client_train_episodes = _partition_episodes( + train_episodes, + episode_tasks, + total_clients, + resolved_client_id, + split_seed, + task_aware_partition, + ) + client_test_episodes = _partition_episodes( + test_episodes, + episode_tasks, + total_clients, + resolved_client_id, + split_seed + 1, + task_aware_partition, + ) + + image_transforms = _build_image_transforms(transform_cfg) + dataset_kwargs = _filter_constructor_kwargs(LeRobotDataset, dataset_cfg) + + if client_train_episodes: + train_dataset = LeRobotDataset( + repo_id, + episodes=client_train_episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + **dataset_kwargs, + ) + else: + train_dataset = _EmptyDataset() + + if client_test_episodes: + test_dataset = LeRobotDataset( + repo_id, + episodes=client_test_episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + **dataset_kwargs, + ) + else: + test_dataset = _EmptyDataset() + + self.trainset = _MappedLeRobotDataset(train_dataset) + self.testset = _MappedLeRobotDataset(test_dataset) + + self.repo_id = repo_id + self.client_id = resolved_client_id + self.train_episodes = client_train_episodes + self.test_episodes = client_test_episodes + self.meta = metadata + + logging.info( + "LeRobot datasource ready for client %s: train episodes=%s, test episodes=%s", + resolved_client_id, + len(client_train_episodes), + len(client_test_episodes), + ) + + @staticmethod + def input_shape(): + """Return shape hint from configured transform image size when available.""" + params_cfg = getattr(Config(), "parameters", None) + transform_cfg = _to_plain_dict(getattr(params_cfg, "transforms", None)) + + image_size = _parse_size(transform_cfg.get("image_size")) + if isinstance(image_size, tuple): + return (3, image_size[0], image_size[1]) + + raise ValueError( + "LeRobot datasource input shape is dataset-dependent. " + 'Set "parameters.transforms.image_size" ' + "for a static shape hint." + ) diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index 42525a5ef..0a609d5e5 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,6 +11,7 @@ feature, femnist, huggingface, + lerobot, lora, nanochat, purchase, @@ -31,7 +32,7 @@ "Nanochat": nanochat, } -registered_partitioned_datasources = {"FEMNIST": femnist} +registered_partitioned_datasources = {"FEMNIST": femnist, "LeRobot": lerobot} _datasource_aliases = { "STL10": ("Torchvision", {"dataset_name": "STL10"}), diff --git a/plato/models/registry.py b/plato/models/registry.py index 382a2d064..1dbda56cf 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -17,6 +17,7 @@ multilayer, nanochat, resnet, + smolvla, torch_hub, vgg, vit, @@ -42,6 +43,7 @@ "huggingface": huggingface.Model, "vit": vit.Model, "nanochat": nanochat.Model, + "smolvla": smolvla.Model, } registered_mlx_models = {} diff --git a/plato/models/smolvla.py b/plato/models/smolvla.py new file mode 100644 index 000000000..dd570d8d3 --- /dev/null +++ b/plato/models/smolvla.py @@ -0,0 +1,284 @@ +""" +Factory for SmolVLA policies integrated with Plato's model registry. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from typing import Any + +import torch.nn as nn + +from plato.config import Config + +DEFAULT_POLICY_PATH = "lerobot/smolvla_base" +DEFAULT_FINETUNE_MODE = "full" +SUPPORTED_FINETUNE_MODES = {"full", "adapter"} +DEFAULT_ADAPTER_PATTERNS = ("adapter", "lora", "peft") + + +def _node_to_dict(node: Any) -> dict[str, Any]: + """Convert a config section into a plain dictionary.""" + if node is None: + return {} + if isinstance(node, dict): + return dict(node) + if hasattr(node, "_asdict"): + return dict(node._asdict()) + if hasattr(node, "__dict__"): + return { + key: value + for key, value in node.__dict__.items() + if not key.startswith("_") + } + raise TypeError("Unsupported policy configuration format.") + + +def _import_smolvla_policy() -> type[Any]: + """Import SmolVLAPolicy lazily to keep robotics dependencies optional.""" + try: + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + except ImportError as exc: # pragma: no cover - environment dependent + raise ImportError( + "SmolVLA requires LeRobot robotics dependencies. " + "Install them with `uv sync --extra robotics`." + ) from exc + return SmolVLAPolicy + + +def _resolve_policy_config() -> dict[str, Any]: + """Read [parameters.policy] config values, if present.""" + parameters = getattr(Config(), "parameters", None) + policy = getattr(parameters, "policy", None) + return _node_to_dict(policy) + + +def _resolve_policy_path( + model_name: str | None, + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> str: + """Resolve pretrained source from kwargs, config, or defaults.""" + candidate = ( + overrides.get("policy_path") + or overrides.get("path") + or policy_config.get("path") + ) + + if candidate is None and isinstance(model_name, str) and model_name: + if model_name.lower() not in {"smolvla", "smolvla_base"}: + candidate = model_name + + if candidate is None: + candidate = DEFAULT_POLICY_PATH + + if not isinstance(candidate, str): + raise TypeError("SmolVLA policy path must be provided as a string.") + + resolved = candidate.strip() + if not resolved: + raise ValueError("SmolVLA policy path cannot be empty.") + + if resolved.lower() == "smolvla_base": + return DEFAULT_POLICY_PATH + + return resolved + + +def _resolve_finetune_mode( + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> str: + """Resolve the finetune mode with validation.""" + mode = overrides.get("finetune_mode", policy_config.get("finetune_mode")) + if mode is None: + mode = DEFAULT_FINETUNE_MODE + + if not isinstance(mode, str): + raise TypeError("`finetune_mode` must be a string.") + + normalized_mode = mode.strip().lower() + if normalized_mode not in SUPPORTED_FINETUNE_MODES: + raise ValueError( + "Unsupported SmolVLA finetune mode " + f"'{mode}'. Expected one of {sorted(SUPPORTED_FINETUNE_MODES)}." + ) + return normalized_mode + + +def _resolve_adapter_patterns( + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> list[str]: + """Resolve adapter parameter patterns from kwargs or config.""" + raw_patterns = overrides.get( + "adapter_parameter_patterns", + policy_config.get("adapter_parameter_patterns"), + ) + + if raw_patterns is None: + return list(DEFAULT_ADAPTER_PATTERNS) + + if isinstance(raw_patterns, str): + parsed = [token.strip() for token in raw_patterns.split(",") if token.strip()] + return parsed or list(DEFAULT_ADAPTER_PATTERNS) + + if isinstance(raw_patterns, Iterable): + parsed = [ + token.strip() + for token in raw_patterns + if isinstance(token, str) and token.strip() + ] + return parsed or list(DEFAULT_ADAPTER_PATTERNS) + + raise TypeError( + "`adapter_parameter_patterns` must be a comma-separated string " + "or list of strings." + ) + + +def _count_trainable_parameters(model: nn.Module) -> int: + return sum( + parameter.numel() for parameter in model.parameters() if parameter.requires_grad + ) + + +def _apply_finetune_mode( + policy: nn.Module, + finetune_mode: str, + adapter_patterns: list[str], +) -> dict[str, Any]: + """Set requires_grad according to the requested finetune mode.""" + named_parameters = list(policy.named_parameters()) + if not named_parameters: + raise ValueError("Loaded SmolVLA policy has no named parameters.") + + if finetune_mode == "full": + for _, parameter in named_parameters: + parameter.requires_grad = True + + trainable_names = [name for name, _ in named_parameters] + return { + "trainable_names": trainable_names, + "fallback_mode": None, + } + + original_trainable_names = { + name for name, parameter in named_parameters if parameter.requires_grad + } + + for _, parameter in named_parameters: + parameter.requires_grad = False + + lowered_patterns = [pattern.lower() for pattern in adapter_patterns] + matched_names: set[str] = set() + for name, parameter in named_parameters: + lowered_name = name.lower() + if any(pattern in lowered_name for pattern in lowered_patterns): + parameter.requires_grad = True + matched_names.add(name) + + fallback_mode: str | None = None + if not matched_names and original_trainable_names: + for name, parameter in named_parameters: + if name in original_trainable_names: + parameter.requires_grad = True + matched_names.add(name) + fallback_mode = "original_requires_grad" + logging.warning( + "SmolVLA adapter mode found no parameter names matching patterns %s; " + "falling back to model-defined requires_grad flags.", + adapter_patterns, + ) + + if _count_trainable_parameters(policy) == 0: + raise ValueError( + "SmolVLA adapter mode selected, but no trainable parameters were " + "resolved. Configure `adapter_parameter_patterns` or set " + "`finetune_mode` to 'full'." + ) + + return { + "trainable_names": sorted(matched_names), + "fallback_mode": fallback_mode, + } + + +def _ensure_checkpoint_compatibility(policy: nn.Module) -> None: + """Validate required methods for Plato aggregation and checkpoint flow.""" + required_methods = ("state_dict", "load_state_dict", "save_pretrained") + missing = [name for name in required_methods if not hasattr(policy, name)] + if missing: + joined = ", ".join(missing) + raise TypeError( + "Loaded SmolVLA policy is incompatible with Plato checkpoints; " + f"missing method(s): {joined}." + ) + + +class Model: + """Factory for LeRobot SmolVLA policies.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any) -> nn.Module: + """ + Build a SmolVLA policy. + + Keyword Args: + policy_path/path: Hugging Face repo id or local path for pretrained policy. + token: HF token for private repositories. + strict: Strictness flag forwarded to SmolVLAPolicy.from_pretrained(). + finetune_mode: "full" or "adapter". + adapter_parameter_patterns: Names/patterns used to select adapter params. + """ + policy_config = _resolve_policy_config() + policy_path = _resolve_policy_path(model_name, policy_config, kwargs) + + token = kwargs.get("token", policy_config.get("token", os.getenv("HF_TOKEN"))) + strict = kwargs.get("strict", policy_config.get("strict", True)) + + SmolVLAPolicy = _import_smolvla_policy() + + try: + policy = SmolVLAPolicy.from_pretrained( + policy_path, + token=token, + strict=strict, + ) + except Exception as exc: # pragma: no cover - exercised via integration + raise RuntimeError( + "Failed to load SmolVLA policy from " + f"'{policy_path}'. Check `parameters.policy.path` and access token." + ) from exc + + if not isinstance(policy, nn.Module): + raise TypeError( + "SmolVLA policy loader returned a non-module object. " + "Expected a torch.nn.Module-compatible policy." + ) + + finetune_mode = _resolve_finetune_mode(policy_config, kwargs) + adapter_patterns = _resolve_adapter_patterns(policy_config, kwargs) + trainable_metadata = _apply_finetune_mode( + policy, + finetune_mode=finetune_mode, + adapter_patterns=adapter_patterns, + ) + + _ensure_checkpoint_compatibility(policy) + + trainable_count = _count_trainable_parameters(policy) + setattr(policy, "plato_policy_path", policy_path) + setattr(policy, "plato_finetune_mode", finetune_mode) + setattr(policy, "plato_adapter_patterns", tuple(adapter_patterns)) + setattr(policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"]) + setattr(policy, "plato_trainable_parameter_count", trainable_count) + setattr( + policy, + "plato_trainable_parameter_names", + tuple(trainable_metadata["trainable_names"]), + ) + + return policy diff --git a/plato/samplers/modality_iid.py b/plato/samplers/modality_iid.py index 66ed03c85..351f11862 100644 --- a/plato/samplers/modality_iid.py +++ b/plato/samplers/modality_iid.py @@ -16,7 +16,7 @@ class Sampler(base.Sampler): """Create a data sampler for each client to use a randomly divided partition of the dataset.""" - def __init__(self, datasource, client_id): + def __init__(self, datasource, client_id, testing=False): super().__init__() self.client_id = client_id diff --git a/plato/samplers/modality_quantity_noniid.py b/plato/samplers/modality_quantity_noniid.py index 606762d77..f17113ba7 100644 --- a/plato/samplers/modality_quantity_noniid.py +++ b/plato/samplers/modality_quantity_noniid.py @@ -17,7 +17,7 @@ class Sampler(base.Sampler): """Create a data sampler for each client to use a randomly divided partition of the dataset.""" - def __init__(self, datasource, client_id): + def __init__(self, datasource, client_id, testing=False): super().__init__() self.client_id = client_id if hasattr(datasource, "get_modality_name"): diff --git a/plato/trainers/lerobot.py b/plato/trainers/lerobot.py new file mode 100644 index 000000000..e66d4ae9e --- /dev/null +++ b/plato/trainers/lerobot.py @@ -0,0 +1,689 @@ +"""Composable trainer for LeRobot policies such as SmolVLA.""" + +from __future__ import annotations + +import contextlib +import logging +from collections.abc import Callable, Iterable, Mapping +from typing import Any, cast + +import torch +import torch.nn as nn +import torch.utils.data +from torch.utils.data._utils.collate import default_collate + +from plato.config import Config +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import CustomCollateFnDataLoaderStrategy +from plato.trainers.strategies.base import ( + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) + +_RESERVED_KEYS = frozenset({"plato_inputs", "plato_targets", "plato_metadata"}) +_SUPPORTED_POLICY_PRECISIONS = frozenset({"fp32", "fp16", "bf16"}) + + +def _config_node_to_dict(node: Any) -> dict[str, Any]: + """Convert config sections to plain dictionaries.""" + if node is None: + return {} + if isinstance(node, dict): + return dict(node) + if hasattr(node, "_asdict"): + return dict(node._asdict()) + if hasattr(node, "__dict__"): + return { + key: value + for key, value in node.__dict__.items() + if not key.startswith("_") + } + return {} + + +def _import_make_pre_post_processors() -> Callable[..., tuple[Callable, Callable]]: + """Import LeRobot pre/post processors lazily to keep robotics deps optional.""" + try: + from lerobot.policies.factory import make_pre_post_processors + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "LeRobot trainer requires optional robotics dependencies. " + "Install them with `uv sync --extra robotics`." + ) from exc + + return make_pre_post_processors + + +def _move_to_device(value: Any, device: torch.device | str) -> Any: + """Recursively move tensors inside nested containers to target device.""" + if torch.is_tensor(value): + return value.to(device) + if isinstance(value, Mapping): + return {key: _move_to_device(item, device) for key, item in value.items()} + if isinstance(value, list): + return [_move_to_device(item, device) for item in value] + if isinstance(value, tuple): + return tuple(_move_to_device(item, device) for item in value) + if hasattr(value, "to"): + try: + return value.to(device) + except (TypeError, AttributeError): + return value + return value + + +def _resolve_batch_size(value: Any) -> int | None: + """Infer batch size from the first tensor-like value in a nested structure.""" + if torch.is_tensor(value): + if value.ndim == 0: + return 1 + return int(value.shape[0]) + if isinstance(value, Mapping): + for nested in value.values(): + size = _resolve_batch_size(nested) + if size is not None: + return size + if isinstance(value, list): + return len(value) if value else None + return None + + +def _collate_values(values: list[Any]) -> Any: + """Collate values with a robust fallback for heterogeneous metadata fields.""" + if not values: + return values + if any(value is None for value in values): + return list(values) + + if all(isinstance(value, Mapping) for value in values): + keys: list[str] = [] + for value in values: + for key in value: + if key not in keys: + keys.append(str(key)) + return { + key: _collate_values( + [cast(Mapping[str, Any], value).get(key) for value in values] + ) + for key in keys + } + + try: + return default_collate(values) + except (TypeError, RuntimeError): + return list(values) + + +def _extract_tensor_label(label: Any) -> torch.Tensor | None: + """Extract a tensor label from common LeRobot sample target structures.""" + if torch.is_tensor(label): + return label + if isinstance(label, Mapping): + action = label.get("action") + if torch.is_tensor(action): + return action + return None + + +class LeRobotBatch(dict): + """Dictionary batch wrapper that supports `.to(device)`.""" + + def to(self, device: torch.device | str): + for key, value in list(self.items()): + self[key] = _move_to_device(value, device) + return self + + +class LeRobotCollateWrapper: + """Collate LeRobot dict samples into `(inputs, labels)` for ComposableTrainer.""" + + def __call__( + self, + examples: Iterable[Any], + ) -> tuple[LeRobotBatch, torch.Tensor]: + example_list = list(examples) + if not example_list: + raise ValueError("LeRobot collate received an empty batch.") + + normalized_inputs: list[dict[str, Any]] = [] + raw_labels: list[Any] = [] + + for sample in example_list: + if isinstance(sample, Mapping): + sample_dict = dict(sample) + label = sample_dict.get("plato_targets") + if label is None: + label = sample_dict.get("action") + + payload = { + key: value + for key, value in sample_dict.items() + if key not in _RESERVED_KEYS + } + + if not payload: + plato_inputs = sample_dict.get("plato_inputs") + if isinstance(plato_inputs, Mapping): + payload = dict(plato_inputs) + + if "action" not in payload and torch.is_tensor(label): + payload["action"] = label + + normalized_inputs.append(payload) + raw_labels.append(label) + else: + normalized_inputs.append({"observation": sample}) + raw_labels.append(None) + + batched_inputs = _collate_values(normalized_inputs) + if not isinstance(batched_inputs, Mapping): + batched_inputs = {"inputs": batched_inputs} + batch_dict = LeRobotBatch(dict(batched_inputs)) + + tensor_labels = [_extract_tensor_label(label) for label in raw_labels] + if tensor_labels and all(label is not None for label in tensor_labels): + labels = _collate_values(cast(list[Any], tensor_labels)) + if torch.is_tensor(labels): + return batch_dict, labels + + action_tensor = batch_dict.get("action") + if torch.is_tensor(action_tensor): + return batch_dict, action_tensor + + batch_size = _resolve_batch_size(batch_dict) + if batch_size is None: + batch_size = len(example_list) + labels = torch.zeros(batch_size, dtype=torch.float32) + return batch_dict, labels + + +def _resolve_policy_forward( + model: nn.Module, + batch: Mapping[str, Any], + reduction: str, +) -> tuple[torch.Tensor, Mapping[str, Any]]: + """Call the policy and normalize output into `(loss_tensor, loss_dict)`.""" + forward_result = model.forward(batch, reduction=reduction) + + if torch.is_tensor(forward_result): + return forward_result, {} + + if isinstance(forward_result, tuple): + if len(forward_result) == 0: + raise ValueError("LeRobot policy forward returned an empty tuple.") + + first = forward_result[0] + second = forward_result[1] if len(forward_result) > 1 else {} + + if torch.is_tensor(first): + loss_dict = second if isinstance(second, Mapping) else {} + return first, cast(Mapping[str, Any], loss_dict) + + if torch.is_tensor(second): + loss_dict = first if isinstance(first, Mapping) else {} + return second, cast(Mapping[str, Any], loss_dict) + + if isinstance(first, Mapping): + maybe_loss = first.get("loss") + if torch.is_tensor(maybe_loss): + return maybe_loss, first + + if isinstance(forward_result, Mapping): + maybe_loss = forward_result.get("loss") + if torch.is_tensor(maybe_loss): + return maybe_loss, forward_result + + raise TypeError( + "LeRobot policy forward must return a tensor loss or a tuple containing " + "a tensor loss. Received: " + f"{type(forward_result)}." + ) + + +def _apply_preprocessor( + batch: LeRobotBatch, + context: TrainingContext, +) -> LeRobotBatch: + """Apply the optional LeRobot preprocessor.""" + preprocessor = context.state.get("lerobot_preprocessor") + if preprocessor is None: + return batch + + processed = preprocessor(dict(batch)) + if not isinstance(processed, Mapping): + raise TypeError( + "LeRobot preprocessor must return a mapping batch. " + f"Received {type(processed)}." + ) + return LeRobotBatch(dict(processed)) + + +def _summarize_loss_dict(loss_dict: Mapping[str, Any]) -> dict[str, float]: + """Store scalar loss dictionary values for debugging/monitoring hooks.""" + summary: dict[str, float] = {} + for key, value in loss_dict.items(): + if torch.is_tensor(value): + try: + summary[str(key)] = float(value.detach().cpu().item()) + except (RuntimeError, ValueError): + continue + elif isinstance(value, (int, float)): + summary[str(key)] = float(value) + return summary + + +def _resolve_sampler_for_loader(sampler: Any) -> Any: + """Resolve a sampler config value to a torch DataLoader sampler object.""" + if sampler is None: + return None + if isinstance(sampler, torch.utils.data.Sampler): + return sampler + if isinstance(sampler, (list, range)): + return torch.utils.data.SubsetRandomSampler(sampler) + if hasattr(sampler, "get"): + return sampler.get() + return sampler + + +def _resolve_precision(precision: Any) -> str: + """Normalize policy precision values from config.""" + if precision is None: + return "fp32" + if not isinstance(precision, str): + raise TypeError("`parameters.policy.precision` must be a string.") + + normalized = precision.strip().lower() + if normalized not in _SUPPORTED_POLICY_PRECISIONS: + supported = ", ".join(sorted(_SUPPORTED_POLICY_PRECISIONS)) + raise ValueError( + "Unsupported `parameters.policy.precision` value " + f"'{precision}'. Expected one of: {supported}." + ) + return normalized + + +def _resolve_runtime_device(device_value: Any, fallback_device: Any) -> torch.device: + """ + Resolve runtime device from policy config, falling back to trainer default. + + Raises explicit errors when a requested accelerator is unavailable so users + can detect mismatched config/environment early. + """ + if isinstance(fallback_device, torch.device): + fallback = fallback_device + else: + fallback = torch.device(str(fallback_device)) + + if device_value is None: + return fallback + if not isinstance(device_value, str): + raise TypeError("`parameters.policy.device` must be a string.") + + normalized = device_value.strip().lower() + if not normalized: + return fallback + + if normalized == "cpu": + return torch.device("cpu") + + if normalized == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + "`parameters.policy.device` is set to 'cuda' but CUDA is not " + "available on this host." + ) + return torch.device("cuda:0") + + if normalized.startswith("cuda:"): + if not torch.cuda.is_available(): + raise RuntimeError( + f"`parameters.policy.device` is set to '{device_value}' but CUDA " + "is not available on this host." + ) + try: + gpu_index = int(normalized.split(":", 1)[1]) + except (IndexError, ValueError) as exc: + raise ValueError( + f"Invalid CUDA device value: '{device_value}'." + ) from exc + if gpu_index < 0 or gpu_index >= torch.cuda.device_count(): + raise RuntimeError( + f"`parameters.policy.device` requested CUDA device {gpu_index}, " + f"but only {torch.cuda.device_count()} device(s) are available." + ) + return torch.device(normalized) + + if normalized == "mps": + mps_backend = getattr(torch.backends, "mps", None) + if mps_backend is None or not mps_backend.is_available(): + raise RuntimeError( + "`parameters.policy.device` is set to 'mps' but MPS is not " + "available on this host." + ) + return torch.device("mps") + + raise ValueError( + "Unsupported `parameters.policy.device` value " + f"'{device_value}'. Expected cpu, cuda[:index], or mps." + ) + + +def _autocast_context( + context: TrainingContext, +) -> tuple[contextlib.AbstractContextManager[Any], bool]: + """Resolve an autocast context from runtime precision and device settings.""" + precision = str(context.state.get("lerobot_precision", "fp32")).lower() + device = context.device + + if not isinstance(device, torch.device): + device = torch.device(str(device)) + + if precision == "fp32": + return contextlib.nullcontext(), False + + if device.type == "cuda": + if not torch.cuda.is_available(): + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' requested, but CUDA is unavailable. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + return torch.autocast(device_type="cuda", dtype=dtype), True + + if device.type == "cpu": + if precision == "bf16": + return torch.autocast(device_type="cpu", dtype=torch.bfloat16), True + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on CPU autocast. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + if device.type == "mps": + if precision == "fp16": + return torch.autocast(device_type="mps", dtype=torch.float16), True + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on MPS autocast. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on device type '%s'. " + "Falling back to fp32 execution.", + precision, + device.type, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + +class LeRobotTrainingStepStrategy(TrainingStepStrategy): + """Training step strategy for LeRobot policies with dict-style batches.""" + + def __init__(self, reduction: str = "mean"): + self.reduction = reduction + + def training_step( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + examples: torch.Tensor | Mapping[str, Any], + labels: torch.Tensor, # pylint: disable=unused-argument + loss_criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + del labels, loss_criterion + + if not isinstance(examples, Mapping): + raise TypeError( + "LeRobot training expects dictionary-style batches. " + f"Received {type(examples).__name__}." + ) + + autocast_guard, autocast_enabled = _autocast_context(context) + context.state["lerobot_autocast_enabled"] = autocast_enabled + with autocast_guard: + batch = _apply_preprocessor(LeRobotBatch(dict(examples)), context) + loss, loss_dict = _resolve_policy_forward( + model, + batch, + reduction=self.reduction, + ) + + if not torch.is_tensor(loss): + raise TypeError( + "LeRobot policy forward did not return a tensor loss." + ) + + loss.backward() + optimizer.step() + + context.state["optimizer_step_completed"] = True + context.state["lerobot_loss_dict"] = _summarize_loss_dict(loss_dict) + return loss.detach() + + +class LeRobotTestingStrategy(TestingStrategy): + """Compute a stable average evaluation loss for regression checks.""" + + def __init__(self, collate_fn: LeRobotCollateWrapper, reduction: str = "mean"): + self.collate_fn = collate_fn + self.reduction = reduction + + def test_model( + self, + model: nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + batch_size = int(config.get("batch_size", 1)) + sampler_obj = _resolve_sampler_for_loader(sampler) + + test_loader = torch.utils.data.DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + sampler=sampler_obj, + collate_fn=self.collate_fn, + ) + + model.to(context.device) + model.eval() + context.state["eval_loader"] = test_loader + + total_loss = 0.0 + total_weight = 0 + + autocast_guard, autocast_enabled = _autocast_context(context) + context.state["lerobot_autocast_enabled"] = autocast_enabled + with torch.no_grad(), autocast_guard: + for examples, labels in test_loader: + examples = examples.to(context.device) + labels = labels.to(context.device) + + batch = _apply_preprocessor(examples, context) + loss, _ = _resolve_policy_forward( + model, + batch, + reduction=self.reduction, + ) + + batch_weight = int(labels.size(0)) + if batch_weight <= 0: + inferred = _resolve_batch_size(batch) + batch_weight = inferred if inferred is not None else 1 + + total_loss += float(loss.detach().item()) * batch_weight + total_weight += batch_weight + + model.train() + context.state.pop("eval_loader", None) + + if total_weight == 0: + return float("inf") + + eval_loss = total_loss / total_weight + context.state["lerobot_eval_loss"] = eval_loss + return eval_loss + + +def _resolve_dataset_stats(dataset: Any) -> Any: + """Extract dataset statistics from LeRobot datasets/subsets when available.""" + if dataset is None: + return None + + metadata_candidates = ( + getattr(dataset, "meta", None), + getattr(dataset, "metadata", None), + ) + for metadata in metadata_candidates: + stats = getattr(metadata, "stats", None) + if stats is not None: + return stats + + nested_dataset = getattr(dataset, "dataset", None) + if nested_dataset is not None and nested_dataset is not dataset: + return _resolve_dataset_stats(nested_dataset) + + return None + + +class Trainer(ComposableTrainer): + """Composable LeRobot trainer backend.""" + + def __init__(self, model=None, callbacks=None): + self._collate_wrapper = LeRobotCollateWrapper() + self._processors_initialised = False + self._pretrained_path = self._resolve_policy_path() + self._runtime_precision = self._resolve_policy_precision() + self._policy_device = self._resolve_policy_device() + self._preprocessor_factory: Callable[..., tuple[Callable, Callable]] | None = ( + None + ) + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=None, + training_step_strategy=LeRobotTrainingStepStrategy(), + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=CustomCollateFnDataLoaderStrategy( + collate_fn=self._collate_wrapper, + num_workers=0, + pin_memory=True, + ), + testing_strategy=LeRobotTestingStrategy(self._collate_wrapper), + ) + + resolved_device = _resolve_runtime_device(self._policy_device, self.device) + self.device = str(resolved_device) + self.context.device = resolved_device + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.state["lerobot_runtime_device"] = str(resolved_device) + self.context.state["lerobot_preprocessor"] = None + self.context.state["lerobot_postprocessor"] = None + + @staticmethod + def _resolve_policy_path() -> str | None: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + candidate = policy_cfg.get("path") + if isinstance(candidate, str): + value = candidate.strip() + return value if value else None + return None + + @staticmethod + def _resolve_policy_precision() -> str: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + return _resolve_precision(policy_cfg.get("precision", "fp32")) + + @staticmethod + def _resolve_policy_device() -> str | None: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + candidate = policy_cfg.get("device") + if candidate is None: + return None + if not isinstance(candidate, str): + raise TypeError("`parameters.policy.device` must be a string.") + value = candidate.strip() + return value if value else None + + def _resolve_model_pretrained_path(self) -> str | None: + model = self._require_model() + model_path = getattr(model, "plato_policy_path", None) + if isinstance(model_path, str) and model_path.strip(): + return model_path.strip() + return self._pretrained_path + + def _ensure_pre_post_processors(self, dataset: Any) -> None: + if self._processors_initialised: + return + + model = self._require_model() + policy_config = getattr(model, "config", None) + if policy_config is None: + raise AttributeError( + "LeRobot trainer expects the model to expose a `config` attribute " + "compatible with `make_pre_post_processors`." + ) + + if self._preprocessor_factory is None: + self._preprocessor_factory = _import_make_pre_post_processors() + + dataset_stats = _resolve_dataset_stats(dataset) + if dataset_stats is None: + logging.warning( + "LeRobot dataset statistics are unavailable; preprocessing will " + "be created without explicit dataset stats." + ) + + kwargs: dict[str, Any] = {"dataset_stats": dataset_stats} + pretrained_path = self._resolve_model_pretrained_path() + if pretrained_path: + kwargs["pretrained_path"] = pretrained_path + + preprocessor, postprocessor = self._preprocessor_factory( + policy_config, **kwargs + ) + if not callable(preprocessor) or not callable(postprocessor): + raise TypeError( + "LeRobot `make_pre_post_processors` must return two callables." + ) + + self.context.state["lerobot_preprocessor"] = preprocessor + self.context.state["lerobot_postprocessor"] = postprocessor + self._processors_initialised = True + + def train_model(self, config, trainset, sampler, **kwargs): + self._ensure_pre_post_processors(trainset) + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.device = torch.device(str(self.device)) + return super().train_model(config, trainset, sampler, **kwargs) + + def test_model(self, config, testset, sampler=None, **kwargs): + self._ensure_pre_post_processors(testset) + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.device = torch.device(str(self.device)) + return super().test_model(config, testset, sampler, **kwargs) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index 5f54403d6..77e7666e9 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -13,6 +13,9 @@ pfedgraph, split_learning, ) +from plato.trainers import ( + lerobot as lerobot_trainer, +) from plato.trainers import ( nanochat as nanochat_trainer, ) @@ -25,6 +28,7 @@ "pfedgraph": pfedgraph.Trainer, "split_learning": split_learning.Trainer, "nanochat": nanochat_trainer.Trainer, + "lerobot": lerobot_trainer.Trainer, } diff --git a/plato/utils/rl_env.py b/plato/utils/rl_env.py index b807a67a5..cef5e57ea 100644 --- a/plato/utils/rl_env.py +++ b/plato/utils/rl_env.py @@ -10,6 +10,7 @@ import asyncio import logging +from typing import Any import gymnasium as gym import numpy as np @@ -43,19 +44,27 @@ def __init__(self, rl_agent): # https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html#tips-and-tricks-when-creating-a-custom-environment n_actions = 1 self.action_space = spaces.Box( - low=-1, high=1, shape=(n_actions,), dtype="float32" + low=-1, high=1, shape=(n_actions,), dtype=np.float32 ) # Use only global model accurarcy as state for now self.n_states = 1 # Also normalize observation space for better RL training self.observation_space = spaces.Box( - low=-1, high=1, shape=(self.n_states,), dtype="float32" + low=-1, high=1, shape=(self.n_states,), dtype=np.float32 ) self.state = np.zeros(self.n_states, dtype=np.float32) - def reset(self): + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, dict[str, Any]]: + super().reset(seed=seed) + del options + if self.rl_agent.rl_episode >= Config().algorithm.rl_episodes: while True: # Give RL agent some time to close connections and exit @@ -72,7 +81,7 @@ def reset(self): self.rl_agent.new_episode_begin.set() self.state = np.zeros(self.n_states, dtype=np.float32) - return self.state.copy() + return self.state.copy(), {} def step(self, action): """One step of reinforcement learning.""" diff --git a/pyproject.toml b/pyproject.toml index c61a4c059..2ca72df43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ nanochat = [ "jinja2", "PyYAML", ] +robotics = [ + "lerobot[smolvla]>=0.4.3,<0.5.0", +] [project.urls] Homepage = "https://github.com/TL-System/plato" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..cc83e035e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Project test package marker for reliable local imports.""" diff --git a/tests/algorithms/test_fedavg_algorithm.py b/tests/algorithms/test_fedavg_algorithm.py new file mode 100644 index 000000000..1ccdd8399 --- /dev/null +++ b/tests/algorithms/test_fedavg_algorithm.py @@ -0,0 +1,156 @@ +"""Tests for FedAvg payload filtering and dtype-safe weight handling.""" + +from __future__ import annotations + +from collections import OrderedDict +from types import SimpleNamespace +from typing import Any, cast + +import pytest +import torch + +from plato.algorithms.fedavg import Algorithm as FedAvgAlgorithm + + +class AdapterToyModel(torch.nn.Module): + """Toy model exposing adapter-mode metadata used by SmolVLA integration.""" + + def __init__(self) -> None: + super().__init__() + self.backbone = torch.nn.Linear(4, 4) + self.adapter = torch.nn.Linear(4, 4, bias=False) + self.register_buffer("token_count", torch.tensor([7], dtype=torch.int64)) + self.plato_finetune_mode = "adapter" + self.plato_trainable_parameter_names = ("adapter.weight",) + + +class DtypeToyModel(torch.nn.Module): + """Toy model with mixed dtypes for casting safeguards.""" + + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones((2, 2), dtype=torch.float32)) + self.register_buffer("step", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("flag", torch.tensor([True, False], dtype=torch.bool)) + + +class BFloat16ToyModel(torch.nn.Module): + """Toy model for bf16 transport-cast regression coverage.""" + + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.ones((2, 2), dtype=torch.bfloat16) + ) + + +def _algorithm_for(model: torch.nn.Module) -> FedAvgAlgorithm: + trainer = cast(Any, SimpleNamespace(model=model)) + return FedAvgAlgorithm(trainer=trainer) + + +def _clone_state_dict(model: torch.nn.Module) -> OrderedDict[str, torch.Tensor]: + return OrderedDict( + (name, tensor.detach().clone()) for name, tensor in model.state_dict().items() + ) + + +def test_adapter_payload_extract_load_round_trip(): + """Adapter mode should exchange only trainable tensors and load them safely.""" + torch.manual_seed(1) + source_model = AdapterToyModel() + source_algorithm = _algorithm_for(source_model) + + payload = source_algorithm.extract_weights() + assert list(payload.keys()) == ["adapter.weight"] + + torch.manual_seed(2) + target_model = AdapterToyModel() + before = _clone_state_dict(target_model) + + target_algorithm = _algorithm_for(target_model) + target_algorithm.load_weights(payload) + + after = target_model.state_dict() + assert torch.equal(after["adapter.weight"], payload["adapter.weight"]) + assert torch.equal(after["backbone.weight"], before["backbone.weight"]) + assert torch.equal(after["backbone.bias"], before["backbone.bias"]) + assert torch.equal(after["token_count"], before["token_count"]) + + round_trip_payload = target_algorithm.extract_weights() + assert list(round_trip_payload.keys()) == ["adapter.weight"] + assert torch.equal(round_trip_payload["adapter.weight"], payload["adapter.weight"]) + + +def test_load_weights_casts_dtype_and_rounds_non_float_tensors(): + """Incoming partial payloads should be cast to model dtypes.""" + model = DtypeToyModel() + algorithm = _algorithm_for(model) + + inbound = OrderedDict( + { + "weight": torch.full((2, 2), 2.5, dtype=torch.float64), + "step": torch.tensor([3.6], dtype=torch.float32), + "flag": torch.tensor([0.2, 0.8], dtype=torch.float32), + } + ) + + algorithm.load_weights(inbound) + + state = model.state_dict() + assert state["weight"].dtype == torch.float32 + assert torch.allclose(state["weight"], torch.full_like(state["weight"], 2.5)) + assert state["step"].dtype == torch.int64 + assert int(state["step"].item()) == 4 + assert state["flag"].dtype == torch.bool + assert torch.equal(state["flag"], torch.tensor([False, True])) + + +def test_extract_weights_casts_bfloat16_payloads_for_transport(): + """bf16 tensors should be cast to fp32 for safe payload serialization.""" + model = BFloat16ToyModel() + algorithm = _algorithm_for(model) + + payload = algorithm.extract_weights() + assert payload["weight"].dtype == torch.float32 + + inbound = OrderedDict( + {"weight": torch.full((2, 2), 3.5, dtype=torch.float32)} + ) + algorithm.load_weights(inbound) + + state = model.state_dict() + assert state["weight"].dtype == torch.bfloat16 + assert torch.allclose( + state["weight"].float(), + torch.full((2, 2), 3.5, dtype=torch.float32), + ) + + +def test_extract_weights_respects_optional_payload_size_limit(): + """Payload extraction should fail fast if a configured max size is exceeded.""" + model = torch.nn.Linear(32, 32, bias=False) + setattr(model, "plato_max_payload_size_mb", 0.0001) + algorithm = _algorithm_for(model) + + with pytest.raises(ValueError, match="payload size"): + algorithm.extract_weights() + + +def test_fedavg_full_mode_round_trip_with_large_weights(): + """Full-state FedAvg flow should remain compatible for non-adapter models.""" + torch.manual_seed(3) + model = torch.nn.Linear(1024, 1024, bias=False) + algorithm = _algorithm_for(model) + + baseline = algorithm.extract_weights() + assert list(baseline.keys()) == ["weight"] + + received = [OrderedDict((name, tensor + 0.25) for name, tensor in baseline.items())] + deltas = algorithm.compute_weight_deltas(baseline, received) + updated = algorithm.update_weights(deltas[0]) + + algorithm.load_weights(updated) + + state = model.state_dict() + assert torch.allclose(state["weight"], received[0]["weight"]) diff --git a/tests/datasources/test_lerobot_datasource.py b/tests/datasources/test_lerobot_datasource.py new file mode 100644 index 000000000..61dc2e81b --- /dev/null +++ b/tests/datasources/test_lerobot_datasource.py @@ -0,0 +1,105 @@ +"""Tests for LeRobot datasource registry wiring and constructor behavior.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from plato.datasources import lerobot as lerobot_datasource +from plato.datasources import registry as datasources_registry +from tests.test_utils.lerobot_stubs import ( + FakeLeRobotDataset, + FakeLeRobotDatasetMetadata, +) + + +@pytest.fixture +def patched_lerobot_backend(monkeypatch): + """Patch LeRobot imports with deterministic local stubs.""" + FakeLeRobotDataset.reset_calls() + fake_transforms = object() + + monkeypatch.setattr( + lerobot_datasource, + "_import_lerobot", + lambda: (FakeLeRobotDataset, FakeLeRobotDatasetMetadata), + ) + monkeypatch.setattr( + lerobot_datasource, + "_build_image_transforms", + lambda _cfg: fake_transforms, + ) + + return SimpleNamespace(fake_transforms=fake_transforms) + + +def test_lerobot_is_registered_as_partitioned_datasource(): + """LeRobot datasource should be wired through the partitioned registry.""" + assert "LeRobot" in datasources_registry.registered_partitioned_datasources + + +def test_registry_constructs_lerobot_datasource_for_client( + temp_config, + patched_lerobot_backend, +): + """Registry should build LeRobot datasource and pass client-aware options.""" + datasource = datasources_registry.get( + datasource_name="LeRobot", + client_id=1, + repo_id="stub/lerobot", + split_seed=7, + train_split=0.5, + delta_timestamps={"observation.image": [-0.1, 0.0]}, + dataset_kwargs={"streaming": True}, + task_aware_split=False, + task_aware_partition=False, + ) + + assert isinstance(datasource, lerobot_datasource.DataSource) + assert datasource.client_id == 1 + assert datasource.repo_id == "stub/lerobot" + assert len(datasource.train_episodes) > 0 + + train_call = FakeLeRobotDataset.constructor_calls[0] + assert train_call["episodes"] == datasource.train_episodes + assert train_call["delta_timestamps"] == {"observation.image": [-0.1, 0.0]} + assert train_call["extra_kwargs"]["streaming"] is True + assert train_call["image_transforms"] is patched_lerobot_backend.fake_transforms + + +def test_lerobot_constructor_is_deterministic_and_maps_samples( + temp_config, + patched_lerobot_backend, +): + """Constructor should produce stable splits and mapped Plato sample keys.""" + first = lerobot_datasource.DataSource( + client_id=2, + repo_id="stub/lerobot", + split_seed=11, + train_split=0.5, + task_aware_split=True, + task_aware_partition=True, + ) + second = lerobot_datasource.DataSource( + client_id=2, + repo_id="stub/lerobot", + split_seed=11, + train_split=0.5, + task_aware_split=True, + task_aware_partition=True, + ) + + assert first.train_episodes == second.train_episodes + assert first.test_episodes == second.test_episodes + assert first.num_train_examples() == len(first.get_train_set()) + assert first.num_test_examples() == len(first.get_test_set()) + + sample = first.get_train_set()[0] + assert "plato_inputs" in sample + assert "plato_targets" in sample + assert "plato_metadata" in sample + assert "observation.image" in sample["plato_inputs"] + assert torch.equal(sample["plato_targets"], sample["action"]) + assert sample["plato_metadata"]["episode_index"] in first.train_episodes diff --git a/tests/integration/test_lerobot_smolvla_smoke.py b/tests/integration/test_lerobot_smolvla_smoke.py new file mode 100644 index 000000000..c51370f1c --- /dev/null +++ b/tests/integration/test_lerobot_smolvla_smoke.py @@ -0,0 +1,135 @@ +"""Integration smoke for LeRobot datasource + SmolVLA model + LeRobot trainer.""" + +from __future__ import annotations + +from importlib import import_module +from types import SimpleNamespace + +import pytest + +from plato.config import Config +from tests.integration.utils import ( + async_run, + build_minimal_config, + configure_environment, +) +from tests.test_utils.lerobot_stubs import ( + FakeLeRobotDataset, + FakeLeRobotDatasetMetadata, + FakeSmolVLAPolicy, + identity_pre_post_processors, +) + +pytestmark = pytest.mark.integration + + +def test_lerobot_smolvla_end_to_end_smoke(monkeypatch): + """Smoke test startup + one short local train + server report processing.""" + config = build_minimal_config( + rounds=1, + clients_per_round=1, + total_clients=1, + model_name="smolvla", + trainer_type="lerobot", + ) + config["server"]["do_test"] = False + config["data"] = { + "datasource": "LeRobot", + "partition_size": 2, + "sampler": "iid", + "random_seed": 1, + } + config["trainer"].update( + { + "type": "lerobot", + "model_type": "smolvla", + "model_name": "smolvla", + "epochs": 1, + "batch_size": 2, + "optimizer": "SGD", + } + ) + config["parameters"] = { + "policy": { + "path": "stub/smolvla", + "finetune_mode": "adapter", + "adapter_parameter_patterns": ["adapter"], + }, + "dataset": { + "repo_id": "stub/lerobot", + "split_seed": 4, + "train_split": 0.5, + "task_aware_split": True, + "task_aware_partition": True, + }, + "optimizer": { + "lr": 0.05, + "momentum": 0.0, + "weight_decay": 0.0, + }, + } + + with configure_environment(config): + Config.args.id = 1 + FakeLeRobotDataset.reset_calls() + FakeSmolVLAPolicy.reset_calls() + + lerobot_datasource = import_module("plato.datasources.lerobot") + smolvla_model = import_module("plato.models.smolvla") + lerobot_trainer = import_module("plato.trainers.lerobot") + processor_registry = import_module("plato.processors.registry") + client_mod = import_module("plato.clients.simple") + server_mod = import_module("plato.servers.fedavg") + + monkeypatch.setattr( + lerobot_datasource, + "_import_lerobot", + lambda: (FakeLeRobotDataset, FakeLeRobotDatasetMetadata), + ) + monkeypatch.setattr( + lerobot_datasource, + "_build_image_transforms", + lambda _cfg: None, + ) + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: identity_pre_post_processors, + ) + monkeypatch.setattr( + processor_registry, + "get", + lambda *args, **kwargs: (None, None), + ) + + client = client_mod.Client() + client._load_data() + client.configure() + client._allocate_data() + + report, payload = async_run(client._train()) + report.processing_time = 0.0 + + assert report.num_samples > 0 + assert list(payload.keys()) == ["adapter.weight"] + + server = server_mod.Server() + server.configure() + server.updates = [ + SimpleNamespace( + client_id=1, + report=report, + payload=payload, + ) + ] + server.current_round = 0 + server.context.current_round = 0 + + async_run(server._process_reports()) + + assert server.accuracy >= 0 diff --git a/tests/models/test_smolvla_model.py b/tests/models/test_smolvla_model.py new file mode 100644 index 000000000..4c477fa6d --- /dev/null +++ b/tests/models/test_smolvla_model.py @@ -0,0 +1,64 @@ +"""Tests for SmolVLA model registry construction and adapter metadata.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast + +from plato.algorithms.fedavg import Algorithm as FedAvgAlgorithm +from plato.models import registry as models_registry +from plato.models import smolvla as smolvla_model +from tests.test_utils.lerobot_stubs import FakeSmolVLAPolicy + + +def test_model_registry_constructs_smolvla_wrapper(temp_config, monkeypatch): + """Model registry should construct SmolVLA via the registered factory.""" + FakeSmolVLAPolicy.reset_calls() + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + + model = models_registry.get( + model_name="smolvla", + model_type="smolvla", + model_params={ + "path": "stub/smolvla", + "finetune_mode": "adapter", + "adapter_parameter_patterns": ["adapter"], + "strict": True, + }, + ) + + assert isinstance(model, FakeSmolVLAPolicy) + assert model.plato_policy_path == "stub/smolvla" + assert model.plato_finetune_mode == "adapter" + assert model.plato_trainable_parameter_names == ("adapter.weight",) + assert FakeSmolVLAPolicy.load_calls[-1]["path"] == "stub/smolvla" + + +def test_smolvla_adapter_metadata_filters_fedavg_payload( + temp_config, + monkeypatch, +): + """Regression: adapter mode metadata must drive adapter-only FedAvg payloads.""" + FakeSmolVLAPolicy.reset_calls() + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + + model = smolvla_model.Model.get( + policy_path="stub/smolvla", + finetune_mode="adapter", + adapter_parameter_patterns=["adapter"], + ) + + trainer = cast(Any, SimpleNamespace(model=model)) + algorithm = FedAvgAlgorithm(trainer=trainer) + payload = algorithm.extract_weights() + + assert list(payload.keys()) == ["adapter.weight"] + assert "backbone.weight" not in payload diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 20b7514cc..41da1e1d1 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -212,6 +212,117 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): Config._cli_overrides = {} + + +def test_config_loads_smolvla_lerobot_parameter_contract(tmp_path: Path, monkeypatch): + """SmolVLA/LeRobot config keys should round-trip through Config().""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "smolvla_lerobot.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "LeRobot"}, + "trainer": { + "type": "lerobot", + "rounds": 1, + "epochs": 1, + "batch_size": 2, + "model_type": "smolvla", + "model_name": "smolvla", + }, + "algorithm": {"type": "fedavg"}, + "parameters": { + "policy": { + "type": "smolvla", + "path": "lerobot/smolvla_base", + "finetune_mode": "adapter", + "precision": "bf16", + "device": "cuda", + }, + "dataset": { + "repo_id": "lerobot/pusht_image", + "delta_timestamps": { + "observation_image": [-0.2, -0.1, 0.0], + }, + }, + "transforms": { + "image_size": [224, 224], + "normalize": True, + "interpolation": "bilinear", + }, + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert config.data.datasource == "LeRobot" + assert config.trainer.type == "lerobot" + assert config.trainer.model_type == "smolvla" + assert config.trainer.model_name == "smolvla" + + assert config.parameters.policy.type == "smolvla" + assert config.parameters.policy.path == "lerobot/smolvla_base" + assert config.parameters.policy.finetune_mode == "adapter" + assert config.parameters.policy.precision == "bf16" + assert config.parameters.policy.device == "cuda" + + assert config.parameters.dataset.repo_id == "lerobot/pusht_image" + assert config.parameters.dataset.delta_timestamps.observation_image == [ + -0.2, + -0.1, + 0.0, + ] + + assert config.parameters.transforms.image_size == [224, 224] + assert config.parameters.transforms.normalize is True + assert config.parameters.transforms.interpolation == "bilinear" + + assert config.parameters.policy._asdict() == { + "type": "smolvla", + "path": "lerobot/smolvla_base", + "finetune_mode": "adapter", + "precision": "bf16", + "device": "cuda", + } + assert config.parameters.dataset._asdict() == { + "repo_id": "lerobot/pusht_image", + "delta_timestamps": { + "observation_image": [-0.2, -0.1, 0.0], + }, + } + assert config.parameters.transforms._asdict() == { + "image_size": [224, 224], + "normalize": True, + "interpolation": "bilinear", + } + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + def test_is_central_server_requires_cross_silo_true(tmp_path: Path, monkeypatch): """Central-server detection should respect `cross_silo = false`.""" config_path = tmp_path / "config.toml" @@ -222,6 +333,7 @@ def test_is_central_server_requires_cross_silo_true(tmp_path: Path, monkeypatch) "trainer": {"type": "basic", "rounds": 1}, "algorithm": {"type": "fedavg", "cross_silo": False}, } + toml_writer.dump(config_data, config_path) monkeypatch.delenv("config_file", raising=False) @@ -240,6 +352,16 @@ def test_is_central_server_requires_cross_silo_true(tmp_path: Path, monkeypatch) delattr(Config, "args") Config._cli_overrides = {} + config = Config() + + assert Config.is_central_server() is False + assert getattr(config.algorithm, "cross_silo", False) is False + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + def test_toml_loader_allows_shared_includes(tmp_path: Path): """Shared include files should not be treated as circular includes.""" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 000000000..a20f1a914 --- /dev/null +++ b/tests/test_utils/__init__.py @@ -0,0 +1 @@ +"""Helpers and fakes shared across the test suite.""" diff --git a/tests/test_utils/lerobot_stubs.py b/tests/test_utils/lerobot_stubs.py new file mode 100644 index 000000000..2272e4b77 --- /dev/null +++ b/tests/test_utils/lerobot_stubs.py @@ -0,0 +1,170 @@ +"""Deterministic LeRobot/SmolVLA stubs for offline integration tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn as nn + + +class FakeLeRobotDatasetMetadata: + """Minimal metadata surface consumed by the LeRobot datasource adapter.""" + + def __init__(self, repo_id: str): + self.repo_id = repo_id + self.total_episodes = 6 + self.tasks = ["pick", "place"] + self.episodes = [ + { + "episode_index": episode, + "task_index": episode % 2, + "task": self.tasks[episode % 2], + } + for episode in range(self.total_episodes) + ] + + +class FakeLeRobotDataset: + """Small dict-style dataset that mimics LeRobot sample payloads.""" + + constructor_calls: list[dict[str, Any]] = [] + + def __init__( + self, + repo_id: str, + episodes: list[int] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + image_transforms: Any = None, + **kwargs: Any, + ): + self.repo_id = repo_id + self.episodes = [int(episode) for episode in (episodes or [])] + self.delta_timestamps = delta_timestamps + self.image_transforms = image_transforms + self.extra_kwargs = dict(kwargs) + self.meta = SimpleNamespace( + stats={ + "action": { + "mean": [0.0, 0.0], + "std": [1.0, 1.0], + } + } + ) + + self.samples: list[dict[str, Any]] = [] + for step, episode in enumerate(self.episodes): + observation = torch.tensor( + [float(episode), float(step)], + dtype=torch.float32, + ) + action = torch.tensor( + [float(episode) + 0.25, float(step) + 0.5], + dtype=torch.float32, + ) + if callable(image_transforms): + observation = image_transforms(observation) + + self.samples.append( + { + "observation.image": observation, + "action": action, + "episode_index": episode, + "step_index": step, + "task": "pick" if episode % 2 == 0 else "place", + } + ) + + self.targets = [sample["task"] for sample in self.samples] + self.classes = ("pick", "place") + + type(self).constructor_calls.append( + { + "repo_id": repo_id, + "episodes": list(self.episodes), + "delta_timestamps": delta_timestamps, + "image_transforms": image_transforms, + "extra_kwargs": dict(kwargs), + } + ) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.samples[index] + + @classmethod + def reset_calls(cls) -> None: + """Clear constructor call history for test isolation.""" + cls.constructor_calls = [] + + +class FakeSmolVLAPolicy(nn.Module): + """Tiny policy compatible with SmolVLA wrapper expectations.""" + + load_calls: list[dict[str, Any]] = [] + + def __init__(self): + super().__init__() + self.backbone = nn.Linear(2, 2, bias=False) + self.adapter = nn.Linear(2, 2, bias=False) + self.config = {"policy": "fake-smolvla"} + + with torch.no_grad(): + self.backbone.weight.copy_(torch.eye(2)) + self.adapter.weight.copy_(0.5 * torch.eye(2)) + + @classmethod + def from_pretrained( + cls, + policy_path: str, + token: str | None = None, + strict: bool = True, + ) -> "FakeSmolVLAPolicy": + """Match the LeRobot loader interface used by the wrapper.""" + cls.load_calls.append( + { + "path": policy_path, + "token": token, + "strict": strict, + } + ) + return cls() + + def forward(self, batch: dict[str, Any], reduction: str = "mean"): + """Return tensor loss + dict payload like SmolVLA policies do.""" + action = batch["action"].float() + prediction = self.adapter(self.backbone(action)) + per_sample = (prediction - action) ** 2 + + if reduction == "sum": + loss = per_sample.sum() + else: + loss = per_sample.mean() + + return loss, {"mse": loss.detach()} + + def save_pretrained(self, *args: Any, **kwargs: Any) -> None: + """Compatibility no-op for checkpoint contract checks.""" + + @classmethod + def reset_calls(cls) -> None: + """Clear loader call history for test isolation.""" + cls.load_calls = [] + + +def identity_pre_post_processors(policy_config: Any, **kwargs: Any): + """Return identity processors while recording constructor args.""" + + def preprocessor(batch): + return batch + + def postprocessor(outputs): + return outputs + + setattr(preprocessor, "policy_config", policy_config) + setattr(preprocessor, "kwargs", kwargs) + + return preprocessor, postprocessor diff --git a/tests/trainers/test_lerobot_trainer.py b/tests/trainers/test_lerobot_trainer.py new file mode 100644 index 000000000..1d1316aff --- /dev/null +++ b/tests/trainers/test_lerobot_trainer.py @@ -0,0 +1,211 @@ +"""Tests for LeRobot trainer training-step behavior with synthetic data.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from plato.config import Config +from plato.trainers import lerobot as lerobot_trainer + + +class _SyntheticLeRobotDataset(torch.utils.data.Dataset): + """Tiny deterministic dataset exposing LeRobot-like dict samples.""" + + def __init__(self): + self.meta = SimpleNamespace(stats={"action": {"mean": [0.0], "std": [1.0]}}) + self.samples = [ + { + "observation.image": torch.tensor([0.0, 1.0], dtype=torch.float32), + "action": torch.tensor([0.0, 2.0], dtype=torch.float32), + "episode_index": 0, + }, + { + "observation.image": torch.tensor([1.0, 2.0], dtype=torch.float32), + "action": torch.tensor([1.0, 3.0], dtype=torch.float32), + "episode_index": 0, + }, + { + "observation.image": torch.tensor([2.0, 3.0], dtype=torch.float32), + "action": torch.tensor([2.0, 4.0], dtype=torch.float32), + "episode_index": 1, + }, + { + "observation.image": torch.tensor([3.0, 4.0], dtype=torch.float32), + "action": torch.tensor([3.0, 5.0], dtype=torch.float32), + "episode_index": 1, + }, + ] + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.samples[index] + + +class _TinyLeRobotPolicy(nn.Module): + """Tiny policy returning tuple output variants consumed by LeRobot trainer.""" + + def __init__(self): + super().__init__() + self.adapter_scale = nn.Parameter(torch.tensor(1.0)) + self.config = {"policy": "tiny"} + self.plato_policy_path = "stub/smolvla" + + def forward(self, batch: dict[str, Any], reduction: str = "mean"): + action = batch["action"].float() + prediction = self.adapter_scale * torch.ones_like(action) + loss = torch.mean((prediction - action) ** 2) + # Return mapping + tensor tuple to exercise normalization branch. + return {"loss_component": loss.detach()}, loss + + +def test_lerobot_trainer_train_model_runs_on_tiny_synthetic_batch( + temp_config, + monkeypatch, +): + """Trainer should complete one short run and update model parameters.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + batch_size=2, + epochs=1, + optimizer="SGD", + ) + config.parameters = Config.node_from_dict( + { + "optimizer": { + "lr": 0.1, + "momentum": 0.0, + "weight_decay": 0.0, + }, + "policy": {"path": "stub/smolvla"}, + } + ) + + factory_calls: dict[str, Any] = {} + + def _fake_pre_post_factory(policy_config, **kwargs): + factory_calls["policy_config"] = policy_config + factory_calls["kwargs"] = kwargs + + def _pre(batch): + return batch + + def _post(outputs): + return outputs + + return _pre, _post + + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: _fake_pre_post_factory, + ) + + model = _TinyLeRobotPolicy() + trainer = lerobot_trainer.Trainer(model=model) + trainset = _SyntheticLeRobotDataset() + + start_value = float(model.adapter_scale.detach().item()) + trainer.train_model( + { + "batch_size": 2, + "epochs": 1, + "run_id": "lerobot-unit", + }, + trainset, + sampler=list(range(len(trainset))), + ) + end_value = float(model.adapter_scale.detach().item()) + + assert end_value != start_value + assert callable(trainer.context.state["lerobot_preprocessor"]) + assert trainer.context.state["lerobot_loss_dict"]["loss_component"] >= 0.0 + assert trainer.run_history.get_metric_values("train_loss") + assert factory_calls["kwargs"]["pretrained_path"] == "stub/smolvla" + assert factory_calls["kwargs"]["dataset_stats"] == trainset.meta.stats + + +def test_lerobot_trainer_consumes_policy_precision_and_device( + temp_config, + monkeypatch, +): + """Trainer should apply policy precision/device runtime settings.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + batch_size=2, + epochs=1, + optimizer="SGD", + ) + config.parameters = Config.node_from_dict( + { + "optimizer": { + "lr": 0.05, + "momentum": 0.0, + "weight_decay": 0.0, + }, + "policy": { + "path": "stub/smolvla", + "precision": "bf16", + "device": "cpu", + }, + } + ) + + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: (lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out)), + ) + + trainer = lerobot_trainer.Trainer(model=_TinyLeRobotPolicy()) + trainset = _SyntheticLeRobotDataset() + + assert trainer.device == "cpu" + assert trainer.context.device == torch.device("cpu") + assert trainer.context.state["lerobot_precision"] == "bf16" + + trainer.train_model( + {"batch_size": 2, "epochs": 1, "run_id": "lerobot-precision"}, + trainset, + sampler=list(range(len(trainset))), + ) + assert trainer.context.state["lerobot_precision"] == "bf16" + assert isinstance(trainer.context.state["lerobot_autocast_enabled"], bool) + + +def test_lerobot_trainer_rejects_unavailable_cuda_device( + temp_config, + monkeypatch, +): + """Policy device should be validated against runtime accelerator availability.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + ) + config.parameters = Config.node_from_dict( + { + "policy": { + "path": "stub/smolvla", + "device": "cuda", + }, + } + ) + + monkeypatch.setattr(lerobot_trainer.torch.cuda, "is_available", lambda: False) + + with pytest.raises(RuntimeError, match="CUDA is not available"): + lerobot_trainer.Trainer(model=_TinyLeRobotPolicy())