Source code for unilab.envs.manipulation.allegro_inhand.rotation

"""Allegro in-hand rotation environment."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, cast

import numpy as np
from etils import epath

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.backend import create_backend
from unilab.base.np_env import NpEnvState
from unilab.base.scene import SceneCfg
from unilab.dr import (
    DomainRandomizationCapabilities,
    DomainRandomizationProvider,
    IntervalRandomizationPlan,
    ResetPlan,
)
from unilab.dr.dr_utils import (
    build_common_reset_randomization,
    build_interval_push_plan,
    validate_common_reset_randomization,
    validate_interval_push_support,
    zero_actions,
)
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import np_quat_conjugate, np_quat_mul, np_quat_to_axis_angle

from .base import AllegroBaseCfg, AllegroBaseEnv


[docs] def resolve_grasp_cache_path(cache_path: str) -> epath.Path: """Resolve Allegro grasp cache paths using the asset-root convention.""" path = epath.Path(cache_path) if path.is_absolute() or path.exists(): return path return epath.Path(ASSETS_ROOT_PATH / cache_path)
[docs] def normalize_rotation_axis(rotation_axis: tuple[float, float, float]) -> np.ndarray: axis = np.asarray(rotation_axis, dtype=get_global_dtype()) return np.asarray(axis / np.linalg.norm(axis), dtype=get_global_dtype())
[docs] def compute_ball_angvel( ball_quat: np.ndarray, prev_ball_quat: np.ndarray, ctrl_dt: float ) -> np.ndarray: rel_quat = np_quat_mul(ball_quat, np_quat_conjugate(prev_ball_quat)) return np.asarray(np_quat_to_axis_angle(rel_quat) / ctrl_dt, dtype=get_global_dtype())
[docs] def compute_pd_torques( targets: np.ndarray, dof_pos: np.ndarray, dof_vel: np.ndarray, kp: float, kd: float ) -> np.ndarray: torques = kp * (targets - dof_pos) - kd * dof_vel return np.asarray(np.clip(torques, -0.5, 0.5), dtype=get_global_dtype())
[docs] def build_obs_lag_history( init_obs: np.ndarray, num_lag_steps: int, num_obs_per_step: int ) -> np.ndarray: num_envs = init_obs.shape[0] history = np.broadcast_to( init_obs[:, None, :], (num_envs, num_lag_steps, num_obs_per_step), ).copy() return np.asarray(history, dtype=init_obs.dtype)
[docs] def sample_cached_grasps( grasp_cache: np.ndarray, num_reset: int ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: idx = np.random.randint(0, len(grasp_cache), size=num_reset) sampled = grasp_cache[idx] return sampled[:, :16], sampled[:, 16:19], sampled[:, 19:23]
[docs] @dataclass class RewardConfigPPO: scales: dict[str, float] angvel_clip_min: float angvel_clip_max: float reset_z_threshold: float
[docs] @dataclass class DomainRandConfig: randomize_base_mass: bool = False added_mass_range: list[float] = field(default_factory=lambda: [0.0, 0.0]) random_com: bool = False com_offset_x: list[float] = field(default_factory=lambda: [0.0, 0.0]) randomize_gravity: bool = False gravity_range: list[list[float]] = field( default_factory=lambda: [[0.0, 0.0, -9.81], [0.0, 0.0, -9.81]] ) push_robots: bool = False push_interval: int = 750 max_force: list[float] = field(default_factory=lambda: [1.0, 1.0, 0.5]) push_body_name: str | None = None joint_noise: float = 0.0 ball_vel_noise: float = 0.0 ball_z_offset: float = 0.0
[docs] @registry.envcfg("AllegroInhandRotation") @dataclass class AllegroRotationPPOCfg(AllegroBaseCfg): scene: SceneCfg = field( default_factory=lambda: SceneCfg( model_file=str(ASSETS_ROOT_PATH / "robots" / "allegro_hand" / "scene.xml") ) ) max_episode_seconds: float = 20.0 reward_config: RewardConfigPPO | None = None domain_rand: DomainRandConfig = field(default_factory=DomainRandConfig) rotation_axis: tuple[float, float, float] = (0.0, 0.0, 1.0) grasp_cache_path: str = "caches/allegro_grasp_50k.npy" gen_grasp: bool = False
[docs] class AllegroRotationDomainRandomizationProvider(DomainRandomizationProvider):
[docs] def validate(self, env: Any, capabilities: DomainRandomizationCapabilities) -> None: validate_common_reset_randomization(env, capabilities) validate_interval_push_support(env, capabilities)
[docs] def build_interval_randomization_plan( self, env: Any, step_counter: int ) -> IntervalRandomizationPlan | None: return build_interval_push_plan(env, step_counter)
def _load_grasp_cache(self, env: Any) -> np.ndarray | None: if env._grasp_cache_loaded: return cast(np.ndarray | None, env._grasp_cache) if env.cfg.gen_grasp: env._grasp_cache = None env._grasp_cache_loaded = True return None cache_path = resolve_grasp_cache_path(env.cfg.grasp_cache_path) if not cache_path.exists(): print( "[allegro_inhand] Grasp cache is missing; no Hugging Face download will be " f"attempted. Expected local cache: {cache_path}. Generate one with " "`uv run train --algo ppo --task allegro_inhand_grasp --sim mujoco " "training.no_play=true`, or point `env.grasp_cache_path` at an existing " "local cache." ) env._grasp_cache = None env._grasp_cache_loaded = True return None env._grasp_cache = np.load(cache_path).astype(np.float64) env._grasp_cache_loaded = True print( "[allegro_inhand] Loaded grasp cache: " f"{cache_path}, shape={env._grasp_cache.shape}, dtype={env._grasp_cache.dtype}" ) return cast(np.ndarray | None, env._grasp_cache) def _sample_reset_state( self, env: Any, num_reset: int ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: dr = env.cfg.domain_rand grasp_cache = self._load_grasp_cache(env) if grasp_cache is not None: hand_qpos, ball_pos, ball_quat = sample_cached_grasps(grasp_cache, num_reset) else: hand_qpos = np.broadcast_to(env.default_angles, (num_reset, env._NUM_HAND_DOF)).copy() hand_qpos += np.random.uniform(-dr.joint_noise, dr.joint_noise, hand_qpos.shape).astype( np.float64 ) hand_qpos = np.clip( hand_qpos, env._ctrl_lower.astype(np.float64), env._ctrl_upper.astype(np.float64), ) ball_init_pos = env._init_qpos[env._NUM_HAND_DOF : env._NUM_HAND_DOF + 3] ball_pos = np.broadcast_to(ball_init_pos, (num_reset, 3)).copy() ball_pos[:, 2] += dr.ball_z_offset ball_quat = np.tile([1.0, 0.0, 0.0, 0.0], (num_reset, 1)) qvel = np.zeros((num_reset, env.nv), dtype=np.float64) qvel[:, env._NUM_HAND_DOF : env._NUM_HAND_DOF + 3] = np.random.uniform( -dr.ball_vel_noise, dr.ball_vel_noise, (num_reset, 3), ) return hand_qpos, ball_pos, ball_quat, qvel def _build_info_updates( self, env: Any, hand_qpos: np.ndarray, ball_pos: np.ndarray, ball_quat: np.ndarray, ) -> dict[str, np.ndarray]: num_reset = hand_qpos.shape[0] dtype = get_global_dtype() init_ctrl = np.asarray(hand_qpos, dtype=dtype) init_ball_pos = np.asarray(ball_pos, dtype=dtype) dof_pos_norm = 2.0 * (init_ctrl - env._dof_mid) / (env._dof_range + 1e-8) init_obs = np.concatenate([dof_pos_norm, init_ctrl, init_ball_pos], axis=1, dtype=dtype) obs_lag_history = build_obs_lag_history(init_obs, env._NUM_LAG_STEPS, env._NUM_OBS_PER_STEP) return { "current_actions": zero_actions(num_reset, env._num_action), "last_actions": zero_actions(num_reset, env._num_action), "prev_ctrl": init_ctrl, "init_pose": init_ctrl.copy(), "prev_dof_pos": init_ctrl.copy(), "prev_ball_pos": init_ball_pos.copy(), "prev_ball_quat": np.asarray(ball_quat, dtype=dtype).copy(), "obs_lag_history": obs_lag_history, }
[docs] def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan: num_reset = len(env_ids) hand_qpos, ball_pos, ball_quat, qvel = self._sample_reset_state(env, num_reset) qpos = np.concatenate([hand_qpos, ball_pos, ball_quat], axis=1, dtype=np.float64) info_updates = self._build_info_updates(env, hand_qpos, ball_pos, ball_quat) return ResetPlan( env_ids=env_ids, qpos=qpos, qvel=qvel, info_updates=info_updates, randomization=build_common_reset_randomization(env, num_reset), )
[docs] def build_reset_observation( self, env: Any, env_ids: np.ndarray, info_updates: dict[str, Any] ) -> dict[str, np.ndarray]: del env_ids return cast( dict[str, np.ndarray], env._compute_obs( info_updates, info_updates["prev_ctrl"], info_updates["prev_ball_pos"], ), )
# ─────────────────────────── Environment ──────────────────────────────
[docs] @registry.env("AllegroInhandRotation", sim_backend="mujoco") @registry.env("AllegroInhandRotation", sim_backend="motrix") class AllegroRotationPPO(AllegroBaseEnv): _cfg: AllegroRotationPPOCfg _reward_cfg: RewardConfigPPO _NUM_OBS_PER_STEP = 35 _NUM_LAG_STEPS = 3
[docs] def __init__( self, cfg: AllegroRotationPPOCfg, num_envs: int = 1, backend_type: str = "mujoco" ) -> None: if cfg.reward_config is None: raise ValueError("reward_config must be provided via Hydra configuration") backend = create_backend( backend_type, cfg.scene, num_envs, cfg.sim_dt, base_name="palm", push_body_name=cfg.domain_rand.push_body_name, add_body_sensors=True, position_actuator_gains={ "kp": cfg.control_config.kp, "kd": cfg.control_config.kd, "actuator_ids": slice(0, 16), }, motrix_max_iterations=cfg.motrix_max_iterations, post_step_forward_sensor=cfg.post_step_forward_sensor, ) super().__init__(cfg, backend, num_envs) self._enable_reward_log = True self._reward_cfg = cfg.reward_config self._dof_range = self._ctrl_upper - self._ctrl_lower self._dof_mid = (self._ctrl_upper + self._ctrl_lower) / 2.0 self._rot_axis = normalize_rotation_axis(cfg.rotation_axis) self._grasp_cache: np.ndarray | None = None self._grasp_cache_loaded = False self._init_reward_functions() self._init_domain_randomization(AllegroRotationDomainRandomizationProvider())
@property def obs_groups_spec(self) -> dict[str, int]: return {"obs": self._NUM_OBS_PER_STEP * self._NUM_LAG_STEPS} def _init_reward_functions(self) -> None: self._reward_fns = { "rotate": self._reward_rotate, "obj_linvel": self._reward_obj_linvel, "pose_diff": self._reward_pose_diff, "torque": self._reward_torque, "work": self._reward_work, "drop": self._reward_drop, } def _reward_rotate( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del info, dof_pos, dof_vel, ball_pos, ball_linvel, torques, terminated vec_dot = ball_angvel @ self._rot_axis reward: np.ndarray = np.clip( vec_dot, self._reward_cfg.angvel_clip_min, self._reward_cfg.angvel_clip_max ) return reward def _reward_obj_linvel( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del info, dof_pos, dof_vel, ball_pos, ball_angvel, torques, terminated penalty: np.ndarray = np.sum(np.abs(ball_linvel), axis=1) return penalty def _reward_pose_diff( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del dof_vel, ball_pos, ball_linvel, ball_angvel, torques, terminated diff = dof_pos - info["init_pose"] penalty: np.ndarray = np.sum(np.square(diff), axis=1) return penalty def _reward_torque( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del info, dof_pos, dof_vel, ball_pos, ball_linvel, ball_angvel, terminated penalty: np.ndarray = np.sum(np.square(torques), axis=1) return penalty def _reward_work( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del info, dof_pos, ball_pos, ball_linvel, ball_angvel, terminated work = np.sum(torques * dof_vel, axis=1) penalty: np.ndarray = np.square(work) return penalty def _reward_drop( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: del info, dof_pos, dof_vel, ball_pos, ball_linvel, ball_angvel, torques return np.asarray(terminated, dtype=get_global_dtype())
[docs] def update_state(self, state: NpEnvState) -> NpEnvState: dof_pos = self.get_hand_dof_pos() ball_pos = self.get_ball_pos() ball_quat = self.get_ball_quat() dof_vel = (dof_pos - state.info.get("prev_dof_pos", dof_pos)) / self._cfg.ctrl_dt ball_linvel = (ball_pos - state.info.get("prev_ball_pos", ball_pos)) / self._cfg.ctrl_dt prev_ball_quat = state.info.get("prev_ball_quat", ball_quat) ball_angvel = compute_ball_angvel(ball_quat, prev_ball_quat, self._cfg.ctrl_dt) state.info["curr_dof_pos"] = dof_pos.copy() state.info["curr_ball_pos"] = ball_pos.copy() state.info["curr_ball_quat"] = ball_quat.copy() state.info["prev_dof_pos"] = dof_pos.copy() state.info["prev_ball_pos"] = ball_pos.copy() state.info["prev_ball_quat"] = ball_quat.copy() targets = state.info["prev_ctrl"] torques = compute_pd_torques( targets=targets, dof_pos=dof_pos, dof_vel=dof_vel, kp=self._cfg.control_config.kp, kd=self._cfg.control_config.kd, ) terminated = ball_pos[:, 2] < self._reward_cfg.reset_z_threshold reward = self._compute_reward( state.info, dof_pos, dof_vel, ball_pos, ball_linvel, ball_angvel, torques, terminated ) obs = self._compute_obs(state.info, dof_pos, ball_pos) return state.replace(obs=obs, reward=reward, terminated=terminated)
def _compute_reward( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, ball_pos: np.ndarray, ball_linvel: np.ndarray, ball_angvel: np.ndarray, torques: np.ndarray, terminated: np.ndarray, ) -> np.ndarray: dtype = get_global_dtype() reward = np.zeros(self._num_envs, dtype=dtype) step_count = info.get("steps", np.zeros(self._num_envs, dtype=np.uint32)) should_log = self._enable_reward_log and (int(step_count[0]) % 4 == 0) log = {} if should_log else info.get("log", {}) for name, scale in self._reward_cfg.scales.items(): if scale == 0 or name not in self._reward_fns: continue rew = self._reward_fns[name]( info, dof_pos, dof_vel, ball_pos, ball_linvel, ball_angvel, torques, terminated ) weighted_rew = rew * scale reward += weighted_rew if should_log: log[f"reward/{name}"] = float(np.mean(weighted_rew)) if should_log: log["reward/total"] = float(np.mean(reward)) info["log"] = log return reward * self._cfg.ctrl_dt def _compute_obs( self, info: dict[str, Any], dof_pos: np.ndarray, ball_pos: np.ndarray ) -> dict[str, np.ndarray]: dtype = get_global_dtype() targets = info["prev_ctrl"] dof_pos_norm = 2.0 * (dof_pos - self._dof_mid) / (self._dof_range + 1e-8) noise_cfg = self._cfg.noise_config if noise_cfg.level > 0.0: dof_pos_norm += ( np.random.uniform(-1.0, 1.0, dof_pos_norm.shape).astype(dtype) * noise_cfg.level * noise_cfg.scale_joint_angle ) current_obs = np.concatenate( [dof_pos_norm, targets, ball_pos.astype(dtype)], axis=1, dtype=dtype ) num_envs = dof_pos.shape[0] obs_lag_history = info.get( "obs_lag_history", np.zeros( (num_envs, self._NUM_LAG_STEPS, self._NUM_OBS_PER_STEP), dtype=dtype, ), ) obs_lag_history[:, :-1] = obs_lag_history[:, 1:] obs_lag_history[:, -1] = current_obs info["obs_lag_history"] = obs_lag_history return { "obs": np.asarray(obs_lag_history.reshape(num_envs, -1), dtype=dtype), }
RewardConfig = RewardConfigPPO Domain_Rand = DomainRandConfig AllegroRotationCfg = AllegroRotationPPOCfg AllegroRotationMj = AllegroRotationPPO