Source code for unilab.envs.manipulation.sharpa_inhand.grasp_gen

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, cast

import numpy as np

from unilab.assets import ASSETS_ROOT_PATH
from unilab.base import registry
from unilab.base.np_env import NpEnvState
from unilab.dr import ResetPlan
from unilab.dr.dr_utils import build_common_reset_randomization
from unilab.envs.common.rotation import np_quat_error_magnitude
from unilab.envs.manipulation.sharpa_inhand.base import (
    SOURCE_DEFAULT_HAND_JOINT_POS_DEG,
    SharpaDomainRandConfig,
    resolve_grasp_cache_file,
)
from unilab.envs.manipulation.sharpa_inhand.rotation import (
    RewardConfig,
    SharpaInhandRotationCfg,
    SharpaInhandRotationDRProvider,
    SharpaInhandRotationEnv,
)


def _default_sharpa_grasp_domain_rand() -> SharpaDomainRandConfig:
    """Build the nested DR defaults used by Sharpa grasp collection.

    Returns:
        Domain-randomization config matching the grasp-task owner defaults.
    """
    return SharpaDomainRandConfig(
        randomize_pd_gains=False,
        randomize_friction=False,
        randomize_com=False,
        randomize_mass=True,
        randomize_mass_lower=0.05,
        randomize_mass_upper=0.051,
        force_scale=0.0,
        random_force_prob_scalar=0.0,
    )


