unilab.algos.torch.appo.learner¶
Asynchronous PPO (APPO) Learner.
Based on IMPACT (Luo et al. 2020): Importance Weighted Asynchronous Architectures with Clipped Target Networks.
Key differences from standard PPO: - V-trace importance sampling correction for off-policy data - Target network with soft update for stable IS ratio computation - PPO clipping applied over IS-corrected ratios
Functions
|
Compute V-trace targets and advantages. |
Classes
Asynchronous PPO Learner. |
- unilab.algos.torch.appo.learner.vtrace_advantages(behavior_log_probs, target_log_probs, rewards, values, bootstrap_values, dones, gamma=0.99, clip_rho=1.0, clip_c=1.0)[source]¶
Compute V-trace targets and advantages.
V-trace (Espeholt et al., 2018) corrects for the off-policy nature of asynchronous data collection by using importance sampling ratios clipped at ρ̄ (rho_bar) and c̄ (c_bar).
- Returns:
V-trace value targets [T, N] advantages: Policy gradient advantages [T, N]
- Return type:
vs
- class unilab.algos.torch.appo.learner.APPOLearner[source]¶
Bases:
objectAsynchronous PPO Learner.
PPO update with V-trace off-policy correction and target network, decoupled from rollout collection.
Key features: - V-trace importance sampling for off-policy advantage estimation - Target network with soft update (tau) for stable IS computation - Observation normalization updated centrally, synced to workers - Time-out (truncation) bootstrap correction - Adaptive learning rate via KL-divergence target
- Parameters:
actor (
MLPModel)critic (
MLPModel)num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)max_grad_norm (
float)use_clipped_value_loss (
bool)schedule (
str)desired_kl (
float)adaptive_kl_factor (
float)adaptive_lr_factor (
float)device (
str)optimizer (
str)tau (
float)target_update_freq (
int)vtrace_clip_rho (
float)vtrace_clip_c (
float)enable_compile (
bool)
- __init__(actor, critic, num_learning_epochs=5, num_mini_batches=4, clip_param=0.2, gamma=0.99, lam=0.95, value_loss_coef=1.0, entropy_coef=0.01, learning_rate=0.001, max_grad_norm=1.0, use_clipped_value_loss=True, schedule='fixed', desired_kl=0.01, adaptive_kl_factor=1.2, adaptive_lr_factor=1.1, device='cpu', optimizer='adam', tau=1.0, target_update_freq=1, vtrace_clip_rho=1.0, vtrace_clip_c=1.0, enable_compile=False, **kwargs)[source]¶
- Parameters:
actor (
MLPModel)critic (
MLPModel)num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)max_grad_norm (
float)use_clipped_value_loss (
bool)schedule (
str)desired_kl (
float)adaptive_kl_factor (
float)adaptive_lr_factor (
float)device (
str)optimizer (
str)tau (
float)target_update_freq (
int)vtrace_clip_rho (
float)vtrace_clip_c (
float)enable_compile (
bool)
- update_target_network()[source]¶
Soft update target actor: target = tau * current + (1 - tau) * target.
- sync_target_actor_buffers()[source]¶
Copy actor buffers such as observation-normalization stats to target actor.
- get_weights()[source]¶
Return actor state dict for syncing to workers.
Workers use the behavior policy (which may be stale). Includes EmpiricalNormalization buffers.