unilab.algos.torch.hora.appo_learner.HoraAPPOLearner¶
- class unilab.algos.torch.hora.appo_learner.HoraAPPOLearner[source]¶
Bases:
APPOLearnerAPPO learner variant for HORA grouped observations.
- Parameters:
actor (
MLPModel)critic (
MLPModel)num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)max_grad_norm (
float)use_clipped_value_loss (
bool)schedule (
str)desired_kl (
float)adaptive_kl_factor (
float)adaptive_lr_factor (
float)device (
str)optimizer (
str)tau (
float)target_update_freq (
int)vtrace_clip_rho (
float)vtrace_clip_c (
float)enable_compile (
bool)
Methods
__init__(actor, critic[, ...])Set actor/critic to eval mode.
Return full learner state for checkpointing.
Return actor state dict for syncing to workers.
process_batch(batch_dict)Compute V-trace targets for grouped HORA rollouts.
Copy actor buffers such as observation-normalization stats to target actor.
Set actor/critic to training mode (enables EmpiricalNormalization.update).
update(batch_dict)Perform the original main APPO update with additional detached metrics.
Soft update target actor: target = tau * current + (1 - tau) * target.
- __init__(actor, critic, num_learning_epochs=5, num_mini_batches=4, clip_param=0.2, gamma=0.99, lam=0.95, value_loss_coef=1.0, entropy_coef=0.01, learning_rate=0.001, max_grad_norm=1.0, use_clipped_value_loss=True, schedule='fixed', desired_kl=0.01, adaptive_kl_factor=1.2, adaptive_lr_factor=1.1, device='cpu', optimizer='adam', tau=1.0, target_update_freq=1, vtrace_clip_rho=1.0, vtrace_clip_c=1.0, enable_compile=False, **kwargs)¶
- Parameters:
actor (
MLPModel)critic (
MLPModel)num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)max_grad_norm (
float)use_clipped_value_loss (
bool)schedule (
str)desired_kl (
float)adaptive_kl_factor (
float)adaptive_lr_factor (
float)device (
str)optimizer (
str)tau (
float)target_update_freq (
int)vtrace_clip_rho (
float)vtrace_clip_c (
float)enable_compile (
bool)
- eval_mode()¶
Set actor/critic to eval mode.
- get_state_dict()¶
Return full learner state for checkpointing.
- get_weights()¶
Return actor state dict for syncing to workers.
Workers use the behavior policy (which may be stale). Includes EmpiricalNormalization buffers.
- sync_target_actor_buffers()¶
Copy actor buffers such as observation-normalization stats to target actor.
- train_mode()¶
Set actor/critic to training mode (enables EmpiricalNormalization.update).
- update(batch_dict)¶
Perform the original main APPO update with additional detached metrics.
- update_target_network()¶
Soft update target actor: target = tau * current + (1 - tau) * target.