Source code for unilab.algos.torch.hora.appo_learner

"""HORA-owned APPO learner with grouped actor and privileged observations."""

from __future__ import annotations

from typing import cast

import torch
from tensordict import TensorDict

from unilab.algos.torch.appo.learner import (
    APPOLearner,
    _distribution_std,
    _sample_tensor_for_metric,
    vtrace_advantages,
)
from unilab.algos.torch.hora.models import HoraActorModel, HoraCriticModel


def _build_hora_obs_td(
    actor_obs: torch.Tensor,
    *,
    device: str,
    priv_info: torch.Tensor | None = None,
) -> TensorDict:
    if priv_info is not None:
        return TensorDict(
            {"actor": actor_obs, "priv_info": priv_info},
            batch_size=actor_obs.shape[0],
            device=device,
        )
    return TensorDict({"policy": actor_obs}, batch_size=actor_obs.shape[0], device=device)


def _derive_priv_info_from_critic(
    actor_obs: torch.Tensor,
    critic_obs: torch.Tensor | None,
    *,
    context: str,
) -> torch.Tensor:
    if critic_obs is None:
        raise ValueError(f"HORA APPO {context} requires critic observations.")
    actor_dim = int(actor_obs.shape[-1])
    critic_dim = int(critic_obs.shape[-1])
    if critic_dim <= actor_dim:
        raise ValueError(
            f"HORA APPO {context} requires critic observations to include privileged tail "
            f"features; got actor_dim={actor_dim}, critic_dim={critic_dim}."
        )
    return critic_obs[..., actor_dim:]


[docs] class HoraAPPOLearner(APPOLearner): """APPO learner variant for HORA grouped observations.""" def _minibatch_policy_value( self, obs_mini: torch.Tensor, critic_obs_mini: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: priv_info = _derive_priv_info_from_critic( obs_mini, critic_obs_mini, context="minibatch update", ) actor = cast(HoraActorModel, self.actor) critic = cast(HoraCriticModel, self.critic) shared_actor = actor.shared mean = shared_actor.policy_mean_from_tensors( obs_mini, priv_info, prefer_student=actor.prefer_student, ) std = _distribution_std(shared_actor.distribution, mean) value = critic.shared.value_from_tensors( obs_mini, priv_info, prefer_student=False, ).squeeze(-1) return mean, std, value
[docs] def process_batch(self, batch_dict): """Compute V-trace targets for grouped HORA rollouts.""" obs = batch_dict["observations"] critic_base = batch_dict.get("critic", None) rewards = batch_dict["rewards"] dones = batch_dict["dones"].float() last_obs = batch_dict["last_obs"] last_critic = batch_dict.get("last_critic", None) behavior_log_probs = batch_dict["actions_log_prob"] actions = batch_dict["actions"] T, N = obs.shape[:2] priv_info = _derive_priv_info_from_critic( obs, critic_base, context="rollout batch", ) last_priv_info = _derive_priv_info_from_critic( last_obs, last_critic, context="bootstrap batch", ) obs_flat = obs.flatten(0, 1) priv_info_flat = priv_info.flatten(0, 1) obs_td = _build_hora_obs_td(obs_flat, device=self.device, priv_info=priv_info_flat) last_obs_td = _build_hora_obs_td(last_obs, device=self.device, priv_info=last_priv_info) critic_obs = critic_base critic_obs_flat = critic_obs.flatten(0, 1) critic_obs_td = _build_hora_obs_td(obs_flat, device=self.device, priv_info=priv_info_flat) critic_last_obs_td = _build_hora_obs_td( last_obs, device=self.device, priv_info=last_priv_info, ) if hasattr(self.actor, "update_normalization"): self.actor.update_normalization(obs_td) self.actor.update_normalization(last_obs_td) self.sync_target_actor_buffers() if hasattr(self.critic, "update_normalization"): self.critic.update_normalization(critic_obs_td) self.critic.update_normalization(critic_last_obs_td) batch_dict["_critic_obs_flat"] = critic_obs_flat with torch.inference_mode(): values_flat = self.critic(critic_obs_td) last_values = self.critic(critic_last_obs_td).squeeze(-1) values = values_flat.view(T, N, -1).squeeze(-1) actions_flat = actions.flatten(0, 1) with torch.inference_mode(): self.target_actor(obs_td, stochastic_output=True) target_log_probs_flat = self.target_actor.get_output_log_prob(actions_flat) batch_dict["_old_mu"] = self.target_actor.output_mean.clone() batch_dict["_old_sigma"] = self.target_actor.output_std.clone() target_log_probs = target_log_probs_flat.view(T, N) with torch.inference_mode(): rhos = torch.exp(target_log_probs - behavior_log_probs) rho_sample = _sample_tensor_for_metric(rhos) batch_dict["_appo_process_metrics"] = { "vtrace/rho_clip_fraction": float( (rhos > float(self.vtrace_clip_rho)).float().mean().item() ), "vtrace/rho_raw_p99": float(torch.quantile(rho_sample, 0.99).item()), } vs, advantages = vtrace_advantages( behavior_log_probs=behavior_log_probs, target_log_probs=target_log_probs, rewards=rewards, values=values, bootstrap_values=last_values, dones=dones, gamma=self.gamma, clip_rho=self.vtrace_clip_rho, clip_c=self.vtrace_clip_c, ) batch_dict["values"] = values batch_dict["advantages"] = advantages batch_dict["returns"] = vs batch_dict["target_log_probs"] = target_log_probs return batch_dict