Source code for unilab.algos.torch.hora.rsl_rl_compat

"""Compatibility helpers for HORA's supported RSL-RL config schemas.

The HORA APPO code uses these helpers to normalize owner configs before
constructing grouped actor/critic modules across supported RSL-RL releases.
"""

from __future__ import annotations

import importlib.metadata
from copy import deepcopy
from functools import lru_cache
from typing import Any

from packaging.version import Version

_MLX_PPO_ONLY_KEYS = {
    "adaptive_kl_beta",
    "adaptive_lr_decay",
    "adaptive_lr_growth",
    "adaptive_lr_update_interval",
    "disable_finite_checks",
    "enable_compile",
    "finite_check_interval",
    "metrics_interval",
    "target_kl_stop",
    "warmup_finite_check_interval",
    "warmup_metrics_interval",
    "warmup_strict_iters",
}


[docs] @lru_cache(maxsize=1) def get_rsl_rl_version() -> str: """Resolve the installed RSL-RL package version. Args: None. Returns: Installed version string from either ``rsl-rl-lib`` or the legacy ``rsl-rl`` package name. """ try: return importlib.metadata.version("rsl-rl-lib") except importlib.metadata.PackageNotFoundError: try: return importlib.metadata.version("rsl-rl") except importlib.metadata.PackageNotFoundError as exc: raise ImportError( "rsl_rl is not installed. Install via: pip install rsl-rl-lib" ) from exc
[docs] def is_rsl_rl_v4() -> bool: """Check whether the active RSL-RL runtime is version 4 or newer. Args: None. Returns: ``True`` when the installed package version is ``>= 4.0.0``. """ return bool(Version(get_rsl_rl_version()) >= Version("4.0.0"))
[docs] def is_rsl_rl_v5() -> bool: """Check whether the active RSL-RL runtime is version 5 or newer. Args: None. Returns: ``True`` when the installed package version is ``>= 5.0.0``. """ return bool(Version(get_rsl_rl_version()) >= Version("5.0.0"))
def _normalize_obs_groups_for_rsl(cfg: dict[str, Any]) -> None: """Translate UniLab owner obs-group aliases into RSL-RL actor/critic groups. Args: cfg: Mutable RSL-RL config dictionary to normalize in place. Returns: None. Updates ``cfg["obs_groups"]`` directly. """ obs_groups_raw = cfg.get("obs_groups", {}) obs_groups = obs_groups_raw if isinstance(obs_groups_raw, dict) else {} if "default" in obs_groups: if "actor" not in obs_groups: obs_groups["actor"] = obs_groups["default"] if "critic" not in obs_groups: obs_groups["critic"] = obs_groups["default"] else: # Keep grouped-dict specs intact for owner runtimes like HORA; the legacy # v3 -> v4 rename only applies to flat list-based group aliases. if isinstance(obs_groups.get("actor"), list): obs_groups["actor"] = ["policy"] if isinstance(obs_groups.get("critic"), list): obs_groups["critic"] = ["policy"] cfg["obs_groups"] = obs_groups def _convert_policy_to_actor_critic( cfg: dict[str, Any], *, distribution_class_name: str, ) -> None: """Split a legacy single ``policy`` config into ``actor`` and ``critic`` blocks. Args: cfg: Mutable RSL-RL config dictionary to normalize in place. distribution_class_name: Distribution class name expected by the target RSL-RL runtime. Returns: None. Updates ``cfg`` directly when a legacy ``policy`` block is present. """ empirical_normalization = bool(cfg.pop("empirical_normalization", False)) cfg.pop("runner_class_name", None) if "policy" not in cfg or "actor" in cfg or "critic" in cfg: return policy = cfg.pop("policy") if not isinstance(policy, dict): return cfg["actor"] = { "class_name": "MLPModel", "hidden_dims": policy.get("actor_hidden_dims", [256, 256, 256]), "activation": policy.get("activation", "elu"), "obs_normalization": empirical_normalization, "distribution_cfg": { "class_name": distribution_class_name, "init_std": policy.get("init_noise_std", 1.0), "std_type": policy.get("noise_std_type", "scalar"), }, } cfg["critic"] = { "class_name": "MLPModel", "hidden_dims": policy.get("critic_hidden_dims", [256, 256, 256]), "activation": policy.get("activation", "elu"), "obs_normalization": empirical_normalization, } def _normalize_algorithm_cfg(cfg: dict[str, Any]) -> None: """Remove owner-only keys that current RSL-RL releases do not accept. Args: cfg: Mutable RSL-RL config dictionary to normalize in place. Returns: None. Updates ``cfg["algorithm"]`` directly when present. """ algorithm_cfg = cfg.get("algorithm") if not isinstance(algorithm_cfg, dict): return algorithm_cfg.setdefault("rnd_cfg", None) for key in _MLX_PPO_ONLY_KEYS: algorithm_cfg.pop(key, None)
[docs] def convert_config_v3_to_v4(cfg: dict[str, Any]) -> dict[str, Any]: """Convert a legacy UniLab PPO/APPO config into the RSL-RL v4 schema. Args: cfg: Resolved owner config dictionary before RSL-RL construction. Returns: Deep-copied config dictionary aligned with the RSL-RL v4 actor/critic schema and obs-group naming. """ converted = deepcopy(cfg) _convert_policy_to_actor_critic( converted, distribution_class_name="rsl_rl.modules.distribution.GaussianDistribution", ) _normalize_algorithm_cfg(converted) _normalize_obs_groups_for_rsl(converted) if "multi_gpu" not in converted: converted["multi_gpu"] = None return converted
[docs] def convert_config_v5(cfg: dict[str, Any]) -> dict[str, Any]: """Convert a legacy UniLab PPO/APPO config into the RSL-RL v5 schema. Args: cfg: Resolved owner config dictionary before RSL-RL construction. Returns: Deep-copied config dictionary aligned with the RSL-RL v5 actor/critic schema and obs-group naming. """ converted = deepcopy(cfg) _convert_policy_to_actor_critic( converted, distribution_class_name="GaussianDistribution", ) _normalize_algorithm_cfg(converted) _normalize_obs_groups_for_rsl(converted) return converted