Source code for unilab.algos.mlx.ppo.runner

"""Runner-style utilities for MLX PPO.

This module keeps train script entrypoints thin, similar to rsl-rl runner usage.
"""

from __future__ import annotations

import pickle
import time
from pathlib import Path

import mlx.core as mx
from mlx.utils import tree_map

from unilab.algos.mlx.common import EmpiricalDiscountedVariationNormalization, RolloutBuffer

from .model import MLPActorCritic
from .ppo import PPOConfig, PPOTrainer


[docs] class TensorboardScalarWriter: """Minimal scalar writer based on tensorboard event files."""
[docs] def __init__(self, log_dir: Path) -> None: from tensorboard.compat.proto.event_pb2 import Event from tensorboard.compat.proto.summary_pb2 import Summary from tensorboard.summary.writer.event_file_writer import EventFileWriter self._Event = Event self._Summary = Summary self._writer = EventFileWriter(str(log_dir))
[docs] def add_scalar(self, tag: str, value: float, step: int) -> None: summary = self._Summary(value=[self._Summary.Value(tag=tag, simple_value=float(value))]) event = self._Event(wall_time=time.time(), step=int(step), summary=summary) self._writer.add_event(event)
[docs] def flush(self) -> None: self._writer.flush()
[docs] def close(self) -> None: self._writer.close()
[docs] def get_latest_run(log_dir: Path) -> Path | None: """Find latest run directory under a task log root.""" if not log_dir.exists(): return None runs = sorted([p for p in log_dir.iterdir() if p.is_dir()]) return runs[-1] if runs else None
[docs] def get_latest_checkpoint(run_dir: Path) -> Path | None: """Find the latest model_*.safetensors checkpoint in a run dir.""" if not run_dir.exists(): return None model_files = [p for p in run_dir.glob("model_*.safetensors") if p.is_file()] if not model_files: return None model_files.sort(key=lambda p: int(p.stem.split("_")[1])) return model_files[-1]
[docs] class MLXPPOAgent: """High-level PPO wrapper to keep train script lightweight."""
[docs] def __init__(self, cfg, obs_dim: int, action_dim: int, learning_rate: float) -> None: policy_cfg = cfg.policy algo_cfg = cfg.algorithm init_noise_std = float(getattr(policy_cfg, "init_noise_std", 1.0)) init_log_std = float(mx.log(mx.array(max(init_noise_std, 1e-6))).item()) obs_norm = bool(getattr(cfg, "empirical_normalization", False)) noise_std_type = str(getattr(policy_cfg, "noise_std_type", "scalar")) state_dependent_std = bool(getattr(policy_cfg, "state_dependent_std", False)) self.model = MLPActorCritic( obs_dim=obs_dim, action_dim=action_dim, actor_hidden_dims=policy_cfg.actor_hidden_dims, critic_hidden_dims=policy_cfg.critic_hidden_dims, activation=policy_cfg.activation, init_log_std=init_log_std, obs_normalization=obs_norm, noise_std_type=noise_std_type, state_dependent_std=state_dependent_std, ) ppo_cfg = PPOConfig( num_learning_epochs=int(algo_cfg.num_learning_epochs), num_mini_batches=int(algo_cfg.num_mini_batches), clip_param=float(algo_cfg.clip_param), gamma=float(algo_cfg.gamma), lam=float(algo_cfg.lam), value_loss_coef=float(algo_cfg.value_loss_coef), entropy_coef=float(algo_cfg.entropy_coef), learning_rate=learning_rate, use_clipped_value_loss=bool(algo_cfg.use_clipped_value_loss), max_grad_norm=float(getattr(algo_cfg, "max_grad_norm", 1.0)), schedule=str(getattr(algo_cfg, "schedule", "fixed")), desired_kl=float(getattr(algo_cfg, "desired_kl", 0.01)), normalize_advantage_per_mini_batch=bool( getattr(algo_cfg, "normalize_advantage_per_mini_batch", False) ), ) self.trainer = PPOTrainer(self.model, ppo_cfg) use_reward_norm = bool(getattr(algo_cfg, "reward_normalization", False)) self.reward_normalizer = ( EmpiricalDiscountedVariationNormalization(gamma=ppo_cfg.gamma) if use_reward_norm else None )
@property def learning_rate(self) -> float: return self.trainer.learning_rate
[docs] def update_normalization(self, obs: mx.array) -> None: self.model.update_normalization(obs)
[docs] def act(self, obs: mx.array): return self.model.act(obs)
[docs] def policy_mean(self, obs: mx.array) -> mx.array: return self.model.policy(obs)
[docs] def normalize_rewards(self, rewards: mx.array) -> mx.array: if self.reward_normalizer is not None: return mx.squeeze(self.reward_normalizer(rewards), axis=-1) return rewards
[docs] def current_action_std(self, action_shape: tuple[int, ...]) -> mx.array: return self.model.current_action_std(action_shape)
[docs] def mean_noise_std(self) -> float: std = self.model.current_action_std((1, self.model.action_dim)) mx.eval(std) return float(mx.mean(std).item())
[docs] def update(self, buffer: RolloutBuffer, last_obs: mx.array): last_values = self.model.value(last_obs) buffer.compute_returns_and_advantages(last_values) return self.trainer.update(buffer)
[docs] def load_weights(self, path: Path) -> None: self.model.load_weights(str(path), strict=True)
[docs] def save_checkpoint(self, model_path: Path, trainer_state_path: Path, iteration: int) -> None: self.model.save_weights(str(model_path)) payload = { "iteration": int(iteration), "learning_rate": float(self.trainer.learning_rate), "optimizer_state": tree_map(lambda x: x.tolist(), self.trainer.optimizer.state), } with trainer_state_path.open("wb") as f: pickle.dump(payload, f)
[docs] def load_trainer_state(self, trainer_state_path: Path) -> int: with trainer_state_path.open("rb") as f: payload = pickle.load(f) self.trainer.learning_rate = float(payload.get("learning_rate", self.trainer.learning_rate)) dtype = getattr(self.model, "dtype", mx.float32) self.trainer.optimizer.learning_rate = mx.array(self.trainer.learning_rate, dtype=dtype) self.trainer.optimizer.state = tree_map(lambda x: mx.array(x), payload["optimizer_state"]) return int(payload.get("iteration", -1))