unilab.algos.mlx.ppo¶
PPO implementation based on MLX.
- class unilab.algos.mlx.ppo.MLPActorCritic[source]¶
Bases:
ModuleShared utility class containing actor and critic MLPs.
- Parameters:
- __init__(obs_dim, action_dim, actor_hidden_dims, critic_hidden_dims, activation='tanh', init_log_std=0.0, min_log_std=-5.0, max_log_std=2.0, obs_normalization=False, noise_std_type='log', state_dependent_std=False, dtype=None)[source]¶
- class unilab.algos.mlx.ppo.PPOConfig[source]¶
Bases:
objectPPOConfig(num_learning_epochs: ‘int’ = 4, num_mini_batches: ‘int’ = 4, clip_param: ‘float’ = 0.2, gamma: ‘float’ = 0.99, lam: ‘float’ = 0.95, value_loss_coef: ‘float’ = 0.5, entropy_coef: ‘float’ = 0.0, learning_rate: ‘float’ = 0.0003, use_clipped_value_loss: ‘bool’ = True, max_grad_norm: ‘float’ = 1.0, log_ratio_clip: ‘float’ = 20.0, schedule: ‘str’ = ‘fixed’, desired_kl: ‘float’ = 0.01, min_learning_rate: ‘float’ = 1e-05, max_learning_rate: ‘float’ = 0.01, normalize_advantage_per_mini_batch: ‘bool’ = False, adaptive_kl_beta: ‘float’ = 0.9, adaptive_lr_decay: ‘float’ = 1.5, adaptive_lr_growth: ‘float’ = 1.2, adaptive_lr_update_interval: ‘int’ = 1, target_kl_stop: ‘float | None’ = None, metrics_interval: ‘int’ = 8, finite_check_interval: ‘int’ = 8, enable_compile: ‘bool’ = False, warmup_strict_iters: ‘int’ = 0, warmup_metrics_interval: ‘int’ = 1, warmup_finite_check_interval: ‘int’ = 1, disable_finite_checks: ‘bool’ = False)
- Parameters:
num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)use_clipped_value_loss (
bool)max_grad_norm (
float)log_ratio_clip (
float)schedule (
str)desired_kl (
float)min_learning_rate (
float)max_learning_rate (
float)normalize_advantage_per_mini_batch (
bool)adaptive_kl_beta (
float)adaptive_lr_decay (
float)adaptive_lr_growth (
float)adaptive_lr_update_interval (
int)metrics_interval (
int)finite_check_interval (
int)enable_compile (
bool)warmup_strict_iters (
int)warmup_metrics_interval (
int)warmup_finite_check_interval (
int)disable_finite_checks (
bool)
- __init__(num_learning_epochs=4, num_mini_batches=4, clip_param=0.2, gamma=0.99, lam=0.95, value_loss_coef=0.5, entropy_coef=0.0, learning_rate=0.0003, use_clipped_value_loss=True, max_grad_norm=1.0, log_ratio_clip=20.0, schedule='fixed', desired_kl=0.01, min_learning_rate=1e-05, max_learning_rate=0.01, normalize_advantage_per_mini_batch=False, adaptive_kl_beta=0.9, adaptive_lr_decay=1.5, adaptive_lr_growth=1.2, adaptive_lr_update_interval=1, target_kl_stop=None, metrics_interval=8, finite_check_interval=8, enable_compile=False, warmup_strict_iters=0, warmup_metrics_interval=1, warmup_finite_check_interval=1, disable_finite_checks=False)¶
- Parameters:
num_learning_epochs (
int)num_mini_batches (
int)clip_param (
float)gamma (
float)lam (
float)value_loss_coef (
float)entropy_coef (
float)learning_rate (
float)use_clipped_value_loss (
bool)max_grad_norm (
float)log_ratio_clip (
float)schedule (
str)desired_kl (
float)min_learning_rate (
float)max_learning_rate (
float)normalize_advantage_per_mini_batch (
bool)adaptive_kl_beta (
float)adaptive_lr_decay (
float)adaptive_lr_growth (
float)adaptive_lr_update_interval (
int)metrics_interval (
int)finite_check_interval (
int)enable_compile (
bool)warmup_strict_iters (
int)warmup_metrics_interval (
int)warmup_finite_check_interval (
int)disable_finite_checks (
bool)
- class unilab.algos.mlx.ppo.PPOTrainer[source]¶
Bases:
objectPPO update logic for MLPActorCritic and RolloutBuffer.
- Parameters:
model (
MLPActorCritic)cfg (
PPOConfig)
- __init__(model, cfg)[source]¶
- Parameters:
model (
MLPActorCritic)cfg (
PPOConfig)
- class unilab.algos.mlx.ppo.MLXPPOAgent[source]¶
Bases:
objectHigh-level PPO wrapper to keep train script lightweight.
- update(buffer, last_obs)[source]¶
- Parameters:
buffer (
RolloutBuffer)last_obs (
array)
- class unilab.algos.mlx.ppo.TensorboardScalarWriter[source]¶
Bases:
objectMinimal scalar writer based on tensorboard event files.
- Parameters:
log_dir (
Path)
- unilab.algos.mlx.ppo.get_latest_checkpoint(run_dir)[source]¶
Find the latest model_*.safetensors checkpoint in a run dir.
- Parameters:
run_dir (
Path)- Return type:
Path|None
- unilab.algos.mlx.ppo.get_latest_run(log_dir)[source]¶
Find latest run directory under a task log root.
- Parameters:
log_dir (
Path)- Return type:
Path|None
Modules