Source code for unilab.envs.locomotion.go2w.rough

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any

import numpy as np

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.np_env import NpEnvState
from unilab.base.scene import SceneCfg, TerrainSceneCfg
from unilab.dr import DomainRandomizationManager, ResetPlan
from unilab.dr.dr_utils import zero_actions
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import (
    np_quat_from_euler_xyz,
    np_quat_mul,
)
from unilab.envs.locomotion.common import rewards
from unilab.envs.locomotion.common.commands import (
    Commands,
    apply_heading_yaw_feedback,
    zero_small_xy_commands,
)
from unilab.envs.locomotion.common.height_scan import (
    HeightScanConfig,
    base_height_from_scan,
    height_scan_obs,
    init_height_scan_sensor,
    raw_height_scan_obs,
    terrain_out_of_bounds,
)
from unilab.envs.locomotion.common.rewards import RewardContext
from unilab.envs.locomotion.common.terrain_spawn import (
    TerrainCurriculumCfg,
    TerrainSpawnManager,
)
from unilab.envs.locomotion.go2w.base import NUM_GO2W_ACTIONS, NUM_LEG_ACTIONS
from unilab.envs.locomotion.go2w.joystick import (
    Go2WJoystickCfg,
    Go2WJoystickDomainRandomizationProvider,
    Go2WJoystickEnv,
    build_go2w_backend_reset_randomization,
    sample_go2w_heading_commands,
)
from unilab.terrains import (
    SubTerrainCfg,
    TerrainGeneratorCfg,
    flat,
    hf_pyramid_slope,
    hf_pyramid_slope_inv,
    pyramid_stairs,
    pyramid_stairs_inv,
    random_rough,
    wave_terrain,
)

# pyright: reportIncompatibleVariableOverride=false, reportAttributeAccessIssue=false, reportCallIssue=false


