Source code for unilab.base.backend

from typing import Any, cast

from unilab.base.scene import SceneCfg

from .base import SimBackend
from .motrix.scene import (
    add_motrix_tracking_frame_sensors,
    materialize_motrix_hfield_attached_scene,
    materialize_motrix_scene,
)

_MUJOCO_XML_EXPORTS = frozenset(
    {
        "add_sensor",
        "create_discardvisual_xml",
        "get_named_body_ids",
        "inject_mujoco_tracking_sensors",
        "materialize_mujoco_hfield_attached_scene",
        "materialize_scene_fragments",
        "materialize_scene_visual_override",
        "processed_xml",
    }
)


def _load_mujoco_backend() -> Any:
    from .mujoco.backend import MuJoCoBackend

    return MuJoCoBackend


def _load_motrix_backend() -> tuple[Any, bool]:
    from .motrix.backend import MOTRIX_AVAILABLE, MotrixBackend

    return MotrixBackend, bool(MOTRIX_AVAILABLE)


[docs] def create_backend( backend_type: str, scene: SceneCfg, num_envs: int, sim_dt: float, **kwargs, ) -> SimBackend: """创建仿真后端 Args: backend_type: "mujoco" 或 "motrix" scene: SceneCfg,静态 scene 或组合式 scene 都通过这个 contract 表达 num_envs: 环境数量 sim_dt: 仿真时间步长 **kwargs: 其他参数(position_actuator_gains, motrix_max_iterations 等) Returns: SimBackend 实例 """ if scene is None: raise ValueError("SceneCfg must be provided") position_actuator_gains = kwargs.pop("position_actuator_gains", None) motrix_max_iterations = kwargs.pop("motrix_max_iterations", None) post_step_forward_sensor = kwargs.pop("post_step_forward_sensor", None) if backend_type == "mujoco": MuJoCoBackend = _load_mujoco_backend() if position_actuator_gains is not None: kwargs["position_actuator_gains"] = position_actuator_gains if post_step_forward_sensor is not None: kwargs["post_step_forward_sensor"] = post_step_forward_sensor return cast(SimBackend, MuJoCoBackend(scene, num_envs, sim_dt, **kwargs)) if backend_type == "motrix": MotrixBackend, motrix_available = _load_motrix_backend() if not motrix_available: raise ImportError("MotrixSim not available, install motrixsim package") if motrix_max_iterations is not None: kwargs["max_iterations"] = motrix_max_iterations return cast(SimBackend, MotrixBackend(scene, num_envs, sim_dt, **kwargs)) raise ValueError(f"Unknown backend: {backend_type}")
def __getattr__(name: str): if name == "MuJoCoBackend": return _load_mujoco_backend() if name == "MotrixBackend": return _load_motrix_backend()[0] if name == "MOTRIX_AVAILABLE": return _load_motrix_backend()[1] if name in _MUJOCO_XML_EXPORTS: from .mujoco import xml return getattr(xml, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") __all__ = [ "SimBackend", "MuJoCoBackend", "MotrixBackend", "add_sensor", "create_discardvisual_xml", "create_backend", "get_named_body_ids", "inject_mujoco_tracking_sensors", "add_motrix_tracking_frame_sensors", "materialize_motrix_hfield_attached_scene", "materialize_motrix_scene", "materialize_mujoco_hfield_attached_scene", "materialize_scene_fragments", "materialize_scene_visual_override", "processed_xml", ]