-
Notifications
You must be signed in to change notification settings - Fork 47
Add fleet CLI for W&B-backed dispatch and status #4984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
How to use the Graphite Merge QueueAdd either label to this PR to merge it via the merge queue:
You must have a Graphite account in order to use the merge queue. Sign up using this link. An organization admin has enabled the Graphite Merge Queue in this repository. Please do not merge from GitHub as this will restart CI on PRs being processed by the merge queue. This stack of pull requests is managed by Graphite. Learn more about stacking. |
This comment has been minimized.
This comment has been minimized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d2b34751a8
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if isinstance(payload, list): | ||
| plugs = payload | ||
| server_host = None | ||
| auth_key = None | ||
| elif isinstance(payload, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle list-form smart plug configs without fatal error
When the smart plug config JSON is a list (which this function explicitly allows), server_host and auth_key are set to None here, but the subsequent if not server_host or not auth_key check in _load_smart_plug_config always raises. That means any list-form config (e.g., legacy or documented list-only configs) will fail every time, so metta fleet status --power and power-cycle are unusable with those inputs. Either accept server host/auth key from list configs or reject lists explicitly.
Useful? React with 👍 / 👎.
| def _shell_path(path: str) -> str: | ||
| if path.startswith("~"): | ||
| return f"\"$HOME{path[1:]}\"" | ||
| return shlex.quote(path) | ||
|
|
||
|
|
||
| def _ensure_wandb_key() -> None: | ||
| import wandb | ||
|
|
||
| if not wandb.api.api_key: | ||
| raise RuntimeError("WandB API key not found. Run 'metta install wandb' first.") | ||
|
|
||
|
|
||
| def _fetch_runs(entity: str, project: str, fetch_limit: int): | ||
| import wandb | ||
|
|
||
| api = wandb.Api(timeout=60) | ||
| return api.runs(f"{entity}/{project}", order="-created_at", per_page=fetch_limit) | ||
|
|
||
|
|
||
| def _collect_host_runs( | ||
| entity: str, | ||
| project: str, | ||
| hosts: Optional[list[str]], | ||
| fetch_limit: int, | ||
| ) -> tuple[dict[str, object], set[str]]: | ||
| runs_by_host: dict[str, object] = {} | ||
| run_hosts: set[str] = set() | ||
| host_set = set(hosts) if hosts else None | ||
| scanned = 0 | ||
|
|
||
| for run in _fetch_runs(entity, project, fetch_limit): | ||
| scanned += 1 | ||
| tags = list(getattr(run, "tags", []) or []) | ||
| host = _extract_host_from_tags(tags) | ||
| if not host: | ||
| continue | ||
| run_hosts.add(host) | ||
| if host_set is not None and host not in host_set: | ||
| continue | ||
| if host not in runs_by_host: | ||
| runs_by_host[host] = run | ||
| if host_set is not None and len(runs_by_host) >= len(host_set): | ||
| break | ||
| if scanned >= fetch_limit: | ||
| break | ||
|
|
||
| return runs_by_host, run_hosts | ||
|
|
||
|
|
||
| def _derive_status( | ||
| run: Optional[object], | ||
| stale_minutes: int, | ||
| now: datetime, | ||
| ) -> tuple[str, Optional[datetime]]: | ||
| if run is None: | ||
| return "no-run", None | ||
|
|
||
| summary = getattr(run, "summary", {}) or {} | ||
| last_updated_at = _parse_run_timestamp(summary, run) | ||
| state = getattr(run, "state", None) or "unknown" | ||
| if state == "running" and last_updated_at: | ||
| if now - last_updated_at > timedelta(minutes=stale_minutes): | ||
| return "stalled", last_updated_at | ||
| return state, last_updated_at | ||
|
|
||
|
|
||
| @app.command(name="status") | ||
| def cmd_status( | ||
| hosts: Annotated[ | ||
| Optional[list[str]], | ||
| typer.Option("--host", "-H", help="Host to inspect (repeatable)."), | ||
| ] = None, | ||
| entity: Annotated[str, typer.Option("--entity", "-e", help="WandB entity")] = METTA_WANDB_ENTITY, | ||
| project: Annotated[str, typer.Option("--project", "-p", help="WandB project")] = METTA_WANDB_PROJECT, | ||
| fetch_limit: Annotated[int, typer.Option("--fetch-limit", help="Max runs to fetch from WandB")] = 200, | ||
| stale_minutes: Annotated[int, typer.Option("--stale-min", help="Minutes without heartbeat to mark stalled")] = 20, | ||
| power: Annotated[bool, typer.Option("--power", help="Include smart plug wattage")] = False, | ||
| smart_plug_config: Annotated[ | ||
| Optional[Path], | ||
| typer.Option( | ||
| "--smart-plug-config", | ||
| help="Path to smart_plugs.json (defaults to ~/.config/metta/smart_plugs.json)", | ||
| ), | ||
| ] = None, | ||
| ssh_check: Annotated[bool, typer.Option("--ssh-check", help="Check SSH reachability")] = False, | ||
| ssh_user: Annotated[Optional[str], typer.Option("--ssh-user", help="SSH user (default: metta)")] = "metta", | ||
| json_output: Annotated[bool, typer.Option("--json", help="Emit JSON instead of table")] = False, | ||
| ) -> None: | ||
| _ensure_wandb_key() | ||
|
|
||
| config_path = smart_plug_config or Path("~/.config/metta/smart_plugs.json").expanduser() | ||
| smart_plug_status: Optional[list[SmartPlugStatus]] = None | ||
| if power and config_path.exists(): | ||
| smart_plug_status = _fetch_smart_plug_status(config_path) | ||
| elif power: | ||
| console.print(f"Smart plug config not found: {config_path}") | ||
| raise typer.Exit(1) | ||
|
|
||
| runs_by_host, run_hosts = _collect_host_runs(entity, project, hosts, fetch_limit) | ||
| resolved_hosts = _resolve_hosts( | ||
| hosts, | ||
| config_path if config_path.exists() else None, | ||
| power, | ||
| smart_plug_status, | ||
| run_hosts, | ||
| ) | ||
|
|
||
| if not resolved_hosts and not runs_by_host: | ||
| console.print("No hosts found. Provide --host or ensure smart_plug_config exists.") | ||
| raise typer.Exit(1) | ||
|
|
||
| now = datetime.now(timezone.utc) | ||
| power_by_label = {item.label: item for item in smart_plug_status or []} | ||
| statuses: list[HostStatus] = [] | ||
|
|
||
| for host in resolved_hosts or sorted(runs_by_host.keys()): | ||
| run = runs_by_host.get(host) | ||
| state, last_updated_at = _derive_status(run, stale_minutes, now) | ||
| watts = None | ||
| if power_by_label: | ||
| plug = power_by_label.get(host) | ||
| watts = plug.apower if plug else None | ||
| run_id = getattr(run, "id", None) if run else None | ||
| run_url = _wandb_run_url(entity, project, run_id) if run_id else None | ||
| statuses.append( | ||
| HostStatus( | ||
| host=host, | ||
| run_id=run_id, | ||
| state=state, | ||
| last_updated_at=last_updated_at, | ||
| watts=watts, | ||
| run_url=run_url, | ||
| ) | ||
| ) | ||
|
|
||
| if json_output: | ||
| payload = [] | ||
| for status in statuses: | ||
| payload.append( | ||
| { | ||
| "host": status.host, | ||
| "run_id": status.run_id, | ||
| "state": status.state, | ||
| "last_updated_at": status.last_updated_at.isoformat() if status.last_updated_at else None, | ||
| "watts": status.watts, | ||
| "run_url": status.run_url, | ||
| "ssh_ok": _check_ssh(_format_ssh_target(status.host, ssh_user), 5) | ||
| if ssh_check | ||
| else None, | ||
| } | ||
| ) | ||
| console.print(json.dumps(payload, indent=2)) | ||
| return | ||
|
|
||
| table = Table(title="Metta Fleet Status") | ||
| table.add_column("Host", style="bold") | ||
| table.add_column("Run", style="cyan") | ||
| table.add_column("State") | ||
| table.add_column("Last Update") | ||
| if power: | ||
| table.add_column("Watts") | ||
| if ssh_check: | ||
| table.add_column("SSH") | ||
|
|
||
| for status in statuses: | ||
| age = _format_age(now, status.last_updated_at) | ||
| watts = f"{status.watts:.1f} W" if status.watts is not None else "—" | ||
| row = [status.host, status.run_id or "—", status.state, age] | ||
| if power: | ||
| row.append(watts) | ||
| if ssh_check: | ||
| target = _format_ssh_target(status.host, ssh_user) | ||
| row.append("ok" if _check_ssh(target, 5) else "no") | ||
| table.add_row(*row) | ||
|
|
||
| console.print(table) | ||
|
|
||
|
|
||
| @app.command( | ||
| name="start", | ||
| context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, | ||
| help="Dispatch a run on a remote host via SSH.", | ||
| ) | ||
| def cmd_start( | ||
| ctx: typer.Context, | ||
| host: Annotated[str, typer.Option("--host", "-H", help="Host to dispatch to")], | ||
| recipe: Annotated[ | ||
| str, | ||
| typer.Option("--recipe", "-r", help="Tool path (e.g., recipes.experiment.arena.train)"), | ||
| ], | ||
| run_id: Annotated[Optional[str], typer.Option("--run", help="Run id (optional)")] = None, | ||
| group: Annotated[Optional[str], typer.Option("--group", help="WandB group")] = None, | ||
| tag: Annotated[ | ||
| Optional[list[str]], | ||
| typer.Option("--tag", help="Extra WandB tags (repeatable)"), | ||
| ] = None, | ||
| inject_tags: Annotated[ | ||
| bool, typer.Option("--wandb-tags/--no-wandb-tags", help="Inject WandB tags") | ||
| ] = True, | ||
| repo: Annotated[str, typer.Option("--repo", help="Repo path on remote host")] = "~/metta", | ||
| ssh_user: Annotated[Optional[str], typer.Option("--ssh-user", help="SSH user (default: metta)")] = "metta", | ||
| ssh_check: Annotated[ | ||
| bool, | ||
| typer.Option("--ssh-check/--no-ssh-check", help="Check SSH before dispatch"), | ||
| ] = True, | ||
| detach: Annotated[ | ||
| bool, | ||
| typer.Option("--detach/--no-detach", help="Run via nohup in background"), | ||
| ] = True, | ||
| log_dir: Annotated[str, typer.Option("--log-dir", help="Remote log directory")] = "~/metta_logs", | ||
| dry_run: Annotated[bool, typer.Option("--dry-run", help="Print the command without executing")] = False, | ||
| ) -> None: | ||
| job_id = run_id or f"{host}-{uuid.uuid4().hex[:8]}" | ||
| tags = _build_wandb_tags(host, recipe, job_id, tag) if inject_tags else [] | ||
| wandb_tags = json.dumps(tags) if inject_tags else None | ||
|
|
||
| cmd_parts = ["uv", "run", "./tools/run.py", recipe] | ||
| cmd_parts.extend(ctx.args) | ||
| if run_id: | ||
| cmd_parts.append(f"run={run_id}") | ||
| if group: | ||
| cmd_parts.append(f"wandb.group={group}") | ||
| if wandb_tags is not None: | ||
| cmd_parts.append(f"wandb.tags={wandb_tags}") | ||
|
|
||
| remote_cmd = " ".join(shlex.quote(part) for part in cmd_parts) | ||
| repo_path = _shell_path(repo) | ||
| log_dir_path = _shell_path(log_dir) | ||
| log_path = _shell_path(f"{log_dir.rstrip('/')}/fleet-{job_id}.log") | ||
|
|
||
| if detach: | ||
| remote_cmd = ( | ||
| f"mkdir -p {log_dir_path} && cd {repo_path} && " | ||
| f"nohup {remote_cmd} > {log_path} 2>&1 &" | ||
| ) | ||
| else: | ||
| remote_cmd = f"cd {repo_path} && {remote_cmd}" | ||
|
|
||
| ssh_target = _format_ssh_target(host, ssh_user) | ||
| if ssh_check and not _check_ssh(ssh_target, 5): | ||
| console.print(f"SSH check failed for {ssh_target}") | ||
| raise typer.Exit(1) | ||
|
|
||
| remote_shell_cmd = f"bash -lc {shlex.quote(remote_cmd)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shell variable expansion bug in remote command execution. The _shell_path() function returns "$HOME..." for tilde paths, which gets embedded into remote_cmd. However, on line 485, remote_cmd is wrapped with shlex.quote() before being passed to bash -lc, which will prevent $HOME from being expanded.
When shlex.quote(remote_cmd) is called, it wraps the entire command in single quotes (or escapes special chars), causing $HOME to be treated as a literal string rather than expanded by the shell.
Example failure:
# What gets executed:
ssh user@host "bash -lc 'cd \"\$HOME/metta\" && ...'"
# $HOME is NOT expanded, causing "cd: $HOME/metta: No such file or directory"Fix:
Remove the shlex.quote() call on line 485 since remote_cmd is already properly constructed, or handle tilde expansion differently:
# Option 1: Don't quote the entire command
remote_shell_cmd = f"bash -lc '{remote_cmd}'" # Use plain quotes, not shlex.quote
# Option 2: Expand ~ on remote via shell
def _shell_path(path: str) -> str:
# Let remote shell handle tilde expansion naturally
return shlex.quote(path)Spotted by Graphite Agent
Is this helpful? React 👍 or 👎 to let us know.

Summary
metta fleetCLI for W&B-based status and job dispatchTesting
metta fleet --help