Source code for unilab.algos.torch.him_ppo.actor_critic

# SPDX-License-Identifier: BSD-3-Clause
#
# Adapted from the HIMLoco RSL-RL HIM actor-critic for UniLab.

from __future__ import annotations

import torch
import torch.nn as nn
from torch.distributions import Normal

from unilab.algos.torch.him_ppo.estimator import HIMEstimator, get_activation


[docs] class HIMActorCritic(nn.Module): is_recurrent = False
[docs] def __init__( self, num_actor_obs: int, num_critic_obs: int, num_one_step_obs: int, num_actions: int, actor_hidden_dims: list[int] | tuple[int, ...] = (512, 256, 128), critic_hidden_dims: list[int] | tuple[int, ...] = (512, 256, 128), activation: str = "elu", init_noise_std: float = 1.0, estimator: dict | None = None, ) -> None: super().__init__() if num_one_step_obs <= 0: raise ValueError("num_one_step_obs must be positive") if num_actor_obs % num_one_step_obs != 0: raise ValueError( "num_actor_obs must be an integer multiple of num_one_step_obs " f"for HIM history obs, got {num_actor_obs} and {num_one_step_obs}" ) if len(actor_hidden_dims) == 0 or len(critic_hidden_dims) == 0: raise ValueError("actor_hidden_dims and critic_hidden_dims must not be empty") self.history_size = int(num_actor_obs // num_one_step_obs) self.num_actor_obs = int(num_actor_obs) self.num_critic_obs = int(num_critic_obs) self.num_actions = int(num_actions) self.num_one_step_obs = int(num_one_step_obs) estimator_cfg = dict(estimator or {}) self.estimator = HIMEstimator( temporal_steps=self.history_size, num_one_step_obs=self.num_one_step_obs, activation=activation, **estimator_cfg, ) actor_input_dim = self.num_one_step_obs + 3 + self.estimator.num_latent self.actor = _build_mlp(actor_input_dim, self.num_actions, actor_hidden_dims, activation) self.critic = _build_mlp(self.num_critic_obs, 1, critic_hidden_dims, activation) self.std = nn.Parameter(float(init_noise_std) * torch.ones(self.num_actions)) self.distribution: Normal | None = None Normal.set_default_validate_args(False)
@property def action_mean(self) -> torch.Tensor: assert self.distribution is not None return self.distribution.mean @property def action_std(self) -> torch.Tensor: assert self.distribution is not None return self.distribution.stddev @property def entropy(self) -> torch.Tensor: assert self.distribution is not None return self.distribution.entropy().sum(dim=-1)
[docs] def reset(self, dones: torch.Tensor | None = None) -> None: del dones
[docs] def forward(self) -> torch.Tensor: raise NotImplementedError
[docs] def update_distribution(self, obs_history: torch.Tensor) -> None: with torch.no_grad(): vel, latent = self.estimator(obs_history) actor_input = torch.cat( (obs_history[:, : self.num_one_step_obs], vel, latent), dim=-1, ) mean = self.actor(actor_input) self.distribution = Normal(mean, mean * 0.0 + self.std)
[docs] def act(self, obs_history: torch.Tensor, **kwargs) -> torch.Tensor: del kwargs self.update_distribution(obs_history) assert self.distribution is not None return self.distribution.sample()
[docs] def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor: assert self.distribution is not None return self.distribution.log_prob(actions).sum(dim=-1)
[docs] def act_inference(self, obs_history: torch.Tensor, observations=None) -> torch.Tensor: del observations vel, latent = self.estimator(obs_history) actor_input = torch.cat( (obs_history[:, : self.num_one_step_obs], vel, latent), dim=-1, ) return self.actor(actor_input)
[docs] def evaluate(self, critic_observations: torch.Tensor, **kwargs) -> torch.Tensor: del kwargs return self.critic(critic_observations)
def _build_mlp( input_dim: int, output_dim: int, hidden_dims: list[int] | tuple[int, ...], activation: str, ) -> nn.Sequential: layers: list[nn.Module] = [] last_dim = int(input_dim) for hidden_dim in hidden_dims: layers += [nn.Linear(last_dim, int(hidden_dim)), get_activation(activation)] last_dim = int(hidden_dim) layers.append(nn.Linear(last_dim, int(output_dim))) return nn.Sequential(*layers)