"""Shared DomainRandomizationProvider for locomotion environments.
Implements the common reset/interval randomization logic shared by
G1, Go1, and Go2 joystick environments. Subclasses override hooks
to provide robot-specific behaviour.
"""
from __future__ import annotations
from typing import Any, cast
import numpy as np
from unilab.dr import (
DomainRandomizationCapabilities,
DomainRandomizationProvider,
IntervalRandomizationPlan,
ResetPlan,
)
from unilab.dr.dr_utils import (
build_common_reset_randomization,
build_interval_push_plan,
validate_common_reset_randomization,
validate_interval_push_support,
zero_actions,
)
from unilab.dtype_config import get_global_dtype
from unilab.envs.common.rotation import np_quat_mul, np_yaw_to_quat
[docs]
class LocomotionDRProvider(DomainRandomizationProvider):
"""Base DR provider for locomotion joystick environments.
Shared logic:
- ``validate``, ``build_interval_randomization_plan``, ``_sample_commands``
- ``build_reset_plan`` (template with hooks)
- ``build_reset_observation`` (template with hook)
Override these hooks in subclasses:
- ``_get_qvel_limit`` — default ``0.5``
- ``_build_extra_info_updates`` — default empty dict
- ``_compute_reset_obs`` — must be implemented per robot
"""
# ── shared methods ───────────────────────────────────────────────
def _get_base_actuator_gains(self, env: Any) -> tuple[np.ndarray | None, np.ndarray | None]:
"""Return (base_kp, base_kd) for per-joint kp/kd domain randomization.
Override to provide per-joint gains cached at init time.
Returns ``(None, None)`` by default, which falls back to scalar
``ControlConfig.Kp`` / ``ControlConfig.Kd``.
"""
return None, None
def _get_reset_randomization_baselines(
self, env: Any
) -> tuple[np.ndarray | None, np.ndarray | None, int | None, np.ndarray | None]:
"""Return cached model tables used for reset-time randomization."""
return None, None, None, None
[docs]
def validate(self, env: Any, capabilities: DomainRandomizationCapabilities) -> None:
base_kp, base_kd = self._get_base_actuator_gains(env)
base_body_mass, base_geom_friction, ground_geom_id, base_dof_armature = (
self._get_reset_randomization_baselines(env)
)
validate_common_reset_randomization(
env,
capabilities,
base_kp=base_kp,
base_kd=base_kd,
base_body_mass=base_body_mass,
base_geom_friction=base_geom_friction,
ground_geom_id=ground_geom_id,
base_dof_armature=base_dof_armature,
)
validate_interval_push_support(env, capabilities)
[docs]
def build_interval_randomization_plan(
self, env: Any, step_counter: int
) -> IntervalRandomizationPlan | None:
return build_interval_push_plan(env, step_counter)
def _sample_commands(self, env: Any, num_reset: int) -> np.ndarray:
low = np.asarray(env.cfg.commands.vel_limit[0], dtype=get_global_dtype())
high = np.asarray(env.cfg.commands.vel_limit[1], dtype=get_global_dtype())
return np.asarray(
np.random.uniform(low=low, high=high, size=(num_reset, 3)), dtype=get_global_dtype()
)
# ── template: build_reset_plan ───────────────────────────────────
def _get_qvel_limit(self, env: Any) -> float:
"""Return the base qvel reset limit. Override for configurable limits."""
return 0.5
def _build_extra_info_updates(self, env: Any, num_reset: int) -> dict[str, np.ndarray]:
"""Return additional info_updates entries (e.g. gait_phase for G1)."""
return {}
[docs]
def build_reset_plan(self, env: Any, env_ids: np.ndarray) -> ResetPlan:
num_reset = len(env_ids)
qpos = np.tile(env._init_qpos, (num_reset, 1))
qvel = np.tile(env._init_qvel, (num_reset, 1))
qpos[:, 0:2] += np.random.uniform(-0.5, 0.5, (num_reset, 2))
yaw = np.random.uniform(-np.pi, np.pi, (num_reset,))
qpos[:, 3:7] = np_quat_mul(qpos[:, 3:7], np_yaw_to_quat(yaw))
qpos[:, 0:3] = env._spawn.apply_spawn(env_ids, qpos[:, 0:3], yaw=yaw)
limit = self._get_qvel_limit(env)
qvel[:, 0:6] = np.asarray(
np.random.uniform(-limit, limit, size=(num_reset, 6)), dtype=get_global_dtype()
)
info_updates: dict[str, Any] = {
"commands": self._sample_commands(env, num_reset),
"current_actions": zero_actions(num_reset, env._num_action),
"last_actions": zero_actions(num_reset, env._num_action),
}
info_updates.update(self._build_extra_info_updates(env, num_reset))
base_kp, base_kd = self._get_base_actuator_gains(env)
base_body_mass, base_geom_friction, ground_geom_id, base_dof_armature = (
self._get_reset_randomization_baselines(env)
)
env._spawn.record_episode_start(env_ids, qpos[:, 0:3])
return ResetPlan(
env_ids=env_ids,
qpos=qpos,
qvel=qvel,
info_updates=info_updates,
randomization=build_common_reset_randomization(
env,
num_reset,
base_kp=base_kp,
base_kd=base_kd,
base_body_mass=base_body_mass,
base_geom_friction=base_geom_friction,
ground_geom_id=ground_geom_id,
base_dof_armature=base_dof_armature,
),
)
# ── template: build_reset_observation ────────────────────────────
def _compute_reset_obs(
self,
env: Any,
env_ids: np.ndarray,
info_updates: dict[str, Any],
linvel: np.ndarray,
gyro: np.ndarray,
gravity: np.ndarray,
dof_pos: np.ndarray,
dof_vel: np.ndarray,
) -> dict[str, np.ndarray]:
"""Compute reset observations. Override per robot to pass correct args to _compute_obs."""
raise NotImplementedError
[docs]
def build_reset_observation(
self, env: Any, env_ids: np.ndarray, info_updates: dict[str, Any]
) -> dict[str, np.ndarray]:
linvel = env.get_local_linvel()[env_ids]
gyro = env.get_gyro()[env_ids]
gravity_sensor = getattr(getattr(env.cfg, "sensor", None), "upvector", "upvector")
gravity = env._backend.get_sensor_data(gravity_sensor)[env_ids]
dof_pos = env.get_dof_pos()[env_ids]
dof_vel = env.get_dof_vel()[env_ids]
return cast(
dict[str, np.ndarray],
self._compute_reset_obs(
env, env_ids, info_updates, linvel, gyro, gravity, dof_pos, dof_vel
),
)