From 9ed57fb613e5a336debe44c8ac8c33795db743c7 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 11:53:24 -0500 Subject: [PATCH 01/15] Defined the SmolVLA and LeRobot integration contract for T1. --- plans/smolvla-lerobot-integration-contract.md | 151 ++++++++++++++++++ .../smolvla-lerobot-plato-integration-plan.md | 140 ++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 plans/smolvla-lerobot-integration-contract.md create mode 100644 plans/smolvla-lerobot-plato-integration-plan.md diff --git a/plans/smolvla-lerobot-integration-contract.md b/plans/smolvla-lerobot-integration-contract.md new file mode 100644 index 000000000..8fb329f43 --- /dev/null +++ b/plans/smolvla-lerobot-integration-contract.md @@ -0,0 +1,151 @@ +# SmolVLA + LeRobot Integration Contract (Release v1) + +Date: 2026-02-19 +Plan Task: T1 (`depends_on: []`) +Status: accepted baseline for implementation tasks T2+ + +## 1. Objective + +Define a concrete, testable contract for first-release SmolVLA + LeRobot support in Plato without changing Plato's core federated runtime model. + +## 2. In-Scope (Release v1) + +1. SmolVLA fine-tuning runs inside Plato's existing federated lifecycle. +2. LeRobot datasets are ingested via Plato datasource APIs. +3. Existing client/server/algorithm loops remain the orchestration path. +4. End users run experiments through TOML configuration only (no source edits). + +## 3. Integration Surface (Concrete Components) + +The implementation must integrate through these existing extension points. + +Runtime entry and lifecycle: +1. `plato.py` +2. `plato/client.py` +3. `plato/clients/registry.py` +4. `plato/clients/base.py` +5. `plato/servers/fedavg.py` +6. `plato/servers/registry.py` +7. `plato/algorithms/registry.py` + +Config loading and propagation: +1. `plato/config.py` + +Datasource extension points: +1. `plato/datasources/base.py` +2. `plato/datasources/registry.py` +3. New module target: `plato/datasources/lerobot.py` + +Model extension points: +1. `plato/models/registry.py` +2. New module target: `plato/models/smolvla.py` + +Trainer extension points: +1. `plato/trainers/base.py` +2. `plato/trainers/composable.py` +3. `plato/trainers/registry.py` +4. New module target: `plato/trainers/lerobot.py` + +Compatibility rule: v1 must work with existing `fedavg` server/algorithm paths. Any special handling must be encapsulated inside the new datasource/model/trainer modules and their registry wiring. + +## 4. Configuration Contract (v1) + +The following fields are the required contract for SmolVLA + LeRobot configs. T3 is responsible for schema wiring/validation. + +```toml +[data] +datasource = "LeRobot" +# existing partitioning keys stay valid (sampler, partition_size, random_seed) + +[trainer] +type = "lerobot" +model_type = "smolvla" +model_name = "smolvla" + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "full" # "full" or "adapter" +precision = "bf16" # expected values: fp32/fp16/bf16 +device = "cuda" # expected values: cpu/cuda/mps + +[parameters.dataset] +repo_id = "/" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +``` + +Required semantics: + +1. `parameters.policy.path` resolves pretrained policy source. +2. `parameters.policy.type` selects policy family (`smolvla` for v1). +3. `parameters.dataset.repo_id` selects LeRobot dataset source. +4. `parameters.dataset.delta_timestamps` controls temporal windowing. +5. `parameters.transforms.*` controls image preprocessing. +6. `parameters.policy.finetune_mode` controls full vs adapter updates. +7. `parameters.policy.precision` and `parameters.policy.device` govern runtime dtype/device behavior. + +## 5. Scope Boundaries and Non-Goals + +Explicitly out of scope for release v1: + +1. New federated algorithms or server types beyond existing registry options. +2. Live robot inference/control loops and async teleoperation workflows. +3. Non-LeRobot robotics dataset backends. +4. End-to-end convergence/benchmark claims beyond smoke/stability checks. +5. Automated dependency bootstrap for platform-specific robotics stacks. + +## 6. Acceptance Checks (Concrete and Testable) + +These checks define go/no-go for the integration scope. + +### AC1: Single-client local training run + +- Config target: `configs/LeRobot/smolvla_single_client_smoke.toml`. +- Command: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml +``` + +Pass criteria: +1. Process exits with code `0`. +2. Trainer completes at least one local epoch in one communication round. +3. A model artifact is written under configured `model_path`. + +### AC2: Multi-client federated round-trip + +- Config target: `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml`. +- Command: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml +``` + +Pass criteria: +1. Server starts and selects two clients in the same round. +2. Server receives two client updates and completes aggregation. +3. Round counter advances to at least round 1 completion without runtime exceptions. + +### AC3: Config-first workflow (no source edits) + +Validation procedure: +1. Run `AC1` and `AC2` using committed TOML files only. +2. Confirm no local source modifications are required between runs. + +Pass criteria: +1. Both runs succeed from clean checkout with only config selection changed. + +## 7. Deliverables Expected From Downstream Tasks + +1. Config files under `configs/LeRobot/` implementing AC1/AC2 targets. +2. Registry wiring and implementation modules listed in Section 3. +3. Smoke/integration tests that encode AC1/AC2/AC3 behavior. + +## 8. Notes + +- This contract intentionally locks only v1 integration behavior and acceptance gates. +- Performance tuning and broader robotics feature surface are deferred to post-v1 tasks. diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md new file mode 100644 index 000000000..d167294ae --- /dev/null +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -0,0 +1,140 @@ +# SmolVLA + LeRobot Integration Plan for Plato + +Date: 2026-02-19 +Scope: Add support for training Hugging Face SmolVLA with LeRobot datasets/framework inside Plato. + +## Dependency Graph + +```text +T1 -> T2, T3 +T2 -> T4, T5 +T3 -> T4, T5, T9 +T4, T5 -> T6 +T6 -> T7, T8 +T4, T5, T6 -> T9 +T7, T8, T9 -> T10 +T10 -> T11 +T11 -> T12 +``` + +## Tasks + +### T1. Define integration contract and acceptance criteria +depends_on: [] +status: completed (2026-02-19) +- Lock exact scope for first release: +- Support SmolVLA fine-tuning in Plato’s existing FL lifecycle. +- Support LeRobot dataset ingestion through Plato datasource APIs. +- Ensure compatibility with existing client/server/algorithm loops. +- Define acceptance checks: +- Single-client local training run works. +- Multi-client federated round-trip works. +- Config-first workflow (no code edits required to run experiment). +work_log: +- Added `plans/smolvla-lerobot-integration-contract.md` to lock v1 scope, concrete integration touchpoints, config contract, and testable acceptance checks. +- Made explicit scope boundaries and non-goals for v1. +files_touched: +- `plans/smolvla-lerobot-integration-contract.md` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- Context7 returned `LeRobot` as the primary documented surface; SmolVLA details were discovered through the LeRobot documentation set. + +### T2. Add dependencies and environment gating +depends_on: [T1] +- Update `pyproject.toml` with required LeRobot and training stack dependencies. +- Regenerate `uv.lock`. +- Add guarded imports so environments without robotics extras still run existing Plato workloads. +- Document required system/runtime notes for optional robotics path. + +### T3. Extend Plato configuration schema for SmolVLA/LeRobot +depends_on: [T1] +- Add/validate config keys needed for SmolVLA + LeRobot: +- `policy.path` / `policy.type` +- `dataset.repo_id` +- `delta_timestamps` +- image transform controls +- precision/device flags +- full-finetune vs adapter mode switch +- Ensure keys flow through `Config()` and into trainer/model/datasource constructors. + +### T4. Implement LeRobot datasource adapter +depends_on: [T2, T3] +- Add `plato/datasources/lerobot.py`. +- Implement dataset loading via LeRobot APIs and map samples into Plato’s expected batch format. +- Register datasource in `plato/datasources/registry.py`. +- Add deterministic client partitioning strategy (episode/task aware split). +- Provide train/test dataset access methods compatible with existing samplers. + +### T5. Implement SmolVLA model/policy wrapper +depends_on: [T2, T3] +- Add `plato/models/smolvla.py`. +- Implement pretrained loading path (`smolvla_base` and custom repo id/path). +- Expose trainable-parameter policy (full model or adapter path). +- Register model in `plato/models/registry.py`. +- Ensure state dict save/load compatibility with Plato aggregation pipeline. + +### T6. Implement LeRobot trainer backend +depends_on: [T4, T5] +- Add `plato/trainers/lerobot.py` (ComposableTrainer-compatible). +- Implement multimodal collate + preprocessing for LeRobot samples. +- Wire forward/loss/backward/optimizer/scheduler flow for SmolVLA policy. +- Implement evaluation hooks suitable for regression checks. +- Register trainer in `plato/trainers/registry.py`. + +### T7. Harden federated payload/aggregation behavior +depends_on: [T6] +- Ensure only intended trainable tensors are exchanged/aggregated. +- Add safeguards for payload size and dtype handling. +- Verify checkpoint/state restore consistency across rounds. +- Validate no regressions in FedAvg flow with large model weights. + +### T8. Validate runtime lifecycle compatibility +depends_on: [T6] +- Confirm integration works with existing lifecycle code paths: +- client setup strategies +- server trainer initialization +- training/report/aggregation loop +- Avoid special-case branching unless strictly necessary. + +### T9. Add runnable experiment configs +depends_on: [T3, T4, T5, T6] +- Add `configs/LeRobot/` config set: +- reusable base datasource fragment +- minimal smoke config +- full fine-tune config aligned to SmolVLA guidance +- Ensure includes/overrides follow repository config conventions. + +### T10. Add tests (unit + integration smoke) +depends_on: [T7, T8, T9] +- Datasource registry + constructor tests for LeRobot datasource. +- Model registry + construction tests for SmolVLA wrapper. +- Trainer step test with tiny synthetic batch. +- End-to-end config smoke test covering startup and one short training run. +- Add regression tests for any bug fixes discovered during integration. + +### T11. Add documentation and runbook +depends_on: [T10] +- Document setup and dependency extras. +- Document config fields and examples. +- Add troubleshooting notes (dataset access, device setup, common failures). +- Add mapping between Plato config and equivalent `lerobot-train` concepts. + +### T12. Stage validation and rollout gate +depends_on: [T11] +- Execute staged validation: +- single-client local run +- 2-client federated smoke run +- larger run for convergence and stability check +- Compare behavior/runtime against expected baseline. +- Define go/no-go criteria and recommended defaults for first public release. + +## Milestones + +- Milestone A (Core plumbing): T1-T6 complete. +- Milestone B (Federated reliability): T7-T8 complete. +- Milestone C (Usability + confidence): T9-T12 complete. + +## Notes + +- Existing codebase discovery found no native SmolVLA/LeRobot implementation yet. +- Primary extension anchors are current registries and Hugging Face integration patterns. From 0e95554ed96344950f5963a57e5f8bc717746ab6 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 11:59:39 -0500 Subject: [PATCH 02/15] Added optional SmolVLA and LeRobot dependency gating. --- docs/docs/install.md | 12 +++++ docs/docs/smolvla_lerobot_setup.md | 46 +++++++++++++++++++ .../smolvla-lerobot-plato-integration-plan.md | 15 ++++++ pyproject.toml | 3 ++ 4 files changed, 76 insertions(+) create mode 100644 docs/docs/smolvla_lerobot_setup.md diff --git a/docs/docs/install.md b/docs/docs/install.md index 342426e35..0b12aa823 100644 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -65,6 +65,18 @@ uv sync --extra mlx See the [Quick Start guide](quickstart.md#using-mlx-as-a-backend) for configuration details. +### Optional: SmolVLA + LeRobot Robotics Stack + +LeRobot and SmolVLA dependencies are available behind Plato's optional +`robotics` extra so default federated-learning installs remain lightweight: + +```bash +uv sync --extra robotics +``` + +See [SmolVLA + LeRobot Optional Setup](smolvla_lerobot_setup.md) for runtime +requirements and guarded-import guidance. + ### Building the `plato-learn` PyPi Package The `plato-learn` PyPi package will be automatically built and published by a GitHub action workflow every time a release is created on GitHub. To build the package manually, follow these steps: diff --git a/docs/docs/smolvla_lerobot_setup.md b/docs/docs/smolvla_lerobot_setup.md new file mode 100644 index 000000000..e36be2a22 --- /dev/null +++ b/docs/docs/smolvla_lerobot_setup.md @@ -0,0 +1,46 @@ +# SmolVLA + LeRobot Optional Setup + +This setup path is optional. Core Plato federated workloads continue to use the +default dependency set from `uv sync`. + +## Install the robotics extra + +From the repository root: + +```bash +uv sync --extra robotics +``` + +This installs `lerobot[smolvla]` and the associated training stack only when the +`robotics` extra is requested. + +## Environment gating + +When adding LeRobot-backed modules, keep imports guarded so non-robotics +environments fail with a clear action instead of a hard crash at import time. + +```python +try: + import lerobot +except ImportError as exc: + raise ImportError( + "LeRobot support is optional. Install with: uv sync --extra robotics" + ) from exc +``` + +## Runtime notes for SmolVLA/LeRobot + +- CUDA-capable GPUs are recommended for practical SmolVLA fine-tuning; CPU is + mainly suitable for smoke checks. +- Install `ffmpeg` on hosts that read video-backed LeRobot datasets. +- Authenticate with Hugging Face (`huggingface-cli login`) when accessing + private dataset repositories. +- LeRobot currently constrains the Torch stack used by this optional path; + if you need different Torch constraints for non-robotics research, keep a + separate virtual environment. + +## Quick verification + +```bash +uv run python -c "import lerobot; print(lerobot.__version__)" +``` diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index d167294ae..b23674647 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -41,10 +41,25 @@ gotchas: ### T2. Add dependencies and environment gating depends_on: [T1] +status: completed (2026-02-19) - Update `pyproject.toml` with required LeRobot and training stack dependencies. - Regenerate `uv.lock`. - Add guarded imports so environments without robotics extras still run existing Plato workloads. - Document required system/runtime notes for optional robotics path. +work_log: +- Added a new optional extra (`robotics`) in `pyproject.toml` with `lerobot[smolvla]>=0.4.3,<0.5.0` so default installs remain unchanged. +- Regenerated `uv.lock` with `uv lock`, then validated both paths: +- `uv sync --frozen` + import check for core Plato. +- `uv sync --frozen --extra robotics` + `import lerobot` check for the optional robotics stack. +- Added focused setup docs for SmolVLA/LeRobot and linked them from `docs/docs/install.md`. +files_touched: +- `pyproject.toml` +- `uv.lock` +- `docs/docs/install.md` +- `docs/docs/smolvla_lerobot_setup.md` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` +gotchas: +- The optional LeRobot path constrains parts of the Torch stack in lock resolution; keeping it under `--extra robotics` avoids forcing robotics dependencies into default `uv sync` environments. ### T3. Extend Plato configuration schema for SmolVLA/LeRobot depends_on: [T1] diff --git a/pyproject.toml b/pyproject.toml index c61a4c059..2ca72df43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ nanochat = [ "jinja2", "PyYAML", ] +robotics = [ + "lerobot[smolvla]>=0.4.3,<0.5.0", +] [project.urls] Homepage = "https://github.com/TL-System/plato" From ce59d320044aa8c62924232cf42e13b7feeb5de0 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:00:26 -0500 Subject: [PATCH 03/15] Completed T3 SmolVLA and LeRobot configuration wiring. --- docs/docs/configurations/parameters.md | 62 ++++++++++ .../smolvla-lerobot-plato-integration-plan.md | 11 ++ tests/test_config_loader.py | 110 ++++++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/docs/docs/configurations/parameters.md b/docs/docs/configurations/parameters.md index 68cd6eddf..8a0119def 100644 --- a/docs/docs/configurations/parameters.md +++ b/docs/docs/configurations/parameters.md @@ -18,3 +18,65 @@ !!! example "loss_criterion" All the parameter settings that need to be passed as keyword parameters when initializing the loss criterion, such as `size_average`. The set of parameters permitted or needed depends on the loss criterion. + +## SmolVLA + LeRobot parameter contract + +`Config()` keeps nested keys under `[parameters]` as dot-accessible nodes. For +the SmolVLA + LeRobot integration, define the following sections. + +| Config key | Purpose | Consumption path | +| --- | --- | --- | +| `data.datasource = "LeRobot"` | Selects the robotics datasource family. | `plato.datasources.registry.get()` chooses the datasource module. | +| `trainer.type = "lerobot"` | Selects the robotics trainer backend. | `plato.trainers.registry.get()` chooses the trainer class. | +| `trainer.model_type = "smolvla"` | Selects the model family. | `plato.models.registry.get()` resolves the model factory. | +| `trainer.model_name = "smolvla"` | Selects the concrete model entry point. | `plato.models.registry.get()` resolves the model name. | +| `parameters.policy.type` | Policy family identifier (`smolvla` in v1). | Consumed by `plato/models/smolvla.py` and `plato/trainers/lerobot.py` via `Config().parameters.policy`. | +| `parameters.policy.path` | Pretrained policy source (Hub/local path). | Consumed by `plato/models/smolvla.py` via `Config().parameters.policy.path`. | +| `parameters.policy.finetune_mode` | Full fine-tune vs adapter mode switch. | Consumed by `plato/trainers/lerobot.py` to decide trainable params. | +| `parameters.policy.precision` | Runtime precision flag (`fp32`/`fp16`/`bf16`). | Consumed by `plato/trainers/lerobot.py` for dtype/autocast setup. | +| `parameters.policy.device` | Runtime device flag (`cpu`/`cuda`/`mps`). | Consumed by `plato/trainers/lerobot.py` for device placement. | +| `parameters.dataset.repo_id` | LeRobot dataset identifier. | Consumed by `plato/datasources/lerobot.py` dataset loader. | +| `parameters.dataset.delta_timestamps` | Temporal window selection per modality key. | Consumed by `plato/datasources/lerobot.py` sampling/windowing logic. | +| `parameters.transforms.*` | Image transform controls (`image_size`, `normalize`, interpolation/crop options). | Consumed by `plato/datasources/lerobot.py` preprocessing pipeline. | + +### Constructor-ready dictionaries + +SmolVLA/LeRobot components should convert section nodes to plain dictionaries +when passing keyword arguments into constructors: + +```python +from plato.config import Config + +cfg = Config() +policy_kwargs = cfg.parameters.policy._asdict() +dataset_kwargs = cfg.parameters.dataset._asdict() +transform_kwargs = cfg.parameters.transforms._asdict() +``` + +### Example + +```toml +[data] +datasource = "LeRobot" + +[trainer] +type = "lerobot" +model_type = "smolvla" +model_name = "smolvla" + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "full" +precision = "bf16" +device = "cuda" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" +``` diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index b23674647..870e99cfb 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -63,6 +63,7 @@ gotchas: ### T3. Extend Plato configuration schema for SmolVLA/LeRobot depends_on: [T1] +status: completed (2026-02-19) - Add/validate config keys needed for SmolVLA + LeRobot: - `policy.path` / `policy.type` - `dataset.repo_id` @@ -71,6 +72,16 @@ depends_on: [T1] - precision/device flags - full-finetune vs adapter mode switch - Ensure keys flow through `Config()` and into trainer/model/datasource constructors. +work_log: +- Verified that `plato/config.py` already preserves nested TOML keys under `Config().parameters` without schema whitelisting, so SmolVLA/LeRobot keys are backward-compatible pass-through. +- Added a focused config loader test to assert `parameters.policy`, `parameters.dataset`, and `parameters.transforms` keys are parsed and exposed as constructor-ready dictionaries via `_asdict()`. +- Extended configuration documentation with an explicit mapping table from config keys to trainer/model/datasource consumption paths and a full SmolVLA/LeRobot TOML example. +files_touched: +- `tests/test_config_loader.py` (updated) +- `docs/docs/configurations/parameters.md` (updated) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- No `Config()` code change was required; introducing strict validation at this stage would have been intrusive and risked regressions for existing custom `parameters.*` users. ### T4. Implement LeRobot datasource adapter depends_on: [T2, T3] diff --git a/tests/test_config_loader.py b/tests/test_config_loader.py index 2b1bb9b75..6e587ce82 100644 --- a/tests/test_config_loader.py +++ b/tests/test_config_loader.py @@ -208,3 +208,113 @@ def test_config_loads_evaluation_section(tmp_path: Path, monkeypatch): if hasattr(Config, "args"): delattr(Config, "args") Config._cli_overrides = {} + + +def test_config_loads_smolvla_lerobot_parameter_contract(tmp_path: Path, monkeypatch): + """SmolVLA/LeRobot config keys should round-trip through Config().""" + config_base = tmp_path / "runtime" + config_path = tmp_path / "smolvla_lerobot.toml" + + config_data = { + "clients": {"type": "simple", "total_clients": 2, "per_round": 1}, + "server": {"address": "127.0.0.1", "port": 8000}, + "data": {"datasource": "LeRobot"}, + "trainer": { + "type": "lerobot", + "rounds": 1, + "epochs": 1, + "batch_size": 2, + "model_type": "smolvla", + "model_name": "smolvla", + }, + "algorithm": {"type": "fedavg"}, + "parameters": { + "policy": { + "type": "smolvla", + "path": "lerobot/smolvla_base", + "finetune_mode": "adapter", + "precision": "bf16", + "device": "cuda", + }, + "dataset": { + "repo_id": "lerobot/pusht_image", + "delta_timestamps": { + "observation_image": [-0.2, -0.1, 0.0], + }, + }, + "transforms": { + "image_size": [224, 224], + "normalize": True, + "interpolation": "bilinear", + }, + }, + } + + toml_writer.dump(config_data, config_path) + + monkeypatch.delenv("config_file", raising=False) + monkeypatch.setattr( + sys, + "argv", + [ + sys.argv[0], + "--config", + str(config_path), + "--base", + str(config_base), + ], + ) + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} + + config = Config() + + assert config.data.datasource == "LeRobot" + assert config.trainer.type == "lerobot" + assert config.trainer.model_type == "smolvla" + assert config.trainer.model_name == "smolvla" + + assert config.parameters.policy.type == "smolvla" + assert config.parameters.policy.path == "lerobot/smolvla_base" + assert config.parameters.policy.finetune_mode == "adapter" + assert config.parameters.policy.precision == "bf16" + assert config.parameters.policy.device == "cuda" + + assert config.parameters.dataset.repo_id == "lerobot/pusht_image" + assert config.parameters.dataset.delta_timestamps.observation_image == [ + -0.2, + -0.1, + 0.0, + ] + + assert config.parameters.transforms.image_size == [224, 224] + assert config.parameters.transforms.normalize is True + assert config.parameters.transforms.interpolation == "bilinear" + + # These dictionaries are the direct constructor kwargs expected by T4/T5/T6. + assert config.parameters.policy._asdict() == { + "type": "smolvla", + "path": "lerobot/smolvla_base", + "finetune_mode": "adapter", + "precision": "bf16", + "device": "cuda", + } + assert config.parameters.dataset._asdict() == { + "repo_id": "lerobot/pusht_image", + "delta_timestamps": { + "observation_image": [-0.2, -0.1, 0.0], + }, + } + assert config.parameters.transforms._asdict() == { + "image_size": [224, 224], + "normalize": True, + "interpolation": "bilinear", + } + + Config._instance = None + if hasattr(Config, "args"): + delattr(Config, "args") + Config._cli_overrides = {} From ee379027ef541530987140a18aebb5426f125628 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:10:39 -0500 Subject: [PATCH 04/15] Implemented SmolVLA model wrapper and registry wiring. --- .../smolvla-lerobot-plato-integration-plan.md | 13 + plato/models/registry.py | 2 + plato/models/smolvla.py | 284 ++++++++++++++++++ 3 files changed, 299 insertions(+) create mode 100644 plato/models/smolvla.py diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 870e99cfb..b38c51959 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -93,11 +93,24 @@ depends_on: [T2, T3] ### T5. Implement SmolVLA model/policy wrapper depends_on: [T2, T3] +status: completed (2026-02-19) - Add `plato/models/smolvla.py`. - Implement pretrained loading path (`smolvla_base` and custom repo id/path). - Expose trainable-parameter policy (full model or adapter path). - Register model in `plato/models/registry.py`. - Ensure state dict save/load compatibility with Plato aggregation pipeline. +work_log: +- Added `plato/models/smolvla.py` with lazy LeRobot import guards, actionable installation errors for missing robotics extras, and a SmolVLA factory path compatible with Plato model registry usage. +- Implemented pretrained policy source resolution with support for `smolvla_base` aliasing to `lerobot/smolvla_base`, config-based `parameters.policy.path`, and explicit constructor overrides (`policy_path` / `path`). +- Added finetune policy modes for `full` and `adapter`; adapter mode uses configurable name-pattern matching and falls back to the loaded policy's existing `requires_grad` flags when patterns do not match. +- Added compatibility checks for `state_dict`, `load_state_dict`, and `save_pretrained`, then registered `model_type = "smolvla"` in `plato/models/registry.py`. +- Ran a targeted constructor/import validation without downloads by monkeypatching `SmolVLAPolicy.from_pretrained` and verifying both direct wrapper construction and registry resolution. +files_touched: +- `plato/models/smolvla.py` (created) +- `plato/models/registry.py` (updated) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- Adapter parameter names are model-dependent; when no configured adapter patterns match, the wrapper intentionally reuses the model's preconfigured trainable flags instead of silently leaving zero trainable tensors. ### T6. Implement LeRobot trainer backend depends_on: [T4, T5] diff --git a/plato/models/registry.py b/plato/models/registry.py index 382a2d064..1dbda56cf 100644 --- a/plato/models/registry.py +++ b/plato/models/registry.py @@ -17,6 +17,7 @@ multilayer, nanochat, resnet, + smolvla, torch_hub, vgg, vit, @@ -42,6 +43,7 @@ "huggingface": huggingface.Model, "vit": vit.Model, "nanochat": nanochat.Model, + "smolvla": smolvla.Model, } registered_mlx_models = {} diff --git a/plato/models/smolvla.py b/plato/models/smolvla.py new file mode 100644 index 000000000..dd570d8d3 --- /dev/null +++ b/plato/models/smolvla.py @@ -0,0 +1,284 @@ +""" +Factory for SmolVLA policies integrated with Plato's model registry. +""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Iterable +from typing import Any + +import torch.nn as nn + +from plato.config import Config + +DEFAULT_POLICY_PATH = "lerobot/smolvla_base" +DEFAULT_FINETUNE_MODE = "full" +SUPPORTED_FINETUNE_MODES = {"full", "adapter"} +DEFAULT_ADAPTER_PATTERNS = ("adapter", "lora", "peft") + + +def _node_to_dict(node: Any) -> dict[str, Any]: + """Convert a config section into a plain dictionary.""" + if node is None: + return {} + if isinstance(node, dict): + return dict(node) + if hasattr(node, "_asdict"): + return dict(node._asdict()) + if hasattr(node, "__dict__"): + return { + key: value + for key, value in node.__dict__.items() + if not key.startswith("_") + } + raise TypeError("Unsupported policy configuration format.") + + +def _import_smolvla_policy() -> type[Any]: + """Import SmolVLAPolicy lazily to keep robotics dependencies optional.""" + try: + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + except ImportError as exc: # pragma: no cover - environment dependent + raise ImportError( + "SmolVLA requires LeRobot robotics dependencies. " + "Install them with `uv sync --extra robotics`." + ) from exc + return SmolVLAPolicy + + +def _resolve_policy_config() -> dict[str, Any]: + """Read [parameters.policy] config values, if present.""" + parameters = getattr(Config(), "parameters", None) + policy = getattr(parameters, "policy", None) + return _node_to_dict(policy) + + +def _resolve_policy_path( + model_name: str | None, + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> str: + """Resolve pretrained source from kwargs, config, or defaults.""" + candidate = ( + overrides.get("policy_path") + or overrides.get("path") + or policy_config.get("path") + ) + + if candidate is None and isinstance(model_name, str) and model_name: + if model_name.lower() not in {"smolvla", "smolvla_base"}: + candidate = model_name + + if candidate is None: + candidate = DEFAULT_POLICY_PATH + + if not isinstance(candidate, str): + raise TypeError("SmolVLA policy path must be provided as a string.") + + resolved = candidate.strip() + if not resolved: + raise ValueError("SmolVLA policy path cannot be empty.") + + if resolved.lower() == "smolvla_base": + return DEFAULT_POLICY_PATH + + return resolved + + +def _resolve_finetune_mode( + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> str: + """Resolve the finetune mode with validation.""" + mode = overrides.get("finetune_mode", policy_config.get("finetune_mode")) + if mode is None: + mode = DEFAULT_FINETUNE_MODE + + if not isinstance(mode, str): + raise TypeError("`finetune_mode` must be a string.") + + normalized_mode = mode.strip().lower() + if normalized_mode not in SUPPORTED_FINETUNE_MODES: + raise ValueError( + "Unsupported SmolVLA finetune mode " + f"'{mode}'. Expected one of {sorted(SUPPORTED_FINETUNE_MODES)}." + ) + return normalized_mode + + +def _resolve_adapter_patterns( + policy_config: dict[str, Any], + overrides: dict[str, Any], +) -> list[str]: + """Resolve adapter parameter patterns from kwargs or config.""" + raw_patterns = overrides.get( + "adapter_parameter_patterns", + policy_config.get("adapter_parameter_patterns"), + ) + + if raw_patterns is None: + return list(DEFAULT_ADAPTER_PATTERNS) + + if isinstance(raw_patterns, str): + parsed = [token.strip() for token in raw_patterns.split(",") if token.strip()] + return parsed or list(DEFAULT_ADAPTER_PATTERNS) + + if isinstance(raw_patterns, Iterable): + parsed = [ + token.strip() + for token in raw_patterns + if isinstance(token, str) and token.strip() + ] + return parsed or list(DEFAULT_ADAPTER_PATTERNS) + + raise TypeError( + "`adapter_parameter_patterns` must be a comma-separated string " + "or list of strings." + ) + + +def _count_trainable_parameters(model: nn.Module) -> int: + return sum( + parameter.numel() for parameter in model.parameters() if parameter.requires_grad + ) + + +def _apply_finetune_mode( + policy: nn.Module, + finetune_mode: str, + adapter_patterns: list[str], +) -> dict[str, Any]: + """Set requires_grad according to the requested finetune mode.""" + named_parameters = list(policy.named_parameters()) + if not named_parameters: + raise ValueError("Loaded SmolVLA policy has no named parameters.") + + if finetune_mode == "full": + for _, parameter in named_parameters: + parameter.requires_grad = True + + trainable_names = [name for name, _ in named_parameters] + return { + "trainable_names": trainable_names, + "fallback_mode": None, + } + + original_trainable_names = { + name for name, parameter in named_parameters if parameter.requires_grad + } + + for _, parameter in named_parameters: + parameter.requires_grad = False + + lowered_patterns = [pattern.lower() for pattern in adapter_patterns] + matched_names: set[str] = set() + for name, parameter in named_parameters: + lowered_name = name.lower() + if any(pattern in lowered_name for pattern in lowered_patterns): + parameter.requires_grad = True + matched_names.add(name) + + fallback_mode: str | None = None + if not matched_names and original_trainable_names: + for name, parameter in named_parameters: + if name in original_trainable_names: + parameter.requires_grad = True + matched_names.add(name) + fallback_mode = "original_requires_grad" + logging.warning( + "SmolVLA adapter mode found no parameter names matching patterns %s; " + "falling back to model-defined requires_grad flags.", + adapter_patterns, + ) + + if _count_trainable_parameters(policy) == 0: + raise ValueError( + "SmolVLA adapter mode selected, but no trainable parameters were " + "resolved. Configure `adapter_parameter_patterns` or set " + "`finetune_mode` to 'full'." + ) + + return { + "trainable_names": sorted(matched_names), + "fallback_mode": fallback_mode, + } + + +def _ensure_checkpoint_compatibility(policy: nn.Module) -> None: + """Validate required methods for Plato aggregation and checkpoint flow.""" + required_methods = ("state_dict", "load_state_dict", "save_pretrained") + missing = [name for name in required_methods if not hasattr(policy, name)] + if missing: + joined = ", ".join(missing) + raise TypeError( + "Loaded SmolVLA policy is incompatible with Plato checkpoints; " + f"missing method(s): {joined}." + ) + + +class Model: + """Factory for LeRobot SmolVLA policies.""" + + @staticmethod + def get(model_name: str | None = None, **kwargs: Any) -> nn.Module: + """ + Build a SmolVLA policy. + + Keyword Args: + policy_path/path: Hugging Face repo id or local path for pretrained policy. + token: HF token for private repositories. + strict: Strictness flag forwarded to SmolVLAPolicy.from_pretrained(). + finetune_mode: "full" or "adapter". + adapter_parameter_patterns: Names/patterns used to select adapter params. + """ + policy_config = _resolve_policy_config() + policy_path = _resolve_policy_path(model_name, policy_config, kwargs) + + token = kwargs.get("token", policy_config.get("token", os.getenv("HF_TOKEN"))) + strict = kwargs.get("strict", policy_config.get("strict", True)) + + SmolVLAPolicy = _import_smolvla_policy() + + try: + policy = SmolVLAPolicy.from_pretrained( + policy_path, + token=token, + strict=strict, + ) + except Exception as exc: # pragma: no cover - exercised via integration + raise RuntimeError( + "Failed to load SmolVLA policy from " + f"'{policy_path}'. Check `parameters.policy.path` and access token." + ) from exc + + if not isinstance(policy, nn.Module): + raise TypeError( + "SmolVLA policy loader returned a non-module object. " + "Expected a torch.nn.Module-compatible policy." + ) + + finetune_mode = _resolve_finetune_mode(policy_config, kwargs) + adapter_patterns = _resolve_adapter_patterns(policy_config, kwargs) + trainable_metadata = _apply_finetune_mode( + policy, + finetune_mode=finetune_mode, + adapter_patterns=adapter_patterns, + ) + + _ensure_checkpoint_compatibility(policy) + + trainable_count = _count_trainable_parameters(policy) + setattr(policy, "plato_policy_path", policy_path) + setattr(policy, "plato_finetune_mode", finetune_mode) + setattr(policy, "plato_adapter_patterns", tuple(adapter_patterns)) + setattr(policy, "plato_adapter_fallback_mode", trainable_metadata["fallback_mode"]) + setattr(policy, "plato_trainable_parameter_count", trainable_count) + setattr( + policy, + "plato_trainable_parameter_names", + tuple(trainable_metadata["trainable_names"]), + ) + + return policy From 18874d4c10b1704a9a097b1fb25fcaa3e2dea801 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:17:51 -0500 Subject: [PATCH 05/15] Implemented LeRobot datasource adapter with deterministic partitioning. --- .../smolvla-lerobot-plato-integration-plan.md | 13 + plato/datasources/lerobot.py | 771 ++++++++++++++++++ plato/datasources/registry.py | 3 +- 3 files changed, 786 insertions(+), 1 deletion(-) create mode 100644 plato/datasources/lerobot.py diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index b38c51959..6f50293a2 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -85,11 +85,24 @@ gotchas: ### T4. Implement LeRobot datasource adapter depends_on: [T2, T3] +status: completed (2026-02-19) - Add `plato/datasources/lerobot.py`. - Implement dataset loading via LeRobot APIs and map samples into Plato’s expected batch format. - Register datasource in `plato/datasources/registry.py`. - Add deterministic client partitioning strategy (episode/task aware split). - Provide train/test dataset access methods compatible with existing samplers. +work_log: +- Added `plato/datasources/lerobot.py` with guarded LeRobot imports, config parsing for `parameters.dataset.*` and `parameters.transforms.*`, and sample mapping that preserves raw fields while attaching `plato_inputs`, `plato_targets`, and `plato_metadata`. +- Implemented deterministic episode-level train/test splitting with optional explicit episode overrides, task-aware stratification when task metadata is available, and deterministic per-client episode partitioning keyed by `data.random_seed`/`parameters.dataset.split_seed`. +- Wired `"LeRobot"` through `plato/datasources/registry.py` as a partitioned datasource so `datasources_registry.get(client_id=...)` passes client identity into the adapter. +- Ran a targeted no-download constructor/registry validation using stubbed `LeRobotDataset` and `LeRobotDatasetMetadata`, confirming deterministic splits and registry retrieval. +files_touched: +- `plato/datasources/lerobot.py` (created) +- `plato/datasources/registry.py` (updated) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- Constructor/registry validation was intentionally monkeypatched to avoid external dataset access and to keep split checks deterministic and offline. +- When task metadata is sparse or missing, the adapter falls back to deterministic episode-only splitting. ### T5. Implement SmolVLA model/policy wrapper depends_on: [T2, T3] diff --git a/plato/datasources/lerobot.py b/plato/datasources/lerobot.py new file mode 100644 index 000000000..f6d632ff0 --- /dev/null +++ b/plato/datasources/lerobot.py @@ -0,0 +1,771 @@ +"""LeRobot datasource with deterministic, episode-aware client partitioning.""" + +from __future__ import annotations + +import hashlib +import inspect +import logging +import random +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import Any + +from plato.config import Config +from plato.datasources import base + +_DEFAULT_TRAIN_SPLIT = 0.8 +_DEFAULT_SPLIT_SEED = 1 +_DEFAULT_NORMALIZE_MEAN = (0.485, 0.456, 0.406) +_DEFAULT_NORMALIZE_STD = (0.229, 0.224, 0.225) + +_EPISODE_INDEX_KEYS = ("episode_index", "episode_id", "index") +_TASK_KEYS = ( + "task", + "task_name", + "language_instruction", + "language_instruction_2", + "language_instruction_3", + "task_id", +) + + +class _EmptyDataset: + """A minimal dataset object for empty episode partitions.""" + + targets: list[Any] = [] + classes: list[str] = [] + + def __len__(self) -> int: + return 0 + + def __getitem__(self, index: int): + raise IndexError(f"Empty dataset does not contain index {index}.") + + +class _MappedLeRobotDataset: + """Wrap a LeRobot dataset and attach Plato-friendly canonical keys.""" + + def __init__(self, dataset: Any): + self._dataset = dataset + self.targets = getattr(dataset, "targets", None) + self.classes = getattr(dataset, "classes", None) + + def __len__(self) -> int: + return len(self._dataset) + + def __getitem__(self, index: int) -> dict[str, Any]: + sample = self._dataset[index] + + if isinstance(sample, Mapping): + raw_sample = dict(sample) + else: + return { + "plato_inputs": sample, + "plato_targets": None, + "plato_metadata": {}, + } + + inputs: dict[str, Any] = {} + targets: dict[str, Any] = {} + metadata: dict[str, Any] = {} + + for key, value in raw_sample.items(): + if key.startswith("observation"): + inputs[key] = value + elif key == "action" or key.startswith("action."): + targets[key] = value + else: + metadata[key] = value + + mapped = dict(raw_sample) + mapped["plato_inputs"] = inputs + mapped["plato_targets"] = raw_sample.get("action", targets or None) + mapped["plato_metadata"] = metadata + return mapped + + def __getattr__(self, name: str) -> Any: + return getattr(self._dataset, name) + + +def _import_lerobot() -> tuple[Any, Any]: + try: + from lerobot.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, + ) + except ImportError as exc: # pragma: no cover - exercised without robotics extra. + raise ImportError( + "LeRobot datasource requires optional robotics dependencies. " + "Install them with \"uv sync --extra robotics\" before using " + '"data.datasource = \"LeRobot\"". ' + ) from exc + + return LeRobotDataset, LeRobotDatasetMetadata + + +def _to_plain(value: Any) -> Any: + if value is None: + return None + if isinstance(value, Mapping): + return {str(key): _to_plain(val) for key, val in value.items()} + if hasattr(value, "_asdict"): + return {str(key): _to_plain(val) for key, val in value._asdict().items()} + if isinstance(value, (list, tuple, set)): + return [_to_plain(item) for item in value] + return value + + +def _to_plain_dict(value: Any) -> dict[str, Any]: + plain = _to_plain(value) + if plain is None: + return {} + if isinstance(plain, Mapping): + return {str(key): val for key, val in plain.items()} + raise TypeError(f"Expected mapping-like configuration, got {type(value)}.") + + +def _as_int(value: Any) -> int | None: + if value is None: + return None + if isinstance(value, bool): + return int(value) + if isinstance(value, int): + return value + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _stable_seed(seed: int, key: str) -> int: + digest = hashlib.sha256(f"{seed}:{key}".encode("utf-8")).digest() + return int.from_bytes(digest[:8], "big") + + +def _parse_size(value: Any) -> tuple[int, int] | int | None: + if value is None: + return None + if isinstance(value, int): + return value + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + parts = [int(item) for item in value] + if not parts: + return None + if len(parts) == 1: + return parts[0] + return (parts[0], parts[1]) + raise TypeError("Expected an int or sequence for transform size.") + + +def _parse_float_sequence(value: Any, default: Sequence[float]) -> tuple[float, ...]: + if value is None: + return tuple(float(item) for item in default) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + parsed = [float(item) for item in value] + if parsed: + return tuple(parsed) + raise TypeError("Expected a numeric sequence for normalization values.") + + +def _resolve_interpolation(value: Any) -> Any: + if value is None: + return None + + try: + from torchvision.transforms import InterpolationMode + except ImportError: + return None + + interpolation_map = { + "nearest": InterpolationMode.NEAREST, + "bilinear": InterpolationMode.BILINEAR, + "bicubic": InterpolationMode.BICUBIC, + "lanczos": InterpolationMode.LANCZOS, + } + return interpolation_map.get(str(value).strip().lower(), None) + + +def _build_image_transforms(transform_cfg: Mapping[str, Any]) -> Any | None: + if not transform_cfg: + return None + + enable = transform_cfg.get("enable", True) + if isinstance(enable, bool) and not enable: + return None + + try: + import torch + from torchvision import transforms as tv_transforms + except ImportError as exc: + raise ImportError( + "LeRobot image transforms require torchvision and torch." + ) from exc + + transform_steps: list[Any] = [] + + image_size = _parse_size(transform_cfg.get("image_size")) + interpolation = _resolve_interpolation(transform_cfg.get("interpolation")) + + if image_size is not None: + resize_kwargs: dict[str, Any] = {} + if interpolation is not None: + resize_kwargs["interpolation"] = interpolation + transform_steps.append(tv_transforms.Resize(image_size, **resize_kwargs)) + + center_crop_cfg = transform_cfg.get("center_crop", None) + crop_size = None + if isinstance(center_crop_cfg, bool): + if center_crop_cfg and image_size is not None: + crop_size = image_size + elif center_crop_cfg is not None: + crop_size = _parse_size(center_crop_cfg) + elif transform_cfg.get("crop_size") is not None: + crop_size = _parse_size(transform_cfg.get("crop_size")) + + if crop_size is not None: + transform_steps.append(tv_transforms.CenterCrop(crop_size)) + + normalize = bool(transform_cfg.get("normalize", False)) + if normalize: + convert_dtype = getattr(tv_transforms, "ConvertImageDtype", None) + if callable(convert_dtype): + transform_steps.append(convert_dtype(torch.float32)) + + mean = _parse_float_sequence( + transform_cfg.get("mean"), + _DEFAULT_NORMALIZE_MEAN, + ) + std = _parse_float_sequence(transform_cfg.get("std"), _DEFAULT_NORMALIZE_STD) + transform_steps.append(tv_transforms.Normalize(mean=mean, std=std)) + + if not transform_steps: + return None + + compose = getattr(tv_transforms, "Compose", None) + if not callable(compose): + return None + + return compose(transform_steps) + + +def _normalize_delta_timestamps(value: Any) -> dict[str, list[float]] | None: + if value is None: + return None + + parsed = _to_plain(value) + if not isinstance(parsed, Mapping): + raise TypeError( + '"parameters.dataset.delta_timestamps" must be a mapping of ' + "key -> list[float]." + ) + + normalized: dict[str, list[float]] = {} + for key, offsets in parsed.items(): + if not isinstance(offsets, Sequence) or isinstance(offsets, (str, bytes)): + raise TypeError( + f"Delta timestamps for '{key}' must be a list-like sequence." + ) + normalized[str(key)] = [float(offset) for offset in offsets] + + return normalized + + +def _columnar_to_rows(columns: Mapping[str, Any]) -> list[dict[str, Any]]: + normalized: dict[str, list[Any]] = {} + max_len = 0 + + for key, value in columns.items(): + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + values = list(value) + elif hasattr(value, "tolist"): + values = list(value.tolist()) + else: + values = [value] + + normalized[str(key)] = values + max_len = max(max_len, len(values)) + + rows: list[dict[str, Any]] = [] + for row_idx in range(max_len): + row: dict[str, Any] = {} + for key, values in normalized.items(): + if row_idx < len(values): + row[key] = values[row_idx] + rows.append(row) + + return rows + + +def _episode_rows(episodes: Any) -> list[dict[str, Any]]: + if episodes is None: + return [] + + if isinstance(episodes, Mapping): + return _columnar_to_rows(episodes) + + to_pylist = getattr(episodes, "to_pylist", None) + if callable(to_pylist): + pylist = to_pylist() + if isinstance(pylist, list): + return [ + dict(item) if isinstance(item, Mapping) else {"value": item} + for item in pylist + ] + + to_dict = getattr(episodes, "to_dict", None) + if callable(to_dict): + try: + as_dict = to_dict(orient="list") + except TypeError: + as_dict = to_dict() + if isinstance(as_dict, Mapping): + return _columnar_to_rows(as_dict) + + if isinstance(episodes, Sequence) and not isinstance(episodes, (str, bytes)): + rows: list[dict[str, Any]] = [] + for entry in episodes: + if isinstance(entry, Mapping): + rows.append(dict(entry)) + elif hasattr(entry, "_asdict"): + rows.append(_to_plain_dict(entry)) + else: + rows.append({"value": entry}) + return rows + + return [] + + +def _resolve_episode_indices(metadata: Any) -> list[int]: + total_episodes = _as_int(getattr(metadata, "total_episodes", None)) + + if total_episodes is None: + total_episodes = len(_episode_rows(getattr(metadata, "episodes", None))) + + if total_episodes is None or total_episodes < 0: + return [] + + return list(range(total_episodes)) + + +def _resolve_task_name(row: Mapping[str, Any], tasks_lookup: Any) -> str | None: + task_index = _as_int(row.get("task_index")) + if task_index is not None and isinstance(tasks_lookup, Sequence): + if not isinstance(tasks_lookup, (str, bytes)) and 0 <= task_index < len( + tasks_lookup + ): + return str(tasks_lookup[task_index]) + + for key in _TASK_KEYS: + if key in row and row[key] is not None: + return str(row[key]) + + return None + + +def _resolve_episode_tasks(metadata: Any, episodes: Sequence[int]) -> dict[int, str | None]: + episode_tasks = {episode: None for episode in episodes} + episode_rows = _episode_rows(getattr(metadata, "episodes", None)) + tasks_lookup = _to_plain(getattr(metadata, "tasks", None)) + + for position, row in enumerate(episode_rows): + episode_index = None + for key in _EPISODE_INDEX_KEYS: + if key in row: + episode_index = _as_int(row.get(key)) + if episode_index is not None: + break + + if episode_index is None: + episode_index = position + + if episode_index not in episode_tasks: + continue + + task_name = _resolve_task_name(row, tasks_lookup) + if task_name is not None: + episode_tasks[episode_index] = task_name + + return episode_tasks + + +def _split_count(total: int, train_ratio: float) -> int: + if total <= 1: + return total + + if train_ratio <= 0: + return 0 + if train_ratio >= 1: + return total + + split_count = int(round(total * train_ratio)) + split_count = max(1, min(total - 1, split_count)) + return split_count + + +def _split_episodes( + episodes: Sequence[int], + episode_tasks: Mapping[int, str | None], + train_ratio: float, + seed: int, + task_aware: bool, +) -> tuple[list[int], list[int]]: + if not episodes: + return [], [] + + all_episodes = [int(episode) for episode in episodes] + has_task_info = task_aware and any(episode_tasks.get(ep) for ep in all_episodes) + + if has_task_info: + grouped: dict[str, list[int]] = defaultdict(list) + for episode in all_episodes: + grouped[str(episode_tasks.get(episode) or "__unknown_task__")].append( + episode + ) + + train_episodes: list[int] = [] + test_episodes: list[int] = [] + + for task_name in sorted(grouped): + task_episodes = sorted(grouped[task_name]) + rng = random.Random(_stable_seed(seed, f"split:{task_name}")) + rng.shuffle(task_episodes) + + split_idx = _split_count(len(task_episodes), train_ratio) + train_episodes.extend(task_episodes[:split_idx]) + test_episodes.extend(task_episodes[split_idx:]) + else: + shuffled = sorted(all_episodes) + random.Random(seed).shuffle(shuffled) + split_idx = _split_count(len(shuffled), train_ratio) + train_episodes = shuffled[:split_idx] + test_episodes = shuffled[split_idx:] + + if not train_episodes and test_episodes: + train_episodes.append(test_episodes.pop(0)) + + return sorted(train_episodes), sorted(test_episodes) + + +def _normalize_episode_list(value: Any) -> list[int] | None: + if value is None: + return None + + parsed = _to_plain(value) + + if isinstance(parsed, int): + return [int(parsed)] + + if isinstance(parsed, Sequence) and not isinstance(parsed, (str, bytes)): + return [int(item) for item in parsed] + + raise TypeError("Episode lists must be an int or a list of ints.") + + +def _resolve_episode_split( + all_episodes: Sequence[int], + explicit_train: list[int] | None, + explicit_test: list[int] | None, + episode_tasks: Mapping[int, str | None], + train_ratio: float, + seed: int, + task_aware: bool, +) -> tuple[list[int], list[int]]: + episode_set = set(int(episode) for episode in all_episodes) + + if explicit_train is None and explicit_test is None: + return _split_episodes(all_episodes, episode_tasks, train_ratio, seed, task_aware) + + train_episodes = [ + int(episode) for episode in (explicit_train or []) if int(episode) in episode_set + ] + test_episodes = [ + int(episode) + for episode in (explicit_test or []) + if int(episode) in episode_set and int(episode) not in train_episodes + ] + + if not train_episodes: + train_episodes = [ + episode for episode in all_episodes if episode not in set(test_episodes) + ] + + if not test_episodes: + test_episodes = [ + episode for episode in all_episodes if episode not in set(train_episodes) + ] + + if not train_episodes and test_episodes: + train_episodes.append(test_episodes.pop(0)) + + return sorted(train_episodes), sorted(test_episodes) + + +def _partition_episodes( + episodes: Sequence[int], + episode_tasks: Mapping[int, str | None], + total_clients: int, + client_id: int, + seed: int, + task_aware: bool, +) -> list[int]: + if not episodes: + return [] + + if total_clients <= 1 or client_id <= 0: + return sorted(int(episode) for episode in episodes) + + client_slot = (int(client_id) - 1) % int(total_clients) + all_episodes = [int(episode) for episode in episodes] + + has_task_info = task_aware and any(episode_tasks.get(ep) for ep in all_episodes) + + selected: list[int] = [] + + if has_task_info: + grouped: dict[str, list[int]] = defaultdict(list) + for episode in all_episodes: + grouped[str(episode_tasks.get(episode) or "__unknown_task__")].append( + episode + ) + + for task_name in sorted(grouped): + task_episodes = sorted(grouped[task_name]) + rng = random.Random(_stable_seed(seed, f"partition:{task_name}")) + rng.shuffle(task_episodes) + + for idx, episode in enumerate(task_episodes): + if idx % total_clients == client_slot: + selected.append(episode) + else: + shuffled = sorted(all_episodes) + random.Random(seed).shuffle(shuffled) + selected = [ + episode + for idx, episode in enumerate(shuffled) + if idx % total_clients == client_slot + ] + + if not selected: + fallback = sorted(all_episodes) + random.Random(_stable_seed(seed, "fallback")).shuffle(fallback) + selected = [fallback[client_slot % len(fallback)]] + + return sorted(selected) + + +def _resolve_default_seed() -> int: + data_cfg = getattr(Config(), "data", None) + configured_seed = _as_int(getattr(data_cfg, "random_seed", None)) + if configured_seed is not None: + return configured_seed + return _DEFAULT_SPLIT_SEED + + +def _resolve_total_clients(config: Any) -> int: + clients_cfg = getattr(config, "clients", None) + total_clients = _as_int(getattr(clients_cfg, "total_clients", 1)) + if total_clients is None or total_clients <= 0: + return 1 + return total_clients + + +def _filter_constructor_kwargs(dataset_cls: Any, kwargs: Mapping[str, Any]) -> dict[str, Any]: + try: + signature = inspect.signature(dataset_cls.__init__) + except (TypeError, ValueError): + return dict(kwargs) + + accepts_var_kwargs = any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + if accepts_var_kwargs: + return dict(kwargs) + + valid_parameters = { + name for name in signature.parameters.keys() if name != "self" + } + filtered = {key: value for key, value in kwargs.items() if key in valid_parameters} + + dropped = sorted(set(kwargs.keys()) - set(filtered.keys())) + if dropped: + logging.warning( + "LeRobot datasource ignored unsupported dataset kwargs: %s", + ", ".join(dropped), + ) + + return filtered + + +class DataSource(base.DataSource): + """LeRobot datasource with deterministic train/test and per-client episode splits.""" + + def __init__(self, client_id: int = 0, **kwargs): + super().__init__() + + LeRobotDataset, LeRobotDatasetMetadata = _import_lerobot() + + config = Config() + params_cfg = getattr(config, "parameters", None) + + dataset_cfg = _to_plain_dict(getattr(params_cfg, "dataset", None)) + transform_cfg = _to_plain_dict(getattr(params_cfg, "transforms", None)) + + dataset_cfg.update(_to_plain_dict(kwargs.pop("dataset_kwargs", None))) + transform_cfg.update(_to_plain_dict(kwargs.pop("transform_kwargs", None))) + + for key in ( + "repo_id", + "delta_timestamps", + "split_seed", + "train_split", + "test_split", + "train_episodes", + "test_episodes", + "task_aware_split", + "task_aware_partition", + ): + if key in kwargs: + dataset_cfg[key] = kwargs.pop(key) + + for key in ( + "image_size", + "interpolation", + "center_crop", + "crop_size", + "normalize", + "mean", + "std", + "enable", + ): + if key in kwargs: + transform_cfg[key] = kwargs.pop(key) + + dataset_cfg.update(_to_plain_dict(kwargs)) + + repo_id = str(dataset_cfg.pop("repo_id", "")).strip() + if not repo_id: + raise ValueError( + "LeRobot datasource requires " + '"parameters.dataset.repo_id" to be set.' + ) + + train_split_raw = dataset_cfg.pop("train_split", _DEFAULT_TRAIN_SPLIT) + test_split_raw = dataset_cfg.pop("test_split", None) + + train_split = float(train_split_raw) + if test_split_raw is not None: + train_split = 1.0 - float(test_split_raw) + train_split = max(0.0, min(1.0, train_split)) + + split_seed = _as_int(dataset_cfg.pop("split_seed", None)) + if split_seed is None: + split_seed = _resolve_default_seed() + + task_aware_split = bool(dataset_cfg.pop("task_aware_split", True)) + task_aware_partition = bool(dataset_cfg.pop("task_aware_partition", True)) + + delta_timestamps = _normalize_delta_timestamps( + dataset_cfg.pop("delta_timestamps", None) + ) + + explicit_train_episodes = _normalize_episode_list( + dataset_cfg.pop("train_episodes", None) + ) + explicit_test_episodes = _normalize_episode_list( + dataset_cfg.pop("test_episodes", None) + ) + + metadata = LeRobotDatasetMetadata(repo_id) + all_episodes = _resolve_episode_indices(metadata) + + if not all_episodes: + raise ValueError(f"No episodes found for LeRobot dataset '{repo_id}'.") + + episode_tasks = _resolve_episode_tasks(metadata, all_episodes) + train_episodes, test_episodes = _resolve_episode_split( + all_episodes, + explicit_train_episodes, + explicit_test_episodes, + episode_tasks, + train_split, + split_seed, + task_aware_split, + ) + + total_clients = _resolve_total_clients(config) + resolved_client_id = int(client_id) + + client_train_episodes = _partition_episodes( + train_episodes, + episode_tasks, + total_clients, + resolved_client_id, + split_seed, + task_aware_partition, + ) + client_test_episodes = _partition_episodes( + test_episodes, + episode_tasks, + total_clients, + resolved_client_id, + split_seed + 1, + task_aware_partition, + ) + + image_transforms = _build_image_transforms(transform_cfg) + dataset_kwargs = _filter_constructor_kwargs(LeRobotDataset, dataset_cfg) + + if client_train_episodes: + train_dataset = LeRobotDataset( + repo_id, + episodes=client_train_episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + **dataset_kwargs, + ) + else: + train_dataset = _EmptyDataset() + + if client_test_episodes: + test_dataset = LeRobotDataset( + repo_id, + episodes=client_test_episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + **dataset_kwargs, + ) + else: + test_dataset = _EmptyDataset() + + self.trainset = _MappedLeRobotDataset(train_dataset) + self.testset = _MappedLeRobotDataset(test_dataset) + + self.repo_id = repo_id + self.client_id = resolved_client_id + self.train_episodes = client_train_episodes + self.test_episodes = client_test_episodes + self.meta = metadata + + logging.info( + "LeRobot datasource ready for client %s: train episodes=%s, test episodes=%s", + resolved_client_id, + len(client_train_episodes), + len(client_test_episodes), + ) + + @staticmethod + def input_shape(): + """Return shape hint from configured transform image size when available.""" + params_cfg = getattr(Config(), "parameters", None) + transform_cfg = _to_plain_dict(getattr(params_cfg, "transforms", None)) + + image_size = _parse_size(transform_cfg.get("image_size")) + if isinstance(image_size, tuple): + return (3, image_size[0], image_size[1]) + + raise ValueError( + "LeRobot datasource input shape is dataset-dependent. " + 'Set "parameters.transforms.image_size" ' + "for a static shape hint." + ) diff --git a/plato/datasources/registry.py b/plato/datasources/registry.py index d4bda0c8d..1abc8108a 100644 --- a/plato/datasources/registry.py +++ b/plato/datasources/registry.py @@ -11,6 +11,7 @@ feature, femnist, huggingface, + lerobot, lora, nanochat, purchase, @@ -31,7 +32,7 @@ "Nanochat": nanochat, } -registered_partitioned_datasources = {"FEMNIST": femnist} +registered_partitioned_datasources = {"FEMNIST": femnist, "LeRobot": lerobot} _datasource_aliases = { "STL10": ("Torchvision", {"dataset_name": "STL10"}), From bd6dd3fbb3046a4ec954a84a5d4c01ff3505d26e Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:27:02 -0500 Subject: [PATCH 06/15] Implemented LeRobot trainer backend for SmolVLA integration. --- .../smolvla-lerobot-plato-integration-plan.md | 17 + plato/trainers/lerobot.py | 501 ++++++++++++++++++ plato/trainers/registry.py | 4 + 3 files changed, 522 insertions(+) create mode 100644 plato/trainers/lerobot.py diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 6f50293a2..57b1d8ec7 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -127,11 +127,28 @@ gotchas: ### T6. Implement LeRobot trainer backend depends_on: [T4, T5] +status: completed (2026-02-19) - Add `plato/trainers/lerobot.py` (ComposableTrainer-compatible). - Implement multimodal collate + preprocessing for LeRobot samples. - Wire forward/loss/backward/optimizer/scheduler flow for SmolVLA policy. - Implement evaluation hooks suitable for regression checks. - Register trainer in `plato/trainers/registry.py`. +work_log: +- Added `plato/trainers/lerobot.py` with a ComposableTrainer-compatible backend that wires custom dict/multimodal collation, processor-aware training steps, and evaluation loss reporting for regression checks. +- Implemented LeRobot pre/post-processor integration via `make_pre_post_processors(policy_cfg, pretrained_path=..., dataset_stats=...)`, with lazy optional-dependency import guards and actionable installation errors. +- Implemented SmolVLA policy forward integration handling tuple loss outputs and preserving optimizer + scheduler flow through the base composable lifecycle. +- Registered `trainer.type = "lerobot"` in `plato/trainers/registry.py`. +- Ran targeted offline validation with monkeypatched processor stubs: +- trainer registry resolution and construction (`trainer.type = "lerobot"`), +- synthetic one-epoch training-step path, +- synthetic evaluation pass returning numeric loss. +- Ran `uv run ruff check` on touched trainer files. +files_touched: +- `plato/trainers/lerobot.py` (created) +- `plato/trainers/registry.py` (updated) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- LeRobot preprocessing depends on optional robotics extras at runtime; the trainer defers imports until processor initialization so non-robotics workloads remain unaffected, and it fails with a clear `uv sync --extra robotics` message when required dependencies are missing. ### T7. Harden federated payload/aggregation behavior depends_on: [T6] diff --git a/plato/trainers/lerobot.py b/plato/trainers/lerobot.py new file mode 100644 index 000000000..e48caafdc --- /dev/null +++ b/plato/trainers/lerobot.py @@ -0,0 +1,501 @@ +"""Composable trainer for LeRobot policies such as SmolVLA.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable, Iterable, Mapping +from typing import Any, cast + +import torch +import torch.nn as nn +import torch.utils.data +from torch.utils.data._utils.collate import default_collate + +from plato.config import Config +from plato.trainers.composable import ComposableTrainer +from plato.trainers.strategies import CustomCollateFnDataLoaderStrategy +from plato.trainers.strategies.base import ( + TestingStrategy, + TrainingContext, + TrainingStepStrategy, +) + +_RESERVED_KEYS = frozenset({"plato_inputs", "plato_targets", "plato_metadata"}) + + +def _config_node_to_dict(node: Any) -> dict[str, Any]: + """Convert config sections to plain dictionaries.""" + if node is None: + return {} + if isinstance(node, dict): + return dict(node) + if hasattr(node, "_asdict"): + return dict(node._asdict()) + if hasattr(node, "__dict__"): + return { + key: value + for key, value in node.__dict__.items() + if not key.startswith("_") + } + return {} + + +def _import_make_pre_post_processors() -> Callable[..., tuple[Callable, Callable]]: + """Import LeRobot pre/post processors lazily to keep robotics deps optional.""" + try: + from lerobot.policies.factory import make_pre_post_processors + except ImportError as exc: # pragma: no cover - optional dependency + raise ImportError( + "LeRobot trainer requires optional robotics dependencies. " + "Install them with `uv sync --extra robotics`." + ) from exc + + return make_pre_post_processors + + +def _move_to_device(value: Any, device: torch.device | str) -> Any: + """Recursively move tensors inside nested containers to target device.""" + if torch.is_tensor(value): + return value.to(device) + if isinstance(value, Mapping): + return {key: _move_to_device(item, device) for key, item in value.items()} + if isinstance(value, list): + return [_move_to_device(item, device) for item in value] + if isinstance(value, tuple): + return tuple(_move_to_device(item, device) for item in value) + if hasattr(value, "to"): + try: + return value.to(device) + except (TypeError, AttributeError): + return value + return value + + +def _resolve_batch_size(value: Any) -> int | None: + """Infer batch size from the first tensor-like value in a nested structure.""" + if torch.is_tensor(value): + if value.ndim == 0: + return 1 + return int(value.shape[0]) + if isinstance(value, Mapping): + for nested in value.values(): + size = _resolve_batch_size(nested) + if size is not None: + return size + if isinstance(value, list): + return len(value) if value else None + return None + + +def _collate_values(values: list[Any]) -> Any: + """Collate values with a robust fallback for heterogeneous metadata fields.""" + if not values: + return values + if any(value is None for value in values): + return list(values) + + if all(isinstance(value, Mapping) for value in values): + keys: list[str] = [] + for value in values: + for key in value: + if key not in keys: + keys.append(str(key)) + return { + key: _collate_values( + [cast(Mapping[str, Any], value).get(key) for value in values] + ) + for key in keys + } + + try: + return default_collate(values) + except (TypeError, RuntimeError): + return list(values) + + +def _extract_tensor_label(label: Any) -> torch.Tensor | None: + """Extract a tensor label from common LeRobot sample target structures.""" + if torch.is_tensor(label): + return label + if isinstance(label, Mapping): + action = label.get("action") + if torch.is_tensor(action): + return action + return None + + +class LeRobotBatch(dict): + """Dictionary batch wrapper that supports `.to(device)`.""" + + def to(self, device: torch.device | str): + for key, value in list(self.items()): + self[key] = _move_to_device(value, device) + return self + + +class LeRobotCollateWrapper: + """Collate LeRobot dict samples into `(inputs, labels)` for ComposableTrainer.""" + + def __call__( + self, + examples: Iterable[Any], + ) -> tuple[LeRobotBatch, torch.Tensor]: + example_list = list(examples) + if not example_list: + raise ValueError("LeRobot collate received an empty batch.") + + normalized_inputs: list[dict[str, Any]] = [] + raw_labels: list[Any] = [] + + for sample in example_list: + if isinstance(sample, Mapping): + sample_dict = dict(sample) + label = sample_dict.get("plato_targets") + if label is None: + label = sample_dict.get("action") + + payload = { + key: value + for key, value in sample_dict.items() + if key not in _RESERVED_KEYS + } + + if not payload: + plato_inputs = sample_dict.get("plato_inputs") + if isinstance(plato_inputs, Mapping): + payload = dict(plato_inputs) + + if "action" not in payload and torch.is_tensor(label): + payload["action"] = label + + normalized_inputs.append(payload) + raw_labels.append(label) + else: + normalized_inputs.append({"observation": sample}) + raw_labels.append(None) + + batched_inputs = _collate_values(normalized_inputs) + if not isinstance(batched_inputs, Mapping): + batched_inputs = {"inputs": batched_inputs} + batch_dict = LeRobotBatch(dict(batched_inputs)) + + tensor_labels = [_extract_tensor_label(label) for label in raw_labels] + if tensor_labels and all(label is not None for label in tensor_labels): + labels = _collate_values(cast(list[Any], tensor_labels)) + if torch.is_tensor(labels): + return batch_dict, labels + + action_tensor = batch_dict.get("action") + if torch.is_tensor(action_tensor): + return batch_dict, action_tensor + + batch_size = _resolve_batch_size(batch_dict) + if batch_size is None: + batch_size = len(example_list) + labels = torch.zeros(batch_size, dtype=torch.float32) + return batch_dict, labels + + +def _resolve_policy_forward( + model: nn.Module, + batch: Mapping[str, Any], + reduction: str, +) -> tuple[torch.Tensor, Mapping[str, Any]]: + """Call the policy and normalize output into `(loss_tensor, loss_dict)`.""" + forward_result = model.forward(batch, reduction=reduction) + + if torch.is_tensor(forward_result): + return forward_result, {} + + if isinstance(forward_result, tuple): + if len(forward_result) == 0: + raise ValueError("LeRobot policy forward returned an empty tuple.") + + first = forward_result[0] + second = forward_result[1] if len(forward_result) > 1 else {} + + if torch.is_tensor(first): + loss_dict = second if isinstance(second, Mapping) else {} + return first, cast(Mapping[str, Any], loss_dict) + + if torch.is_tensor(second): + loss_dict = first if isinstance(first, Mapping) else {} + return second, cast(Mapping[str, Any], loss_dict) + + if isinstance(first, Mapping): + maybe_loss = first.get("loss") + if torch.is_tensor(maybe_loss): + return maybe_loss, first + + if isinstance(forward_result, Mapping): + maybe_loss = forward_result.get("loss") + if torch.is_tensor(maybe_loss): + return maybe_loss, forward_result + + raise TypeError( + "LeRobot policy forward must return a tensor loss or a tuple containing " + "a tensor loss. Received: " + f"{type(forward_result)}." + ) + + +def _apply_preprocessor( + batch: LeRobotBatch, + context: TrainingContext, +) -> LeRobotBatch: + """Apply the optional LeRobot preprocessor.""" + preprocessor = context.state.get("lerobot_preprocessor") + if preprocessor is None: + return batch + + processed = preprocessor(dict(batch)) + if not isinstance(processed, Mapping): + raise TypeError( + "LeRobot preprocessor must return a mapping batch. " + f"Received {type(processed)}." + ) + return LeRobotBatch(dict(processed)) + + +def _summarize_loss_dict(loss_dict: Mapping[str, Any]) -> dict[str, float]: + """Store scalar loss dictionary values for debugging/monitoring hooks.""" + summary: dict[str, float] = {} + for key, value in loss_dict.items(): + if torch.is_tensor(value): + try: + summary[str(key)] = float(value.detach().cpu().item()) + except (RuntimeError, ValueError): + continue + elif isinstance(value, (int, float)): + summary[str(key)] = float(value) + return summary + + +def _resolve_sampler_for_loader(sampler: Any) -> Any: + """Resolve a sampler config value to a torch DataLoader sampler object.""" + if sampler is None: + return None + if isinstance(sampler, torch.utils.data.Sampler): + return sampler + if isinstance(sampler, (list, range)): + return torch.utils.data.SubsetRandomSampler(sampler) + if hasattr(sampler, "get"): + return sampler.get() + return sampler + + +class LeRobotTrainingStepStrategy(TrainingStepStrategy): + """Training step strategy for LeRobot policies with dict-style batches.""" + + def __init__(self, reduction: str = "mean"): + self.reduction = reduction + + def training_step( + self, + model: nn.Module, + optimizer: torch.optim.Optimizer, + examples: LeRobotBatch, + labels: torch.Tensor, # pylint: disable=unused-argument + loss_criterion, # pylint: disable=unused-argument + context: TrainingContext, + ) -> torch.Tensor: + optimizer.zero_grad() + + batch = _apply_preprocessor(examples, context) + loss, loss_dict = _resolve_policy_forward( + model, + batch, + reduction=self.reduction, + ) + + if not torch.is_tensor(loss): + raise TypeError( + "LeRobot policy forward did not return a tensor loss." + ) + + loss.backward() + optimizer.step() + + context.state["optimizer_step_completed"] = True + context.state["lerobot_loss_dict"] = _summarize_loss_dict(loss_dict) + return loss.detach() + + +class LeRobotTestingStrategy(TestingStrategy): + """Compute a stable average evaluation loss for regression checks.""" + + def __init__(self, collate_fn: LeRobotCollateWrapper, reduction: str = "mean"): + self.collate_fn = collate_fn + self.reduction = reduction + + def test_model( + self, + model: nn.Module, + config: dict[str, Any], + testset, + sampler, + context: TrainingContext, + ) -> float: + batch_size = int(config.get("batch_size", 1)) + sampler_obj = _resolve_sampler_for_loader(sampler) + + test_loader = torch.utils.data.DataLoader( + testset, + batch_size=batch_size, + shuffle=False, + sampler=sampler_obj, + collate_fn=self.collate_fn, + ) + + model.to(context.device) + model.eval() + context.state["eval_loader"] = test_loader + + total_loss = 0.0 + total_weight = 0 + + with torch.no_grad(): + for examples, labels in test_loader: + examples = examples.to(context.device) + labels = labels.to(context.device) + + batch = _apply_preprocessor(examples, context) + loss, _ = _resolve_policy_forward( + model, + batch, + reduction=self.reduction, + ) + + batch_weight = int(labels.size(0)) + if batch_weight <= 0: + inferred = _resolve_batch_size(batch) + batch_weight = inferred if inferred is not None else 1 + + total_loss += float(loss.detach().item()) * batch_weight + total_weight += batch_weight + + model.train() + context.state.pop("eval_loader", None) + + if total_weight == 0: + return float("inf") + + eval_loss = total_loss / total_weight + context.state["lerobot_eval_loss"] = eval_loss + return eval_loss + + +def _resolve_dataset_stats(dataset: Any) -> Any: + """Extract dataset statistics from LeRobot datasets/subsets when available.""" + if dataset is None: + return None + + metadata_candidates = ( + getattr(dataset, "meta", None), + getattr(dataset, "metadata", None), + ) + for metadata in metadata_candidates: + stats = getattr(metadata, "stats", None) + if stats is not None: + return stats + + nested_dataset = getattr(dataset, "dataset", None) + if nested_dataset is not None and nested_dataset is not dataset: + return _resolve_dataset_stats(nested_dataset) + + return None + + +class Trainer(ComposableTrainer): + """Composable LeRobot trainer backend.""" + + def __init__(self, model=None, callbacks=None): + self._collate_wrapper = LeRobotCollateWrapper() + self._processors_initialised = False + self._pretrained_path = self._resolve_policy_path() + self._preprocessor_factory: Callable[..., tuple[Callable, Callable]] | None = ( + None + ) + + super().__init__( + model=model, + callbacks=callbacks, + loss_strategy=None, + optimizer_strategy=None, + training_step_strategy=LeRobotTrainingStepStrategy(), + lr_scheduler_strategy=None, + model_update_strategy=None, + data_loader_strategy=CustomCollateFnDataLoaderStrategy( + collate_fn=self._collate_wrapper, + num_workers=0, + pin_memory=True, + ), + testing_strategy=LeRobotTestingStrategy(self._collate_wrapper), + ) + + self.context.state["lerobot_preprocessor"] = None + self.context.state["lerobot_postprocessor"] = None + + @staticmethod + def _resolve_policy_path() -> str | None: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + candidate = policy_cfg.get("path") + if isinstance(candidate, str): + value = candidate.strip() + return value if value else None + return None + + def _resolve_model_pretrained_path(self) -> str | None: + model = self._require_model() + model_path = getattr(model, "plato_policy_path", None) + if isinstance(model_path, str) and model_path.strip(): + return model_path.strip() + return self._pretrained_path + + def _ensure_pre_post_processors(self, dataset: Any) -> None: + if self._processors_initialised: + return + + model = self._require_model() + policy_config = getattr(model, "config", None) + if policy_config is None: + raise AttributeError( + "LeRobot trainer expects the model to expose a `config` attribute " + "compatible with `make_pre_post_processors`." + ) + + if self._preprocessor_factory is None: + self._preprocessor_factory = _import_make_pre_post_processors() + + dataset_stats = _resolve_dataset_stats(dataset) + if dataset_stats is None: + logging.warning( + "LeRobot dataset statistics are unavailable; preprocessing will " + "be created without explicit dataset stats." + ) + + kwargs: dict[str, Any] = {"dataset_stats": dataset_stats} + pretrained_path = self._resolve_model_pretrained_path() + if pretrained_path: + kwargs["pretrained_path"] = pretrained_path + + preprocessor, postprocessor = self._preprocessor_factory( + policy_config, **kwargs + ) + if not callable(preprocessor) or not callable(postprocessor): + raise TypeError( + "LeRobot `make_pre_post_processors` must return two callables." + ) + + self.context.state["lerobot_preprocessor"] = preprocessor + self.context.state["lerobot_postprocessor"] = postprocessor + self._processors_initialised = True + + def train_model(self, config, trainset, sampler, **kwargs): + self._ensure_pre_post_processors(trainset) + return super().train_model(config, trainset, sampler, **kwargs) + + def test_model(self, config, testset, sampler=None, **kwargs): + self._ensure_pre_post_processors(testset) + return super().test_model(config, testset, sampler, **kwargs) diff --git a/plato/trainers/registry.py b/plato/trainers/registry.py index 5f54403d6..77e7666e9 100644 --- a/plato/trainers/registry.py +++ b/plato/trainers/registry.py @@ -13,6 +13,9 @@ pfedgraph, split_learning, ) +from plato.trainers import ( + lerobot as lerobot_trainer, +) from plato.trainers import ( nanochat as nanochat_trainer, ) @@ -25,6 +28,7 @@ "pfedgraph": pfedgraph.Trainer, "split_learning": split_learning.Trainer, "nanochat": nanochat_trainer.Trainer, + "lerobot": lerobot_trainer.Trainer, } From 97e34d948248e8a59c0399606103d760af36253d Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:32:51 -0500 Subject: [PATCH 07/15] Added runnable LeRobot SmolVLA experiment configs. --- configs/LeRobot/lerobot_datasource_base.toml | 9 +++ .../smolvla_fedavg_two_client_smoke.toml | 70 +++++++++++++++++ configs/LeRobot/smolvla_full_finetune.toml | 77 +++++++++++++++++++ .../LeRobot/smolvla_single_client_smoke.toml | 70 +++++++++++++++++ .../smolvla-lerobot-plato-integration-plan.md | 14 ++++ 5 files changed, 240 insertions(+) create mode 100644 configs/LeRobot/lerobot_datasource_base.toml create mode 100644 configs/LeRobot/smolvla_fedavg_two_client_smoke.toml create mode 100644 configs/LeRobot/smolvla_full_finetune.toml create mode 100644 configs/LeRobot/smolvla_single_client_smoke.toml diff --git a/configs/LeRobot/lerobot_datasource_base.toml b/configs/LeRobot/lerobot_datasource_base.toml new file mode 100644 index 000000000..7228f4848 --- /dev/null +++ b/configs/LeRobot/lerobot_datasource_base.toml @@ -0,0 +1,9 @@ + +# The training and testing dataset +datasource = "LeRobot" + +# IID or non-IID? LeRobot handles deterministic episode partitioning internally. +sampler = "iid" + +# The random seed used for deterministic train/test split and client partitioning. +random_seed = 1 diff --git a/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml b/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml new file mode 100644 index 000000000..45db563b9 --- /dev/null +++ b/configs/LeRobot/smolvla_fedavg_two_client_smoke.toml @@ -0,0 +1,70 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8001 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_fedavg_two_client_smoke" +model_path = "models/lerobot/smolvla_fedavg_two_client_smoke" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 1 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 1 +batch_size = 2 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" +precision = "fp32" +device = "cpu" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.0 diff --git a/configs/LeRobot/smolvla_full_finetune.toml b/configs/LeRobot/smolvla_full_finetune.toml new file mode 100644 index 000000000..89c83b6f6 --- /dev/null +++ b/configs/LeRobot/smolvla_full_finetune.toml @@ -0,0 +1,77 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 2 + +# The number of clients selected in each round +per_round = 2 + +# Should the clients compute test accuracy locally? +do_test = true + +[server] +address = "127.0.0.1" +port = 8002 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_full_finetune" +model_path = "models/lerobot/smolvla_full_finetune" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 10 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 5 + +# SmolVLA upstream fine-tuning guidance uses batch size 64. +batch_size = 64 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "full" +precision = "bf16" +device = "cuda" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +split_seed = 7 +train_split = 0.9 +task_aware_split = true +task_aware_partition = true +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +center_crop = true +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.01 diff --git a/configs/LeRobot/smolvla_single_client_smoke.toml b/configs/LeRobot/smolvla_single_client_smoke.toml new file mode 100644 index 000000000..0c48f00f5 --- /dev/null +++ b/configs/LeRobot/smolvla_single_client_smoke.toml @@ -0,0 +1,70 @@ +[clients] + +# Type +type = "simple" + +# The total number of clients +total_clients = 1 + +# The number of clients selected in each round +per_round = 1 + +# Should the clients compute test accuracy locally? +do_test = false + +[server] +address = "127.0.0.1" +port = 8000 +simulate_wall_time = false +checkpoint_path = "checkpoints/lerobot/smolvla_single_client_smoke" +model_path = "models/lerobot/smolvla_single_client_smoke" + +[data] +include = "lerobot_datasource_base.toml" + +[trainer] + +# The type of the trainer +type = "lerobot" + +# The maximum number of training rounds +rounds = 1 + +# The maximum number of clients running concurrently +max_concurrency = 1 + +# The machine learning model +model_type = "smolvla" +model_name = "smolvla" + +# Number of epoches for local training in each communication round +epochs = 1 +batch_size = 2 +optimizer = "AdamW" + +[algorithm] + +# Aggregation algorithm +type = "fedavg" + +[parameters] + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" +precision = "fp32" +device = "cpu" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" + +[parameters.optimizer] +lr = 0.0001 +weight_decay = 0.0 diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 57b1d8ec7..529da90f7 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -167,11 +167,25 @@ depends_on: [T6] ### T9. Add runnable experiment configs depends_on: [T3, T4, T5, T6] +status: completed (2026-02-19) - Add `configs/LeRobot/` config set: - reusable base datasource fragment - minimal smoke config - full fine-tune config aligned to SmolVLA guidance - Ensure includes/overrides follow repository config conventions. +work_log: +- Added `configs/LeRobot/` with a reusable datasource include fragment plus runnable single-client smoke, two-client FedAvg smoke, and fuller full-fine-tune configs. +- Aligned all new configs with T4-T6 integration keys: `data.datasource = "LeRobot"`, `trainer.type = "lerobot"`, `trainer.model_type = "smolvla"`, and explicit `[parameters.policy]`, `[parameters.dataset]`, `[parameters.transforms]` sections. +- Mapped SmolVLA fine-tuning guidance into Plato semantics by keeping `policy.path = "lerobot/smolvla_base"`, `policy.finetune_mode = "full"`, `policy.device = "cuda"`, and `batch_size = 64` in the fuller config. +files_touched: +- `configs/LeRobot/lerobot_datasource_base.toml` (created) +- `configs/LeRobot/smolvla_single_client_smoke.toml` (created) +- `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` (created) +- `configs/LeRobot/smolvla_full_finetune.toml` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- The datasource include fragment is intentionally sectionless so `[data].include` merges it directly into the `data` table. +- SmolVLA upstream examples are step-based (`lerobot-train --steps`), while Plato scheduling is round/epoch-based, so the fuller config mirrors guidance through batch/device/fine-tune mode and keeps runtime knobs in `trainer.rounds` + `trainer.epochs`. ### T10. Add tests (unit + integration smoke) depends_on: [T7, T8, T9] From ae9c860653f108944d2fdb7e259b5971443640fa Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:33:04 -0500 Subject: [PATCH 08/15] Validated LeRobot lifecycle compatibility with existing runtime paths. --- plans/smolvla-lerobot-plato-integration-plan.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 529da90f7..5ac1cf11b 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -159,11 +159,21 @@ depends_on: [T6] ### T8. Validate runtime lifecycle compatibility depends_on: [T6] +status: completed (2026-02-19) - Confirm integration works with existing lifecycle code paths: - client setup strategies - server trainer initialization - training/report/aggregation loop - Avoid special-case branching unless strictly necessary. +work_log: +- Ran a focused runtime smoke with `data.datasource = "LeRobot"`, `trainer.type = "lerobot"`, and `trainer.model_type = "smolvla"` using monkeypatched LeRobot/SmolVLA externals to avoid downloads, then exercised the default `simple.Client` lifecycle (`_load_data` -> `configure` -> `_allocate_data` -> `_train`). +- Verified lifecycle construction path through existing registries and strategy plumbing: datasource (`LeRobot`) + trainer (`lerobot`) + algorithm (`fedavg`) were all instantiated through default client/server setup with no special-case branching. +- Executed a short mocked client/server round-trip by feeding the client-produced payload/report into `fedavg.Server._process_reports()` after `Server.configure()`, confirming server trainer initialization and aggregation/report processing completed successfully. +- No lifecycle compatibility bug was found in this scope, so no runtime code patch was applied. +files_touched: +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- The focused smoke directly called `client._train()`; because `report.processing_time` is normally attached in the payload strategy path, the smoke sets `report.processing_time = 0.0` before invoking server report processing. ### T9. Add runnable experiment configs depends_on: [T3, T4, T5, T6] From 1edacc1e7083b46b1bef6a9f4b7686e36f2856f2 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:35:25 -0500 Subject: [PATCH 09/15] Hardened FedAvg payload filtering and dtype-safe state loading. --- .../smolvla-lerobot-plato-integration-plan.md | 13 + plato/algorithms/fedavg.py | 255 +++++++++++++++++- tests/algorithms/test_fedavg_algorithm.py | 124 +++++++++ 3 files changed, 384 insertions(+), 8 deletions(-) create mode 100644 tests/algorithms/test_fedavg_algorithm.py diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 5ac1cf11b..dc5992076 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -152,10 +152,23 @@ gotchas: ### T7. Harden federated payload/aggregation behavior depends_on: [T6] +status: completed (2026-02-19) - Ensure only intended trainable tensors are exchanged/aggregated. - Add safeguards for payload size and dtype handling. - Verify checkpoint/state restore consistency across rounds. - Validate no regressions in FedAvg flow with large model weights. +work_log: +- Hardened `plato/algorithms/fedavg.py` to exchange adapter-only tensors when `plato_finetune_mode = "adapter"` and `plato_trainable_parameter_names` are provided, while preserving full-state behavior for existing non-adapter models. +- Added dtype-safe tensor casting and partial payload merge logic in `load_weights()`, plus stricter key/shape validation and delta application safeguards for partial/full state dicts across rounds. +- Added payload-size safeguards with an optional limit (`model.plato_max_payload_size_mb` or `PLATO_FEDAVG_MAX_PAYLOAD_MB`) and fail-fast checks when payloads exceed the configured cap. +- Added targeted regression tests for filtered extract/load round-trip, dtype safety, optional payload-size guard, and full-mode FedAvg round-trip with large weights. +- Ran `uv run ruff check` and focused `uv run pytest` for the new FedAvg algorithm tests. +files_touched: +- `plato/algorithms/fedavg.py` (updated) +- `tests/algorithms/test_fedavg_algorithm.py` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- Payload-size enforcement is intentionally opt-in to maintain backward compatibility for existing workloads that may exchange large full-model state dicts. ### T8. Validate runtime lifecycle compatibility depends_on: [T6] diff --git a/plato/algorithms/fedavg.py b/plato/algorithms/fedavg.py index 605e0b945..6c00bf2ef 100644 --- a/plato/algorithms/fedavg.py +++ b/plato/algorithms/fedavg.py @@ -2,10 +2,14 @@ The federated averaging algorithm for PyTorch. """ +from __future__ import annotations + +import os from collections import OrderedDict -from collections.abc import MutableMapping -from typing import Any, Optional +from collections.abc import Iterable, Mapping, MutableMapping +from typing import Any +import torch from torch.nn import Module from plato.algorithms import base @@ -14,21 +18,209 @@ class Algorithm(base.Algorithm): """PyTorch-based federated averaging algorithm, used by both the client and the server.""" + @staticmethod + def _as_state_mapping(weights: Any, context: str) -> Mapping[str, torch.Tensor]: + """Validate and cast a state-dict-like payload.""" + if not isinstance(weights, Mapping): + raise TypeError(f"{context} must be a mapping of parameter names to tensors.") + return weights + + @staticmethod + def _cast_tensor_like( + tensor: torch.Tensor, reference: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """Cast an incoming tensor to match the dtype expected by `reference`.""" + if tensor.shape != reference.shape: + raise ValueError( + f"Tensor shape mismatch for '{tensor_name}': " + f"received {tuple(tensor.shape)}, expected {tuple(reference.shape)}." + ) + + if tensor.dtype == reference.dtype: + return tensor.detach() + + if reference.dtype == torch.bool: + if torch.is_floating_point(tensor): + return (tensor >= 0.5).detach() + return tensor.ne(0).detach() + + if torch.is_floating_point(reference) or torch.is_complex(reference): + return tensor.to(reference.dtype).detach() + + if torch.is_floating_point(tensor): + return torch.round(tensor).to(reference.dtype).detach() + + return tensor.to(reference.dtype).detach() + + @staticmethod + def _compute_tensor_delta( + current_weight: torch.Tensor, + baseline_weight: torch.Tensor, + tensor_name: str, + ) -> torch.Tensor: + """Compute a dtype-safe delta tensor for a parameter.""" + current_casted = Algorithm._cast_tensor_like( + current_weight, baseline_weight, tensor_name + ) + + if baseline_weight.dtype == torch.bool: + return current_casted.to(torch.int8) - baseline_weight.to(torch.int8) + + if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + return current_casted.to(baseline_weight.dtype) - baseline_weight + + return current_casted.to(torch.int64) - baseline_weight.to(torch.int64) + + @staticmethod + def _apply_tensor_delta( + baseline_weight: torch.Tensor, delta: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """Apply a delta tensor to the baseline tensor with dtype safeguards.""" + if delta.shape != baseline_weight.shape: + raise ValueError( + f"Delta shape mismatch for '{tensor_name}': " + f"received {tuple(delta.shape)}, expected {tuple(baseline_weight.shape)}." + ) + + if baseline_weight.dtype == torch.bool: + if torch.is_floating_point(delta): + delta_integral = torch.round(delta).to(torch.int8) + else: + delta_integral = delta.to(torch.int8) + return (baseline_weight.to(torch.int8) + delta_integral).ne(0) + + if torch.is_floating_point(baseline_weight) or torch.is_complex(baseline_weight): + return baseline_weight + delta.to(baseline_weight.dtype) + + if torch.is_floating_point(delta): + delta_casted = torch.round(delta).to(baseline_weight.dtype) + else: + delta_casted = delta.to(baseline_weight.dtype) + return baseline_weight + delta_casted + + def _resolve_payload_limit_mb(self) -> float | None: + """Resolve an optional payload-size limit from model attrs or env vars.""" + model = self.require_model() + configured_limit = getattr(model, "plato_max_payload_size_mb", None) + if configured_limit is None: + configured_limit = os.getenv("PLATO_FEDAVG_MAX_PAYLOAD_MB") + + if configured_limit is None: + return None + + try: + limit_mb = float(configured_limit) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Invalid payload size limit: {configured_limit!r}. " + "Expected a positive numeric value in megabytes." + ) from exc + + if limit_mb <= 0: + raise ValueError("Payload size limit must be greater than 0 MB.") + + return limit_mb + + @staticmethod + def _estimate_payload_size_bytes(weights: Mapping[str, torch.Tensor]) -> int: + """Estimate payload size by summing tensor storage bytes.""" + size_bytes = 0 + for name, tensor in weights.items(): + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Payload tensor '{name}' must be a torch.Tensor, " + f"received {type(tensor).__name__}." + ) + size_bytes += tensor.numel() * tensor.element_size() + return size_bytes + + def _assert_payload_size(self, weights: Mapping[str, torch.Tensor], source: str) -> None: + """Enforce an optional payload-size safeguard.""" + limit_mb = self._resolve_payload_limit_mb() + if limit_mb is None: + return + + payload_size_mb = self._estimate_payload_size_bytes(weights) / 1024**2 + if payload_size_mb > limit_mb: + raise ValueError( + f"{source} payload size {payload_size_mb:.2f} MB exceeds " + f"configured limit {limit_mb:.2f} MB." + ) + + @staticmethod + def _resolve_adapter_parameter_names( + target_model: Module, state_dict: Mapping[str, torch.Tensor] + ) -> list[str] | None: + """Resolve parameter names to exchange for adapter-only finetuning.""" + finetune_mode = getattr(target_model, "plato_finetune_mode", None) + if not isinstance(finetune_mode, str) or finetune_mode.strip().lower() != "adapter": + return None + + trainable_names_attr = getattr(target_model, "plato_trainable_parameter_names", None) + names_from_attr = ( + [ + name + for name in trainable_names_attr + if isinstance(name, str) and name in state_dict + ] + if isinstance(trainable_names_attr, Iterable) + else [] + ) + + if names_from_attr: + return names_from_attr + + names_from_requires_grad = [ + name + for name, parameter in target_model.named_parameters() + if parameter.requires_grad and name in state_dict + ] + if names_from_requires_grad: + return names_from_requires_grad + + raise ValueError( + "Adapter finetune mode is enabled, but no trainable parameters " + "were resolved for federated payload exchange." + ) + def compute_weight_deltas( self, baseline_weights: MutableMapping[str, Any], weights_received, ): """Compute the deltas between baseline weights and weights received.""" + baseline_mapping = self._as_state_mapping( + baseline_weights, context="baseline_weights" + ) + # Calculate updates from the received weights deltas = [] for weight in weights_received: + weight_mapping = self._as_state_mapping(weight, context="received weights") + self._assert_payload_size(weight_mapping, source="Received") + + unknown_keys = set(weight_mapping).difference(baseline_mapping) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Received weights include unexpected parameter(s): {unknown}.") + delta = OrderedDict() - for name, current_weight in weight.items(): - baseline = baseline_weights[name] + for name, current_weight in weight_mapping.items(): + if not isinstance(current_weight, torch.Tensor): + raise TypeError( + f"Received tensor '{name}' must be a torch.Tensor, " + f"received {type(current_weight).__name__}." + ) + + baseline = baseline_mapping[name] + if not isinstance(baseline, torch.Tensor): + raise TypeError( + f"Baseline tensor '{name}' must be a torch.Tensor, " + f"received {type(baseline).__name__}." + ) # Calculate update - _delta = current_weight - baseline + _delta = self._compute_tensor_delta(current_weight, baseline, name) delta[name] = _delta deltas.append(delta) @@ -37,10 +229,27 @@ def compute_weight_deltas( def update_weights(self, deltas): """Updates the existing model weights from the provided deltas.""" baseline_weights = self.extract_weights() + delta_mapping = self._as_state_mapping(deltas, context="deltas") updated_weights = OrderedDict() for name, weight in baseline_weights.items(): - updated_weights[name] = weight + deltas[name] + updated_weights[name] = weight + + unknown_keys = set(delta_mapping).difference(baseline_weights) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Delta includes unexpected parameter(s): {unknown}.") + + for name, delta in delta_mapping.items(): + baseline = baseline_weights[name] + if not isinstance(delta, torch.Tensor): + raise TypeError( + f"Delta tensor '{name}' must be a torch.Tensor, " + f"received {type(delta).__name__}." + ) + updated_weights[name] = self._apply_tensor_delta(baseline, delta, name) + + self._assert_payload_size(updated_weights, source="Updated") return updated_weights @@ -51,9 +260,39 @@ def extract_weights(self, model: Module | None = None): target_model = self.require_model() else: target_model = model - return target_model.cpu().state_dict() + + state_dict = target_model.state_dict() + adapter_names = self._resolve_adapter_parameter_names(target_model, state_dict) + keys_to_exchange = adapter_names or list(state_dict.keys()) + + payload = OrderedDict( + (name, state_dict[name].detach().cpu().clone()) for name in keys_to_exchange + ) + self._assert_payload_size(payload, source="Extracted") + return payload def load_weights(self, weights): """Loads the model weights passed in as a parameter.""" + weights_mapping = self._as_state_mapping(weights, context="weights") + self._assert_payload_size(weights_mapping, source="Inbound") + model: Module = self.require_model() - model.load_state_dict(weights, strict=True) + current_state = model.state_dict() + + unknown_keys = set(weights_mapping).difference(current_state) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise KeyError(f"Inbound weights include unexpected parameter(s): {unknown}.") + + merged_state = OrderedDict(current_state.items()) + for name, incoming_tensor in weights_mapping.items(): + if not isinstance(incoming_tensor, torch.Tensor): + raise TypeError( + f"Inbound tensor '{name}' must be a torch.Tensor, " + f"received {type(incoming_tensor).__name__}." + ) + merged_state[name] = self._cast_tensor_like( + incoming_tensor, current_state[name], name + ) + + model.load_state_dict(merged_state, strict=True) diff --git a/tests/algorithms/test_fedavg_algorithm.py b/tests/algorithms/test_fedavg_algorithm.py new file mode 100644 index 000000000..f6efcf837 --- /dev/null +++ b/tests/algorithms/test_fedavg_algorithm.py @@ -0,0 +1,124 @@ +"""Tests for FedAvg payload filtering and dtype-safe weight handling.""" + +from __future__ import annotations + +from collections import OrderedDict +from types import SimpleNamespace + +import pytest +import torch + +from plato.algorithms.fedavg import Algorithm as FedAvgAlgorithm + + +class AdapterToyModel(torch.nn.Module): + """Toy model exposing adapter-mode metadata used by SmolVLA integration.""" + + def __init__(self) -> None: + super().__init__() + self.backbone = torch.nn.Linear(4, 4) + self.adapter = torch.nn.Linear(4, 4, bias=False) + self.register_buffer("token_count", torch.tensor([7], dtype=torch.int64)) + self.plato_finetune_mode = "adapter" + self.plato_trainable_parameter_names = ("adapter.weight",) + + +class DtypeToyModel(torch.nn.Module): + """Toy model with mixed dtypes for casting safeguards.""" + + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones((2, 2), dtype=torch.float32)) + self.register_buffer("step", torch.tensor([1], dtype=torch.int64)) + self.register_buffer("flag", torch.tensor([True, False], dtype=torch.bool)) + + +def _algorithm_for(model: torch.nn.Module) -> FedAvgAlgorithm: + trainer = SimpleNamespace(model=model) + return FedAvgAlgorithm(trainer=trainer) + + +def _clone_state_dict(model: torch.nn.Module) -> OrderedDict[str, torch.Tensor]: + return OrderedDict( + (name, tensor.detach().clone()) for name, tensor in model.state_dict().items() + ) + + +def test_adapter_payload_extract_load_round_trip(): + """Adapter mode should exchange only trainable tensors and load them safely.""" + torch.manual_seed(1) + source_model = AdapterToyModel() + source_algorithm = _algorithm_for(source_model) + + payload = source_algorithm.extract_weights() + assert list(payload.keys()) == ["adapter.weight"] + + torch.manual_seed(2) + target_model = AdapterToyModel() + before = _clone_state_dict(target_model) + + target_algorithm = _algorithm_for(target_model) + target_algorithm.load_weights(payload) + + after = target_model.state_dict() + assert torch.equal(after["adapter.weight"], payload["adapter.weight"]) + assert torch.equal(after["backbone.weight"], before["backbone.weight"]) + assert torch.equal(after["backbone.bias"], before["backbone.bias"]) + assert torch.equal(after["token_count"], before["token_count"]) + + round_trip_payload = target_algorithm.extract_weights() + assert list(round_trip_payload.keys()) == ["adapter.weight"] + assert torch.equal(round_trip_payload["adapter.weight"], payload["adapter.weight"]) + + +def test_load_weights_casts_dtype_and_rounds_non_float_tensors(): + """Incoming partial payloads should be cast to model dtypes.""" + model = DtypeToyModel() + algorithm = _algorithm_for(model) + + inbound = OrderedDict( + { + "weight": torch.full((2, 2), 2.5, dtype=torch.float64), + "step": torch.tensor([3.6], dtype=torch.float32), + "flag": torch.tensor([0.2, 0.8], dtype=torch.float32), + } + ) + + algorithm.load_weights(inbound) + + state = model.state_dict() + assert state["weight"].dtype == torch.float32 + assert torch.allclose(state["weight"], torch.full_like(state["weight"], 2.5)) + assert state["step"].dtype == torch.int64 + assert int(state["step"].item()) == 4 + assert state["flag"].dtype == torch.bool + assert torch.equal(state["flag"], torch.tensor([False, True])) + + +def test_extract_weights_respects_optional_payload_size_limit(): + """Payload extraction should fail fast if a configured max size is exceeded.""" + model = torch.nn.Linear(32, 32, bias=False) + setattr(model, "plato_max_payload_size_mb", 0.0001) + algorithm = _algorithm_for(model) + + with pytest.raises(ValueError, match="payload size"): + algorithm.extract_weights() + + +def test_fedavg_full_mode_round_trip_with_large_weights(): + """Full-state FedAvg flow should remain compatible for non-adapter models.""" + torch.manual_seed(3) + model = torch.nn.Linear(1024, 1024, bias=False) + algorithm = _algorithm_for(model) + + baseline = algorithm.extract_weights() + assert list(baseline.keys()) == ["weight"] + + received = [OrderedDict((name, tensor + 0.25) for name, tensor in baseline.items())] + deltas = algorithm.compute_weight_deltas(baseline, received) + updated = algorithm.update_weights(deltas[0]) + + algorithm.load_weights(updated) + + state = model.state_dict() + assert torch.allclose(state["weight"], received[0]["weight"]) From d488414878e5b41d6e3faaf82409214a597f7fd0 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:45:46 -0500 Subject: [PATCH 10/15] Added SmolVLA-LeRobot unit and smoke integration tests. --- .../smolvla-lerobot-plato-integration-plan.md | 18 ++ tests/__init__.py | 1 + tests/datasources/test_lerobot_datasource.py | 100 +++++++++++ .../integration/test_lerobot_smolvla_smoke.py | 135 ++++++++++++++ tests/models/test_smolvla_model.py | 62 +++++++ tests/test_utils/__init__.py | 1 + tests/test_utils/lerobot_stubs.py | 170 ++++++++++++++++++ tests/trainers/test_lerobot_trainer.py | 133 ++++++++++++++ 8 files changed, 620 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/datasources/test_lerobot_datasource.py create mode 100644 tests/integration/test_lerobot_smolvla_smoke.py create mode 100644 tests/models/test_smolvla_model.py create mode 100644 tests/test_utils/__init__.py create mode 100644 tests/test_utils/lerobot_stubs.py create mode 100644 tests/trainers/test_lerobot_trainer.py diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index dc5992076..5eaa3f824 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -212,11 +212,29 @@ gotchas: ### T10. Add tests (unit + integration smoke) depends_on: [T7, T8, T9] +status: completed (2026-02-19) - Datasource registry + constructor tests for LeRobot datasource. - Model registry + construction tests for SmolVLA wrapper. - Trainer step test with tiny synthetic batch. - End-to-end config smoke test covering startup and one short training run. - Add regression tests for any bug fixes discovered during integration. +work_log: +- Added focused LeRobot datasource tests covering partitioned registry resolution, deterministic constructor split behavior, and mapped `plato_inputs`/`plato_targets` sample keys. +- Added SmolVLA model tests covering registry-based wrapper construction and a FedAvg regression check asserting adapter-mode metadata results in adapter-only payload extraction. +- Added a LeRobot trainer tiny-batch unit test that exercises one short training step with synthetic dict samples, stubbed pre/post processors, and parameter-update assertions. +- Added an end-to-end LeRobot+SmolVLA smoke test that boots from config, runs one short client training pass, and processes a server FedAvg report/update loop with external dependencies fully monkeypatched. +- Fixed local test import shadowing discovered during validation by adding package markers under `tests/` and `tests/test_utils/`. +files_touched: +- `tests/__init__.py` (created) +- `tests/test_utils/__init__.py` (created) +- `tests/test_utils/lerobot_stubs.py` (created) +- `tests/datasources/test_lerobot_datasource.py` (created) +- `tests/models/test_smolvla_model.py` (created) +- `tests/trainers/test_lerobot_trainer.py` (created) +- `tests/integration/test_lerobot_smolvla_smoke.py` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- The local environment includes a third-party `tests` package in site-packages; without `tests/__init__.py`, pytest imports can resolve to the wrong module namespace. ### T11. Add documentation and runbook depends_on: [T10] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..cc83e035e --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Project test package marker for reliable local imports.""" diff --git a/tests/datasources/test_lerobot_datasource.py b/tests/datasources/test_lerobot_datasource.py new file mode 100644 index 000000000..3e76c73d0 --- /dev/null +++ b/tests/datasources/test_lerobot_datasource.py @@ -0,0 +1,100 @@ +"""Tests for LeRobot datasource registry wiring and constructor behavior.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from plato.datasources import lerobot as lerobot_datasource +from plato.datasources import registry as datasources_registry +from tests.test_utils.lerobot_stubs import ( + FakeLeRobotDataset, + FakeLeRobotDatasetMetadata, +) + + +@pytest.fixture +def patched_lerobot_backend(monkeypatch): + """Patch LeRobot imports with deterministic local stubs.""" + FakeLeRobotDataset.reset_calls() + fake_transforms = object() + + monkeypatch.setattr( + lerobot_datasource, + "_import_lerobot", + lambda: (FakeLeRobotDataset, FakeLeRobotDatasetMetadata), + ) + monkeypatch.setattr( + lerobot_datasource, + "_build_image_transforms", + lambda _cfg: fake_transforms, + ) + + return SimpleNamespace(fake_transforms=fake_transforms) + + +def test_lerobot_is_registered_as_partitioned_datasource(): + """LeRobot datasource should be wired through the partitioned registry.""" + assert "LeRobot" in datasources_registry.registered_partitioned_datasources + + +def test_registry_constructs_lerobot_datasource_for_client( + temp_config, + patched_lerobot_backend, +): + """Registry should build LeRobot datasource and pass client-aware options.""" + datasource = datasources_registry.get( + datasource_name="LeRobot", + client_id=1, + repo_id="stub/lerobot", + split_seed=7, + train_split=0.5, + delta_timestamps={"observation.image": [-0.1, 0.0]}, + dataset_kwargs={"streaming": True}, + task_aware_split=False, + task_aware_partition=False, + ) + + assert isinstance(datasource, lerobot_datasource.DataSource) + assert datasource.client_id == 1 + assert datasource.repo_id == "stub/lerobot" + assert len(datasource.train_episodes) > 0 + + train_call = FakeLeRobotDataset.constructor_calls[0] + assert train_call["episodes"] == datasource.train_episodes + assert train_call["delta_timestamps"] == {"observation.image": [-0.1, 0.0]} + assert train_call["extra_kwargs"]["streaming"] is True + assert train_call["image_transforms"] is patched_lerobot_backend.fake_transforms + + +def test_lerobot_constructor_is_deterministic_and_maps_samples( + temp_config, + patched_lerobot_backend, +): + """Constructor should produce stable splits and mapped Plato sample keys.""" + kwargs = { + "client_id": 2, + "repo_id": "stub/lerobot", + "split_seed": 11, + "train_split": 0.5, + "task_aware_split": True, + "task_aware_partition": True, + } + + first = lerobot_datasource.DataSource(**kwargs) + second = lerobot_datasource.DataSource(**kwargs) + + assert first.train_episodes == second.train_episodes + assert first.test_episodes == second.test_episodes + assert first.num_train_examples() == len(first.get_train_set()) + assert first.num_test_examples() == len(first.get_test_set()) + + sample = first.get_train_set()[0] + assert "plato_inputs" in sample + assert "plato_targets" in sample + assert "plato_metadata" in sample + assert "observation.image" in sample["plato_inputs"] + assert torch.equal(sample["plato_targets"], sample["action"]) + assert sample["plato_metadata"]["episode_index"] in first.train_episodes diff --git a/tests/integration/test_lerobot_smolvla_smoke.py b/tests/integration/test_lerobot_smolvla_smoke.py new file mode 100644 index 000000000..c51370f1c --- /dev/null +++ b/tests/integration/test_lerobot_smolvla_smoke.py @@ -0,0 +1,135 @@ +"""Integration smoke for LeRobot datasource + SmolVLA model + LeRobot trainer.""" + +from __future__ import annotations + +from importlib import import_module +from types import SimpleNamespace + +import pytest + +from plato.config import Config +from tests.integration.utils import ( + async_run, + build_minimal_config, + configure_environment, +) +from tests.test_utils.lerobot_stubs import ( + FakeLeRobotDataset, + FakeLeRobotDatasetMetadata, + FakeSmolVLAPolicy, + identity_pre_post_processors, +) + +pytestmark = pytest.mark.integration + + +def test_lerobot_smolvla_end_to_end_smoke(monkeypatch): + """Smoke test startup + one short local train + server report processing.""" + config = build_minimal_config( + rounds=1, + clients_per_round=1, + total_clients=1, + model_name="smolvla", + trainer_type="lerobot", + ) + config["server"]["do_test"] = False + config["data"] = { + "datasource": "LeRobot", + "partition_size": 2, + "sampler": "iid", + "random_seed": 1, + } + config["trainer"].update( + { + "type": "lerobot", + "model_type": "smolvla", + "model_name": "smolvla", + "epochs": 1, + "batch_size": 2, + "optimizer": "SGD", + } + ) + config["parameters"] = { + "policy": { + "path": "stub/smolvla", + "finetune_mode": "adapter", + "adapter_parameter_patterns": ["adapter"], + }, + "dataset": { + "repo_id": "stub/lerobot", + "split_seed": 4, + "train_split": 0.5, + "task_aware_split": True, + "task_aware_partition": True, + }, + "optimizer": { + "lr": 0.05, + "momentum": 0.0, + "weight_decay": 0.0, + }, + } + + with configure_environment(config): + Config.args.id = 1 + FakeLeRobotDataset.reset_calls() + FakeSmolVLAPolicy.reset_calls() + + lerobot_datasource = import_module("plato.datasources.lerobot") + smolvla_model = import_module("plato.models.smolvla") + lerobot_trainer = import_module("plato.trainers.lerobot") + processor_registry = import_module("plato.processors.registry") + client_mod = import_module("plato.clients.simple") + server_mod = import_module("plato.servers.fedavg") + + monkeypatch.setattr( + lerobot_datasource, + "_import_lerobot", + lambda: (FakeLeRobotDataset, FakeLeRobotDatasetMetadata), + ) + monkeypatch.setattr( + lerobot_datasource, + "_build_image_transforms", + lambda _cfg: None, + ) + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: identity_pre_post_processors, + ) + monkeypatch.setattr( + processor_registry, + "get", + lambda *args, **kwargs: (None, None), + ) + + client = client_mod.Client() + client._load_data() + client.configure() + client._allocate_data() + + report, payload = async_run(client._train()) + report.processing_time = 0.0 + + assert report.num_samples > 0 + assert list(payload.keys()) == ["adapter.weight"] + + server = server_mod.Server() + server.configure() + server.updates = [ + SimpleNamespace( + client_id=1, + report=report, + payload=payload, + ) + ] + server.current_round = 0 + server.context.current_round = 0 + + async_run(server._process_reports()) + + assert server.accuracy >= 0 diff --git a/tests/models/test_smolvla_model.py b/tests/models/test_smolvla_model.py new file mode 100644 index 000000000..682bfc948 --- /dev/null +++ b/tests/models/test_smolvla_model.py @@ -0,0 +1,62 @@ +"""Tests for SmolVLA model registry construction and adapter metadata.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from plato.algorithms.fedavg import Algorithm as FedAvgAlgorithm +from plato.models import registry as models_registry +from plato.models import smolvla as smolvla_model +from tests.test_utils.lerobot_stubs import FakeSmolVLAPolicy + + +def test_model_registry_constructs_smolvla_wrapper(temp_config, monkeypatch): + """Model registry should construct SmolVLA via the registered factory.""" + FakeSmolVLAPolicy.reset_calls() + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + + model = models_registry.get( + model_name="smolvla", + model_type="smolvla", + model_params={ + "path": "stub/smolvla", + "finetune_mode": "adapter", + "adapter_parameter_patterns": ["adapter"], + "strict": True, + }, + ) + + assert isinstance(model, FakeSmolVLAPolicy) + assert model.plato_policy_path == "stub/smolvla" + assert model.plato_finetune_mode == "adapter" + assert model.plato_trainable_parameter_names == ("adapter.weight",) + assert FakeSmolVLAPolicy.load_calls[-1]["path"] == "stub/smolvla" + + +def test_smolvla_adapter_metadata_filters_fedavg_payload( + temp_config, + monkeypatch, +): + """Regression: adapter mode metadata must drive adapter-only FedAvg payloads.""" + FakeSmolVLAPolicy.reset_calls() + monkeypatch.setattr( + smolvla_model, + "_import_smolvla_policy", + lambda: FakeSmolVLAPolicy, + ) + + model = smolvla_model.Model.get( + policy_path="stub/smolvla", + finetune_mode="adapter", + adapter_parameter_patterns=["adapter"], + ) + + algorithm = FedAvgAlgorithm(trainer=SimpleNamespace(model=model)) + payload = algorithm.extract_weights() + + assert list(payload.keys()) == ["adapter.weight"] + assert "backbone.weight" not in payload diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 000000000..a20f1a914 --- /dev/null +++ b/tests/test_utils/__init__.py @@ -0,0 +1 @@ +"""Helpers and fakes shared across the test suite.""" diff --git a/tests/test_utils/lerobot_stubs.py b/tests/test_utils/lerobot_stubs.py new file mode 100644 index 000000000..2272e4b77 --- /dev/null +++ b/tests/test_utils/lerobot_stubs.py @@ -0,0 +1,170 @@ +"""Deterministic LeRobot/SmolVLA stubs for offline integration tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn as nn + + +class FakeLeRobotDatasetMetadata: + """Minimal metadata surface consumed by the LeRobot datasource adapter.""" + + def __init__(self, repo_id: str): + self.repo_id = repo_id + self.total_episodes = 6 + self.tasks = ["pick", "place"] + self.episodes = [ + { + "episode_index": episode, + "task_index": episode % 2, + "task": self.tasks[episode % 2], + } + for episode in range(self.total_episodes) + ] + + +class FakeLeRobotDataset: + """Small dict-style dataset that mimics LeRobot sample payloads.""" + + constructor_calls: list[dict[str, Any]] = [] + + def __init__( + self, + repo_id: str, + episodes: list[int] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, + image_transforms: Any = None, + **kwargs: Any, + ): + self.repo_id = repo_id + self.episodes = [int(episode) for episode in (episodes or [])] + self.delta_timestamps = delta_timestamps + self.image_transforms = image_transforms + self.extra_kwargs = dict(kwargs) + self.meta = SimpleNamespace( + stats={ + "action": { + "mean": [0.0, 0.0], + "std": [1.0, 1.0], + } + } + ) + + self.samples: list[dict[str, Any]] = [] + for step, episode in enumerate(self.episodes): + observation = torch.tensor( + [float(episode), float(step)], + dtype=torch.float32, + ) + action = torch.tensor( + [float(episode) + 0.25, float(step) + 0.5], + dtype=torch.float32, + ) + if callable(image_transforms): + observation = image_transforms(observation) + + self.samples.append( + { + "observation.image": observation, + "action": action, + "episode_index": episode, + "step_index": step, + "task": "pick" if episode % 2 == 0 else "place", + } + ) + + self.targets = [sample["task"] for sample in self.samples] + self.classes = ("pick", "place") + + type(self).constructor_calls.append( + { + "repo_id": repo_id, + "episodes": list(self.episodes), + "delta_timestamps": delta_timestamps, + "image_transforms": image_transforms, + "extra_kwargs": dict(kwargs), + } + ) + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.samples[index] + + @classmethod + def reset_calls(cls) -> None: + """Clear constructor call history for test isolation.""" + cls.constructor_calls = [] + + +class FakeSmolVLAPolicy(nn.Module): + """Tiny policy compatible with SmolVLA wrapper expectations.""" + + load_calls: list[dict[str, Any]] = [] + + def __init__(self): + super().__init__() + self.backbone = nn.Linear(2, 2, bias=False) + self.adapter = nn.Linear(2, 2, bias=False) + self.config = {"policy": "fake-smolvla"} + + with torch.no_grad(): + self.backbone.weight.copy_(torch.eye(2)) + self.adapter.weight.copy_(0.5 * torch.eye(2)) + + @classmethod + def from_pretrained( + cls, + policy_path: str, + token: str | None = None, + strict: bool = True, + ) -> "FakeSmolVLAPolicy": + """Match the LeRobot loader interface used by the wrapper.""" + cls.load_calls.append( + { + "path": policy_path, + "token": token, + "strict": strict, + } + ) + return cls() + + def forward(self, batch: dict[str, Any], reduction: str = "mean"): + """Return tensor loss + dict payload like SmolVLA policies do.""" + action = batch["action"].float() + prediction = self.adapter(self.backbone(action)) + per_sample = (prediction - action) ** 2 + + if reduction == "sum": + loss = per_sample.sum() + else: + loss = per_sample.mean() + + return loss, {"mse": loss.detach()} + + def save_pretrained(self, *args: Any, **kwargs: Any) -> None: + """Compatibility no-op for checkpoint contract checks.""" + + @classmethod + def reset_calls(cls) -> None: + """Clear loader call history for test isolation.""" + cls.load_calls = [] + + +def identity_pre_post_processors(policy_config: Any, **kwargs: Any): + """Return identity processors while recording constructor args.""" + + def preprocessor(batch): + return batch + + def postprocessor(outputs): + return outputs + + setattr(preprocessor, "policy_config", policy_config) + setattr(preprocessor, "kwargs", kwargs) + + return preprocessor, postprocessor diff --git a/tests/trainers/test_lerobot_trainer.py b/tests/trainers/test_lerobot_trainer.py new file mode 100644 index 000000000..1851e43b2 --- /dev/null +++ b/tests/trainers/test_lerobot_trainer.py @@ -0,0 +1,133 @@ +"""Tests for LeRobot trainer training-step behavior with synthetic data.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn as nn + +from plato.config import Config +from plato.trainers import lerobot as lerobot_trainer + + +class _SyntheticLeRobotDataset(torch.utils.data.Dataset): + """Tiny deterministic dataset exposing LeRobot-like dict samples.""" + + def __init__(self): + self.meta = SimpleNamespace(stats={"action": {"mean": [0.0], "std": [1.0]}}) + self.samples = [ + { + "observation.image": torch.tensor([0.0, 1.0], dtype=torch.float32), + "action": torch.tensor([0.0, 2.0], dtype=torch.float32), + "episode_index": 0, + }, + { + "observation.image": torch.tensor([1.0, 2.0], dtype=torch.float32), + "action": torch.tensor([1.0, 3.0], dtype=torch.float32), + "episode_index": 0, + }, + { + "observation.image": torch.tensor([2.0, 3.0], dtype=torch.float32), + "action": torch.tensor([2.0, 4.0], dtype=torch.float32), + "episode_index": 1, + }, + { + "observation.image": torch.tensor([3.0, 4.0], dtype=torch.float32), + "action": torch.tensor([3.0, 5.0], dtype=torch.float32), + "episode_index": 1, + }, + ] + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, index: int) -> dict[str, Any]: + return self.samples[index] + + +class _TinyLeRobotPolicy(nn.Module): + """Tiny policy returning tuple output variants consumed by LeRobot trainer.""" + + def __init__(self): + super().__init__() + self.adapter_scale = nn.Parameter(torch.tensor(1.0)) + self.config = {"policy": "tiny"} + self.plato_policy_path = "stub/smolvla" + + def forward(self, batch: dict[str, Any], reduction: str = "mean"): + action = batch["action"].float() + prediction = self.adapter_scale * torch.ones_like(action) + loss = torch.mean((prediction - action) ** 2) + # Return mapping + tensor tuple to exercise normalization branch. + return {"loss_component": loss.detach()}, loss + + +def test_lerobot_trainer_train_model_runs_on_tiny_synthetic_batch( + temp_config, + monkeypatch, +): + """Trainer should complete one short run and update model parameters.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + batch_size=2, + epochs=1, + optimizer="SGD", + ) + config.parameters = Config.node_from_dict( + { + "optimizer": { + "lr": 0.1, + "momentum": 0.0, + "weight_decay": 0.0, + }, + "policy": {"path": "stub/smolvla"}, + } + ) + + factory_calls: dict[str, Any] = {} + + def _fake_pre_post_factory(policy_config, **kwargs): + factory_calls["policy_config"] = policy_config + factory_calls["kwargs"] = kwargs + + def _pre(batch): + return batch + + def _post(outputs): + return outputs + + return _pre, _post + + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: _fake_pre_post_factory, + ) + + model = _TinyLeRobotPolicy() + trainer = lerobot_trainer.Trainer(model=model) + trainset = _SyntheticLeRobotDataset() + + start_value = float(model.adapter_scale.detach().item()) + trainer.train_model( + { + "batch_size": 2, + "epochs": 1, + "run_id": "lerobot-unit", + }, + trainset, + sampler=list(range(len(trainset))), + ) + end_value = float(model.adapter_scale.detach().item()) + + assert end_value != start_value + assert callable(trainer.context.state["lerobot_preprocessor"]) + assert trainer.context.state["lerobot_loss_dict"]["loss_component"] >= 0.0 + assert trainer.run_history.get_metric_values("train_loss") + assert factory_calls["kwargs"]["pretrained_path"] == "stub/smolvla" + assert factory_calls["kwargs"]["dataset_stats"] == trainset.meta.stats From ec1d7d76fdfdaaca0b50012b97915ebc42bd8fe3 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 12:50:15 -0500 Subject: [PATCH 11/15] Added SmolVLA LeRobot runbook and completed T11 documentation updates. --- docs/docs/install.md | 4 +- docs/docs/smolvla_lerobot_runbook.md | 200 ++++++++++++++++++ docs/mkdocs.yml | 2 + .../smolvla-lerobot-plato-integration-plan.md | 13 ++ 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 docs/docs/smolvla_lerobot_runbook.md diff --git a/docs/docs/install.md b/docs/docs/install.md index 0b12aa823..c2c797305 100644 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -75,7 +75,9 @@ uv sync --extra robotics ``` See [SmolVLA + LeRobot Optional Setup](smolvla_lerobot_setup.md) for runtime -requirements and guarded-import guidance. +requirements and guarded-import guidance, and +[SmolVLA + LeRobot Runbook](smolvla_lerobot_runbook.md) for setup-to-run +operational steps and troubleshooting. ### Building the `plato-learn` PyPi Package diff --git a/docs/docs/smolvla_lerobot_runbook.md b/docs/docs/smolvla_lerobot_runbook.md new file mode 100644 index 000000000..31a79fce5 --- /dev/null +++ b/docs/docs/smolvla_lerobot_runbook.md @@ -0,0 +1,200 @@ +# SmolVLA + LeRobot Runbook + +This runbook is for operators running SmolVLA training in Plato with LeRobot datasets. +It complements the setup notes in [SmolVLA + LeRobot Optional Setup](smolvla_lerobot_setup.md) +and the parameter contract in [Configuration Parameters](configurations/parameters.md). + +## 1) Setup + +1. Install core dependencies: + +```bash +uv sync +``` + +2. Install robotics extras: + +```bash +uv sync --extra robotics +``` + +3. Authenticate to Hugging Face when using private repos: + +```bash +huggingface-cli login +``` + +4. Verify optional stack import: + +```bash +uv run python -c "import lerobot; print(lerobot.__version__)" +``` + +5. If your dataset is video-backed, ensure `ffmpeg` is installed on the host. + +## 2) Config Profiles to Start From + +Use the configs added under `configs/LeRobot/`: + +- `configs/LeRobot/lerobot_datasource_base.toml`: shared LeRobot datasource include. +- `configs/LeRobot/smolvla_single_client_smoke.toml`: minimal single-client smoke run. +- `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml`: 2-client FedAvg smoke run. +- `configs/LeRobot/smolvla_full_finetune.toml`: longer full fine-tune profile. + +## 3) Run Commands + +Single-client smoke: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml +``` + +Two-client federated smoke: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml +``` + +Full-finetune profile: + +```bash +uv run python plato.py --config configs/LeRobot/smolvla_full_finetune.toml +``` + +## 4) Required Plato TOML Fields + +Minimum contract for this integration: + +```toml +[data] +datasource = "LeRobot" + +[trainer] +type = "lerobot" +model_type = "smolvla" +model_name = "smolvla" + +[parameters.policy] +type = "smolvla" +path = "lerobot/smolvla_base" +finetune_mode = "adapter" # or "full" +precision = "fp32" +device = "cpu" # or "cuda" / "mps" + +[parameters.dataset] +repo_id = "lerobot/pusht_image" +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } + +[parameters.transforms] +image_size = [224, 224] +normalize = true +interpolation = "bilinear" +``` + +## 5) Plato ↔ `lerobot-train` Mapping + +| Plato config field(s) | `lerobot-train` equivalent | Type | +| --- | --- | --- | +| `parameters.policy.path` | `--policy.path` | Direct | +| `parameters.dataset.repo_id` | `--dataset.repo_id` | Direct | +| `trainer.batch_size` | `--batch_size` | Direct | +| `parameters.policy.device` | `--policy.device` | Direct | +| `trainer.rounds` + `trainer.epochs` | `--steps` | Conceptual scheduling mapping | +| `server.checkpoint_path` / `server.model_path` | `--output_dir` | Conceptual output-location mapping | +| `parameters.dataset.delta_timestamps` | LeRobot dataset `delta_timestamps` usage during training | Conceptual data-window mapping | +| `parameters.policy.finetune_mode` (`full`/`adapter`) | Trainable-parameter strategy during policy training | Conceptual finetune-mode mapping | + +Notes: + +- Upstream LeRobot examples for SmolVLA commonly use `--steps`; Plato uses round/epoch scheduling. +- Adapter-mode behavior in Plato is implemented via `parameters.policy.finetune_mode` and + adapter parameter selection in the SmolVLA model wrapper. + +## 6) Troubleshooting + +### Missing optional robotics dependencies + +Symptom: + +- `ImportError: ... Install them with "uv sync --extra robotics" ...` + +Action: + +```bash +uv sync --extra robotics +``` + +### Dataset repo not configured + +Symptom: + +- `LeRobot datasource requires "parameters.dataset.repo_id" to be set.` + +Action: + +- Set `parameters.dataset.repo_id` in your TOML. + +### Private dataset/model access failure + +Symptoms: + +- SmolVLA load failure from `parameters.policy.path`. +- Dataset access/auth failures from the Hub. + +Actions: + +```bash +huggingface-cli login +# optionally for non-interactive runs +export HF_TOKEN= +``` + +### No episodes found + +Symptom: + +- `No episodes found for LeRobot dataset ""` + +Actions: + +- Verify `parameters.dataset.repo_id` exists and contains episodes. +- Confirm access permissions for the dataset repository. + +### Invalid `delta_timestamps` shape + +Symptom: + +- `"parameters.dataset.delta_timestamps" must be a mapping of key -> list[float].` + +Action: + +- Use mapping syntax, for example: + +```toml +[parameters.dataset] +delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } +``` + +### Device/precision mismatch or OOM + +Symptoms: + +- CUDA/MPS initialization failures. +- Out-of-memory during training. + +Actions: + +- Start from `configs/LeRobot/smolvla_single_client_smoke.toml` (CPU, tiny batch). +- Reduce `trainer.batch_size`. +- Use `parameters.policy.device = "cpu"` for smoke checks. +- Move to `cuda` + higher batch sizes only after smoke passes. + +### FFmpeg / build issues in robotics stack + +Symptom: + +- Build/runtime errors mentioning FFmpeg or PyAV dependencies. + +Actions: + +- Install host FFmpeg libraries and build toolchain (`cmake`, `build-essential`, FFmpeg libs), then reinstall robotics extras. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index c5e428749..a31e28f1e 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -48,6 +48,8 @@ extra_css: nav: - Home: index.md - Installation: install.md + - SmolVLA + LeRobot Optional Setup: smolvla_lerobot_setup.md + - SmolVLA + LeRobot Runbook: smolvla_lerobot_runbook.md - Quick Start: quickstart.md - Examples: - Getting Started: examples/Getting Started.md diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 5eaa3f824..846e4d249 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -238,10 +238,23 @@ gotchas: ### T11. Add documentation and runbook depends_on: [T10] +status: completed (2026-02-19) - Document setup and dependency extras. - Document config fields and examples. - Add troubleshooting notes (dataset access, device setup, common failures). - Add mapping between Plato config and equivalent `lerobot-train` concepts. +work_log: +- Added an operator-facing runbook covering dependency setup, runnable commands, minimum TOML contract, and troubleshooting for common LeRobot/SmolVLA failures. +- Added an explicit Plato TOML to `lerobot-train` mapping table with direct flag mappings (`policy.path`, `dataset.repo_id`, `batch_size`, `policy.device`) and conceptual mappings (`rounds`/`epochs` vs `steps`, output paths). +- Referenced all new `configs/LeRobot/*` profiles directly in the runbook and linked the runbook from installation docs and top-level docs navigation. +- Grounded mapping/troubleshooting notes against current LeRobot documentation via Context7 and implementation-specific runtime errors from Plato's LeRobot datasource/trainer/model integration. +files_touched: +- `docs/docs/smolvla_lerobot_runbook.md` (created) +- `docs/docs/install.md` (updated) +- `docs/mkdocs.yml` (updated) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- `lerobot-train` examples are primarily step-based (`--steps`) while Plato scheduling is round/epoch-based; documentation uses explicit conceptual mapping instead of implying a one-to-one flag conversion. ### T12. Stage validation and rollout gate depends_on: [T11] From 9b69889f151340362b2be712f754bacea96b35ab Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 13:06:35 -0500 Subject: [PATCH 12/15] Completed T12 staged validation and rollout gate. --- .../smolvla-lerobot-plato-integration-plan.md | 15 ++ plans/smolvla-lerobot-validation-report.md | 175 ++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 plans/smolvla-lerobot-validation-report.md diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md index 846e4d249..623bb1b53 100644 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ b/plans/smolvla-lerobot-plato-integration-plan.md @@ -258,12 +258,27 @@ gotchas: ### T12. Stage validation and rollout gate depends_on: [T11] +status: completed (2026-02-19) - Execute staged validation: - single-client local run - 2-client federated smoke run - larger run for convergence and stability check - Compare behavior/runtime against expected baseline. - Define go/no-go criteria and recommended defaults for first public release. +work_log: +- Captured validation window and environment (`2026-02-19 12:51-13:05 EST`, `uv 0.9.18`, Python `3.13.11`, `lerobot` + `torch` importable, `torch.cuda.is_available() == False`). +- Ran focused preflight baseline: `uv run pytest -q tests/test_config_loader.py::test_config_loads_smolvla_lerobot_parameter_contract tests/datasources/test_lerobot_datasource.py tests/models/test_smolvla_model.py tests/trainers/test_lerobot_trainer.py tests/integration/test_lerobot_smolvla_smoke.py tests/algorithms/test_fedavg_algorithm.py` -> `12 passed`. +- Ran staged real-config commands with bounded runtime: +- `timeout 300 uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml` -> fail (exit `124`, hit `TypeError: Got unsupported ScalarType BFloat16` during round-1 model payload serialization). +- `timeout 240 uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` -> fail (exit `124`, same BFloat16 serialization failure). +- `timeout 120 uv run python plato.py --config configs/LeRobot/smolvla_full_finetune.toml -u` -> fail (exit `124`, same BFloat16 serialization failure before convergence phase). +- Verified generated runtime CSVs (`runtime/results/94032.csv`, `runtime/results/94157.csv`, `runtime/results/94326.csv`) contain headers only and no completed round rows. +- Recorded gate decision and release defaults in `plans/smolvla-lerobot-validation-report.md`: current gate is `NO-GO` until bfloat16 payload serialization is fixed. +files_touched: +- `plans/smolvla-lerobot-validation-report.md` (created) +- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) +gotchas: +- Staged runs can download/load SmolVLA and initialize LeRobot datasets, but federated round dispatch currently blocks on safetensor/tree serialization of bfloat16 tensors (`Got unsupported ScalarType BFloat16`), resulting in stalled runs that require timeout fencing. ## Milestones diff --git a/plans/smolvla-lerobot-validation-report.md b/plans/smolvla-lerobot-validation-report.md new file mode 100644 index 000000000..472f13571 --- /dev/null +++ b/plans/smolvla-lerobot-validation-report.md @@ -0,0 +1,175 @@ +# SmolVLA + LeRobot Validation Report (T12) + +Date: 2026-02-19 +Validation window: 2026-02-19 12:51:36 EST to 13:05:19 EST (UTC: 17:51:36 to 18:05:19) + +## 1) Environment context + +- Repo: `/Users/bli/Playground/plato` +- `uv`: `0.9.18` +- Python: `3.13.11` +- Dependency probes: + - `import lerobot` -> available + - `import torch` -> available + - `torch.cuda.is_available()` -> `False` + +## 2) Commands executed and concrete outcomes + +### A. Baseline preflight (lightweight, offline-safe) + +Command: + +```bash +/usr/bin/time -p uv run pytest -q \ + tests/test_config_loader.py::test_config_loads_smolvla_lerobot_parameter_contract \ + tests/datasources/test_lerobot_datasource.py \ + tests/models/test_smolvla_model.py \ + tests/trainers/test_lerobot_trainer.py \ + tests/integration/test_lerobot_smolvla_smoke.py \ + tests/algorithms/test_fedavg_algorithm.py +``` + +Outcome: + +- Pass: `12 passed in 0.08s` +- Wall clock (`time -p`): `real 4.92`, `user 4.42`, `sys 0.50` + +Interpretation: + +- Local unit/integration coverage for LeRobot datasource, SmolVLA model wrapper, trainer, and FedAvg adapter behavior is healthy. + +### B. Stage 1: Single-client local run + +Command: + +```bash +/usr/bin/time -p timeout 300 uv run python plato.py \ + --config configs/LeRobot/smolvla_single_client_smoke.toml +``` + +Observed key runtime behavior: + +- Server and client initialized. +- LeRobot datasource loaded (`train episodes=165, test episodes=41`). +- Failure during round-1 model dispatch: + - `TypeError: Got unsupported ScalarType BFloat16` + - stack path includes `plato/processors/safetensor_encode.py` -> `plato/serialization/safetensor.py` -> `plato/utils/tree.py`. + +Exit/timing: + +- Fail: exit code `124` (timeout) +- `real 300.96`, `user 20.08`, `sys 15.71` + +### C. Stage 2: 2-client federated smoke run + +Command: + +```bash +/usr/bin/time -p timeout 240 uv run python plato.py \ + --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml +``` + +Observed key runtime behavior: + +- Server started with 2 clients configured. +- First client connected; round started. +- Failure on first payload send with the same exception: + - `TypeError: Got unsupported ScalarType BFloat16` + +Exit/timing: + +- Fail: exit code `124` (timeout) +- `real 240.91`, `user 12.78`, `sys 3.63` + +### D. Stage 3: Larger run (convergence/stability gate proxy) + +Command: + +```bash +/usr/bin/time -p timeout 120 uv run python plato.py \ + --config configs/LeRobot/smolvla_full_finetune.toml -u +``` + +Notes: + +- `-u` used because this environment reports `torch.cuda.is_available() == False`. + +Observed key runtime behavior: + +- Server initialized (`Training: 10 rounds`). +- Datasource initialized (`train episodes=185, test episodes=21`). +- Failure at round-1 dispatch with the same exception: + - `TypeError: Got unsupported ScalarType BFloat16` + +Exit/timing: + +- Fail: exit code `124` (timeout) +- `real 120.99`, `user 13.03`, `sys 4.22` + +### E. Runtime artifact check + +Command: + +```bash +ls runtime/results | rg "^(94032|94157|94326)\\.csv$" +``` + +Observed files: + +- `runtime/results/94032.csv` +- `runtime/results/94157.csv` +- `runtime/results/94326.csv` + +Content check: + +- Each file contains header only: `round,accuracy,elapsed_time` +- No completed round rows were recorded. + +## 3) Baseline comparison + +Expected baseline for staged gate: + +- Single-client smoke: complete 1/1 round and exit without unhandled exception. +- Two-client smoke: complete 1/1 federated round with both clients selected and aggregated. +- Larger run: progress beyond round 1 to provide initial convergence/stability signal. + +Actual: + +- All three runs failed before completing round 1 due the same bfloat16 serialization issue. +- Therefore runtime behavior is below baseline for release readiness. + +## 4) What could not be fully validated and why + +- Full convergence/stability behavior (multi-round trend) could not be validated because execution stopped before any completed round. +- End-to-end federated completion for the 2-client path could not be validated for the same reason. +- CUDA path in `smolvla_full_finetune.toml` could not be validated in this environment (`torch.cuda.is_available() == False`). + +## 5) Go/No-Go rollout gate + +Decision: **NO-GO** for first public release in current state. + +Blocking condition: + +- Federated payload serialization does not currently handle bfloat16 tensors emitted by SmolVLA/LeRobot policy state, causing unhandled exceptions and hung runs until timeout. + +Suggested release gate criteria to re-run after fix: + +1. No unhandled exception across all three staged commands. +2. Single-client smoke completes within timeout and writes >=1 runtime CSV data row. +3. Two-client smoke completes round 1 with aggregation and writes >=1 runtime CSV data row. +4. Larger profile completes at least 3 rounds in the same environment class used for release qualification. + +## 6) Recommended default settings for first public release + +These are recommended defaults **after** the blocking serialization issue is fixed: + +- Entry profile: `configs/LeRobot/smolvla_single_client_smoke.toml` +- `parameters.policy.finetune_mode = "adapter"` +- `parameters.policy.precision = "fp32"` +- `parameters.policy.device = "cpu"` for first-run smoke docs, then move to accelerator. +- `trainer.rounds = 1`, `trainer.epochs = 1`, `trainer.batch_size = 2` as onboarding default. +- Federated smoke default: `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` as second gate after single-client pass. + +Operational note: + +- Keep explicit timeout wrappers in CI/staging commands to avoid indefinite hangs when async server/client exceptions occur. From 63f6222a55d821c58149281a2d59faf6807b1041 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 13:49:38 -0500 Subject: [PATCH 13/15] Fixed all 'ty check' diagnostics and revalidated. --- .../fedadp/fedadp_algorithm.py | 4 +- .../server_aggregation/moon/moon_algorithm.py | 4 +- plato/algorithms/fedavg.py | 28 ++- plato/samplers/modality_iid.py | 2 +- plato/samplers/modality_quantity_noniid.py | 2 +- plato/trainers/lerobot.py | 206 +++++++++++++++++- plato/utils/rl_env.py | 17 +- tests/algorithms/test_fedavg_algorithm.py | 34 ++- tests/datasources/test_lerobot_datasource.py | 27 ++- tests/models/test_smolvla_model.py | 4 +- tests/trainers/test_lerobot_trainer.py | 78 +++++++ 11 files changed, 375 insertions(+), 31 deletions(-) diff --git a/examples/server_aggregation/fedadp/fedadp_algorithm.py b/examples/server_aggregation/fedadp/fedadp_algorithm.py index 8f841ab23..0956ddbfb 100644 --- a/examples/server_aggregation/fedadp/fedadp_algorithm.py +++ b/examples/server_aggregation/fedadp/fedadp_algorithm.py @@ -174,7 +174,9 @@ def _to_float_tensor(tensor: torch.Tensor) -> torch.Tensor: @staticmethod def _cast_tensor_like( - tensor: torch.Tensor, reference: torch.Tensor + tensor: torch.Tensor, + reference: torch.Tensor, + tensor_name: str = "tensor", ) -> torch.Tensor: """Cast a tensor to match the dtype of a reference tensor.""" if tensor.dtype == reference.dtype: diff --git a/examples/server_aggregation/moon/moon_algorithm.py b/examples/server_aggregation/moon/moon_algorithm.py index cd6ded9b7..87658644a 100644 --- a/examples/server_aggregation/moon/moon_algorithm.py +++ b/examples/server_aggregation/moon/moon_algorithm.py @@ -21,7 +21,9 @@ class Algorithm(fedavg.Algorithm): @staticmethod def _cast_tensor_like( - tensor: torch.Tensor, reference: torch.Tensor + tensor: torch.Tensor, + reference: torch.Tensor, + tensor_name: str = "tensor", ) -> torch.Tensor: """Cast a tensor to match a reference dtype (handles bool/int safely).""" if tensor.dtype == reference.dtype: diff --git a/plato/algorithms/fedavg.py b/plato/algorithms/fedavg.py index 6c00bf2ef..e1e0b3d97 100644 --- a/plato/algorithms/fedavg.py +++ b/plato/algorithms/fedavg.py @@ -25,6 +25,28 @@ def _as_state_mapping(weights: Any, context: str) -> Mapping[str, torch.Tensor]: raise TypeError(f"{context} must be a mapping of parameter names to tensors.") return weights + @staticmethod + def _to_transport_tensor( + tensor: torch.Tensor, tensor_name: str + ) -> torch.Tensor: + """ + Convert a tensor to a wire-safe representation for payload transport. + + Safetensor serialization in the current runtime path does not support + `torch.bfloat16` conversion through numpy. Cast bf16 payload tensors to + fp32 for transport, then cast back in `load_weights`. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Payload tensor '{tensor_name}' must be a torch.Tensor, " + f"received {type(tensor).__name__}." + ) + + prepared = tensor.detach().cpu().clone() + if prepared.dtype == torch.bfloat16: + return prepared.to(torch.float32) + return prepared + @staticmethod def _cast_tensor_like( tensor: torch.Tensor, reference: torch.Tensor, tensor_name: str @@ -266,7 +288,11 @@ def extract_weights(self, model: Module | None = None): keys_to_exchange = adapter_names or list(state_dict.keys()) payload = OrderedDict( - (name, state_dict[name].detach().cpu().clone()) for name in keys_to_exchange + ( + name, + self._to_transport_tensor(state_dict[name], name), + ) + for name in keys_to_exchange ) self._assert_payload_size(payload, source="Extracted") return payload diff --git a/plato/samplers/modality_iid.py b/plato/samplers/modality_iid.py index 66ed03c85..351f11862 100644 --- a/plato/samplers/modality_iid.py +++ b/plato/samplers/modality_iid.py @@ -16,7 +16,7 @@ class Sampler(base.Sampler): """Create a data sampler for each client to use a randomly divided partition of the dataset.""" - def __init__(self, datasource, client_id): + def __init__(self, datasource, client_id, testing=False): super().__init__() self.client_id = client_id diff --git a/plato/samplers/modality_quantity_noniid.py b/plato/samplers/modality_quantity_noniid.py index 606762d77..f17113ba7 100644 --- a/plato/samplers/modality_quantity_noniid.py +++ b/plato/samplers/modality_quantity_noniid.py @@ -17,7 +17,7 @@ class Sampler(base.Sampler): """Create a data sampler for each client to use a randomly divided partition of the dataset.""" - def __init__(self, datasource, client_id): + def __init__(self, datasource, client_id, testing=False): super().__init__() self.client_id = client_id if hasattr(datasource, "get_modality_name"): diff --git a/plato/trainers/lerobot.py b/plato/trainers/lerobot.py index e48caafdc..e66d4ae9e 100644 --- a/plato/trainers/lerobot.py +++ b/plato/trainers/lerobot.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import logging from collections.abc import Callable, Iterable, Mapping from typing import Any, cast @@ -21,6 +22,7 @@ ) _RESERVED_KEYS = frozenset({"plato_inputs", "plato_targets", "plato_metadata"}) +_SUPPORTED_POLICY_PRECISIONS = frozenset({"fp32", "fp16", "bf16"}) def _config_node_to_dict(node: Any) -> dict[str, Any]: @@ -284,6 +286,151 @@ def _resolve_sampler_for_loader(sampler: Any) -> Any: return sampler +def _resolve_precision(precision: Any) -> str: + """Normalize policy precision values from config.""" + if precision is None: + return "fp32" + if not isinstance(precision, str): + raise TypeError("`parameters.policy.precision` must be a string.") + + normalized = precision.strip().lower() + if normalized not in _SUPPORTED_POLICY_PRECISIONS: + supported = ", ".join(sorted(_SUPPORTED_POLICY_PRECISIONS)) + raise ValueError( + "Unsupported `parameters.policy.precision` value " + f"'{precision}'. Expected one of: {supported}." + ) + return normalized + + +def _resolve_runtime_device(device_value: Any, fallback_device: Any) -> torch.device: + """ + Resolve runtime device from policy config, falling back to trainer default. + + Raises explicit errors when a requested accelerator is unavailable so users + can detect mismatched config/environment early. + """ + if isinstance(fallback_device, torch.device): + fallback = fallback_device + else: + fallback = torch.device(str(fallback_device)) + + if device_value is None: + return fallback + if not isinstance(device_value, str): + raise TypeError("`parameters.policy.device` must be a string.") + + normalized = device_value.strip().lower() + if not normalized: + return fallback + + if normalized == "cpu": + return torch.device("cpu") + + if normalized == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + "`parameters.policy.device` is set to 'cuda' but CUDA is not " + "available on this host." + ) + return torch.device("cuda:0") + + if normalized.startswith("cuda:"): + if not torch.cuda.is_available(): + raise RuntimeError( + f"`parameters.policy.device` is set to '{device_value}' but CUDA " + "is not available on this host." + ) + try: + gpu_index = int(normalized.split(":", 1)[1]) + except (IndexError, ValueError) as exc: + raise ValueError( + f"Invalid CUDA device value: '{device_value}'." + ) from exc + if gpu_index < 0 or gpu_index >= torch.cuda.device_count(): + raise RuntimeError( + f"`parameters.policy.device` requested CUDA device {gpu_index}, " + f"but only {torch.cuda.device_count()} device(s) are available." + ) + return torch.device(normalized) + + if normalized == "mps": + mps_backend = getattr(torch.backends, "mps", None) + if mps_backend is None or not mps_backend.is_available(): + raise RuntimeError( + "`parameters.policy.device` is set to 'mps' but MPS is not " + "available on this host." + ) + return torch.device("mps") + + raise ValueError( + "Unsupported `parameters.policy.device` value " + f"'{device_value}'. Expected cpu, cuda[:index], or mps." + ) + + +def _autocast_context( + context: TrainingContext, +) -> tuple[contextlib.AbstractContextManager[Any], bool]: + """Resolve an autocast context from runtime precision and device settings.""" + precision = str(context.state.get("lerobot_precision", "fp32")).lower() + device = context.device + + if not isinstance(device, torch.device): + device = torch.device(str(device)) + + if precision == "fp32": + return contextlib.nullcontext(), False + + if device.type == "cuda": + if not torch.cuda.is_available(): + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' requested, but CUDA is unavailable. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + dtype = torch.float16 if precision == "fp16" else torch.bfloat16 + return torch.autocast(device_type="cuda", dtype=dtype), True + + if device.type == "cpu": + if precision == "bf16": + return torch.autocast(device_type="cpu", dtype=torch.bfloat16), True + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on CPU autocast. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + if device.type == "mps": + if precision == "fp16": + return torch.autocast(device_type="mps", dtype=torch.float16), True + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on MPS autocast. " + "Falling back to fp32 execution.", + precision, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + if not context.state.get("lerobot_precision_warning_emitted"): + logging.warning( + "LeRobot precision '%s' is not supported on device type '%s'. " + "Falling back to fp32 execution.", + precision, + device.type, + ) + context.state["lerobot_precision_warning_emitted"] = True + return contextlib.nullcontext(), False + + class LeRobotTrainingStepStrategy(TrainingStepStrategy): """Training step strategy for LeRobot policies with dict-style batches.""" @@ -294,19 +441,29 @@ def training_step( self, model: nn.Module, optimizer: torch.optim.Optimizer, - examples: LeRobotBatch, + examples: torch.Tensor | Mapping[str, Any], labels: torch.Tensor, # pylint: disable=unused-argument - loss_criterion, # pylint: disable=unused-argument + loss_criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], context: TrainingContext, ) -> torch.Tensor: optimizer.zero_grad() + del labels, loss_criterion - batch = _apply_preprocessor(examples, context) - loss, loss_dict = _resolve_policy_forward( - model, - batch, - reduction=self.reduction, - ) + if not isinstance(examples, Mapping): + raise TypeError( + "LeRobot training expects dictionary-style batches. " + f"Received {type(examples).__name__}." + ) + + autocast_guard, autocast_enabled = _autocast_context(context) + context.state["lerobot_autocast_enabled"] = autocast_enabled + with autocast_guard: + batch = _apply_preprocessor(LeRobotBatch(dict(examples)), context) + loss, loss_dict = _resolve_policy_forward( + model, + batch, + reduction=self.reduction, + ) if not torch.is_tensor(loss): raise TypeError( @@ -354,7 +511,9 @@ def test_model( total_loss = 0.0 total_weight = 0 - with torch.no_grad(): + autocast_guard, autocast_enabled = _autocast_context(context) + context.state["lerobot_autocast_enabled"] = autocast_enabled + with torch.no_grad(), autocast_guard: for examples, labels in test_loader: examples = examples.to(context.device) labels = labels.to(context.device) @@ -413,6 +572,8 @@ def __init__(self, model=None, callbacks=None): self._collate_wrapper = LeRobotCollateWrapper() self._processors_initialised = False self._pretrained_path = self._resolve_policy_path() + self._runtime_precision = self._resolve_policy_precision() + self._policy_device = self._resolve_policy_device() self._preprocessor_factory: Callable[..., tuple[Callable, Callable]] | None = ( None ) @@ -433,6 +594,11 @@ def __init__(self, model=None, callbacks=None): testing_strategy=LeRobotTestingStrategy(self._collate_wrapper), ) + resolved_device = _resolve_runtime_device(self._policy_device, self.device) + self.device = str(resolved_device) + self.context.device = resolved_device + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.state["lerobot_runtime_device"] = str(resolved_device) self.context.state["lerobot_preprocessor"] = None self.context.state["lerobot_postprocessor"] = None @@ -446,6 +612,24 @@ def _resolve_policy_path() -> str | None: return value if value else None return None + @staticmethod + def _resolve_policy_precision() -> str: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + return _resolve_precision(policy_cfg.get("precision", "fp32")) + + @staticmethod + def _resolve_policy_device() -> str | None: + parameters = getattr(Config(), "parameters", None) + policy_cfg = _config_node_to_dict(getattr(parameters, "policy", None)) + candidate = policy_cfg.get("device") + if candidate is None: + return None + if not isinstance(candidate, str): + raise TypeError("`parameters.policy.device` must be a string.") + value = candidate.strip() + return value if value else None + def _resolve_model_pretrained_path(self) -> str | None: model = self._require_model() model_path = getattr(model, "plato_policy_path", None) @@ -494,8 +678,12 @@ def _ensure_pre_post_processors(self, dataset: Any) -> None: def train_model(self, config, trainset, sampler, **kwargs): self._ensure_pre_post_processors(trainset) + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.device = torch.device(str(self.device)) return super().train_model(config, trainset, sampler, **kwargs) def test_model(self, config, testset, sampler=None, **kwargs): self._ensure_pre_post_processors(testset) + self.context.state["lerobot_precision"] = self._runtime_precision + self.context.device = torch.device(str(self.device)) return super().test_model(config, testset, sampler, **kwargs) diff --git a/plato/utils/rl_env.py b/plato/utils/rl_env.py index b807a67a5..cef5e57ea 100644 --- a/plato/utils/rl_env.py +++ b/plato/utils/rl_env.py @@ -10,6 +10,7 @@ import asyncio import logging +from typing import Any import gymnasium as gym import numpy as np @@ -43,19 +44,27 @@ def __init__(self, rl_agent): # https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html#tips-and-tricks-when-creating-a-custom-environment n_actions = 1 self.action_space = spaces.Box( - low=-1, high=1, shape=(n_actions,), dtype="float32" + low=-1, high=1, shape=(n_actions,), dtype=np.float32 ) # Use only global model accurarcy as state for now self.n_states = 1 # Also normalize observation space for better RL training self.observation_space = spaces.Box( - low=-1, high=1, shape=(self.n_states,), dtype="float32" + low=-1, high=1, shape=(self.n_states,), dtype=np.float32 ) self.state = np.zeros(self.n_states, dtype=np.float32) - def reset(self): + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, dict[str, Any]]: + super().reset(seed=seed) + del options + if self.rl_agent.rl_episode >= Config().algorithm.rl_episodes: while True: # Give RL agent some time to close connections and exit @@ -72,7 +81,7 @@ def reset(self): self.rl_agent.new_episode_begin.set() self.state = np.zeros(self.n_states, dtype=np.float32) - return self.state.copy() + return self.state.copy(), {} def step(self, action): """One step of reinforcement learning.""" diff --git a/tests/algorithms/test_fedavg_algorithm.py b/tests/algorithms/test_fedavg_algorithm.py index f6efcf837..1ccdd8399 100644 --- a/tests/algorithms/test_fedavg_algorithm.py +++ b/tests/algorithms/test_fedavg_algorithm.py @@ -4,6 +4,7 @@ from collections import OrderedDict from types import SimpleNamespace +from typing import Any, cast import pytest import torch @@ -33,8 +34,18 @@ def __init__(self) -> None: self.register_buffer("flag", torch.tensor([True, False], dtype=torch.bool)) +class BFloat16ToyModel(torch.nn.Module): + """Toy model for bf16 transport-cast regression coverage.""" + + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter( + torch.ones((2, 2), dtype=torch.bfloat16) + ) + + def _algorithm_for(model: torch.nn.Module) -> FedAvgAlgorithm: - trainer = SimpleNamespace(model=model) + trainer = cast(Any, SimpleNamespace(model=model)) return FedAvgAlgorithm(trainer=trainer) @@ -95,6 +106,27 @@ def test_load_weights_casts_dtype_and_rounds_non_float_tensors(): assert torch.equal(state["flag"], torch.tensor([False, True])) +def test_extract_weights_casts_bfloat16_payloads_for_transport(): + """bf16 tensors should be cast to fp32 for safe payload serialization.""" + model = BFloat16ToyModel() + algorithm = _algorithm_for(model) + + payload = algorithm.extract_weights() + assert payload["weight"].dtype == torch.float32 + + inbound = OrderedDict( + {"weight": torch.full((2, 2), 3.5, dtype=torch.float32)} + ) + algorithm.load_weights(inbound) + + state = model.state_dict() + assert state["weight"].dtype == torch.bfloat16 + assert torch.allclose( + state["weight"].float(), + torch.full((2, 2), 3.5, dtype=torch.float32), + ) + + def test_extract_weights_respects_optional_payload_size_limit(): """Payload extraction should fail fast if a configured max size is exceeded.""" model = torch.nn.Linear(32, 32, bias=False) diff --git a/tests/datasources/test_lerobot_datasource.py b/tests/datasources/test_lerobot_datasource.py index 3e76c73d0..61dc2e81b 100644 --- a/tests/datasources/test_lerobot_datasource.py +++ b/tests/datasources/test_lerobot_datasource.py @@ -74,17 +74,22 @@ def test_lerobot_constructor_is_deterministic_and_maps_samples( patched_lerobot_backend, ): """Constructor should produce stable splits and mapped Plato sample keys.""" - kwargs = { - "client_id": 2, - "repo_id": "stub/lerobot", - "split_seed": 11, - "train_split": 0.5, - "task_aware_split": True, - "task_aware_partition": True, - } - - first = lerobot_datasource.DataSource(**kwargs) - second = lerobot_datasource.DataSource(**kwargs) + first = lerobot_datasource.DataSource( + client_id=2, + repo_id="stub/lerobot", + split_seed=11, + train_split=0.5, + task_aware_split=True, + task_aware_partition=True, + ) + second = lerobot_datasource.DataSource( + client_id=2, + repo_id="stub/lerobot", + split_seed=11, + train_split=0.5, + task_aware_split=True, + task_aware_partition=True, + ) assert first.train_episodes == second.train_episodes assert first.test_episodes == second.test_episodes diff --git a/tests/models/test_smolvla_model.py b/tests/models/test_smolvla_model.py index 682bfc948..4c477fa6d 100644 --- a/tests/models/test_smolvla_model.py +++ b/tests/models/test_smolvla_model.py @@ -3,6 +3,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import Any, cast from plato.algorithms.fedavg import Algorithm as FedAvgAlgorithm from plato.models import registry as models_registry @@ -55,7 +56,8 @@ def test_smolvla_adapter_metadata_filters_fedavg_payload( adapter_parameter_patterns=["adapter"], ) - algorithm = FedAvgAlgorithm(trainer=SimpleNamespace(model=model)) + trainer = cast(Any, SimpleNamespace(model=model)) + algorithm = FedAvgAlgorithm(trainer=trainer) payload = algorithm.extract_weights() assert list(payload.keys()) == ["adapter.weight"] diff --git a/tests/trainers/test_lerobot_trainer.py b/tests/trainers/test_lerobot_trainer.py index 1851e43b2..1d1316aff 100644 --- a/tests/trainers/test_lerobot_trainer.py +++ b/tests/trainers/test_lerobot_trainer.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from typing import Any +import pytest import torch import torch.nn as nn @@ -131,3 +132,80 @@ def _post(outputs): assert trainer.run_history.get_metric_values("train_loss") assert factory_calls["kwargs"]["pretrained_path"] == "stub/smolvla" assert factory_calls["kwargs"]["dataset_stats"] == trainset.meta.stats + + +def test_lerobot_trainer_consumes_policy_precision_and_device( + temp_config, + monkeypatch, +): + """Trainer should apply policy precision/device runtime settings.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + batch_size=2, + epochs=1, + optimizer="SGD", + ) + config.parameters = Config.node_from_dict( + { + "optimizer": { + "lr": 0.05, + "momentum": 0.0, + "weight_decay": 0.0, + }, + "policy": { + "path": "stub/smolvla", + "precision": "bf16", + "device": "cpu", + }, + } + ) + + monkeypatch.setattr( + lerobot_trainer, + "_import_make_pre_post_processors", + lambda: (lambda *_args, **_kwargs: (lambda batch: batch, lambda out: out)), + ) + + trainer = lerobot_trainer.Trainer(model=_TinyLeRobotPolicy()) + trainset = _SyntheticLeRobotDataset() + + assert trainer.device == "cpu" + assert trainer.context.device == torch.device("cpu") + assert trainer.context.state["lerobot_precision"] == "bf16" + + trainer.train_model( + {"batch_size": 2, "epochs": 1, "run_id": "lerobot-precision"}, + trainset, + sampler=list(range(len(trainset))), + ) + assert trainer.context.state["lerobot_precision"] == "bf16" + assert isinstance(trainer.context.state["lerobot_autocast_enabled"], bool) + + +def test_lerobot_trainer_rejects_unavailable_cuda_device( + temp_config, + monkeypatch, +): + """Policy device should be validated against runtime accelerator availability.""" + config = Config() + config.trainer = config.trainer._replace( + type="lerobot", + model_type="smolvla", + model_name="smolvla_unit", + ) + config.parameters = Config.node_from_dict( + { + "policy": { + "path": "stub/smolvla", + "device": "cuda", + }, + } + ) + + monkeypatch.setattr(lerobot_trainer.torch.cuda, "is_available", lambda: False) + + with pytest.raises(RuntimeError, match="CUDA is not available"): + lerobot_trainer.Trainer(model=_TinyLeRobotPolicy()) From 31150968931976efa996e00e6bd392ce477a16a6 Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Thu, 19 Feb 2026 13:57:13 -0500 Subject: [PATCH 14/15] Rearranged the SmolVLA documentation. --- .../3. SmolVLA Trainer with LeRobot.md} | 55 +++++++++++++++++-- docs/docs/install.md | 4 +- docs/docs/smolvla_lerobot_setup.md | 46 ---------------- docs/mkdocs.yml | 4 +- 4 files changed, 55 insertions(+), 54 deletions(-) rename docs/docs/{smolvla_lerobot_runbook.md => examples/case-studies/3. SmolVLA Trainer with LeRobot.md} (75%) delete mode 100644 docs/docs/smolvla_lerobot_setup.md diff --git a/docs/docs/smolvla_lerobot_runbook.md b/docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md similarity index 75% rename from docs/docs/smolvla_lerobot_runbook.md rename to docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md index 31a79fce5..ba99ac107 100644 --- a/docs/docs/smolvla_lerobot_runbook.md +++ b/docs/docs/examples/case-studies/3. SmolVLA Trainer with LeRobot.md @@ -1,8 +1,6 @@ -# SmolVLA + LeRobot Runbook +# SmolVLA Trainer with LeRobot -This runbook is for operators running SmolVLA training in Plato with LeRobot datasets. -It complements the setup notes in [SmolVLA + LeRobot Optional Setup](smolvla_lerobot_setup.md) -and the parameter contract in [Configuration Parameters](configurations/parameters.md). +This runbook is for operators running SmolVLA training in Plato with LeRobot datasets. It complements the parameter contract in [Configuration Parameters](../../configurations/parameters.md). ## 1) Setup @@ -198,3 +196,52 @@ Symptom: Actions: - Install host FFmpeg libraries and build toolchain (`cmake`, `build-essential`, FFmpeg libs), then reinstall robotics extras. + + + + +## SmolVLA + LeRobot Optional Setup + +This setup path is optional. Core Plato federated workloads continue to use the default dependency set from `uv sync`. + +### Install the robotics extra + +From the repository root: + +```bash +uv sync --extra robotics +``` + +This installs `lerobot[smolvla]` and the associated training stack only when the +`robotics` extra is requested. + +### Environment gating + +When adding LeRobot-backed modules, keep imports guarded so non-robotics +environments fail with a clear action instead of a hard crash at import time. + +```python +try: + import lerobot +except ImportError as exc: + raise ImportError( + "LeRobot support is optional. Install with: uv sync --extra robotics" + ) from exc +``` + +### Runtime notes for SmolVLA/LeRobot + +- CUDA-capable GPUs are recommended for practical SmolVLA fine-tuning; CPU is + mainly suitable for smoke checks. +- Install `ffmpeg` on hosts that read video-backed LeRobot datasets. +- Authenticate with Hugging Face (`huggingface-cli login`) when accessing + private dataset repositories. +- LeRobot currently constrains the Torch stack used by this optional path; + if you need different Torch constraints for non-robotics research, keep a + separate virtual environment. + +### Quick verification + +```bash +uv run python -c "import lerobot; print(lerobot.__version__)" +``` diff --git a/docs/docs/install.md b/docs/docs/install.md index c2c797305..f62f4289e 100644 --- a/docs/docs/install.md +++ b/docs/docs/install.md @@ -74,9 +74,9 @@ LeRobot and SmolVLA dependencies are available behind Plato's optional uv sync --extra robotics ``` -See [SmolVLA + LeRobot Optional Setup](smolvla_lerobot_setup.md) for runtime +See [SmolVLA + LeRobot Optional Setup](examples/case-studies/smolvla_lerobot_setup.md) for runtime requirements and guarded-import guidance, and -[SmolVLA + LeRobot Runbook](smolvla_lerobot_runbook.md) for setup-to-run +[SmolVLA + LeRobot Runbook](examples/case-studies/smolvla_lerobot_runbook.md) for setup-to-run operational steps and troubleshooting. ### Building the `plato-learn` PyPi Package diff --git a/docs/docs/smolvla_lerobot_setup.md b/docs/docs/smolvla_lerobot_setup.md deleted file mode 100644 index e36be2a22..000000000 --- a/docs/docs/smolvla_lerobot_setup.md +++ /dev/null @@ -1,46 +0,0 @@ -# SmolVLA + LeRobot Optional Setup - -This setup path is optional. Core Plato federated workloads continue to use the -default dependency set from `uv sync`. - -## Install the robotics extra - -From the repository root: - -```bash -uv sync --extra robotics -``` - -This installs `lerobot[smolvla]` and the associated training stack only when the -`robotics` extra is requested. - -## Environment gating - -When adding LeRobot-backed modules, keep imports guarded so non-robotics -environments fail with a clear action instead of a hard crash at import time. - -```python -try: - import lerobot -except ImportError as exc: - raise ImportError( - "LeRobot support is optional. Install with: uv sync --extra robotics" - ) from exc -``` - -## Runtime notes for SmolVLA/LeRobot - -- CUDA-capable GPUs are recommended for practical SmolVLA fine-tuning; CPU is - mainly suitable for smoke checks. -- Install `ffmpeg` on hosts that read video-backed LeRobot datasets. -- Authenticate with Hugging Face (`huggingface-cli login`) when accessing - private dataset repositories. -- LeRobot currently constrains the Torch stack used by this optional path; - if you need different Torch constraints for non-robotics research, keep a - separate virtual environment. - -## Quick verification - -```bash -uv run python -c "import lerobot; print(lerobot.__version__)" -``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index a31e28f1e..64f5f5ca9 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -48,8 +48,6 @@ extra_css: nav: - Home: index.md - Installation: install.md - - SmolVLA + LeRobot Optional Setup: smolvla_lerobot_setup.md - - SmolVLA + LeRobot Runbook: smolvla_lerobot_runbook.md - Quick Start: quickstart.md - Examples: - Getting Started: examples/Getting Started.md @@ -71,6 +69,8 @@ nav: - Case Studies: - Federated LoRA Fine-Tuning: examples/case-studies/1. LoRA.md - Composable Trainer API: examples/case-studies/2. Composable Trainer.md + - SmolVLA + LeRobot Optional Setup: examples/case-studies/smolvla_lerobot_setup.md + - SmolVLA + LeRobot Runbook: examples/case-studies/smolvla_lerobot_runbook.md - Configuration Settings: - Overview: configurations/overview.md - General: configurations/general.md From deaa0c4906a1f645fdf601d0445b470090ddc7db Mon Sep 17 00:00:00 2001 From: Baochun Li Date: Tue, 10 Mar 2026 15:19:54 -0400 Subject: [PATCH 15/15] Removed plans. --- plans/smolvla-lerobot-integration-contract.md | 151 --------- .../smolvla-lerobot-plato-integration-plan.md | 292 ------------------ plans/smolvla-lerobot-validation-report.md | 175 ----------- 3 files changed, 618 deletions(-) delete mode 100644 plans/smolvla-lerobot-integration-contract.md delete mode 100644 plans/smolvla-lerobot-plato-integration-plan.md delete mode 100644 plans/smolvla-lerobot-validation-report.md diff --git a/plans/smolvla-lerobot-integration-contract.md b/plans/smolvla-lerobot-integration-contract.md deleted file mode 100644 index 8fb329f43..000000000 --- a/plans/smolvla-lerobot-integration-contract.md +++ /dev/null @@ -1,151 +0,0 @@ -# SmolVLA + LeRobot Integration Contract (Release v1) - -Date: 2026-02-19 -Plan Task: T1 (`depends_on: []`) -Status: accepted baseline for implementation tasks T2+ - -## 1. Objective - -Define a concrete, testable contract for first-release SmolVLA + LeRobot support in Plato without changing Plato's core federated runtime model. - -## 2. In-Scope (Release v1) - -1. SmolVLA fine-tuning runs inside Plato's existing federated lifecycle. -2. LeRobot datasets are ingested via Plato datasource APIs. -3. Existing client/server/algorithm loops remain the orchestration path. -4. End users run experiments through TOML configuration only (no source edits). - -## 3. Integration Surface (Concrete Components) - -The implementation must integrate through these existing extension points. - -Runtime entry and lifecycle: -1. `plato.py` -2. `plato/client.py` -3. `plato/clients/registry.py` -4. `plato/clients/base.py` -5. `plato/servers/fedavg.py` -6. `plato/servers/registry.py` -7. `plato/algorithms/registry.py` - -Config loading and propagation: -1. `plato/config.py` - -Datasource extension points: -1. `plato/datasources/base.py` -2. `plato/datasources/registry.py` -3. New module target: `plato/datasources/lerobot.py` - -Model extension points: -1. `plato/models/registry.py` -2. New module target: `plato/models/smolvla.py` - -Trainer extension points: -1. `plato/trainers/base.py` -2. `plato/trainers/composable.py` -3. `plato/trainers/registry.py` -4. New module target: `plato/trainers/lerobot.py` - -Compatibility rule: v1 must work with existing `fedavg` server/algorithm paths. Any special handling must be encapsulated inside the new datasource/model/trainer modules and their registry wiring. - -## 4. Configuration Contract (v1) - -The following fields are the required contract for SmolVLA + LeRobot configs. T3 is responsible for schema wiring/validation. - -```toml -[data] -datasource = "LeRobot" -# existing partitioning keys stay valid (sampler, partition_size, random_seed) - -[trainer] -type = "lerobot" -model_type = "smolvla" -model_name = "smolvla" - -[parameters.policy] -type = "smolvla" -path = "lerobot/smolvla_base" -finetune_mode = "full" # "full" or "adapter" -precision = "bf16" # expected values: fp32/fp16/bf16 -device = "cuda" # expected values: cpu/cuda/mps - -[parameters.dataset] -repo_id = "/" -delta_timestamps = { observation_image = [-0.2, -0.1, 0.0] } - -[parameters.transforms] -image_size = [224, 224] -normalize = true -``` - -Required semantics: - -1. `parameters.policy.path` resolves pretrained policy source. -2. `parameters.policy.type` selects policy family (`smolvla` for v1). -3. `parameters.dataset.repo_id` selects LeRobot dataset source. -4. `parameters.dataset.delta_timestamps` controls temporal windowing. -5. `parameters.transforms.*` controls image preprocessing. -6. `parameters.policy.finetune_mode` controls full vs adapter updates. -7. `parameters.policy.precision` and `parameters.policy.device` govern runtime dtype/device behavior. - -## 5. Scope Boundaries and Non-Goals - -Explicitly out of scope for release v1: - -1. New federated algorithms or server types beyond existing registry options. -2. Live robot inference/control loops and async teleoperation workflows. -3. Non-LeRobot robotics dataset backends. -4. End-to-end convergence/benchmark claims beyond smoke/stability checks. -5. Automated dependency bootstrap for platform-specific robotics stacks. - -## 6. Acceptance Checks (Concrete and Testable) - -These checks define go/no-go for the integration scope. - -### AC1: Single-client local training run - -- Config target: `configs/LeRobot/smolvla_single_client_smoke.toml`. -- Command: - -```bash -uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml -``` - -Pass criteria: -1. Process exits with code `0`. -2. Trainer completes at least one local epoch in one communication round. -3. A model artifact is written under configured `model_path`. - -### AC2: Multi-client federated round-trip - -- Config target: `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml`. -- Command: - -```bash -uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml -``` - -Pass criteria: -1. Server starts and selects two clients in the same round. -2. Server receives two client updates and completes aggregation. -3. Round counter advances to at least round 1 completion without runtime exceptions. - -### AC3: Config-first workflow (no source edits) - -Validation procedure: -1. Run `AC1` and `AC2` using committed TOML files only. -2. Confirm no local source modifications are required between runs. - -Pass criteria: -1. Both runs succeed from clean checkout with only config selection changed. - -## 7. Deliverables Expected From Downstream Tasks - -1. Config files under `configs/LeRobot/` implementing AC1/AC2 targets. -2. Registry wiring and implementation modules listed in Section 3. -3. Smoke/integration tests that encode AC1/AC2/AC3 behavior. - -## 8. Notes - -- This contract intentionally locks only v1 integration behavior and acceptance gates. -- Performance tuning and broader robotics feature surface are deferred to post-v1 tasks. diff --git a/plans/smolvla-lerobot-plato-integration-plan.md b/plans/smolvla-lerobot-plato-integration-plan.md deleted file mode 100644 index 623bb1b53..000000000 --- a/plans/smolvla-lerobot-plato-integration-plan.md +++ /dev/null @@ -1,292 +0,0 @@ -# SmolVLA + LeRobot Integration Plan for Plato - -Date: 2026-02-19 -Scope: Add support for training Hugging Face SmolVLA with LeRobot datasets/framework inside Plato. - -## Dependency Graph - -```text -T1 -> T2, T3 -T2 -> T4, T5 -T3 -> T4, T5, T9 -T4, T5 -> T6 -T6 -> T7, T8 -T4, T5, T6 -> T9 -T7, T8, T9 -> T10 -T10 -> T11 -T11 -> T12 -``` - -## Tasks - -### T1. Define integration contract and acceptance criteria -depends_on: [] -status: completed (2026-02-19) -- Lock exact scope for first release: -- Support SmolVLA fine-tuning in Plato’s existing FL lifecycle. -- Support LeRobot dataset ingestion through Plato datasource APIs. -- Ensure compatibility with existing client/server/algorithm loops. -- Define acceptance checks: -- Single-client local training run works. -- Multi-client federated round-trip works. -- Config-first workflow (no code edits required to run experiment). -work_log: -- Added `plans/smolvla-lerobot-integration-contract.md` to lock v1 scope, concrete integration touchpoints, config contract, and testable acceptance checks. -- Made explicit scope boundaries and non-goals for v1. -files_touched: -- `plans/smolvla-lerobot-integration-contract.md` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- Context7 returned `LeRobot` as the primary documented surface; SmolVLA details were discovered through the LeRobot documentation set. - -### T2. Add dependencies and environment gating -depends_on: [T1] -status: completed (2026-02-19) -- Update `pyproject.toml` with required LeRobot and training stack dependencies. -- Regenerate `uv.lock`. -- Add guarded imports so environments without robotics extras still run existing Plato workloads. -- Document required system/runtime notes for optional robotics path. -work_log: -- Added a new optional extra (`robotics`) in `pyproject.toml` with `lerobot[smolvla]>=0.4.3,<0.5.0` so default installs remain unchanged. -- Regenerated `uv.lock` with `uv lock`, then validated both paths: -- `uv sync --frozen` + import check for core Plato. -- `uv sync --frozen --extra robotics` + `import lerobot` check for the optional robotics stack. -- Added focused setup docs for SmolVLA/LeRobot and linked them from `docs/docs/install.md`. -files_touched: -- `pyproject.toml` -- `uv.lock` -- `docs/docs/install.md` -- `docs/docs/smolvla_lerobot_setup.md` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` -gotchas: -- The optional LeRobot path constrains parts of the Torch stack in lock resolution; keeping it under `--extra robotics` avoids forcing robotics dependencies into default `uv sync` environments. - -### T3. Extend Plato configuration schema for SmolVLA/LeRobot -depends_on: [T1] -status: completed (2026-02-19) -- Add/validate config keys needed for SmolVLA + LeRobot: -- `policy.path` / `policy.type` -- `dataset.repo_id` -- `delta_timestamps` -- image transform controls -- precision/device flags -- full-finetune vs adapter mode switch -- Ensure keys flow through `Config()` and into trainer/model/datasource constructors. -work_log: -- Verified that `plato/config.py` already preserves nested TOML keys under `Config().parameters` without schema whitelisting, so SmolVLA/LeRobot keys are backward-compatible pass-through. -- Added a focused config loader test to assert `parameters.policy`, `parameters.dataset`, and `parameters.transforms` keys are parsed and exposed as constructor-ready dictionaries via `_asdict()`. -- Extended configuration documentation with an explicit mapping table from config keys to trainer/model/datasource consumption paths and a full SmolVLA/LeRobot TOML example. -files_touched: -- `tests/test_config_loader.py` (updated) -- `docs/docs/configurations/parameters.md` (updated) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- No `Config()` code change was required; introducing strict validation at this stage would have been intrusive and risked regressions for existing custom `parameters.*` users. - -### T4. Implement LeRobot datasource adapter -depends_on: [T2, T3] -status: completed (2026-02-19) -- Add `plato/datasources/lerobot.py`. -- Implement dataset loading via LeRobot APIs and map samples into Plato’s expected batch format. -- Register datasource in `plato/datasources/registry.py`. -- Add deterministic client partitioning strategy (episode/task aware split). -- Provide train/test dataset access methods compatible with existing samplers. -work_log: -- Added `plato/datasources/lerobot.py` with guarded LeRobot imports, config parsing for `parameters.dataset.*` and `parameters.transforms.*`, and sample mapping that preserves raw fields while attaching `plato_inputs`, `plato_targets`, and `plato_metadata`. -- Implemented deterministic episode-level train/test splitting with optional explicit episode overrides, task-aware stratification when task metadata is available, and deterministic per-client episode partitioning keyed by `data.random_seed`/`parameters.dataset.split_seed`. -- Wired `"LeRobot"` through `plato/datasources/registry.py` as a partitioned datasource so `datasources_registry.get(client_id=...)` passes client identity into the adapter. -- Ran a targeted no-download constructor/registry validation using stubbed `LeRobotDataset` and `LeRobotDatasetMetadata`, confirming deterministic splits and registry retrieval. -files_touched: -- `plato/datasources/lerobot.py` (created) -- `plato/datasources/registry.py` (updated) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- Constructor/registry validation was intentionally monkeypatched to avoid external dataset access and to keep split checks deterministic and offline. -- When task metadata is sparse or missing, the adapter falls back to deterministic episode-only splitting. - -### T5. Implement SmolVLA model/policy wrapper -depends_on: [T2, T3] -status: completed (2026-02-19) -- Add `plato/models/smolvla.py`. -- Implement pretrained loading path (`smolvla_base` and custom repo id/path). -- Expose trainable-parameter policy (full model or adapter path). -- Register model in `plato/models/registry.py`. -- Ensure state dict save/load compatibility with Plato aggregation pipeline. -work_log: -- Added `plato/models/smolvla.py` with lazy LeRobot import guards, actionable installation errors for missing robotics extras, and a SmolVLA factory path compatible with Plato model registry usage. -- Implemented pretrained policy source resolution with support for `smolvla_base` aliasing to `lerobot/smolvla_base`, config-based `parameters.policy.path`, and explicit constructor overrides (`policy_path` / `path`). -- Added finetune policy modes for `full` and `adapter`; adapter mode uses configurable name-pattern matching and falls back to the loaded policy's existing `requires_grad` flags when patterns do not match. -- Added compatibility checks for `state_dict`, `load_state_dict`, and `save_pretrained`, then registered `model_type = "smolvla"` in `plato/models/registry.py`. -- Ran a targeted constructor/import validation without downloads by monkeypatching `SmolVLAPolicy.from_pretrained` and verifying both direct wrapper construction and registry resolution. -files_touched: -- `plato/models/smolvla.py` (created) -- `plato/models/registry.py` (updated) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- Adapter parameter names are model-dependent; when no configured adapter patterns match, the wrapper intentionally reuses the model's preconfigured trainable flags instead of silently leaving zero trainable tensors. - -### T6. Implement LeRobot trainer backend -depends_on: [T4, T5] -status: completed (2026-02-19) -- Add `plato/trainers/lerobot.py` (ComposableTrainer-compatible). -- Implement multimodal collate + preprocessing for LeRobot samples. -- Wire forward/loss/backward/optimizer/scheduler flow for SmolVLA policy. -- Implement evaluation hooks suitable for regression checks. -- Register trainer in `plato/trainers/registry.py`. -work_log: -- Added `plato/trainers/lerobot.py` with a ComposableTrainer-compatible backend that wires custom dict/multimodal collation, processor-aware training steps, and evaluation loss reporting for regression checks. -- Implemented LeRobot pre/post-processor integration via `make_pre_post_processors(policy_cfg, pretrained_path=..., dataset_stats=...)`, with lazy optional-dependency import guards and actionable installation errors. -- Implemented SmolVLA policy forward integration handling tuple loss outputs and preserving optimizer + scheduler flow through the base composable lifecycle. -- Registered `trainer.type = "lerobot"` in `plato/trainers/registry.py`. -- Ran targeted offline validation with monkeypatched processor stubs: -- trainer registry resolution and construction (`trainer.type = "lerobot"`), -- synthetic one-epoch training-step path, -- synthetic evaluation pass returning numeric loss. -- Ran `uv run ruff check` on touched trainer files. -files_touched: -- `plato/trainers/lerobot.py` (created) -- `plato/trainers/registry.py` (updated) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- LeRobot preprocessing depends on optional robotics extras at runtime; the trainer defers imports until processor initialization so non-robotics workloads remain unaffected, and it fails with a clear `uv sync --extra robotics` message when required dependencies are missing. - -### T7. Harden federated payload/aggregation behavior -depends_on: [T6] -status: completed (2026-02-19) -- Ensure only intended trainable tensors are exchanged/aggregated. -- Add safeguards for payload size and dtype handling. -- Verify checkpoint/state restore consistency across rounds. -- Validate no regressions in FedAvg flow with large model weights. -work_log: -- Hardened `plato/algorithms/fedavg.py` to exchange adapter-only tensors when `plato_finetune_mode = "adapter"` and `plato_trainable_parameter_names` are provided, while preserving full-state behavior for existing non-adapter models. -- Added dtype-safe tensor casting and partial payload merge logic in `load_weights()`, plus stricter key/shape validation and delta application safeguards for partial/full state dicts across rounds. -- Added payload-size safeguards with an optional limit (`model.plato_max_payload_size_mb` or `PLATO_FEDAVG_MAX_PAYLOAD_MB`) and fail-fast checks when payloads exceed the configured cap. -- Added targeted regression tests for filtered extract/load round-trip, dtype safety, optional payload-size guard, and full-mode FedAvg round-trip with large weights. -- Ran `uv run ruff check` and focused `uv run pytest` for the new FedAvg algorithm tests. -files_touched: -- `plato/algorithms/fedavg.py` (updated) -- `tests/algorithms/test_fedavg_algorithm.py` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- Payload-size enforcement is intentionally opt-in to maintain backward compatibility for existing workloads that may exchange large full-model state dicts. - -### T8. Validate runtime lifecycle compatibility -depends_on: [T6] -status: completed (2026-02-19) -- Confirm integration works with existing lifecycle code paths: -- client setup strategies -- server trainer initialization -- training/report/aggregation loop -- Avoid special-case branching unless strictly necessary. -work_log: -- Ran a focused runtime smoke with `data.datasource = "LeRobot"`, `trainer.type = "lerobot"`, and `trainer.model_type = "smolvla"` using monkeypatched LeRobot/SmolVLA externals to avoid downloads, then exercised the default `simple.Client` lifecycle (`_load_data` -> `configure` -> `_allocate_data` -> `_train`). -- Verified lifecycle construction path through existing registries and strategy plumbing: datasource (`LeRobot`) + trainer (`lerobot`) + algorithm (`fedavg`) were all instantiated through default client/server setup with no special-case branching. -- Executed a short mocked client/server round-trip by feeding the client-produced payload/report into `fedavg.Server._process_reports()` after `Server.configure()`, confirming server trainer initialization and aggregation/report processing completed successfully. -- No lifecycle compatibility bug was found in this scope, so no runtime code patch was applied. -files_touched: -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- The focused smoke directly called `client._train()`; because `report.processing_time` is normally attached in the payload strategy path, the smoke sets `report.processing_time = 0.0` before invoking server report processing. - -### T9. Add runnable experiment configs -depends_on: [T3, T4, T5, T6] -status: completed (2026-02-19) -- Add `configs/LeRobot/` config set: -- reusable base datasource fragment -- minimal smoke config -- full fine-tune config aligned to SmolVLA guidance -- Ensure includes/overrides follow repository config conventions. -work_log: -- Added `configs/LeRobot/` with a reusable datasource include fragment plus runnable single-client smoke, two-client FedAvg smoke, and fuller full-fine-tune configs. -- Aligned all new configs with T4-T6 integration keys: `data.datasource = "LeRobot"`, `trainer.type = "lerobot"`, `trainer.model_type = "smolvla"`, and explicit `[parameters.policy]`, `[parameters.dataset]`, `[parameters.transforms]` sections. -- Mapped SmolVLA fine-tuning guidance into Plato semantics by keeping `policy.path = "lerobot/smolvla_base"`, `policy.finetune_mode = "full"`, `policy.device = "cuda"`, and `batch_size = 64` in the fuller config. -files_touched: -- `configs/LeRobot/lerobot_datasource_base.toml` (created) -- `configs/LeRobot/smolvla_single_client_smoke.toml` (created) -- `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` (created) -- `configs/LeRobot/smolvla_full_finetune.toml` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- The datasource include fragment is intentionally sectionless so `[data].include` merges it directly into the `data` table. -- SmolVLA upstream examples are step-based (`lerobot-train --steps`), while Plato scheduling is round/epoch-based, so the fuller config mirrors guidance through batch/device/fine-tune mode and keeps runtime knobs in `trainer.rounds` + `trainer.epochs`. - -### T10. Add tests (unit + integration smoke) -depends_on: [T7, T8, T9] -status: completed (2026-02-19) -- Datasource registry + constructor tests for LeRobot datasource. -- Model registry + construction tests for SmolVLA wrapper. -- Trainer step test with tiny synthetic batch. -- End-to-end config smoke test covering startup and one short training run. -- Add regression tests for any bug fixes discovered during integration. -work_log: -- Added focused LeRobot datasource tests covering partitioned registry resolution, deterministic constructor split behavior, and mapped `plato_inputs`/`plato_targets` sample keys. -- Added SmolVLA model tests covering registry-based wrapper construction and a FedAvg regression check asserting adapter-mode metadata results in adapter-only payload extraction. -- Added a LeRobot trainer tiny-batch unit test that exercises one short training step with synthetic dict samples, stubbed pre/post processors, and parameter-update assertions. -- Added an end-to-end LeRobot+SmolVLA smoke test that boots from config, runs one short client training pass, and processes a server FedAvg report/update loop with external dependencies fully monkeypatched. -- Fixed local test import shadowing discovered during validation by adding package markers under `tests/` and `tests/test_utils/`. -files_touched: -- `tests/__init__.py` (created) -- `tests/test_utils/__init__.py` (created) -- `tests/test_utils/lerobot_stubs.py` (created) -- `tests/datasources/test_lerobot_datasource.py` (created) -- `tests/models/test_smolvla_model.py` (created) -- `tests/trainers/test_lerobot_trainer.py` (created) -- `tests/integration/test_lerobot_smolvla_smoke.py` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- The local environment includes a third-party `tests` package in site-packages; without `tests/__init__.py`, pytest imports can resolve to the wrong module namespace. - -### T11. Add documentation and runbook -depends_on: [T10] -status: completed (2026-02-19) -- Document setup and dependency extras. -- Document config fields and examples. -- Add troubleshooting notes (dataset access, device setup, common failures). -- Add mapping between Plato config and equivalent `lerobot-train` concepts. -work_log: -- Added an operator-facing runbook covering dependency setup, runnable commands, minimum TOML contract, and troubleshooting for common LeRobot/SmolVLA failures. -- Added an explicit Plato TOML to `lerobot-train` mapping table with direct flag mappings (`policy.path`, `dataset.repo_id`, `batch_size`, `policy.device`) and conceptual mappings (`rounds`/`epochs` vs `steps`, output paths). -- Referenced all new `configs/LeRobot/*` profiles directly in the runbook and linked the runbook from installation docs and top-level docs navigation. -- Grounded mapping/troubleshooting notes against current LeRobot documentation via Context7 and implementation-specific runtime errors from Plato's LeRobot datasource/trainer/model integration. -files_touched: -- `docs/docs/smolvla_lerobot_runbook.md` (created) -- `docs/docs/install.md` (updated) -- `docs/mkdocs.yml` (updated) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- `lerobot-train` examples are primarily step-based (`--steps`) while Plato scheduling is round/epoch-based; documentation uses explicit conceptual mapping instead of implying a one-to-one flag conversion. - -### T12. Stage validation and rollout gate -depends_on: [T11] -status: completed (2026-02-19) -- Execute staged validation: -- single-client local run -- 2-client federated smoke run -- larger run for convergence and stability check -- Compare behavior/runtime against expected baseline. -- Define go/no-go criteria and recommended defaults for first public release. -work_log: -- Captured validation window and environment (`2026-02-19 12:51-13:05 EST`, `uv 0.9.18`, Python `3.13.11`, `lerobot` + `torch` importable, `torch.cuda.is_available() == False`). -- Ran focused preflight baseline: `uv run pytest -q tests/test_config_loader.py::test_config_loads_smolvla_lerobot_parameter_contract tests/datasources/test_lerobot_datasource.py tests/models/test_smolvla_model.py tests/trainers/test_lerobot_trainer.py tests/integration/test_lerobot_smolvla_smoke.py tests/algorithms/test_fedavg_algorithm.py` -> `12 passed`. -- Ran staged real-config commands with bounded runtime: -- `timeout 300 uv run python plato.py --config configs/LeRobot/smolvla_single_client_smoke.toml` -> fail (exit `124`, hit `TypeError: Got unsupported ScalarType BFloat16` during round-1 model payload serialization). -- `timeout 240 uv run python plato.py --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` -> fail (exit `124`, same BFloat16 serialization failure). -- `timeout 120 uv run python plato.py --config configs/LeRobot/smolvla_full_finetune.toml -u` -> fail (exit `124`, same BFloat16 serialization failure before convergence phase). -- Verified generated runtime CSVs (`runtime/results/94032.csv`, `runtime/results/94157.csv`, `runtime/results/94326.csv`) contain headers only and no completed round rows. -- Recorded gate decision and release defaults in `plans/smolvla-lerobot-validation-report.md`: current gate is `NO-GO` until bfloat16 payload serialization is fixed. -files_touched: -- `plans/smolvla-lerobot-validation-report.md` (created) -- `plans/smolvla-lerobot-plato-integration-plan.md` (updated) -gotchas: -- Staged runs can download/load SmolVLA and initialize LeRobot datasets, but federated round dispatch currently blocks on safetensor/tree serialization of bfloat16 tensors (`Got unsupported ScalarType BFloat16`), resulting in stalled runs that require timeout fencing. - -## Milestones - -- Milestone A (Core plumbing): T1-T6 complete. -- Milestone B (Federated reliability): T7-T8 complete. -- Milestone C (Usability + confidence): T9-T12 complete. - -## Notes - -- Existing codebase discovery found no native SmolVLA/LeRobot implementation yet. -- Primary extension anchors are current registries and Hugging Face integration patterns. diff --git a/plans/smolvla-lerobot-validation-report.md b/plans/smolvla-lerobot-validation-report.md deleted file mode 100644 index 472f13571..000000000 --- a/plans/smolvla-lerobot-validation-report.md +++ /dev/null @@ -1,175 +0,0 @@ -# SmolVLA + LeRobot Validation Report (T12) - -Date: 2026-02-19 -Validation window: 2026-02-19 12:51:36 EST to 13:05:19 EST (UTC: 17:51:36 to 18:05:19) - -## 1) Environment context - -- Repo: `/Users/bli/Playground/plato` -- `uv`: `0.9.18` -- Python: `3.13.11` -- Dependency probes: - - `import lerobot` -> available - - `import torch` -> available - - `torch.cuda.is_available()` -> `False` - -## 2) Commands executed and concrete outcomes - -### A. Baseline preflight (lightweight, offline-safe) - -Command: - -```bash -/usr/bin/time -p uv run pytest -q \ - tests/test_config_loader.py::test_config_loads_smolvla_lerobot_parameter_contract \ - tests/datasources/test_lerobot_datasource.py \ - tests/models/test_smolvla_model.py \ - tests/trainers/test_lerobot_trainer.py \ - tests/integration/test_lerobot_smolvla_smoke.py \ - tests/algorithms/test_fedavg_algorithm.py -``` - -Outcome: - -- Pass: `12 passed in 0.08s` -- Wall clock (`time -p`): `real 4.92`, `user 4.42`, `sys 0.50` - -Interpretation: - -- Local unit/integration coverage for LeRobot datasource, SmolVLA model wrapper, trainer, and FedAvg adapter behavior is healthy. - -### B. Stage 1: Single-client local run - -Command: - -```bash -/usr/bin/time -p timeout 300 uv run python plato.py \ - --config configs/LeRobot/smolvla_single_client_smoke.toml -``` - -Observed key runtime behavior: - -- Server and client initialized. -- LeRobot datasource loaded (`train episodes=165, test episodes=41`). -- Failure during round-1 model dispatch: - - `TypeError: Got unsupported ScalarType BFloat16` - - stack path includes `plato/processors/safetensor_encode.py` -> `plato/serialization/safetensor.py` -> `plato/utils/tree.py`. - -Exit/timing: - -- Fail: exit code `124` (timeout) -- `real 300.96`, `user 20.08`, `sys 15.71` - -### C. Stage 2: 2-client federated smoke run - -Command: - -```bash -/usr/bin/time -p timeout 240 uv run python plato.py \ - --config configs/LeRobot/smolvla_fedavg_two_client_smoke.toml -``` - -Observed key runtime behavior: - -- Server started with 2 clients configured. -- First client connected; round started. -- Failure on first payload send with the same exception: - - `TypeError: Got unsupported ScalarType BFloat16` - -Exit/timing: - -- Fail: exit code `124` (timeout) -- `real 240.91`, `user 12.78`, `sys 3.63` - -### D. Stage 3: Larger run (convergence/stability gate proxy) - -Command: - -```bash -/usr/bin/time -p timeout 120 uv run python plato.py \ - --config configs/LeRobot/smolvla_full_finetune.toml -u -``` - -Notes: - -- `-u` used because this environment reports `torch.cuda.is_available() == False`. - -Observed key runtime behavior: - -- Server initialized (`Training: 10 rounds`). -- Datasource initialized (`train episodes=185, test episodes=21`). -- Failure at round-1 dispatch with the same exception: - - `TypeError: Got unsupported ScalarType BFloat16` - -Exit/timing: - -- Fail: exit code `124` (timeout) -- `real 120.99`, `user 13.03`, `sys 4.22` - -### E. Runtime artifact check - -Command: - -```bash -ls runtime/results | rg "^(94032|94157|94326)\\.csv$" -``` - -Observed files: - -- `runtime/results/94032.csv` -- `runtime/results/94157.csv` -- `runtime/results/94326.csv` - -Content check: - -- Each file contains header only: `round,accuracy,elapsed_time` -- No completed round rows were recorded. - -## 3) Baseline comparison - -Expected baseline for staged gate: - -- Single-client smoke: complete 1/1 round and exit without unhandled exception. -- Two-client smoke: complete 1/1 federated round with both clients selected and aggregated. -- Larger run: progress beyond round 1 to provide initial convergence/stability signal. - -Actual: - -- All three runs failed before completing round 1 due the same bfloat16 serialization issue. -- Therefore runtime behavior is below baseline for release readiness. - -## 4) What could not be fully validated and why - -- Full convergence/stability behavior (multi-round trend) could not be validated because execution stopped before any completed round. -- End-to-end federated completion for the 2-client path could not be validated for the same reason. -- CUDA path in `smolvla_full_finetune.toml` could not be validated in this environment (`torch.cuda.is_available() == False`). - -## 5) Go/No-Go rollout gate - -Decision: **NO-GO** for first public release in current state. - -Blocking condition: - -- Federated payload serialization does not currently handle bfloat16 tensors emitted by SmolVLA/LeRobot policy state, causing unhandled exceptions and hung runs until timeout. - -Suggested release gate criteria to re-run after fix: - -1. No unhandled exception across all three staged commands. -2. Single-client smoke completes within timeout and writes >=1 runtime CSV data row. -3. Two-client smoke completes round 1 with aggregation and writes >=1 runtime CSV data row. -4. Larger profile completes at least 3 rounds in the same environment class used for release qualification. - -## 6) Recommended default settings for first public release - -These are recommended defaults **after** the blocking serialization issue is fixed: - -- Entry profile: `configs/LeRobot/smolvla_single_client_smoke.toml` -- `parameters.policy.finetune_mode = "adapter"` -- `parameters.policy.precision = "fp32"` -- `parameters.policy.device = "cpu"` for first-run smoke docs, then move to accelerator. -- `trainer.rounds = 1`, `trainer.epochs = 1`, `trainer.batch_size = 2` as onboarding default. -- Federated smoke default: `configs/LeRobot/smolvla_fedavg_two_client_smoke.toml` as second gate after single-client pass. - -Operational note: - -- Keep explicit timeout wrappers in CI/staging commands to avoid indefinite hangs when async server/client exceptions occur.