unilab.algos.mlx.ppo

PPO implementation based on MLX.

class unilab.algos.mlx.ppo.MLPActorCritic[source]

Bases: Module

Shared utility class containing actor and critic MLPs.

Parameters:
__init__(obs_dim, action_dim, actor_hidden_dims, critic_hidden_dims, activation='tanh', init_log_std=0.0, min_log_std=-5.0, max_log_std=2.0, obs_normalization=False, noise_std_type='log', state_dependent_std=False, dtype=None)[source]
Parameters:
clipped_log_std()[source]

Clamp log-std to avoid numerical explosion.

Return type:

array

policy(obs)[source]
Parameters:

obs (array)

Return type:

array

distribution_params(obs)[source]
Parameters:

obs (array)

Return type:

tuple[array, array, array]

value(obs)[source]
Parameters:

obs (array)

Return type:

array

update_normalization(obs)[source]
Parameters:

obs (array)

Return type:

None

act(obs)[source]

Sample actions and return MLX tensors.

Parameters:

obs (array)

Return type:

Tuple[array, array, array, array, array]

current_action_std(action_shape)[source]

Return broadcasted std tensor for current policy.

Parameters:

action_shape (tuple[int, ...])

Return type:

array

class unilab.algos.mlx.ppo.PPOConfig[source]

Bases: object

PPOConfig(num_learning_epochs: ‘int’ = 4, num_mini_batches: ‘int’ = 4, clip_param: ‘float’ = 0.2, gamma: ‘float’ = 0.99, lam: ‘float’ = 0.95, value_loss_coef: ‘float’ = 0.5, entropy_coef: ‘float’ = 0.0, learning_rate: ‘float’ = 0.0003, use_clipped_value_loss: ‘bool’ = True, max_grad_norm: ‘float’ = 1.0, log_ratio_clip: ‘float’ = 20.0, schedule: ‘str’ = ‘fixed’, desired_kl: ‘float’ = 0.01, min_learning_rate: ‘float’ = 1e-05, max_learning_rate: ‘float’ = 0.01, normalize_advantage_per_mini_batch: ‘bool’ = False, adaptive_kl_beta: ‘float’ = 0.9, adaptive_lr_decay: ‘float’ = 1.5, adaptive_lr_growth: ‘float’ = 1.2, adaptive_lr_update_interval: ‘int’ = 1, target_kl_stop: ‘float | None’ = None, metrics_interval: ‘int’ = 8, finite_check_interval: ‘int’ = 8, enable_compile: ‘bool’ = False, warmup_strict_iters: ‘int’ = 0, warmup_metrics_interval: ‘int’ = 1, warmup_finite_check_interval: ‘int’ = 1, disable_finite_checks: ‘bool’ = False)

Parameters:
  • num_learning_epochs (int)

  • num_mini_batches (int)

  • clip_param (float)

  • gamma (float)

  • lam (float)

  • value_loss_coef (float)

  • entropy_coef (float)

  • learning_rate (float)

  • use_clipped_value_loss (bool)

  • max_grad_norm (float)

  • log_ratio_clip (float)

  • schedule (str)

  • desired_kl (float)

  • min_learning_rate (float)

  • max_learning_rate (float)

  • normalize_advantage_per_mini_batch (bool)

  • adaptive_kl_beta (float)

  • adaptive_lr_decay (float)

  • adaptive_lr_growth (float)

  • adaptive_lr_update_interval (int)

  • target_kl_stop (float | None)

  • metrics_interval (int)

  • finite_check_interval (int)

  • enable_compile (bool)

  • warmup_strict_iters (int)

  • warmup_metrics_interval (int)

  • warmup_finite_check_interval (int)

  • disable_finite_checks (bool)

