Source code for unilab.envs.manipulation.sharpa_inhand.rotation

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, cast

import numpy as np

from unilab.assets.hub import resolve_grasp_cache_files
from unilab.base import registry
from unilab.base.backend import create_backend
from unilab.base.np_env import NpEnvState
from unilab.dr import (
    DomainRandomizationCapabilities,
    DomainRandomizationProvider,
    GeomSizeOverride,
    InitRandomizationPlan,
    IntervalRandomizationPlan,
    ModelVariantSpec,
    ResetPlan,
)
from unilab.dr.dr_utils import build_common_reset_randomization, validate_common_reset_randomization
from unilab.dr.types import (
    RESET_TERM_BODY_IPOS,
    RESET_TERM_BODY_MASS,
    RESET_TERM_GEOM_FRICTION,
    RESET_TERM_GRAVITY,
    RESET_TERM_KD,
    RESET_TERM_KP,
    ResetRandomizationPayload,
)
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import (
    np_quat_apply,
    np_quat_conjugate,
    np_quat_mul,
    np_quat_to_axis_angle,
)
from unilab.envs.manipulation.sharpa_inhand.base import (
    SharpaInhandBaseCfg,
    SharpaInhandBaseEnv,
    repeat_obs_history,
    resolve_grasp_cache_file,
    sample_scale_grasp_caches,
)


