Source code for unilab.algos.mlx.common.activations

"""Activation helpers for MLX models."""

from __future__ import annotations

from typing import Callable

import mlx.core as mx


[docs] def get_activation(name: str | None) -> Callable[[mx.array], mx.array]: """Resolve a string activation name to a callable.""" if name is None: return lambda x: x name = name.lower() if name == "relu": return lambda x: mx.maximum(x, 0.0) if name == "elu": return lambda x: mx.where(x > 0.0, x, mx.exp(x) - 1.0) if name == "tanh": return lambda x: mx.tanh(x) if name == "sigmoid": return lambda x: mx.sigmoid(x) if name == "swish": return lambda x: x * mx.sigmoid(x) if name == "identity": return lambda x: x raise ValueError(f"Unsupported activation: {name}")