Source code for unilab.envs.manipulation.sharpa_inhand.base

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Sequence, cast

import gymnasium as gym
import numpy as np

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base.backend import SimBackend
from unilab.base.base import EnvCfg
from unilab.base.np_env import NpEnv, NpEnvState
from unilab.base.scene import SceneCfg
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import np_quat_apply, np_quat_mul

DEFAULT_ACTUATED_JOINT_NAMES: list[str] = [
    "right_thumb_CMC_FE",
    "right_thumb_CMC_AA",
    "right_thumb_MCP_FE",
    "right_thumb_MCP_AA",
    "right_thumb_IP",
    "right_index_MCP_FE",
    "right_index_MCP_AA",
    "right_index_PIP",
    "right_index_DIP",
    "right_middle_MCP_FE",
    "right_middle_MCP_AA",
    "right_middle_PIP",
    "right_middle_DIP",
    "right_ring_MCP_FE",
    "right_ring_MCP_AA",
    "right_ring_PIP",
    "right_ring_DIP",
    "right_pinky_CMC",
    "right_pinky_MCP_FE",
    "right_pinky_MCP_AA",
    "right_pinky_PIP",
    "right_pinky_DIP",
]

DEFAULT_FINGERTIP_BODY_NAMES: list[str] = [
    "right_thumb_DP",
    "right_index_DP",
    "right_middle_DP",
    "right_ring_DP",
    "right_pinky_DP",
]

# Source parity anchor from sharpa-rl-lab:
# rl_isaaclab/tasks/inhand_rotate/sharpa_wave_env_cfg.py (hand init_state joint_pos).
SOURCE_DEFAULT_HAND_JOINT_POS_DEG: tuple[float, ...] = (
    95.12771,
    -3.11244,
    14.81626,
    -1.03493,
    12.23986,
    65.21091,
    6.1133,
    15.58495,
    5.90325,
    31.74149,
    -0.95812,
    41.88173,
    12.844,
    31.72383,
    9.84458,
    35.22366,
    18.02839,
    10.9712,
    68.30895,
    7.99151,
    5.89626,
    5.89875,
)


