Source code for unilab.dr.types
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import numpy as np
RESET_TERM_BASE_COM = "base_com_offset"
RESET_TERM_BASE_MASS = "base_mass_delta"
RESET_TERM_GRAVITY = "gravity"
RESET_TERM_BODY_IQUAT = "body_iquat"
RESET_TERM_BODY_INERTIA = "body_inertia"
RESET_TERM_BODY_IPOS = "body_ipos"
RESET_TERM_BODY_MASS = "body_mass"
RESET_TERM_DOF_ARMATURE = "dof_armature"
RESET_TERM_GEOM_FRICTION = "geom_friction"
RESET_TERM_KP = "kp"
RESET_TERM_KD = "kd"
[docs]
@dataclass(frozen=True)
class GeomSizeOverride:
geom_name: str
size: tuple[float, ...]
[docs]
@dataclass(frozen=True)
class ModelVariantSpec:
geom_size_overrides: tuple[GeomSizeOverride, ...] = field(default_factory=tuple)
[docs]
def is_empty(self) -> bool:
return not self.geom_size_overrides
[docs]
@dataclass(frozen=True)
class DomainRandomizationCapabilities:
supported_reset_terms: frozenset[str] = field(default_factory=frozenset)
supports_interval_push: bool = False
supports_interval_body_velocity_delta: bool = False
supports_interval_body_force: bool = False
[docs]
def supports_reset_term(self, term: str) -> bool:
return term in self.supported_reset_terms
[docs]
def get_unsupported_reset_terms(self, requested_terms: frozenset[str]) -> frozenset[str]:
return frozenset(term for term in requested_terms if not self.supports_reset_term(term))
[docs]
def filter_reset_payload(
self, payload: ResetRandomizationPayload
) -> tuple[ResetRandomizationPayload | None, frozenset[str]]:
unsupported = self.get_unsupported_reset_terms(payload.requested_terms())
if not unsupported:
return payload, frozenset()
filtered = ResetRandomizationPayload(
base_mass_delta=(
payload.base_mass_delta if self.supports_reset_term(RESET_TERM_BASE_MASS) else None
),
base_com_offset=(
payload.base_com_offset if self.supports_reset_term(RESET_TERM_BASE_COM) else None
),
gravity=payload.gravity if self.supports_reset_term(RESET_TERM_GRAVITY) else None,
body_iquat=(
payload.body_iquat if self.supports_reset_term(RESET_TERM_BODY_IQUAT) else None
),
body_inertia=(
payload.body_inertia if self.supports_reset_term(RESET_TERM_BODY_INERTIA) else None
),
body_ipos=(
payload.body_ipos if self.supports_reset_term(RESET_TERM_BODY_IPOS) else None
),
body_mass=(
payload.body_mass if self.supports_reset_term(RESET_TERM_BODY_MASS) else None
),
dof_armature=(
payload.dof_armature if self.supports_reset_term(RESET_TERM_DOF_ARMATURE) else None
),
geom_friction=(
payload.geom_friction
if self.supports_reset_term(RESET_TERM_GEOM_FRICTION)
else None
),
kp=payload.kp if self.supports_reset_term(RESET_TERM_KP) else None,
kd=payload.kd if self.supports_reset_term(RESET_TERM_KD) else None,
)
return (None if filtered.is_empty() else filtered), unsupported
[docs]
@dataclass
class ResetRandomizationPayload:
base_mass_delta: np.ndarray | None = None
base_com_offset: np.ndarray | None = None
gravity: np.ndarray | None = None
body_iquat: np.ndarray | None = None
body_inertia: np.ndarray | None = None
body_ipos: np.ndarray | None = None
body_mass: np.ndarray | None = None
dof_armature: np.ndarray | None = None
geom_friction: np.ndarray | None = None
kp: np.ndarray | None = None
kd: np.ndarray | None = None
[docs]
def requested_terms(self) -> frozenset[str]:
terms: set[str] = set()
if self.base_mass_delta is not None:
terms.add(RESET_TERM_BASE_MASS)
if self.base_com_offset is not None:
terms.add(RESET_TERM_BASE_COM)
if self.gravity is not None:
terms.add(RESET_TERM_GRAVITY)
if self.body_iquat is not None:
terms.add(RESET_TERM_BODY_IQUAT)
if self.body_inertia is not None:
terms.add(RESET_TERM_BODY_INERTIA)
if self.body_ipos is not None:
terms.add(RESET_TERM_BODY_IPOS)
if self.body_mass is not None:
terms.add(RESET_TERM_BODY_MASS)
if self.dof_armature is not None:
terms.add(RESET_TERM_DOF_ARMATURE)
if self.geom_friction is not None:
terms.add(RESET_TERM_GEOM_FRICTION)
if self.kp is not None:
terms.add(RESET_TERM_KP)
if self.kd is not None:
terms.add(RESET_TERM_KD)
return frozenset(terms)
[docs]
def is_empty(self) -> bool:
return not self.requested_terms()
[docs]
@dataclass
class IntervalRandomizationPlan:
push_perturbation_limit: Sequence[float] | np.ndarray | None = None
body_ids: np.ndarray | None = None
body_linear_velocity_delta: np.ndarray | None = None
body_force: np.ndarray | None = None
[docs]
def is_empty(self) -> bool:
return (
self.push_perturbation_limit is None
and self.body_linear_velocity_delta is None
and self.body_force is None
)
[docs]
@dataclass
class InitRandomizationPlan:
model_assignments: np.ndarray
model_variants: tuple[ModelVariantSpec, ...]
[docs]
def is_empty(self) -> bool:
return len(self.model_variants) == 0
[docs]
@dataclass
class ResetPlan:
env_ids: np.ndarray
qpos: np.ndarray
qvel: np.ndarray
info_updates: dict[str, Any]
randomization: ResetRandomizationPayload | None = None