Source code for unilab.algos.mlx.common.distributions

"""Distribution utilities for RL policies."""

from __future__ import annotations

import math

import mlx.core as mx


[docs] def diag_gaussian_log_prob(actions: mx.array, mean: mx.array, log_std: mx.array) -> mx.array: """Log-probability under a diagonal Gaussian.""" var = mx.exp(2.0 * log_std) log_probs = -0.5 * (((actions - mean) ** 2) / var + 2.0 * log_std + math.log(2.0 * math.pi)) return mx.sum(log_probs, axis=-1)
[docs] def diag_gaussian_entropy(log_std: mx.array) -> mx.array: """Entropy of a diagonal Gaussian.""" return mx.sum(log_std + 0.5 * (1.0 + math.log(2.0 * math.pi)), axis=-1)