unilab.algos.mlx.common.normalization

Running-stat normalization utilities for MLX RL.

Classes

EmpiricalDiscountedVariationNormalization

Reward normalization with running std of discounted returns.

EmpiricalNormalization

Normalize features using running mean/std over batch axis.

class unilab.algos.mlx.common.normalization.EmpiricalNormalization[source]

Bases: object

Normalize features using running mean/std over batch axis.

Parameters:
__init__(shape, eps=0.01, dtype=None)[source]
Parameters:
update(x)[source]
Parameters:

x (array)

Return type:

None

class unilab.algos.mlx.common.normalization.EmpiricalDiscountedVariationNormalization[source]

Bases: object

Reward normalization with running std of discounted returns.

Parameters:
__init__(eps=0.01, gamma=0.99, dtype=None)[source]
Parameters: