"""RSL-RL-specific training helpers."""
from __future__ import annotations
from copy import deepcopy
from typing import Any
import numpy as np
import torch
from tensordict import TensorDict
from unilab.base.final_observation import resolve_terminal_observation_contract
from unilab.base.np_env import NpEnvState
from unilab.utils.tensor import to_numpy, to_torch
[docs]
def get_policy_obs_dims(obs_groups_spec: dict[str, int]) -> tuple[int, int]:
"""Return ``(actor_obs_dim, flat_policy_obs_dim)`` for RSL-RL policies."""
actor_obs_dim = int(obs_groups_spec.get("obs", 0))
flat_policy_obs_dim = int(
sum(dim for group_name, dim in obs_groups_spec.items() if group_name != "critic")
)
return actor_obs_dim, flat_policy_obs_dim or actor_obs_dim
[docs]
def normalize_ppo_train_cfg(train_cfg: dict[str, Any]) -> dict[str, Any]:
"""Map UniLab PPO owner config to the current RSL-RL schema."""
normalized = deepcopy(train_cfg)
algorithm_cfg = normalized.get("algorithm")
if isinstance(algorithm_cfg, dict):
for key in (
"target_kl_stop",
"adaptive_kl_beta",
"adaptive_lr_growth",
"adaptive_lr_decay",
"adaptive_lr_update_interval",
"metrics_interval",
"finite_check_interval",
"warmup_strict_iters",
"warmup_metrics_interval",
"warmup_finite_check_interval",
"disable_finite_checks",
):
algorithm_cfg.pop(key, None)
if "actor" in normalized and "critic" in normalized:
return normalized
policy_cfg = normalized.pop("policy", None)
if not isinstance(policy_cfg, dict):
return normalized
actor_hidden_dims = policy_cfg.get("actor_hidden_dims", [512, 256, 128])
critic_hidden_dims = policy_cfg.get("critic_hidden_dims", actor_hidden_dims)
activation = policy_cfg.get("activation", "elu")
init_noise_std = float(policy_cfg.get("init_noise_std", 1.0))
obs_normalization = bool(normalized.get("empirical_normalization", False))
normalized["actor"] = {
"class_name": "rsl_rl.models.MLPModel",
"hidden_dims": actor_hidden_dims,
"activation": activation,
"obs_normalization": obs_normalization,
"distribution_cfg": {
"class_name": "rsl_rl.modules.distribution.GaussianDistribution",
"init_std": init_noise_std,
"std_type": "scalar",
},
}
normalized["critic"] = {
"class_name": "rsl_rl.models.MLPModel",
"hidden_dims": critic_hidden_dims,
"activation": activation,
"obs_normalization": obs_normalization,
}
obs_groups = normalized.get("obs_groups")
if isinstance(obs_groups, dict) and "actor" not in obs_groups and "default" in obs_groups:
default_groups = obs_groups.pop("default")
if isinstance(default_groups, list) and default_groups:
obs_groups["actor"] = list(default_groups)
return normalized
[docs]
class RslRlVecEnvWrapper:
"""Adapter from UniLab's env contract to the RSL-RL VecEnv contract."""
[docs]
def __init__(
self,
env: Any,
device: str = "cpu",
policy_obs_mode: str = "flat",
) -> None:
if policy_obs_mode == "auto":
policy_obs_mode = "flat"
if policy_obs_mode not in {"actor", "flat"}:
raise ValueError(
f"Unsupported policy_obs_mode={policy_obs_mode!r}; expected 'actor' or 'flat'."
)
self.env = env
self.cfg = env.cfg
self.device = device
self.policy_obs_mode = policy_obs_mode
self.num_envs = env.num_envs
self.observation_space = env.observation_space
self.action_space = env.action_space
self._actor_obs_dim, self._flat_obs_dim = get_policy_obs_dims(env.obs_groups_spec)
self.num_obs = (
self._actor_obs_dim if self.policy_obs_mode == "actor" else self._flat_obs_dim
)
self.num_privileged_obs = int(env.obs_groups_spec.get("critic", self.num_obs))
action_shape = env.action_space.shape
if action_shape is None:
raise ValueError("env.action_space.shape must be defined")
self.num_actions = int(action_shape[0])
self.episode_returns = torch.zeros(self.num_envs, device=device)
self.episode_lengths = torch.zeros(self.num_envs, device=device)
self.episode_length_buf = self.episode_lengths
self.max_episode_length = np.ceil(env.cfg.max_episode_seconds / env.cfg.ctrl_dt)
self.reset()
def _policy_obs(self, obs: dict[str, Any]) -> torch.Tensor:
if self.policy_obs_mode == "actor":
return to_torch(obs["obs"], self.device)
policy_groups = [
to_numpy(value) for group_name, value in obs.items() if group_name != "critic"
]
if not policy_groups:
raise KeyError("Observation dict must contain at least one non-critic group")
if len(policy_groups) == 1:
return to_torch(policy_groups[0], self.device)
return to_torch(np.concatenate(policy_groups, axis=1), self.device)
def _obs_to_tensordict(
self,
obs: dict[str, Any],
info: dict[str, Any] | None = None,
) -> TensorDict:
del info
actor_obs = to_torch(obs["obs"], self.device)
td_dict: dict[str, torch.Tensor] = {
"actor": actor_obs,
"policy": self._policy_obs(obs),
}
if "critic" in obs:
td_dict["critic"] = to_torch(obs["critic"], self.device)
return TensorDict(td_dict, batch_size=self.num_envs, device=self.device)
def _resolve_final_observation(self, state: NpEnvState) -> dict[str, Any] | None:
if isinstance(state.final_observation, dict):
return state.final_observation
if isinstance(state.info, dict):
final_observation = state.info.get("final_observation")
if isinstance(final_observation, dict):
return final_observation
return None
def _resolve_done(self, state: NpEnvState) -> torch.Tensor:
return to_torch(state.terminated | state.truncated, self.device).bool()
[docs]
def step(
self, actions: torch.Tensor | np.ndarray
) -> tuple[TensorDict, torch.Tensor, torch.Tensor, dict]:
actions_np = to_numpy(actions)
state = self.env.step(actions_np)
rewards = to_torch(state.reward, self.device)
dones = self._resolve_done(state)
self.episode_returns += rewards
self.episode_lengths += 1
infos: dict[str, torch.Tensor | TensorDict | dict[str, Any]] = {}
done_idx = torch.nonzero(dones).flatten()
if len(done_idx) > 0:
infos["time_outs"] = to_torch(state.truncated, self.device).bool()
final_observation = self._resolve_final_observation(state)
terminal_contract = resolve_terminal_observation_contract(
next_obs_batch_size=self.num_envs,
final_observation=final_observation,
done=to_numpy(dones),
info=state.info,
truncated=to_numpy(infos["time_outs"]),
)
if np.any(terminal_contract.timeout_terminal_mask) and final_observation is not None:
infos["time_out_bootstrap_obs"] = self._obs_to_tensordict(final_observation)
self.episode_returns[done_idx] = 0
self.episode_lengths[done_idx] = 0
if "log" in state.info:
infos["log"] = state.info["log"]
return (
self._obs_to_tensordict(state.obs, getattr(state, "info", None)),
rewards,
dones,
infos,
)
[docs]
def reset(self) -> tuple[TensorDict, dict[str, Any]]:
if self.env.state is None:
self.env.init_state()
env_indices = np.arange(self.num_envs, dtype=np.int32)
obs_out, info = self.env.reset(env_indices)
self.episode_returns[:] = 0
self.episode_lengths[:] = 0
return self._obs_to_tensordict(obs_out, info), info
[docs]
def get_observations(self) -> TensorDict:
assert self.env.state is not None
return self._obs_to_tensordict(self.env.state.obs, self.env.state.info)
[docs]
def get_privileged_observations(self) -> torch.Tensor:
assert self.env.state is not None
obs = self.env.state.obs
return to_torch(obs.get("critic", obs["obs"]), self.device)
[docs]
def close(self) -> None:
self.env.close()