unilab.algos.mlx.ppo.runner

Runner-style utilities for MLX PPO.

This module keeps train script entrypoints thin, similar to rsl-rl runner usage.

Functions

get_latest_checkpoint(run_dir)

Find the latest model_*.safetensors checkpoint in a run dir.

get_latest_run(log_dir)

Find latest run directory under a task log root.

Classes

MLXPPOAgent

High-level PPO wrapper to keep train script lightweight.

TensorboardScalarWriter

Minimal scalar writer based on tensorboard event files.

unilab.algos.mlx.ppo.runner.tree_map(fn, tree)
class unilab.algos.mlx.ppo.runner.TensorboardScalarWriter[source]

Bases: object

Minimal scalar writer based on tensorboard event files.

Parameters:

log_dir (Path)

__init__(log_dir)[source]
Parameters:

log_dir (Path)

add_scalar(tag, value, step)[source]
Parameters:
Return type:

None

flush()[source]
Return type:

None

close()[source]
Return type:

None

unilab.algos.mlx.ppo.runner.get_latest_run(log_dir)[source]

Find latest run directory under a task log root.

Parameters:

log_dir (Path)

Return type:

Path | None

unilab.algos.mlx.ppo.runner.get_latest_checkpoint(run_dir)[source]

Find the latest model_*.safetensors checkpoint in a run dir.

Parameters:

run_dir (Path)

Return type:

Path | None

class unilab.algos.mlx.ppo.runner.MLXPPOAgent[source]

Bases: object

High-level PPO wrapper to keep train script lightweight.

Parameters:
__init__(cfg, obs_dim, action_dim, learning_rate)[source]
Parameters:
property learning_rate: float
update_normalization(obs)[source]
Parameters:

obs (array)

Return type:

None

act(obs)[source]
Parameters:

obs (array)

policy_mean(obs)[source]
Parameters:

obs (array)

Return type:

array

normalize_rewards(rewards)[source]
Parameters:

rewards (array)

Return type:

array

current_action_std(action_shape)[source]
Parameters:

action_shape (tuple[int, ...])

Return type:

array

mean_noise_std()[source]
Return type:

float

update(buffer, last_obs)[source]
Parameters:
load_weights(path)[source]
Parameters:

path (Path)

Return type:

None

save_checkpoint(model_path, trainer_state_path, iteration)[source]
Parameters:
  • model_path (Path)

  • trainer_state_path (Path)

  • iteration (int)

Return type:

None

load_trainer_state(trainer_state_path)[source]
Parameters:

trainer_state_path (Path)

Return type:

int