Source code for unilab.envs.locomotion.go2.joystick

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, cast

import numpy as np

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.backend import create_backend
from unilab.base.np_env import NpEnvState
from unilab.base.scene import SceneCfg
from unilab.dtype_config import get_global_dtype
from unilab.envs.locomotion.common import rewards
from unilab.envs.locomotion.common.base import Sensor
from unilab.envs.locomotion.common.commands import Commands
from unilab.envs.locomotion.common.domain_rand import DomainRandConfig
from unilab.envs.locomotion.common.dr_provider import LocomotionDRProvider
from unilab.envs.locomotion.common.rewards import RewardContext
from unilab.envs.locomotion.common.terrain_spawn import (
    TerrainCurriculumCfg,
    TerrainSpawnManager,
)
from unilab.envs.locomotion.go2.base import Go2BaseCfg, Go2BaseEnv


[docs] @dataclass class InitState: pos = [0.0, 0.0, 0.42]
[docs] @dataclass class Go2DomainRandConfig(DomainRandConfig): randomize_kp: bool = True kp_multiplier_range: list[float] = field(default_factory=lambda: [0.9, 1.1]) randomize_kd: bool = True kd_multiplier_range: list[float] = field(default_factory=lambda: [0.9, 1.1])
[docs] @dataclass class RewardConfig: scales: dict[str, float] tracking_sigma: float base_height_target: float target_foot_height: float = 0.1
[docs] @dataclass class JoystickSensor(Sensor): local_linvel = "local_linvel" gyro = "gyro" feet_force = ["FL_foot_contact", "FR_foot_contact", "RL_foot_contact", "RR_foot_contact"] feet_pos = ["FL_pos", "FR_pos", "RL_pos", "RR_pos"]
[docs] @registry.envcfg("Go2JoystickFlat") @dataclass class Go2JoystickCfg(Go2BaseCfg): scene: SceneCfg = field( default_factory=lambda: SceneCfg( model_file=str(ASSETS_ROOT_PATH / "robots" / "go2" / "scene_flat.xml") ) ) max_episode_seconds: float = 20.0 init_state: InitState = field(default_factory=InitState) commands: Commands = field(default_factory=Commands) reward_config: RewardConfig | None = None sensor: JoystickSensor = field(default_factory=JoystickSensor) domain_rand: Go2DomainRandConfig = field(default_factory=Go2DomainRandConfig) terrain_curriculum: TerrainCurriculumCfg = field(default_factory=TerrainCurriculumCfg)
[docs] class Go2JoystickDomainRandomizationProvider(LocomotionDRProvider): def _compute_reset_obs( self, env: Any, env_ids: Any, info_updates: Any, linvel: Any, gyro: Any, gravity: Any, dof_pos: Any, dof_vel: Any, ) -> dict[str, np.ndarray]: return cast( dict[str, np.ndarray], env._compute_obs( info_updates, linvel, gyro, gravity, dof_pos, dof_vel, env.feet_phase[env_ids] ), )
[docs] @registry.env("Go2JoystickFlat", sim_backend="mujoco") @registry.env("Go2JoystickFlat", sim_backend="motrix") class Go2WalkTask(Go2BaseEnv): _cfg: Go2JoystickCfg
[docs] def __init__(self, cfg: Go2JoystickCfg, num_envs=1, backend_type="mujoco"): if cfg.reward_config is None: raise ValueError("reward_config must be provided via Hydra configuration") self._scene_terrain_origins: np.ndarray | None = None scene_cfg = cfg.scene terrain_generator = scene_cfg.terrain.generator if scene_cfg.terrain is not None else None backend = create_backend( backend_type, cfg.scene, num_envs, cfg.sim_dt, base_name=cfg.asset.base_name, push_body_name=cfg.domain_rand.push_body_name, position_actuator_gains={"kp": cfg.control_config.Kp, "kd": cfg.control_config.Kd}, motrix_max_iterations=cfg.motrix_max_iterations, post_step_forward_sensor=cfg.post_step_forward_sensor, ) self._terrain_surface_sampler = getattr(backend, "terrain_surface_sampler", None) self._terrain_surface_sample_height = self._resolve_terrain_surface_sample_height() terrain_origins = getattr(backend, "terrain_origins", None) if terrain_origins is not None: self._scene_terrain_origins = terrain_origins super().__init__(cfg, backend, num_envs) self._enable_reward_log = True self._reward_cfg = cfg.reward_config self._init_reward_functions() self._init_domain_randomization(Go2JoystickDomainRandomizationProvider()) if self._scene_terrain_origins is not None and terrain_generator is not None: self._spawn = TerrainSpawnManager( num_envs, self._scene_terrain_origins, cell_size=float(terrain_generator.size[0]), cfg=cfg.terrain_curriculum, terrain_surface_sampler=self._terrain_surface_sampler, ) self.phase = np.zeros((num_envs,), dtype=np.float32) self.feet_phase = np.zeros((num_envs, len(cfg.sensor.feet_force)), dtype=np.float32) self.gait_frequency = 2 self.feet_force = np.zeros((num_envs, len(cfg.sensor.feet_force), 3), dtype=np.float32) self.feet_pos = np.zeros((num_envs, len(cfg.sensor.feet_pos), 3), dtype=np.float32)
[docs] def get_playback_model(self, env_index: int | None = None) -> Any: return super().get_playback_model(env_index)
def _resolve_terrain_surface_sample_height( self, ) -> Callable[[np.ndarray], np.ndarray] | None: sampler = self._terrain_surface_sampler if sampler is None: return None sample_height = getattr(sampler, "sample_height", None) if not callable(sample_height): raise TypeError("terrain_surface_sampler must expose sample_height(xy)") return cast(Callable[[np.ndarray], np.ndarray], sample_height) @property def obs_groups_spec(self) -> dict[str, int]: # gyro(3) + gravity(3) + diff(12) + dof_vel(12) + action(12) + cmd(3) + phase(4) = 49 return {"obs": 49, "critic": 52} def _init_reward_functions(self): self._reward_fns: dict[str, Any] = { "tracking_lin_vel": rewards.tracking_lin_vel, "tracking_ang_vel": rewards.tracking_ang_vel, "lin_vel_z": rewards.lin_vel_z, "ang_vel_xy": rewards.ang_vel_xy, "base_height": rewards.base_height, "action_rate": rewards.action_rate, "similar_to_default": rewards.similar_to_default, "alive": rewards.alive, "swing_feet_z": self._reward_swing_feet_z, "contact": self._reward_contact, "foot_drag": self._reward_foot_drag, }
[docs] def update_state(self, state: NpEnvState) -> NpEnvState: self.phase = np.fmod(self.phase + self._cfg.ctrl_dt * self.gait_frequency, 1.0) self.feet_phase[:, 0] = self.phase self.feet_phase[:, 3] = self.phase self.feet_phase[:, 1] = (self.phase + 0.5) % 1 self.feet_phase[:, 2] = (self.phase + 0.5) % 1 linvel = self.get_local_linvel() gyro = self.get_gyro() gravity = self._backend.get_sensor_data("upvector") dof_pos = self.get_dof_pos() dof_vel = self.get_dof_vel() self.feet_force[:, :, :] = 0 for i in range(len(self._cfg.sensor.feet_force)): self.feet_force[:, i, :] = self._backend.get_sensor_data(self._cfg.sensor.feet_force[i]) for i in range(len(self._cfg.sensor.feet_pos)): self.feet_pos[:, i, :] = self._backend.get_sensor_data(self._cfg.sensor.feet_pos[i]) terminated = gravity[:, 2] <= 0.5 reward = self._compute_reward(state.info, linvel, gyro, dof_pos) obs = self._compute_obs( state.info, linvel, gyro, gravity, dof_pos, dof_vel, self.feet_phase ) state = state.replace(obs=obs, reward=reward, terminated=terminated) done = state.terminated | state.truncated if np.any(done): done_indices = np.where(done)[0] stats = self._spawn.update_on_done( done_indices, self._backend.get_base_pos()[done_indices] ) if stats: if "log" not in state.info: state.info["log"] = {} for k, v in stats.items(): state.info["log"][f"terrain_curriculum/{k}"] = float(v) return state
def _compute_obs( self, info: dict, linvel, gyro, gravity, dof_pos, dof_vel, feet_phase ) -> dict[str, np.ndarray]: noise_cfg = self._cfg.noise_config diff = dof_pos - self.default_angles noisy_gyro = self._obs_noise(gyro, noise_cfg.scale_gyro) noisy_gravity = self._obs_noise(gravity, noise_cfg.scale_gravity) noisy_diff = self._obs_noise(diff, noise_cfg.scale_joint_angle) noisy_dof_vel = self._obs_noise(dof_vel, noise_cfg.scale_joint_vel) command = info["commands"] last_actions = info.get("current_actions", np.zeros_like(diff)) obs = np.concatenate( [ noisy_gyro, -noisy_gravity, noisy_diff, noisy_dof_vel, last_actions, command, feet_phase, ], axis=1, dtype=get_global_dtype(), ) critic = np.concatenate( [gyro, -gravity, diff, dof_vel, last_actions, command, feet_phase, linvel], axis=1, dtype=get_global_dtype(), ) return {"obs": obs, "critic": critic} def _compute_reward(self, info: dict, linvel, gyro, dof_pos) -> np.ndarray: cfg = self._reward_cfg ctx = RewardContext( info=info, linvel=linvel, gyro=gyro, dof_pos=dof_pos, num_envs=self._num_envs, default_angles=self.default_angles, tracking_sigma=cfg.tracking_sigma, base_height_target=cfg.base_height_target, base_height=self._reward_base_height_values(), ) return rewards.run_reward_dispatch( scales=cfg.scales, fns=self._reward_fns, ctx=ctx, info=info, enable_log=self._enable_reward_log, ctrl_dt=self._cfg.ctrl_dt, ) # ── reward functions (robot-specific) ──────────────────────────── def _reward_base_height_values(self) -> np.ndarray: base_pos = np.asarray(self._backend.get_base_pos(), dtype=get_global_dtype()) sample_height = self._terrain_surface_sample_height if sample_height is None: return np.asarray(base_pos[:, 2], dtype=get_global_dtype()) surface = np.asarray(sample_height(base_pos[:, :2]), dtype=get_global_dtype()) return np.asarray(base_pos[:, 2] - surface, dtype=get_global_dtype()) def _reward_swing_feet_z(self, ctx: RewardContext) -> np.ndarray: is_swing = self.feet_phase >= 0.6 target_height = 0.1 height_error = np.square(self.feet_pos[:, :, 2] - target_height) swing_rew = np.exp(-height_error / 0.01) * is_swing reward: np.ndarray = np.sum(swing_rew, axis=1) / len(self._cfg.sensor.feet_pos) return reward def _reward_foot_drag(self, ctx: RewardContext) -> np.ndarray: foot_pos = self.get_foot_pos() foot_heights = foot_pos[..., 2] foot_contact = self.get_foot_contact() is_swing = foot_contact < 0.5 safe_height = self._reward_cfg.target_foot_height / 2.0 height_error = np.clip(safe_height - foot_heights, 0.0, None) error = np.square(height_error) * is_swing drag_penalty: np.ndarray = np.sum(error, axis=1) return drag_penalty def _reward_contact(self, ctx: RewardContext) -> np.ndarray: contact = self.feet_force[:, :, 2] > 0.1 res = np.zeros(self._num_envs, dtype=np.float32) for i in range(len(self._cfg.sensor.feet_force)): is_contact = (self.feet_phase[:, i] < 0.6) | (self.gait_frequency < 1.0e-8) res += (contact[:, i] == is_contact).astype(np.float32) return res / len(self._cfg.sensor.feet_force)