Source code for unilab.algos.torch.him_ppo.storage

# SPDX-License-Identifier: BSD-3-Clause
#
# Adapted from the HIMLoco RSL-RL HIM rollout storage for UniLab.

from __future__ import annotations

from collections.abc import Sequence
from typing import cast

import torch


[docs] class HIMRolloutStorage:
[docs] class Transition:
[docs] def __init__(self) -> None: self.observations: torch.Tensor | None = None self.critic_observations: torch.Tensor | None = None self.next_critic_observations: torch.Tensor | None = None self.actions: torch.Tensor | None = None self.rewards: torch.Tensor | None = None self.dones: torch.Tensor | None = None self.values: torch.Tensor | None = None self.actions_log_prob: torch.Tensor | None = None self.action_mean: torch.Tensor | None = None self.action_sigma: torch.Tensor | None = None
[docs] def clear(self) -> None: self.observations = None self.critic_observations = None self.next_critic_observations = None self.actions = None self.rewards = None self.dones = None self.values = None self.actions_log_prob = None self.action_mean = None self.action_sigma = None
[docs] def __init__( self, num_envs: int, num_transitions_per_env: int, obs_shape: Sequence[int], privileged_obs_shape: Sequence[int | None], actions_shape: Sequence[int], device: str = "cpu", ) -> None: self.device = device self.obs_shape = tuple(obs_shape) self.privileged_obs_shape = tuple(privileged_obs_shape) self.actions_shape = tuple(actions_shape) self.num_transitions_per_env = int(num_transitions_per_env) self.num_envs = int(num_envs) self.step = 0 self.observations = torch.zeros( self.num_transitions_per_env, self.num_envs, *self.obs_shape, device=self.device, ) if self.privileged_obs_shape and self.privileged_obs_shape[0] is not None: if any(dim is None for dim in self.privileged_obs_shape): raise ValueError("privileged_obs_shape cannot contain None values") privileged_obs_shape = cast(tuple[int, ...], self.privileged_obs_shape) self.privileged_observations = torch.zeros( self.num_transitions_per_env, self.num_envs, *privileged_obs_shape, device=self.device, ) self.next_privileged_observations = torch.zeros_like(self.privileged_observations) else: self.privileged_observations = None self.next_privileged_observations = None self.rewards = torch.zeros( self.num_transitions_per_env, self.num_envs, 1, device=self.device ) self.actions = torch.zeros( self.num_transitions_per_env, self.num_envs, *self.actions_shape, device=self.device, ) self.dones = torch.zeros( self.num_transitions_per_env, self.num_envs, 1, device=self.device ).bool() self.actions_log_prob = torch.zeros_like(self.rewards) self.values = torch.zeros_like(self.rewards) self.returns = torch.zeros_like(self.rewards) self.advantages = torch.zeros_like(self.rewards) self.mu = torch.zeros_like(self.actions) self.sigma = torch.zeros_like(self.actions)
[docs] def add_transition(self, transition: Transition) -> None: if self.step >= self.num_transitions_per_env: raise AssertionError("Rollout buffer overflow") if transition.observations is None: raise ValueError("transition.observations is required") if transition.actions is None: raise ValueError("transition.actions is required") if transition.rewards is None: raise ValueError("transition.rewards is required") if transition.dones is None: raise ValueError("transition.dones is required") if transition.values is None: raise ValueError("transition.values is required") if transition.actions_log_prob is None: raise ValueError("transition.actions_log_prob is required") if transition.action_mean is None or transition.action_sigma is None: raise ValueError("transition action distribution stats are required") self.observations[self.step].copy_(transition.observations) if self.privileged_observations is not None: if transition.critic_observations is None: raise ValueError("transition.critic_observations is required") if transition.next_critic_observations is None: raise ValueError("transition.next_critic_observations is required") assert self.next_privileged_observations is not None self.privileged_observations[self.step].copy_(transition.critic_observations) self.next_privileged_observations[self.step].copy_(transition.next_critic_observations) self.actions[self.step].copy_(transition.actions) self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) self.dones[self.step].copy_(transition.dones.view(-1, 1).bool()) self.values[self.step].copy_(transition.values) self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) self.mu[self.step].copy_(transition.action_mean) self.sigma[self.step].copy_(transition.action_sigma) self.step += 1
[docs] def add_transitions(self, transition: Transition) -> None: self.add_transition(transition)
[docs] def clear(self) -> None: self.step = 0
[docs] def compute_returns(self, last_values: torch.Tensor, gamma: float, lam: float) -> None: advantage = torch.zeros_like(last_values) for step in reversed(range(self.num_transitions_per_env)): if step == self.num_transitions_per_env - 1: next_values = last_values else: next_values = self.values[step + 1] next_is_not_terminal = 1.0 - self.dones[step].float() delta = ( self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] ) advantage = delta + next_is_not_terminal * gamma * lam * advantage self.returns[step] = advantage + self.values[step] self.advantages = self.returns - self.values self.advantages = (self.advantages - self.advantages.mean()) / ( self.advantages.std() + 1e-8 )
[docs] def mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8): batch_size = self.num_envs * self.num_transitions_per_env mini_batch_size = batch_size // int(num_mini_batches) if mini_batch_size <= 0: raise ValueError("num_mini_batches is too large for the rollout batch") indices = torch.randperm( int(num_mini_batches) * mini_batch_size, requires_grad=False, device=self.device, ) observations = self.observations.flatten(0, 1) if self.privileged_observations is not None: assert self.next_privileged_observations is not None critic_observations = self.privileged_observations.flatten(0, 1) next_critic_observations = self.next_privileged_observations.flatten(0, 1) else: critic_observations = observations next_critic_observations = observations actions = self.actions.flatten(0, 1) values = self.values.flatten(0, 1) returns = self.returns.flatten(0, 1) old_actions_log_prob = self.actions_log_prob.flatten(0, 1) advantages = self.advantages.flatten(0, 1) old_mu = self.mu.flatten(0, 1) old_sigma = self.sigma.flatten(0, 1) for _ in range(int(num_epochs)): for i in range(int(num_mini_batches)): start = i * mini_batch_size end = (i + 1) * mini_batch_size batch_idx = indices[start:end] yield ( observations[batch_idx], critic_observations[batch_idx], actions[batch_idx], next_critic_observations[batch_idx], values[batch_idx], advantages[batch_idx], returns[batch_idx], old_actions_log_prob[batch_idx], old_mu[batch_idx], old_sigma[batch_idx], )