Source code for unilab.tools.viz_nan

"""Interactive viewer for NaN guard state dumps."""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Any

import numpy as np


[docs] def load_dump(dump_path: str) -> dict: data = np.load(dump_path, allow_pickle=True) states = data["states"] metadata = {} for key in data.files: if key.startswith("meta_"): val = data[key] metadata[key[5:]] = val.item() if val.ndim == 0 else val return {"states": states, "metadata": metadata}
[docs] def replay_dump(dump_path: str, env_index: int = 0) -> None: import mujoco as _mujoco import mujoco.viewer as _viewer # noqa: F401 mujoco: Any = _mujoco dump = load_dump(dump_path) states = dump["states"] meta = dump["metadata"] if states.size == 0: print("No physics states in dump (backend may not support state playback).") print(f"Metadata: {meta}") return model_file = str(meta.get("model_file", "")) dump_dir = Path(dump_path).parent model_path = None if model_file and Path(model_file).is_file(): model_path = model_file else: for f in sorted(dump_dir.glob("*_model.*")): model_path = str(f) break if model_path is None: print(f"Cannot find model file. model_file in metadata: {model_file}") return model = ( mujoco.MjModel.from_xml_path(model_path) if model_path.endswith(".xml") else mujoco.MjModel.from_binary_path(model_path) ) d = mujoco.MjData(model) num_steps = states.shape[0] num_envs = states.shape[1] if states.ndim >= 2 else 1 nan_env_ids = meta.get("nan_env_ids", np.array([])) step_detected = meta.get("detection_step", "?") print(f"Dump: {dump_path}") print(f"Steps in buffer: {num_steps}, Envs: {num_envs}") print(f"NaN detected at step {step_detected}, env ids: {nan_env_ids}") print(f"Viewing env index: {env_index}") print("Press Ctrl+C to exit.") state_size = mujoco.mj_stateSize(model, mujoco.mjtState.mjSTATE_PHYSICS) with mujoco.viewer.launch_passive(model, d) as viewer: step_idx = 0 while viewer.is_running(): if states.ndim >= 3: flat = states[step_idx, env_index] elif states.ndim == 2: flat = states[step_idx] else: break if flat.shape[0] >= state_size: mujoco.mj_setState(model, d, flat[:state_size], mujoco.mjtState.mjSTATE_PHYSICS) mujoco.mj_forward(model, d) viewer.sync() import time time.sleep(1.0 / 30.0) step_idx = (step_idx + 1) % num_steps
[docs] def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( prog="unilab-viz-nan", description="Replay a NaN guard state dump in MuJoCo viewer.", ) parser.add_argument("dump_path", help="Path to the .npz dump file") parser.add_argument("--env-index", type=int, default=0, help="Environment index to visualize") args = parser.parse_args(argv) replay_dump(args.dump_path, env_index=args.env_index) return 0
if __name__ == "__main__": raise SystemExit(main())