Source code for unilab.cli

"""Thin package CLI for routing to existing UniLab training entrypoints."""

from __future__ import annotations

import argparse
import platform
import re
import shutil
import subprocess
import sys
from dataclasses import dataclass
from importlib.util import find_spec
from pathlib import Path
from typing import Sequence

from unilab.demo import run_demo

SUPPORTED_ALGOS = ("ppo", "mlx_ppo", "appo", "sac", "td3", "flashsac")
SUPPORTED_SIMS = ("mujoco", "motrix")
SUPPORTED_RENDER_MODES = ("auto", "interactive", "record", "none")
OFFPOLICY_ALGOS = {"sac", "td3", "flashsac"}
RESERVED_OVERRIDE_KEYS = {
    "algo",
    "task",
    "training.sim_backend",
    "training.play_only",
}
TASK_NAME_PATTERN = re.compile(r"^[A-Za-z0-9_][A-Za-z0-9_-]*$")
RUN_ID_PATTERN = re.compile(r"^[A-Za-z0-9_.-]+$")


[docs] @dataclass(frozen=True) class Route: script_name: str config_group: str owner_task: str generated_overrides: tuple[str, ...]
[docs] def repo_root() -> Path: return Path(__file__).resolve().parents[2]
def _script_path(route: Route, root: Path) -> Path: return root / "scripts" / route.script_name def _owner_yaml_path(route: Route, root: Path) -> Path: return root / "conf" / route.config_group / "task" / route.owner_task def _check_private_checkout(root: Path) -> None: if not (root / "conf").is_dir() or not (root / "scripts").is_dir(): raise SystemExit( "The current UniLab CLI expects a UniLab source checkout. " "Run it from the uv-managed editable environment created by this repo." ) def _check_reserved_overrides(overrides: Sequence[str]) -> None: reserved = [ override for override in overrides if _override_key(override) in RESERVED_OVERRIDE_KEYS ] if reserved: joined = ", ".join(reserved) raise SystemExit( "Route-defining Hydra overrides must be provided through CLI flags, " f"not passthrough: {joined}" ) def _override_key(override: str) -> str: key = override.split("=", 1)[0].strip() return key.lstrip("+~") def _check_task_name(task: str) -> None: if TASK_NAME_PATTERN.fullmatch(task) is None: raise SystemExit( "--task must be a registry task name such as `go1_joystick`; " "do not include slashes, dots, or path separators." ) def _check_profile(profile: str | None) -> None: if profile is None: return if TASK_NAME_PATTERN.fullmatch(profile) is None: raise SystemExit( "--profile must be a task owner variant such as `hora`; " "do not include slashes, dots, or path separators." ) def _check_load_run(load_run: str) -> None: if load_run == "-1": return if RUN_ID_PATTERN.fullmatch(load_run) is None or load_run in {".", ".."}: raise SystemExit("--load-run must be `-1` or a run directory name, not a path.") def _check_runtime_requirements(algo: str, sim: str) -> None: if algo == "mlx_ppo" and platform.system() != "Darwin": raise SystemExit("mlx_ppo is only supported on macOS; use --algo ppo for torch PPO.") if sim == "motrix" and find_spec("motrixsim") is None: raise SystemExit( "sim=motrix requires the Motrix extra. Install it with `uv sync --extra motrix`." ) def _override_bool(overrides: Sequence[str], key: str) -> bool | None: selected: bool | None = None for override in overrides: if _override_key(override) != key or "=" not in override: continue value = override.split("=", 1)[1].strip().lower() if value in {"true", "1", "yes", "on"}: selected = True elif value in {"false", "0", "no", "off"}: selected = False return selected def _override_value(overrides: Sequence[str], key: str) -> str | None: selected: str | None = None for override in overrides: if _override_key(override) != key or "=" not in override: continue selected = override.split("=", 1)[1].strip() return selected def _needs_motrix_renderer(mode: str, sim: str, overrides: Sequence[str]) -> bool: if sim != "motrix": return False play_render_mode = _override_value(overrides, "training.play_render_mode") if play_render_mode is not None and play_render_mode.strip().lower() in {"none", "record"}: return False if mode == "eval": return True if mode == "train": return _override_bool(overrides, "training.no_play") is not True return False def _python_executable_for_route(mode: str, sim: str, overrides: Sequence[str]) -> str: if platform.system() != "Darwin" or not _needs_motrix_renderer(mode, sim, overrides): return sys.executable return _mxpython_executable() def _mxpython_executable() -> str: if Path(sys.executable).name == "mxpython": return sys.executable mxpython = shutil.which("mxpython") if mxpython is not None: return mxpython venv_mxpython = Path(sys.executable).with_name("mxpython") if venv_mxpython.is_file(): return str(venv_mxpython) raise SystemExit( "macOS Motrix playback uses the native renderer and must be launched with " "`mxpython`. Install the Motrix extra so `mxpython` is on PATH, or use " "`training.no_play=true` for non-rendering training." )
[docs] def build_route(algo: str, task: str, sim: str, profile: str | None = None) -> Route: task_choice: str owner = f"{sim}_{profile}" if profile is not None else sim if algo in OFFPOLICY_ALGOS: task_choice = f"{algo}/{task}/{owner}" return Route( script_name="train_offpolicy.py", config_group="offpolicy", owner_task=f"{algo}/{task}/{owner}.yaml", generated_overrides=(f"algo={algo}", f"task={task_choice}"), ) task_choice = f"{task}/{owner}" if algo == "ppo": return Route( script_name="train_rsl_rl.py", config_group="ppo", owner_task=f"{task}/{owner}.yaml", generated_overrides=(f"task={task_choice}",), ) if algo == "mlx_ppo": return Route( script_name="train_mlx_ppo.py", config_group="ppo", owner_task=f"{task}/{owner}.yaml", generated_overrides=(f"task={task_choice}",), ) if algo == "appo": return Route( script_name="train_appo.py", config_group="appo", owner_task=f"{task}/{owner}.yaml", generated_overrides=(f"task={task_choice}",), ) raise SystemExit(f"Unsupported algo={algo!r}; choose one of: {', '.join(SUPPORTED_ALGOS)}")
[docs] def build_command( *, mode: str, algo: str, task: str, sim: str, overrides: Sequence[str], profile: str | None = None, load_run: str | None = None, render_mode: str | None = None, root: Path | None = None, ) -> list[str]: selected_root = root or repo_root() _check_private_checkout(selected_root) _check_task_name(task) _check_profile(profile) _check_reserved_overrides(overrides) _check_runtime_requirements(algo, sim) route = build_route(algo, task, sim, profile) script = _script_path(route, selected_root) if not script.is_file(): raise SystemExit(f"Entrypoint script not found: {script}") owner_yaml = _owner_yaml_path(route, selected_root) if not owner_yaml.is_file(): raise SystemExit( f"No owner config exists for algo={algo}, task={task}, sim={sim}: {owner_yaml}" ) generated = list(route.generated_overrides) if render_mode is not None: generated.append(f"training.play_render_mode={render_mode}") if mode == "eval": generated.append("training.play_only=true") if load_run is not None: _check_load_run(load_run) if any(_override_key(o) == "algo.load_run" for o in overrides): raise SystemExit("Use either --load-run or algo.load_run=..., not both.") generated.append(f"algo.load_run={load_run}") executable = _python_executable_for_route(mode, sim, (*generated, *overrides)) return [executable, str(script), *generated, *overrides]
def _train_eval_parser(*, mode: str) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog=mode) parser.add_argument("--algo", required=True, choices=SUPPORTED_ALGOS) parser.add_argument("--task", required=True) parser.add_argument("--sim", required=True, choices=SUPPORTED_SIMS) parser.add_argument("--profile", default=None) parser.add_argument("--render-mode", choices=SUPPORTED_RENDER_MODES, default=None) if mode == "eval": parser.add_argument("--load-run", default=None) return parser def _demo_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(prog="demo") parser.add_argument("demo_name") parser.add_argument("--refresh", action="store_true") parser.add_argument("--device", default=None) return parser def _run_train_eval(mode: str, argv: Sequence[str] | None = None) -> int: parser = _train_eval_parser(mode=mode) args, overrides = parser.parse_known_args(argv) command = build_command( mode=mode, algo=args.algo, task=args.task, sim=args.sim, profile=args.profile, overrides=overrides, load_run=getattr(args, "load_run", None), render_mode=args.render_mode, ) return subprocess.run(command, check=False).returncode
[docs] def train_main(argv: Sequence[str] | None = None) -> int: return _run_train_eval("train", argv)
[docs] def eval_main(argv: Sequence[str] | None = None) -> int: return _run_train_eval("eval", argv)
[docs] def demo_main(argv: Sequence[str] | None = None) -> int: parser = _demo_parser() args, overrides = parser.parse_known_args(argv) if overrides: raise SystemExit( f"demo does not accept passthrough Hydra overrides: {', '.join(overrides)}" ) return run_demo( demo_name=args.demo_name, refresh=args.refresh, device=args.device, )
if __name__ == "__main__": raise SystemExit(train_main())