[docs] @dataclass class SharpaInhandRotationGraspCfg(SharpaInhandRotationCfg): max_episode_seconds: float = 3.0 # 12.0 reset_height_lower: float = 0.61406 reset_height_upper: float = 0.62406 reset_angle_diff: float = 30.0 / 180.0 * np.pi grasp_cache_path: str = "" domain_rand: SharpaDomainRandConfig = field(default_factory=_default_sharpa_grasp_domain_rand) reward_config: RewardConfig = field( default_factory=lambda: RewardConfig( scales={ "rotate": 0.0, "obj_linvel": 0.0, "pose_diff": 0.0, "torque": 0.0, "work": 0.0, "object_pos": 0.0, } ) ) grasp_collection_target: int = 50_000 grasp_auto_save: bool = True
[docs] @registry.envcfg("SharpaInhandRotationGrasp") @dataclass class SharpaInhandGraspEnvCfg(SharpaInhandRotationGraspCfg): pass
[docs] class SharpaInhandGraspDRProvider(SharpaInhandRotationDRProvider):
[docs] def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan: # Keep original grasp task behavior: collect successful pre-reset states on each reset. env._collect_successful_grasps(env_ids) num_reset = len(env_ids) if num_reset == 0: return ResetPlan( env_ids=env_ids, qpos=np.zeros((0, env.nq), dtype=np.float64), qvel=np.zeros((0, env.nv), dtype=np.float64), info_updates={}, randomization=None, ) rand = 2.0 * np.random.rand(num_reset, env._num_action) - 1.0 hand_qpos = np.broadcast_to(env._grasp_default_angles, (num_reset, env._num_action)).copy() hand_qpos += 0.15 * rand hand_qpos = np.clip(hand_qpos, env._ctrl_lower, env._ctrl_upper) object_pos = np.broadcast_to(env._init_qpos[env._obj_pos_slice], (num_reset, 3)).copy() object_quat = np.broadcast_to(env._init_qpos[env._obj_quat_slice], (num_reset, 4)).copy() qpos = np.zeros((num_reset, env.nq), dtype=np.float64) qpos[:, : env._num_action] = hand_qpos qpos[:, env._obj_pos_slice] = object_pos qpos[:, env._obj_quat_slice] = object_quat qvel = np.zeros((num_reset, env.nv), dtype=np.float64) p_gain, d_gain = self._sample_reset_pd_gains(env, num_reset, dtype=env._np_dtype) info_updates = self._build_info_updates( env, env_ids=env_ids, hand_qpos=hand_qpos, object_pos=object_pos, object_quat=object_quat, reset_height_lower=np.full((num_reset,), env.cfg.reset_height_lower, dtype=np.float64), reset_height_upper=np.full((num_reset,), env.cfg.reset_height_upper, dtype=np.float64), rot_axis=np.broadcast_to(env._rot_axis, (num_reset, 3)).astype(np.float64), p_gain=p_gain, d_gain=d_gain, friction_scale=None, randomized_mass=None, randomized_com_offset=None, gravity=None, ) env._clear_tactile_history(env_ids) return ResetPlan( env_ids=env_ids, qpos=qpos, qvel=qvel, info_updates=info_updates, randomization=build_common_reset_randomization(env, num_reset), )
[docs] @registry.env("SharpaInhandRotationGrasp", sim_backend="mujoco") @registry.env("SharpaInhandRotationGrasp", sim_backend="motrix") class SharpaInhandRotationGraspEnv(SharpaInhandRotationEnv): _cfg: SharpaInhandRotationGraspCfg
[docs] def __init__( self, cfg: SharpaInhandRotationGraspCfg, num_envs: int = 1, backend_type: str = "motrix", ) -> None: if cfg.domain_rand.randomize_gravity or cfg.domain_rand.randomize_gravity_direction: raise ValueError( "SharpaInhandRotationGrasp does not support gravity randomization; " "disable env.domain_rand.randomize_gravity and " "env.domain_rand.randomize_gravity_direction." ) super().__init__( cfg, num_envs=num_envs, backend_type=backend_type, dr_provider=SharpaInhandGraspDRProvider(), ) self._saved_grasping_states: list[list[np.ndarray]] = [ list() for _ in range(self._num_scales) ] if self._num_scales != 1: raise ValueError( "Sharpa grasp generation now collects exactly one object scale per run; " f"got scale_list={list(self.scale_values)}" ) self._grasp_target_per_scale = max(1, int(cfg.grasp_collection_target)) self._grasp_cache_saved = False self._grasp_target_reached_notified = False self._last_grasp_progress_step = -1 self._last_grasp_progress_counts: tuple[int, ...] | None = None self._grasp_default_angles = np.asarray( np.deg2rad(np.asarray(SOURCE_DEFAULT_HAND_JOINT_POS_DEG, dtype=np.float64)), dtype=self._np_dtype, ) if self._grasp_default_angles.shape[0] != self._num_action: raise ValueError( "Source grasp default angle count mismatch: " f"{self._grasp_default_angles.shape[0]} vs expected {self._num_action}" ) if np.any(np.bincount(self.scale_ids, minlength=self._num_scales) == 0): raise ValueError( "Sharpa grasp generation requires at least one environment for the configured scale; " f"got num_envs={num_envs}, num_scales={self._num_scales}" )
[docs] def apply_action(self, actions: np.ndarray, state: NpEnvState) -> np.ndarray: # Grasp-cache collection should not use policy/random actions. # Keep controls fixed at reset targets by forcing zero action input. zero_actions = np.zeros_like(actions, dtype=self._np_dtype) return super().apply_action(zero_actions, state)
def _total_saved_grasps(self) -> int: return int(sum(len(bucket) for bucket in self._saved_grasping_states)) def _collection_target_reached(self) -> bool: return all( len(bucket) >= self._grasp_target_per_scale for bucket in self._saved_grasping_states ) def _get_per_scale_grasp_counts(self) -> tuple[int, ...]: """Return collected grasp counts for each scale bucket. Returns: Tuple where index is scale id and value is collected grasp count. """ return tuple(len(bucket) for bucket in self._saved_grasping_states) def _maybe_print_grasp_progress(self, force: bool = False) -> None: """Print runtime grasp-collection progress grouped by scale. Args: force: When True, print even if throttling would normally skip. """ if self.state is None: return counts = self._get_per_scale_grasp_counts() step_info = self.state.info.get("steps") step = int(step_info[0]) if isinstance(step_info, np.ndarray) and step_info.size > 0 else 0 if not force: if counts == self._last_grasp_progress_counts: return if self._last_grasp_progress_step >= 0 and step - self._last_grasp_progress_step < 32: return total = int(sum(counts)) per_scale = ", ".join( f"scale={float(self.scale_values[i]):g}:{count}" for i, count in enumerate(counts) ) print( "[SharpaInhandRotationGrasp] " f"grasp progress total={total}/{int(self._cfg.grasp_collection_target)}, " f"per_scale=[{per_scale}]" ) self._last_grasp_progress_step = step self._last_grasp_progress_counts = counts def _stop_collection(self) -> None: if self._grasp_target_reached_notified: return if not self._collection_target_reached(): return self._maybe_print_grasp_progress(force=True) self._grasp_target_reached_notified = True collected = self._total_saved_grasps() target = int(self._cfg.grasp_collection_target) print( "[SharpaInhandRotationGrasp] Grasp collection target reached " f"(saved={collected}, configured_target={target}). Program stopped." ) if self.state is not None: log = self.state.info.get("log", {}) log["grasp/target_reached"] = 1.0 self.state.info["log"] = log exit(0) def _collect_successful_grasps(self, env_ids: np.ndarray) -> None: if self.state is None or len(env_ids) == 0: return success_mask = self.state.truncated[env_ids] & ~self.state.terminated[env_ids] if not np.any(success_mask): return success_env_ids = env_ids[np.flatnonzero(success_mask)] hand_qpos = self.get_hand_dof_pos()[success_env_ids] object_pos = self.get_object_pos()[success_env_ids] object_quat = self.get_object_quat()[success_env_ids] all_states = np.concatenate([hand_qpos, object_pos, object_quat], axis=1).astype(np.float32) saved_scale_ids = self.scale_ids[success_env_ids] for i, scale_id in enumerate(saved_scale_ids): bucket = self._saved_grasping_states[int(scale_id)] if len(bucket) < self._grasp_target_per_scale: bucket.append(all_states[i : i + 1]) self._maybe_print_grasp_progress() if self._grasp_cache_saved: return finished_scales = sum( int(len(bucket) >= self._grasp_target_per_scale) for bucket in self._saved_grasping_states ) if finished_scales < self._num_scales: return if not self._cfg.grasp_auto_save: self._grasp_cache_saved = True self._stop_collection() return output_file = resolve_grasp_cache_file( self._cfg.grasp_cache_path or "caches/sharpa_grasp_linspace", float(self.scale_values[0]), ) if not output_file.is_absolute(): output_file = ASSETS_ROOT_PATH / output_file output_file.parent.mkdir(parents=True, exist_ok=True) save_data = np.concatenate(self._saved_grasping_states[0], axis=0)[ : self._grasp_target_per_scale ] np.save(output_file, save_data) self._grasp_cache_saved = True if self.state is not None: log = self.state.info.get("log", {}) log["grasp_cache/saved"] = 1.0 log["grasp_cache/num_states"] = float(save_data.shape[0]) self.state.info["log"] = log self._stop_collection() def _compute_reward( self, info: dict[str, Any], dof_pos: np.ndarray, dof_vel: np.ndarray, object_pos: np.ndarray, object_linvel: np.ndarray, object_angvel: np.ndarray, torques: np.ndarray, ) -> np.ndarray: del info, dof_pos, dof_vel, object_pos, object_linvel, object_angvel, torques return np.zeros((self._num_envs,), dtype=self._np_dtype)
[docs] def update_state(self, state: NpEnvState) -> NpEnvState: next_state = super().update_state(state) fingertip_pos = self.get_fingertip_pos() object_pos = self.get_object_pos() object_quat = self.get_object_quat() object_default_pose = np.asarray( next_state.info.get( "object_default_pose", np.zeros((self._num_envs, 7), dtype=self._np_dtype) ), dtype=self._np_dtype, ) cond1 = np.all( np.linalg.norm(fingertip_pos - object_pos[:, None, :], axis=-1) < 0.1, axis=1 ) tactile = np.asarray(self.last_contacts, dtype=self._np_dtype) cond2 = np.sum(tactile > 0.5, axis=1) >= 3 quat_error = np_quat_error_magnitude(object_default_pose[:, 3:7], object_quat) cond3 = quat_error < self._cfg.reset_angle_diff grasp_valid = cond1 & cond2 & cond3 terminated = np.asarray(next_state.terminated | (~grasp_valid), dtype=bool) reward = np.zeros((self._num_envs,), dtype=self._np_dtype) step_count = next_state.info.get("steps", np.zeros((self._num_envs,), dtype=np.uint32)) should_log = self._enable_reward_log and (int(step_count[0]) % 4 == 0) if should_log: log = next_state.info.get("log", {}) log["grasp/cond1"] = float(np.mean(cond1.astype(np.float32))) log["grasp/cond2"] = float(np.mean(cond2.astype(np.float32))) log["grasp/cond3"] = float(np.mean(cond3.astype(np.float32))) log["grasp/valid"] = float(np.mean(grasp_valid.astype(np.float32))) per_scale_counts = self._get_per_scale_grasp_counts() log["grasp/target_cache_size"] = float(self._cfg.grasp_collection_target) for scale_idx, count in enumerate(per_scale_counts): scale_value = float(self.scale_values[scale_idx]) log[f"grasp/cache_size_scale_{scale_value:g}"] = float(count) next_state.info["log"] = log return next_state.replace(reward=reward, terminated=terminated)
SharpaWaveGraspCfg = SharpaInhandGraspEnvCfg