Source code for unilab.algos.torch.hora.models

from __future__ import annotations

import copy
from dataclasses import dataclass
from typing import Any, cast

import torch
import torch.nn as nn
from rsl_rl.modules import EmpiricalNormalization, GaussianDistribution
from tensordict import TensorDict


def _build_activation(name: str) -> nn.Module:
    normalized = str(name).strip().lower()
    if normalized == "elu":
        return nn.ELU()
    if normalized == "relu":
        return nn.ReLU()
    if normalized == "tanh":
        return nn.Tanh()
    raise ValueError(f"Unsupported activation: {name!r}")


class _MLP(nn.Module):
    def __init__(
        self, input_dim: int, hidden_dims: list[int] | tuple[int, ...], activation: str
    ) -> None:
        super().__init__()
        layers: list[nn.Module] = []
        current_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, int(hidden_dim)))
            layers.append(_build_activation(activation))
            current_dim = int(hidden_dim)
        self.net = nn.Sequential(*layers)
        self.output_dim = current_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


[docs] class ProprioAdaptTConv(nn.Module): """Temporal adaptation encoder used by HORA stage-2 distillation."""
[docs] def __init__(self, frame_dim: int, latent_dim: int) -> None: super().__init__() self.channel_transform = nn.Sequential( nn.Linear(frame_dim, frame_dim), nn.ReLU(inplace=True), nn.Linear(frame_dim, frame_dim), nn.ReLU(inplace=True), ) self.temporal_aggregation = nn.Sequential( nn.Conv1d(frame_dim, frame_dim, kernel_size=9, stride=2), nn.ReLU(inplace=True), nn.Conv1d(frame_dim, frame_dim, kernel_size=5, stride=1), nn.ReLU(inplace=True), nn.Conv1d(frame_dim, frame_dim, kernel_size=5, stride=1), nn.ReLU(inplace=True), ) self.low_dim_proj = nn.Linear(frame_dim * 3, latent_dim) self._init_weights()
def _init_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Conv1d): fan_out = module.kernel_size[0] * module.out_channels module.weight.data.normal_(mean=0.0, std=(2.0 / fan_out) ** 0.5) if module.bias is not None: nn.init.zeros_(module.bias) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias)
[docs] def forward(self, proprio_hist: torch.Tensor) -> torch.Tensor: x = self.channel_transform(proprio_hist) x = x.permute(0, 2, 1) x = self.temporal_aggregation(x) return self.low_dim_proj(x.flatten(1))
[docs] @dataclass class HoraCoreOutput: policy_obs: torch.Tensor trunk_latent: torch.Tensor privileged_latent: torch.Tensor privileged_target: torch.Tensor
[docs] class HoraSharedActorCritic(nn.Module): """Shared-backbone HORA actor-critic with optional adaptation encoder."""
[docs] def __init__( self, obs_dim: int, action_dim: int, *, priv_info_dim: int, priv_info_embed_dim: int = 8, actor_hidden_dims: list[int] | tuple[int, ...] = (512, 256, 128), priv_mlp_hidden_dims: list[int] | tuple[int, ...] = (256, 128, 8), activation: str = "elu", obs_normalization: bool = False, distribution_cfg: dict[str, Any] | None = None, use_student_encoder: bool = False, proprio_hist_len: int = 30, proprio_frame_dim: int | None = None, ) -> None: super().__init__() self.obs_dim = int(obs_dim) self.action_dim = int(action_dim) self.priv_info_dim = int(priv_info_dim) self.priv_info_embed_dim = int(priv_info_embed_dim) self.use_student_encoder = bool(use_student_encoder) self.proprio_hist_len = int(proprio_hist_len) self.proprio_frame_dim = ( int(proprio_frame_dim) if proprio_frame_dim is not None else self.obs_dim // 3 ) self.obs_normalizer = ( EmpiricalNormalization(self.obs_dim) if obs_normalization else nn.Identity() ) self.priv_encoder = _MLP(self.priv_info_dim, list(priv_mlp_hidden_dims), activation) self.trunk = _MLP( self.obs_dim + self.priv_info_embed_dim, list(actor_hidden_dims), activation ) self.value_head = nn.Linear(self.trunk.output_dim, 1) self.mu_head = nn.Linear(self.trunk.output_dim, self.action_dim) self.distribution = GaussianDistribution( self.action_dim, **( { key: value for key, value in ( distribution_cfg if distribution_cfg is not None else {"init_std": 1.0, "std_type": "scalar"} ).items() if key != "class_name" } ), ) self.adapt_tconv = ( ProprioAdaptTConv(self.proprio_frame_dim, self.priv_info_embed_dim) if self.use_student_encoder else None ) self._init_linear_biases()
def _init_linear_biases(self) -> None: for module in self.modules(): if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias) def _normalize_actor_obs(self, actor_obs: torch.Tensor) -> torch.Tensor: return self.obs_normalizer(actor_obs)
[docs] def update_normalization(self, obs: TensorDict) -> None: if isinstance(self.obs_normalizer, EmpiricalNormalization): self.obs_normalizer.update(obs["actor"])
def _zero_privileged_latent( self, batch_size: int, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: return torch.zeros((batch_size, self.priv_info_embed_dim), device=device, dtype=dtype)
[docs] def encode_privileged_info(self, priv_info: torch.Tensor | None) -> torch.Tensor: if priv_info is None: raise ValueError("priv_info is required to compute the HORA teacher latent") return torch.tanh(self.priv_encoder(priv_info))
[docs] def encode_proprio_history(self, proprio_hist: torch.Tensor) -> torch.Tensor: if self.adapt_tconv is None: raise RuntimeError("HORA adaptation encoder is not enabled") return torch.tanh(self.adapt_tconv(proprio_hist))
[docs] def build_core_output(self, obs: TensorDict, *, prefer_student: bool) -> HoraCoreOutput: actor_obs = obs["actor"] policy_obs = self._normalize_actor_obs(actor_obs) priv_info = obs.get("priv_info") proprio_hist = obs.get("proprio_hist") privileged_target = ( self.encode_privileged_info(priv_info) if priv_info is not None else self._zero_privileged_latent(actor_obs.shape[0], actor_obs.device, actor_obs.dtype) ) if prefer_student and self.adapt_tconv is not None and proprio_hist is not None: privileged_latent = self.encode_proprio_history(proprio_hist) elif priv_info is not None: privileged_latent = privileged_target else: privileged_latent = self._zero_privileged_latent( actor_obs.shape[0], actor_obs.device, actor_obs.dtype ) trunk_input = torch.cat([policy_obs, privileged_latent], dim=-1) trunk_latent = self.trunk(trunk_input) return HoraCoreOutput( policy_obs=policy_obs, trunk_latent=trunk_latent, privileged_latent=privileged_latent, privileged_target=privileged_target, )
[docs] def trunk_latent_from_tensors( self, actor_obs: torch.Tensor, priv_info: torch.Tensor | None, *, prefer_student: bool, proprio_hist: torch.Tensor | None = None, ) -> torch.Tensor: """Tensor-only HORA trunk path used by APPO compiled minibatch loss.""" policy_obs = self._normalize_actor_obs(actor_obs) if prefer_student and self.adapt_tconv is not None and proprio_hist is not None: privileged_latent = self.encode_proprio_history(proprio_hist) elif priv_info is not None: privileged_latent = self.encode_privileged_info(priv_info) else: privileged_latent = self._zero_privileged_latent( actor_obs.shape[0], actor_obs.device, actor_obs.dtype ) return self.trunk(torch.cat([policy_obs, privileged_latent], dim=-1))
[docs] def policy_mean_from_tensors( self, actor_obs: torch.Tensor, priv_info: torch.Tensor | None, *, prefer_student: bool, proprio_hist: torch.Tensor | None = None, ) -> torch.Tensor: trunk_latent = self.trunk_latent_from_tensors( actor_obs, priv_info, prefer_student=prefer_student, proprio_hist=proprio_hist, ) return self.mu_head(trunk_latent)
[docs] def value_from_tensors( self, actor_obs: torch.Tensor, priv_info: torch.Tensor | None, *, prefer_student: bool, proprio_hist: torch.Tensor | None = None, ) -> torch.Tensor: trunk_latent = self.trunk_latent_from_tensors( actor_obs, priv_info, prefer_student=prefer_student, proprio_hist=proprio_hist, ) return self.value_head(trunk_latent)
[docs] def policy_mean( self, obs: TensorDict, *, prefer_student: bool ) -> tuple[torch.Tensor, HoraCoreOutput]: core_output = self.build_core_output(obs, prefer_student=prefer_student) return self.mu_head(core_output.trunk_latent), core_output
[docs] def value( self, obs: TensorDict, *, prefer_student: bool ) -> tuple[torch.Tensor, HoraCoreOutput]: core_output = self.build_core_output(obs, prefer_student=prefer_student) return self.value_head(core_output.trunk_latent), core_output
[docs] def build_hora_shared_actor_critic( *, obs_dim: int, action_dim: int, priv_info_dim: int, actor_cfg: dict[str, Any] | None = None, critic_cfg: dict[str, Any] | None = None, ) -> HoraSharedActorCritic: """Build the shared HORA core from actor/critic config without mutating inputs.""" shared_cfg = copy.deepcopy(actor_cfg or {}) critic_shared_cfg = copy.deepcopy(critic_cfg or {}) for cfg in (shared_cfg, critic_shared_cfg): cfg.pop("class_name", None) cfg.pop("num_actions", None) critic_shared_cfg.pop("distribution_cfg", None) activation = shared_cfg.pop("activation", None) if activation is None: activation = critic_shared_cfg.pop("activation", "elu") else: critic_shared_cfg.pop("activation", None) obs_normalization = shared_cfg.pop("obs_normalization", None) if obs_normalization is None: obs_normalization = critic_shared_cfg.pop("obs_normalization", False) else: critic_shared_cfg.pop("obs_normalization", None) priv_info_embed_dim = shared_cfg.pop("priv_info_embed_dim", None) if priv_info_embed_dim is None: priv_info_embed_dim = critic_shared_cfg.pop("priv_info_embed_dim", priv_info_dim) else: critic_shared_cfg.pop("priv_info_embed_dim", None) priv_mlp_hidden_dims = shared_cfg.pop("priv_mlp_hidden_dims", None) if priv_mlp_hidden_dims is None: priv_mlp_hidden_dims = critic_shared_cfg.pop("priv_mlp_hidden_dims", (256, 128, 8)) else: critic_shared_cfg.pop("priv_mlp_hidden_dims", None) return HoraSharedActorCritic( obs_dim=obs_dim, action_dim=action_dim, priv_info_dim=priv_info_dim, actor_hidden_dims=shared_cfg.pop("hidden_dims", (512, 256, 128)), activation=str(activation), obs_normalization=bool(obs_normalization), distribution_cfg=shared_cfg.pop("distribution_cfg", None), priv_info_embed_dim=int(priv_info_embed_dim), priv_mlp_hidden_dims=priv_mlp_hidden_dims, use_student_encoder=bool(shared_cfg.pop("use_student_encoder", False)), proprio_hist_len=int(shared_cfg.pop("proprio_hist_len", 30)), proprio_frame_dim=shared_cfg.pop("proprio_frame_dim", None), )
class _HoraInferenceModule(nn.Module): input_names = ["actor", "priv_info", "proprio_hist"] output_names = ["actions"] def __init__( self, *, obs_normalizer: nn.Module, priv_encoder: nn.Module, trunk: nn.Module, mu_head: nn.Module, obs_dim: int, priv_info_dim: int, proprio_hist_len: int, proprio_frame_dim: int, verbose: bool = False, adapt_tconv: nn.Module | None = None, prefer_student: bool = False, ) -> None: super().__init__() self.obs_normalizer = obs_normalizer self.priv_encoder = priv_encoder self.trunk = trunk self.mu_head = mu_head self.adapt_tconv = adapt_tconv self.prefer_student = bool(prefer_student) self.obs_dim = int(obs_dim) self.priv_info_dim = int(priv_info_dim) self.proprio_hist_len = int(proprio_hist_len) self.proprio_frame_dim = int(proprio_frame_dim) self.verbose = bool(verbose) def forward( self, actor: torch.Tensor, priv_info: torch.Tensor, proprio_hist: torch.Tensor ) -> torch.Tensor: policy_obs = self.obs_normalizer(actor) if self.prefer_student: if self.adapt_tconv is None: raise RuntimeError("HORA adaptation encoder export requires adapt_tconv") privileged_latent = torch.tanh(self.adapt_tconv(proprio_hist)) else: privileged_latent = torch.tanh(self.priv_encoder(priv_info)) trunk_input = torch.cat([policy_obs, privileged_latent], dim=-1) return self.mu_head(self.trunk(trunk_input)) def get_dummy_inputs(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return ( torch.zeros(1, self.obs_dim), torch.zeros(1, self.priv_info_dim), torch.zeros(1, self.proprio_hist_len, self.proprio_frame_dim), )
[docs] class HoraActorModel(nn.Module): is_recurrent: bool = False
[docs] def __init__( self, obs: TensorDict, obs_groups: dict[str, list[str]], obs_set: str, output_dim: int, *, shared_model: HoraSharedActorCritic | None = None, hidden_dims: list[int] | tuple[int, ...] = (512, 256, 128), activation: str = "elu", obs_normalization: bool = False, distribution_cfg: dict[str, Any] | None = None, priv_info_dim: int | None = None, priv_info_embed_dim: int = 8, priv_mlp_hidden_dims: list[int] | tuple[int, ...] = (256, 128, 8), use_student_encoder: bool = False, proprio_hist_len: int = 30, proprio_frame_dim: int | None = None, ) -> None: del obs_groups, obs_set super().__init__() if shared_model is None: shared_model = HoraSharedActorCritic( obs_dim=int(obs["actor"].shape[-1]), action_dim=output_dim, priv_info_dim=int( priv_info_dim if priv_info_dim is not None else obs.get("priv_info").shape[-1] ), priv_info_embed_dim=priv_info_embed_dim, actor_hidden_dims=hidden_dims, priv_mlp_hidden_dims=priv_mlp_hidden_dims, activation=activation, obs_normalization=obs_normalization, distribution_cfg=distribution_cfg, use_student_encoder=use_student_encoder, proprio_hist_len=proprio_hist_len, proprio_frame_dim=proprio_frame_dim if proprio_frame_dim is not None else (int(obs["proprio_hist"].shape[-1]) if "proprio_hist" in obs else None), ) self.shared = shared_model self.prefer_student = bool(use_student_encoder)
[docs] def forward( self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state=None, stochastic_output: bool = False, ) -> torch.Tensor: del masks, hidden_state mean, _ = self.shared.policy_mean(obs, prefer_student=self.prefer_student) self.shared.distribution.update(mean) if stochastic_output: return self.shared.distribution.sample() return self.shared.distribution.deterministic_output(mean)
[docs] def reset(self, dones: torch.Tensor | None = None, hidden_state=None) -> None: del dones, hidden_state
[docs] def get_hidden_state(self): return None
[docs] def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None: del dones
@property def distribution(self) -> GaussianDistribution: return self.shared.distribution @property def output_mean(self) -> torch.Tensor: return self.shared.distribution.mean @property def output_std(self) -> torch.Tensor: return self.shared.distribution.std @property def output_entropy(self) -> torch.Tensor: return self.shared.distribution.entropy @property def output_distribution_params(self) -> tuple[torch.Tensor, ...]: return cast(tuple[torch.Tensor, ...], self.shared.distribution.params)
[docs] def get_output_log_prob(self, outputs: torch.Tensor) -> torch.Tensor: return self.shared.distribution.log_prob(outputs)
[docs] def get_kl_divergence( self, old_params: tuple[torch.Tensor, ...], new_params: tuple[torch.Tensor, ...] ) -> torch.Tensor: return self.shared.distribution.kl_divergence(old_params, new_params)
[docs] def update_normalization(self, obs: TensorDict) -> None: self.shared.update_normalization(obs)
[docs] def as_jit(self) -> nn.Module: return _HoraInferenceModule( obs_normalizer=copy.deepcopy(self.shared.obs_normalizer), priv_encoder=copy.deepcopy(self.shared.priv_encoder), trunk=copy.deepcopy(self.shared.trunk), mu_head=copy.deepcopy(self.shared.mu_head), obs_dim=self.shared.obs_dim, priv_info_dim=self.shared.priv_info_dim, proprio_hist_len=self.shared.proprio_hist_len, proprio_frame_dim=self.shared.proprio_frame_dim, adapt_tconv=copy.deepcopy(self.shared.adapt_tconv), prefer_student=self.prefer_student, )
[docs] def as_onnx(self, verbose: bool) -> nn.Module: return _HoraInferenceModule( obs_normalizer=copy.deepcopy(self.shared.obs_normalizer), priv_encoder=copy.deepcopy(self.shared.priv_encoder), trunk=copy.deepcopy(self.shared.trunk), mu_head=copy.deepcopy(self.shared.mu_head), obs_dim=self.shared.obs_dim, priv_info_dim=self.shared.priv_info_dim, proprio_hist_len=self.shared.proprio_hist_len, proprio_frame_dim=self.shared.proprio_frame_dim, verbose=verbose, adapt_tconv=copy.deepcopy(self.shared.adapt_tconv), prefer_student=self.prefer_student, )
[docs] class HoraCriticModel(nn.Module): is_recurrent: bool = False
[docs] def __init__( self, obs: TensorDict, obs_groups: dict[str, list[str]], obs_set: str, output_dim: int, *, shared_model: HoraSharedActorCritic | None = None, hidden_dims: list[int] | tuple[int, ...] = (512, 256, 128), activation: str = "elu", obs_normalization: bool = False, priv_info_dim: int | None = None, priv_info_embed_dim: int = 8, priv_mlp_hidden_dims: list[int] | tuple[int, ...] = (256, 128, 8), ) -> None: del obs_groups, obs_set, output_dim super().__init__() if shared_model is None: shared_model = HoraSharedActorCritic( obs_dim=int(obs["actor"].shape[-1]), action_dim=1, priv_info_dim=int( priv_info_dim if priv_info_dim is not None else obs.get("priv_info").shape[-1] ), priv_info_embed_dim=priv_info_embed_dim, actor_hidden_dims=hidden_dims, priv_mlp_hidden_dims=priv_mlp_hidden_dims, activation=activation, obs_normalization=obs_normalization, ) self.shared = shared_model
[docs] def forward( self, obs: TensorDict, masks: torch.Tensor | None = None, hidden_state=None, stochastic_output: bool = False, ) -> torch.Tensor: del masks, hidden_state, stochastic_output value, _ = self.shared.value(obs, prefer_student=False) return value
[docs] def reset(self, dones: torch.Tensor | None = None, hidden_state=None) -> None: del dones, hidden_state
[docs] def get_hidden_state(self): return None
[docs] def detach_hidden_state(self, dones: torch.Tensor | None = None) -> None: del dones
[docs] def update_normalization(self, obs: TensorDict) -> None: del obs