"""APPO Runner — Asynchronous PPO with native multiprocessing.
Pipeline:
1. Collector subprocess publishes rollout payloads → RolloutRingBuffer
2. Learner reads rollouts, computes V-trace corrected updates
3. Weights synced back to collector via SharedWeightSync
"""
import multiprocessing as mp
import os
import sys
import time
from collections import deque
from copy import deepcopy
from typing import Any
import torch
from rsl_rl.utils import resolve_callable
from unilab.algos.torch.appo.learner import APPOLearner
from unilab.algos.torch.appo.staging import RolloutStagingPool
from unilab.algos.torch.appo.worker import appo_collector_fn
from unilab.ipc import AsyncRunner, RolloutRingBuffer, SharedWeightSync
from unilab.logging import OffPolicyLogger
from unilab.training.seed import apply_training_seed, derive_worker_seed
def _optimizer_lr_from_state(optimizer: torch.optim.Optimizer) -> float:
if not optimizer.param_groups:
return 0.0
return float(optimizer.param_groups[0].get("lr", 0.0))
def _sync_resume_target_actor(learner: APPOLearner) -> None:
"""After resume, target actor must exactly match the restored actor."""
target_actor = getattr(learner, "target_actor", None)
if target_actor is None:
learner.update_target_network()
return
target_actor.load_state_dict(learner.actor.state_dict())
target_actor.eval()
for param in target_actor.parameters():
param.requires_grad = False
[docs]
class APPORunner(AsyncRunner):
"""APPO async runner using shared memory."""
[docs]
def __init__(
self,
env_name: str,
env_cfg_overrides: dict,
rl_cfg: dict,
device: str | None = None,
collector_device: str | None = None,
sim_backend: str = "mujoco",
num_envs: int = 1024,
steps_per_env: int = 24,
num_workers: int = 1, # kept for API compat, but only 1 collector used
replay_queue_size: int = 3,
seed: int | None = None,
resume_path: str | None = None,
):
del num_workers
super().__init__(
env_name=env_name,
env_cfg_overrides=env_cfg_overrides,
rl_cfg=rl_cfg,
device=device,
collector_device=collector_device,
sim_backend=sim_backend,
num_envs=num_envs,
)
self.steps_per_env = steps_per_env
self.replay_queue_size = replay_queue_size
self.staging_pool_size = replay_queue_size
self.seed = seed
self.resume_path = resume_path
if self.staging_pool_size < 1:
raise ValueError("APPO staging pool size must be >= 1")
# Resolve dims
self._resolve_dims()
def _get_default_device(self) -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def _resolve_dims(self):
self.obs_dim, self.action_dim = self._detect_dims()
# Update rl_cfg so internal RSL-RL networks get correct observation dimension
if "obs_groups" not in self.rl_cfg:
self.rl_cfg["obs_groups"] = {
"actor": {"policy": self.obs_dim},
"critic": {"policy": self.critic_input_dim},
}
else:
actor_group = self.rl_cfg["obs_groups"].get(
"actor", self.rl_cfg["obs_groups"].get("policy", {})
)
if isinstance(actor_group, dict) and "policy" in actor_group:
actor_group["policy"] = self.obs_dim
critic_group = self.rl_cfg["obs_groups"].get("critic")
if critic_group is None:
self.rl_cfg["obs_groups"]["critic"] = {"policy": self.critic_input_dim}
elif isinstance(critic_group, dict) and "policy" in critic_group:
critic_group["policy"] = self.critic_input_dim
def _detect_dims(self):
"""Create a tiny env to read obs/action dims, then close it."""
from unilab.base import registry
from unilab.base.observations import get_critic_base_dim, get_obs_dims
from unilab.base.registry import ensure_registries
ensure_registries()
apply_training_seed(self.seed, torch_runtime=True, cuda=True)
env = registry.make(
self.env_name,
num_envs=1,
sim_backend=self.sim_backend,
env_cfg_override=self.env_cfg_overrides if self.env_cfg_overrides else None,
)
obs_dim, critic_dim = get_obs_dims(env.obs_groups_spec)
self.critic_dim = critic_dim
self.critic_input_dim = get_critic_base_dim(env.obs_groups_spec)
assert env.action_space.shape is not None
action_dim = env.action_space.shape[0]
env.close()
return obs_dim, action_dim
def _build_learner(self):
cfg = dict(self.rl_cfg)
import torch
from tensordict import TensorDict
apply_training_seed(self.seed, torch_runtime=True, cuda=True)
obs_example = torch.zeros((self.num_envs, self.obs_dim), device=self.device)
td_example = TensorDict({"policy": obs_example}, batch_size=self.num_envs)
critic_obs_dim = self.critic_input_dim
critic_obs_example = torch.zeros((self.num_envs, critic_obs_dim), device=self.device)
critic_td_example = TensorDict({"policy": critic_obs_example}, batch_size=self.num_envs)
# Build actor (stochastic MLPModel — distribution_cfg carries GaussianDistribution)
# deepcopy so MLPModel.__init__'s distribution_cfg.pop("class_name") doesn't
# mutate the shared rl_cfg that gets sent to the collector subprocess.
actor_cfg = deepcopy(cfg.get("actor", {}))
actor_cls = resolve_callable(actor_cfg.pop("class_name"))
actor_cfg.pop("num_actions", None)
actor = actor_cls(td_example, cfg["obs_groups"], "actor", self.action_dim, **actor_cfg)
# Build critic (deterministic MLPModel, no distribution).
critic_cfg: dict[str, Any] = deepcopy(cfg.get("critic") or cfg.get("actor") or {})
critic_cls = resolve_callable(critic_cfg.pop("class_name", "rsl_rl.models.MLPModel"))
critic_cfg.pop("num_actions", None)
critic_cfg.pop("distribution_cfg", None) # critic is deterministic
critic = critic_cls(critic_td_example, cfg["obs_groups"], "critic", 1, **critic_cfg)
# Extract algorithm hyperparams from rl_cfg["algorithm"] (or top-level)
algo_cfg = cfg.get("algorithm", cfg)
learner = APPOLearner(
actor=actor,
critic=critic,
device=self.device,
num_learning_epochs=algo_cfg.get("num_learning_epochs", 5),
num_mini_batches=algo_cfg.get("num_mini_batches", 4),
clip_param=algo_cfg.get("clip_param", 0.2),
gamma=algo_cfg.get("gamma", 0.99),
lam=algo_cfg.get("lam", 0.95),
value_loss_coef=algo_cfg.get("value_loss_coef", 1.0),
entropy_coef=algo_cfg.get("entropy_coef", 0.01),
learning_rate=algo_cfg.get("learning_rate", 1e-3),
max_grad_norm=algo_cfg.get("max_grad_norm", 1.0),
use_clipped_value_loss=algo_cfg.get("use_clipped_value_loss", True),
schedule=algo_cfg.get("schedule", "fixed"),
desired_kl=algo_cfg.get("desired_kl", 0.01),
adaptive_kl_factor=algo_cfg.get("adaptive_kl_factor", 1.2),
adaptive_lr_factor=algo_cfg.get("adaptive_lr_factor", 1.1),
optimizer=algo_cfg.get("optimizer", "adam"),
tau=algo_cfg.get("tau", 1.0),
target_update_freq=algo_cfg.get("target_update_freq", 1),
vtrace_clip_rho=algo_cfg.get("vtrace_clip_rho", 1.0),
vtrace_clip_c=algo_cfg.get("vtrace_clip_c", 1.0),
enable_compile=algo_cfg.get("enable_compile", True),
)
return learner
def _collector_fn(self, stop_event, **kwargs):
appo_collector_fn(stop_event=stop_event, **kwargs)
[docs]
def learn(
self,
max_iterations: int = 1500,
save_interval: int = 50,
log_dir: str = "logs",
logger_type: str = "tensorboard",
) -> None:
os.makedirs(log_dir, exist_ok=True)
train_start_wall = time.time()
best_mean_reward = float("-inf")
last_mean_reward = 0.0
ckpt_path: str | None = None
iteration = 0
learner = self._build_learner()
if self.resume_path:
checkpoint = torch.load(self.resume_path, map_location=self.device, weights_only=True)
learner.actor.load_state_dict(checkpoint["actor"])
learner.critic.load_state_dict(checkpoint["critic"])
if "optimizer" in checkpoint:
learner.optimizer.load_state_dict(checkpoint["optimizer"])
learner.learning_rate = _optimizer_lr_from_state(learner.optimizer)
_sync_resume_target_actor(learner)
# --- memory budget check ---
from unilab.ipc.memory_budget import estimate_appo_bytes, warn_if_over_budget
mem_est = estimate_appo_bytes(
num_envs=self.num_envs,
steps_per_env=self.steps_per_env,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
critic_dim=self.critic_dim,
num_slots=4,
)
warn_if_over_budget(mem_est, label="APPO")
# Create shared rollout IPC ring buffer; learner-side tensor lifetime is
# owned by the bounded staging pool below.
rollout_ring_buffer = RolloutRingBuffer(
num_envs=self.num_envs,
num_steps=self.steps_per_env,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
critic_dim=self.critic_dim,
num_slots=4,
create=True,
)
self._shared_resources.append(rollout_ring_buffer)
# Create weight sync for collector-side actor and critic bootstrap values.
actor_weight_sync = SharedWeightSync.from_state_dict(
learner.actor.state_dict(), create=True
)
critic_weight_sync = SharedWeightSync.from_state_dict(
learner.critic.state_dict(), create=True
)
self._shared_resources.extend([actor_weight_sync, critic_weight_sync])
actor_weight_param_shapes = {
name: p.shape for name, p in learner.actor.state_dict().items()
}
critic_weight_param_shapes = {
name: p.shape for name, p in learner.critic.state_dict().items()
}
metrics_queue: mp.Queue = mp.get_context("spawn").Queue(maxsize=100)
# Start collector
collector_kwargs = {
"env_name": self.env_name,
"rl_cfg": self.rl_cfg,
"num_envs": self.num_envs,
"steps_per_env": self.steps_per_env,
"shm_rollout_ring_buffer_name": rollout_ring_buffer.name,
"sync_primitives": (
rollout_ring_buffer._write_ptr,
rollout_ring_buffer._read_ptr,
),
"obs_dim": self.obs_dim,
"action_dim": self.action_dim,
"critic_dim": self.critic_dim,
"actor_weight_sync_name": actor_weight_sync.name,
"actor_weight_param_shapes": actor_weight_param_shapes,
"critic_weight_sync_name": critic_weight_sync.name,
"critic_weight_param_shapes": critic_weight_param_shapes,
"metrics_queue": metrics_queue,
"collector_device": self.collector_device,
"sim_backend": self.sim_backend,
"env_cfg_override": self.env_cfg_overrides if self.env_cfg_overrides else None,
"seed": derive_worker_seed(self.seed, worker_index=0),
}
self._start_collector(
target_fn=appo_collector_fn,
kwargs={"stop_event": self._stop_event, **collector_kwargs},
)
env_steps_per_sync = self.steps_per_env * self.num_envs
logger = OffPolicyLogger(
algo_name="APPO",
max_iterations=max_iterations,
num_envs=self.num_envs,
env_name=self.env_name,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
log_dir=log_dir,
log_backend=logger_type,
)
logger.set_collection_sync(True, env_steps_per_sync)
logger.log_status(
f"Waiting for first rollout... "
f"(staging_pool={self.staging_pool_size}, "
f"epochs={learner.num_learning_epochs})"
)
logger_started = False
reward_history: deque = deque(maxlen=200)
latest_reward_components: dict = {}
staging_pool = RolloutStagingPool(
capacity=self.staging_pool_size,
num_envs=self.num_envs,
slot_shapes=rollout_ring_buffer.slot_shapes,
device=self.device,
)
for iteration in range(1, max_iterations + 1):
# Drain collector metrics while waiting for next rollout
self._drain_metrics(metrics_queue, reward_history, latest_reward_components, logger)
wait_start = time.time()
data_ready = rollout_ring_buffer.wait_for_data(timeout=60.0)
if not data_ready:
# Check if the collector subprocess died — fail fast instead of
# burning through remaining iterations with 60s timeouts each.
if not self._check_collector_alive():
self._drain_metrics(
metrics_queue, reward_history, latest_reward_components, logger
)
raise RuntimeError(
"APPO collector process died before producing data. "
"Check stderr for [APPO WORKER CRASH] messages."
)
logger.log_status(
f"[yellow]Warning: Timeout waiting for data at iteration {iteration}[/]"
)
continue
if not logger_started:
logger.start(status="Training")
logger_started = True
available_on_arrive = rollout_ring_buffer.available()
wait_time = time.time() - wait_start
# Drain ALL available slots into the staging pool in one pass.
# This keeps the GPU busy: if the collector produced 3 rollouts while
# the learner was training, we consume all 3 immediately rather than
# processing them one-per-iteration.
num_new = rollout_ring_buffer.available()
learner_incremental_h2d_time = 0.0
for _ in range(num_new):
h2d_start = time.perf_counter()
staging_pool.stage_numpy_views(rollout_ring_buffer.read_numpy_views())
learner_incremental_h2d_time += time.perf_counter() - h2d_start
rollout_ring_buffer.advance_read()
self._drain_metrics(metrics_queue, reward_history, latest_reward_components, logger)
combined = staging_pool.batch()
train_start = time.time()
learner.process_batch(combined)
metrics = learner.update(combined)
train_time = time.time() - train_start
weight_sync_start = time.perf_counter()
actor_weight_sync.write_weights(learner.actor.state_dict())
critic_weight_sync.write_weights(learner.critic.state_dict())
weight_sync_time = time.perf_counter() - weight_sync_start
metrics["staging_pool_len"] = float(staging_pool.active_count)
metrics["staging_pool_capacity"] = float(staging_pool.capacity)
metrics["available_on_arrive"] = float(available_on_arrive)
metrics["rollouts_read"] = float(num_new)
logger.update_staging_pool(staging_pool.active_count, staging_pool.capacity)
mean_reward = (
sum(list(reward_history)[-50:]) / max(len(list(reward_history)[-50:]), 1)
if reward_history
else 0.0
)
last_mean_reward = float(mean_reward)
best_mean_reward = max(best_mean_reward, last_mean_reward)
logger.log_step(
iteration=iteration,
metrics=metrics,
reward=mean_reward,
reward_components=latest_reward_components,
train_time=train_time,
wait_time=wait_time,
learner_incremental_h2d_time=learner_incremental_h2d_time,
weight_sync_time=weight_sync_time,
extra_info={
"throughput_steps": num_new * env_steps_per_sync,
},
)
if save_interval > 0 and iteration % save_interval == 0:
ckpt_path = os.path.join(log_dir, f"model_{iteration}.pt")
torch.save(learner.get_state_dict(), ckpt_path)
logger.log_save(ckpt_path)
ckpt_path = os.path.join(log_dir, f"model_{max_iterations}.pt")
torch.save(learner.get_state_dict(), ckpt_path)
logger.log_save(ckpt_path)
logger.finish()
summary = {
"status": "completed",
"completed_iterations": iteration,
"total_env_steps": int(logger._total_steps),
"final_mean_reward": last_mean_reward if reward_history else None,
"best_mean_reward": best_mean_reward if reward_history else None,
"mean_episode_length": float(logger._mean_ep_length),
"last_checkpoint": ckpt_path,
"training_wall_time_sec": time.time() - train_start_wall,
}
self.last_run_summary = summary
# _check_collector_alive() inherited from AsyncRunner base class
@staticmethod
def _drain_metrics(queue, reward_history, reward_components, logger):
"""Drain all pending messages from the collector metrics queue.
Mirrors OffPolicyRunner._drain_metrics so APPO has the same
logger update coverage (ep_length, done rates, collector timing).
"""
while not queue.empty():
try:
m = queue.get_nowait()
if "error" in m:
logger.log_status(f"[red]Collector ERROR: {m['error']}[/]")
raise RuntimeError(f"Collector process failed: {m['error']}")
if "mean_ep_reward" in m:
reward_history.append(m["mean_ep_reward"])
if "reward_components" in m:
reward_components.clear()
reward_components.update(m["reward_components"])
if "mean_ep_length" in m:
logger.update_ep_length(m["mean_ep_length"])
if "collector_timing_ms" in m:
logger.update_collector_timing(m["collector_timing_ms"])
if "timeout_rate" in m or "terminated_rate" in m:
logger.update_done_rates(
timeout_rate=float(m.get("timeout_rate", 0.0)),
terminated_rate=float(m.get("terminated_rate", 0.0)),
)
if "total_steps" in m:
logger.log_collector(
m["total_steps"],
0, # APPO uses shared memory, not a separate buffer
m.get("mean_ep_reward", 0.0),
)
except Exception as e:
print(f"[APPORunner] metrics drain error: {e}", file=sys.stderr)
break