Source code for unilab.algos.mlx.common.normalization
"""Running-stat normalization utilities for MLX RL."""
from __future__ import annotations
from typing import Any
import mlx.core as mx
[docs]
class EmpiricalNormalization:
"""Normalize features using running mean/std over batch axis."""
[docs]
def __init__(self, shape: int, eps: float = 1e-2, dtype: Any | None = None) -> None:
self.eps = float(eps)
self.dtype = mx.float32 if dtype is None else dtype
self.mean = mx.zeros((1, shape), dtype=self.dtype)
self.var = mx.ones((1, shape), dtype=self.dtype)
self.std = mx.ones((1, shape), dtype=self.dtype)
self.count = mx.array(0.0, dtype=self.dtype)
[docs]
def __call__(self, x: mx.array) -> mx.array:
return (x - self.mean) / (self.std + self.eps)
[docs]
def update(self, x: mx.array) -> None:
x = mx.array(x, dtype=self.dtype)
batch_count = mx.array(float(x.shape[0]), dtype=self.dtype)
batch_mean = mx.mean(x, axis=0, keepdims=True)
batch_var = mx.var(x, axis=0, keepdims=True)
total = self.count + batch_count
rate = batch_count / (total + 1e-8)
delta = batch_mean - self.mean
self.mean = self.mean + rate * delta
self.var = self.var + rate * (batch_var - self.var + delta * (batch_mean - self.mean))
self.std = mx.sqrt(mx.maximum(self.var, 1e-8))
self.count = total
mx.eval(self.mean, self.var, self.std, self.count)
[docs]
class EmpiricalDiscountedVariationNormalization:
"""Reward normalization with running std of discounted returns."""
[docs]
def __init__(self, eps: float = 1e-2, gamma: float = 0.99, dtype: Any | None = None) -> None:
self.dtype = mx.float32 if dtype is None else dtype
self.emp_norm = EmpiricalNormalization(shape=1, eps=eps, dtype=self.dtype)
self.gamma = float(gamma)
self.avg: mx.array | None = None
[docs]
def __call__(self, rew: mx.array) -> mx.array:
"""Normalize reward tensor of shape [N] or [N, 1]."""
if rew.ndim == 1:
rew = mx.expand_dims(rew, axis=-1)
rew = mx.array(rew, dtype=self.dtype)
if self.avg is None:
self.avg = rew
else:
self.avg = self.avg * self.gamma + rew
avg: mx.array = self.avg
self.emp_norm.update(avg)
return rew / (self.emp_norm.std + self.emp_norm.eps)