unilab.algos.mlx.ppo.model

Actor-Critic model for MLX PPO.

Classes

MLPActorCritic

Shared utility class containing actor and critic MLPs.

class unilab.algos.mlx.ppo.model.MLPActorCritic[source]

Bases: Module

Shared utility class containing actor and critic MLPs.

Parameters:
__init__(obs_dim, action_dim, actor_hidden_dims, critic_hidden_dims, activation='tanh', init_log_std=0.0, min_log_std=-5.0, max_log_std=2.0, obs_normalization=False, noise_std_type='log', state_dependent_std=False, dtype=None)[source]
Parameters:
clipped_log_std()[source]

Clamp log-std to avoid numerical explosion.

Return type:

array

policy(obs)[source]
Parameters:

obs (array)

Return type:

array

distribution_params(obs)[source]
Parameters:

obs (array)

Return type:

tuple[array, array, array]

value(obs)[source]
Parameters:

obs (array)

Return type:

array

update_normalization(obs)[source]
Parameters:

obs (array)

Return type:

None

act(obs)[source]

Sample actions and return MLX tensors.

Parameters:

obs (array)

Return type:

Tuple[array, array, array, array, array]

current_action_std(action_shape)[source]

Return broadcasted std tensor for current policy.

Parameters:

action_shape (tuple[int, ...])

Return type:

array