num_learning_epochs: int = 4
num_mini_batches: int = 4
clip_param: float = 0.2
gamma: float = 0.99
lam: float = 0.95
value_loss_coef: float = 0.5
entropy_coef: float = 0.0
learning_rate: float = 0.0003
use_clipped_value_loss: bool = True
max_grad_norm: float = 1.0
log_ratio_clip: float = 20.0
schedule: str = 'fixed'
desired_kl: float = 0.01
min_learning_rate: float = 1e-05
max_learning_rate: float = 0.01
normalize_advantage_per_mini_batch: bool = False
adaptive_kl_beta: float = 0.9
adaptive_lr_decay: float = 1.5
adaptive_lr_growth: float = 1.2
adaptive_lr_update_interval: int = 1
target_kl_stop: float | None = None
metrics_interval: int = 8
finite_check_interval: int = 8
enable_compile: bool = False
warmup_strict_iters: int = 0
warmup_metrics_interval: int = 1
warmup_finite_check_interval: int = 1
disable_finite_checks: bool = False
__init__(num_learning_epochs=4, num_mini_batches=4, clip_param=0.2, gamma=0.99, lam=0.95, value_loss_coef=0.5, entropy_coef=0.0, learning_rate=0.0003, use_clipped_value_loss=True, max_grad_norm=1.0, log_ratio_clip=20.0, schedule='fixed', desired_kl=0.01, min_learning_rate=1e-05, max_learning_rate=0.01, normalize_advantage_per_mini_batch=False, adaptive_kl_beta=0.9, adaptive_lr_decay=1.5, adaptive_lr_growth=1.2, adaptive_lr_update_interval=1, target_kl_stop=None, metrics_interval=8, finite_check_interval=8, enable_compile=False, warmup_strict_iters=0, warmup_metrics_interval=1, warmup_finite_check_interval=1, disable_finite_checks=False)
Parameters:
  • num_learning_epochs (int)

  • num_mini_batches (int)

  • clip_param (float)

  • gamma (float)

  • lam (float)

  • value_loss_coef (float)

  • entropy_coef (float)

  • learning_rate (float)

  • use_clipped_value_loss (bool)

  • max_grad_norm (float)

  • log_ratio_clip (float)

  • schedule (str)

  • desired_kl (float)

  • min_learning_rate (float)

  • max_learning_rate (float)

  • normalize_advantage_per_mini_batch (bool)

  • adaptive_kl_beta (float)

  • adaptive_lr_decay (float)

  • adaptive_lr_growth (float)

  • adaptive_lr_update_interval (int)

  • target_kl_stop (float | None)

  • metrics_interval (int)

  • finite_check_interval (int)

  • enable_compile (bool)

  • warmup_strict_iters (int)

  • warmup_metrics_interval (int)

  • warmup_finite_check_interval (int)

  • disable_finite_checks (bool)

class unilab.algos.mlx.ppo.PPOTrainer[source]

Bases: object

PPO update logic for MLPActorCritic and RolloutBuffer.

Parameters:
__init__(model, cfg)[source]
Parameters:
update(buffer, iteration=-1)[source]
Parameters:
Return type:

Dict[str, float]

class unilab.algos.mlx.ppo.MLXPPOAgent[source]

Bases: object

High-level PPO wrapper to keep train script lightweight.

Parameters:
__init__(cfg, obs_dim, action_dim, learning_rate)[source]
Parameters:
property learning_rate: float
update_normalization(obs)[source]
Parameters:

obs (array)

Return type:

None

act(obs)[source]
Parameters:

obs (array)

policy_mean(obs)[source]
Parameters:

obs (array)

Return type:

array

normalize_rewards(rewards)[source]
Parameters:

rewards (array)

Return type:

array

current_action_std(action_shape)[source]
Parameters:

action_shape (tuple[int, ...])

Return type:

array

mean_noise_std()[source]
Return type:

float

update(buffer, last_obs)[source]
Parameters:
load_weights(path)[source]
Parameters:

path (Path)

Return type:

None

save_checkpoint(model_path, trainer_state_path, iteration)[source]
Parameters:
  • model_path (Path)

  • trainer_state_path (Path)

  • iteration (int)

Return type:

None

load_trainer_state(trainer_state_path)[source]
Parameters:

trainer_state_path (Path)

Return type:

int

class unilab.algos.mlx.ppo.TensorboardScalarWriter[source]

Bases: object

Minimal scalar writer based on tensorboard event files.

Parameters:

log_dir (Path)

__init__(log_dir)[source]
Parameters:

log_dir (Path)

add_scalar(tag, value, step)[source]
Parameters:
Return type:

None

flush()[source]
Return type:

None

close()[source]
Return type:

None

unilab.algos.mlx.ppo.get_latest_checkpoint(run_dir)[source]

Find the latest model_*.safetensors checkpoint in a run dir.

Parameters:

run_dir (Path)

Return type:

Path | None

unilab.algos.mlx.ppo.get_latest_run(log_dir)[source]

Find latest run directory under a task log root.

Parameters:

log_dir (Path)

Return type:

Path | None

Modules

model

Actor-Critic model for MLX PPO.

ppo

PPO trainer implemented with MLX.

runner

Runner-style utilities for MLX PPO.