Source code for unilab.algos.mlx.ppo.model

"""Actor-Critic model for MLX PPO."""

from __future__ import annotations

import math
from typing import Any, Sequence, Tuple

import mlx.core as mx
import mlx.nn as nn

from unilab.algos.mlx.common import MLP, EmpiricalNormalization, diag_gaussian_log_prob


[docs] class MLPActorCritic(nn.Module): """Shared utility class containing actor and critic MLPs."""
[docs] def __init__( self, obs_dim: int, action_dim: int, actor_hidden_dims: Sequence[int], critic_hidden_dims: Sequence[int], activation: str = "tanh", init_log_std: float = 0.0, min_log_std: float = -5.0, max_log_std: float = 2.0, obs_normalization: bool = False, noise_std_type: str = "log", state_dependent_std: bool = False, dtype: Any | None = None, ) -> None: super().__init__() self.action_dim = int(action_dim) self.dtype = mx.float32 if dtype is None else dtype self.noise_std_type = noise_std_type self.state_dependent_std = bool(state_dependent_std) self.obs_normalization = bool(obs_normalization) self.obs_normalizer = ( EmpiricalNormalization(obs_dim, dtype=self.dtype) if self.obs_normalization else None ) actor_output_dim = action_dim * 2 if self.state_dependent_std else action_dim self.actor = MLP(obs_dim, actor_output_dim, actor_hidden_dims, activation=activation) self.critic = MLP(obs_dim, 1, critic_hidden_dims, activation=activation) self.actor.init_orthogonal(hidden_gain=math.sqrt(2.0), output_gain=0.01) self.critic.init_orthogonal(hidden_gain=math.sqrt(2.0), output_gain=1.0) if self.state_dependent_std: # Keep std head conservative at init like rsl-rl. self.actor.layers[-1].weight[self.action_dim :] = 0.0 if self.noise_std_type == "scalar": self.actor.layers[-1].bias[self.action_dim :] = float( mx.exp(mx.array(init_log_std)).item() ) elif self.noise_std_type == "log": self.actor.layers[-1].bias[self.action_dim :] = float(init_log_std) else: raise ValueError(f"Unknown noise_std_type: {self.noise_std_type}") else: if self.noise_std_type == "scalar": self.std = mx.full( (action_dim,), float(mx.exp(mx.array(init_log_std)).item()), dtype=self.dtype ) elif self.noise_std_type == "log": self.log_std = mx.full((action_dim,), float(init_log_std), dtype=self.dtype) else: raise ValueError(f"Unknown noise_std_type: {self.noise_std_type}") self.min_log_std = float(min_log_std) self.max_log_std = float(max_log_std)
[docs] def clipped_log_std(self) -> mx.array: """Clamp log-std to avoid numerical explosion.""" if self.noise_std_type == "log": return mx.clip(self.log_std, self.min_log_std, self.max_log_std) std = mx.maximum(self.std, 1e-4) log_std = mx.log(std) return mx.clip(log_std, self.min_log_std, self.max_log_std)
[docs] def policy(self, obs: mx.array) -> mx.array: mean, _, _ = self.distribution_params(obs) return mean
[docs] def distribution_params(self, obs: mx.array) -> tuple[mx.array, mx.array, mx.array]: if self.obs_normalizer is not None: obs = self.obs_normalizer(obs) if self.state_dependent_std: out = self.actor(obs) mean = out[:, : self.action_dim] std_head = out[:, self.action_dim :] if self.noise_std_type == "scalar": std = mx.maximum(nn.softplus(std_head), 1e-4) log_std = mx.log(std) elif self.noise_std_type == "log": log_std = mx.clip(std_head, self.min_log_std, self.max_log_std) std = mx.maximum(mx.exp(log_std), 1e-4) else: raise ValueError(f"Unknown noise_std_type: {self.noise_std_type}") else: mean = self.actor(obs) if self.noise_std_type == "scalar": std_base = mx.maximum(self.std, 1e-4) std = mx.broadcast_to(std_base, mean.shape) log_std = mx.log(std) elif self.noise_std_type == "log": log_std_base = mx.clip(self.log_std, self.min_log_std, self.max_log_std) log_std = mx.broadcast_to(log_std_base, mean.shape) std = mx.maximum(mx.exp(log_std), 1e-4) else: raise ValueError(f"Unknown noise_std_type: {self.noise_std_type}") return mean, std, log_std
[docs] def value(self, obs: mx.array) -> mx.array: if self.obs_normalizer is not None: obs = self.obs_normalizer(obs) return mx.squeeze(self.critic(obs), axis=-1)
[docs] def update_normalization(self, obs: mx.array) -> None: if self.obs_normalizer is not None: self.obs_normalizer.update(obs)
[docs] def act(self, obs: mx.array) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]: """Sample actions and return MLX tensors.""" mean, std, log_std = self.distribution_params(obs) noise = mx.random.normal(mean.shape) actions = mean + noise * std log_probs = diag_gaussian_log_prob(actions, mean, log_std) values = self.value(obs) return actions, log_probs, values, mean, std
[docs] def current_action_std(self, action_shape: tuple[int, ...]) -> mx.array: """Return broadcasted std tensor for current policy.""" if self.noise_std_type == "scalar": std = mx.maximum(self.std, 1e-4) return mx.broadcast_to(std, action_shape) log_std = self.clipped_log_std() return mx.broadcast_to(mx.maximum(mx.exp(log_std), 1e-4), action_shape)