Source code for unilab.algos.torch.him_ppo.runner

# SPDX-License-Identifier: BSD-3-Clause
#
# HIM-PPO OnPolicy Runner for UniLab.

from __future__ import annotations

import os
import time
from collections import deque
from typing import Any, Callable, cast

import torch
from tensordict import TensorDict

from unilab.algos.torch.him_ppo.actor_critic import HIMActorCritic
from unilab.algos.torch.him_ppo.algorithm import HIMPPO


class _HIMLogger:
    """Minimal logger compatible with train_rsl_rl.py summary extraction."""

    def __init__(self) -> None:
        self.rewbuffer: deque[float] = deque(maxlen=100)
        self.lenbuffer: deque[float] = deque(maxlen=100)
        self.tot_timesteps: int = 0


[docs] class HIMOnPolicyRunner: """On-policy training runner for HIM-PPO. Provides the same external interface as rsl_rl's ``OnPolicyRunner`` so that ``train_him_ppo.py`` can share most of its logic with ``train_rsl_rl.py``. """
[docs] def __init__( self, env: Any, train_cfg: dict[str, Any], log_dir: str | None = None, device: str = "cpu", ) -> None: self.env = env self.device = device self.log_dir = log_dir self.current_learning_iteration: int = 0 self.logger = _HIMLogger() # ── Parse config ──────────────────────────────────────────────────── cfg: dict[str, Any] = dict(train_cfg) num_one_step_obs = int(cfg["num_one_step_obs"]) num_actor_history = int(cfg.get("num_actor_history", 1)) num_actor_obs = num_actor_history * num_one_step_obs num_critic_obs = int(getattr(env, "num_privileged_obs", None) or env.num_obs) num_actions = int(env.num_actions) policy_cfg: dict[str, Any] = dict(cfg.get("policy") or {}) estimator_cfg: dict[str, Any] = dict(cfg.get("estimator") or {}) algo_cfg: dict[str, Any] = dict(cfg.get("algorithm") or {}) # ── Build actor-critic ─────────────────────────────────────────────── self.actor_critic = HIMActorCritic( num_actor_obs=num_actor_obs, num_critic_obs=num_critic_obs, num_one_step_obs=num_one_step_obs, num_actions=num_actions, actor_hidden_dims=list(policy_cfg.get("actor_hidden_dims", [512, 256, 128])), critic_hidden_dims=list(policy_cfg.get("critic_hidden_dims", [512, 256, 128])), activation=str(policy_cfg.get("activation", "elu")), init_noise_std=float(policy_cfg.get("init_noise_std", 1.0)), estimator=estimator_cfg, ).to(device) # ── Build algorithm ────────────────────────────────────────────────── self.alg = HIMPPO(self.actor_critic, device=device, **algo_cfg) self.num_steps_per_env = int(cfg.get("num_steps_per_env", 24)) self.save_interval = int(cfg.get("save_interval", 100)) # ── Init storage ───────────────────────────────────────────────────── self.alg.init_storage( env.num_envs, self.num_steps_per_env, [num_actor_obs], [num_critic_obs], [num_actions], ) # Per-env episode tracking (independent of env wrapper's counters) self._ep_returns = torch.zeros(env.num_envs, device=device) self._ep_lengths = torch.zeros(env.num_envs, device=device) # Tensorboard writer (optional) self._writer: Any = None if log_dir is not None: try: from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped] self._writer = SummaryWriter(log_dir=log_dir) except ImportError: pass
# ── Public interface ─────────────────────────────────────────────────────
[docs] def learn( self, num_learning_iterations: int, init_at_random_ep_len: bool = True, ) -> None: obs_td, _ = self.env.reset() obs = obs_td["actor"].to(self.device) critic_obs = obs_td.get("critic", obs).to(self.device) if init_at_random_ep_len: self.env.episode_length_buf = torch.randint_like( self.env.episode_length_buf, high=int(self.env.max_episode_length), ) self.alg.train_mode() start_iter = self.current_learning_iteration tot_iter = start_iter + num_learning_iterations start_time = time.time() for it in range(start_iter, tot_iter): infos: dict[str, Any] = {} # ── Rollout collection ─────────────────────────────────────────── with torch.inference_mode(): for _ in range(self.num_steps_per_env): actions = self.alg.act(obs, critic_obs) obs_td, rewards, dones, infos = self.env.step(actions) next_obs = obs_td["actor"].to(self.device) next_critic_obs = obs_td.get("critic", next_obs).to(self.device) # Track episode stats before env wrapper resets counters done_ids = dones.nonzero(as_tuple=False).flatten() self._ep_returns += rewards.to(self.device) self._ep_lengths += 1 if len(done_ids) > 0: for idx in done_ids: self.logger.rewbuffer.append(float(self._ep_returns[idx])) self.logger.lenbuffer.append(float(self._ep_lengths[idx])) self._ep_returns[done_ids] = 0.0 self._ep_lengths[done_ids] = 0.0 self.alg.process_env_step(obs_td, rewards, dones, infos) obs = next_obs critic_obs = next_critic_obs self.alg.compute_returns(critic_obs) # ── Update ─────────────────────────────────────────────────────── value_loss, surrogate_loss, estimation_loss, swap_loss = self.alg.update() self.current_learning_iteration = it + 1 self.logger.tot_timesteps += self.num_steps_per_env * self.env.num_envs # ── Logging ────────────────────────────────────────────────────── elapsed = time.time() - start_time self._print_iter( it + 1, tot_iter, value_loss, surrogate_loss, estimation_loss, swap_loss, elapsed, infos, ) if self._writer is not None: global_step = self.current_learning_iteration self._writer.add_scalar("train/value_loss", value_loss, global_step) self._writer.add_scalar("train/surrogate_loss", surrogate_loss, global_step) self._writer.add_scalar("train/estimation_loss", estimation_loss, global_step) self._writer.add_scalar("train/swap_loss", swap_loss, global_step) for k, v in (infos.get("log") or {}).items(): self._writer.add_scalar(k, v, global_step) # ── Checkpoint ─────────────────────────────────────────────────── if ( self.log_dir is not None and self.current_learning_iteration % self.save_interval == 0 ): self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) # Final checkpoint if self.log_dir is not None: self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
[docs] def save(self, path: str) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) torch.save( { "actor_state_dict": self.actor_critic.state_dict(), "optimizer_state_dict": self.alg.optimizer.state_dict(), "iteration": self.current_learning_iteration, }, path, )
[docs] def load(self, path: str) -> None: ckpt = torch.load(path, map_location=self.device, weights_only=True) self.actor_critic.load_state_dict(ckpt["actor_state_dict"]) if "optimizer_state_dict" in ckpt: self.alg.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if "iteration" in ckpt: self.current_learning_iteration = int(ckpt["iteration"])
[docs] def get_inference_policy(self, device: str | None = None) -> Callable[..., Any]: self.actor_critic.eval() if device is not None: self.actor_critic.to(device) return cast(Callable[..., Any], self.actor_critic.act_inference)
[docs] def export_policy_to_onnx(self, path: str, filename: str = "policy.onnx") -> None: """Export the end-to-end HIM-PPO policy (estimator + actor) to ONNX. Input: obs_history (1, H_a * num_one_step_obs) — flattened actor obs history Output: actions (1, num_actions) """ class _PolicyExport(torch.nn.Module): def __init__(self, ac: Any) -> None: super().__init__() self.ac = ac def forward(self, obs_history: torch.Tensor) -> torch.Tensor: return self.ac.act_inference(obs_history) orig_device = next(self.actor_critic.parameters()).device model = _PolicyExport(self.actor_critic).cpu().eval() dummy = torch.zeros(1, self.actor_critic.num_actor_obs) os.makedirs(path, exist_ok=True) save_path = os.path.join(path, filename) with torch.inference_mode(): torch.onnx.export( model, (dummy,), save_path, export_params=True, opset_version=18, input_names=["obs_history"], output_names=["actions"], dynamic_axes={"obs_history": {0: "batch_size"}, "actions": {0: "batch_size"}}, ) self.actor_critic.to(orig_device) print(f"Exported HIM-PPO policy to {save_path}")
[docs] def export_policy_to_jit(self, path: str, filename: str = "policy.pt") -> None: """Export the end-to-end HIM-PPO policy (estimator + actor) via TorchScript trace. Wraps estimator + actor as a plain module (without HIMActorCritic's properties) so that torch.jit.trace can introspect it without hitting the distribution assert. """ orig_device = next(self.actor_critic.parameters()).device ac = self.actor_critic.cpu().eval() num_one_step_obs = ac.num_one_step_obs class _PolicyExport(torch.nn.Module): def __init__(self) -> None: super().__init__() self.estimator = ac.estimator self.actor_mlp = ac.actor def forward(self, obs_history: torch.Tensor) -> torch.Tensor: vel, latent = self.estimator.get_latent(obs_history) actor_input = torch.cat((obs_history[:, :num_one_step_obs], vel, latent), dim=-1) return self.actor_mlp(actor_input) model = _PolicyExport().eval() dummy = torch.zeros(1, ac.num_actor_obs) with torch.inference_mode(): traced = cast(torch.jit.ScriptModule, torch.jit.trace(model, (dummy,))) os.makedirs(path, exist_ok=True) save_path = os.path.join(path, filename) traced.save(save_path) self.actor_critic.to(orig_device) print(f"Exported HIM-PPO policy (JIT) to {save_path}")
# ── Helpers ────────────────────────────────────────────────────────────── def _print_iter( self, it: int, tot: int, value_loss: float, surrogate_loss: float, estimation_loss: float, swap_loss: float, elapsed: float, infos: dict, ) -> None: sep = "-" * 80 mean_rew = ( sum(self.logger.rewbuffer) / len(self.logger.rewbuffer) if self.logger.rewbuffer else 0.0 ) mean_len = ( sum(self.logger.lenbuffer) / len(self.logger.lenbuffer) if self.logger.lenbuffer else 0.0 ) time_str = time.strftime("%H:%M:%S", time.gmtime(elapsed)) eta = elapsed / it * (tot - it) if it > 0 else 0.0 eta_str = time.strftime("%H:%M:%S", time.gmtime(eta)) print(sep) print(f"{'Iteration':>40}: {it}/{tot}") print(f"{'Mean value loss':>40}: {value_loss:.4f}") print(f"{'Mean surrogate loss':>40}: {surrogate_loss:.4f}") print(f"{'Mean estimation loss':>40}: {estimation_loss:.4f}") print(f"{'Mean swap loss':>40}: {swap_loss:.4f}") if mean_rew: print(f"{'Mean episode reward':>40}: {mean_rew:.4f}") if mean_len: print(f"{'Mean episode length':>40}: {mean_len:.1f}") for k, v in (infos.get("log") or {}).items(): print(f"{k:>40}: {v:.4f}") print(f"{'Time elapsed':>40}: {time_str}") print(f"{'ETA':>40}: {eta_str}") print(sep)