unilab.algos.mlx.ppo.model
Actor-Critic model for MLX PPO.
Classes
-
class unilab.algos.mlx.ppo.model.MLPActorCritic[source]
Bases: Module
Shared utility class containing actor and critic MLPs.
- Parameters:
-
-
__init__(obs_dim, action_dim, actor_hidden_dims, critic_hidden_dims, activation='tanh', init_log_std=0.0, min_log_std=-5.0, max_log_std=2.0, obs_normalization=False, noise_std_type='log', state_dependent_std=False, dtype=None)[source]
- Parameters:
-
-
clipped_log_std()[source]
Clamp log-std to avoid numerical explosion.
- Return type:
array
-
policy(obs)[source]
- Parameters:
obs (array)
- Return type:
array
-
distribution_params(obs)[source]
- Parameters:
obs (array)
- Return type:
tuple[array, array, array]
-
value(obs)[source]
- Parameters:
obs (array)
- Return type:
array
-
update_normalization(obs)[source]
- Parameters:
obs (array)
- Return type:
None
-
act(obs)[source]
Sample actions and return MLX tensors.
- Parameters:
obs (array)
- Return type:
Tuple[array, array, array, array, array]
-
current_action_std(action_shape)[source]
Return broadcasted std tensor for current policy.
- Parameters:
action_shape (tuple[int, ...])
- Return type:
array