[docs] @dataclass class Go2WRoughCommands(Commands): vel_limit: list[list[float]] = field( default_factory=lambda: [[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]] ) resampling_time: float = 10.0 heading_command: bool = True heading_range: list[float] = field(default_factory=lambda: [-np.pi, np.pi])
[docs] @dataclass class RoughTerminationConfig: terrain_out_of_bounds: bool = True terrain_distance_buffer: float = 3.0
[docs] @dataclass(kw_only=True) class Go2WRoughTerrainCfg(TerrainGeneratorCfg): size: tuple[float, float] = (8.0, 8.0) num_rows: int = 6 num_cols: int = 6 border_width: float = 1.0 add_lights: bool = True horizontal_scale: float = 0.1 sub_terrains: dict[str, SubTerrainCfg] = field( default_factory=lambda: { "flat": flat(proportion=0.0), "pyramid_stairs": pyramid_stairs( proportion=0.1, step_height_range=(0.025, 0.10), step_width=0.4, platform_width=3.0, border_width=0.2, ), "pyramid_stairs_inv": pyramid_stairs_inv( proportion=0.1, step_height_range=(0.025, 0.10), step_width=0.4, platform_width=3.0, border_width=0.2, ), "hf_pyramid_slope": hf_pyramid_slope( proportion=0.2, slope_range=(0.0, 0.3), platform_width=2.0, border_width=0.2, ), "hf_pyramid_slope_inv": hf_pyramid_slope_inv( proportion=0.2, slope_range=(0.0, 0.3), platform_width=2.0, border_width=0.2, ), "random_rough": random_rough( proportion=0.3, noise_range=(0.01, 0.06), noise_step=0.01, border_width=0.2, ), "wave_terrain": wave_terrain( proportion=0.3, amplitude_range=(0.0, 0.12), num_waves=4, border_width=0.2, ), } )
[docs] @registry.envcfg("Go2WJoystickRough") @dataclass class Go2WJoystickRoughCfg(Go2WJoystickCfg): """Go2W rough terrain task with procedurally generated sub-terrains.""" scene: SceneCfg = field( default_factory=lambda: SceneCfg( model_file=str(ASSETS_ROOT_PATH / "robots" / "go2w" / "go2w.xml"), fragment_files=[ str(ASSETS_ROOT_PATH / "robots" / "go2w" / "locomotion_task.xml"), ], terrain=TerrainSceneCfg( generator=Go2WRoughTerrainCfg(), hfield_name="terrain_hfield", geom_name="floor", ), ) ) commands: Go2WRoughCommands = field(default_factory=Go2WRoughCommands) terrain_scan: HeightScanConfig = field(default_factory=HeightScanConfig) termination_config: RoughTerminationConfig = field(default_factory=RoughTerminationConfig) terrain_curriculum: TerrainCurriculumCfg = field(default_factory=TerrainCurriculumCfg)
[docs] class Go2WJoystickRoughDomainRandomizationProvider(Go2WJoystickDomainRandomizationProvider): def _sample_commands(self, env: Any, num_reset: int) -> np.ndarray: commands = super()._sample_commands(env, num_reset) zero_small_xy_commands(commands, threshold=0.08) standing_prob = env.cfg.commands.rel_standing_envs if standing_prob > 0.0: standing = np.random.uniform(size=(num_reset,)) < min(standing_prob, 1.0) commands[standing] = 0.0 if env.cfg.commands.heading_command: commands[:, 2] = 0.0 return commands
[docs] def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan: num_reset = len(env_ids) qpos = np.tile(env._init_qpos, (num_reset, 1)) qvel = np.tile(env._init_qvel, (num_reset, 1)) qpos[:, 0:2] += np.random.uniform(-0.5, 0.5, (num_reset, 2)) qpos[:, 2] += np.random.uniform(0.25, 0.5, (num_reset,)) qpos[:, 0:3] += env._spawn.origins_for(env_ids) roll = np.random.uniform(-3.14, 3.14, (num_reset,)) pitch = np.random.uniform(-3.14, 3.14, (num_reset,)) yaw = np.random.uniform(-3.14, 3.14, (num_reset,)) qpos[:, 3:7] = np_quat_mul(qpos[:, 3:7], np_quat_from_euler_xyz(roll, pitch, yaw)) qvel[:, 0:6] = np.asarray( np.random.uniform(-0.5, 0.5, size=(num_reset, 6)), dtype=get_global_dtype() ) motor_kp, motor_kd = env.sample_reset_motor_gains(num_reset) env.set_motor_gains(env_ids, motor_kp, motor_kd) commands = self._sample_commands(env, num_reset) info_updates: dict[str, Any] = { "commands": commands, "current_actions": zero_actions(num_reset, env._num_action), "last_actions": zero_actions(num_reset, env._num_action), "motor_kp": motor_kp.astype(get_global_dtype()), "motor_kd": motor_kd.astype(get_global_dtype()), "torques": np.zeros((num_reset, env._num_action), dtype=get_global_dtype()), } if getattr(env.cfg.commands, "heading_command", False): info_updates["heading_commands"] = sample_go2w_heading_commands(env, num_reset) return ResetPlan( env_ids=env_ids, qpos=qpos, qvel=qvel, info_updates=info_updates, randomization=build_go2w_backend_reset_randomization(env, num_reset), )
[docs] @registry.env("Go2WJoystickRough", sim_backend="mujoco") class Go2WJoystickRoughEnv(Go2WJoystickEnv): _cfg: Go2WJoystickRoughCfg _height_scan_dim: int = 0
[docs] def __init__(self, cfg: Go2WJoystickRoughCfg, num_envs=1, backend_type="mujoco"): super().__init__(cfg, num_envs=num_envs, backend_type=backend_type) terrain_origins = getattr(self._backend, "terrain_origins", None) terrain_generator = cfg.scene.terrain.generator if cfg.scene.terrain is not None else None if terrain_origins is not None and terrain_generator is not None: self._spawn = TerrainSpawnManager( num_envs, terrain_origins, cell_size=float(terrain_generator.size[0]), cfg=cfg.terrain_curriculum, terrain_surface_sampler=getattr(self._backend, "terrain_surface_sampler", None), ) self._dr_manager = DomainRandomizationManager( self, Go2WJoystickRoughDomainRandomizationProvider() ) init_height_scan_sensor(self, cfg.terrain_scan, cfg.asset.base_name)
@property def obs_groups_spec(self) -> dict[str, int]: return {"obs": 53, "critic": 56 + self._height_scan_dim} def _init_reward_functions(self) -> None: def gated(fn): return lambda ctx: fn(ctx) * self._upright_scale(ctx.gravity) def _joint_pos_penalty(ctx: RewardContext) -> np.ndarray: return self._reward_joint_pos_penalty(ctx) * self._upright_scale(ctx.gravity) def _stand_still(ctx: RewardContext) -> np.ndarray: return self._reward_stand_still(ctx) * self._upright_scale(ctx.gravity) self._reward_fns = { "tracking_lin_vel": gated(rewards.tracking_lin_vel), "tracking_ang_vel": gated(rewards.tracking_ang_vel), "lin_vel_z": gated(rewards.lin_vel_z), "ang_vel_xy": gated(rewards.ang_vel_xy), "base_height": gated(rewards.base_height), "orientation": gated(rewards.orientation), "similar_to_default": gated(rewards.similar_to_default), "torques": gated(self._reward_torques_l2), "joint_torques_l2": gated(self._reward_joint_torques_l2), "energy": gated(rewards.energy), "dof_vel": gated(self._reward_dof_vel), "dof_acc": gated(self._reward_dof_acc), "joint_acc_l2": gated(self._reward_dof_acc), "wheel_acc": gated(self._reward_wheel_acc), "joint_acc_wheel_l2": gated(self._reward_wheel_acc), "stand_still": _stand_still, "hip_pos": gated(self._reward_hip_pos), "dof_error": gated(self._reward_dof_error), "joint_pos_penalty": _joint_pos_penalty, "joint_power": gated(self._reward_joint_power), "joint_mirror": gated(self._reward_joint_mirror), "alive": rewards.alive, "upward": rewards.upward, "wheel_vel": gated(self._reward_wheel_vel), "action_rate": rewards.action_rate, } def _upright_scale(self, gravity: np.ndarray | None) -> np.ndarray: return rewards.upright_scale(gravity, self._num_envs) def _compute_obs( self, info: dict, linvel: np.ndarray, gyro: np.ndarray, gravity: np.ndarray, dof_pos: np.ndarray, dof_vel: np.ndarray, ) -> dict[str, np.ndarray]: noise_cfg = self._cfg.noise_config leg_diff = dof_pos[:, :NUM_LEG_ACTIONS] - self.default_angles[:NUM_LEG_ACTIONS] policy_gyro = self._obs_noise(gyro, noise_cfg.scale_gyro) * 0.25 policy_gravity = self._obs_noise(-gravity, noise_cfg.scale_gravity) policy_leg_diff = self._obs_noise(leg_diff, noise_cfg.scale_joint_angle) policy_dof_vel = self._obs_noise(dof_vel, noise_cfg.scale_joint_vel) * 0.05 num_obs = gyro.shape[0] last_actions = info.get( "current_actions", np.zeros((num_obs, NUM_GO2W_ACTIONS), dtype=dof_pos.dtype) ) commands = info["commands"] obs = np.concatenate( [ policy_gyro, policy_gravity, commands, policy_leg_diff, policy_dof_vel, last_actions, ], axis=1, dtype=get_global_dtype(), ) critic_base = np.concatenate( [linvel, gyro, -gravity, commands, leg_diff, dof_vel, last_actions], axis=1, dtype=get_global_dtype(), ) critic = np.concatenate( [critic_base, height_scan_obs(self, self._cfg.terrain_scan, num_obs)], axis=1, dtype=get_global_dtype(), ) return {"obs": obs, "critic": critic} def _reward_base_height_values(self, num_obs: int) -> np.ndarray: height = base_height_from_scan(self, num_obs) if height.shape[0] != num_obs: return super()._reward_base_height_values(num_obs) return height def _update_commands(self, info: dict) -> None: commands = info.get("commands") if commands is None: return commands_arr = np.asarray(commands, dtype=get_global_dtype()) resampling_time = float(self._cfg.commands.resampling_time) if resampling_time > 0.0: interval_steps = max(int(round(resampling_time / self._cfg.ctrl_dt)), 1) steps = np.asarray(info.get("steps", np.zeros((self._num_envs,), dtype=np.uint32))) resample_mask = (steps > 0) & ((steps % interval_steps) == 0) if np.any(resample_mask): num_resample = int(np.count_nonzero(resample_mask)) low = np.asarray(self._cfg.commands.vel_limit[0], dtype=get_global_dtype()) high = np.asarray(self._cfg.commands.vel_limit[1], dtype=get_global_dtype()) sampled = np.random.uniform(low=low, high=high, size=(num_resample, 3)).astype( get_global_dtype() ) zero_small_xy_commands(commands, threshold=0.08) commands_arr[resample_mask] = sampled if self._cfg.commands.heading_command: heading_commands = self._ensure_heading_commands(info, commands_arr.shape[0]) heading_commands[resample_mask] = sample_go2w_heading_commands( self, num_resample ) info["heading_commands"] = heading_commands if self._cfg.commands.heading_command: heading_commands = self._ensure_heading_commands(info, commands_arr.shape[0]) base_quat = np.asarray(self._backend.get_base_quat(), dtype=get_global_dtype()) if base_quat.shape[0] == commands_arr.shape[0]: apply_heading_yaw_feedback(commands_arr, base_quat, heading_commands, stiffness=0.5) info["commands"] = commands_arr def _compute_terminated(self, gravity: np.ndarray) -> np.ndarray: del gravity return np.zeros((self._num_envs,), dtype=bool) def _raw_height_scan_obs(self, num_obs: int) -> tuple[np.ndarray | None, np.ndarray | None]: return raw_height_scan_obs(self, num_obs) def _compute_truncated(self, state: NpEnvState) -> np.ndarray: truncated = super()._compute_truncated(state) if self._cfg.termination_config.terrain_out_of_bounds: terrain_scene = self._cfg.scene.terrain terrain_cfg = terrain_scene.generator if terrain_scene is not None else None np.logical_or( truncated, terrain_out_of_bounds( self, terrain_cfg, float(self._cfg.termination_config.terrain_distance_buffer), ), out=truncated, ) return truncated
registry.register_env("Go2WJoystickRough", Go2WJoystickRoughEnv, sim_backend="motrix")