Source code for unilab.algos.torch.hora.appo

"""HORA-owned APPO entry helpers."""

from __future__ import annotations

import os
from collections.abc import Callable
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, cast

import torch
from omegaconf import DictConfig

from unilab.algos.torch.hora.appo_runner import HoraAPPORunner
from unilab.algos.torch.hora.rsl_rl_compat import (
    convert_config_v3_to_v4,
    is_rsl_rl_v4,
    is_rsl_rl_v5,
)
from unilab.base.observations import get_obs_dims
from unilab.training import BackendAdapter, create_env, log_playback_plan

from .models import build_hora_shared_actor_critic
from .observations import build_hora_actor_tensordict, split_hora_obs_with_priv_info
from .runtime import is_hora_appo_runtime


[docs] @dataclass(frozen=True) class HoraAPPORuntime: """Resolved HORA APPO entrypoints used by the generic APPO script. Args: runner_cls: Runner class used for HORA APPO training mode. play_fn: Play-mode callable used for HORA APPO checkpoint playback. Returns: Immutable entrypoint bundle consumed by generic APPO script assembly. """ runner_cls: type[HoraAPPORunner] play_fn: Callable[..., str | None]
[docs] def resolve_hora_appo_runtime(rl_cfg: dict[str, Any]) -> HoraAPPORuntime | None: """Resolve HORA APPO entrypoints from an explicit runtime marker. Args: rl_cfg: Resolved algorithm config dictionary from Hydra composition. Returns: ``HoraAPPORuntime`` when the owner config selects HORA APPO, otherwise ``None``. """ if not is_hora_appo_runtime(rl_cfg): return None return HoraAPPORuntime(runner_cls=HoraAPPORunner, play_fn=play_hora_appo)
def _update_hora_obs_groups( rl_cfg: dict[str, Any], *, obs_dim: int, priv_info_dim: int, ) -> None: """Update grouped actor/critic dims for the HORA APPO runtime. Args: rl_cfg: Mutable algorithm config dictionary to update in place. obs_dim: Actor observation dimension reported by the env contract. priv_info_dim: Privileged-info dimension reported by the env contract. Returns: None. Mutates ``rl_cfg["obs_groups"]`` directly. """ obs_groups = rl_cfg.setdefault("obs_groups", {}) actor_group = obs_groups.setdefault("actor", {}) critic_group = obs_groups.setdefault("critic", {}) if isinstance(actor_group, dict): actor_group["actor"] = obs_dim actor_group["priv_info"] = priv_info_dim if isinstance(critic_group, dict): critic_group["actor"] = obs_dim critic_group["priv_info"] = priv_info_dim
[docs] def play_hora_appo( cfg: DictConfig, rl_cfg: dict[str, Any], *, root_dir, resolve_checkpoint_path, ) -> str | None: """Play HORA APPO checkpoints with grouped actor and privileged inputs.""" import numpy as np from rsl_rl.utils import resolve_callable from tensordict import TensorDict env_cfg_override = BackendAdapter( cfg, root_dir=root_dir, algo_name="appo", ).build_task_env_cfg_override() device = cfg.training.device or ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) print(f"Using device for play: {device}") env = cast( Any, create_env( cfg, num_envs=cfg.training.play_env_num, env_cfg_override=env_cfg_override, ), ) obs_dim, _ = get_obs_dims(env.obs_groups_spec) if env.state is None: env.init_state() _, _, state_priv_info = split_hora_obs_with_priv_info( env.state.obs, env.state.info if env.state is not None else None, ) priv_info_dim = int(state_priv_info.shape[1]) if state_priv_info is not None else 0 if priv_info_dim <= 0: raise ValueError("HORA APPO play requires privileged info from the environment.") action_shape = env.action_space.shape if action_shape is None: raise ValueError("env.action_space.shape must be defined") action_dim = int(action_shape[0]) rl_cfg_dict = dict(rl_cfg) _update_hora_obs_groups(rl_cfg_dict, obs_dim=obs_dim, priv_info_dim=priv_info_dim) if is_rsl_rl_v5(): pass elif is_rsl_rl_v4(): rl_cfg_dict = convert_config_v3_to_v4(rl_cfg_dict) obs_example = torch.zeros((cfg.training.play_env_num, obs_dim), device=device) td_example = TensorDict( { "actor": obs_example, "priv_info": torch.zeros((cfg.training.play_env_num, priv_info_dim), device=device), }, batch_size=cfg.training.play_env_num, ) actor_cfg = deepcopy(rl_cfg_dict["actor"]) actor_cls = resolve_callable(actor_cfg.pop("class_name")) actor_cfg.pop("num_actions", None) critic_cfg = deepcopy(rl_cfg_dict.get("critic") or rl_cfg_dict.get("actor") or {}) critic_cfg.pop("class_name", None) critic_cfg.pop("num_actions", None) critic_cfg.pop("distribution_cfg", None) shared_model = build_hora_shared_actor_critic( obs_dim=obs_dim, action_dim=action_dim, priv_info_dim=priv_info_dim, actor_cfg=actor_cfg, critic_cfg=critic_cfg, ).to(device) actor = actor_cls( td_example, rl_cfg_dict["obs_groups"], "actor", action_dim, shared_model=shared_model, **actor_cfg, ) actor = actor.to(device) actor.eval() load_path, load_path_dir = resolve_checkpoint_path(cfg) if not load_path or not os.path.exists(load_path): print(f"Could not find run to load. load_path={load_path}") return None print(f"Loading model: {load_path}") checkpoint = torch.load(load_path, map_location=device, weights_only=True) actor.load_state_dict(checkpoint["actor"]) current_priv_info: np.ndarray | None = None def initialize_play_obs() -> np.ndarray: nonlocal current_priv_info obs_out, info_out = env.reset(np.arange(cfg.training.play_env_num, dtype=np.int32)) actor_obs, _, priv_info = split_hora_obs_with_priv_info(obs_out, info_out) current_priv_info = priv_info.astype(np.float32) if priv_info is not None else None return np.asarray(actor_obs, dtype=np.float32) def step_play_obs(obs_np: np.ndarray) -> np.ndarray: nonlocal current_priv_info if current_priv_info is None: raise ValueError("HORA APPO play step is missing privileged info.") td = build_hora_actor_tensordict( obs_np, priv_info=current_priv_info, device=device, batch_size=cfg.training.play_env_num, ) actions = actor(td).cpu().numpy().astype(np.float32) state = env.step(actions) actor_obs, _, priv_info = split_hora_obs_with_priv_info(state.obs, state.info) current_priv_info = priv_info.astype(np.float32) if priv_info is not None else None return np.asarray(actor_obs, dtype=np.float32) print("Collecting physics states...") with torch.inference_mode(): play_video_path = cast( str | None, env.run_playback_mode( play_render_mode=getattr(cfg.training, "play_render_mode", "auto"), play_steps=getattr(cfg.training, "play_steps", None), output_video=os.path.join(load_path_dir, "play_video.mp4") if load_path_dir else None, render_spacing=float( getattr(cfg.training, "render_spacing", getattr(env.cfg, "render_spacing", 1.0)) ), initialize=initialize_play_obs, step=step_play_obs, camera_kwargs={ "cam_distance": cfg.training.cam_distance, "cam_elevation": cfg.training.cam_elevation, "cam_azimuth": cfg.training.cam_azimuth, "cam_lookat": getattr(cfg.training, "cam_lookat", None), "cam_tracking": getattr(cfg.training, "cam_tracking", False), "cam_tracking_env_idx": getattr(cfg.training, "cam_tracking_env_idx", 0), "cam_tracking_extra_envs": getattr(cfg.training, "cam_tracking_extra_envs", 2), }, on_plan=log_playback_plan, ), ) if play_video_path is not None: print(f"Saving video to {play_video_path} with mediapy...") print("Done.") return play_video_path
__all__ = ["HoraAPPORunner", "HoraAPPORuntime", "play_hora_appo", "resolve_hora_appo_runtime"]