Source code for unilab.envs.manipulation.allegro_inhand.grasp_gen

"""Allegro grasp-generation task built on top of rotation env."""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path

import numpy as np

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.np_env import NpEnvState

from .rotation import AllegroRotationPPO, AllegroRotationPPOCfg, RewardConfigPPO


[docs] @registry.envcfg("AllegroInhandRotationGrasp") @dataclass class AllegroRotationGraspCfg(AllegroRotationPPOCfg): # These are fallback defaults. Hydra task env overrides (e.g. # conf/ppo/task/allegro_inhand_grasp/mujoco.yaml and CLI env.*) # are applied at env construction and take precedence. max_episode_seconds: float = 2.0 reward_config: RewardConfigPPO = field( default_factory=lambda: RewardConfigPPO( scales={ "rotate": 0.0, "obj_linvel": 0.0, "pose_diff": 0.0, "torque": 0.0, "work": 0.0, "drop": 0.0, }, angvel_clip_min=-0.5, angvel_clip_max=0.5, reset_z_threshold=0.125, ) ) gen_grasp: bool = True grasp_collection_target: int = 50_000 grasp_auto_save: bool = True grasp_quality_check: bool = True grasp_min_contacts: int = 2
[docs] @registry.env("AllegroInhandRotationGrasp", sim_backend="mujoco") @registry.env("AllegroInhandRotationGrasp", sim_backend="motrix") class AllegroRotationGrasp(AllegroRotationPPO): _cfg: AllegroRotationGraspCfg _CONTACT_SENSORS = ("ff_contact", "mf_contact", "rf_contact", "th_contact")
[docs] def __init__( self, cfg: AllegroRotationGraspCfg, num_envs: int = 1, backend_type: str = "mujoco" ) -> None: super().__init__(cfg, num_envs=num_envs, backend_type=backend_type) self._saved_grasping_states: list[np.ndarray] = [] self._grasp_cache_saved = False self._grasp_target_reached_notified = False
[docs] def apply_action(self, actions: np.ndarray, state: NpEnvState) -> np.ndarray: del actions zero_actions = np.zeros((self._num_envs, self._NUM_HAND_DOF), dtype=self._np_dtype) return super().apply_action(zero_actions, state)
@staticmethod def _sensor_scalar(sensor_data: np.ndarray) -> np.ndarray: sensor_data = np.asarray(sensor_data) if sensor_data.ndim == 1: return sensor_data return sensor_data.reshape(sensor_data.shape[0], -1)[:, 0] def _contact_count(self) -> np.ndarray: contacts = np.stack( [self._sensor_scalar(self.get_sensor_data(name)) for name in self._CONTACT_SENSORS], axis=1, ) return np.asarray(np.sum(contacts > 0.5, axis=1), dtype=np.int32) def _compute_grasp_conditions(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: # Hora-style good grasp criteria from compute_reward: # 1) all fingertips near object, 2) enough fingertip contacts, 3) object not dropped. ball_pos = self.get_ball_pos() fingertip_pos = self.get_fingertip_pos() cond1 = np.all(np.linalg.norm(fingertip_pos - ball_pos[:, None, :], axis=-1) < 0.1, axis=1) cond2 = self._contact_count() >= int(self._cfg.grasp_min_contacts) cond3 = ball_pos[:, 2] > float(self._reward_cfg.reset_z_threshold) return ( np.asarray(cond1, dtype=bool), np.asarray(cond2, dtype=bool), np.asarray(cond3, dtype=bool), ) def _check_grasp_quality(self, env_ids: np.ndarray) -> np.ndarray: cond1, cond2, cond3 = self._compute_grasp_conditions() return np.asarray(cond1[env_ids] & cond2[env_ids] & cond3[env_ids], dtype=bool) def _total_saved_grasps(self) -> int: if not self._saved_grasping_states: return 0 return int(sum(states.shape[0] for states in self._saved_grasping_states)) def _stop_collection(self) -> None: if self._grasp_target_reached_notified: return target = int(self._cfg.grasp_collection_target) if target <= 0: return total = self._total_saved_grasps() if total < target: return self._grasp_target_reached_notified = True print( "[AllegroInhandRotationGrasp] Grasp collection target reached " f"({total}/{target}). Program stopped." ) if self.state is not None: log = self.state.info.get("log", {}) log["grasp/target_reached"] = 1.0 self.state.info["log"] = log exit(0) def _save_grasp_cache(self, force: bool = False) -> None: if self._grasp_cache_saved and not force: return total = self._total_saved_grasps() target = int(self._cfg.grasp_collection_target) if not force and total < target: return if total == 0: return all_states = np.concatenate(self._saved_grasping_states, axis=0).astype(np.float32) if target > 0: all_states = all_states[:target] output_file = Path(self._cfg.grasp_cache_path or "caches/allegro_grasp_50k.npy") if not output_file.is_absolute(): output_file = ASSETS_ROOT_PATH / output_file output_file.parent.mkdir(parents=True, exist_ok=True) np.save(output_file, all_states) self._grasp_cache_saved = True if self.state is not None: log = self.state.info.get("log", {}) log["grasp_cache/saved"] = 1.0 log["grasp_cache/num_states"] = float(all_states.shape[0]) self.state.info["log"] = log def _collect_successful_grasps(self, env_ids: np.ndarray) -> None: if self.state is None or env_ids.size == 0: return success_mask = self.state.truncated[env_ids] & ~self.state.terminated[env_ids] if not np.any(success_mask): return success_env_ids = env_ids[np.flatnonzero(success_mask)] if self._cfg.grasp_quality_check: quality_mask = self._check_grasp_quality(success_env_ids) success_env_ids = success_env_ids[np.flatnonzero(quality_mask)] if success_env_ids.size == 0: return curr_dof_pos = np.asarray(self.state.info.get("curr_dof_pos", self.get_hand_dof_pos())) curr_ball_pos = np.asarray(self.state.info.get("curr_ball_pos", self.get_ball_pos())) curr_ball_quat = np.asarray(self.state.info.get("curr_ball_quat", self.get_ball_quat())) hand_qpos = curr_dof_pos[success_env_ids, : self._NUM_HAND_DOF] ball_pos = curr_ball_pos[success_env_ids] ball_quat = curr_ball_quat[success_env_ids] states = np.concatenate([hand_qpos, ball_pos, ball_quat], axis=1).astype(np.float32) self._saved_grasping_states.append(states) self._save_grasp_cache() self._stop_collection() if self.state is not None: log = self.state.info.get("log", {}) log["grasp/cache_size"] = float(self._total_saved_grasps()) self.state.info["log"] = log
[docs] def update_state(self, state: NpEnvState) -> NpEnvState: next_state = super().update_state(state) reward = np.zeros((self._num_envs,), dtype=self._np_dtype) cond1, cond2, cond3 = self._compute_grasp_conditions() if self._cfg.grasp_quality_check: grasp_valid = cond1 & cond2 & cond3 terminated = np.asarray(next_state.terminated | (~grasp_valid), dtype=bool) else: grasp_valid = np.ones((self._num_envs,), dtype=bool) terminated = np.asarray(next_state.terminated, dtype=bool) step_count = next_state.info.get("steps", np.zeros((self._num_envs,), dtype=np.uint32)) should_log = self._enable_reward_log and (int(step_count[0]) % 4 == 0) if should_log: log = next_state.info.get("log", {}) log["grasp/cond1"] = float(np.mean(cond1.astype(np.float32))) log["grasp/cond2"] = float(np.mean(cond2.astype(np.float32))) log["grasp/cond3"] = float(np.mean(cond3.astype(np.float32))) log["grasp/valid"] = float(np.mean(grasp_valid.astype(np.float32))) log["grasp/cache_size"] = float(self._total_saved_grasps()) next_state.info["log"] = log return next_state.replace(reward=reward, terminated=terminated)
def _reset_done_envs(self) -> None: if self.state is not None: done = self.state.terminated | self.state.truncated if np.any(done): env_ids = np.flatnonzero(done).astype(np.int32) self._collect_successful_grasps(env_ids) super()._reset_done_envs()
[docs] def close(self) -> None: self._save_grasp_cache(force=bool(self._cfg.grasp_auto_save)) super().close()
AllegroRotationGraspEnv = AllegroRotationGrasp AllegroRotationGraspCfgAlias = AllegroRotationGraspCfg