Source code for unilab.envs.motion_tracking.g1.tracking_sac

"""G1 Motion Tracking SAC Environment — thin SAC wrapper over G1MotionTrackingEnv.

Differences from the PPO base:
- Critic observations additionally include ``base_lin_vel`` (3 dims),
  matching holosoma's asymmetric actor-critic design for WBT.
- Registered under a separate name so it can be paired with FastSAC
  configs without affecting the PPO motion-tracking pipeline.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np

from unilab.base import registry
from unilab.dtype_config import get_global_dtype

from .tracking import G1MotionTrackingCfg, G1MotionTrackingEnv


[docs] @registry.envcfg("G1MotionTrackingSAC") @dataclass class G1MotionTrackingSACCfg(G1MotionTrackingCfg): """Config for SAC-based motion tracking (identical fields, separate registry entry)."""
[docs] @registry.env("G1MotionTrackingSAC", sim_backend="mujoco") @registry.env("G1MotionTrackingSAC", sim_backend="motrix") class G1MotionTrackingSACEnv(G1MotionTrackingEnv): """G1 Motion Tracking environment for FastSAC training. Extends the PPO motion-tracking environment with ``base_lin_vel`` appended to the critic observation, matching holosoma's asymmetric actor-critic WBT design. The motrix backend is registered for sim2sim eval/playback only — checkpoints trained on mujoco can be replayed via motrix's native renderer through ``eval --sim motrix``. """ @property def obs_groups_spec(self) -> dict[str, int]: spec = super().obs_groups_spec # Append base_lin_vel (3) to critic observations. return {**spec, "critic": spec["critic"] + 3} def _compute_obs( self, info: dict, motion_data, linvel: np.ndarray, gyro: np.ndarray, dof_pos: np.ndarray, dof_vel: np.ndarray, robot_body_pos_w: np.ndarray, robot_body_quat_w: np.ndarray, ) -> dict[str, np.ndarray]: obs = super()._compute_obs( # pyright: ignore[reportAttributeAccessIssue] info, motion_data, linvel, gyro, dof_pos, dof_vel, robot_body_pos_w, robot_body_quat_w, ) # Append base_lin_vel to critic observations. obs["critic"] = np.concatenate([obs["critic"], linvel], axis=1, dtype=get_global_dtype()) # type: ignore[call-overload] return obs