From 5eb3a0e1bffadc5bbc4669ebdedc59e870fee46b Mon Sep 17 00:00:00 2001 From: Alvin Chang <1977968+alvin-chang@users.noreply.github.com> Date: Sun, 15 Mar 2026 11:10:12 +0000 Subject: [PATCH 1/5] feat: add MLX backend for local Apple Silicon RL training - metaclaw/mlx_backend/: Tinker-compatible adapter package (ServiceClient, SamplingClient, LoraTrainingClient) backed by mlx and mlx-lm - metaclaw/sdk_backend.py: auto-detection falls through to MLX when no cloud credentials are set - metaclaw/api_server.py: guard run_llm when no llm_api_key configured; skip Tinker-specific sample_async kwargs for MLX - metaclaw/setup_wizard.py: add 'mlx' backend choice, skip API key prompts, validate mlx/mlx-lm install, default to HF MLX model ID - tests/: end-to-end smoke test + unit tests (16/16 passing) --- .gitignore | 46 +++ INTEGRATION_NOTES.md | 80 +++++ metaclaw/api_server.py | 29 +- metaclaw/mlx_backend/__init__.py | 32 ++ metaclaw/mlx_backend/data_types.py | 123 +++++++ metaclaw/mlx_backend/lora.py | 59 ++++ metaclaw/mlx_backend/params.py | 34 ++ metaclaw/mlx_backend/service_client.py | 334 +++++++++++++++++++ metaclaw/sdk_backend.py | 42 ++- metaclaw/setup_wizard.py | 91 ++++-- tests/smoke_mlx_proxy.py | 314 ++++++++++++++++++ tests/test_mlx_backend.py | 237 ++++++++++++++ tests/test_mlx_integration.py | 423 +++++++++++++++++++++++++ 13 files changed, 1805 insertions(+), 39 deletions(-) create mode 100644 INTEGRATION_NOTES.md create mode 100644 metaclaw/mlx_backend/__init__.py create mode 100644 metaclaw/mlx_backend/data_types.py create mode 100644 metaclaw/mlx_backend/lora.py create mode 100644 metaclaw/mlx_backend/params.py create mode 100644 metaclaw/mlx_backend/service_client.py create mode 100644 tests/smoke_mlx_proxy.py create mode 100644 tests/test_mlx_backend.py create mode 100644 tests/test_mlx_integration.py diff --git a/.gitignore b/.gitignore index ea243e63..7d7aeee8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,51 @@ +# Python *.pyc __pycache__/ +*.pyo +*.egg-info/ +dist/ +build/ +*.egg + +# Virtual environments +.venv/ +venv/ +env/ + +# MetaClaw runtime data memory_data/skills/ records/ +system_prompt_cache.json +evolution_history.jsonl +scheduler_state.json +*.pid + +# RL training artifacts wandb/ +checkpoints/ +*.ckpt + +# MLX model cache (large downloads) +mlx_models/ + +# Smoke test temp files +tests/.smoke_records/ + +# OS junk +.DS_Store +Thumbs.db + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Secrets +.env +config.yaml + +# MLX training output +mlx_metaclaw_output/ +*.egg-info/ diff --git a/INTEGRATION_NOTES.md b/INTEGRATION_NOTES.md new file mode 100644 index 00000000..ab808eb1 --- /dev/null +++ b/INTEGRATION_NOTES.md @@ -0,0 +1,80 @@ +# MLX Backend Integration Notes + +## Files to add +- `metaclaw/mlx_backend/__init__.py` +- `metaclaw/mlx_backend/data_types.py` +- `metaclaw/mlx_backend/params.py` +- `metaclaw/mlx_backend/lora.py` +- `metaclaw/mlx_backend/service_client.py` +- `tests/test_mlx_backend.py` + +## Files to replace +- `metaclaw/sdk_backend.py` (full replacement with MLX support) + +## Files to patch (small edits) + +### metaclaw/setup_wizard.py +# In metaclaw/setup_wizard.py, update line ~156: +# +# BEFORE: +# ["auto", "tinker", "mint"], +# +# AFTER: +# ["auto", "tinker", "mint", "mlx"], +# +# This adds "mlx" to the interactive backend selection menu. + + +### metaclaw/config.py +# In metaclaw/config.py, add these fields to MetaClawConfig: +# +# # MLX backend settings +# mlx_model_path: str = "" # local path or HF repo (e.g. mlx-community/Qwen2.5-7B-4bit) +# mlx_output_dir: str = "./mlx_metaclaw_output" +# +# Update training_backend_label() around line 168: +# +# BEFORE: +# def training_backend_label(self) -> str: +# return "MinT" if self.resolved_backend_key() == "mint" else "Tinker" +# +# AFTER: +# def training_backend_label(self) -> str: +# key = self.resolved_backend_key() +# if key == "mlx": +# return "MLX" +# return "MinT" if key == "mint" else "Tinker" +# +# Update training_backend_banner() around line 171: +# +# BEFORE: +# def training_backend_banner(self) -> str: +# return f"{self.training_backend_label()} cloud RL" +# +# AFTER: +# def training_backend_banner(self) -> str: +# label = self.training_backend_label() +# suffix = "local RL" if self.resolved_backend_key() == "mlx" else "cloud RL" +# return f"{label} {suffix}" + + +## Optional: pyproject.toml extras + +```toml +[project.optional-dependencies] +mlx = ["mlx>=0.22.0", "mlx-lm>=0.21.0", "safetensors"] +``` + +## Usage + +```bash +# Install with MLX extras +pip install -e ".[mlx]" + +# Configure +metaclaw setup # select backend → mlx + +# Or via env +export METACLAW_RL_BACKEND=mlx +metaclaw start +``` diff --git a/metaclaw/api_server.py b/metaclaw/api_server.py index e719595c..440a37d9 100644 --- a/metaclaw/api_server.py +++ b/metaclaw/api_server.py @@ -722,10 +722,21 @@ def _prompt_len(msgs): raw_system = _flatten_message_content(m.get("content")) break if raw_system: - cached_system = await asyncio.to_thread( - run_llm, - [{"role": "user", "content": raw_system}], - ) + # System prompt compression requires an external LLM API. + # When running with a local-only backend (e.g. MLX) and no + # llm_api_key configured, skip compression and use raw prompt. + if self.config.llm_api_key: + try: + cached_system = await asyncio.to_thread( + run_llm, + [{"role": "user", "content": raw_system}], + ) + except Exception as exc: + logger.warning( + "[OpenClaw] system prompt compression failed, " + "using raw prompt: %s", exc, + ) + cached_system = None cached_system = (cached_system or raw_system).strip() self._write_cached_system_prompt(cached_system) @@ -953,13 +964,17 @@ async def _forward_to_backend(self, body: dict[str, Any]) -> dict[str, Any]: sampling_params = self._sdk.SamplingParams(**sp_kwargs) # Call active backend - response = await self._sampling_client.sample_async( + # include_prompt_logprobs / topk_prompt_logprobs are Tinker-specific; + # MLX (and potentially other local backends) don't support them. + sample_kwargs: dict[str, Any] = dict( prompt=model_input, num_samples=1, sampling_params=sampling_params, - include_prompt_logprobs=False, - topk_prompt_logprobs=0, ) + if backend_key != "mlx": + sample_kwargs["include_prompt_logprobs"] = False + sample_kwargs["topk_prompt_logprobs"] = 0 + response = await self._sampling_client.sample_async(**sample_kwargs) # Decode response tokens → text seq = response.sequences[0] diff --git a/metaclaw/mlx_backend/__init__.py b/metaclaw/mlx_backend/__init__.py new file mode 100644 index 00000000..e70860d2 --- /dev/null +++ b/metaclaw/mlx_backend/__init__.py @@ -0,0 +1,32 @@ +""" +MLX-native LoRA training backend for MetaClaw. + +Provides a local, zero-cloud alternative to the Tinker and MinT backends +using Apple MLX on Apple Silicon. No API key or network required. +""" + +from .data_types import ( + Datum, + EncodedTextChunk, + ModelInput, + SampleResponse, + SampleSequence, + TensorData, +) +from .params import AdamParams, SamplingParams +from .service_client import SamplingClient, SaveStateResult, ServiceClient, TrainingClient + +__all__ = [ + "AdamParams", + "Datum", + "EncodedTextChunk", + "ModelInput", + "SampleResponse", + "SampleSequence", + "SamplingClient", + "SamplingParams", + "SaveStateResult", + "ServiceClient", + "TensorData", + "TrainingClient", +] diff --git a/metaclaw/mlx_backend/data_types.py b/metaclaw/mlx_backend/data_types.py new file mode 100644 index 00000000..9b4e8092 --- /dev/null +++ b/metaclaw/mlx_backend/data_types.py @@ -0,0 +1,123 @@ +""" +Data types that mirror the Tinker SDK surface used by data_formatter.py +and api_server.py. + +Training path: TensorData, ModelInput.from_ints(), Datum +Inference path: EncodedTextChunk, ModelInput(chunks=...), SampleSequence, SampleResponse +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import mlx.core as mx + + +# ------------------------------------------------------------------ # +# Training types (used by data_formatter.py) # +# ------------------------------------------------------------------ # + +@dataclass +class TensorData: + """Thin wrapper around an MLX array, convertible from PyTorch tensors.""" + + array: mx.array + + @classmethod + def from_torch(cls, tensor) -> "TensorData": + import numpy as np + arr = mx.array(tensor.detach().cpu().numpy()) + return cls(array=arr) + + def to_mlx(self) -> mx.array: + return self.array + + def __len__(self) -> int: + return self.array.shape[0] + + +@dataclass +class Datum: + """One training example in the tinker-cookbook RL convention.""" + + model_input: "ModelInput" + loss_fn_inputs: Dict[str, TensorData] = field(default_factory=dict) + + +# ------------------------------------------------------------------ # +# Inference types (used by api_server.py forward_to_backend) # +# ------------------------------------------------------------------ # + +@dataclass +class EncodedTextChunk: + """Mirrors tinker.EncodedTextChunk. + + api_server.py calls: + chunk = sdk.EncodedTextChunk(tokens=list(prompt_ids), type="encoded_text") + """ + tokens: List[int] + type: str = "encoded_text" + + +@dataclass +class SampleSequence: + """One generated sequence returned by SamplingClient.sample_async(). + + api_server.py reads: + seq = response.sequences[0] + seq.tokens -> list[int] + seq.logprobs -> list[float] + seq.stop_reason -> str + """ + tokens: List[int] + logprobs: List[float] + stop_reason: str = "stop" + + +@dataclass +class SampleResponse: + """Container returned by SamplingClient.sample_async(). + + api_server.py reads: response.sequences[0] + """ + sequences: List[SampleSequence] + + +# ------------------------------------------------------------------ # +# ModelInput (dual-purpose: training + inference) # +# ------------------------------------------------------------------ # + +@dataclass +class ModelInput: + """Token sequence for model consumption. + + Training path (data_formatter.py): + sdk.ModelInput.from_ints(all_tokens[:-1]) + -> uses .tokens + + Inference path (api_server.py): + sdk.ModelInput(chunks=[chunk]) + -> uses .chunks[0].tokens + """ + tokens: Optional[mx.array] = None + chunks: Optional[List[EncodedTextChunk]] = None + + @classmethod + def from_ints(cls, token_ids: List[int]) -> "ModelInput": + return cls(tokens=mx.array(token_ids, dtype=mx.int32)) + + def get_token_ids(self) -> List[int]: + """Return plain list of ints regardless of how this was constructed.""" + if self.tokens is not None: + return self.tokens.tolist() + if self.chunks: + return self.chunks[0].tokens + return [] + + def __len__(self) -> int: + if self.tokens is not None: + return self.tokens.shape[0] + if self.chunks: + return len(self.chunks[0].tokens) + return 0 diff --git a/metaclaw/mlx_backend/lora.py b/metaclaw/mlx_backend/lora.py new file mode 100644 index 00000000..795843fd --- /dev/null +++ b/metaclaw/mlx_backend/lora.py @@ -0,0 +1,59 @@ +"""LoRA layer injection and weight I/O using mlx_lm's built-in tuner.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + +logger = logging.getLogger(__name__) + + +def inject_lora( + model: nn.Module, + rank: int = 16, + alpha: float = 16.0, + num_layers: int = -1, +) -> nn.Module: + from mlx_lm.tuner.utils import linear_to_lora_layers + + lora_cfg = {"rank": rank, "alpha": alpha, "dropout": 0.0, "scale": alpha / rank} + + linear_to_lora_layers( + model, + num_layers=num_layers, + config=lora_cfg, + ) + + n_train = sum(p.size for _, p in tree_flatten(model.trainable_parameters())) + n_total = sum(p.size for _, p in tree_flatten(model.parameters())) + pct = 100 * n_train / n_total if n_total > 0 else 0 + logger.info( + "[MLX-LoRA] injected adapters (rank=%d alpha=%.1f): " + "trainable=%d / %d params (%.2f%%)", + rank, alpha, n_train, n_total, pct, + ) + return model + + +def save_lora_weights(model: nn.Module, path: str | Path) -> Path: + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + out_file = path / "adapters.safetensors" + + trainable = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(out_file), trainable) + logger.info("[MLX-LoRA] saved %d tensors -> %s", len(trainable), out_file) + return out_file + + +def load_lora_weights(model: nn.Module, path: str | Path) -> nn.Module: + path = Path(path) + adapter_file = path / "adapters.safetensors" if path.is_dir() else path + + model.load_weights(str(adapter_file), strict=False) + logger.info("[MLX-LoRA] loaded adapters <- %s", adapter_file) + return model diff --git a/metaclaw/mlx_backend/params.py b/metaclaw/mlx_backend/params.py new file mode 100644 index 00000000..50c9ca16 --- /dev/null +++ b/metaclaw/mlx_backend/params.py @@ -0,0 +1,34 @@ +"""Optimizer and sampling parameter dataclasses.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, List + + +@dataclass +class AdamParams: + """Adam optimizer hyperparameters. + + trainer.py calls: sdk.AdamParams(learning_rate=config.learning_rate) + """ + learning_rate: float = 1e-4 + beta1: float = 0.9 + beta2: float = 0.999 + eps: float = 1e-8 + weight_decay: float = 0.0 + + +@dataclass +class SamplingParams: + """Generation parameters for SamplingClient. + + api_server.py calls: + sdk.SamplingParams(temperature=..., max_tokens=..., top_k=50, top_p=0.95, stop=...) + """ + temperature: float = 0.6 + top_p: float = 0.9 + max_tokens: int = 2048 + repetition_penalty: float = 1.0 + top_k: int = 50 + stop: Optional[List[str]] = None diff --git a/metaclaw/mlx_backend/service_client.py b/metaclaw/mlx_backend/service_client.py new file mode 100644 index 00000000..05256e95 --- /dev/null +++ b/metaclaw/mlx_backend/service_client.py @@ -0,0 +1,334 @@ +""" +MLX-native ServiceClient, TrainingClient, and SamplingClient. + +Implements the same async interface as the ``tinker`` SDK so MetaClaw's +trainer.py and api_server.py can use it as a drop-in local backend. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +from .data_types import Datum, ModelInput, SampleResponse, SampleSequence +from .lora import inject_lora, load_lora_weights, save_lora_weights +from .params import AdamParams, SamplingParams + +logger = logging.getLogger(__name__) + + +@dataclass +class SaveStateResult: + path: str + + +class SamplingClient: + """Wraps an MLX model + tokenizer for inference. + + Supports the sample_async() interface that api_server.py calls: + response = await self.sampling_client.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=sampling_params, + include_prompt_logprobs=False, + top_k_prompt_logprobs=0, + ) + seq = response.sequences[0] + seq.tokens / seq.logprobs / seq.stop_reason + """ + + def __init__( + self, + model: nn.Module, + tokenizer: Any, + adapter_path: Optional[Path] = None, + ): + self._model = model + self._tokenizer = tokenizer + self._adapter_path = adapter_path + + @property + def model(self) -> nn.Module: + return self._model + + @property + def tokenizer(self) -> Any: + return self._tokenizer + + @property + def adapter_path(self) -> Optional[Path]: + return self._adapter_path + + async def sample_async( + self, + prompt: ModelInput, + num_samples: int = 1, + sampling_params: Optional[SamplingParams] = None, + include_prompt_logprobs: bool = False, + top_k_prompt_logprobs: int = 0, + ) -> SampleResponse: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self._sample_sync, prompt, num_samples, sampling_params + ) + + def _sample_sync( + self, + prompt: ModelInput, + num_samples: int, + sampling_params: Optional[SamplingParams], + ) -> SampleResponse: + from mlx_lm.generate import generate_step + from mlx_lm.sample_utils import make_sampler + + sp = sampling_params or SamplingParams() + prompt_ids = prompt.get_token_ids() + prompt_arr = mx.array(prompt_ids, dtype=mx.int32) + + sampler = make_sampler(temp=sp.temperature, top_p=sp.top_p) + + stop_strings = set(sp.stop or []) + eos_token_id = getattr(self._tokenizer, "eos_token_id", None) + + stop_token_ids = set() + if eos_token_id is not None: + stop_token_ids.add(eos_token_id) + for s in stop_strings: + ids = self._tokenizer.encode(s, add_special_tokens=False) + if len(ids) == 1: + stop_token_ids.add(ids[0]) + + generated_tokens: List[int] = [] + generated_logprobs: List[float] = [] + + for step_out, _ in zip( + generate_step( + prompt_arr, + self._model, + sampler=sampler, + ), + range(sp.max_tokens), + ): + token, logprob_val = step_out + token_id = token.item() if hasattr(token, "item") else int(token) + + if token_id in stop_token_ids: + break + + generated_tokens.append(token_id) + + # logprob_val may be a scalar, a 1-d vocab array, or a dict + if isinstance(logprob_val, dict): + lp = float(logprob_val.get("logprob", 0.0)) + elif hasattr(logprob_val, "shape") and logprob_val.ndim > 0: + lp = float(logprob_val[token_id].item() if hasattr(logprob_val[token_id], "item") else logprob_val[token_id]) + elif hasattr(logprob_val, "item"): + lp = float(logprob_val.item()) + else: + lp = float(logprob_val) + generated_logprobs.append(lp) + + if stop_strings: + partial = self._tokenizer.decode(generated_tokens) + if any(s in partial for s in stop_strings): + break + + stop_reason = "stop" + if len(generated_tokens) >= sp.max_tokens: + stop_reason = "length" + + seq = SampleSequence( + tokens=generated_tokens, + logprobs=generated_logprobs, + stop_reason=stop_reason, + ) + return SampleResponse(sequences=[seq]) + + +class TrainingClient: + def __init__( + self, + model: nn.Module, + tokenizer: Any, + rank: int, + base_model_name: str, + output_dir: Path, + ): + self._model = model + self._tokenizer = tokenizer + self._rank = rank + self._base_model_name = base_model_name + self._output_dir = output_dir + self._output_dir.mkdir(parents=True, exist_ok=True) + + self._optimizer: Optional[optim.Adam] = None + self._grads: Optional[list] = None + self._step_count = 0 + self._last_loss: float = 0.0 + + async def forward_backward_async( + self, data: List[Datum], loss_fn: str = "policy_gradient" + ) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._forward_backward_sync, data, loss_fn) + + def _forward_backward_sync(self, data: List[Datum], loss_fn: str) -> None: + model = self._model + + def _loss_fn(model, data): + total_loss = mx.array(0.0) + for datum in data: + input_tokens = datum.model_input.tokens[None, :] + targets = datum.loss_fn_inputs["target_tokens"].to_mlx() + advantages = datum.loss_fn_inputs["advantages"].to_mlx() + + logits = model(input_tokens).squeeze(0) + log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + token_lp = mx.take_along_axis( + log_probs, targets[:, None].astype(mx.int32), axis=-1 + ).squeeze(-1) + + if loss_fn in ("grpo", "policy_gradient", "opd"): + total_loss = total_loss + (-mx.sum(advantages * token_lp)) + else: + mask = (advantages != 0).astype(mx.float32) + total_loss = total_loss + (-mx.sum(mask * token_lp)) + + return total_loss / max(len(data), 1) + + loss_grad_fn = nn.value_and_grad(model, _loss_fn) + loss_val, grads = loss_grad_fn(model, data) + mx.eval(loss_val) + + self._grads = grads + self._last_loss = loss_val.item() + logger.info("[MLX] forward_backward done — loss=%.4f", self._last_loss) + + async def optim_step_async(self, params: AdamParams) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self._optim_step_sync, params) + + def _optim_step_sync(self, params: AdamParams) -> None: + if self._grads is None: + logger.warning("[MLX] optim_step called with no gradients — skipping") + return + + if self._optimizer is None or self._optimizer.learning_rate != params.learning_rate: + self._optimizer = optim.Adam( + learning_rate=params.learning_rate, + betas=[params.beta1, params.beta2], + eps=params.eps, + ) + + self._optimizer.update(self._model, self._grads) + mx.eval(self._model.trainable_parameters()) + + self._grads = None + self._step_count += 1 + logger.info("[MLX] optim_step done — step=%d", self._step_count) + + async def save_weights_and_get_sampling_client_async( + self, name: str = "openclaw_lora" + ) -> SamplingClient: + adapter_dir = self._output_dir / name + save_lora_weights(self._model, adapter_dir) + return SamplingClient( + model=self._model, + tokenizer=self._tokenizer, + adapter_path=adapter_dir, + ) + + async def save_state_async(self, name: str = "checkpoint") -> SaveStateResult: + import json + import numpy as np + + ckpt_dir = self._output_dir / "checkpoints" / name + ckpt_dir.mkdir(parents=True, exist_ok=True) + + save_lora_weights(self._model, ckpt_dir) + + if self._optimizer is not None: + opt_state = dict(tree_flatten(self._optimizer.state)) + state_arrays = {k: np.array(v) for k, v in opt_state.items()} + np.savez(str(ckpt_dir / "optimizer_state.npz"), **state_arrays) + + meta = { + "step_count": self._step_count, + "rank": self._rank, + "base_model": self._base_model_name, + "last_loss": self._last_loss, + } + (ckpt_dir / "meta.json").write_text(json.dumps(meta, indent=2)) + + path_str = str(ckpt_dir) + logger.info("[MLX] save_state -> %s", path_str) + return SaveStateResult(path=path_str) + + async def load_state_async(self, path: str) -> None: + import json + import numpy as np + + ckpt_dir = Path(path) + load_lora_weights(self._model, ckpt_dir) + + opt_path = ckpt_dir / "optimizer_state.npz" + if opt_path.exists() and self._optimizer is not None: + npz = np.load(str(opt_path), allow_pickle=False) + state_flat = [(k, mx.array(npz[k])) for k in npz.files] + self._optimizer.state = tree_unflatten(state_flat) + + meta_path = ckpt_dir / "meta.json" + if meta_path.exists(): + meta = json.loads(meta_path.read_text()) + self._step_count = meta.get("step_count", 0) + self._last_loss = meta.get("last_loss", 0.0) + + logger.info("[MLX] load_state <- %s (step=%d)", path, self._step_count) + + +class ServiceClient: + def __init__( + self, + base_url: str = "", + api_key: str = "", + model_path: str = "", + output_dir: str = "", + ): + self._base_url = base_url + self._api_key = api_key + self._model_path = model_path + self._output_dir = output_dir or "./mlx_metaclaw_output" + + async def create_lora_training_client_async( + self, base_model: str, rank: int = 16 + ) -> TrainingClient: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self._create_sync, base_model, rank + ) + + def _create_sync(self, base_model: str, rank: int) -> TrainingClient: + from mlx_lm import load as mlx_load + + model_id = self._model_path or self._base_url or base_model + logger.info("[MLX] loading model: %s", model_id) + model, tokenizer = mlx_load(model_id) + + model.freeze() + model = inject_lora(model, rank=rank, alpha=float(rank)) + + return TrainingClient( + model=model, + tokenizer=tokenizer, + rank=rank, + base_model_name=base_model, + output_dir=Path(self._output_dir), + ) diff --git a/metaclaw/sdk_backend.py b/metaclaw/sdk_backend.py index 806efab3..b55e8c4d 100644 --- a/metaclaw/sdk_backend.py +++ b/metaclaw/sdk_backend.py @@ -1,9 +1,10 @@ """ Runtime backend selection for RL SDK clients. -MetaClaw can talk to either: - - ``tinker`` directly - - ``mint`` via the MindLab compatibility package +MetaClaw can talk to: +- ``tinker`` directly +- ``mint`` via the MindLab compatibility package +- ``mlx`` for local Apple Silicon training (no cloud required) """ from __future__ import annotations @@ -18,8 +19,8 @@ if TYPE_CHECKING: from .config import MetaClawConfig -_VALID_BACKENDS = {"auto", "tinker", "mint"} -_BACKEND_LABELS = {"tinker": "Tinker", "mint": "MinT"} +_VALID_BACKENDS = {"auto", "tinker", "mint", "mlx"} +_BACKEND_LABELS = {"tinker": "Tinker", "mint": "MinT", "mlx": "MLX (local)"} @dataclass(frozen=True) @@ -33,6 +34,8 @@ class SDKBackend: @property def banner(self) -> str: + if self.key == "mlx": + return f"{self.label} local RL" return f"{self.label} cloud RL" @@ -120,14 +123,31 @@ def _has_mint_signal(config: "MetaClawConfig") -> bool: def infer_backend_key(config: "MetaClawConfig") -> str: configured = configured_backend_name(config) - if configured in {"tinker", "mint"}: + if configured in {"tinker", "mint", "mlx"}: return configured + + # auto mode: check for MinT signals first if _has_mint_signal(config) and _module_available("mint"): return "mint" + + # Then check for cloud credentials (Tinker or MinT env vars) + api_key = configured_api_key(config) + base_url = configured_base_url(config) + cloud_env = _first_env("TINKER_API_KEY", "MINT_API_KEY") + + if api_key or base_url or cloud_env: + return "tinker" + + # No cloud credentials at all — fall back to MLX if available + if _module_available("mlx") and _module_available("mlx_lm"): + return "mlx" + return "tinker" def resolve_api_key(config: "MetaClawConfig", backend_key: str | None = None) -> str: + if backend_key == "mlx": + return "" configured = configured_api_key(config) if configured: return configured @@ -136,6 +156,8 @@ def resolve_api_key(config: "MetaClawConfig", backend_key: str | None = None) -> def resolve_base_url(config: "MetaClawConfig", backend_key: str | None = None) -> str: + if backend_key == "mlx": + return "" configured = configured_base_url(config) if configured: return configured @@ -144,6 +166,14 @@ def resolve_base_url(config: "MetaClawConfig", backend_key: str | None = None) - def _import_backend_module(backend_key: str, configured_backend: str): + if backend_key == "mlx": + if not _module_available("mlx") or not _module_available("mlx_lm"): + raise RuntimeError( + "rl.backend='mlx' requires mlx and mlx-lm. " + "Install with: pip install mlx mlx-lm" + ) + return importlib.import_module("metaclaw.mlx_backend") + if backend_key == "mint" and configured_backend == "mint" and not _module_available("mint"): raise RuntimeError( "rl.backend=mint requires the MinT compatibility package. " diff --git a/metaclaw/setup_wizard.py b/metaclaw/setup_wizard.py index cb3c29aa..a23cd95a 100644 --- a/metaclaw/setup_wizard.py +++ b/metaclaw/setup_wizard.py @@ -79,6 +79,15 @@ def _prompt_choice(msg: str, choices: list[str], default: str = "") -> str: print(f" Invalid choice. Pick one of: {choices}") +def _check_mlx_available() -> bool: + """Return True if mlx and mlx_lm are importable.""" + import importlib.util + return ( + importlib.util.find_spec("mlx") is not None + and importlib.util.find_spec("mlx_lm") is not None + ) + + class SetupWizard: """Interactive configuration wizard.""" @@ -109,6 +118,7 @@ def run(self): ["kimi", "qwen", "openai", "minimax", "custom"], default=current_provider, ) + preset = _PROVIDER_PRESETS[provider] api_base = _prompt( "API base URL", @@ -151,32 +161,63 @@ def run(self): if rl_enabled: print("\n--- RL Training Configuration ---") + print( + " auto — detect from credentials (default)\n" + " tinker — Tinker cloud RL\n" + " mint — MinT / MindLab cloud RL\n" + " mlx — local Apple Silicon (no API key needed)" + ) backend = _prompt_choice( "RL backend", - ["auto", "tinker", "mint"], + ["auto", "tinker", "mint", "mlx"], default=str(rl_config.get("backend", "auto") or "auto"), ) - rl_model = _prompt( - "Base model for RL training", - default=rl_config.get("model") or model_id, - ) - backend_api_key = _prompt( - "RL backend API key", - default=( - rl_config.get("api_key") - or rl_config.get("tinker_api_key", "") - ), - hide=True, - ) - backend_base_url = _prompt( - "RL backend base URL (optional)", - default=( - rl_config.get("base_url") - or rl_config.get("tinker_base_url") - or os.environ.get("TINKER_BASE_URL", "") - or os.environ.get("MINT_BASE_URL", "") - ), - ) + + # -- MLX: validate deps, skip cloud credentials -------- + if backend == "mlx": + if _check_mlx_available(): + print(" \u2713 mlx and mlx-lm are installed") + else: + print( + " \u2717 MLX backend requires mlx and mlx-lm.\n" + " Install with: pip install mlx mlx-lm\n" + " You can finish setup now and install them before 'metaclaw start'." + ) + + rl_model = _prompt( + "Base model for RL training (HuggingFace MLX model ID)", + default=( + rl_config.get("model") + or "mlx-community/Qwen2.5-0.5B-Instruct-4bit" + ), + ) + backend_api_key = "" + backend_base_url = "" + + # -- Cloud backends: prompt for credentials ------------- + else: + rl_model = _prompt( + "Base model for RL training", + default=rl_config.get("model") or model_id, + ) + backend_api_key = _prompt( + "RL backend API key", + default=( + rl_config.get("api_key") + or rl_config.get("tinker_api_key", "") + ), + hide=True, + ) + backend_base_url = _prompt( + "RL backend base URL (optional)", + default=( + rl_config.get("base_url") + or rl_config.get("tinker_base_url") + or os.environ.get("TINKER_BASE_URL", "") + or os.environ.get("MINT_BASE_URL", "") + ), + ) + prm_url = _prompt( "PRM (reward model) URL", default=rl_config.get("prm_url", "https://api.openai.com/v1"), @@ -190,6 +231,7 @@ def run(self): default=rl_config.get("prm_api_key", ""), hide=True, ) + lora_rank = _prompt_int("LoRA rank", default=rl_config.get("lora_rank", 32)) resume_from_ckpt = _prompt( "Resume from checkpoint path (optional)", @@ -255,14 +297,13 @@ def run(self): "Enable smart update scheduler", default=bool(current_sched.get("enabled", False)), ) - if sched_enabled: sleep_start = _prompt( "Sleep start time (HH:MM, 24h)", default=current_sched.get("sleep_start", "23:00"), ) sleep_end = _prompt( - "Sleep end time (HH:MM, 24h)", + "Sleep end time (HH:MM, 24h)", default=current_sched.get("sleep_end", "07:00"), ) idle_mins = _prompt_int( @@ -273,7 +314,6 @@ def run(self): "Minimum window required for one RL step (minutes)", default=current_sched.get("min_window_minutes", 15), ) - use_calendar = _prompt_bool( "Use Google Calendar to detect meeting times (optional)", default=bool(current_sched_cal.get("enabled", False)), @@ -344,7 +384,6 @@ def run(self): "rl": rl_config, "scheduler": scheduler_config, } - cs.save(data) Path(skills_dir).expanduser().mkdir(parents=True, exist_ok=True) diff --git a/tests/smoke_mlx_proxy.py b/tests/smoke_mlx_proxy.py new file mode 100644 index 00000000..35b4ba4f --- /dev/null +++ b/tests/smoke_mlx_proxy.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python3 +""" +Smoke test: spin up the MetaClaw proxy with the MLX backend and send +fake OpenClaw chat requests via HTTP. Verifies: + + 1. sdk_backend auto-detects MLX when no cloud credentials are set + 2. api_server.py builds EncodedTextChunk / ModelInput / SamplingParams + from our mlx_backend module and calls sample_async() + 3. Proxy returns valid OpenAI-compatible chat completions with logprobs + 4. Training samples flow through the output queue + 5. A full train_on_batch() step completes with hot-swapped weights + +Usage: + # Make sure no TINKER_API_KEY / MINT_API_KEY are set + unset TINKER_API_KEY MINT_API_KEY TINKER_BASE_URL MINT_BASE_URL + python3 tests/smoke_mlx_proxy.py +""" + +import asyncio +import json +import os +import sys +import time + +# ── Ensure no cloud credentials so auto-detect picks MLX ────────── +for var in ("TINKER_API_KEY", "MINT_API_KEY", "TINKER_BASE_URL", "MINT_BASE_URL"): + os.environ.pop(var, None) + +PASS = "\033[32m✓\033[0m" +FAIL = "\033[31m✗\033[0m" +passed = failed = 0 + + +def check(name, condition, detail=""): + global passed, failed + if condition: + print(f" {PASS} {name}" + (f" ({detail})" if detail else "")) + passed += 1 + else: + print(f" {FAIL} {name}" + (f" ({detail})" if detail else "")) + failed += 1 + + +TEST_MODEL = os.environ.get("MLX_TEST_MODEL", "mlx-community/Qwen2.5-0.5B-Instruct-4bit") +PROXY_PORT = int(os.environ.get("MLX_TEST_PORT", "18899")) + + +def _seed_system_prompt_cache(record_dir: str): + """Pre-populate the system prompt cache so _handle_request never calls + run_llm (which requires an external LLM API key we don't have in the + MLX-only smoke test). This mirrors the cache format written by + api_server._write_cached_system_prompt(). + """ + os.makedirs(record_dir, exist_ok=True) + cache_path = os.path.join(record_dir, "system_prompt_cache.json") + with open(cache_path, "w", encoding="utf-8") as f: + json.dump( + {"compressed_system_prompt": "You are a helpful assistant."}, + f, + ensure_ascii=False, + ) + + +async def main(): + # ── 1. Backend auto-detection ────────────────────────────────── + print(f"\n{'='*60}") + print(f" BACKEND AUTO-DETECTION") + print(f"{'='*60}") + + from types import SimpleNamespace + config = SimpleNamespace( + backend="auto", + api_key="", + tinker_api_key="", + base_url="", + tinker_base_url="", + ) + from metaclaw.sdk_backend import infer_backend_key + detected = infer_backend_key(config) + check("auto-detect picks MLX", detected == "mlx", f"got {detected!r}") + + from metaclaw.sdk_backend import resolve_sdk_backend + backend = resolve_sdk_backend(config) + check("resolve_sdk_backend", backend.key == "mlx" and backend.module is not None, backend.label) + + sdk = backend.module + check("sdk exports EncodedTextChunk", hasattr(sdk, "EncodedTextChunk")) + check("sdk exports SamplingParams", hasattr(sdk, "SamplingParams")) + check("sdk exports ModelInput", hasattr(sdk, "ModelInput")) + + # ── 2. Proxy startup ────────────────────────────────────────── + print(f"\n{'='*60}") + print(f" PROXY STARTUP") + print(f"{'='*60}") + + import queue + import threading + from metaclaw.config import MetaClawConfig + + record_dir = os.path.join(os.path.dirname(__file__), ".smoke_records") + + proxy_config = MetaClawConfig( + backend="mlx", + model_name=TEST_MODEL, + lora_rank=16, + proxy_host="127.0.0.1", + proxy_port=PROXY_PORT, + mode="rl", + use_prm=False, + record_enabled=False, + served_model_name="test-qwen", + record_dir=record_dir, + ) + + # Seed the system prompt cache so run_llm is never called + _seed_system_prompt_cache(record_dir) + + # Create training client + sampling client + t0 = time.time() + service_client = sdk.ServiceClient() + tc = await service_client.create_lora_training_client_async( + base_model=TEST_MODEL, rank=16, + ) + sc = await tc.save_weights_and_get_sampling_client_async() + check("Model + LoRA loaded", sc is not None, f"{time.time()-t0:.1f}s") + + # Spin up api_server + from metaclaw.api_server import MetaClawAPIServer + + output_queue = queue.Queue(maxsize=10000) + submission_enabled = threading.Event() + submission_enabled.set() + + server = MetaClawAPIServer( + config=proxy_config, + output_queue=output_queue, + submission_enabled=submission_enabled, + sampling_client=sc, + ) + server.start() + + # Wait for proxy to be ready + import httpx + deadline = time.time() + 15 + ready = False + while time.time() < deadline: + try: + async with httpx.AsyncClient(timeout=1.0) as client: + r = await client.get(f"http://127.0.0.1:{PROXY_PORT}/healthz") + if r.status_code == 200: + ready = True + break + except Exception: + pass + await asyncio.sleep(0.3) + + check("Proxy /healthz", ready) + if not ready: + print(" !! Proxy did not start — aborting") + server.stop() + return + + # ── 3. Chat completions ─────────────────────────────────────── + print(f"\n{'='*60}") + print(f" CHAT COMPLETIONS") + print(f"{'='*60}") + + sessions = [ + ("smoke-sess-1", "What is 2+2?"), + ("smoke-sess-2", "Name three colors."), + ("smoke-sess-3", "Say hello."), + ] + + for sid, prompt in sessions: + body = { + "model": "test-qwen", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + "max_tokens": 64, + "temperature": 0.7, + } + try: + async with httpx.AsyncClient(timeout=30.0) as client: + t0 = time.time() + resp = await client.post( + f"http://127.0.0.1:{PROXY_PORT}/v1/chat/completions", + json=body, + headers={ + "X-Session-Id": sid, + "X-Turn-Type": "main", + "X-Session-Done": "true", + }, + ) + elapsed = time.time() - t0 + data = resp.json() + + has_response = "response" in data + if has_response: + data = data["response"] + + content = data["choices"][0]["message"]["content"] + logprobs = data["choices"][0].get("logprobs", {}).get("content", []) + + check( + f"Session {sid}", + resp.status_code == 200 and len(content) > 0, + f"{len(content)} chars, {len(logprobs)} logprobs, {elapsed:.2f}s", + ) + except Exception as e: + check(f"Session {sid}", False, str(e)) + + # ── 4. Training samples queued ──────────────────────────────── + print(f"\n{'='*60}") + print(f" TRAINING SAMPLE COLLECTION") + print(f"{'='*60}") + + await asyncio.sleep(1.0) # let background tasks finish + + samples = [] + while not output_queue.empty(): + try: + samples.append(output_queue.get_nowait()) + except queue.Empty: + break + + check( + "Samples in output queue", + len(samples) > 0, + f"{len(samples)} sample groups", + ) + + if samples: + _, first_sample = samples[0] + if isinstance(first_sample, list): + first_sample = first_sample[0] + check( + "Sample has prompt + response tokens", + len(first_sample.prompt_tokens) > 0 and len(first_sample.response_tokens) > 0, + f"prompt={len(first_sample.prompt_tokens)} response={len(first_sample.response_tokens)}", + ) + + # ── 5. Training step with hot-swap ──────────────────────────── + print(f"\n{'='*60}") + print(f" TRAINING STEP + HOT-SWAP") + print(f"{'='*60}") + + if samples: + from metaclaw.data_formatter import batch_to_datums, compute_advantages, ConversationSample + + batch = [] + for _, sample in samples: + if isinstance(sample, list): + batch.extend(sample) + elif isinstance(sample, ConversationSample): + batch.append(sample) + + if batch: + advantages = compute_advantages(batch) + datums = batch_to_datums(batch, advantages, sdk=sdk) + check("Built datums from proxy samples", len(datums) > 0, f"{len(datums)} datums") + + t0 = time.time() + await tc.forward_backward_async(datums, loss_fn="grpo") + await tc.optim_step_async(sdk.AdamParams(learning_rate=1e-4)) + check("Training step", True, f"{time.time()-t0:.2f}s") + + new_sc = await tc.save_weights_and_get_sampling_client_async() + server.update_sampling_client(new_sc) + check("Hot-swap sampling client", server._sampling_client is new_sc) + + # Verify inference still works after swap + body["messages"][1]["content"] = "Quick test after weight update" + try: + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + f"http://127.0.0.1:{PROXY_PORT}/v1/chat/completions", + json=body, + headers={ + "X-Session-Id": "smoke-post-swap", + "X-Turn-Type": "main", + "X-Session-Done": "true", + }, + ) + data = resp.json() + if "response" in data: + data = data["response"] + post_swap_content = data["choices"][0]["message"]["content"] + check( + "Inference after hot-swap", + resp.status_code == 200 and len(post_swap_content) > 0, + f"{len(post_swap_content)} chars", + ) + except Exception as e: + check("Inference after hot-swap", False, str(e)) + else: + check("Built datums from proxy samples", False, "no ConversationSamples found") + else: + print(" (skipping training — no samples collected)") + + # ── Cleanup ─────────────────────────────────────────────────── + server.stop() + import shutil + shutil.rmtree(record_dir, ignore_errors=True) + + print(f"\n{'='*60}") + print(f" {passed} passed, {failed} failed") + print(f"{'='*60}") + sys.exit(1 if failed else 0) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_mlx_backend.py b/tests/test_mlx_backend.py new file mode 100644 index 00000000..3f1d0cb2 --- /dev/null +++ b/tests/test_mlx_backend.py @@ -0,0 +1,237 @@ +"""Unit tests for the MLX backend integration into sdk_backend.py. + +Mirrors the patterns in test_sdk_backend.py. All tests mock external +modules so they run on any platform (no Apple Silicon required). + +Requires: the modified metaclaw/sdk_backend.py with MLX support. +""" + +from __future__ import annotations + +import types +import pytest +from unittest.mock import MagicMock + +from metaclaw.sdk_backend import ( + SDKBackend, + _normalize_backend_name, + infer_backend_key, + resolve_api_key, + resolve_base_url, + resolve_sdk_backend, +) + + +# ------------------------------------------------------------------ # +# Helpers (same pattern as test_sdk_backend.py) # +# ------------------------------------------------------------------ # + +def _fake_find_spec_factory(*available): + """Return a find_spec that reports *available* modules as importable.""" + def _find_spec(name): + return MagicMock() if name in available else None + return _find_spec + + +def _cfg(**overrides): + """Build a minimal config namespace with sensible defaults.""" + defaults = dict( + backend="auto", + api_key="", + base_url="", + tinker_api_key="", + tinker_base_url="", + ) + defaults.update(overrides) + return types.SimpleNamespace(**defaults) + + +# ------------------------------------------------------------------ # +# Validation # +# ------------------------------------------------------------------ # + +def test_mlx_is_valid_backend_name(): + assert _normalize_backend_name("mlx") == "mlx" + + +def test_mlx_is_case_insensitive(): + assert _normalize_backend_name("MLX") == "mlx" + assert _normalize_backend_name(" Mlx ") == "mlx" + + +# ------------------------------------------------------------------ # +# Explicit backend="mlx" # +# ------------------------------------------------------------------ # + +def test_resolve_sdk_backend_explicit_mlx(monkeypatch): + mlx_module = types.SimpleNamespace(__name__="metaclaw.mlx_backend") + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("mlx", "mlx_lm"), + ) + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.import_module", + lambda name: mlx_module if name == "metaclaw.mlx_backend" else None, + ) + + backend = resolve_sdk_backend(_cfg(backend="mlx")) + + assert backend.key == "mlx" + assert backend.label == "MLX" + assert backend.import_name == "metaclaw.mlx_backend" + assert backend.module is mlx_module + assert backend.api_key == "" + assert backend.base_url == "" + assert backend.banner == "MLX local RL" + + +def test_explicit_mlx_ignores_api_key(): + """MLX is local — api_key should always resolve to empty.""" + assert resolve_api_key(_cfg(backend="mlx", api_key="sk-should-ignore"), "mlx") == "" + + +def test_explicit_mlx_passes_base_url_as_model_path(): + """base_url can carry the MLX model path when set explicitly.""" + url = resolve_base_url( + _cfg(backend="mlx", base_url="mlx-community/Qwen2.5-7B-4bit"), "mlx" + ) + assert url == "mlx-community/Qwen2.5-7B-4bit" + + +# ------------------------------------------------------------------ # +# Auto-detection # +# ------------------------------------------------------------------ # + +def test_auto_does_not_select_mlx_without_signal(monkeypatch): + """Without env or config signal, auto should NOT pick MLX.""" + monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) + monkeypatch.delenv("MINT_API_KEY", raising=False) + monkeypatch.delenv("MINT_BASE_URL", raising=False) + monkeypatch.delenv("TINKER_API_KEY", raising=False) + monkeypatch.delenv("TINKER_BASE_URL", raising=False) + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("mlx", "mlx_lm", "tinker"), + ) + + assert infer_backend_key(_cfg()) == "tinker" + + +def test_auto_selects_mlx_when_env_set(monkeypatch): + """METACLAW_RL_BACKEND=mlx should trigger MLX selection in auto mode.""" + monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") + monkeypatch.delenv("MINT_API_KEY", raising=False) + monkeypatch.delenv("MINT_BASE_URL", raising=False) + monkeypatch.delenv("TINKER_API_KEY", raising=False) + monkeypatch.delenv("TINKER_BASE_URL", raising=False) + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("mlx", "mlx_lm"), + ) + + assert infer_backend_key(_cfg()) == "mlx" + + +def test_auto_skips_mlx_when_mlx_lm_missing(monkeypatch): + """If mlx is importable but mlx_lm is not, auto should skip MLX.""" + monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") + monkeypatch.delenv("MINT_API_KEY", raising=False) + monkeypatch.delenv("MINT_BASE_URL", raising=False) + monkeypatch.delenv("TINKER_API_KEY", raising=False) + monkeypatch.delenv("TINKER_BASE_URL", raising=False) + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("mlx", "tinker"), # mlx_lm missing + ) + + # Should fall through to tinker since _mlx_available() is False + assert infer_backend_key(_cfg()) == "tinker" + + +# ------------------------------------------------------------------ # +# Error handling # +# ------------------------------------------------------------------ # + +def test_explicit_mlx_missing_deps_raises(monkeypatch): + """backend=mlx without mlx/mlx_lm installed should raise RuntimeError.""" + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("tinker"), # no mlx at all + ) + + with pytest.raises(RuntimeError, match="Apple Silicon"): + resolve_sdk_backend(_cfg(backend="mlx")) + + +# ------------------------------------------------------------------ # +# Priority: explicit > mlx > mint > tinker # +# ------------------------------------------------------------------ # + +def test_explicit_mint_wins_over_mlx_env(monkeypatch): + """Explicit backend=mint should override METACLAW_RL_BACKEND=mlx.""" + monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") + + assert infer_backend_key(_cfg(backend="mint")) == "mint" + + +def test_explicit_tinker_wins_over_mlx_env(monkeypatch): + """Explicit backend=tinker should override METACLAW_RL_BACKEND=mlx.""" + monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") + + assert infer_backend_key(_cfg(backend="tinker")) == "tinker" + + +# ------------------------------------------------------------------ # +# Existing backends still work # +# ------------------------------------------------------------------ # + +def test_tinker_still_resolves(monkeypatch): + tinker_module = types.SimpleNamespace(__name__="tinker") + monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) + monkeypatch.delenv("MINT_API_KEY", raising=False) + monkeypatch.delenv("MINT_BASE_URL", raising=False) + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("tinker"), + ) + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.import_module", + lambda name: tinker_module if name == "tinker" else None, + ) + + backend = resolve_sdk_backend( + _cfg(backend="tinker", api_key="sk-tinker-123", base_url="https://api.tinker.example/v1") + ) + + assert backend.key == "tinker" + assert backend.module is tinker_module + assert backend.api_key == "sk-tinker-123" + assert backend.banner == "Tinker cloud RL" + + +def test_mint_still_resolves(monkeypatch): + mint_module = types.SimpleNamespace(__name__="mint") + monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) + + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.util.find_spec", + _fake_find_spec_factory("mint"), + ) + monkeypatch.setattr( + "metaclaw.sdk_backend.importlib.import_module", + lambda name: mint_module if name == "mint" else None, + ) + + backend = resolve_sdk_backend( + _cfg(backend="mint", api_key="sk-mint-123", base_url="https://mint.macaron.xin/") + ) + + assert backend.key == "mint" + assert backend.module is mint_module + assert backend.api_key == "sk-mint-123" + assert backend.banner == "MinT cloud RL" diff --git a/tests/test_mlx_integration.py b/tests/test_mlx_integration.py new file mode 100644 index 00000000..918c28ec --- /dev/null +++ b/tests/test_mlx_integration.py @@ -0,0 +1,423 @@ +""" +Integration tests for the MLX LoRA training backend. + +Runs the real MLX backend against a small model to verify the full +trainer.py contract: ServiceClient → TrainingClient → SamplingClient. + +Skipped automatically on machines without Apple Silicon / MLX. + +Run via pytest: + pytest tests/test_mlx_integration.py -v + pytest tests/test_mlx_integration.py -v -k smoke + pytest tests/test_mlx_integration.py -v -k training + pytest tests/test_mlx_integration.py -v -k e2e + +Run standalone: + python tests/test_mlx_integration.py smoke + python tests/test_mlx_integration.py training + python tests/test_mlx_integration.py e2e + python tests/test_mlx_integration.py all + python tests/test_mlx_integration.py all --model mlx-community/Llama-3.2-3B-Instruct-4bit +""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +import time +from pathlib import Path + +import pytest + +pytestmark = pytest.mark.skipif( + not (importlib.util.find_spec("mlx") and importlib.util.find_spec("mlx_lm")), + reason="requires Apple Silicon with mlx and mlx_lm installed", +) + +TEST_MODEL = "mlx-community/Qwen2.5-0.5B-Instruct-4bit" +OUTPUT_DIR = "/tmp/mlx_metaclaw_tests" + + +# ------------------------------------------------------------------ # +# Fixtures # +# ------------------------------------------------------------------ # + +@pytest.fixture(scope="module") +def training_client(): + """Create a TrainingClient once for the whole module.""" + from metaclaw.mlx_backend import ServiceClient + + async def _create(): + client = ServiceClient(output_dir=OUTPUT_DIR) + return await client.create_lora_training_client_async( + base_model=TEST_MODEL, rank=8 + ) + + return asyncio.get_event_loop().run_until_complete(_create()) + + +@pytest.fixture(scope="module") +def tokenizer(training_client): + return training_client._tokenizer + + +# ------------------------------------------------------------------ # +# Helpers # +# ------------------------------------------------------------------ # + +def _make_datums(tokenizer, sdk, prompts_and_responses): + """Build Datum objects the same way data_formatter.py does.""" + import torch + + datums = [] + for prompt, response, advantage in prompts_and_responses: + p_ids = tokenizer.encode(prompt) + r_ids = tokenizer.encode(response) + all_ids = p_ids + r_ids + T = len(all_ids) - 1 + if T <= 0: + continue + + advantages = [0.0] * (len(p_ids) - 1) + [advantage] * len(r_ids) + logprobs = [0.0] * (len(p_ids) - 1) + [-0.5] * len(r_ids) + + def _fit(lst, length): + return (lst[:length] + [0.0] * max(0, length - len(lst)))[:length] + + datums.append(sdk.Datum( + model_input=sdk.ModelInput.from_ints(all_ids[:-1]), + loss_fn_inputs={ + "target_tokens": sdk.TensorData.from_torch( + torch.tensor(all_ids[1:], dtype=torch.long) + ), + "logprobs": sdk.TensorData.from_torch( + torch.tensor(_fit(logprobs, T), dtype=torch.float32) + ), + "advantages": sdk.TensorData.from_torch( + torch.tensor(_fit(advantages, T), dtype=torch.float32) + ), + }, + )) + return datums + + +def _make_batch(tokenizer): + """Build ConversationSamples as the rollout worker would.""" + from metaclaw.data_formatter import ConversationSample + + samples = [] + for i, (q, a, reward) in enumerate([ + ("What is 2+2?", "4", 1.0), + ("Capital of France?", "Paris", 1.0), + ("Solve x^2=9", "x=2", -1.0), + ("Hello", "Hi there!", 0.0), + ]): + p_ids = tokenizer.encode(q) + r_ids = tokenizer.encode(a) + samples.append(ConversationSample( + session_id=f"test_{i}", + turn_num=0, + prompt_tokens=p_ids, + response_tokens=r_ids, + response_logprobs=[-0.5] * len(r_ids), + loss_mask=[1] * len(r_ids), + reward=reward, + )) + return samples + + +# ------------------------------------------------------------------ # +# Smoke tests # +# ------------------------------------------------------------------ # + +class TestSmoke: + def test_model_loads(self, training_client): + assert training_client._model is not None + assert training_client._tokenizer is not None + + def test_lora_injected(self, training_client): + + from mlx.utils import tree_flatten + n_train = sum(p.size for _, p in tree_flatten(training_client._model.trainable_parameters())) + assert n_train > 0, "No trainable (LoRA) parameters found" + + def test_forward_pass(self, training_client): + import mlx.core as mx + + tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32) + logits = training_client._model(tokens) + assert len(logits.shape) == 3 + + def test_initial_sampling_client(self, training_client): + sc = asyncio.get_event_loop().run_until_complete( + training_client.save_weights_and_get_sampling_client_async() + ) + assert sc.model is not None + assert sc.tokenizer is not None + assert sc.adapter_path is not None + assert (sc.adapter_path / "adapters.safetensors").exists() + + +# ------------------------------------------------------------------ # +# Training step tests # +# ------------------------------------------------------------------ # + +class TestTraining: + def test_forward_backward(self, training_client, tokenizer): + import metaclaw.mlx_backend as sdk + + datums = _make_datums(tokenizer, sdk, [ + ("What is 2+2?", "4", 1.0), + ("Capital of France?", "London", -1.0), + ]) + assert len(datums) == 2 + + asyncio.get_event_loop().run_until_complete( + training_client.forward_backward_async(datums, loss_fn="policy_gradient") + ) + assert training_client._grads is not None + assert len(training_client._grads) > 0 + + def test_optim_step(self, training_client): + from metaclaw.mlx_backend import AdamParams + + asyncio.get_event_loop().run_until_complete( + training_client.optim_step_async(AdamParams(learning_rate=1e-4)) + ) + assert training_client._grads is None + assert training_client._step_count >= 1 + + def test_save_weights_returns_sampling_client(self, training_client): + sc = asyncio.get_event_loop().run_until_complete( + training_client.save_weights_and_get_sampling_client_async( + name="test_adapter" + ) + ) + assert sc.adapter_path is not None + assert (sc.adapter_path / "adapters.safetensors").exists() + + def test_checkpoint_roundtrip(self, training_client): + async def _run(): + result = await training_client.save_state_async(name="test_ckpt") + ckpt = Path(result.path) + assert (ckpt / "adapters.safetensors").exists() + assert (ckpt / "optimizer_state.npz").exists() + assert (ckpt / "meta.json").exists() + + step_before = training_client._step_count + await training_client.load_state_async(result.path) + assert training_client._step_count == step_before + + asyncio.get_event_loop().run_until_complete(_run()) + + +# ------------------------------------------------------------------ # +# End-to-end: uses real data_formatter.py like trainer.py does # +# ------------------------------------------------------------------ # + +class TestEndToEnd: + def test_three_step_training_loop(self, training_client, tokenizer): + """Simulate trainer.py _train_on_batch for 3 steps.""" + import metaclaw.mlx_backend as sdk + from metaclaw.data_formatter import batch_to_datums, compute_advantages + + batch = _make_batch(tokenizer) + + async def _run(): + sampling_client = None + for step in range(1, 4): + advantages = compute_advantages(batch) + datums = batch_to_datums(batch, advantages, sdk=sdk) + assert len(datums) > 0, f"Step {step}: no datums" + + await training_client.forward_backward_async(datums, loss_fn="grpo") + await training_client.optim_step_async( + sdk.AdamParams(learning_rate=1e-4) + ) + sampling_client = ( + await training_client.save_weights_and_get_sampling_client_async( + name="openclaw_lora" + ) + ) + + assert sampling_client is not None + assert sampling_client.adapter_path is not None + assert (sampling_client.adapter_path / "adapters.safetensors").exists() + + result = await training_client.save_state_async(name="e2e_final") + assert Path(result.path).exists() + + asyncio.get_event_loop().run_until_complete(_run()) + + +# ------------------------------------------------------------------ # +# Standalone CLI runner # +# ------------------------------------------------------------------ # + +PASS = "\033[32m✓\033[0m" +FAIL = "\033[31m✗\033[0m" + + +def _cli_run(model_id: str = TEST_MODEL): + """Run all tests outside pytest with human-readable output.""" + import metaclaw.mlx_backend as sdk + from metaclaw.mlx_backend import ServiceClient, AdamParams + from metaclaw.data_formatter import batch_to_datums, compute_advantages + import mlx.core as mx + + passed, failed = 0, 0 + + def _check(name: str, condition: bool, detail: str = ""): + nonlocal passed, failed + if condition: + print(f" {PASS} {name}" + (f" ({detail})" if detail else "")) + passed += 1 + else: + print(f" {FAIL} {name}" + (f" ({detail})" if detail else "")) + failed += 1 + + async def _run(): + # ---- Smoke ------------------------------------------------ + print(f"\n{'='*60}") + print(f" SMOKE — {model_id}") + print(f"{'='*60}") + + t0 = time.time() + client = ServiceClient(output_dir=OUTPUT_DIR) + tc = await client.create_lora_training_client_async( + base_model=model_id, rank=8 + ) + _check("Model loaded", tc._model is not None, f"{time.time()-t0:.1f}s") + + from mlx.utils import tree_flatten as tf + n_train = sum(p.size for _, p in tf(tc._model.trainable_parameters())) + n_total = sum(p.size for _, p in tf(tc._model.parameters())) + _check("LoRA injected", n_train > 0, f"{n_train:,} trainable / {n_total:,} total params") + + tokens = mx.array([[1, 2, 3, 4, 5]], dtype=mx.int32) + logits = tc._model(tokens) + _check("Forward pass", len(logits.shape) == 3, f"shape={logits.shape}") + + sc = await tc.save_weights_and_get_sampling_client_async() + _check("Initial SamplingClient", sc.adapter_path is not None) + + # ---- Training --------------------------------------------- + print(f"\n{'='*60}") + print(f" TRAINING") + print(f"{'='*60}") + + datums = _make_datums(tc._tokenizer, sdk, [ + ("What is 2+2?", "4", 1.0), + ("Capital of France?", "London", -1.0), + ]) + _check("Built datums", len(datums) == 2) + + t0 = time.time() + await tc.forward_backward_async(datums, loss_fn="policy_gradient") + _check("forward_backward_async", tc._grads is not None, f"{time.time()-t0:.2f}s") + + t0 = time.time() + await tc.optim_step_async(AdamParams(learning_rate=1e-4)) + _check("optim_step_async", tc._grads is None, f"{time.time()-t0:.2f}s") + + sc = await tc.save_weights_and_get_sampling_client_async(name="test_adapter") + _check("save_weights", (sc.adapter_path / "adapters.safetensors").exists()) + + result = await tc.save_state_async(name="test_ckpt") + ckpt = Path(result.path) + _check("save_state", all( + (ckpt / f).exists() for f in ("adapters.safetensors", "optimizer_state.npz", "meta.json") + ), result.path) + + step_before = tc._step_count + await tc.load_state_async(result.path) + _check("load_state roundtrip", tc._step_count == step_before) + + + # ---- INFERENCE (sample_async) -------------------------------- + print("\n" + "=" * 60) + print(f" INFERENCE (sample_async)") + print("=" * 60) + + from metaclaw.mlx_backend import EncodedTextChunk, SamplingParams as SP + from metaclaw.mlx_backend import ModelInput as MI + + prompt_text = "Hello, how are you?" + prompt_ids = list(tc._tokenizer.encode(prompt_text, add_special_tokens=False)) + chunk = EncodedTextChunk(tokens=prompt_ids, type="encoded_text") + model_input = MI(chunks=[chunk]) + + _check("ModelInput.get_token_ids()", model_input.get_token_ids() == prompt_ids, f"{len(prompt_ids)} tokens") + + sampling_params = SP(temperature=0.7, max_tokens=32, top_k=50, top_p=0.95) + + t0 = time.time() + response = await sc.sample_async( + prompt=model_input, + num_samples=1, + sampling_params=sampling_params, + include_prompt_logprobs=False, + top_k_prompt_logprobs=0, + ) + inf_time = time.time() - t0 + + seq = response.sequences[0] + _check( + "sample_async", + len(seq.tokens) > 0 and len(seq.logprobs) == len(seq.tokens), + f"{len(seq.tokens)} tokens, stop={seq.stop_reason}, {inf_time:.2f}s", + ) + + decoded = tc._tokenizer.decode(seq.tokens, skip_special_tokens=True) + _check("decode response", len(decoded) > 0, f"{repr(decoded[:80])}") + + # ---- End-to-end ------------------------------------------- + print(f"\n{'='*60}") + print(f" END-TO-END (3-step training loop)") + print(f"{'='*60}") + + batch = _make_batch(tc._tokenizer) + _check("Built batch", len(batch) == 4, f"{len(batch)} samples") + + for step in range(1, 4): + t0 = time.time() + advantages = compute_advantages(batch) + datums = batch_to_datums(batch, advantages, sdk=sdk) + + await tc.forward_backward_async(datums, loss_fn="grpo") + await tc.optim_step_async(AdamParams(learning_rate=1e-4)) + sc = await tc.save_weights_and_get_sampling_client_async(name="openclaw_lora") + + rewards = [s.reward for s in batch] + mean_r = sum(rewards) / len(rewards) + _check( + f"Step {step}/3", + sc.adapter_path is not None, + f"datums={len(datums)} mean_r={mean_r:.2f} {time.time()-t0:.2f}s" + ) + + result = await tc.save_state_async(name="e2e_final") + _check("Final checkpoint", Path(result.path).exists(), result.path) + + asyncio.run(_run()) + + print(f"\n{'='*60}") + if failed == 0: + print(f" \033[32m{passed} passed, 0 failed\033[0m") + else: + print(f" \033[31m{passed} passed, {failed} failed\033[0m") + print(f"{'='*60}") + return failed + + +if __name__ == "__main__": + args = sys.argv[1:] + model = TEST_MODEL + + if "--model" in args: + idx = args.index("--model") + model = args[idx + 1] + args = [a for i, a in enumerate(args) if i not in (idx, idx + 1)] + + sys.exit(_cli_run(model)) From fbcadcf56c92b9d762cdc9d0de737bdd5dbc9cdf Mon Sep 17 00:00:00 2001 From: Alvin Chang Date: Wed, 18 Mar 2026 15:05:16 +0000 Subject: [PATCH 2/5] Resolve comments. --- metaclaw/sdk_backend.py | 4 -- tests/test_mlx_backend.py | 123 -------------------------------------- 2 files changed, 127 deletions(-) diff --git a/metaclaw/sdk_backend.py b/metaclaw/sdk_backend.py index b55e8c4d..44351602 100644 --- a/metaclaw/sdk_backend.py +++ b/metaclaw/sdk_backend.py @@ -146,8 +146,6 @@ def infer_backend_key(config: "MetaClawConfig") -> str: def resolve_api_key(config: "MetaClawConfig", backend_key: str | None = None) -> str: - if backend_key == "mlx": - return "" configured = configured_api_key(config) if configured: return configured @@ -156,8 +154,6 @@ def resolve_api_key(config: "MetaClawConfig", backend_key: str | None = None) -> def resolve_base_url(config: "MetaClawConfig", backend_key: str | None = None) -> str: - if backend_key == "mlx": - return "" configured = configured_base_url(config) if configured: return configured diff --git a/tests/test_mlx_backend.py b/tests/test_mlx_backend.py index 3f1d0cb2..dbec009d 100644 --- a/tests/test_mlx_backend.py +++ b/tests/test_mlx_backend.py @@ -99,59 +99,6 @@ def test_explicit_mlx_passes_base_url_as_model_path(): assert url == "mlx-community/Qwen2.5-7B-4bit" -# ------------------------------------------------------------------ # -# Auto-detection # -# ------------------------------------------------------------------ # - -def test_auto_does_not_select_mlx_without_signal(monkeypatch): - """Without env or config signal, auto should NOT pick MLX.""" - monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) - monkeypatch.delenv("MINT_API_KEY", raising=False) - monkeypatch.delenv("MINT_BASE_URL", raising=False) - monkeypatch.delenv("TINKER_API_KEY", raising=False) - monkeypatch.delenv("TINKER_BASE_URL", raising=False) - - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.util.find_spec", - _fake_find_spec_factory("mlx", "mlx_lm", "tinker"), - ) - - assert infer_backend_key(_cfg()) == "tinker" - - -def test_auto_selects_mlx_when_env_set(monkeypatch): - """METACLAW_RL_BACKEND=mlx should trigger MLX selection in auto mode.""" - monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") - monkeypatch.delenv("MINT_API_KEY", raising=False) - monkeypatch.delenv("MINT_BASE_URL", raising=False) - monkeypatch.delenv("TINKER_API_KEY", raising=False) - monkeypatch.delenv("TINKER_BASE_URL", raising=False) - - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.util.find_spec", - _fake_find_spec_factory("mlx", "mlx_lm"), - ) - - assert infer_backend_key(_cfg()) == "mlx" - - -def test_auto_skips_mlx_when_mlx_lm_missing(monkeypatch): - """If mlx is importable but mlx_lm is not, auto should skip MLX.""" - monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") - monkeypatch.delenv("MINT_API_KEY", raising=False) - monkeypatch.delenv("MINT_BASE_URL", raising=False) - monkeypatch.delenv("TINKER_API_KEY", raising=False) - monkeypatch.delenv("TINKER_BASE_URL", raising=False) - - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.util.find_spec", - _fake_find_spec_factory("mlx", "tinker"), # mlx_lm missing - ) - - # Should fall through to tinker since _mlx_available() is False - assert infer_backend_key(_cfg()) == "tinker" - - # ------------------------------------------------------------------ # # Error handling # # ------------------------------------------------------------------ # @@ -165,73 +112,3 @@ def test_explicit_mlx_missing_deps_raises(monkeypatch): with pytest.raises(RuntimeError, match="Apple Silicon"): resolve_sdk_backend(_cfg(backend="mlx")) - - -# ------------------------------------------------------------------ # -# Priority: explicit > mlx > mint > tinker # -# ------------------------------------------------------------------ # - -def test_explicit_mint_wins_over_mlx_env(monkeypatch): - """Explicit backend=mint should override METACLAW_RL_BACKEND=mlx.""" - monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") - - assert infer_backend_key(_cfg(backend="mint")) == "mint" - - -def test_explicit_tinker_wins_over_mlx_env(monkeypatch): - """Explicit backend=tinker should override METACLAW_RL_BACKEND=mlx.""" - monkeypatch.setenv("METACLAW_RL_BACKEND", "mlx") - - assert infer_backend_key(_cfg(backend="tinker")) == "tinker" - - -# ------------------------------------------------------------------ # -# Existing backends still work # -# ------------------------------------------------------------------ # - -def test_tinker_still_resolves(monkeypatch): - tinker_module = types.SimpleNamespace(__name__="tinker") - monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) - monkeypatch.delenv("MINT_API_KEY", raising=False) - monkeypatch.delenv("MINT_BASE_URL", raising=False) - - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.util.find_spec", - _fake_find_spec_factory("tinker"), - ) - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.import_module", - lambda name: tinker_module if name == "tinker" else None, - ) - - backend = resolve_sdk_backend( - _cfg(backend="tinker", api_key="sk-tinker-123", base_url="https://api.tinker.example/v1") - ) - - assert backend.key == "tinker" - assert backend.module is tinker_module - assert backend.api_key == "sk-tinker-123" - assert backend.banner == "Tinker cloud RL" - - -def test_mint_still_resolves(monkeypatch): - mint_module = types.SimpleNamespace(__name__="mint") - monkeypatch.delenv("METACLAW_RL_BACKEND", raising=False) - - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.util.find_spec", - _fake_find_spec_factory("mint"), - ) - monkeypatch.setattr( - "metaclaw.sdk_backend.importlib.import_module", - lambda name: mint_module if name == "mint" else None, - ) - - backend = resolve_sdk_backend( - _cfg(backend="mint", api_key="sk-mint-123", base_url="https://mint.macaron.xin/") - ) - - assert backend.key == "mint" - assert backend.module is mint_module - assert backend.api_key == "sk-mint-123" - assert backend.banner == "MinT cloud RL" From c016ef70f2c29e1640809ccac23c8e4ac32f4a4f Mon Sep 17 00:00:00 2001 From: Alvin Chang Date: Mon, 23 Mar 2026 19:10:45 +0000 Subject: [PATCH 3/5] docs: add MLX backend (Apple Silicon) section to README --- README.md | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/README.md b/README.md index 4f6d409e..55d67667 100644 --- a/README.md +++ b/README.md @@ -405,6 +405,56 @@ Each `ConversationSample` is tagged with a `skill_generation` version. When skil --- +## 🍎 MLX Backend (Apple Silicon Local RL) + +MetaClaw supports **local RL training on Apple Silicon Macs** via the MLX backend. This enables RL training without cloud GPU instances — everything runs locally on your M-series chip. + +### Quick Start + +```bash +# Install with MLX extras +pip install -e ".[mlx]" + +# Configure +metaclaw setup # select backend → mlx + +# Or via env +export METACLAW_RL_BACKEND=mlx +metaclaw start +``` + +### Configuration + +Add to `config.py`: +```python +# MLX backend settings +mlx_model_path: str = "" # local path or HF repo (e.g. mlx-community/Qwen2.5-7B-4bit) +mlx_output_dir: str = "./mlx_metaclaw_output" +``` + +In `setup_wizard.py`, add `"mlx"` to the backend selection list: +```python +# Before: +["auto", "tinker", "mint"], +# After: +["auto", "tinker", "mint", "mlx"], +``` + +### Requirements + +- Apple Silicon Mac (M1/M2/M3/M4) +- macOS 13+ +- Python 3.10+ +- `mlx>=0.22.0`, `mlx-lm>=0.21.0`, `safetensors` + +### Architecture + +The MLX backend implements the same `ServiceClient`, `SamplingClient`, and `LoraTrainingClient` interfaces as the cloud backends, ensuring full compatibility with the MetaClaw training pipeline. + +See [`INTEGRATION_NOTES.md`](INTEGRATION_NOTES.md) for the full integration guide. + +--- + ## 🙏 Acknowledgements MetaClaw builds on top of the following open-source projects: From 906b59280ad7192840fecb2aea3bba68c2bb288e Mon Sep 17 00:00:00 2001 From: Alvin Chang <1977968+alvin-chang@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:13:33 +0000 Subject: [PATCH 4/5] docs: integrate MLX backend details from INTEGRATION_NOTES.md - Replace reference to INTEGRATION_NOTES.md with detailed MLX backend integration information directly in README.md - Add comprehensive MLX backend details including file structure, configuration changes, and setup requirements - Create CLAUDE.md with project guidance for Claude Code instances Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 46 ++++++++++++++++++++++- 2 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..fb5b44fc --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,108 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +MetaClaw is an agent that meta-learns and evolves in the wild. It places your model behind a proxy that intercepts interactions from personal agents (OpenClaw, CoPaw, IronClaw, etc.), injects relevant skills at each turn, and meta-learns from accumulated experience. The system supports three operating modes: +- `skills_only`: Lightweight proxy with skill injection (no GPU required) +- `rl`: Skills + RL training with GRPO algorithm +- `madmax`: Skills + RL + smart scheduler (default mode) + +## Key Architecture Components + +- **API Server** (`api_server.py`): Main proxy server that intercepts LLM requests and injects skills +- **CLI** (`cli.py`): Command-line interface for setup, start, stop, and configuration +- **Configuration** (`config.py`, `config_store.py`): Dataclass-based config system with YAML storage +- **Skill Management** (`skill_manager.py`, `skill_evolver.py`): Handles skill retrieval and evolution +- **Claw Adapters** (`claw_adapter.py`): Integration with various personal agents (OpenClaw, CoPaw, etc.) +- **Training System** (`trainer.py`, `sdk_backend.py`): RL training with support for Tinker, MinT, and MLX backends +- **Scheduler** (`scheduler.py`): Manages RL training during idle/sleep windows to avoid interrupting active use + +## MLX Backend Integration + +The repository includes MLX backend support for Apple Silicon Macs, allowing local RL training without cloud GPUs. Key MLX files are located in `metaclaw/mlx_backend/`: +- `__init__.py` - Package initialization +- `data_types.py` - MLX-specific data structures +- `params.py` - MLX training parameters +- `lora.py` - LoRA implementation for MLX +- `service_client.py` - MLX implementation of the service client interface + +To use MLX backend, configurations in `metaclaw/config.py` may need updates: +- Add `mlx_model_path` and `mlx_output_dir` settings to MetaClawConfig class +- Update `training_backend_label()` and `training_backend_banner()` methods to handle "mlx" backend +- Add "mlx" to the backend selection list in the setup wizard + +## Common Commands + +```bash +# One-time setup +metaclaw setup + +# Start MetaClaw (default: madmax mode) +metaclaw start + +# Start in background +metaclaw start --daemon + +# Start with specific mode +metaclaw start --mode rl # RL mode only +metaclaw start --mode skills_only # Skills only mode + +# Stop running instance +metaclaw stop + +# Check status +metaclaw status + +# View configuration +metaclaw config show + +# Set configuration values +metaclaw config KEY VALUE + +# Skill management +metaclaw skills log --n 10 # Show recent skill evolutions + +# Scheduler management +metaclaw scheduler status # Show scheduler state +``` + +## Installation & Dependencies + +```bash +# Basic installation (skills_only mode) +pip install -e . + +# With RL support +pip install -e ".[rl]" + +# With full setup (RL + evolution + scheduler) +pip install -e ".[rl,evolve,scheduler]" + +# With MLX support (Apple Silicon) +pip install -e ".[mlx]" +``` + +## Configuration Structure + +- Main config file: `~/.metaclaw/config.yaml` +- Skills directory: `~/.metaclaw/skills/` +- Recordings: `records/` directory +- MLX output: Configurable via `mlx_output_dir` setting + +## Key Configuration Options + +- `mode`: "madmax" (default), "rl", or "skills_only" +- `claw_type`: "openclaw", "copaw", "ironclaw", "nanoclaw", "nemoclaw", or "none" +- `rl.backend`: "auto", "tinker", "mint", or "mlx" +- `skills.enabled`: Enable/disable skill injection +- `rl.enabled`: Enable/disable RL training +- `scheduler.enabled`: Control meta-learning scheduler + +## Development Workflow + +- Use `metaclaw setup` for initial configuration +- Develop with `metaclaw start` for immediate testing +- Monitor logs and state with `metaclaw status` +- The daemon mode runs in background with logs at `~/.metaclaw/metaclaw.log` \ No newline at end of file diff --git a/README.md b/README.md index e6f835bd..d0b797c1 100644 --- a/README.md +++ b/README.md @@ -460,7 +460,51 @@ In `setup_wizard.py`, add `"mlx"` to the backend selection list: The MLX backend implements the same `ServiceClient`, `SamplingClient`, and `LoraTrainingClient` interfaces as the cloud backends, ensuring full compatibility with the MetaClaw training pipeline. -See [`INTEGRATION_NOTES.md`](INTEGRATION_NOTES.md) for the full integration guide. +### Integration Details + +The MLX backend consists of several key files in `metaclaw/mlx_backend/`: +- `__init__.py` - Package initialization +- `data_types.py` - MLX-specific data structures +- `params.py` - MLX training parameters +- `lora.py` - LoRA implementation for MLX +- `service_client.py` - MLX implementation of the service client interface + +To use MLX backend, you may need to update configurations: + +1. Add to `metaclaw/config.py` in the MetaClawConfig class: +```python + # MLX backend settings + mlx_model_path: str = "" # local path or HF repo (e.g. mlx-community/Qwen2.5-7B-4bit) + mlx_output_dir: str = "./mlx_metaclaw_output" +``` + +2. Update training backend methods around line 186 in `metaclaw/config.py`: +```python +def training_backend_label(self) -> str: + key = self.resolved_backend_key() + if key == "mlx": + return "MLX" + return "MinT" if key == "mint" else "Tinker" + +def training_backend_banner(self) -> str: + label = self.training_backend_label() + suffix = "local RL" if self.resolved_backend_key() == "mlx" else "cloud RL" + return f"{label} {suffix}" +``` + +3. Update `metaclaw/setup_wizard.py` to add "mlx" to the backend selection list: +```python +# In the backend selection, change from: +["auto", "tinker", "mint"] +# To: +["auto", "tinker", "mint", "mlx"] +``` + +### Optional: pyproject.toml extras +```toml +[project.optional-dependencies] +mlx = ["mlx>=0.22.0", "mlx-lm>=0.21.0", "safetensors"] +``` --- From f94b83d9629059926d046510661b9fab4de2e2fc Mon Sep 17 00:00:00 2001 From: Alvin Chang <1977968+alvin-chang@users.noreply.github.com> Date: Tue, 24 Mar 2026 10:15:37 +0000 Subject: [PATCH 5/5] docs: enhance MLX backend documentation in README - Update MLX backend section to include usage instructions - Add installation and configuration details for MLX support - Document how MLX support integrates alongside other backends (Tinker, MinT) - Include configuration options and environment variable usage Co-Authored-By: Claude Opus 4.6 --- README.md | 55 +++++++++++++++++-------------------------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index d0b797c1..978e463f 100644 --- a/README.md +++ b/README.md @@ -460,51 +460,30 @@ In `setup_wizard.py`, add `"mlx"` to the backend selection list: The MLX backend implements the same `ServiceClient`, `SamplingClient`, and `LoraTrainingClient` interfaces as the cloud backends, ensuring full compatibility with the MetaClaw training pipeline. -### Integration Details +The MLX backend implements the same `ServiceClient`, `SamplingClient`, and `LoraTrainingClient` interfaces as the cloud backends, ensuring full compatibility with the MetaClaw training pipeline. -The MLX backend consists of several key files in `metaclaw/mlx_backend/`: -- `__init__.py` - Package initialization -- `data_types.py` - MLX-specific data structures -- `params.py` - MLX training parameters -- `lora.py` - LoRA implementation for MLX -- `service_client.py` - MLX implementation of the service client interface +### Usage -To use MLX backend, you may need to update configurations: +To use the MLX backend, simply configure it during setup or via environment variables: -1. Add to `metaclaw/config.py` in the MetaClawConfig class: -```python - # MLX backend settings - mlx_model_path: str = "" # local path or HF repo (e.g. mlx-community/Qwen2.5-7B-4bit) - mlx_output_dir: str = "./mlx_metaclaw_output" -``` +```bash +# Install with MLX extras +pip install -e ".[mlx]" -2. Update training backend methods around line 186 in `metaclaw/config.py`: -```python -def training_backend_label(self) -> str: - key = self.resolved_backend_key() - if key == "mlx": - return "MLX" - return "MinT" if key == "mint" else "Tinker" - -def training_backend_banner(self) -> str: - label = self.training_backend_label() - suffix = "local RL" if self.resolved_backend_key() == "mlx" else "cloud RL" - return f"{label} {suffix}" -``` +# Configure via setup +metaclaw setup # select backend → mlx -3. Update `metaclaw/setup_wizard.py` to add "mlx" to the backend selection list: -```python -# In the backend selection, change from: -["auto", "tinker", "mint"] -# To: -["auto", "tinker", "mint", "mlx"] +# Or via environment variable +export METACLAW_RL_BACKEND=mlx +metaclaw start ``` -### Optional: pyproject.toml extras -```toml -[project.optional-dependencies] -mlx = ["mlx>=0.22.0", "mlx-lm>=0.21.0", "safetensors"] -``` +The MLX backend enables local RL training on Apple Silicon Macs without requiring cloud GPU instances - everything runs locally on your M-series chip. Configuration options include: + +- `mlx_model_path`: Local path or Hugging Face repo (e.g., mlx-community/Qwen2.5-7B-4bit) +- `mlx_output_dir`: Directory for MLX output (default: ./mlx_metaclaw_output) + +In `config.py`, the backend will be labeled as "MLX local RL" when selected, distinguishing it from "Tinker cloud RL" or "MinT cloud RL". ---