[docs] @dataclass class SharpaControlConfig: action_scale: float = 1.0 / 24.0 # MuJoCo Sharpa loads PD defaults from XML actuator gains. These fields remain # fallback defaults for backends that cannot expose actuator gains yet. p_gain: float = 1.0 d_gain: float = 0.1 # MuJoCo Sharpa currently uses position actuators, so torque-control mode # stays declared here for owner-config structure but is rejected at runtime. torque_control: bool = False dof_limits_scale: float = 0.9
[docs] @dataclass class SharpaSensorConfig: tactile_force_sensor_names: list[str] = field(default_factory=list)
[docs] @dataclass class SharpaObservationConfig: observation_mode: str = "separated" enable_tactile: bool = True binary_contact: bool = False enable_contact_pos: bool = False contact_smooth: float = 0.5 contact_threshold: float = 0.05 tactile_force_clip_max: float = 4.0
[docs] @dataclass class SharpaPrivilegedInfoConfig: include_friction_scale: bool = True include_gravity_direction: bool = False
[docs] @dataclass class SharpaDomainRandConfig: scale_list: list[float] = field(default_factory=lambda: [0.5]) randomize_base_mass: bool = False added_mass_range: list[float] = field(default_factory=lambda: [0.0, 0.0]) random_com: bool = False com_offset_x: list[float] = field(default_factory=lambda: [0.0, 0.0]) randomize_gravity: bool = False gravity_range: list[list[float]] = field( default_factory=lambda: [[0.0, 0.0, -9.81], [0.0, 0.0, -9.81]] ) randomize_gravity_direction: bool = False gravity_direction_magnitude: float = 9.81 randomize_pd_gains: bool = True randomize_p_gain_scale_lower: float = 0.5 randomize_p_gain_scale_upper: float = 2.0 randomize_d_gain_scale_lower: float = 0.5 randomize_d_gain_scale_upper: float = 2.0 randomize_friction: bool = True randomize_friction_scale_lower: float = 0.5 randomize_friction_scale_upper: float = 2.0 elastomer_base_friction: float = 1.6 metal_base_friction: float = 0.2 object_base_friction: float = 1.0 randomize_com: bool = True randomize_com_lower: float = -0.01 randomize_com_upper: float = 0.01 randomize_mass: bool = True randomize_mass_lower: float = 0.01 randomize_mass_upper: float = 0.25 force_scale: float = 2.0 random_force_prob_scalar: float = 0.25 force_decay: float = 0.9 force_decay_interval: float = 0.08 joint_noise_scale: float = 0.02 contact_latency: float = 0.005 contact_sensor_noise: float = 0.01 push_body_name: str | None = None
[docs] @dataclass class SharpaInhandBaseCfg(EnvCfg): scene: SceneCfg = field( default_factory=lambda: SceneCfg( model_file=str(ASSETS_ROOT_PATH / "robots" / "sharpa_wave" / "scene.xml") ) ) max_episode_seconds: float = 20.0 sim_dt: float = 1.0 / 240.0 ctrl_dt: float = 12.0 / 240.0 action_space: int = 22 observation_space: int = 192 prop_hist_len: int = 30 critic_info_dim: int = 8 clip_obs: float = 5.0 clip_actions: float = 1.0 num_hand_dofs: int = 22 frame_obs_dim: int = 64 obs_lag_steps: int = 3 obs_history_len: int = 80 base_name: str = "right_hand_C_MC" object_body_name: str = "object" object_geom_name: str = "object" actuated_joint_names: list[str] = field( default_factory=lambda: list(DEFAULT_ACTUATED_JOINT_NAMES) ) fingertip_body_names: list[str] = field( default_factory=lambda: list(DEFAULT_FINGERTIP_BODY_NAMES) ) control_config: SharpaControlConfig = field(default_factory=SharpaControlConfig) sensor: SharpaSensorConfig = field(default_factory=SharpaSensorConfig) # type: ignore[assignment] obs: SharpaObservationConfig = field(default_factory=SharpaObservationConfig) priv_info: SharpaPrivilegedInfoConfig = field(default_factory=SharpaPrivilegedInfoConfig) domain_rand: SharpaDomainRandConfig = field(default_factory=SharpaDomainRandConfig) reset_height_lower: float = 0.59906 reset_height_upper: float = 0.63906 reset_angle_diff: float = 45.0 / 180.0 * np.pi rot_axis: tuple[float, float, float] = (0.0, 0.0, 1.0) grasp_cache_path: str = str(ASSETS_ROOT_PATH / "caches" / "sharpa_grasp_linspace") disable_tactile_ids: list[int] = field(default_factory=list) # Match the reference Sharpa object-position reward/privileged-info anchor # by using the fixed XML/default object pose instead of the sampled grasp reset. use_default_object_pose_for_object_pos_anchor: bool = False debug_show_axes: bool = False
[docs] def format_scale_tag(scale_value: float) -> str: """Convert one object scale into a stable cache filename tag. Args: scale_value: Single object scale value. Returns: Scale tag used in cache filenames. """ scale_value = float(scale_value) if scale_value <= 0.0: raise ValueError(f"scale values must be positive, got {scale_value}") return f"{scale_value:g}"
[docs] def resolve_grasp_cache_file(grasp_cache_path: str, scale_value: float) -> Path: """Resolve the grasp cache path for a single object scale. Args: grasp_cache_path: Configured cache prefix or template path. scale_value: Single object scale value for this cache file. Returns: Cache path for that exact scale. """ scale_tag = format_scale_tag(scale_value) if "{scale}" in grasp_cache_path: return Path(grasp_cache_path.format(scale=scale_tag)) base = Path(grasp_cache_path) if base.suffix == ".npy": return base.with_name(f"{base.stem}_{scale_tag}{base.suffix}") return Path(f"{grasp_cache_path}_{scale_tag}.npy")
[docs] def sample_scale_grasp_caches( grasp_caches: Sequence[np.ndarray], scale_ids: np.ndarray, ) -> np.ndarray: """Sample one cached grasp per reset environment from per-scale cache files. Args: grasp_caches: Cache arrays ordered the same way as env scale ids. scale_ids: Scale-bucket assignment for each reset environment. Returns: Cached grasp states with shape ``(num_envs, 29)``. """ num_envs = scale_ids.shape[0] num_scales = len(grasp_caches) if num_scales <= 0: raise ValueError("grasp_caches must contain at least one scale bucket") sampled = np.zeros((num_envs, 29), dtype=np.float64) for scale_idx, grasp_cache in enumerate(grasp_caches): if grasp_cache.ndim != 2 or grasp_cache.shape[1] < 29: raise ValueError(f"Expected cached grasp shape (?, 29), got {grasp_cache.shape}") if grasp_cache.shape[0] == 0: raise ValueError(f"grasp cache for scale id {scale_idx} is empty") env_ids = np.flatnonzero(scale_ids == scale_idx) if len(env_ids) == 0: continue sample_ids = np.random.randint(0, grasp_cache.shape[0], size=len(env_ids)) sampled[env_ids] = grasp_cache[sample_ids] return sampled
[docs] def repeat_obs_history(init_frame: np.ndarray, history_len: int) -> np.ndarray: history = np.broadcast_to( init_frame[:, None, :], (init_frame.shape[0], history_len, init_frame.shape[1]) ).copy() return np.asarray(history, dtype=init_frame.dtype)
[docs] class SharpaInhandBaseEnv(NpEnv): _cfg: SharpaInhandBaseCfg _default_p_gain: np.ndarray _default_d_gain: np.ndarray
[docs] def __init__(self, cfg: SharpaInhandBaseCfg, backend: SimBackend, num_envs: int = 1) -> None: super().__init__(cfg, backend, num_envs) self._np_dtype = get_global_dtype() self._num_action = int(cfg.num_hand_dofs) actuator_range = np.asarray(self._backend.get_actuator_ctrl_range(), dtype=self._np_dtype) if actuator_range.shape[0] < self._num_action: raise ValueError( f"Model has {actuator_range.shape[0]} actuators, but Sharpa task needs {self._num_action}" ) # Keep the raw XML actuator limits for observation normalization and any # logic that needs the original backend contract. Target-position clipping # uses a separate scaled limit pair to match Sharpa source behavior. self._ctrl_lower = np.asarray(actuator_range[: self._num_action, 0], dtype=self._np_dtype) self._ctrl_upper = np.asarray(actuator_range[: self._num_action, 1], dtype=self._np_dtype) self._target_lower, self._target_upper = self._resolve_target_joint_limits( self._ctrl_lower, self._ctrl_upper, cfg.control_config.dof_limits_scale, ) self._init_qpos = self._resolve_init_qpos() self._init_qvel = np.asarray(self._backend.get_init_qvel(), dtype=np.float64) self.nq = int(self._init_qpos.shape[0]) self.nv = int(self._init_qvel.shape[0]) if self.nq < self._num_action + 7: raise ValueError( f"Model qpos dim {self.nq} is too small for {self._num_action} hand DoFs + object pose" ) self._obj_pos_slice = slice(self._num_action, self._num_action + 3) self._obj_quat_slice = slice(self._num_action + 3, self._num_action + 7) source_default_angles = np.deg2rad( np.asarray(SOURCE_DEFAULT_HAND_JOINT_POS_DEG, dtype=np.float64) ) if source_default_angles.shape[0] != self._num_action: raise ValueError( "Source default hand joint pose size mismatch: " f"{source_default_angles.shape[0]} vs expected {self._num_action}" ) self.default_angles = np.asarray(source_default_angles, dtype=self._np_dtype) self._action_space = gym.spaces.Box( low=-float(cfg.clip_actions), high=float(cfg.clip_actions), shape=(self._num_action,), dtype=np.float32, ) self._object_body_ids = self._backend.get_body_ids([cfg.object_body_name]) self._fingertip_body_ids = self._backend.get_body_ids(cfg.fingertip_body_names) self._object_geom_base_size = self._resolve_object_geom_base_size() self._num_tactile = len(cfg.fingertip_body_names) self.last_contacts = np.zeros((num_envs, self._num_tactile), dtype=self._np_dtype) self._prev_tactile_force = np.zeros((num_envs, self._num_tactile), dtype=self._np_dtype) self.object_default_pose = np.zeros((num_envs, 7), dtype=self._np_dtype) self.obs_buf_lag_history = np.zeros( (num_envs, cfg.obs_history_len, cfg.frame_obs_dim), dtype=self._np_dtype ) self.proprio_hist_buf = np.zeros( (num_envs, cfg.prop_hist_len, cfg.frame_obs_dim), dtype=self._np_dtype ) self.critic_info_buf = np.zeros((num_envs, cfg.critic_info_dim), dtype=self._np_dtype) self.scale_ids, self._num_scales, self._bucket_env = self._build_scale_ids( num_envs, cfg.domain_rand.scale_list ) self.scale_values = self._build_scale_values(cfg.domain_rand.scale_list)
@property def action_space(self) -> gym.spaces.Box: return self._action_space # type: ignore[no-any-return] def _resolve_target_joint_limits( self, raw_lower: np.ndarray, raw_upper: np.ndarray, scale: float, ) -> tuple[np.ndarray, np.ndarray]: """Build scaled target-position clipping limits from raw XML actuator bounds. Args: raw_lower: Original lower control limits loaded from the backend XML. raw_upper: Original upper control limits loaded from the backend XML. scale: Multiplicative scale applied to both bound arrays. Returns: Tuple of scaled ``(lower, upper)`` target-position limits. """ scale_value = float(scale) if scale_value <= 0.0: raise ValueError(f"dof_limits_scale must be positive, got {scale_value}") target_lower = np.asarray(raw_lower, dtype=self._np_dtype) * scale_value target_upper = np.asarray(raw_upper, dtype=self._np_dtype) * scale_value if np.any(target_lower > target_upper): raise ValueError("Scaled Sharpa target joint limits are invalid") return target_lower.astype(self._np_dtype, copy=False), target_upper.astype( self._np_dtype, copy=False, ) def _resolve_init_qpos(self) -> np.ndarray: for key_name in ("home", "stand", "default"): try: return np.asarray(self._backend.get_keyframe_qpos(key_name), dtype=np.float64) except Exception: continue try: return np.asarray(self._backend.get_default_qpos(), dtype=np.float64) except NotImplementedError as exc: raise ValueError("Could not resolve initial qpos from backend contract") from exc def _build_scale_ids( self, num_envs: int, scale_list: Sequence[float] ) -> tuple[np.ndarray, int, int]: """Build deterministic near-even environment assignments for each scale. Args: num_envs: Number of vectorized environments to assign. scale_list: Explicit object scale values used by this env instance. Returns: Tuple of scale id per environment, total scale count, and the minimum number of environments assigned to any scale. """ if len(scale_list) == 0: raise ValueError("scale_list must contain at least one scale") scale_values = np.asarray(scale_list, dtype=np.float64) if np.any(scale_values <= 0.0): raise ValueError(f"scale_list values must be positive, got {list(scale_list)}") num_scales = int(scale_values.shape[0]) if num_scales <= 0: raise ValueError(f"scale_list must contain at least one value, got {list(scale_list)}") bucket_env = num_envs // num_scales remainder = num_envs % num_scales # Assign the remainder to the lowest scale ids to keep bucket sizes within one env. counts = np.full((num_scales,), bucket_env, dtype=np.int32) counts[:remainder] += 1 scale_ids = np.repeat(np.arange(num_scales, dtype=np.int32), counts) return scale_ids, num_scales, bucket_env def _build_scale_values(self, scale_list: Sequence[float]) -> np.ndarray: """Normalize configured scale values into a stable numpy array. Args: scale_list: Explicit list of object scales. Returns: Array of configured scale values in config order. """ scale_values = np.asarray(scale_list, dtype=np.float64) if scale_values.ndim != 1 or scale_values.size == 0: raise ValueError(f"scale_list must be a non-empty flat list, got {list(scale_list)}") if np.any(scale_values <= 0.0): raise ValueError(f"scale_list values must be positive, got {list(scale_list)}") return scale_values def _resolve_object_geom_base_size(self) -> np.ndarray | None: try: return cast(np.ndarray, self._backend.get_geom_size(self._cfg.object_geom_name)) except NotImplementedError: return None
[docs] def apply_action(self, actions: np.ndarray, state: NpEnvState) -> np.ndarray: clipped_actions = np.clip(actions, -self._cfg.clip_actions, self._cfg.clip_actions) clipped_actions = np.asarray(clipped_actions[:, : self._num_action], dtype=self._np_dtype) state.info["last_actions"] = state.info.get("current_actions", clipped_actions.copy()) state.info["current_actions"] = clipped_actions prev_targets = state.info.get( "prev_targets", np.broadcast_to(self.default_angles, (self._num_envs, self._num_action)).copy(), ) targets = prev_targets + self._cfg.control_config.action_scale * clipped_actions # Clip action targets by the scaled control range only. Observation # normalization continues to use the raw XML actuator limits. targets = np.clip(targets, self._target_lower, self._target_upper) prev_targets = np.asarray(targets, dtype=self._np_dtype) state.info["prev_targets"] = prev_targets return prev_targets
[docs] def get_hand_dof_pos(self) -> np.ndarray: return np.asarray(self._backend.get_dof_pos()[:, : self._num_action], dtype=self._np_dtype)
[docs] def get_hand_dof_vel(self) -> np.ndarray: return np.asarray(self._backend.get_dof_vel()[:, : self._num_action], dtype=self._np_dtype)
[docs] def get_fingertip_pos(self) -> np.ndarray: return np.asarray( self._backend.get_body_pos_w(self._fingertip_body_ids), dtype=self._np_dtype )
[docs] def get_object_pos(self) -> np.ndarray: return np.asarray( self._backend.get_body_pos_w(self._object_body_ids)[:, 0, :], dtype=self._np_dtype )
[docs] def get_object_quat(self) -> np.ndarray: return np.asarray( self._backend.get_body_quat_w(self._object_body_ids)[:, 0, :], dtype=self._np_dtype )
def _extract_sensor_scalar(self, sensor_name: str) -> np.ndarray: data = np.asarray(self._backend.get_sensor_data(sensor_name), dtype=self._np_dtype) if data.ndim == 1: return data if data.ndim == 2 and data.shape[1] == 1: return data[:, 0] if data.ndim == 2 and data.shape[1] >= 3: return np.asarray(np.linalg.norm(data[:, :3], axis=1), dtype=self._np_dtype) flat = data.reshape(data.shape[0], -1) return np.asarray(flat[:, 0], dtype=self._np_dtype) def _read_tactile_force(self) -> np.ndarray: """Read per-finger tactile force magnitudes in configured sensor order. Args: None. Returns: Array of shape ``(num_envs, num_tactile)`` ordered exactly as ``sensor.tactile_force_sensor_names``. """ tactile_force = np.zeros((self._num_envs, self._num_tactile), dtype=self._np_dtype) if not self._cfg.sensor.tactile_force_sensor_names: return tactile_force for sensor_id, sensor_name in enumerate( self._cfg.sensor.tactile_force_sensor_names[: self._num_tactile] ): try: tactile_force[:, sensor_id] = self._extract_sensor_scalar(sensor_name) except Exception: tactile_force[:, sensor_id] = 0.0 return tactile_force def _clip_tactile_force(self, tactile_force: np.ndarray) -> np.ndarray: """Clip raw tactile-force magnitudes before they enter observation smoothing. Args: tactile_force: Raw per-finger tactile magnitudes with shape ``(num_envs, num_tactile)``. Returns: Clipped tactile-force array with the same shape. Non-positive clip values disable this clamp so the caller can opt out explicitly. """ clip_max = float(self._cfg.obs.tactile_force_clip_max) tactile_force = np.asarray(tactile_force, dtype=self._np_dtype) if clip_max <= 0.0: return tactile_force return np.asarray(np.clip(tactile_force, 0.0, clip_max), dtype=self._np_dtype) def _clear_tactile_history(self, env_ids: np.ndarray | None = None) -> None: """Clear tactile-output and raw-force history buffers. Args: env_ids: Optional environment ids to clear. When ``None``, clear all envs. Returns: None. Buffers are updated in place. """ if env_ids is None: self.last_contacts.fill(0.0) self._prev_tactile_force.fill(0.0) return self.last_contacts[env_ids] = 0.0 self._prev_tactile_force[env_ids] = 0.0 def _compute_tactile_observation(self) -> np.ndarray: """Build tactile observations with source-equivalent smoothing and latency. Args: None. Returns: Tactile-force observation array with shape ``(num_envs, num_tactile)``. """ obs_cfg = self._cfg.obs domain_rand = self._cfg.domain_rand if not obs_cfg.enable_tactile: self._clear_tactile_history() return np.zeros((self._num_envs, self._num_tactile), dtype=self._np_dtype) current_force = SharpaInhandBaseEnv._clip_tactile_force(self, self._read_tactile_force()) smooth_contact = ( current_force * obs_cfg.contact_smooth + self._prev_tactile_force * (1.0 - obs_cfg.contact_smooth) ).astype(self._np_dtype) self._prev_tactile_force[:] = current_force for disabled_id in self._cfg.disable_tactile_ids: if 0 <= disabled_id < self._num_tactile: smooth_contact[:, disabled_id] = 0.0 latency = np.where( np.random.rand(self._num_envs, self._num_tactile) < domain_rand.contact_latency, 1.0, 0.0, ).astype(self._np_dtype) if obs_cfg.binary_contact: binary_contact = (smooth_contact > obs_cfg.contact_threshold).astype(self._np_dtype) self.last_contacts = self.last_contacts * latency + binary_contact * (1.0 - latency) noise_mask = ( np.random.rand(self._num_envs, self._num_tactile) >= domain_rand.contact_sensor_noise ).astype(self._np_dtype) return np.where( self.last_contacts > 0.1, noise_mask * self.last_contacts, self.last_contacts, ) self.last_contacts = self.last_contacts * latency + smooth_contact * (1.0 - latency) return self.last_contacts.copy() def _compute_contact_positions(self, tactile: np.ndarray) -> np.ndarray: del tactile # TODO(sharpa_inhand): IsaacLab contact positions are defined through contact sensor # frame transforms. Backend-level contact point parity is not available yet in the # current UniLab sensor contract, so we keep this channel as zeros for now. return np.zeros((self._num_envs, self._num_tactile * 3), dtype=self._np_dtype) def _normalize_joint_pos(self, dof_pos: np.ndarray) -> np.ndarray: return np.asarray( (2.0 * dof_pos - self._ctrl_upper - self._ctrl_lower) / (self._ctrl_upper - self._ctrl_lower + 1.0e-8), dtype=self._np_dtype, ) def _sample_pd_scales(self, lower: float, upper: float, shape: tuple[int, int]) -> np.ndarray: if lower > 1.0 or upper < 1.0: raise ValueError("PD randomization scales must satisfy lower <= 1 <= upper") small = np.random.uniform(lower, 1.0, size=shape) large = np.random.uniform(1.0, upper, size=shape) use_small = np.random.rand(*shape) > 0.5 return np.where(use_small, small, large).astype(self._np_dtype) def _load_default_pd_gains(self) -> tuple[np.ndarray, np.ndarray]: """Resolve the default per-DOF PD gains used as the randomization baseline. Args: None. Returns: Tuple of ``(p_gain, d_gain)`` arrays with shape ``(num_action,)``. """ try: p_gain, d_gain = self._backend.get_actuator_gains() return ( np.asarray(p_gain[: self._num_action], dtype=self._np_dtype).copy(), np.asarray(d_gain[: self._num_action], dtype=self._np_dtype).copy(), ) except NotImplementedError: return ( np.full((self._num_action,), self._cfg.control_config.p_gain, dtype=self._np_dtype), np.full((self._num_action,), self._cfg.control_config.d_gain, dtype=self._np_dtype), ) def _resolve_pd_gains(self, info: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: p_gain = info.get( "p_gain", np.broadcast_to(self._default_p_gain, (self._num_envs, self._num_action)).copy(), ) d_gain = info.get( "d_gain", np.broadcast_to(self._default_d_gain, (self._num_envs, self._num_action)).copy(), ) return np.asarray(p_gain, dtype=self._np_dtype), np.asarray(d_gain, dtype=self._np_dtype) def _update_proprio_history(self, obs_history: np.ndarray) -> np.ndarray: return np.asarray(obs_history[:, -self._cfg.prop_hist_len :], dtype=self._np_dtype)