Source code for unilab.visualization.viser_scene

# pyright: reportMissingImports=false
"""MuJoCo-to-viser scene adapter for interactive web-based 3D visualization.

This module renders MuJoCo scenes via a viser web server, providing browser-based
interactive 3D viewing without requiring a local display or GLFW.  It is gated
behind the ``viser`` optional-dependency group and is **not** imported by default.

Usage (from ``scripts/play_viser.py``)::

    from unilab.visualization.viser_scene import MujocoViserScene, VISER_AVAILABLE
"""

from __future__ import annotations

import math
from typing import Any

import mujoco
import numpy as np

try:
    import trimesh
    import viser

    VISER_AVAILABLE = True
except ImportError:
    VISER_AVAILABLE = False


# --------------------------------------------------------------------------- #
# Rotation helpers (pure numpy, no scipy dependency)                          #
# --------------------------------------------------------------------------- #


def _rotmat_to_wxyz(mat: np.ndarray) -> tuple[float, float, float, float]:
    """Convert a 3x3 rotation matrix to a (w, x, y, z) quaternion."""
    m = np.asarray(mat, dtype=np.float64).reshape(3, 3)
    trace = m[0, 0] + m[1, 1] + m[2, 2]

    if trace > 0:
        s = 0.5 / math.sqrt(trace + 1.0)
        w = 0.25 / s
        x = (m[2, 1] - m[1, 2]) * s
        y = (m[0, 2] - m[2, 0]) * s
        z = (m[1, 0] - m[0, 1]) * s
    elif m[0, 0] > m[1, 1] and m[0, 0] > m[2, 2]:
        s = 2.0 * math.sqrt(1.0 + m[0, 0] - m[1, 1] - m[2, 2])
        w = (m[2, 1] - m[1, 2]) / s
        x = 0.25 * s
        y = (m[0, 1] + m[1, 0]) / s
        z = (m[0, 2] + m[2, 0]) / s
    elif m[1, 1] > m[2, 2]:
        s = 2.0 * math.sqrt(1.0 + m[1, 1] - m[0, 0] - m[2, 2])
        w = (m[0, 2] - m[2, 0]) / s
        x = (m[0, 1] + m[1, 0]) / s
        y = 0.25 * s
        z = (m[1, 2] + m[2, 1]) / s
    else:
        s = 2.0 * math.sqrt(1.0 + m[2, 2] - m[0, 0] - m[1, 1])
        w = (m[1, 0] - m[0, 1]) / s
        x = (m[0, 2] + m[2, 0]) / s
        y = (m[1, 2] + m[2, 1]) / s
        z = 0.25 * s

    return (float(w), float(x), float(y), float(z))


# --------------------------------------------------------------------------- #
# Geometry extraction helpers                                                 #
# --------------------------------------------------------------------------- #


def _rgba_to_color(rgba: np.ndarray) -> tuple[int, int, int]:
    """Convert MuJoCo float RGBA [0,1] to viser int RGB [0,255]."""
    return (
        int(np.clip(rgba[0] * 255, 0, 255)),
        int(np.clip(rgba[1] * 255, 0, 255)),
        int(np.clip(rgba[2] * 255, 0, 255)),
    )


def _rgba_to_opacity(rgba: np.ndarray) -> float:
    return float(np.clip(rgba[3], 0.0, 1.0))


def _extract_mesh(model: mujoco.MjModel, geom_dataid: int) -> tuple[np.ndarray, np.ndarray]:
    """Extract vertices and faces for a MuJoCo mesh geom."""
    vert_adr = model.mesh_vertadr[geom_dataid]
    vert_num = model.mesh_vertnum[geom_dataid]
    face_adr = model.mesh_faceadr[geom_dataid]
    face_num = model.mesh_facenum[geom_dataid]

    vertices = model.mesh_vert[vert_adr : vert_adr + vert_num].copy()
    faces = model.mesh_face[face_adr : face_adr + face_num].copy()
    return vertices, faces


