Source code for unilab.algos.mlx.common.rotation
from __future__ import annotations
import importlib
def _require_mlx_core():
"""Import MLX lazily so non-MLX workflows don't crash at module import time."""
try:
return importlib.import_module("mlx.core")
except Exception as exc:
raise RuntimeError(
"MLX backend is unavailable. Install the MLX extra to use MLX rotation helpers."
) from exc
[docs]
def quat_mul(q1, q2):
"""Multiply two MLX quaternion batches."""
mx = _require_mlx_core()
w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3]
w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3]
return mx.stack(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
axis=1,
)
[docs]
def axis_angle_to_quat(axis, angle):
"""Convert MLX axis-angle batches to quaternions."""
mx = _require_mlx_core()
half_angle = angle / 2
c = mx.cos(half_angle)
s = mx.sin(half_angle)
return mx.stack([c, axis[:, 0] * s, axis[:, 1] * s, axis[:, 2] * s], axis=1)