Source code for unilab.algos.torch.hora.sac_models

"""HORA SAC actor models."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Tuple, cast

import torch
import torch.nn as nn


def _build_mlp(
    input_dim: int,
    hidden_dims: Sequence[int],
    *,
    activation: type[nn.Module],
    use_layer_norm: bool,
    device: str | torch.device,
) -> tuple[nn.Sequential, int]:
    layers: list[nn.Module] = []
    current_dim = int(input_dim)
    for hidden_dim in hidden_dims:
        next_dim = int(hidden_dim)
        layers.append(nn.Linear(current_dim, next_dim, device=device))
        if use_layer_norm:
            layers.append(nn.LayerNorm(next_dim, device=device))
        layers.append(activation())
        current_dim = next_dim
    return nn.Sequential(*layers), current_dim


[docs] class HoraSACActor(nn.Module): """Privileged HORA teacher actor for SAC. The named modules below are intentionally stable for later distillation: ``priv_encoder``, ``actor_trunk``, ``action_mean_head``, and ``action_logstd_head``. """ action_scale: torch.Tensor action_bias: torch.Tensor
[docs] def __init__( self, obs_dim: int, priv_info_dim: int, action_dim: int, *, hidden_dim: int = 512, priv_info_embed_dim: int = 9, priv_mlp_hidden_dims: Sequence[int] = (256, 128, 9), log_std_max: float = 0.0, log_std_min: float = -5.0, use_tanh: bool = True, use_layer_norm: bool = True, device: str | torch.device = "cpu", action_scale: torch.Tensor | None = None, action_bias: torch.Tensor | None = None, ) -> None: super().__init__() self.obs_dim = int(obs_dim) self.priv_info_dim = int(priv_info_dim) self.action_dim = int(action_dim) self.priv_info_embed_dim = int(priv_info_embed_dim) self.log_std_max = float(log_std_max) self.log_std_min = float(log_std_min) self.use_tanh = bool(use_tanh) self.priv_encoder, priv_out_dim = _build_mlp( self.priv_info_dim, tuple(priv_mlp_hidden_dims), activation=nn.SiLU, use_layer_norm=use_layer_norm, device=device, ) if priv_out_dim != self.priv_info_embed_dim: self.priv_projection = nn.Linear( priv_out_dim, self.priv_info_embed_dim, device=device, ) else: self.priv_projection = nn.Identity() trunk_dims = (hidden_dim, hidden_dim // 2, hidden_dim // 4) self.actor_trunk, trunk_out_dim = _build_mlp( self.obs_dim + self.priv_info_embed_dim, trunk_dims, activation=nn.SiLU, use_layer_norm=use_layer_norm, device=device, ) self.action_mean_head = nn.Linear(trunk_out_dim, self.action_dim, device=device) self.action_logstd_head = nn.Linear(trunk_out_dim, self.action_dim, device=device) nn.init.constant_(self.action_mean_head.weight, 0.0) nn.init.constant_(self.action_mean_head.bias, 0.0) nn.init.constant_(self.action_logstd_head.weight, 0.0) nn.init.constant_(self.action_logstd_head.bias, 0.0) if action_scale is not None: self.register_buffer("action_scale", action_scale.to(device)) else: self.register_buffer("action_scale", torch.ones(self.action_dim, device=device)) if action_bias is not None: self.register_buffer("action_bias", action_bias.to(device)) else: self.register_buffer("action_bias", torch.zeros(self.action_dim, device=device))
[docs] def encode_privileged_info(self, priv_info: torch.Tensor) -> torch.Tensor: encoded = self.priv_encoder(priv_info) return torch.tanh(self.priv_projection(encoded))
def _distribution_params( self, obs: torch.Tensor, priv_info: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: z = self.encode_privileged_info(priv_info) latent = self.actor_trunk(torch.cat([obs, z], dim=-1)) mean = self.action_mean_head(latent) log_std = self.action_logstd_head(latent) log_std = torch.tanh(log_std) log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1) mean = torch.clamp(mean, -10.0, 10.0) mean = torch.nan_to_num(mean, nan=0.0) log_std = torch.nan_to_num(log_std, nan=self.log_std_min) return mean, log_std
[docs] def forward( self, obs: torch.Tensor, priv_info: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mean, log_std = self._distribution_params(obs, priv_info) if self.use_tanh: action = torch.tanh(mean) * self.action_scale + self.action_bias else: action = mean return action, mean, log_std
[docs] def as_export_module(self) -> nn.Module: """Return a wrapper with explicit actor/priv inputs for ONNX export.""" actor = self class _Wrapper(nn.Module): def __init__(self) -> None: super().__init__() self.base = actor def forward(self, obs: torch.Tensor, priv_info: torch.Tensor) -> torch.Tensor: action, _, _ = self.base(obs, priv_info) return cast(torch.Tensor, action) return _Wrapper()
[docs] def get_actions_and_log_probs( self, obs: torch.Tensor, priv_info: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: _, mean, log_std = self(obs, priv_info) std = log_std.exp() dist = torch.distributions.Normal(mean, std) raw_action = dist.rsample() if self.use_tanh: tanh_action = torch.tanh(raw_action) action = tanh_action * self.action_scale + self.action_bias log_prob = dist.log_prob(raw_action) log_prob -= torch.log(1 - tanh_action.pow(2) + 1e-6) log_prob -= torch.log(self.action_scale + 1e-6) else: action = raw_action log_prob = dist.log_prob(raw_action) return action, log_prob.sum(1), log_std
[docs] @torch.no_grad() def explore( self, obs: torch.Tensor, priv_info: torch.Tensor, deterministic: bool = False, ) -> torch.Tensor: _, mean, log_std = self.forward(obs, priv_info) if deterministic: if self.use_tanh: return torch.tanh(mean) * self.action_scale + self.action_bias return mean std = log_std.exp() raw_action = torch.distributions.Normal(mean, std).rsample() if self.use_tanh: return torch.tanh(raw_action) * self.action_scale + self.action_bias return raw_action