Source code for unilab.algos.mlx.common.mlp

"""MLP module used by MLX RL algorithms."""

from __future__ import annotations

import math
from typing import Sequence

import mlx.core as mx
import mlx.nn as nn

from .activations import get_activation


[docs] class MLP(nn.Module): """Simple feed-forward MLP with configurable activations."""
[docs] def __init__( self, input_dim: int, output_dim: int, hidden_dims: Sequence[int], activation: str = "elu", last_activation: str | None = None, ) -> None: super().__init__() dims = [int(input_dim)] + [int(h) for h in hidden_dims] + [int(output_dim)] self.layers = [nn.Linear(dims[i], dims[i + 1]) for i in range(len(dims) - 1)] self.activation = get_activation(activation) self.last_activation = ( get_activation(last_activation) if last_activation is not None else None )
[docs] def __call__(self, x: mx.array) -> mx.array: for idx, layer in enumerate(self.layers): x = layer(x) is_last = idx == (len(self.layers) - 1) if not is_last: x = self.activation(x) elif self.last_activation is not None: x = self.last_activation(x) return x
[docs] def init_orthogonal( self, hidden_gain: float = math.sqrt(2.0), output_gain: float = 1.0 ) -> None: """Orthogonally initialize linear layers with separate output gain.""" num_layers = len(self.layers) for idx, layer in enumerate(self.layers): gain = output_gain if idx == (num_layers - 1) else hidden_gain layer.weight = nn.init.orthogonal(gain=gain)(layer.weight) layer.bias = mx.zeros_like(layer.bias)