"""Asynchronous PPO (APPO) Learner.
Based on IMPACT (Luo et al. 2020): Importance Weighted Asynchronous
Architectures with Clipped Target Networks.
Key differences from standard PPO:
- V-trace importance sampling correction for off-policy data
- Target network with soft update for stable IS ratio computation
- PPO clipping applied over IS-corrected ratios
"""
import copy
import math
from collections.abc import Iterable
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from rsl_rl.models import MLPModel
from rsl_rl.utils import resolve_optimizer
from tensordict import TensorDict
_LOG_2_PI = math.log(2.0 * math.pi)
_NORMAL_ENTROPY_OFFSET = 0.5 * (1.0 + _LOG_2_PI)
def _distribution_std(distribution: Any, mean: torch.Tensor) -> torch.Tensor:
"""Return broadcast std tensor for rsl_rl GaussianDistribution."""
if distribution is None:
raise RuntimeError("APPO actor must expose a stochastic distribution")
if distribution.std_type == "scalar":
return distribution.std_param.expand_as(mean)
return torch.exp(distribution.log_std_param).expand_as(mean)
def _sample_tensor_for_metric(tensor: torch.Tensor, max_items: int = 8192) -> torch.Tensor:
"""Return a deterministic bounded sample for scalar metrics."""
flat = tensor.detach().reshape(-1)
if flat.numel() <= max_items:
return flat
stride = max(flat.numel() // max_items, 1)
return flat[::stride][:max_items]
def _grad_norm(parameters) -> float:
"""Compute gradient norm without mutating gradients or clipping."""
norms = [p.grad.detach().norm(2) for p in parameters if getattr(p, "grad", None) is not None]
if not norms:
return 0.0
return float(torch.norm(torch.stack(norms), 2).item())
def _unique_parameters(parameters: Iterable[torch.nn.Parameter]) -> list[torch.nn.Parameter]:
"""Return parameters without duplicate object ids, preserving first use order."""
unique: list[torch.nn.Parameter] = []
seen: set[int] = set()
for param in parameters:
ident = id(param)
if ident in seen:
continue
seen.add(ident)
unique.append(param)
return unique
[docs]
def vtrace_advantages(
behavior_log_probs, # [T, N] log π_b(a|s) from worker
target_log_probs, # [T, N] log π_target(a|s) from target network
rewards, # [T, N]
values, # [T, N]
bootstrap_values, # [N] V(s_{T})
dones, # [T, N] float
gamma=0.99,
clip_rho=1.0,
clip_c=1.0,
):
"""Compute V-trace targets and advantages.
V-trace (Espeholt et al., 2018) corrects for the off-policy nature
of asynchronous data collection by using importance sampling ratios
clipped at ρ̄ (rho_bar) and c̄ (c_bar).
Returns:
vs: V-trace value targets [T, N]
advantages: Policy gradient advantages [T, N]
"""
T, N = rewards.shape
device = values.device
with torch.no_grad():
# IS ratios: ρ_t = π_target(a_t|s_t) / π_behavior(a_t|s_t)
log_rhos = target_log_probs - behavior_log_probs
rhos = torch.exp(log_rhos)
clipped_rhos = torch.clamp(rhos, max=clip_rho)
cs = torch.clamp(rhos, max=clip_c)
non_terminal = 1.0 - dones
# Vectorized next_values: shift values by 1, fill last step with bootstrap
next_values = torch.cat([values[1:], bootstrap_values.unsqueeze(0)], dim=0)
# Temporal difference errors
deltas = clipped_rhos * (rewards + gamma * next_values * non_terminal - values)
# Backward accumulation of V-trace corrections — run on CPU numpy to avoid
# T sequential GPU kernel launches (one-time transfer cost is cheaper).
deltas_np = deltas.cpu().numpy()
non_terminal_np = non_terminal.cpu().numpy()
cs_np = cs.cpu().numpy()
values_np = values.cpu().numpy()
vs_np = np.empty_like(values_np)
vs_minus_v = np.zeros(N, dtype=np.float32)
for t in range(T - 1, -1, -1):
vs_minus_v = deltas_np[t] + gamma * non_terminal_np[t] * cs_np[t] * vs_minus_v
vs_np[t] = values_np[t] + vs_minus_v
vs = torch.from_numpy(vs_np).to(device)
# Vectorized policy gradient advantages
next_vs = torch.cat([vs[1:], bootstrap_values.unsqueeze(0)], dim=0)
advantages = clipped_rhos * (rewards + gamma * next_vs * non_terminal - values)
return vs, advantages
[docs]
class APPOLearner:
"""Asynchronous PPO Learner.
PPO update with V-trace off-policy correction and target network,
decoupled from rollout collection.
Key features:
- V-trace importance sampling for off-policy advantage estimation
- Target network with soft update (tau) for stable IS computation
- Observation normalization updated centrally, synced to workers
- Time-out (truncation) bootstrap correction
- Adaptive learning rate via KL-divergence target
"""
[docs]
def __init__(
self,
actor: MLPModel,
critic: MLPModel,
num_learning_epochs: int = 5,
num_mini_batches: int = 4,
clip_param: float = 0.2,
gamma: float = 0.99,
lam: float = 0.95,
value_loss_coef: float = 1.0,
entropy_coef: float = 0.01,
learning_rate: float = 1e-3,
max_grad_norm: float = 1.0,
use_clipped_value_loss: bool = True,
schedule: str = "fixed",
desired_kl: float = 0.01,
adaptive_kl_factor: float = 1.2,
adaptive_lr_factor: float = 1.1,
device: str = "cpu",
optimizer: str = "adam",
# APPO-specific parameters
tau: float = 1.0,
target_update_freq: int = 1,
vtrace_clip_rho: float = 1.0,
vtrace_clip_c: float = 1.0,
enable_compile: bool = False,
**kwargs,
):
self.device = device
self._device_type = torch.device(device).type
self.actor = actor.to(self.device)
self.critic = critic.to(self.device)
# Target actor for V-trace IS computation
self.target_actor = copy.deepcopy(self.actor).to(self.device)
self.target_actor.eval()
for p in self.target_actor.parameters():
p.requires_grad = False
actor_params = _unique_parameters(self.actor.parameters())
critic_params = _unique_parameters(self.critic.parameters())
self._combined_params = _unique_parameters([*actor_params, *critic_params])
# PPO parameters
self.clip_param = clip_param
self.num_learning_epochs = num_learning_epochs
self.num_mini_batches = num_mini_batches
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.gamma = gamma
self.lam = lam
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self.desired_kl = desired_kl
self.schedule = schedule
self.learning_rate = learning_rate
if adaptive_kl_factor <= 1.0:
raise ValueError(f"adaptive_kl_factor must be > 1.0, got {adaptive_kl_factor}")
if adaptive_lr_factor <= 1.0:
raise ValueError(f"adaptive_lr_factor must be > 1.0, got {adaptive_lr_factor}")
self.adaptive_kl_factor = adaptive_kl_factor
self.adaptive_lr_factor = adaptive_lr_factor
# APPO-specific parameters
self.tau = tau
self.target_update_freq = target_update_freq
self.vtrace_clip_rho = vtrace_clip_rho
self.vtrace_clip_c = vtrace_clip_c
self._update_counter = 0
self.last_update_metrics: dict[str, float] = {}
self.enable_compile = (
bool(enable_compile) and self._device_type == "cuda" and hasattr(torch, "compile")
)
# Optimizer
self.optimizer = resolve_optimizer(optimizer)( # pyright: ignore[reportCallIssue]
self._combined_params, lr=learning_rate
)
self._minibatch_loss_fn = self._minibatch_loss_tensors
if self.enable_compile:
self._compile_training_methods()
def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or self._device_type != "cuda":
return
self._minibatch_loss_fn = compile_fn(
self._minibatch_loss_tensors,
mode="reduce-overhead",
fullgraph=False,
)
def _actor_mean_std(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
distribution: Any = self.actor.distribution
mean = self.actor.mlp(self.actor.obs_normalizer(obs))
return mean, _distribution_std(distribution, mean)
def _critic_value(self, obs: torch.Tensor) -> torch.Tensor:
return self.critic.mlp(self.critic.obs_normalizer(obs)).squeeze(-1)
def _minibatch_policy_value(
self,
obs_mini: torch.Tensor,
critic_obs_mini: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mu, sigma = self._actor_mean_std(obs_mini)
value = self._critic_value(critic_obs_mini)
return mu, sigma, value
@staticmethod
def _gaussian_log_prob(
actions: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
normalized = (actions - mean) / std
return (-0.5 * (normalized.pow(2) + 2.0 * torch.log(std) + _LOG_2_PI)).sum(dim=-1)
@staticmethod
def _gaussian_entropy(std: torch.Tensor) -> torch.Tensor:
return (torch.log(std) + _NORMAL_ENTROPY_OFFSET).sum(dim=-1)
def _minibatch_loss_tensors(
self,
obs_mini: torch.Tensor,
critic_obs_mini: torch.Tensor,
actions_mini: torch.Tensor,
target_values_mini: torch.Tensor,
advantages_mini: torch.Tensor,
behavior_logp_mini: torch.Tensor,
old_values_mini: torch.Tensor,
target_logp_mini: torch.Tensor,
old_mu_mini: torch.Tensor,
old_sigma_mini: torch.Tensor,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
mu, sigma, value = self._minibatch_policy_value(obs_mini, critic_obs_mini)
current_log_prob = self._gaussian_log_prob(actions_mini, mu, sigma)
entropy = self._gaussian_entropy(sigma).mean()
with torch.no_grad():
clipped_rho = torch.clamp(torch.exp(behavior_logp_mini - target_logp_mini), max=1.0)
ratio = clipped_rho * torch.exp(current_log_prob - behavior_logp_mini)
surrogate = -advantages_mini * ratio
surrogate_clipped = -advantages_mini * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
if self.use_clipped_value_loss:
value_clipped = old_values_mini + (value - old_values_mini).clamp(
-self.clip_param, self.clip_param
)
value_losses = (value - target_values_mini).pow(2)
value_losses_clipped = (value_clipped - target_values_mini).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (value - target_values_mini).pow(2).mean()
kl = torch.sum(
torch.log(sigma / old_sigma_mini + 1e-5)
+ (old_sigma_mini.pow(2) + (old_mu_mini - mu).pow(2)) / (2.0 * sigma.pow(2))
- 0.5,
dim=-1,
)
kl_mean = torch.mean(kl)
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy
return loss, surrogate_loss, value_loss, entropy, kl_mean, current_log_prob, ratio
[docs]
def train_mode(self):
"""Set actor/critic to training mode (enables EmpiricalNormalization.update)."""
self.actor.train()
self.critic.train()
[docs]
def eval_mode(self):
"""Set actor/critic to eval mode."""
self.actor.eval()
self.critic.eval()
[docs]
def update_target_network(self):
"""Soft update target actor: target = tau * current + (1 - tau) * target."""
for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)
# Also copy buffers (e.g. normalization stats)
self.sync_target_actor_buffers()
def _update_adaptive_learning_rate(self, kl_mean: float) -> None:
"""Update optimizer LR from KL according to the configured adaptive schedule."""
if self.desired_kl is None or self.schedule != "adaptive":
return
if kl_mean > self.desired_kl * self.adaptive_kl_factor:
self.learning_rate = max(1e-5, self.learning_rate / self.adaptive_lr_factor)
elif 0.0 < kl_mean < self.desired_kl / self.adaptive_kl_factor:
self.learning_rate = min(1e-2, self.learning_rate * self.adaptive_lr_factor)
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate
[docs]
def sync_target_actor_buffers(self):
"""Copy actor buffers such as observation-normalization stats to target actor."""
for target_buf, buf in zip(self.target_actor.buffers(), self.actor.buffers()):
target_buf.data.copy_(buf.data)
[docs]
def get_weights(self):
"""Return actor state dict for syncing to workers.
Workers use the behavior policy (which may be stale).
Includes EmpiricalNormalization buffers.
"""
return self.actor.state_dict()
[docs]
def get_state_dict(self):
"""Return full learner state for checkpointing."""
return {
"actor": self.actor.state_dict(),
"critic": self.critic.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
[docs]
def process_batch(self, batch_dict):
"""Compute V-trace targets on GPU.
Uses target network log-probs and behavior log-probs to compute
importance-sampling-corrected value targets and advantages.
"""
obs = batch_dict["observations"] # [T, N, D]
critic_base = batch_dict.get("critic", None) # [T, N, C] or None
rewards = batch_dict["rewards"] # [T, N]
dones = batch_dict["dones"].float() # [T, N]
last_obs = batch_dict["last_obs"] # [N, D]
last_critic = batch_dict.get("last_critic", None) # [N, C] or None
behavior_log_probs = batch_dict["actions_log_prob"] # [T, N]
actions = batch_dict["actions"] # [T, N, A]
T, N = obs.shape[:2]
obs_flat = obs.flatten(0, 1) # [T*N, D]
# Actor: obs only
obs_td = TensorDict({"policy": obs_flat}, batch_size=obs_flat.shape[0], device=self.device)
last_obs_td = TensorDict({"policy": last_obs}, batch_size=N, device=self.device)
# Critic: explicit critic obs when available, else actor obs
if critic_base is None:
critic_base = obs
if last_critic is None:
last_critic = last_obs
critic_obs = critic_base
critic_last_obs = last_critic
critic_obs_flat = critic_obs.flatten(0, 1) # [T*N, D+P]
critic_obs_td = TensorDict(
{"policy": critic_obs_flat}, batch_size=critic_obs_flat.shape[0], device=self.device
)
critic_last_obs_td = TensorDict(
{"policy": critic_last_obs}, batch_size=N, device=self.device
)
# Update Observation Normalization
if hasattr(self.actor, "update_normalization"):
self.actor.update_normalization(obs_td)
self.actor.update_normalization(last_obs_td)
self.sync_target_actor_buffers()
if hasattr(self.critic, "update_normalization"):
self.critic.update_normalization(critic_obs_td)
self.critic.update_normalization(critic_last_obs_td)
# Cache critic_obs_flat for update()
batch_dict["_critic_obs_flat"] = critic_obs_flat
with torch.inference_mode():
# Compute values with current critic
values_flat = self.critic(critic_obs_td) # [T*N, 1]
last_values = self.critic(critic_last_obs_td).squeeze(-1) # [N]
values = values_flat.view(T, N, -1).squeeze(-1) # [T, N]
# Compute target policy log-probs for V-trace IS ratios.
# Also cache mu/sigma here so update() doesn't need a second forward pass.
actions_flat = actions.flatten(0, 1) # [T*N, A]
with torch.inference_mode():
self.target_actor(obs_td, stochastic_output=True)
target_log_probs_flat = self.target_actor.get_output_log_prob(actions_flat)
batch_dict["_old_mu"] = self.target_actor.output_mean.clone()
batch_dict["_old_sigma"] = self.target_actor.output_std.clone()
target_log_probs = target_log_probs_flat.view(T, N)
with torch.inference_mode():
rhos = torch.exp(target_log_probs - behavior_log_probs)
rho_sample = _sample_tensor_for_metric(rhos)
batch_dict["_appo_process_metrics"] = {
"vtrace/rho_clip_fraction": float(
(rhos > float(self.vtrace_clip_rho)).float().mean().item()
),
"vtrace/rho_raw_p99": float(torch.quantile(rho_sample, 0.99).item()),
}
# V-trace targets and advantages
vs, advantages = vtrace_advantages(
behavior_log_probs=behavior_log_probs,
target_log_probs=target_log_probs,
rewards=rewards,
values=values,
bootstrap_values=last_values,
dones=dones,
gamma=self.gamma,
clip_rho=self.vtrace_clip_rho,
clip_c=self.vtrace_clip_c,
)
batch_dict["values"] = values
batch_dict["advantages"] = advantages
batch_dict["returns"] = vs # V-trace targets as returns
batch_dict["target_log_probs"] = target_log_probs
return batch_dict
[docs]
def update(self, batch_dict):
"""Perform the original main APPO update with additional detached metrics."""
obs_flat = batch_dict["observations"].flatten(0, 1)
actions_flat = batch_dict["actions"].flatten(0, 1)
returns_flat = batch_dict["returns"].flatten(0, 1)
advantages_flat = batch_dict["advantages"].flatten(0, 1)
behavior_log_probs_flat = batch_dict["actions_log_prob"].flatten(0, 1)
old_values_flat = batch_dict["values"].flatten(0, 1)
target_log_probs_flat = batch_dict["target_log_probs"].flatten(0, 1)
advantages_flat = (advantages_flat - advantages_flat.mean()) / (
advantages_flat.std() + 1e-8
)
critic_obs_flat = batch_dict.get("_critic_obs_flat")
if critic_obs_flat is None:
critic_base_flat = batch_dict.get("critic")
critic_obs_flat = (
obs_flat if critic_base_flat is None else critic_base_flat.flatten(0, 1)
)
with torch.inference_mode():
old_mu_flat = batch_dict["_old_mu"]
old_sigma_flat = batch_dict["_old_sigma"]
batch_size = obs_flat.shape[0]
mini_batch_size = batch_size // self.num_mini_batches
mean_surrogate_loss = 0.0
mean_value_loss = 0.0
mean_entropy = 0.0
mean_kl = 0.0
mean_clip_fraction = 0.0
mean_behavior_to_current_kl = 0.0
mean_target_to_current_kl = 0.0
mean_global_grad_norm = 0.0
num_updates = 0
for epoch in range(self.num_learning_epochs):
indices = torch.randperm(batch_size, device=self.device)
for i in range(self.num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
obs_mini = obs_flat[batch_idx]
critic_obs_mini = critic_obs_flat[batch_idx]
actions_mini = actions_flat[batch_idx]
target_values_mini = returns_flat[batch_idx]
advantages_mini = advantages_flat[batch_idx]
behavior_logp_mini = behavior_log_probs_flat[batch_idx]
old_values_mini = old_values_flat[batch_idx]
target_logp_mini = target_log_probs_flat[batch_idx]
old_mu_mini = old_mu_flat[batch_idx]
old_sigma_mini = old_sigma_flat[batch_idx]
(
loss,
surrogate_loss,
value_loss,
entropy,
kl_mean,
current_log_prob,
ratio,
) = self._minibatch_loss_fn(
obs_mini,
critic_obs_mini,
actions_mini,
target_values_mini,
advantages_mini,
behavior_logp_mini,
old_values_mini,
target_logp_mini,
old_mu_mini,
old_sigma_mini,
)
kl_mean_value: float | None = None
if self.desired_kl is not None and self.schedule == "adaptive":
with torch.inference_mode():
kl_mean_value = float(kl_mean.item())
self._update_adaptive_learning_rate(kl_mean_value)
mean_kl += kl_mean_value
mean_target_to_current_kl += kl_mean_value
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
global_grad_norm = _grad_norm(self._combined_params)
nn.utils.clip_grad_norm_(self._combined_params, self.max_grad_norm)
self.optimizer.step()
with torch.inference_mode():
clip_fraction = (torch.abs(ratio - 1.0) > self.clip_param).float().mean().item()
behavior_to_current_kl = (behavior_logp_mini - current_log_prob).mean().item()
if kl_mean_value is None:
kl_mean_value = float(kl_mean.item())
mean_target_to_current_kl += kl_mean_value
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
mean_entropy += entropy.item()
mean_clip_fraction += float(clip_fraction)
mean_behavior_to_current_kl += float(behavior_to_current_kl)
mean_global_grad_norm += global_grad_norm
num_updates += 1
self._update_counter += 1
if self._update_counter % self.target_update_freq == 0:
self.update_target_network()
num_updates = max(num_updates, 1)
final_lr = float(self.learning_rate)
self.learning_rate = final_lr
metrics = {
"surrogate_loss": mean_surrogate_loss / num_updates,
"value_loss": mean_value_loss / num_updates,
"entropy": mean_entropy / num_updates,
"kl": mean_kl / num_updates if self.schedule == "adaptive" else 0.0,
"loss/policy_loss": mean_surrogate_loss / num_updates,
"loss/value_loss": mean_value_loss / num_updates,
"policy/entropy": mean_entropy / num_updates,
"ppo/approx_kl": mean_target_to_current_kl / num_updates,
"ppo/clip_fraction": mean_clip_fraction / num_updates,
"grad/global_norm": mean_global_grad_norm / num_updates,
"optim/learning_rate": final_lr,
"policy_kl/behavior_to_current_kl": mean_behavior_to_current_kl / num_updates,
"appo/updates_executed": float(num_updates),
}
metrics.update(batch_dict.get("_appo_process_metrics", {}))
self.last_update_metrics = metrics
return metrics