unilab.algos.mlx.ppo.runner.MLXPPOAgent

class unilab.algos.mlx.ppo.runner.MLXPPOAgent[source]

Bases: object

High-level PPO wrapper to keep train script lightweight.

Parameters:

Methods

__init__(cfg, obs_dim, action_dim, learning_rate)

act(obs)

current_action_std(action_shape)

load_trainer_state(trainer_state_path)

load_weights(path)

mean_noise_std()

normalize_rewards(rewards)

policy_mean(obs)

save_checkpoint(model_path, ...)

update(buffer, last_obs)

update_normalization(obs)

Attributes

__init__(cfg, obs_dim, action_dim, learning_rate)[source]
Parameters:
property learning_rate: float
update_normalization(obs)[source]
Parameters:

obs (array)

Return type:

None

act(obs)[source]
Parameters:

obs (array)

policy_mean(obs)[source]
Parameters:

obs (array)

Return type:

array

normalize_rewards(rewards)[source]
Parameters:

rewards (array)

Return type:

array

current_action_std(action_shape)[source]
Parameters:

action_shape (tuple[int, ...])

Return type:

array

mean_noise_std()[source]
Return type:

float

update(buffer, last_obs)[source]
Parameters:
load_weights(path)[source]
Parameters:

path (Path)

Return type:

None

save_checkpoint(model_path, trainer_state_path, iteration)[source]
Parameters:
  • model_path (Path)

  • trainer_state_path (Path)

  • iteration (int)

Return type:

None

load_trainer_state(trainer_state_path)[source]
Parameters:

trainer_state_path (Path)

Return type:

int