Source code for unilab.utils.nan_guard
"""NaN/Inf guard for env-layer numerical anomaly detection and state dumping."""
from __future__ import annotations
import logging
import shutil
import time
from dataclasses import dataclass
from pathlib import Path
import numpy as np
logger = logging.getLogger(__name__)
[docs]
@dataclass
class NanGuardCfg:
enabled: bool = False
buffer_size: int = 100
max_envs_to_dump: int = 5
output_dir: str | None = None
[docs]
class NanGuard:
[docs]
def __init__(
self,
cfg: NanGuardCfg,
num_envs: int,
supports_state_playback: bool,
) -> None:
self._cfg = cfg
self._num_envs = num_envs
self._supports_state_playback = supports_state_playback
self._buffer: list[np.ndarray] = []
self._buffer_idx: int = 0
self._buffer_full: bool = False
self._dumped: bool = False
[docs]
def capture(self, physics_state: np.ndarray | None) -> None:
if physics_state is None:
return
if not self._buffer_full and len(self._buffer) < self._cfg.buffer_size:
self._buffer.append(physics_state)
else:
self._buffer_full = True
self._buffer[self._buffer_idx] = physics_state
self._buffer_idx = (self._buffer_idx + 1) % self._cfg.buffer_size
[docs]
def check(self, obs: dict[str, np.ndarray], reward: np.ndarray) -> np.ndarray | None:
if not self._cfg.enabled or self._dumped:
return None
bad_mask = np.zeros(self._num_envs, dtype=bool)
for v in obs.values():
bad_mask |= ~np.all(np.isfinite(v), axis=tuple(range(1, v.ndim)))
bad_mask |= ~np.isfinite(reward)
if not np.any(bad_mask):
return None
return np.flatnonzero(bad_mask).astype(np.int32)
[docs]
def dump(
self,
nan_env_ids: np.ndarray,
model_file: str,
step: int,
) -> str | None:
if self._dumped:
return None
self._dumped = True
output_dir = Path(self._cfg.output_dir or "/tmp/unilab/nan_dumps")
output_dir.mkdir(parents=True, exist_ok=True)
dump_env_ids = nan_env_ids[: self._cfg.max_envs_to_dump]
if self._buffer_full:
ordered = self._buffer[self._buffer_idx :] + self._buffer[: self._buffer_idx]
else:
ordered = list(self._buffer)
states = np.stack(ordered, axis=0) if ordered else np.array([])
if states.ndim >= 3 and dump_env_ids.size > 0:
states = states[:, dump_env_ids]
metadata = {
"num_envs_total": self._num_envs,
"nan_env_ids": nan_env_ids,
"dumped_env_ids": dump_env_ids,
"buffer_size": self._cfg.buffer_size,
"buffer_len": len(ordered),
"detection_step": step,
"timestamp": time.time(),
"model_file": model_file,
"supports_state_playback": self._supports_state_playback,
}
ts = time.strftime("%Y%m%d_%H%M%S")
dump_name = f"nan_dump_{ts}_step{step}"
npz_path = output_dir / f"{dump_name}.npz"
np.savez(
str(npz_path),
states=states,
**{f"meta_{k}": v for k, v in metadata.items()},
)
if model_file and Path(model_file).is_file():
model_dst = output_dir / f"{dump_name}_model{Path(model_file).suffix}"
shutil.copy2(model_file, model_dst)
latest_link = output_dir / "nan_dump_latest.npz"
latest_link.unlink(missing_ok=True)
try:
latest_link.symlink_to(npz_path.name)
except OSError:
pass
logger.warning(
"NaN guard triggered at step %d for %d envs. Dump: %s",
step,
len(nan_env_ids),
npz_path,
)
return str(npz_path)