"""G1 box tracking environment with object-aware motion imitation."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, cast
import numpy as np
from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.scene import SceneCfg
from unilab.dr import DomainRandomizationManager, ResetPlan
from unilab.dr.dr_utils import build_common_reset_randomization, zero_actions
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.math import np_sample_uniform
from unilab.envs.common.rotation import (
np_matrix_from_quat,
np_quat_apply,
np_quat_error_magnitude,
np_quat_from_euler_xyz,
np_quat_inv,
np_quat_mul,
np_subtract_frame_transforms,
)
from .motion_box_loader import BoxMotionData, BoxMotionLoader
from .tracking import (
G1MotionTrackingCfg,
G1MotionTrackingDomainRandomizationProvider,
G1MotionTrackingEnv,
RewardConfig,
)
[docs]
@dataclass
class BoxRewardConfig(RewardConfig):
"""Reward config extended with object-tracking terms."""
scales: dict[str, float] = field(
default_factory=lambda: {
**RewardConfig().scales,
"undesired_contacts": -0.1,
"object_global_ref_position_error_exp": 1.0,
"object_global_ref_orientation_error_exp": 1.0,
}
)
std_object_pos: float = 0.3
std_object_ori: float = 0.4
[docs]
@dataclass
class G1BoxTrackingCfg(G1MotionTrackingCfg):
"""Configuration for the G1 large-box tracking task."""
scene: SceneCfg = field(
default_factory=lambda: SceneCfg(
model_file=str(ASSETS_ROOT_PATH / "robots" / "g1" / "scene_flat_with_largebox.xml")
)
)
motion_file: str | list[str] = str(
ASSETS_ROOT_PATH / "motions" / "g1" / "sub3_largebox_003_boxconverted.npz"
)
object_body_name: str = "largebox"
object_pos_threshold: float = 0.25
object_ori_threshold: float = 0.8
reward_config: BoxRewardConfig = field(default_factory=BoxRewardConfig)
[docs]
@registry.envcfg("G1BoxTracking")
@dataclass
class G1BoxTrackingEnvCfg(G1BoxTrackingCfg):
"""Registered config for G1 box tracking."""
pass
def _build_box_motion_reference_state(
env: Any, env_ids: np.ndarray, motion_data: BoxMotionData
) -> tuple[np.ndarray, np.ndarray]:
dtype = get_global_dtype()
num_reset = len(env_ids)
root_pos = motion_data.body_pos_w[:, 0].copy()
root_ori = motion_data.body_quat_w[:, 0].copy()
root_lin_vel = motion_data.body_lin_vel_w[:, 0].copy()
root_ang_vel = motion_data.body_ang_vel_w[:, 0].copy()
joint_pos = motion_data.joint_pos.copy()
joint_vel = motion_data.joint_vel.copy()
pose_rand = env.cfg.pose_randomization
pose_ranges = [
(pose_rand.x[0], pose_rand.x[1]),
(pose_rand.y[0], pose_rand.y[1]),
(pose_rand.z[0], pose_rand.z[1]),
(pose_rand.roll[0], pose_rand.roll[1]),
(pose_rand.pitch[0], pose_rand.pitch[1]),
(pose_rand.yaw[0], pose_rand.yaw[1]),
]
pose_samples = np.array(
[[np.random.uniform(low, high) for low, high in pose_ranges] for _ in range(num_reset)],
dtype=dtype,
)
root_pos += pose_samples[:, 0:3]
root_ori = np_quat_mul(
np_quat_from_euler_xyz(pose_samples[:, 3], pose_samples[:, 4], pose_samples[:, 5]),
root_ori,
)
vel_rand = env.cfg.velocity_randomization
vel_ranges = [
(vel_rand.x[0], vel_rand.x[1]),
(vel_rand.y[0], vel_rand.y[1]),
(vel_rand.z[0], vel_rand.z[1]),
(vel_rand.roll[0], vel_rand.roll[1]),
(vel_rand.pitch[0], vel_rand.pitch[1]),
(vel_rand.yaw[0], vel_rand.yaw[1]),
]
vel_samples = np.array(
[[np.random.uniform(low, high) for low, high in vel_ranges] for _ in range(num_reset)],
dtype=dtype,
)
root_lin_vel += vel_samples[:, :3]
root_ang_vel += vel_samples[:, 3:]
joint_pos += np_sample_uniform(
env.cfg.joint_position_range[0],
env.cfg.joint_position_range[1],
joint_pos.shape,
dtype=np.float32,
)
joint_range = env._get_joint_range()
if joint_range is not None:
joint_pos = np.clip(joint_pos, joint_range[:, 0], joint_range[:, 1])
qpos = np.tile(env._init_qpos, (num_reset, 1))
qvel = np.tile(env._init_qvel, (num_reset, 1))
qpos[:, 0:3] = root_pos
qpos[:, 3:7] = root_ori
qpos[:, 7 : 7 + joint_pos.shape[1]] = joint_pos
qvel[:, 0:3] = root_lin_vel
qvel[:, 3:6] = np_quat_apply(np_quat_inv(root_ori), root_ang_vel)
qvel[:, 6 : 6 + joint_vel.shape[1]] = joint_vel
if motion_data.object_pos_w is not None:
qpos[:, env._obj_pos_slice] = motion_data.object_pos_w
qpos[:, env._obj_quat_slice] = motion_data.object_quat_w
qvel[:, env._obj_lin_vel_slice] = motion_data.object_lin_vel_w
qvel[:, env._obj_ang_vel_slice] = motion_data.object_ang_vel_w
return qpos, qvel
[docs]
class G1BoxTrackingDomainRandomizationProvider(G1MotionTrackingDomainRandomizationProvider):
"""Reset provider that restores both robot and object state from motion data."""
[docs]
def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan:
num_reset = len(env_ids)
motion_frames = env.motion_sampler.sample_frames(env_ids)
motion_data = cast(BoxMotionData, env.motion_loader.get_motion_at_frame(motion_frames))
qpos, qvel = _build_box_motion_reference_state(env, env_ids, motion_data)
info_updates = {
"current_actions": zero_actions(num_reset, env._num_action),
"last_actions": zero_actions(num_reset, env._num_action),
}
return ResetPlan(
env_ids=env_ids,
qpos=qpos,
qvel=qvel,
info_updates=info_updates,
randomization=build_common_reset_randomization(
env, num_reset, base_kp=self._base_kp, base_kd=self._base_kd
),
)
[docs]
@registry.env("G1BoxTracking", sim_backend="mujoco")
@registry.env("G1BoxTracking", sim_backend="motrix")
class G1BoxTrackingEnv(G1MotionTrackingEnv):
"""Motion tracking env extended with large-box state and rewards."""
_cfg: G1BoxTrackingCfg
[docs]
def __init__(self, cfg: G1BoxTrackingCfg, num_envs=1, backend_type="mujoco"):
super().__init__(cfg, num_envs, backend_type)
motion_body_ids = self._backend.get_motion_body_ids(cfg.body_names)
self.motion_loader = BoxMotionLoader(cfg.motion_file, body_indices=motion_body_ids)
self.motion_sampler = type(self.motion_sampler)(
self.motion_loader, mode=cfg.sampling_mode, num_envs=num_envs
)
if cfg.domain_rand.randomize_kp or cfg.domain_rand.randomize_kd:
base_kp, base_kd = self._backend.get_actuator_gains()
dr_provider = G1BoxTrackingDomainRandomizationProvider(base_kp=base_kp, base_kd=base_kd)
else:
dr_provider = G1BoxTrackingDomainRandomizationProvider()
# Parent init already applied init randomization and materialized the backend.
# Box tracking only needs to swap in a box-aware reset/obs provider for future resets.
self._dr_manager = DomainRandomizationManager(self, dr_provider)
self._object_body_ids = self._backend.get_body_ids([cfg.object_body_name])
nq = self._init_qpos.shape[0]
self._obj_pos_slice = slice(nq - 7, nq - 4)
self._obj_quat_slice = slice(nq - 4, nq)
nv = self._init_qvel.shape[0]
self._obj_lin_vel_slice = slice(nv - 6, nv - 3)
self._obj_ang_vel_slice = slice(nv - 3, nv)
if not self.motion_loader.has_object:
raise ValueError(
f"Motion file '{cfg.motion_file}' does not contain object data. "
"Expected keys: object_pos_w, object_quat_w, object_lin_vel_w, object_ang_vel_w"
)
def _get_joint_range(self) -> np.ndarray | None:
joint_range = super()._get_joint_range()
if joint_range is not None and joint_range.shape[0] > self.motion_loader.num_joints:
joint_range = joint_range[: self.motion_loader.num_joints]
return joint_range
def _resample_reference_state(self, env_ids: np.ndarray) -> None:
motion_frames = self.motion_sampler.sample_frames(env_ids)
motion_data = cast(BoxMotionData, self.motion_loader.get_motion_at_frame(motion_frames))
qpos, qvel = _build_box_motion_reference_state(self, env_ids, motion_data)
self._backend.set_state(env_ids, qpos, qvel)
[docs]
def get_dof_pos(self) -> np.ndarray:
dof_pos = super().get_dof_pos()
return dof_pos[:, : self.motion_loader.num_joints]
[docs]
def get_dof_vel(self) -> np.ndarray:
dof_vel = super().get_dof_vel()
return dof_vel[:, : self.motion_loader.num_joints]
@property
def obs_groups_spec(self) -> dict[str, int]:
spec = super().obs_groups_spec
return {**spec, "critic": spec["critic"] + 12}
def _actor_obs_dim(self, n: int) -> int:
return 6 + 3 + n * 5
def _build_actor_obs(
self,
*,
command: np.ndarray,
motion_anchor_pos_b: np.ndarray,
motion_anchor_ori_b: np.ndarray,
noisy_linvel: np.ndarray,
noisy_gyro: np.ndarray,
noisy_joint_pos_rel: np.ndarray,
noisy_dof_vel: np.ndarray,
last_actions: np.ndarray,
) -> np.ndarray:
return np.concatenate(
[
command,
motion_anchor_ori_b,
noisy_gyro,
noisy_joint_pos_rel,
noisy_dof_vel,
last_actions,
],
axis=1,
dtype=get_global_dtype(),
)
def _init_reward_functions(self):
super()._init_reward_functions()
self._reward_fns["object_global_ref_position_error_exp"] = self._reward_object_position
self._reward_fns["object_global_ref_orientation_error_exp"] = (
self._reward_object_orientation
)
def _compute_terminations(
self,
motion_data: BoxMotionData,
robot_body_pos_w: np.ndarray,
robot_body_quat_w: np.ndarray,
) -> np.ndarray:
terminated = super()._compute_terminations(motion_data, robot_body_pos_w, robot_body_quat_w)
if motion_data.object_pos_w is not None:
obj_pos_w = self._backend.get_body_pos_w(self._object_body_ids)[:, 0, :]
obj_pos_error = np.linalg.norm(obj_pos_w - motion_data.object_pos_w, axis=-1)
terminated |= obj_pos_error > self._cfg.object_pos_threshold
if motion_data.object_quat_w is not None:
obj_quat_w = self._backend.get_body_quat_w(self._object_body_ids)[:, 0, :]
obj_ori_error = np_quat_error_magnitude(obj_quat_w, motion_data.object_quat_w)
terminated |= obj_ori_error > self._cfg.object_ori_threshold
return terminated
def _compute_obs(
self,
info: dict,
motion_data: BoxMotionData,
linvel: np.ndarray,
gyro: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
robot_body_pos_w: np.ndarray,
robot_body_quat_w: np.ndarray,
) -> dict[str, np.ndarray]:
obs = super()._compute_obs(
info, motion_data, linvel, gyro, dof_pos, dof_vel, robot_body_pos_w, robot_body_quat_w
)
env_ids = info.get("env_ids")
if isinstance(env_ids, np.ndarray):
obj_pos_w = self._backend.get_body_pos_w(self._object_body_ids)[env_ids, 0, :]
obj_quat_w = self._backend.get_body_quat_w(self._object_body_ids)[env_ids, 0, :]
obj_lin_vel_w = self._backend.get_body_lin_vel_w(self._object_body_ids)[env_ids, 0, :]
else:
num_envs = linvel.shape[0]
obj_pos_w = self._backend.get_body_pos_w(self._object_body_ids)[:num_envs, 0, :]
obj_quat_w = self._backend.get_body_quat_w(self._object_body_ids)[:num_envs, 0, :]
obj_lin_vel_w = self._backend.get_body_lin_vel_w(self._object_body_ids)[:num_envs, 0, :]
anchor_pos_w = robot_body_pos_w[:, self.anchor_body_idx]
anchor_quat_w = robot_body_quat_w[:, self.anchor_body_idx]
obj_pos_b, obj_ori_rel = np_subtract_frame_transforms(
anchor_pos_w, anchor_quat_w, obj_pos_w, obj_quat_w
)
obj_ori_mat = np_matrix_from_quat(obj_ori_rel)
num_envs = linvel.shape[0]
obj_ori_b = obj_ori_mat[:, :, :2].reshape(num_envs, 6)
obj_lin_vel_b = np_quat_apply(np_quat_inv(anchor_quat_w), obj_lin_vel_w)
object_obs = np.concatenate(
[obj_pos_b, obj_ori_b, obj_lin_vel_b],
axis=1,
dtype=get_global_dtype(),
)
obs["critic"] = np.concatenate(
[obs["critic"], object_obs], axis=1, dtype=get_global_dtype()
)
return obs
def _reward_object_position(self, info: dict) -> np.ndarray:
motion_data: BoxMotionData = info["motion_data"]
if motion_data.object_pos_w is None:
return np.zeros((self._num_envs,), dtype=get_global_dtype())
obj_pos_w = self._backend.get_body_pos_w(self._object_body_ids)[:, 0, :]
error = np.sum(np.square(obj_pos_w - motion_data.object_pos_w), axis=-1)
return np.asarray(
np.exp(-error / self._cfg.reward_config.std_object_pos**2), dtype=get_global_dtype()
)
def _reward_object_orientation(self, info: dict) -> np.ndarray:
motion_data: BoxMotionData = info["motion_data"]
if motion_data.object_quat_w is None:
return np.zeros((self._num_envs,), dtype=get_global_dtype())
obj_quat_w = self._backend.get_body_quat_w(self._object_body_ids)[:, 0, :]
error = np_quat_error_magnitude(obj_quat_w, motion_data.object_quat_w) ** 2
return np.asarray(
np.exp(-error / self._cfg.reward_config.std_object_ori**2), dtype=get_global_dtype()
)