Source code for unilab.algos.mlx.ppo.ppo

"""PPO trainer implemented with MLX."""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, Dict, cast

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map

from unilab.algos.mlx.common import RolloutBuffer, diag_gaussian_entropy, diag_gaussian_log_prob

from .model import MLPActorCritic


[docs] @dataclass class PPOConfig: num_learning_epochs: int = 4 num_mini_batches: int = 4 clip_param: float = 0.2 gamma: float = 0.99 lam: float = 0.95 value_loss_coef: float = 0.5 entropy_coef: float = 0.0 learning_rate: float = 3e-4 use_clipped_value_loss: bool = True max_grad_norm: float = 1.0 log_ratio_clip: float = 20.0 schedule: str = "fixed" desired_kl: float = 0.01 min_learning_rate: float = 1e-5 max_learning_rate: float = 1e-2 normalize_advantage_per_mini_batch: bool = False adaptive_kl_beta: float = 0.9 adaptive_lr_decay: float = 1.5 adaptive_lr_growth: float = 1.2 adaptive_lr_update_interval: int = 1 target_kl_stop: float | None = None metrics_interval: int = 8 finite_check_interval: int = 8 enable_compile: bool = False warmup_strict_iters: int = 0 warmup_metrics_interval: int = 1 warmup_finite_check_interval: int = 1 disable_finite_checks: bool = False
[docs] class PPOTrainer: """PPO update logic for `MLPActorCritic` and `RolloutBuffer`."""
[docs] def __init__(self, model: MLPActorCritic, cfg: PPOConfig) -> None: self.model = model self.cfg = cfg self._dtype = getattr(model, "dtype", mx.float32) self.learning_rate = float(cfg.learning_rate) self.optimizer = optim.Adam(learning_rate=self.learning_rate) self.loss_and_grad = nn.value_and_grad(model, self._loss_fn) self.compiled_loss_and_grad = self.loss_and_grad if self.cfg.enable_compile and hasattr(mx, "compile"): try: self.compiled_loss_and_grad = mx.compile(self.loss_and_grad) except Exception: self.compiled_loss_and_grad = self.loss_and_grad self._kl_ema: float | None = None
@staticmethod def _tree_leaves(tree: Any) -> list[Any]: flat_tree = cast(list[tuple[str, Any]], tree_flatten(tree)) return [leaf for _, leaf in flat_tree] @staticmethod def _all_finite(tree) -> bool: leaves = PPOTrainer._tree_leaves(tree) if not leaves: return True checks = [mx.all(mx.isfinite(leaf)) for leaf in leaves] # type: ignore[arg-type] mx.eval(*checks) return all(bool(c.item()) for c in checks) def _clip_grads(self, grads): """Global gradient clipping similar to rsl-rl max_grad_norm.""" if self.cfg.max_grad_norm <= 0.0: return grads leaves = self._tree_leaves(grads) if not leaves: return grads sq_norm = mx.array(0.0, dtype=self._dtype) for leaf in leaves: sq_norm = sq_norm + mx.sum(leaf * leaf) global_norm = mx.sqrt(sq_norm + 1e-12) clip_coef = mx.minimum(1.0, self.cfg.max_grad_norm / (global_norm + 1e-6)) return tree_map(lambda g: g * clip_coef, grads) def _loss_fn(self, model: MLPActorCritic, batch: Dict[str, mx.array]) -> mx.array: obs = batch["obs"] actions = batch["actions"] old_log_probs = batch["old_log_probs"] returns = batch["returns"] advantages = batch["advantages"] old_values = batch["old_values"] if self.cfg.normalize_advantage_per_mini_batch: advantages = (advantages - mx.mean(advantages)) / (mx.std(advantages) + 1e-8) mean, sigma, log_std = model.distribution_params(obs) values = model.value(obs) log_probs = diag_gaussian_log_prob(actions, mean, log_std) entropy = mx.mean(diag_gaussian_entropy(log_std)) log_ratio = mx.clip( log_probs - old_log_probs, -self.cfg.log_ratio_clip, self.cfg.log_ratio_clip ) ratio = mx.exp(log_ratio) surr1 = ratio * advantages surr2 = mx.clip(ratio, 1.0 - self.cfg.clip_param, 1.0 + self.cfg.clip_param) * advantages policy_loss = -mx.mean(mx.minimum(surr1, surr2)) if self.cfg.use_clipped_value_loss: value_pred_clipped = old_values + mx.clip( values - old_values, -self.cfg.clip_param, self.cfg.clip_param ) value_losses = (values - returns) ** 2 value_losses_clipped = (value_pred_clipped - returns) ** 2 value_loss = mx.mean(mx.maximum(value_losses, value_losses_clipped)) else: value_loss = mx.mean((returns - values) ** 2) return policy_loss + self.cfg.value_loss_coef * value_loss - self.cfg.entropy_coef * entropy def _metrics(self, batch: Dict[str, mx.array]) -> Dict[str, float]: obs = batch["obs"] actions = batch["actions"] old_log_probs = batch["old_log_probs"] returns = batch["returns"] advantages = batch["advantages"] old_values = batch["old_values"] old_mu = batch["old_mu"] old_sigma = batch["old_sigma"] if self.cfg.normalize_advantage_per_mini_batch: advantages = (advantages - mx.mean(advantages)) / (mx.std(advantages) + 1e-8) mean, sigma, log_std = self.model.distribution_params(obs) values = self.model.value(obs) log_probs = diag_gaussian_log_prob(actions, mean, log_std) entropy = mx.mean(diag_gaussian_entropy(log_std)) sigma = mx.maximum(sigma, 1e-5) log_ratio = mx.clip( log_probs - old_log_probs, -self.cfg.log_ratio_clip, self.cfg.log_ratio_clip ) ratio = mx.exp(log_ratio) surr1 = ratio * advantages surr2 = mx.clip(ratio, 1.0 - self.cfg.clip_param, 1.0 + self.cfg.clip_param) * advantages policy_loss = -mx.mean(mx.minimum(surr1, surr2)) clip_fraction = mx.mean((mx.abs(ratio - 1.0) > self.cfg.clip_param).astype(self._dtype)) if self.cfg.use_clipped_value_loss: value_pred_clipped = old_values + mx.clip( values - old_values, -self.cfg.clip_param, self.cfg.clip_param ) value_losses = (values - returns) ** 2 value_losses_clipped = (value_pred_clipped - returns) ** 2 value_loss = mx.mean(mx.maximum(value_losses, value_losses_clipped)) else: value_loss = mx.mean((returns - values) ** 2) ratio_mean = mx.mean(ratio) ratio_max = mx.max(ratio) std_mean = mx.mean(sigma) adv_std = mx.std(advantages) returns_var = mx.var(returns) explained_variance = 1.0 - mx.var(returns - values) / (returns_var + 1e-8) # Match rsl-rl style analytic KL for adaptive LR. kl = mx.sum( mx.log(sigma / (old_sigma + 1e-5) + 1e-5) + (old_sigma**2 + (old_mu - mean) ** 2) / (2.0 * sigma**2) - 0.5, axis=-1, ) kl_mean = mx.mean(kl) mx.eval( policy_loss, value_loss, entropy, kl_mean, clip_fraction, ratio_mean, ratio_max, std_mean, adv_std, explained_variance, ) return { "surrogate": float(policy_loss.item()), "value": float(value_loss.item()), "entropy": float(entropy.item()), "approx_kl": float(kl_mean.item()), "clip_fraction": float(clip_fraction.item()), "ratio_mean": float(ratio_mean.item()), "ratio_max": float(ratio_max.item()), "std_mean": float(std_mean.item()), "adv_std": float(adv_std.item()), "value_explained_variance": float(explained_variance.item()), }
[docs] def update(self, buffer: RolloutBuffer, iteration: int = -1) -> Dict[str, float]: agg = { "surrogate": 0.0, "value": 0.0, "entropy": 0.0, "approx_kl": 0.0, "clip_fraction": 0.0, "ratio_mean": 0.0, "ratio_max": 0.0, "std_mean": 0.0, "adv_std": 0.0, "value_explained_variance": 0.0, } updates = 0 skipped_nonfinite_loss = 0 skipped_nonfinite_grads = 0 rolled_back_updates = 0 skipped_nonfinite_metrics = 0 early_stopped_kl = 0 last_metrics: Dict[str, float] | None = None in_warmup = (iteration >= 0) and (iteration < int(self.cfg.warmup_strict_iters)) metrics_interval = ( max(1, int(self.cfg.warmup_metrics_interval)) if in_warmup else max(1, int(self.cfg.metrics_interval)) ) finite_check_interval = ( max(1, int(self.cfg.warmup_finite_check_interval)) if in_warmup else max(1, int(self.cfg.finite_check_interval)) ) target_dtype = self._dtype for batch_idx, batch in enumerate( buffer.mini_batch_generator(self.cfg.num_mini_batches, self.cfg.num_learning_epochs) ): # Mixed precision: cast batch to model dtype (e.g. float32) when buffer is float16. batch = tree_map( lambda x: ( x.astype(target_dtype) if hasattr(x, "astype") and getattr(x, "dtype", None) != target_dtype else x ), batch, ) do_full_checks = batch_idx % finite_check_interval == 0 if self.cfg.disable_finite_checks: do_full_checks = False do_metrics = (batch_idx % metrics_interval == 0) or (last_metrics is None) try: loss, grads = self.compiled_loss_and_grad(self.model, batch) except Exception: # Fallback: some MLX versions do not support compiling this closure shape. self.compiled_loss_and_grad = self.loss_and_grad loss, grads = self.loss_and_grad(self.model, batch) if do_full_checks and (not mx.all(mx.isfinite(loss)).item()): skipped_nonfinite_loss += 1 continue if do_full_checks and (not self._all_finite(grads)): skipped_nonfinite_grads += 1 continue grads = self._clip_grads(grads) if do_full_checks and (not self._all_finite(grads)): skipped_nonfinite_grads += 1 continue self.optimizer.update(self.model, grads) mx.eval(loss, self.model.parameters(), self.optimizer.state) if do_full_checks and (not self._all_finite(self.model.parameters())): skipped_nonfinite_grads += 1 continue if do_metrics: metrics = self._metrics(batch) if not all(math.isfinite(v) for v in metrics.values()): skipped_nonfinite_metrics += 1 continue last_metrics = metrics else: metrics = ( last_metrics if last_metrics is not None else { "surrogate": 0.0, "value": 0.0, "entropy": 0.0, "approx_kl": 0.0, "clip_fraction": 0.0, "ratio_mean": 1.0, "ratio_max": 1.0, "std_mean": 0.0, "adv_std": 0.0, "value_explained_variance": 0.0, } ) if ( self.cfg.target_kl_stop is not None and metrics["approx_kl"] > self.cfg.target_kl_stop ): early_stopped_kl += 1 break if do_metrics and self.cfg.schedule == "adaptive" and self.cfg.desired_kl is not None: kl = metrics["approx_kl"] if self._kl_ema is None: self._kl_ema = kl else: beta = min(max(self.cfg.adaptive_kl_beta, 0.0), 0.999) self._kl_ema = beta * self._kl_ema + (1.0 - beta) * kl if (updates + 1) % max(1, int(self.cfg.adaptive_lr_update_interval)) == 0: kl_for_lr = self._kl_ema if kl_for_lr > self.cfg.desired_kl * 2.0: self.learning_rate = max( self.cfg.min_learning_rate, self.learning_rate / max(self.cfg.adaptive_lr_decay, 1.01), ) elif 0.0 < kl_for_lr < self.cfg.desired_kl / 2.0: self.learning_rate = min( self.cfg.max_learning_rate, self.learning_rate * max(self.cfg.adaptive_lr_growth, 1.0), ) self.optimizer.learning_rate = mx.array(self.learning_rate, dtype=self._dtype) for key in agg: agg[key] += metrics[key] updates += 1 if updates == 0: return { **agg, "learning_rate": self.learning_rate, "updates_applied": 0.0, "skipped_nonfinite_loss": float(skipped_nonfinite_loss), "skipped_nonfinite_grads": float(skipped_nonfinite_grads), "rolled_back_updates": float(rolled_back_updates), "skipped_nonfinite_metrics": float(skipped_nonfinite_metrics), "early_stopped_kl": float(early_stopped_kl), } out = {key: value / updates for key, value in agg.items()} out["learning_rate"] = self.learning_rate out["updates_applied"] = float(updates) out["skipped_nonfinite_loss"] = float(skipped_nonfinite_loss) out["skipped_nonfinite_grads"] = float(skipped_nonfinite_grads) out["rolled_back_updates"] = float(rolled_back_updates) out["skipped_nonfinite_metrics"] = float(skipped_nonfinite_metrics) out["early_stopped_kl"] = float(early_stopped_kl) return out