Source code for unilab.algos.torch.hora.ppo

from __future__ import annotations

from collections.abc import Callable
from itertools import chain
from typing import Any, cast

import torch
import torch.optim as optim
from rsl_rl.algorithms.ppo import PPO
from rsl_rl.env import VecEnv
from rsl_rl.extensions import resolve_rnd_config, resolve_symmetry_config
from rsl_rl.storage import RolloutStorage
from rsl_rl.utils import resolve_obs_groups, resolve_optimizer
from tensordict import TensorDict

from unilab.algos.torch.hora.models import HoraActorModel, HoraCriticModel, HoraSharedActorCritic
from unilab.algos.torch.rsl_rl_ppo import FinalObservationAwarePPO


[docs] class HoraPPO(FinalObservationAwarePPO): """PPO variant that constructs a shared HORA actor-critic backbone."""
[docs] def __init__( self, actor: HoraActorModel, critic: HoraCriticModel, storage: RolloutStorage, num_learning_epochs: int = 5, num_mini_batches: int = 4, clip_param: float = 0.2, gamma: float = 0.99, lam: float = 0.95, value_loss_coef: float = 1.0, entropy_coef: float = 0.01, learning_rate: float = 0.001, max_grad_norm: float = 1.0, optimizer: str = "adam", use_clipped_value_loss: bool = True, schedule: str = "adaptive", desired_kl: float = 0.01, normalize_advantage_per_mini_batch: bool = False, device: str = "cpu", rnd_cfg: dict | None = None, symmetry_cfg: dict | None = None, multi_gpu_cfg: dict | None = None, enable_compile: bool = False, ) -> None: self.device = device self.is_multi_gpu = multi_gpu_cfg is not None if multi_gpu_cfg is not None: self.gpu_global_rank = multi_gpu_cfg["global_rank"] self.gpu_world_size = multi_gpu_cfg["world_size"] else: self.gpu_global_rank = 0 self.gpu_world_size = 1 if rnd_cfg: rnd_lr = rnd_cfg.pop("learning_rate", 1e-3) from rsl_rl.extensions import RandomNetworkDistillation self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg) self.rnd_optimizer = optim.Adam(self.rnd.predictor.parameters(), lr=rnd_lr) else: self.rnd = None self.rnd_optimizer = None self.symmetry: dict[str, Any] | None if symmetry_cfg is not None: use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"] if not use_symmetry: print("Symmetry not used for learning. We will use it for logging instead.") from rsl_rl.utils import resolve_callable symmetry_cfg["data_augmentation_func"] = resolve_callable( symmetry_cfg["data_augmentation_func"] ) if not callable(symmetry_cfg["data_augmentation_func"]): raise ValueError( "Symmetry configuration exists but the function is not callable: " f"{symmetry_cfg['data_augmentation_func']}" ) if actor.is_recurrent or critic.is_recurrent: raise ValueError("Symmetry augmentation is not supported for recurrent policies.") self.symmetry = symmetry_cfg else: self.symmetry = None self_ref = cast(Any, self) self_ref.actor = actor.to(self.device) self_ref.critic = critic.to(self.device) optimizer_cls = cast(Callable[..., optim.Optimizer], resolve_optimizer(optimizer)) self.optimizer = optimizer_cls( self._unique_trainable_parameters(), lr=learning_rate, ) self.storage = storage self.transition = RolloutStorage.Transition() self.clip_param = clip_param self.num_learning_epochs = num_learning_epochs self.num_mini_batches = num_mini_batches self.value_loss_coef = value_loss_coef self.entropy_coef = entropy_coef self.gamma = gamma self.lam = lam self.max_grad_norm = max_grad_norm self.use_clipped_value_loss = use_clipped_value_loss self.desired_kl = desired_kl self.schedule = schedule self.learning_rate = learning_rate self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch # FinalObservationAwarePPO supports a compiled MLP fast path. HORA PPO # uses grouped TensorDict models with a shared privileged trunk, so keep # the regular RSL-RL update while accepting the shared PPO config field. self.enable_compile = False
def _unique_trainable_parameters(self) -> list[torch.nn.Parameter]: params: list[torch.nn.Parameter] = [] seen: set[int] = set() for param in chain(self.actor.parameters(), self.critic.parameters()): ident = id(param) if ident in seen: continue seen.add(ident) params.append(param) return params
[docs] @staticmethod def construct_algorithm(obs: TensorDict, env: VecEnv, cfg: dict, device: str) -> PPO: cfg["obs_groups"] = resolve_obs_groups(obs, cfg["obs_groups"], ["actor", "critic"]) cfg["algorithm"] = resolve_rnd_config(cfg["algorithm"], obs, cfg["obs_groups"], env) cfg["algorithm"] = resolve_symmetry_config(cfg["algorithm"], env) actor_cfg = dict(cfg["actor"]) critic_cfg = dict(cfg["critic"]) actor_cfg.pop("class_name", None) critic_cfg.pop("class_name", None) proprio_hist_len = int(actor_cfg.pop("proprio_hist_len", obs["proprio_hist"].shape[1])) proprio_frame_dim = int(actor_cfg.pop("proprio_frame_dim", obs["proprio_hist"].shape[-1])) shared_model = HoraSharedActorCritic( obs_dim=int(obs["actor"].shape[-1]), action_dim=int(env.num_actions), priv_info_dim=int(obs["priv_info"].shape[-1]), actor_hidden_dims=actor_cfg.pop("hidden_dims", (512, 256, 128)), activation=actor_cfg.pop("activation", "elu"), obs_normalization=bool(actor_cfg.pop("obs_normalization", False)), distribution_cfg=actor_cfg.pop("distribution_cfg", None), priv_info_embed_dim=int( actor_cfg.pop("priv_info_embed_dim", obs["priv_info"].shape[-1]) ), priv_mlp_hidden_dims=actor_cfg.pop("priv_mlp_hidden_dims", (256, 128, 8)), use_student_encoder=bool(actor_cfg.pop("use_student_encoder", False)), proprio_hist_len=proprio_hist_len, proprio_frame_dim=proprio_frame_dim, ).to(device) actor = HoraActorModel( obs, cfg["obs_groups"], "actor", env.num_actions, shared_model=shared_model, **actor_cfg, ).to(device) critic = HoraCriticModel( obs, cfg["obs_groups"], "critic", 1, shared_model=shared_model, **critic_cfg, ).to(device) storage = RolloutStorage( "rl", env.num_envs, cfg["num_steps_per_env"], obs, [env.num_actions], device ) algorithm_cfg = dict(cfg["algorithm"]) algorithm_cfg.pop("class_name", None) return HoraPPO( actor, critic, storage, device=device, **algorithm_cfg, multi_gpu_cfg=cfg["multi_gpu"] )
[docs] def process_env_step( self, obs: TensorDict, rewards: torch.Tensor, dones: torch.Tensor, extras: dict[str, torch.Tensor | TensorDict], ) -> None: self.actor.update_normalization(obs) if self.rnd: self.rnd.update_normalization(obs) self.transition.rewards = rewards.clone() self.transition.dones = dones if self.rnd: self.intrinsic_rewards = self.rnd.get_intrinsic_reward(obs) self.transition.rewards += self.intrinsic_rewards timeouts = extras.get("time_outs") timeout_bootstrap_obs = extras.get("time_out_bootstrap_obs") if isinstance(timeouts, torch.Tensor): timeout_mask = timeouts.to(self.device).float() can_bootstrap = ( timeout_bootstrap_obs is not None and isinstance(timeout_bootstrap_obs, TensorDict) and "priv_info" in timeout_bootstrap_obs and torch.count_nonzero(timeout_mask) > 0 ) if can_bootstrap: assert isinstance(timeout_bootstrap_obs, TensorDict) bootstrap_obs = timeout_bootstrap_obs.to(self.device) bootstrap_values = self.critic(bootstrap_obs).detach() self.transition.rewards += self.gamma * torch.squeeze( bootstrap_values * timeout_mask.unsqueeze(1), 1 ) else: transition_values = self.transition.values assert transition_values is not None self.transition.rewards += self.gamma * torch.squeeze( transition_values * timeout_mask.unsqueeze(1), 1 ) self.storage.add_transition(self.transition) self.transition.clear() self.actor.reset(dones) self.critic.reset(dones)