[docs] def build_visible_env_indices(num_envs: int, visible_envs: int) -> np.ndarray: """Select a stable subset of env indices spread across the full batch. Args: num_envs: Total number of runtime environments. visible_envs: Number of env slots exposed in the viewer. Returns: A monotonically increasing array of runtime env indices. """ if visible_envs <= 0: raise ValueError(f"visible_envs must be positive, got {visible_envs}") if visible_envs >= num_envs: return np.arange(num_envs, dtype=np.int32) return np.floor(np.linspace(0, num_envs, visible_envs, endpoint=False)).astype(np.int32)
# --------------------------------------------------------------------------- # # MujocoViserScene # # --------------------------------------------------------------------------- #
[docs] class MujocoViserScene: """Bridges a ``mujoco.MjModel`` to a ``viser.ViserServer`` scene graph. Call :meth:`build` once to populate the scene with geometry handles, then call :meth:`update` each frame to sync body transforms from ``MjData``. """
[docs] def __init__( self, server: Any, model: mujoco.MjModel, *, name_prefix: str = "/mujoco", position_offset: tuple[float, float, float] = (0.0, 0.0, 0.0), render_plane: bool = True, ) -> None: if not VISER_AVAILABLE: raise ImportError("viser is not installed. Install with: uv sync --extra viser") self._server: viser.ViserServer = server self._model = model self._name_prefix = name_prefix.rstrip("/") or "/mujoco" self._position_offset = np.asarray(position_offset, dtype=np.float64) self._render_plane = bool(render_plane) self._handles: dict[int, Any] = {} self._build()
[docs] def reset( self, model: mujoco.MjModel, *, position_offset: tuple[float, float, float] | None = None, render_plane: bool | None = None, ) -> None: """Rebuild the viser scene for a new MuJoCo model. Args: model: MuJoCo model whose geoms should populate the scene. position_offset: Optional XYZ offset applied to all geoms. render_plane: Optional override for whether plane geoms should be built. Returns: None. """ self.close() self._model = model if position_offset is not None: self._position_offset = np.asarray(position_offset, dtype=np.float64) if render_plane is not None: self._render_plane = bool(render_plane) self._build()
[docs] def close(self) -> None: """Remove all scene handles owned by this adapter.""" for handle in self._handles.values(): handle.remove() self._handles.clear()
# ------------------------------------------------------------------ # # Scene construction # # ------------------------------------------------------------------ # def _build(self) -> None: """Create viser scene nodes for every MuJoCo geom.""" model = self._model server = self._server server.scene.set_up_direction("+z") for i in range(model.ngeom): geom_type = model.geom_type[i] size = model.geom_size[i] rgba = model.geom_rgba[i] color = _rgba_to_color(rgba) opacity = _rgba_to_opacity(rgba) name = f"{self._name_prefix}/geom/{i}" handle: Any | None = None if geom_type == mujoco.mjtGeom.mjGEOM_PLANE: if not self._render_plane: continue # Render ground plane as a grid plane_size = float(size[0]) if size[0] > 0 else 10.0 handle = server.scene.add_grid( name, width=plane_size * 2, height=plane_size * 2, cell_size=0.5, ) elif geom_type == mujoco.mjtGeom.mjGEOM_SPHERE: handle = server.scene.add_icosphere( name, radius=float(size[0]), color=color, opacity=opacity, ) elif geom_type == mujoco.mjtGeom.mjGEOM_CAPSULE: half_len = float(size[1]) radius = float(size[0]) mesh = trimesh.creation.capsule(height=half_len * 2, radius=radius) handle = server.scene.add_mesh_trimesh(name, mesh=mesh) # Manually set color since trimesh mesh may not carry it if hasattr(handle, "color"): handle.color = color elif geom_type == mujoco.mjtGeom.mjGEOM_ELLIPSOID: # Use a unit sphere mesh scaled non-uniformly mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0) mesh.vertices *= np.array([float(size[0]), float(size[1]), float(size[2])]) handle = server.scene.add_mesh_trimesh(name, mesh=mesh) elif geom_type == mujoco.mjtGeom.mjGEOM_CYLINDER: handle = server.scene.add_cylinder( name, radius=float(size[0]), height=float(size[1]) * 2, color=color, opacity=opacity, ) elif geom_type == mujoco.mjtGeom.mjGEOM_BOX: handle = server.scene.add_box( name, dimensions=( float(size[0]) * 2, float(size[1]) * 2, float(size[2]) * 2, ), color=color, opacity=opacity, ) elif geom_type == mujoco.mjtGeom.mjGEOM_MESH: dataid = model.geom_dataid[i] if dataid >= 0: vertices, faces = _extract_mesh(model, dataid) handle = server.scene.add_mesh_simple( name, vertices=vertices.astype(np.float32), faces=faces.astype(np.int32), color=color, opacity=opacity, ) if handle is not None: self._handles[i] = handle # ------------------------------------------------------------------ # # Per-frame update # # ------------------------------------------------------------------ #
[docs] def update(self, data: mujoco.MjData) -> None: """Sync all geom transforms from *data* into the viser scene.""" with self._server.atomic(): for i, handle in self._handles.items(): xpos = data.geom_xpos[i] + self._position_offset xmat = data.geom_xmat[i] handle.position = (float(xpos[0]), float(xpos[1]), float(xpos[2])) handle.wxyz = _rotmat_to_wxyz(xmat)