[docs] @dataclass class RewardConfig: scales: dict[str, float] = field( default_factory=lambda: { "rotate": 2.5, "obj_linvel": -0.3, "pose_diff": -0.4, "torque": -0.1, "work": -0.5, "object_pos": 0.003, } ) angvel_clip_min: float = -0.5 angvel_clip_max: float = 0.5
[docs] @registry.envcfg("SharpaInhandRotation") @dataclass class SharpaInhandRotationCfg(SharpaInhandBaseCfg): critic_info_dim: int = 9 reward_config: RewardConfig | None = None zero_action_test_mode: bool = False
[docs] def sample_random_quaternion(num_envs: int) -> np.ndarray: """Sample uniformly distributed random quaternions in wxyz convention. Args: num_envs: Number of quaternions to sample. Returns: Quaternion array with shape ``(num_envs, 4)``. """ u1 = np.random.rand(num_envs) u2 = np.random.rand(num_envs) * 2.0 * np.pi u3 = np.random.rand(num_envs) * 2.0 * np.pi q1 = np.sqrt(1.0 - u1) * np.sin(u2) q2 = np.sqrt(1.0 - u1) * np.cos(u2) q3 = np.sqrt(u1) * np.sin(u3) q4 = np.sqrt(u1) * np.cos(u3) return np.stack([q4, q1, q2, q3], axis=1).astype(np.float64)
[docs] class SharpaInhandRotationDRProvider(DomainRandomizationProvider):
[docs] def validate(self, env: Any, capabilities: DomainRandomizationCapabilities) -> None: unsupported = validate_common_reset_randomization(env, capabilities) domain_rand = getattr(env.cfg, "domain_rand", None) if domain_rand is not None and getattr(domain_rand, "randomize_gravity_direction", False): if getattr(domain_rand, "randomize_gravity", False): raise ValueError( "Use only one Sharpa gravity randomization mode: " "domain_rand.randomize_gravity_direction or domain_rand.randomize_gravity" ) if not capabilities.supports_reset_term(RESET_TERM_GRAVITY): unsupported = unsupported | frozenset({RESET_TERM_GRAVITY}) if ( domain_rand is not None and domain_rand.force_scale > 0.0 and not capabilities.supports_interval_body_force ): raise NotImplementedError( f"{env._backend.backend_type} backend does not support interval body force perturbation" ) if domain_rand is not None and domain_rand.randomize_pd_gains: if not capabilities.supports_reset_term(RESET_TERM_KP): unsupported = unsupported | frozenset({RESET_TERM_KP}) if not capabilities.supports_reset_term(RESET_TERM_KD): unsupported = unsupported | frozenset({RESET_TERM_KD}) if ( domain_rand is not None and domain_rand.randomize_com and not capabilities.supports_reset_term(RESET_TERM_BODY_IPOS) ): unsupported = unsupported | frozenset({RESET_TERM_BODY_IPOS}) if ( domain_rand is not None and domain_rand.randomize_mass and not capabilities.supports_reset_term(RESET_TERM_BODY_MASS) ): unsupported = unsupported | frozenset({RESET_TERM_BODY_MASS}) if ( domain_rand is not None and domain_rand.randomize_friction and not capabilities.supports_reset_term(RESET_TERM_GEOM_FRICTION) ): unsupported = unsupported | frozenset({RESET_TERM_GEOM_FRICTION}) if unsupported: names = ", ".join(sorted(unsupported)) raise NotImplementedError( f"{env._backend.backend_type} backend does not support reset randomization terms: {names}" )
[docs] def build_init_randomization_plan(self, env: Any) -> InitRandomizationPlan | None: base_size = getattr(env, "_object_geom_base_size", None) if base_size is None: return None model_variants = tuple( ModelVariantSpec( geom_size_overrides=( GeomSizeOverride( geom_name=env.cfg.object_geom_name, size=tuple(np.asarray(base_size * scale, dtype=np.float64)), ), ) ) for scale in np.asarray(env.scale_values, dtype=np.float64) ) return InitRandomizationPlan( model_assignments=np.asarray(env.scale_ids, dtype=np.int32).copy(), model_variants=model_variants, )
def _load_grasp_cache(self, env: Any) -> tuple[np.ndarray, ...]: """Load one grasp cache file for each configured object scale. Args: env: Sharpa rotation env instance. Returns: Tuple of cache arrays ordered the same as ``env.scale_values``. """ if getattr(env, "_grasp_cache", None) is not None: return cast(tuple[np.ndarray, ...], env._grasp_cache) grasp_caches: list[np.ndarray] = [] missing_files: list[str] = [] for scale_value in np.asarray(env.scale_values, dtype=np.float64): cache_file = resolve_grasp_cache_file(env.cfg.grasp_cache_path, float(scale_value)) # Auto-download from HF if the cache file is missing locally. resolved = cast(str, resolve_grasp_cache_files(str(cache_file))) cache_file = Path(resolved) if not cache_file.exists(): missing_files.append(str(cache_file)) continue grasp_caches.append(np.load(cache_file).astype(np.float64)) if missing_files: missing = ", ".join(missing_files) raise RuntimeError( f"Missing Sharpa grasp cache file(s): {missing}\n" "Generate them with:\n" " bash scripts/sharpa_collect_grasps.sh 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6\n" "See docs/sphinx/source/zh_CN/user_guide/D-tasks/04-sharpa-inhand.md" ) env._grasp_cache = tuple(grasp_caches) return cast(tuple[np.ndarray, ...], env._grasp_cache) def _sample_reset_pd_gains( self, env: Any, num_reset: int, *, dtype: np.dtype[Any], ) -> tuple[np.ndarray, np.ndarray]: """Sample absolute reset-time PD gains from split-around-1 scale ranges. Args: env: Sharpa rotation env instance. num_reset: Number of environments being reset. dtype: Output dtype for the returned gain arrays. Returns: Tuple of absolute ``(p_gain, d_gain)`` arrays with shape ``(num_reset, env._num_action)``. """ p_gain = np.broadcast_to(env._default_p_gain, (num_reset, env._num_action)).astype( dtype, copy=True ) d_gain = np.broadcast_to(env._default_d_gain, (num_reset, env._num_action)).astype( dtype, copy=True ) domain_rand = env.cfg.domain_rand if domain_rand.randomize_pd_gains: p_scale = env._sample_pd_scales( domain_rand.randomize_p_gain_scale_lower, domain_rand.randomize_p_gain_scale_upper, shape=(num_reset, env._num_action), ) d_scale = env._sample_pd_scales( domain_rand.randomize_d_gain_scale_lower, domain_rand.randomize_d_gain_scale_upper, shape=(num_reset, env._num_action), ) p_gain *= p_scale d_gain *= d_scale return p_gain, d_gain def _build_info_updates( self, env: Any, env_ids: np.ndarray, hand_qpos: np.ndarray, object_pos: np.ndarray, object_quat: np.ndarray, reset_height_lower: np.ndarray, reset_height_upper: np.ndarray, rot_axis: np.ndarray, p_gain: np.ndarray, d_gain: np.ndarray, friction_scale: np.ndarray | None, randomized_mass: np.ndarray | None, randomized_com_offset: np.ndarray | None, gravity: np.ndarray | None, ) -> dict[str, np.ndarray]: num_reset = hand_qpos.shape[0] dtype = get_global_dtype() critic_info = env._build_reset_critic_info( num_reset, env_ids, friction_scale=friction_scale, randomized_mass=randomized_mass, randomized_com_offset=randomized_com_offset, gravity=gravity, ).astype(dtype) tactile = np.zeros((num_reset, env._num_tactile), dtype=dtype) contact_pos = np.zeros((num_reset, env._num_tactile * 3), dtype=dtype) hand_qpos_f = hand_qpos.astype(dtype) targets = hand_qpos_f.copy() object_pos_f = object_pos.astype(dtype) init_frame = env._build_policy_frame( dof_pos=hand_qpos_f, targets=targets, tactile=tactile, contact_pos=contact_pos, add_noise=True, ) critic_init_frame = env._build_policy_frame( dof_pos=hand_qpos_f, targets=targets, tactile=tactile, contact_pos=contact_pos, add_noise=False, ) obs_lag_history = repeat_obs_history(init_frame, env.cfg.obs_history_len).astype(dtype) critic_obs_lag_history = repeat_obs_history( critic_init_frame, env.cfg.obs_history_len ).astype(dtype) object_default_pose = np.concatenate( [object_pos_f, object_quat.astype(dtype)], axis=1 ).astype(dtype) object_pos_anchor = env._build_object_pos_anchor(object_pos_f).astype(dtype) critic_info = env._fill_critic_info( critic_info=critic_info, object_pos=object_pos_f, object_pos_anchor=object_pos_anchor, ) info_updates = { "current_actions": np.zeros((num_reset, env._num_action), dtype=dtype), "last_actions": np.zeros((num_reset, env._num_action), dtype=dtype), "prev_targets": hand_qpos_f.copy(), "init_pose": hand_qpos_f.copy(), "prev_hand_pos": hand_qpos_f.copy(), "prev_object_pos": object_pos.astype(dtype).copy(), "prev_object_quat": object_quat.astype(dtype).copy(), "object_default_pose": object_default_pose, "object_pos_anchor": object_pos_anchor, "reset_height_lower": reset_height_lower.astype(dtype), "reset_height_upper": reset_height_upper.astype(dtype), "rot_axis": rot_axis.astype(dtype), "p_gain": p_gain, "d_gain": d_gain, "critic_info": critic_info, "obs_lag_history": obs_lag_history, "critic_obs_lag_history": critic_obs_lag_history, "proprio_hist": env._update_proprio_history(obs_lag_history), } return info_updates
[docs] def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan: num_reset = len(env_ids) if num_reset == 0: return ResetPlan( env_ids=env_ids, qpos=np.zeros((0, env.nq), dtype=np.float64), qvel=np.zeros((0, env.nv), dtype=np.float64), info_updates={}, randomization=None, ) friction_scale = env._sample_friction_scale(num_reset) randomized_mass = env._sample_object_mass(num_reset) randomized_com_offset = env._sample_object_com_offset(num_reset) gravity = env._sample_reset_gravity(num_reset) p_gain, d_gain = self._sample_reset_pd_gains(env, num_reset, dtype=get_global_dtype()) grasp_cache = self._load_grasp_cache(env) sampled_pose = sample_scale_grasp_caches(grasp_cache, env.scale_ids[env_ids]) hand_qpos = sampled_pose[:, : env._num_action] object_pos = sampled_pose[:, env._num_action : env._num_action + 3] object_quat = sampled_pose[:, env._num_action + 3 : env._num_action + 7] rot_axis = np.broadcast_to(env._rot_axis, (num_reset, 3)).copy().astype(np.float64) qpos = np.zeros((num_reset, env.nq), dtype=np.float64) qpos[:, : env._num_action] = hand_qpos qpos[:, env._obj_pos_slice] = object_pos qpos[:, env._obj_quat_slice] = object_quat qvel = np.zeros((num_reset, env.nv), dtype=np.float64) height_range = env.cfg.reset_height_upper - env.cfg.reset_height_lower reset_height_lower = object_pos[:, 2] - 0.5 * height_range reset_height_upper = object_pos[:, 2] + 0.5 * height_range info_updates = self._build_info_updates( env, env_ids=env_ids, hand_qpos=hand_qpos, object_pos=object_pos, object_quat=object_quat, reset_height_lower=reset_height_lower, reset_height_upper=reset_height_upper, rot_axis=rot_axis, p_gain=p_gain, d_gain=d_gain, friction_scale=friction_scale, randomized_mass=randomized_mass, randomized_com_offset=randomized_com_offset, gravity=gravity, ) # Match the source task by clearing any cached external object force on reset. env._random_object_force[env_ids] = 0.0 env._clear_tactile_history(env_ids) return ResetPlan( env_ids=env_ids, qpos=qpos, qvel=qvel, info_updates=info_updates, randomization=env._build_reset_randomization( num_reset, p_gain=p_gain, d_gain=d_gain, friction_scale=friction_scale, randomized_mass=randomized_mass, randomized_com_offset=randomized_com_offset, gravity=gravity, ), )
[docs] def build_reset_observation( self, env: Any, env_ids: np.ndarray, info_updates: dict[str, Any], ) -> dict[str, np.ndarray]: del env_ids tactile, contact_pos = env._policy_frame_zeros(len(info_updates["prev_targets"])) return cast( dict[str, np.ndarray], env._compute_obs_from_inputs( info_updates, dof_pos=np.asarray(info_updates["prev_targets"]), object_pos=np.asarray(info_updates["prev_object_pos"]), tactile=tactile, contact_pos=contact_pos, ), )
[docs] def build_interval_randomization_plan( self, env: Any, step_counter: int, ) -> IntervalRandomizationPlan | None: """Build Sharpa object-force perturbations for the upcoming control step. Args: env: Sharpa rotation env instance. step_counter: Global environment step counter. Returns: Interval randomization plan carrying direct object-force perturbations, or ``None`` when object-force injection is disabled. """ del step_counter domain_rand = env.cfg.domain_rand if domain_rand.force_scale <= 0.0: return None decay = float( np.power( domain_rand.force_decay, env.cfg.ctrl_dt / max(domain_rand.force_decay_interval, 1.0e-8), ) ) env._random_object_force *= decay random_mask = np.random.rand(env._num_envs) < float(domain_rand.random_force_prob_scalar) if np.any(random_mask): object_mass = env._resolve_current_object_mass() env._random_object_force[random_mask] = ( np.random.randn(int(np.sum(random_mask)), 3).astype(np.float64) * object_mass[random_mask, None] * float(domain_rand.force_scale) ) return IntervalRandomizationPlan( body_ids=np.asarray([env._object_body_id], dtype=np.int32), body_force=env._random_object_force[:, None, :].copy(), )
[docs] @registry.env("SharpaInhandRotation", sim_backend="mujoco") @registry.env("SharpaInhandRotation", sim_backend="motrix") class SharpaInhandRotationEnv(SharpaInhandBaseEnv): _cfg: SharpaInhandRotationCfg _reward_cfg: RewardConfig _OBS_MODE_ALIASES: dict[str, str] = { "separated": "separated", "flattened": "flattened", } _CRITIC_BASE_DIM_WITHOUT_OPTIONALS = 8
[docs] def __init__( self, cfg: SharpaInhandRotationCfg, num_envs: int = 1, backend_type: str = "motrix", dr_provider: DomainRandomizationProvider | None = None, ) -> None: if cfg.reward_config is None: raise ValueError("reward_config must be provided via Hydra configuration") backend = create_backend( backend_type, cfg.scene, num_envs, cfg.sim_dt, base_name=cfg.base_name, push_body_name=cfg.domain_rand.push_body_name, add_body_sensors=True, motrix_max_iterations=cfg.motrix_max_iterations, post_step_forward_sensor=cfg.post_step_forward_sensor, ) super().__init__(cfg, backend, num_envs) self._observation_mode = self._resolve_observation_mode(cfg.obs.observation_mode) expected_critic_info_dim = self._expected_critic_info_dim() if cfg.critic_info_dim != expected_critic_info_dim: cfg.critic_info_dim = expected_critic_info_dim policy_frame_dim = self._policy_frame_dim() self.obs_buf_lag_history = np.zeros( (num_envs, cfg.obs_history_len, policy_frame_dim), dtype=self._np_dtype ) self.proprio_hist_buf = np.zeros( (num_envs, cfg.prop_hist_len, policy_frame_dim), dtype=self._np_dtype ) self.critic_info_buf = np.zeros((num_envs, expected_critic_info_dim), dtype=self._np_dtype) self._friction_geom_ids = self._resolve_friction_geom_ids() self._base_geom_friction = self._resolve_base_geom_friction() self._base_gravity = self._resolve_base_gravity() self._object_body_id = self._resolve_object_body_id() self._base_body_mass = self._resolve_base_body_mass() self._base_body_ipos = self._resolve_base_body_ipos() self._default_p_gain, self._default_d_gain = self._load_default_pd_gains() self._random_object_force = np.zeros((num_envs, 3), dtype=np.float64) if cfg.control_config.torque_control: raise NotImplementedError( "Sharpa torque_control=True is not implemented with the current position-actuator XML setup. " "Set env.control_config.torque_control=false. Virtual torques are still computed explicitly for reward terms." ) self._reward_cfg = cfg.reward_config self._zero_action_test_mode = bool(cfg.zero_action_test_mode) self._enable_reward_log = True self._grasp_cache: np.ndarray | None = None axis = np.asarray(cfg.rot_axis, dtype=self._np_dtype) axis_norm = np.linalg.norm(axis) if axis_norm <= 1.0e-8: raise ValueError("rot_axis must be non-zero") self._rot_axis = np.asarray(axis / axis_norm, dtype=self._np_dtype) provider = dr_provider if dr_provider is not None else SharpaInhandRotationDRProvider() self._init_domain_randomization(provider)
[docs] def apply_action(self, actions: np.ndarray, state: NpEnvState) -> np.ndarray: actions_np = np.asarray(actions, dtype=self._np_dtype) if self._zero_action_test_mode: actions_np = np.zeros_like(actions_np, dtype=self._np_dtype) return super().apply_action(actions_np, state)
def _scale_randomization_enabled(self) -> bool: return self._num_scales > 1 or not np.allclose(self.scale_values, self.scale_values[0]) def _expected_critic_info_dim(self) -> int: priv_info_cfg = self._cfg.priv_info dim = self._CRITIC_BASE_DIM_WITHOUT_OPTIONALS if priv_info_cfg.include_friction_scale: dim += 1 if priv_info_cfg.include_gravity_direction: dim += 3 return dim def _resolve_friction_geom_ids(self) -> dict[str, np.ndarray]: """Resolve MuJoCo geom ids touched by Sharpa friction randomization. Args: None. Returns: Mapping with object, elastomer, and metal collision geom id arrays. """ try: object_geom_id = self._backend.get_geom_id(self._cfg.object_geom_name) base_body_id = self._backend.get_body_id(self._cfg.base_name) hand_body_ids = set( int(body_id) for body_id in self._backend.get_body_subtree_ids(base_body_id) ) geom_body_ids = self._backend.get_geom_body_ids() geom_contype, geom_conaffinity = self._backend.get_geom_contact_masks() geom_names = self._backend.get_geom_names() except NotImplementedError: empty = np.zeros((0,), dtype=np.int32) return {"object": empty, "elastomer": empty, "metal": empty} elastomer_ids: list[int] = [] metal_ids: list[int] = [] for geom_id, geom_name in enumerate(geom_names): if int(geom_body_ids[geom_id]) not in hand_body_ids: continue if int(geom_contype[geom_id]) == 0 and int(geom_conaffinity[geom_id]) == 0: continue if "elastomer" in geom_name: elastomer_ids.append(geom_id) else: metal_ids.append(geom_id) if not elastomer_ids: raise ValueError("No Sharpa elastomer collision geoms found for friction randomization") if not metal_ids: raise ValueError("No Sharpa metal collision geoms found for friction randomization") return { "object": np.asarray([object_geom_id], dtype=np.int32), "elastomer": np.asarray(elastomer_ids, dtype=np.int32), "metal": np.asarray(metal_ids, dtype=np.int32), } def _resolve_object_body_id(self) -> int: """Resolve the MuJoCo body id of the manipulated object. Args: None. Returns: The MuJoCo object body id, or -1 when the backend is not MuJoCo. """ if self._object_body_ids.size == 0: return -1 return int(self._object_body_ids[0]) def _resolve_current_object_mass(self) -> np.ndarray: """Resolve the current object mass used by force perturbation sampling. Args: None. Returns: Array of shape ``(num_envs,)`` with current object masses. """ if self.state is not None: critic_info = np.asarray( self.state.info.get( "critic_info", np.zeros((self._num_envs, self._cfg.critic_info_dim), dtype=self._np_dtype), ), dtype=np.float64, ) mass = critic_info[:, self._critic_info_layout()["mass"]].reshape(self._num_envs) if np.all(mass > 0.0): return mass if self._base_body_mass is None or self._object_body_id < 0: raise ValueError("MuJoCo object body-mass cache is unavailable") return np.full( (self._num_envs,), float(self._base_body_mass[self._object_body_id]), dtype=np.float64, ) def _resolve_base_geom_friction(self) -> np.ndarray | None: """Cache MuJoCo model friction vectors used as torsional/rolling templates. Args: None. Returns: Full geom friction table, or None when the backend has no geom-friction hook. """ try: return np.asarray(self._backend.get_geom_friction(), dtype=np.float64).copy() except NotImplementedError: return None def _resolve_base_gravity(self) -> np.ndarray: """Cache the default gravity vector for privileged-info fallbacks. Args: None. Returns: Gravity vector with shape ``(3,)``. """ try: return np.asarray(self._backend.get_gravity(), dtype=np.float64).copy() except NotImplementedError: pass return np.asarray([0.0, 0.0, -9.81], dtype=np.float64) def _resolve_base_body_mass(self) -> np.ndarray | None: """Cache the MuJoCo body-mass table used as reset randomization baseline. Args: None. Returns: Full body-mass table, or None when the backend is not MuJoCo. """ try: return np.asarray(self._backend.get_body_mass(), dtype=np.float64).copy() except NotImplementedError: return None def _resolve_base_body_ipos(self) -> np.ndarray | None: """Cache the MuJoCo inertial-position table used for object COM randomization. Args: None. Returns: Full body inertial-position table, or None when the backend is not MuJoCo. """ try: return np.asarray(self._backend.get_body_ipos(), dtype=np.float64).copy() except NotImplementedError: return None def _sample_friction_scale(self, batch_size: int) -> np.ndarray | None: """Sample one friction multiplier per reset environment. Args: batch_size: Number of reset environments. Returns: Shape (batch_size, 1) multipliers, or None when disabled. """ domain_rand = self._cfg.domain_rand if not domain_rand.randomize_friction: return None return np.random.uniform( domain_rand.randomize_friction_scale_lower, domain_rand.randomize_friction_scale_upper, size=(batch_size, 1), ).astype(np.float64) def _sample_object_mass(self, batch_size: int) -> np.ndarray | None: """Sample reset-time object masses. Args: batch_size: Number of reset environments. Returns: Shape (batch_size, 1) absolute object masses, or None when disabled. """ domain_rand = self._cfg.domain_rand if not domain_rand.randomize_mass: return None return np.random.uniform( domain_rand.randomize_mass_lower, domain_rand.randomize_mass_upper, size=(batch_size, 1), ).astype(np.float64) def _sample_object_com_offset(self, batch_size: int) -> np.ndarray | None: """Sample reset-time object COM offsets. Args: batch_size: Number of reset environments. Returns: Shape (batch_size, 3) COM offsets, or None when disabled. """ domain_rand = self._cfg.domain_rand if not domain_rand.randomize_com: return None return np.random.uniform( domain_rand.randomize_com_lower, domain_rand.randomize_com_upper, size=(batch_size, 3), ).astype(np.float64) def _friction_profile(self, material: str, base_sliding_friction: float) -> np.ndarray: """Build one MuJoCo friction vector while preserving XML friction ratios. Args: material: One of object, elastomer, or metal. base_sliding_friction: Sliding-friction coefficient before random scaling. Returns: Shape (3,) MuJoCo friction vector. """ if self._base_geom_friction is None: raise ValueError("MuJoCo base geom friction cache is unavailable") geom_ids = self._friction_geom_ids[material] template = np.asarray(self._base_geom_friction[int(geom_ids[0])], dtype=np.float64) sliding = float(template[0]) if sliding <= 0.0: raise ValueError(f"{material} base sliding friction must be positive, got {sliding}") return np.asarray(template / sliding * base_sliding_friction, dtype=np.float64) def _sample_reset_gravity(self, batch_size: int) -> np.ndarray | None: """Sample the exact reset-time gravity vector applied by Sharpa. Args: batch_size: Number of reset environments. Returns: Gravity array with shape ``(batch_size, 3)``, or ``None`` when reset-time gravity randomization is disabled. """ gravity = self._build_gravity_direction_randomization(batch_size) if gravity is not None: return gravity domain_rand = self._cfg.domain_rand if not getattr(domain_rand, "randomize_gravity", False): return None gravity_range = np.asarray(domain_rand.gravity_range, dtype=np.float64) if gravity_range.shape != (2, 3): raise ValueError( f"domain_rand.gravity_range must have shape (2, 3), got {gravity_range.shape}" ) low = np.minimum(gravity_range[0], gravity_range[1]) high = np.maximum(gravity_range[0], gravity_range[1]) return np.random.uniform(low=low, high=high, size=(batch_size, 3)).astype(np.float64) def _resolved_privileged_gravity( self, batch_size: int, gravity: np.ndarray | None, ) -> np.ndarray: """Resolve the gravity vector written into privileged info for one reset batch. Args: batch_size: Number of reset environments. gravity: Sampled reset gravity, if reset-time randomization is active. Returns: Gravity array with shape ``(batch_size, 3)``. """ if gravity is not None: return np.asarray(gravity, dtype=self._np_dtype) return np.broadcast_to(self._base_gravity, (batch_size, 3)).astype( self._np_dtype, copy=True ) def _build_friction_randomization( self, batch_size: int, friction_scale: np.ndarray | None, ) -> np.ndarray | None: """Build the reset-time MuJoCo geom_friction randomization table. Args: batch_size: Number of reset environments. friction_scale: Shape (batch_size, 1) sampled friction multipliers. Returns: Shape (batch_size, ngeom, 3) friction table, or None when disabled. """ if friction_scale is None: return None if self._base_geom_friction is None: raise ValueError("MuJoCo base geom friction cache is unavailable") geom_friction = np.broadcast_to( self._base_geom_friction, (batch_size, *self._base_geom_friction.shape), ).copy() scale = np.asarray(friction_scale, dtype=np.float64).reshape(batch_size, 1, 1) domain_rand = self._cfg.domain_rand material_profiles = { "object": self._friction_profile("object", domain_rand.object_base_friction), "elastomer": self._friction_profile("elastomer", domain_rand.elastomer_base_friction), "metal": self._friction_profile("metal", domain_rand.metal_base_friction), } for material, profile in material_profiles.items(): geom_friction[:, self._friction_geom_ids[material], :] = scale * profile.reshape( 1, 1, 3 ) return geom_friction def _build_object_mass_randomization( self, batch_size: int, randomized_mass: np.ndarray | None, ) -> np.ndarray | None: """Build the reset-time MuJoCo body_mass table for object mass randomization. Args: batch_size: Number of reset environments. randomized_mass: Shape (batch_size, 1) sampled absolute object masses. Returns: Shape (batch_size, nbody) mass table, or None when disabled. """ if randomized_mass is None: return None if self._base_body_mass is None or self._object_body_id < 0: raise ValueError("MuJoCo base body-mass cache is unavailable") body_mass = np.broadcast_to( self._base_body_mass, (batch_size, self._base_body_mass.size) ).copy() body_mass[:, self._object_body_id] = np.asarray(randomized_mass, dtype=np.float64).reshape( batch_size ) return body_mass def _build_object_com_randomization( self, batch_size: int, randomized_com_offset: np.ndarray | None, ) -> np.ndarray | None: """Build the reset-time MuJoCo body_ipos table for object COM randomization. Args: batch_size: Number of reset environments. randomized_com_offset: Shape (batch_size, 3) sampled COM offsets. Returns: Shape (batch_size, nbody, 3) inertial-position table, or None when disabled. """ if randomized_com_offset is None: return None if self._base_body_ipos is None or self._object_body_id < 0: raise ValueError("MuJoCo base body-ipos cache is unavailable") body_ipos = np.broadcast_to( self._base_body_ipos, (batch_size, *self._base_body_ipos.shape), ).copy() body_ipos[:, self._object_body_id, :] += np.asarray(randomized_com_offset, dtype=np.float64) return body_ipos def _build_gravity_direction_randomization(self, batch_size: int) -> np.ndarray | None: """Sample fixed-magnitude gravity vectors with randomized directions. Args: batch_size: Number of reset environments. Returns: Gravity vectors with shape ``(batch_size, 3)``, or None when disabled. """ domain_rand = self._cfg.domain_rand if not getattr(domain_rand, "randomize_gravity_direction", False): return None magnitude = float(getattr(domain_rand, "gravity_direction_magnitude", 9.81)) if magnitude <= 0.0: raise ValueError(f"gravity_direction_magnitude must be positive, got {magnitude}") gravity = np.zeros((batch_size, 3), dtype=np.float64) gravity[:, 2] = -magnitude random_quat = sample_random_quaternion(batch_size) # A uniform quaternion and its inverse have the same distribution, so # rotating gravity samples the same relative gravity directions induced # by uniformly rotating the global hand/object/task frame. return np.asarray(np_quat_apply(random_quat, gravity), dtype=np.float64) def _build_reset_randomization( self, batch_size: int, *, p_gain: np.ndarray | None, d_gain: np.ndarray | None, friction_scale: np.ndarray | None, randomized_mass: np.ndarray | None, randomized_com_offset: np.ndarray | None, gravity: np.ndarray | None, ) -> ResetRandomizationPayload | None: """Build reset-randomization payloads owned by the Sharpa rotation env. Args: batch_size: Number of reset environments. friction_scale: Shape (batch_size, 1) sampled friction multipliers. Returns: Reset-randomization payload, or None when no backend randomization is requested. """ payload = build_common_reset_randomization(self, batch_size) if self._cfg.domain_rand.randomize_pd_gains: if payload is None: payload = ResetRandomizationPayload() payload.kp = np.asarray(p_gain, dtype=np.float64) payload.kd = np.asarray(d_gain, dtype=np.float64) body_mass = self._build_object_mass_randomization(batch_size, randomized_mass) body_ipos = self._build_object_com_randomization(batch_size, randomized_com_offset) geom_friction = self._build_friction_randomization(batch_size, friction_scale) if body_mass is not None: if payload is None: payload = ResetRandomizationPayload() payload.body_mass = body_mass if body_ipos is not None: if payload is None: payload = ResetRandomizationPayload() payload.body_ipos = body_ipos if geom_friction is not None: if payload is None: payload = ResetRandomizationPayload() payload.geom_friction = geom_friction if gravity is not None: if payload is None: payload = ResetRandomizationPayload() payload.gravity = gravity return payload def _critic_info_layout(self) -> dict[str, slice]: """Describe the flat critic_info channel layout. Args: None. Returns: Mapping from logical field names to channel slices. """ offset = 0 layout = {"object_pos_delta": slice(offset, offset + 3)} offset += 3 priv_info_cfg = self._cfg.priv_info if priv_info_cfg.include_friction_scale: layout["friction"] = slice(offset, offset + 1) offset += 1 layout["mass"] = slice(offset, offset + 1) offset += 1 layout["com"] = slice(offset, offset + 3) offset += 3 layout["scale"] = slice(offset, offset + 1) offset += 1 if priv_info_cfg.include_gravity_direction: layout["gravity"] = slice(offset, offset + 3) return layout def _assign_critic_info_field( self, critic_info: np.ndarray, field_name: str, values: np.ndarray, ) -> None: """Assign one critic_info field according to the declared layout. Args: critic_info: Critic-info buffer to update in place. field_name: Logical field name from the layout. values: Batch-major values for the field. Returns: None. The critic_info array is updated in place. """ field_slice = self._critic_info_layout().get(field_name) if field_slice is None: return field_values = np.asarray(values, dtype=self._np_dtype).reshape(critic_info.shape[0], -1) critic_info[:, field_slice] = field_values def _build_reset_critic_info( self, batch_size: int, env_ids: np.ndarray, *, friction_scale: np.ndarray | None, randomized_mass: np.ndarray | None, randomized_com_offset: np.ndarray | None, gravity: np.ndarray | None, ) -> np.ndarray: """Build reset-time critic_info for all randomized object properties. Args: batch_size: Number of reset environments. env_ids: Global environment ids for this reset batch. Returns: Critic-info tensor with shape (batch_size, critic_info_dim). """ critic_info = np.zeros((batch_size, self._cfg.critic_info_dim), dtype=self._np_dtype) priv_info_cfg = self._cfg.priv_info if priv_info_cfg.include_friction_scale: self._assign_critic_info_field( critic_info, "friction", ( friction_scale if friction_scale is not None else np.ones((batch_size, 1), dtype=np.float64) ), ) if randomized_mass is not None: self._assign_critic_info_field( critic_info, "mass", randomized_mass, ) if randomized_com_offset is not None: self._assign_critic_info_field( critic_info, "com", randomized_com_offset, ) if self._scale_randomization_enabled(): self._assign_critic_info_field( critic_info, "scale", self.scale_values[self.scale_ids[env_ids]].reshape(batch_size, 1), ) if priv_info_cfg.include_gravity_direction: self._assign_critic_info_field( critic_info, "gravity", self._resolved_privileged_gravity(batch_size, gravity), ) return critic_info def _policy_frame_dim(self) -> int: obs_cfg = self._cfg.obs dim = self._num_action + self._num_action if obs_cfg.enable_tactile: dim += self._num_tactile if obs_cfg.enable_contact_pos: dim += self._num_tactile * 3 return dim def _policy_frame_zeros(self, batch_size: int) -> tuple[np.ndarray, np.ndarray]: """Build zero-filled optional policy inputs for reset observations. Args: batch_size: Number of environments in the batch. Returns: Tuple of tactile and contact-position arrays sized for the current config. """ obs_cfg = self._cfg.obs tactile_dim = self._num_tactile if obs_cfg.enable_tactile else 0 contact_pos_dim = self._num_tactile * 3 if obs_cfg.enable_contact_pos else 0 return ( np.zeros((batch_size, tactile_dim), dtype=self._np_dtype), np.zeros((batch_size, contact_pos_dim), dtype=self._np_dtype), ) def _fixed_default_object_pose(self, batch_size: int) -> np.ndarray: """Build the fixed default object pose from the backend init state. Args: batch_size: Number of environments in the batch. Returns: Array with shape ``(batch_size, 7)`` containing the default object position and quaternion loaded from the backend/model init qpos. """ object_pos = np.broadcast_to( np.asarray(self._init_qpos[self._obj_pos_slice], dtype=self._np_dtype), (batch_size, 3), ).copy() object_quat = np.broadcast_to( np.asarray(self._init_qpos[self._obj_quat_slice], dtype=self._np_dtype), (batch_size, 4), ).copy() return np.concatenate([object_pos, object_quat], axis=1).astype(self._np_dtype) def _build_object_pos_anchor( self, object_pos: np.ndarray, ) -> np.ndarray: """Choose the reward/privileged-info object-position anchor. Args: object_pos: Reset-time object positions with shape ``(batch_size, 3)``. Returns: Array with shape ``(batch_size, 3)``. When the config flag is set, the fixed XML/default object position is used; otherwise the sampled reset position is returned to preserve the legacy UniLab behavior. """ object_pos = np.asarray(object_pos, dtype=self._np_dtype) if self._cfg.use_default_object_pose_for_object_pos_anchor: return self._fixed_default_object_pose(object_pos.shape[0])[:, 0:3] return object_pos.copy() def _resolve_object_pos_anchor( self, info: dict[str, Any], batch_size: int, ) -> np.ndarray: """Resolve the cached object-position anchor for one runtime batch. Args: info: Mutable env-state info dictionary. batch_size: Number of environments in the current batch. Returns: Array with shape ``(batch_size, 3)`` used by the object-position reward and privileged-info delta channels. """ cached_anchor = info.get("object_pos_anchor") if cached_anchor is not None: return np.asarray(cached_anchor, dtype=self._np_dtype) object_default_pose = info.get("object_default_pose") if object_default_pose is not None: return np.asarray(object_default_pose, dtype=self._np_dtype)[:, 0:3] if self._cfg.use_default_object_pose_for_object_pos_anchor: return self._fixed_default_object_pose(batch_size)[:, 0:3] return np.zeros((batch_size, 3), dtype=self._np_dtype) def _policy_frame_parts( self, dof_norm: np.ndarray, targets: np.ndarray, tactile: np.ndarray, contact_pos: np.ndarray, ) -> list[np.ndarray]: """Collect policy-frame components according to the configured observation layout. Args: dof_norm: Normalized hand joint positions. targets: Hand joint targets. tactile: Tactile features. contact_pos: Contact-position features. Returns: Ordered list of arrays that should be concatenated into the policy frame. """ parts = [dof_norm, targets] obs_cfg = self._cfg.obs if obs_cfg.enable_tactile: parts.append(np.asarray(tactile, dtype=self._np_dtype)) if obs_cfg.enable_contact_pos: parts.append(np.asarray(contact_pos, dtype=self._np_dtype)) return parts def _fill_critic_info( self, critic_info: np.ndarray, object_pos: np.ndarray, object_pos_anchor: np.ndarray, ) -> np.ndarray: """Populate privileged channels that are derived from runtime state. Args: critic_info: Critic info buffer to update in place. object_pos: Current object positions with shape (batch, 3). object_pos_anchor: Object-position anchor with shape (batch, 3). Returns: Updated critic info array with shape (batch, critic_info_dim). """ self._assign_critic_info_field( critic_info, "object_pos_delta", object_pos - object_pos_anchor, ) return critic_info @classmethod def _resolve_observation_mode(cls, observation_mode: str) -> str: normalized_mode = str(observation_mode).strip().lower() resolved_mode = cls._OBS_MODE_ALIASES.get(normalized_mode) if resolved_mode is None: supported_modes = "', '".join(sorted(cls._OBS_MODE_ALIASES)) raise ValueError( f"observation_mode must be one of '{supported_modes}', got {observation_mode!r}" ) return resolved_mode def _build_policy_frame( self, dof_pos: np.ndarray, targets: np.ndarray, tactile: np.ndarray, contact_pos: np.ndarray, *, add_noise: bool = True, ) -> np.ndarray: dof_pos_f = np.asarray(dof_pos, dtype=self._np_dtype) targets_f = np.asarray(targets, dtype=self._np_dtype) dof_norm = self._normalize_joint_pos(dof_pos_f) joint_noise_scale = float(self._cfg.domain_rand.joint_noise_scale) if add_noise and joint_noise_scale > 0.0: dof_norm += ( np.random.uniform(-1.0, 1.0, size=dof_norm.shape).astype(self._np_dtype) * joint_noise_scale ) frame = np.asarray( np.concatenate( self._policy_frame_parts( dof_norm=dof_norm, targets=targets_f, tactile=tactile, contact_pos=contact_pos, ), axis=1, ), dtype=self._np_dtype, ) return self._clip_observation_values(frame) def _clip_observation_values(self, values: np.ndarray) -> np.ndarray: """Clamp observation tensors to a configurable absolute value bound. Args: values: Observation-like array to clip. Returns: Clipped array with the same shape. Non-positive ``clip_obs`` disables the clamp so callers can opt out when needed. """ clip_max = float(getattr(self._cfg, "clip_obs", 5.0)) values = np.asarray(values, dtype=self._np_dtype) if clip_max <= 0.0: return values return np.asarray(np.clip(values, -clip_max, clip_max), dtype=self._np_dtype) @property def obs_groups_spec(self) -> dict[str, int]: policy_obs_dim = self._cfg.obs_lag_steps * self._policy_frame_dim() if self._observation_mode == "flattened": return {"obs": policy_obs_dim + self._cfg.critic_info_dim} return {"obs": policy_obs_dim, "critic": policy_obs_dim + self._cfg.critic_info_dim} def _build_critic_info( self, info: dict[str, Any], batch_size: int, object_pos: np.ndarray, ) -> np.ndarray: """Build privileged critic info for the current batch. Args: info: Mutable state info dictionary carrying reset-time caches. batch_size: Number of environments in the current batch. object_pos: Current object positions with shape (batch, 3). Returns: Privileged info array with shape (batch, critic_info_dim). """ critic_info = np.asarray( info.get( "critic_info", np.zeros((batch_size, self._cfg.critic_info_dim), dtype=self._np_dtype), ), dtype=self._np_dtype, ) object_pos_anchor = self._resolve_object_pos_anchor(info, batch_size) critic_info = self._fill_critic_info( critic_info=critic_info, object_pos=object_pos, object_pos_anchor=object_pos_anchor, ) info["critic_info"] = critic_info return critic_info def _pack_observations( self, policy_obs: np.ndarray, critic_info: np.ndarray, critic_policy_obs: np.ndarray | None = None, ) -> dict[str, np.ndarray]: """Pack actor and privileged info into the env observation groups. Args: policy_obs: Actor observation tensor with shape (batch, actor_dim). critic_info: Privileged tensor with shape (batch, critic_info_dim). Returns: Observation groups that satisfy the UniLab env contract. """ if self._observation_mode == "flattened": flattened_obs = self._clip_observation_values( np.concatenate([policy_obs, critic_info], axis=1).astype(self._np_dtype) ) return {"obs": flattened_obs} if critic_policy_obs is None: critic_policy_obs = policy_obs return { "obs": self._clip_observation_values(policy_obs), "critic": self._clip_observation_values( np.concatenate([critic_policy_obs, critic_info], axis=1).astype(self._np_dtype) ), } def _compute_obs_from_inputs( self, info: dict[str, Any], dof_pos: np.ndarray, object_pos: np.ndarray, tactile: np.ndarray, contact_pos: np.ndarray, ) -> dict[str, np.ndarray]: targets = np.asarray(info.get("prev_targets", dof_pos), dtype=self._np_dtype) frame = self._build_policy_frame( dof_pos=dof_pos, targets=targets, tactile=tactile, contact_pos=contact_pos, add_noise=True, ) critic_frame = self._build_policy_frame( dof_pos=dof_pos, targets=targets, tactile=tactile, contact_pos=contact_pos, add_noise=False, ) batch_size = int(frame.shape[0]) history = info.get("obs_lag_history") critic_history = info.get("critic_obs_lag_history") if history is None: history = repeat_obs_history(frame, self._cfg.obs_history_len).astype(self._np_dtype) else: history = np.asarray(history, dtype=self._np_dtype) history[:, :-1] = history[:, 1:] history[:, -1] = frame if critic_history is None: critic_history = repeat_obs_history(critic_frame, self._cfg.obs_history_len).astype( self._np_dtype ) else: critic_history = np.asarray(critic_history, dtype=self._np_dtype) critic_history[:, :-1] = critic_history[:, 1:] critic_history[:, -1] = critic_frame info["obs_lag_history"] = history info["critic_obs_lag_history"] = critic_history info["proprio_hist"] = self._update_proprio_history(history) obs = np.asarray( history[:, -self._cfg.obs_lag_steps :].reshape(batch_size, -1), dtype=self._np_dtype, ) critic_obs = np.asarray( critic_history[:, -self._cfg.obs_lag_steps :].reshape(batch_size, -1), dtype=self._np_dtype, ) critic_info = self._build_critic_info(info, batch_size=batch_size, object_pos=object_pos) return self._pack_observations(obs, critic_info, critic_policy_obs=critic_obs) def _compute_reward( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, object_pos: np.ndarray, object_linvel: np.ndarray, object_angvel: np.ndarray, torques: np.ndarray, ) -> np.ndarray: rot_axis = np.asarray( info.get("rot_axis", np.broadcast_to(self._rot_axis, (self._num_envs, 3))), dtype=self._np_dtype, ) rotate_reward = np.clip( np.sum(object_angvel * rot_axis, axis=1), self._reward_cfg.angvel_clip_min, self._reward_cfg.angvel_clip_max, ) object_linvel_penalty = np.sum(np.abs(object_linvel), axis=1) pos_diff_penalty = np.sum(np.square(dof_pos - self.default_angles), axis=1) torque_penalty = np.sum(np.square(torques), axis=1) work_penalty = np.square(np.sum(torques * dof_vel, axis=1)) object_pos_anchor = self._resolve_object_pos_anchor(info, object_pos.shape[0]) object_pos_reward = 1.0 / (np.linalg.norm(object_pos - object_pos_anchor, axis=1) + 0.001) reward_terms: dict[str, np.ndarray] = { "rotate": np.asarray(rotate_reward, dtype=self._np_dtype), "obj_linvel": np.asarray(object_linvel_penalty, dtype=self._np_dtype), "pose_diff": np.asarray(pos_diff_penalty, dtype=self._np_dtype), "torque": np.asarray(torque_penalty, dtype=self._np_dtype), "work": np.asarray(work_penalty, dtype=self._np_dtype), "object_pos": np.asarray(object_pos_reward, dtype=self._np_dtype), } reward = np.zeros((self._num_envs,), dtype=self._np_dtype) step_count = info.get("steps", np.zeros((self._num_envs,), dtype=np.uint32)) should_log = self._enable_reward_log and (int(step_count[0]) % 4 == 0) log = {} if should_log else info.get("log", {}) for name, scale in self._reward_cfg.scales.items(): if scale == 0.0 or name not in reward_terms: continue weighted = reward_terms[name] * scale reward += weighted if should_log: log[f"reward/{name}"] = float(np.mean(weighted)) if should_log: log["reward/total"] = float(np.mean(reward)) info["log"] = log return np.asarray(reward, dtype=self._np_dtype) * self._cfg.ctrl_dt
[docs] def update_state(self, state: NpEnvState) -> NpEnvState: dof_pos = self.get_hand_dof_pos() dof_vel = self.get_hand_dof_vel() object_pos = self.get_object_pos() object_quat = self.get_object_quat() prev_object_pos = np.asarray( state.info.get("prev_object_pos", object_pos), dtype=self._np_dtype ) prev_object_quat = np.asarray( state.info.get("prev_object_quat", object_quat), dtype=self._np_dtype ) object_linvel = (object_pos - prev_object_pos) / self._cfg.ctrl_dt object_angvel = ( np_quat_to_axis_angle(np_quat_mul(object_quat, np_quat_conjugate(prev_object_quat))) / self._cfg.ctrl_dt ) targets = np.asarray( state.info.get( "prev_targets", np.broadcast_to(self.default_angles, (self._num_envs, self._num_action)).copy(), ), dtype=self._np_dtype, ) p_gain, d_gain = self._resolve_pd_gains(state.info) # Explicit virtual torque used for reward parity with source Sharpa formulation. virtual_torques = np.asarray( p_gain * (targets - dof_pos) - d_gain * dof_vel, dtype=self._np_dtype, ) tactile = self._compute_tactile_observation() contact_pos = self._compute_contact_positions(tactile) reward = self._compute_reward( state.info, dof_pos=dof_pos, dof_vel=dof_vel, object_pos=object_pos, object_linvel=object_linvel, object_angvel=object_angvel, torques=virtual_torques, ) reset_height_lower = np.asarray( state.info.get( "reset_height_lower", np.full((self._num_envs,), self._cfg.reset_height_lower, dtype=self._np_dtype), ), dtype=self._np_dtype, ) reset_height_upper = np.asarray( state.info.get( "reset_height_upper", np.full((self._num_envs,), self._cfg.reset_height_upper, dtype=self._np_dtype), ), dtype=self._np_dtype, ) terminated = (object_pos[:, 2] > reset_height_upper) | ( object_pos[:, 2] < reset_height_lower ) obs = self._compute_obs_from_inputs( state.info, dof_pos=dof_pos, object_pos=object_pos, tactile=tactile, contact_pos=contact_pos, ) state.info["prev_hand_pos"] = dof_pos.copy() state.info["hand_dof_vel"] = dof_vel.copy() state.info["prev_object_pos"] = object_pos.copy() state.info["prev_object_quat"] = object_quat.copy() state.info["torques"] = virtual_torques state.info["virtual_torques"] = virtual_torques.copy() state.info["object_linvel"] = object_linvel state.info["object_angvel"] = object_angvel return state.replace( obs=obs, reward=reward, terminated=np.asarray(terminated, dtype=bool), )
SharpaWaveRewardConfig = RewardConfig SharpaWaveRotationCfg = SharpaInhandRotationCfg