unilab.algos.mlx.common

MLX RL base modules.

This package contains framework-level building blocks that are reused by algorithm implementations (e.g. PPO).

class unilab.algos.mlx.common.MLP[source]

Bases: Module

Simple feed-forward MLP with configurable activations.

Parameters:
__init__(input_dim, output_dim, hidden_dims, activation='elu', last_activation=None)[source]
Parameters:
init_orthogonal(hidden_gain=1.4142135623730951, output_gain=1.0)[source]

Orthogonally initialize linear layers with separate output gain.

Parameters:
Return type:

None

class unilab.algos.mlx.common.EmpiricalNormalization[source]

Bases: object

Normalize features using running mean/std over batch axis.

Parameters:
__init__(shape, eps=0.01, dtype=None)[source]
Parameters:
update(x)[source]
Parameters:

x (array)

Return type:

None

class unilab.algos.mlx.common.EmpiricalDiscountedVariationNormalization[source]

Bases: object

Reward normalization with running std of discounted returns.

Parameters:
__init__(eps=0.01, gamma=0.99, dtype=None)[source]
Parameters:
class unilab.algos.mlx.common.RolloutBuffer[source]

Bases: object

On-policy rollout storage for vectorized environments.

Parameters:
num_steps: int
num_envs: int
obs_dim: int
action_dim: int
gamma: float
lam: float
dtype: Any | None = None
add(obs, actions, log_probs, action_mean, action_std, rewards, dones, values)[source]
Parameters:
  • obs (array)

  • actions (array)

  • log_probs (array)

  • action_mean (array)

  • action_std (array)

  • rewards (array)

  • dones (array)

  • values (array)

Return type:

None

compute_returns_and_advantages(last_values)[source]
Parameters:

last_values (array)

Return type:

None

mini_batch_generator(num_mini_batches, num_epochs)[source]
Parameters:
  • num_mini_batches (int)

  • num_epochs (int)

Return type:

Generator[Dict[str, array], None, None]

clear()[source]
Return type:

None

__init__(num_steps, num_envs, obs_dim, action_dim, gamma, lam, dtype=None)
Parameters:
unilab.algos.mlx.common.diag_gaussian_log_prob(actions, mean, log_std)[source]

Log-probability under a diagonal Gaussian.

Parameters:
  • actions (array)

  • mean (array)

  • log_std (array)

Return type:

array

unilab.algos.mlx.common.diag_gaussian_entropy(log_std)[source]

Entropy of a diagonal Gaussian.

Parameters:

log_std (array)

Return type:

array

Modules

activations

Activation helpers for MLX models.

distributions

Distribution utilities for RL policies.

mlp

MLP module used by MLX RL algorithms.

normalization

Running-stat normalization utilities for MLX RL.

rollout_storage

Rollout buffer for on-policy algorithms.

rotation