from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.backend import create_backend
from unilab.base.np_env import NpEnvState
from unilab.base.scene import SceneCfg
from unilab.dr.types import ResetPlan
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import np_matrix_from_quat, np_quat_from_euler_xyz
from unilab.envs.locomotion.common import rewards
from unilab.envs.locomotion.common.commands import Commands
from unilab.envs.locomotion.common.domain_rand import DomainRandConfig
from unilab.envs.locomotion.common.dr_provider import LocomotionDRProvider
from unilab.envs.locomotion.common.rewards import RewardContext
from unilab.envs.locomotion.go2_arm.base import (
DEFAULT_LEG_ANGLES,
Go2ArmBaseCfg,
Go2ArmBaseEnv,
Go2ArmSensor,
build_go2_arm_position_gains,
)
def _default_go2_arm_model_file() -> str:
return str(ASSETS_ROOT_PATH / "robots" / "go2_arm" / "scene_flat.xml")
def _default_go2_arm_scene() -> SceneCfg:
return SceneCfg(model_file=_default_go2_arm_model_file())
def _resolve_go2_arm_scene(cfg: "Go2ArmManipLocoCfg") -> SceneCfg:
scene = cfg.scene
default_model_file = _default_go2_arm_model_file()
if scene is None:
scene = SceneCfg(model_file=cfg.model_file)
elif cfg.model_file != default_model_file and scene.model_file == default_model_file:
scene = SceneCfg(
model_file=cfg.model_file,
fragment_files=list(scene.fragment_files),
terrain=scene.terrain,
)
cfg.scene = scene
return scene
def _sphere2cart(sphere: np.ndarray) -> np.ndarray:
"""Convert (..., 3)[l, phi, theta] to (..., 3)[x, y, z]."""
l = sphere[..., 0]
phi = sphere[..., 1]
theta = sphere[..., 2]
x = l * np.cos(phi) * np.cos(theta)
y = l * np.sin(theta)
z = l * np.sin(phi) * np.cos(theta)
return np.stack([x, y, z], axis=-1)
def _cart2sphere(cart: np.ndarray) -> np.ndarray:
"""Convert (..., 3)[x, y, z] to (..., 3)[l, phi, theta]."""
cart = np.asarray(cart)
l_sq = np.sum(cart**2, axis=-1, keepdims=True)
l = np.sqrt(np.maximum(l_sq, 1e-12))
phi = np.arctan2(cart[..., 2:3], cart[..., 0:1])
theta = np.arcsin(np.clip(cart[..., 1:2] / l, -1.0, 1.0))
return np.concatenate([l, phi, theta], axis=-1)
[docs]
@dataclass
class InitState:
pos: list[float] = field(default_factory=lambda: [0.0, 0.0, 0.42])
[docs]
@dataclass
class Go2ArmDomainRandConfig(DomainRandConfig):
randomize_kp: bool = True
kp_multiplier_range: list[float] = field(default_factory=lambda: [0.9, 1.1])
randomize_kd: bool = True
kd_multiplier_range: list[float] = field(default_factory=lambda: [0.9, 1.1])
[docs]
@dataclass
class EEGoalConfig:
"""End-effector goal config in spherical coordinates."""
# Spherical sampling ranges.
sphere_l_range: list[float] = field(default_factory=lambda: [0.3, 0.6])
sphere_phi_range: list[float] = field(default_factory=lambda: [-1.2566, 1.0472])
sphere_theta_range: list[float] = field(default_factory=lambda: [-2.3562, 2.3562])
# Trajectory timing.
traj_time_range: list[float] = field(default_factory=lambda: [1.0, 3.0])
hold_time_range: list[float] = field(default_factory=lambda: [0.5, 2.0])
# Collision checks.
collision_upper_limits: list[float] = field(default_factory=lambda: [0.3, 0.15, 0.05 - 0.165])
collision_lower_limits: list[float] = field(
default_factory=lambda: [-0.2, -0.15, -0.35 - 0.165]
)
underground_limit: float = -0.57
num_collision_check_samples: int = 10
num_resample_attempts: int = 10
# End-effector target orientation sampling (XYZ Euler to wxyz quaternion).
default_orn_roll: float = float(np.pi / 2.0)
arm_induced_pitch: float = 0.78
delta_orn_r: list[float] = field(default_factory=lambda: [-0.5, 0.5])
delta_orn_p: list[float] = field(default_factory=lambda: [-0.5, 0.5])
delta_orn_y: list[float] = field(default_factory=lambda: [-0.5, 0.5])
# Initial goal used as the reset-time start point.
init_ee_cart: list[float] = field(default_factory=lambda: [0.30, 0.0, 0.25])
[docs]
@dataclass
class CommandsConfig(Commands):
# Periodic command resampling time in seconds. None disables mid-episode resampling.
resample_time_s: float | None = None
# Probability of explicitly sampling a zero-velocity command for stable standing.
zero_command_prob: float = 0.2
[docs]
@dataclass
class CurriculumConfig:
"""Expand velocity command ranges when mean tracking_lin_vel exceeds a threshold."""
enable: bool = False
# Expansion threshold: per-step episode mean tracking_lin_vel must exceed this value.
threshold: float = 0.8
# Expansion step applied on each trigger: [vx, vy, vyaw].
step_size: list[float] = field(default_factory=lambda: [0.1, 0.05, 0.1])
# Absolute velocity-range limits to prevent unbounded expansion.
max_vel_limit: list[float] = field(default_factory=lambda: [1.0, 0.4, 0.8])
[docs]
@dataclass
class RewardConfig:
scales: dict[str, float]
tracking_sigma: float
base_height_target: float
target_foot_height: float = 0.1
object_sigma: float = 0.1
# Soft limits for 12 leg joints in radians. Empty lists disable the reward term.
leg_dof_upper_limits: list[float] = field(default_factory=list)
leg_dof_lower_limits: list[float] = field(default_factory=list)
dof_pos_limit_margin: float = 0.01
[docs]
@dataclass
class HistoryConfig:
"""Actor/critic observation history lengths. A value of 1 disables history."""
num_actor_history: int = 1
num_critic_history: int = 1
[docs]
@dataclass
class ArmStageConfig:
freeze_arm_joints: bool = False
disable_ee_goal_trajectory: bool = False
fixed_ee_goal_cart: list[float] = field(default_factory=lambda: [0.30, 0.0, 0.25])
[docs]
@registry.envcfg("Go2ArmManipLoco")
@dataclass
class Go2ArmManipLocoCfg(Go2ArmBaseCfg):
scene: SceneCfg = field(default_factory=_default_go2_arm_scene)
model_file: str = field(default_factory=_default_go2_arm_model_file)
max_episode_seconds: float = 20.0
init_state: InitState = field(default_factory=InitState)
commands: CommandsConfig = field(default_factory=CommandsConfig) # type: ignore[assignment]
reward_config: RewardConfig | None = None
sensor: Go2ArmSensor = field(default_factory=Go2ArmSensor) # type: ignore[assignment]
domain_rand: Go2ArmDomainRandConfig = field(default_factory=Go2ArmDomainRandConfig)
goal_ee: EEGoalConfig = field(default_factory=EEGoalConfig)
history: HistoryConfig = field(default_factory=HistoryConfig)
arm_stage: ArmStageConfig = field(default_factory=ArmStageConfig)
curriculum: CurriculumConfig = field(default_factory=CurriculumConfig)
[docs]
class Go2ArmManipLocoDRProvider(LocomotionDRProvider):
[docs]
def __init__(
self,
*,
base_kp: np.ndarray | None = None,
base_kd: np.ndarray | None = None,
base_body_mass: np.ndarray | None = None,
base_geom_friction: np.ndarray | None = None,
ground_geom_id: int | None = None,
base_dof_armature: np.ndarray | None = None,
):
self._base_kp = base_kp
self._base_kd = base_kd
self._base_body_mass = base_body_mass
self._base_geom_friction = base_geom_friction
self._ground_geom_id = ground_geom_id
self._base_dof_armature = base_dof_armature
def _sample_commands(self, env: Any, num_reset: int) -> np.ndarray:
commands = super()._sample_commands(env, num_reset)
return env._postprocess_velocity_commands(commands)
def _get_base_actuator_gains(self, env: Any) -> tuple[np.ndarray | None, np.ndarray | None]:
return self._base_kp, self._base_kd
def _get_reset_randomization_baselines(
self, env: Any
) -> tuple[np.ndarray | None, np.ndarray | None, int | None, np.ndarray | None]:
return (
self._base_body_mass,
self._base_geom_friction,
self._ground_geom_id,
self._base_dof_armature,
)
[docs]
def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan:
plan = super().build_reset_plan(env, env_ids)
env.reset_ee_goals(env_ids)
# Update command curriculum at episode end before resetting timers.
env._update_command_curriculum(env_ids)
# Reset command timers. reset_ee_goals already clears _arm_goal_timer.
env._cmd_timer[env_ids] = 0
env._arm_goal_timer[env_ids] = 0
# Clear history buffers for reset environments.
env._history_obs_buf[env_ids] = 0.0
env._history_critic_buf[env_ids] = 0.0
env.phase[env_ids] = 0.0
env._write_feet_phase(env_ids, env._command_is_moving(plan.info_updates["commands"]))
return plan
def _compute_reset_obs(
self,
env: Any,
env_ids: np.ndarray,
info_updates: dict[str, Any],
linvel: np.ndarray,
gyro: np.ndarray,
gravity: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
) -> dict[str, np.ndarray]:
ee_local_pos, _ = env.get_ee_local_pose()
# info_updates may contain global info arrays; slice entries for env_ids.
n = len(env_ids)
sliced_info: dict[str, Any] = {}
for k, v in info_updates.items():
if isinstance(v, np.ndarray) and v.ndim >= 1 and v.shape[0] == env._num_envs:
sliced_info[k] = v[env_ids]
else:
sliced_info[k] = v
actor_raw = env._compute_raw_obs( # type: ignore[no-any-return]
sliced_info,
linvel,
gyro,
gravity,
dof_pos,
dof_vel,
ee_local_pos[env_ids],
env.curr_ee_goal_cart[env_ids],
env.feet_phase[env_ids],
add_noise=True,
)
critic_raw = env._compute_raw_obs( # type: ignore[no-any-return]
sliced_info,
linvel,
gyro,
gravity,
dof_pos,
dof_vel,
ee_local_pos[env_ids],
env.curr_ee_goal_cart[env_ids],
env.feet_phase[env_ids],
add_noise=False,
)
del n
return env._update_history(actor_raw, env_ids=env_ids, critic_raw_obs=critic_raw) # type: ignore[no-any-return]
[docs]
@registry.env("Go2ArmManipLoco", sim_backend="motrix")
@registry.env("Go2ArmManipLoco", sim_backend="mujoco")
class Go2ArmManipLocoEnv(Go2ArmBaseEnv):
_cfg: Go2ArmManipLocoCfg
[docs]
def __init__(self, cfg: Go2ArmManipLocoCfg, num_envs=1, backend_type="mujoco"):
if cfg.reward_config is None:
raise ValueError("reward_config must be provided via Hydra configuration")
if backend_type not in {"mujoco", "motrix"}:
raise ValueError(
"Go2ArmManipLoco supports only the mujoco and motrix backends, "
f"got {backend_type!r}"
)
scene = _resolve_go2_arm_scene(cfg)
backend_kwargs: dict[str, Any] = {
"base_name": cfg.asset.base_name,
"push_body_name": cfg.domain_rand.push_body_name,
}
if backend_type == "motrix":
backend_kwargs["motrix_max_iterations"] = cfg.iterations
else:
backend_kwargs["position_actuator_gains"] = build_go2_arm_position_gains(
cfg.control_config
)
backend_kwargs["iterations"] = cfg.iterations
backend_kwargs["post_step_forward_sensor"] = cfg.post_step_forward_sensor
backend = create_backend(
backend_type,
scene,
num_envs,
cfg.sim_dt,
**backend_kwargs,
)
super().__init__(cfg, backend, num_envs)
if self._num_action != 18:
raise ValueError(f"Go2ArmManipLoco expects 18 actuators, got {self._num_action}")
if not 0.0 <= cfg.commands.zero_command_prob <= 1.0:
raise ValueError(
"env.commands.zero_command_prob must be in [0, 1], "
f"got {cfg.commands.zero_command_prob}"
)
self._enable_reward_log = True
self._reward_cfg = cfg.reward_config
self._leg_pose_weights = np.array([1.0, 1.0, 0.1] * 4 + [0.0] * 6, dtype=get_global_dtype())
self._init_reward_functions()
self._init_ee_goal_buffers(num_envs)
self._current_ee_local_pos = np.zeros((num_envs, 3), dtype=get_global_dtype())
self.phase = np.zeros((num_envs,), dtype=np.float32)
self.feet_phase = np.zeros((num_envs, len(cfg.sensor.feet_force)), dtype=np.float32)
self.gait_frequency = 2.0
self.feet_force = np.zeros((num_envs, len(cfg.sensor.feet_force), 3), dtype=np.float32)
self.feet_pos = np.zeros((num_envs, len(cfg.sensor.feet_pos), 3), dtype=np.float32)
# Mid-episode command resampling. None disables periodic resampling.
if cfg.commands.resample_time_s is not None:
self._cmd_resample_steps: int | None = max(
1, int(cfg.commands.resample_time_s / cfg.ctrl_dt)
)
self._cmd_timer = np.random.randint(
0, self._cmd_resample_steps, size=(num_envs,), dtype=np.int32
)
else:
self._cmd_resample_steps = None
self._cmd_timer = np.zeros((num_envs,), dtype=np.int32)
# Per-env episode tracking_lin_vel accumulator for command curriculum.
self._episode_sum_tracking_vel = np.zeros(num_envs, dtype=np.float64)
self._episode_steps = np.zeros(num_envs, dtype=np.int32)
# History buffers.
# Actor obs excludes linvel (first 3 dims) to avoid bypassing the estimator.
# raw_obs layout: linvel(3)+gyro(3)+(-gravity)(3)+command(3)+feet_phase(4)+
# diff(18)+dof_vel(18)+ee_local_pos(3)+ee_goal_cart(3)+
# ee_error(3)+last_actions(18) = 79 dims.
_CRITIC_ONE = 79 # Single-step critic obs dim, including privileged linvel.
_ACTOR_ONE = 76 # Single-step actor obs dim after removing linvel[0:3].
H_a = cfg.history.num_actor_history
H_c = cfg.history.num_critic_history
self._actor_one_step_dim = _ACTOR_ONE
self._critic_one_step_dim = _CRITIC_ONE
self._history_obs_buf = np.zeros((num_envs, H_a * _ACTOR_ONE), dtype=get_global_dtype())
self._history_critic_buf = np.zeros((num_envs, H_c * _CRITIC_ONE), dtype=get_global_dtype())
base_kp: np.ndarray | None = None
base_kd: np.ndarray | None = None
if cfg.domain_rand.randomize_kp or cfg.domain_rand.randomize_kd:
base_kp, base_kd = backend.get_actuator_gains()
base_body_mass: np.ndarray | None = None
if cfg.domain_rand.randomize_body_mass:
base_body_mass = backend.get_body_mass()
base_geom_friction: np.ndarray | None = None
ground_geom_id: int | None = None
if cfg.domain_rand.randomize_ground_friction:
base_geom_friction = backend.get_geom_friction()
ground_geom_id = backend.get_geom_id(cfg.asset.ground)
base_dof_armature: np.ndarray | None = None
if cfg.domain_rand.randomize_dof_armature:
base_dof_armature = backend.get_dof_armature()
dr_provider = Go2ArmManipLocoDRProvider(
base_kp=base_kp,
base_kd=base_kd,
base_body_mass=base_body_mass,
base_geom_friction=base_geom_friction,
ground_geom_id=ground_geom_id,
base_dof_armature=base_dof_armature,
)
self._init_domain_randomization(dr_provider)
@property
def obs_groups_spec(self) -> dict[str, int]:
H_a = self._cfg.history.num_actor_history
H_c = self._cfg.history.num_critic_history
return {"obs": H_a * self._actor_one_step_dim, "critic": H_c * self._critic_one_step_dim}
def _init_ee_goal_buffers(self, num_envs: int) -> None:
dtype = get_global_dtype()
self.curr_ee_goal_cart = np.zeros((num_envs, 3), dtype=dtype)
self.curr_ee_goal_sphere = np.zeros((num_envs, 3), dtype=dtype)
self.ee_goal_orn_euler = np.zeros((num_envs, 3), dtype=dtype)
self.ee_goal_orn_quat = np.tile(
np.asarray([1.0, 0.0, 0.0, 0.0], dtype=dtype),
(num_envs, 1),
)
self.ee_goal_orn_delta_rpy = np.zeros((num_envs, 3), dtype=dtype)
# Goal position in world coordinates, used for render-time visualization.
self.curr_ee_goal_world = np.zeros((num_envs, 3), dtype=dtype)
self._ee_start_sphere = np.zeros((num_envs, 3), dtype=dtype)
self._ee_goal_sphere = np.zeros((num_envs, 3), dtype=dtype)
self._arm_goal_timer = np.zeros((num_envs,), dtype=np.int32)
self._traj_steps = np.ones((num_envs,), dtype=np.int32)
self._traj_total_steps = np.ones((num_envs,), dtype=np.int32)
def _sample_timing(self, env_ids: np.ndarray) -> None:
"""Sample movement and hold durations for env_ids."""
cfg = self._cfg.goal_ee
dt = self._cfg.ctrl_dt
traj_t = np.random.uniform(*cfg.traj_time_range, size=len(env_ids))
hold_t = np.random.uniform(*cfg.hold_time_range, size=len(env_ids))
traj_s = np.maximum(1, np.round(traj_t / dt).astype(np.int32))
hold_s = np.maximum(0, np.round(hold_t / dt).astype(np.int32))
self._traj_steps[env_ids] = traj_s
self._traj_total_steps[env_ids] = traj_s + hold_s
def _collision_check_sphere(self, starts: np.ndarray, goals: np.ndarray) -> np.ndarray:
"""Check spherical lerp paths for collisions after Cartesian conversion."""
cfg = self._cfg.goal_ee
dtype = get_global_dtype()
n = max(2, cfg.num_collision_check_samples)
t = np.linspace(0.0, 1.0, n, dtype=dtype) # (n,)
path_sphere = (
starts[:, None, :] + (goals - starts)[:, None, :] * t[None, :, None]
) # (N, n, 3)
path_cart = _sphere2cart(path_sphere.reshape(-1, 3)).reshape(len(starts), n, 3)
upper = np.asarray(cfg.collision_upper_limits, dtype=dtype)
lower = np.asarray(cfg.collision_lower_limits, dtype=dtype)
inside_collision_box = np.all(path_cart < upper, axis=2) & np.all(path_cart > lower, axis=2)
collision_mask = np.any(inside_collision_box, axis=1)
underground_mask = np.any(path_cart[..., 2] < float(cfg.underground_limit), axis=1)
return collision_mask | underground_mask
def _sample_goal_spheres(self, env_ids: np.ndarray, start_spheres: np.ndarray) -> None:
"""Sample goal spheres and write them into _ee_goal_sphere[env_ids]."""
cfg = self._cfg.goal_ee
dtype = get_global_dtype()
init_sphere = _cart2sphere(np.asarray(cfg.init_ee_cart, dtype=dtype)[None, :])[0]
candidates = np.broadcast_to(init_sphere, (len(env_ids), 3)).copy()
remaining = np.arange(len(env_ids), dtype=np.int32)
for _ in range(max(1, cfg.num_resample_attempts)):
l = np.random.uniform(*cfg.sphere_l_range, size=len(remaining)).astype(dtype)
phi = np.random.uniform(*cfg.sphere_phi_range, size=len(remaining)).astype(dtype)
theta = np.random.uniform(*cfg.sphere_theta_range, size=len(remaining)).astype(dtype)
new_goals = np.stack([l, phi, theta], axis=1)
candidates[remaining] = new_goals
unsafe = self._collision_check_sphere(start_spheres[remaining], new_goals)
remaining = remaining[unsafe]
if len(remaining) == 0:
break
self._ee_goal_sphere[env_ids] = candidates
def _sample_ee_goal_orn_delta(self, env_ids: np.ndarray, *, is_init: bool) -> None:
if len(env_ids) == 0:
return
if is_init:
self.ee_goal_orn_delta_rpy[env_ids] = 0.0
return
dtype = get_global_dtype()
ranges = (
self._cfg.goal_ee.delta_orn_r,
self._cfg.goal_ee.delta_orn_p,
self._cfg.goal_ee.delta_orn_y,
)
for axis, bounds in enumerate(ranges):
low_high = np.asarray(bounds, dtype=dtype)
if low_high.shape != (2,):
raise ValueError("goal_ee delta orientation ranges must have shape (2,)")
if low_high[1] < low_high[0]:
raise ValueError("goal_ee delta orientation range high must be >= low")
self.ee_goal_orn_delta_rpy[env_ids, axis] = np.random.uniform(
low=low_high[0],
high=low_high[1],
size=(len(env_ids),),
).astype(dtype)
def _update_curr_ee_goal_orientation(self, env_ids: np.ndarray) -> None:
if len(env_ids) == 0:
return
dtype = get_global_dtype()
goal_cfg = self._cfg.goal_ee
goal_local = self.curr_ee_goal_cart[env_ids]
goal_sphere = self.curr_ee_goal_sphere[env_ids]
delta = self.ee_goal_orn_delta_rpy[env_ids]
default_yaw = np.arctan2(goal_local[:, 1], goal_local[:, 0])
default_pitch = -goal_sphere[:, 1] + float(goal_cfg.arm_induced_pitch)
roll = float(goal_cfg.default_orn_roll) + delta[:, 0]
pitch = default_pitch + delta[:, 1]
yaw = default_yaw + delta[:, 2]
self.ee_goal_orn_euler[env_ids] = np.stack([roll, pitch, yaw], axis=1).astype(dtype)
self.ee_goal_orn_quat[env_ids] = np.atleast_2d(
np_quat_from_euler_xyz(roll, pitch, yaw)
).astype(dtype)
[docs]
def reset_ee_goals(self, env_ids: np.ndarray) -> None:
"""Reset EE goals by sampling the first segment from init_ee_cart."""
env_ids = np.asarray(env_ids, dtype=np.int32).reshape(-1)
if len(env_ids) == 0:
return
stage_cfg = self._cfg.arm_stage
if stage_cfg.disable_ee_goal_trajectory:
fixed_goal = np.asarray(stage_cfg.fixed_ee_goal_cart, dtype=get_global_dtype())
if fixed_goal.shape != (3,):
raise ValueError(
f"env.arm_stage.fixed_ee_goal_cart must have shape (3,), got {fixed_goal.shape}"
)
fixed_sphere = _cart2sphere(fixed_goal[None, :])[0]
self._ee_start_sphere[env_ids] = fixed_sphere
self._ee_goal_sphere[env_ids] = fixed_sphere
self._traj_steps[env_ids] = 1
self._traj_total_steps[env_ids] = 1
self._arm_goal_timer[env_ids] = 0
self.curr_ee_goal_cart[env_ids] = fixed_goal
self.curr_ee_goal_sphere[env_ids] = fixed_sphere
self._sample_ee_goal_orn_delta(env_ids, is_init=True)
self._update_curr_ee_goal_orientation(env_ids)
return
dtype = get_global_dtype()
init_sphere = _cart2sphere(
np.asarray(self._cfg.goal_ee.init_ee_cart, dtype=dtype)[None, :]
)[0]
self._ee_start_sphere[env_ids] = init_sphere
self._sample_goal_spheres(
env_ids,
np.broadcast_to(init_sphere, (len(env_ids), 3)).copy(),
)
self._sample_timing(env_ids)
self._arm_goal_timer[env_ids] = 0
self.curr_ee_goal_sphere[env_ids] = init_sphere
self.curr_ee_goal_cart[env_ids] = _sphere2cart(
np.broadcast_to(init_sphere, (len(env_ids), 3))
)
self._sample_ee_goal_orn_delta(env_ids, is_init=True)
self._update_curr_ee_goal_orientation(env_ids)
def _update_command_curriculum(self, env_ids: np.ndarray) -> None:
"""Update velocity command ranges at episode end from tracking_lin_vel.
This follows the go2_arx_robot.py rule:
mean(episode_sum[env_ids] / episode_steps[env_ids]) > threshold
The unweighted tracking_lin_vel maximum is 1.0 per step.
"""
cur = self._cfg.curriculum
if not cur.enable:
return
ep_steps = np.maximum(self._episode_steps[env_ids], 1)
mean_per_step = float(np.mean(self._episode_sum_tracking_vel[env_ids] / ep_steps))
if mean_per_step > cur.threshold:
step = np.asarray(cur.step_size, dtype=np.float64)
max_limit = np.asarray(cur.max_vel_limit, dtype=np.float64)
low = np.asarray(self._cfg.commands.vel_limit[0], dtype=np.float64)
high = np.asarray(self._cfg.commands.vel_limit[1], dtype=np.float64)
low = np.clip(low - step, -max_limit, 0.0)
high = np.clip(high + step, 0.0, max_limit)
self._cfg.commands.vel_limit = [low.tolist(), high.tolist()]
# Clear episode statistics for reset environments.
self._episode_sum_tracking_vel[env_ids] = 0.0
self._episode_steps[env_ids] = 0
# Command clipping threshold: small vx/vy/vyaw commands are zeroed out.
_CMD_CLIP: float = 0.1
def _command_is_moving(self, commands: np.ndarray) -> np.ndarray:
command_arr = np.asarray(commands)
return np.any(np.abs(command_arr[:, :3]) > self._CMD_CLIP, axis=1)
def _normalize_velocity_commands(self, commands: np.ndarray) -> np.ndarray:
normalized = np.asarray(commands, dtype=get_global_dtype()).copy()
normalized[~self._command_is_moving(normalized)] = 0.0
return normalized
def _postprocess_velocity_commands(self, commands: np.ndarray) -> np.ndarray:
processed = self._normalize_velocity_commands(commands)
prob = float(self._cfg.commands.zero_command_prob)
if prob <= 0.0 or processed.shape[0] == 0:
return processed
zero_mask = np.random.random(size=(processed.shape[0],)) < prob
processed[zero_mask] = 0.0
return processed
def _write_feet_phase(self, env_ids: np.ndarray | slice, is_moving: np.ndarray) -> None:
phase = self.phase[env_ids]
feet_phase = self.feet_phase[env_ids].copy()
feet_phase[:, 0] = phase
feet_phase[:, 3] = phase
feet_phase[:, 1] = (phase + 0.5) % 1.0
feet_phase[:, 2] = (phase + 0.5) % 1.0
feet_phase[~is_moving] = 0.0
self.feet_phase[env_ids] = feet_phase
def _resample_commands(self, env_ids: np.ndarray, info: dict) -> None:
"""Resample velocity commands and zero out small commands."""
if len(env_ids) == 0:
return
low = np.asarray(self._cfg.commands.vel_limit[0], dtype=get_global_dtype())
high = np.asarray(self._cfg.commands.vel_limit[1], dtype=get_global_dtype())
new_cmds = np.random.uniform(low=low, high=high, size=(len(env_ids), 3)).astype(
get_global_dtype()
)
new_cmds = self._postprocess_velocity_commands(new_cmds)
if "commands" in info:
info["commands"][env_ids] = new_cmds
[docs]
def apply_action(self, actions: np.ndarray, state: NpEnvState) -> np.ndarray:
state.info["last_actions"] = state.info.get("current_actions", np.zeros_like(actions))
stage_cfg = self._cfg.arm_stage
if stage_cfg.freeze_arm_joints:
effective_actions = actions.copy()
effective_actions[:, 12:18] = 0.0
else:
effective_actions = actions
state.info["current_actions"] = effective_actions
exec_actions = (
state.info["last_actions"]
if self._cfg.control_config.simulate_action_latency
else effective_actions
)
ee_local_pos, ee_local_quat = self.get_ee_local_pose()
dq_ik = self.compute_arm_ik_delta(
self.curr_ee_goal_cart,
ee_local_pos,
self.ee_goal_orn_quat,
ee_local_quat,
)
leg_ctrl = (
exec_actions[:, :12] * self._cfg.control_config.action_scale + self.default_angles[:12]
)
if stage_cfg.freeze_arm_joints:
arm_ctrl = np.broadcast_to(self.default_angles[12:18], (self._num_envs, 6)).astype(
get_global_dtype(),
copy=False,
)
else:
arm_ctrl = (
self.get_arm_dof_pos()
+ exec_actions[:, 12:18] * self._cfg.control_config.arm_action_scale
+ self._cfg.ik.gain * dq_ik
)
ctrl = np.concatenate([leg_ctrl, arm_ctrl], axis=1, dtype=get_global_dtype())
return np.clip(ctrl, self.action_space.low, self.action_space.high)
def _init_reward_functions(self) -> None:
self._reward_fns: dict[str, Any] = {
# Tracking rewards.
"tracking_lin_vel": rewards.tracking_lin_vel,
"tracking_ang_vel": rewards.tracking_ang_vel,
# Velocity and orientation penalties.
"lin_vel_z": rewards.lin_vel_z,
"ang_vel_xy": rewards.ang_vel_xy,
"roll": rewards.roll, # Requires ctx.gravity.
# Height and joint-pose terms.
"base_height": rewards.base_height,
"similar_to_default": rewards.similar_to_default, # Aligns with Go2 Joystick.
"leg_pose": rewards.weighted_pose, # Weighted leg L2 term.
"dof_pos_limits": self._reward_dof_pos_limits, # Leg soft limits.
# Action and effort penalties.
"action_rate": rewards.action_rate,
"torques": rewards.torques, # L1 torque over all 18 DOFs.
"energy": rewards.energy, # Requires ctx.dof_vel and info["torques"].
"dof_vel": self._reward_dof_vel, # L2 velocity over all 18 DOFs.
"dof_acc": rewards.dof_acc, # Requires info["qacc"].
# Standing penalty.
"stand_still": self._reward_stand_still, # Penalizes leg pose at zero command.
# Survival.
"alive": rewards.alive,
# Gait terms.
"swing_feet_z": self._reward_swing_feet_z,
"foot_drag": self._reward_foot_drag,
"contact": self._reward_contact,
# Manipulation rewards.
"object_distance": self._reward_object_distance,
"object_distance_l2": self._reward_object_distance_l2,
# Arm collision penalty.
"arm_collision": self._reward_arm_collision,
}
[docs]
def update_state(self, state: NpEnvState) -> NpEnvState:
# Mid-episode command resampling, enabled only when resample_time_s is set.
if self._cmd_resample_steps is not None:
self._cmd_timer += 1
resample_ids = np.where(self._cmd_timer >= self._cmd_resample_steps)[0].astype(np.int32)
if len(resample_ids) > 0:
self._resample_commands(resample_ids, state.info)
self._cmd_timer[resample_ids] = 0
# Gait phase update: zero commands reset the phase to a full-stance pattern.
# This gives contact a four-feet contact target and naturally disables swing_feet_z.
cmd = state.info.get("commands", np.zeros((self._num_envs, 3), dtype=np.float32))
is_moving = self._command_is_moving(cmd)
advanced = np.fmod(self.phase + self._cfg.ctrl_dt * self.gait_frequency, 1.0)
self.phase = np.where(is_moving, advanced, 0.0)
self._write_feet_phase(slice(None), is_moving)
# EE goal trajectory update.
stage_cfg = self._cfg.arm_stage
if stage_cfg.disable_ee_goal_trajectory:
fixed_goal = np.asarray(stage_cfg.fixed_ee_goal_cart, dtype=get_global_dtype())
if fixed_goal.shape != (3,):
raise ValueError(
f"env.arm_stage.fixed_ee_goal_cart must have shape (3,), got {fixed_goal.shape}"
)
self.curr_ee_goal_cart[:] = fixed_goal
self.curr_ee_goal_sphere[:] = _cart2sphere(fixed_goal[None, :])[0]
else:
self._arm_goal_timer += 1
expired = np.where(self._arm_goal_timer >= self._traj_total_steps)[0].astype(np.int32)
if len(expired) > 0:
self._ee_start_sphere[expired] = self._ee_goal_sphere[expired].copy()
self._sample_goal_spheres(expired, self._ee_start_sphere[expired])
self._sample_ee_goal_orn_delta(expired, is_init=False)
self._sample_timing(expired)
self._arm_goal_timer[expired] = 0
# Spherical interpolation, updated every step.
t_frac = np.clip(self._arm_goal_timer / self._traj_steps, 0.0, 1.0).astype(
get_global_dtype()
)[:, None] # (num_envs, 1)
curr_sphere = (
self._ee_start_sphere + (self._ee_goal_sphere - self._ee_start_sphere) * t_frac
)
self.curr_ee_goal_sphere[:] = curr_sphere
self.curr_ee_goal_cart[:] = _sphere2cart(curr_sphere)
self._update_curr_ee_goal_orientation(np.arange(self._num_envs, dtype=np.int32))
# Compute the world-space goal position for render-time visualization.
ab_pos = self._backend.get_sensor_data("armbasepoint_world_pos") # (N, 3)
ab_quat = self._backend.get_sensor_data("armbasepoint_world_quat") # (N, 4)
R = np_matrix_from_quat(ab_quat) # (N, 3, 3)
self.curr_ee_goal_world[:] = ab_pos + np.einsum("nij,nj->ni", R, self.curr_ee_goal_cart)
linvel = self.get_local_linvel()
gyro = self.get_gyro()
gravity = self._backend.get_sensor_data("upvector")
dof_pos = self.get_dof_pos()
dof_vel = self.get_dof_vel()
ee_local_pos, _ = self.get_ee_local_pose()
self._current_ee_local_pos = ee_local_pos
self.feet_force[:, :, :] = 0
for i, sensor_name in enumerate(self._cfg.sensor.feet_force):
self.feet_force[:, i, :] = self._backend.get_sensor_data(sensor_name)
for i, sensor_name in enumerate(self._cfg.sensor.feet_pos):
self.feet_pos[:, i, :] = self._backend.get_sensor_data(sensor_name)
terminated = gravity[:, 2] <= 0.5
reward = self._compute_reward(
state.info, linvel, gyro, gravity, dof_pos, dof_vel, ee_local_pos
)
obs = self._compute_obs(
state.info,
linvel,
gyro,
gravity,
dof_pos,
dof_vel,
ee_local_pos,
self.curr_ee_goal_cart,
self.feet_phase,
)
return state.replace(obs=obs, reward=reward, terminated=terminated)
def _compute_raw_obs(
self,
info: dict,
linvel: np.ndarray,
gyro: np.ndarray,
gravity: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
ee_local_pos: np.ndarray,
ee_goal_cart: np.ndarray,
feet_phase: np.ndarray,
*,
add_noise: bool = True,
) -> np.ndarray:
"""Compute single-step 79-dim obs (no history).
Layout: linvel(3)+gyro(3)+(-gravity)(3)+command(3)+feet_phase(4)+
diff(18)+dof_vel(18)+ee_local_pos(3)+ee_goal_cart(3)+ee_error(3)+
last_actions(18) = 79
"""
diff = dof_pos - self.default_angles
if add_noise:
noise_cfg = self._cfg.noise_config
linvel = self._obs_noise(linvel, noise_cfg.scale_linvel)
gyro = self._obs_noise(gyro, noise_cfg.scale_gyro)
gravity = self._obs_noise(gravity, noise_cfg.scale_gravity)
diff = self._obs_noise(diff, noise_cfg.scale_joint_angle)
dof_vel = self._obs_noise(dof_vel, noise_cfg.scale_joint_vel)
ee_local_pos = self._obs_noise(ee_local_pos, noise_cfg.scale_ee_pos)
n = len(dof_pos)
command = info["commands"] if info["commands"].shape[0] == n else info["commands"][:n]
last_actions = info.get(
"current_actions", np.zeros((n, self._num_action), dtype=get_global_dtype())
)
ee_error = ee_local_pos - ee_goal_cart
return np.concatenate(
[
linvel, # 3
gyro, # 3
-gravity, # 3
command, # 3
feet_phase, # 4
diff, # 18
dof_vel, # 18
ee_local_pos, # 3
ee_goal_cart, # 3
ee_error, # 3
last_actions, # 18
],
axis=1,
dtype=get_global_dtype(),
)
def _update_history(
self,
raw_obs: np.ndarray,
env_ids: np.ndarray | None = None,
*,
critic_raw_obs: np.ndarray | None = None,
) -> dict[str, np.ndarray]:
"""Update history buffers and return obs dict (with or without env_ids slice).
Actor buffer stores obs WITHOUT linvel (raw_obs[:, 3:], 76-dim) so the actor
cannot shortcut the estimator. Critic buffer stores clean full 79-dim obs
when a separate critic_raw_obs is provided.
"""
A = self._actor_one_step_dim # 76
C = self._critic_one_step_dim # 79
H_a = self._cfg.history.num_actor_history
H_c = self._cfg.history.num_critic_history
actor_step = raw_obs[:, 3:] if raw_obs.ndim == 2 else raw_obs[3:]
critic_step = raw_obs if critic_raw_obs is None else critic_raw_obs
if env_ids is None:
if H_a > 1:
self._history_obs_buf = np.roll(self._history_obs_buf, -A, axis=1)
self._history_obs_buf[:, -A:] = actor_step
if H_c > 1:
self._history_critic_buf = np.roll(self._history_critic_buf, -C, axis=1)
self._history_critic_buf[:, -C:] = critic_step
return {
"obs": self._history_obs_buf.copy(),
"critic": self._history_critic_buf.copy(),
}
else:
if H_a > 1:
self._history_obs_buf[env_ids] = np.roll(self._history_obs_buf[env_ids], -A, axis=1)
self._history_obs_buf[env_ids, -A:] = actor_step
if H_c > 1:
self._history_critic_buf[env_ids] = np.roll(
self._history_critic_buf[env_ids], -C, axis=1
)
self._history_critic_buf[env_ids, -C:] = critic_step
return {
"obs": self._history_obs_buf[env_ids].copy(),
"critic": self._history_critic_buf[env_ids].copy(),
}
def _compute_obs(
self,
info: dict,
linvel: np.ndarray,
gyro: np.ndarray,
gravity: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
ee_local_pos: np.ndarray,
ee_goal_cart: np.ndarray,
feet_phase: np.ndarray,
) -> dict[str, np.ndarray]:
actor_raw = self._compute_raw_obs(
info,
linvel,
gyro,
gravity,
dof_pos,
dof_vel,
ee_local_pos,
ee_goal_cart,
feet_phase,
add_noise=True,
)
critic_raw = self._compute_raw_obs(
info,
linvel,
gyro,
gravity,
dof_pos,
dof_vel,
ee_local_pos,
ee_goal_cart,
feet_phase,
add_noise=False,
)
return self._update_history(actor_raw, critic_raw_obs=critic_raw)
def _compute_reward(
self,
info: dict,
linvel: np.ndarray,
gyro: np.ndarray,
gravity: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
ee_local_pos: np.ndarray,
) -> np.ndarray:
dtype = get_global_dtype()
reward = np.zeros((self._num_envs,), dtype=dtype)
cfg = self._reward_cfg
self._current_ee_local_pos = ee_local_pos
ctx = RewardContext(
info=info,
linvel=linvel,
gyro=gyro,
gravity=gravity,
dof_pos=dof_pos,
dof_vel=dof_vel,
num_envs=self._num_envs,
default_angles=self.default_angles,
tracking_sigma=cfg.tracking_sigma,
base_height_target=cfg.base_height_target,
base_height=self._backend.get_base_pos()[:, 2],
pose_weights=self._leg_pose_weights,
)
step_count = info.get("steps", np.zeros((self._num_envs,), dtype=np.uint32))
should_log = self._enable_reward_log and (int(step_count[0]) % 4 == 0)
log = {} if should_log else info.get("log", {})
for name, scale in cfg.scales.items():
if scale == 0 or name not in self._reward_fns:
continue
rew = self._reward_fns[name](ctx)
weighted_rew = rew * scale
reward += weighted_rew
if name == "tracking_lin_vel":
self._episode_sum_tracking_vel += rew.astype(np.float64)
if should_log:
log[f"reward/{name}"] = float(np.mean(weighted_rew))
self._episode_steps += 1
info["log"] = log
return reward * self._cfg.ctrl_dt
def _reward_swing_feet_z(self, _ctx: RewardContext) -> np.ndarray:
is_swing = self.feet_phase >= 0.6
height_error = np.square(self.feet_pos[:, :, 2] - self._reward_cfg.target_foot_height)
swing_rew = np.exp(-height_error / 0.01) * is_swing
reward: np.ndarray = np.sum(swing_rew, axis=1) / len(self._cfg.sensor.feet_pos)
return reward
def _reward_foot_drag(self, _ctx: RewardContext) -> np.ndarray:
foot_heights = self.feet_pos[..., 2]
foot_contact = self.get_foot_contact()
is_swing = foot_contact < 0.5
safe_height = self._reward_cfg.target_foot_height / 2.0
height_error = np.clip(safe_height - foot_heights, 0.0, None)
error = np.square(height_error) * is_swing
drag_penalty: np.ndarray = np.sum(error, axis=1)
return drag_penalty
def _reward_contact(self, _ctx: RewardContext) -> np.ndarray:
contact = self.feet_force[:, :, 2] > 0.1
res = np.zeros(self._num_envs, dtype=np.float32)
for i in range(len(self._cfg.sensor.feet_force)):
is_contact = (self.feet_phase[:, i] < 0.6) | (self.gait_frequency < 1.0e-8)
res += (contact[:, i] == is_contact).astype(np.float32)
return res / len(self._cfg.sensor.feet_force)
def _reward_object_distance(self, _ctx: RewardContext) -> np.ndarray:
dis_err = np.sum(
np.square(self._current_ee_local_pos - self.curr_ee_goal_cart),
axis=1,
)
return np.exp(-dis_err / self._reward_cfg.object_sigma) # type: ignore[no-any-return]
def _reward_object_distance_l2(self, _ctx: RewardContext) -> np.ndarray:
return np.sum(
np.square(self._current_ee_local_pos - self.curr_ee_goal_cart),
axis=1,
)
def _reward_stand_still(self, ctx: RewardContext) -> np.ndarray:
"""Penalize leg deviation from the default pose when command is near zero."""
commands = ctx.info["commands"]
is_still = (~self._command_is_moving(commands)).astype(get_global_dtype())
assert ctx.dof_pos is not None
dof_error = np.sum(np.abs(ctx.dof_pos[:, :12] - ctx.default_angles[:12]), axis=1)
return is_still * dof_error
def _reward_dof_vel(self, ctx: RewardContext) -> np.ndarray:
"""L2 velocity penalty over all 18 joints."""
assert ctx.dof_vel is not None
return np.sum(np.square(ctx.dof_vel), axis=1) # type: ignore[no-any-return]
def _reward_dof_pos_limits(self, ctx: RewardContext) -> np.ndarray:
"""Leg soft-limit penalty configured through reward_config limits."""
cfg = self._reward_cfg
if not cfg.leg_dof_upper_limits or not cfg.leg_dof_lower_limits:
return np.zeros(self._num_envs, dtype=get_global_dtype())
dtype = get_global_dtype()
upper = np.asarray(cfg.leg_dof_upper_limits, dtype=dtype)
lower = np.asarray(cfg.leg_dof_lower_limits, dtype=dtype)
m = cfg.dof_pos_limit_margin
leg_pos = ctx.dof_pos[:, :12]
over = np.square(np.maximum(leg_pos - upper + m, 0.0))
under = np.square(np.maximum(lower + m - leg_pos, 0.0))
return np.sum(over + under, axis=1) # type: ignore[no-any-return]
_ARM_TOUCH_SENSORS = (
"arm_touch_base",
"arm_touch_link1",
"arm_touch_link2",
"arm_touch_link3",
"arm_touch_link4",
"arm_touch_link5",
"arm_touch_link6",
"arm_touch_eef",
"arm_touch_g2base",
)
def _reward_arm_collision(self, _ctx: RewardContext) -> np.ndarray:
"""Sum arm-link contact forces. The scale should be negative."""
total = np.zeros(self._num_envs, dtype=get_global_dtype())
for name in self._ARM_TOUCH_SENSORS:
total += self._backend.get_sensor_data(name)[:, 0]
return total