Source code for unilab.envs.locomotion.common.terrain_spawn

"""Spawn-origin managers for locomotion envs.

``BaseSpawnManager`` is a no-op default: every env spawns at the world origin
(plus the existing per-env xy jitter from the dr_provider). Used whenever the
env has no procedural terrain — flat scenes don't need spatial separation

``TerrainSpawnManager`` overrides this for terrain scenes: it indexes
``terrain_origins[level, type_col]`` so each env spawns on a specific cell, and
optionally promotes/demotes ``level`` per-env on episode end. With
``enabled=True`` levels start at 0; with ``enabled=False`` levels are uniformly
distributed and never change — but spawn still uses cell-aware xyz so robots
land on the correct surface height.
"""

from __future__ import annotations

from dataclasses import dataclass

import numpy as np


[docs] class BaseSpawnManager: """Default no-op spawn manager: returns zeros, records nothing."""
[docs] def origins_for(self, env_ids: np.ndarray) -> np.ndarray: return np.zeros((env_ids.shape[0], 3), dtype=np.float64)
[docs] def apply_spawn( self, env_ids: np.ndarray, qpos_xyz: np.ndarray, *, yaw: np.ndarray | None = None, ) -> np.ndarray: del yaw return np.asarray(qpos_xyz, dtype=np.float64) + self.origins_for(env_ids)
[docs] def record_episode_start(self, env_ids: np.ndarray, qpos_xyz: np.ndarray) -> None: del env_ids, qpos_xyz
[docs] def update_on_done(self, done_indices: np.ndarray, current_xyz: np.ndarray) -> dict[str, float]: del done_indices, current_xyz return {}
[docs] @dataclass class TerrainCurriculumCfg: enabled: bool = False """If True, levels start at 0 and evolve via promote/demote.""" promote_frac: float = 0.5 """Walked distance > promote_frac * cell_size promotes one level.""" demote_frac: float = 0.25 """Walked distance < demote_frac * cell_size demotes one level.""" cycle_top_frac: float = 0.5 """When level overflows the top row, resample uniformly in ``[num_rows * cycle_top_frac, num_rows - 1]``.""" spawn_height_margin: float = 0.05 """Extra z added on top of the sampled terrain surface height.""" seed: int | None = None
[docs] class TerrainSpawnManager(BaseSpawnManager):
[docs] def __init__( self, num_envs: int, terrain_origins: np.ndarray, cell_size: float, cfg: TerrainCurriculumCfg, terrain_surface_sampler: object | None = None, spawn_height_points: np.ndarray | None = None, ) -> None: if terrain_origins.ndim != 3 or terrain_origins.shape[2] != 3: raise ValueError( f"terrain_origins must have shape (num_rows, num_cols, 3); " f"got {terrain_origins.shape}" ) num_rows, num_cols, _ = terrain_origins.shape if cfg.enabled and num_rows < 2: raise ValueError( f"Curriculum requires terrain_generator.num_rows >= 2; got {num_rows}." ) self._terrain_origins = terrain_origins.astype(np.float64, copy=False) self._num_rows = num_rows self._num_cols = num_cols self._cell_size = float(cell_size) self._cfg = cfg self._terrain_surface_sampler = terrain_surface_sampler if spawn_height_points is None: self._spawn_height_points = np.zeros((1, 3), dtype=np.float64) else: points = np.asarray(spawn_height_points, dtype=np.float64) if points.ndim != 2 or points.shape[1] != 3: raise ValueError( f"spawn_height_points must have shape (num_points, 3), got {points.shape}" ) self._spawn_height_points = points self._rng = np.random.default_rng(cfg.seed) self.type_cols = self._rng.integers(0, num_cols, size=num_envs).astype(np.int32) if cfg.enabled: self.levels = np.zeros(num_envs, dtype=np.int32) else: self.levels = self._rng.integers(0, num_rows, size=num_envs).astype(np.int32) self._episode_start_xyz = np.zeros((num_envs, 3), dtype=np.float64) self._has_started = np.zeros(num_envs, dtype=bool)
@property def enabled(self) -> bool: return self._cfg.enabled
[docs] def origins_for(self, env_ids: np.ndarray) -> np.ndarray: rows = self.levels[env_ids] cols = self.type_cols[env_ids] out = self._terrain_origins[rows, cols].copy() out[:, 2] += self._cfg.spawn_height_margin return out
[docs] def apply_spawn( self, env_ids: np.ndarray, qpos_xyz: np.ndarray, *, yaw: np.ndarray | None = None, ) -> np.ndarray: rows = self.levels[env_ids] cols = self.type_cols[env_ids] origins = self._terrain_origins[rows, cols] out = np.asarray(qpos_xyz, dtype=np.float64).copy() base_height = out[:, 2].copy() out[:, 0:2] += origins[:, 0:2] if self._terrain_surface_sampler is None: out[:, 2] += origins[:, 2] + self._cfg.spawn_height_margin return out required_base_z = self._sample_spawn_required_base_height( out[:, 0:2], base_height=base_height, yaw=yaw, ) out[:, 2] = required_base_z + self._cfg.spawn_height_margin return out
def _sample_spawn_required_base_height( self, base_xy: np.ndarray, *, base_height: np.ndarray, yaw: np.ndarray | None, ) -> np.ndarray: sampler = self._terrain_surface_sampler sample_height = getattr(sampler, "sample_height", None) if not callable(sample_height): raise TypeError("terrain_surface_sampler must expose sample_height(xy)") points = self._spawn_height_points local_xy = points[:, :2] local_z = points[:, 2] base_surface = np.asarray(sample_height(base_xy), dtype=np.float64) required_base_z = base_surface + np.asarray(base_height, dtype=np.float64) if yaw is None: rotated_xy = np.broadcast_to(local_xy, (base_xy.shape[0],) + local_xy.shape) else: yaw_arr = np.asarray(yaw, dtype=np.float64).reshape(-1) if yaw_arr.shape != (base_xy.shape[0],): raise ValueError(f"yaw must have shape ({base_xy.shape[0]},), got {yaw_arr.shape}") cos_yaw = np.cos(yaw_arr) sin_yaw = np.sin(yaw_arr) rotated_xy = np.empty((base_xy.shape[0], local_xy.shape[0], 2), dtype=np.float64) rotated_xy[:, :, 0] = ( cos_yaw[:, None] * local_xy[None, :, 0] - sin_yaw[:, None] * local_xy[None, :, 1] ) rotated_xy[:, :, 1] = ( sin_yaw[:, None] * local_xy[None, :, 0] + cos_yaw[:, None] * local_xy[None, :, 1] ) sample_xy = base_xy[:, None, :] + rotated_xy sampled = np.asarray(sample_height(sample_xy.reshape(-1, 2)), dtype=np.float64).reshape( base_xy.shape[0], points.shape[0] ) required_support_z = np.max(sampled - local_z[None, :], axis=1) return np.maximum(required_base_z, required_support_z)
[docs] def record_episode_start(self, env_ids: np.ndarray, qpos_xyz: np.ndarray) -> None: self._episode_start_xyz[env_ids] = qpos_xyz self._has_started[env_ids] = True
[docs] def update_on_done(self, done_indices: np.ndarray, current_xyz: np.ndarray) -> dict[str, float]: active_mask = self._has_started[done_indices] active = done_indices[active_mask] num_skipped = int((~active_mask).sum()) if active.size == 0: return { "mean_level": float(self.levels.mean()), "max_level": float(self.levels.max()), "mean_walked": 0.0, "num_promoted": 0, "num_demoted": 0, "num_skipped": num_skipped, } starts = self._episode_start_xyz[active, :2] ends = current_xyz[active_mask, :2] walked = np.linalg.norm(ends - starts, axis=1) num_promoted = 0 num_demoted = 0 if self._cfg.enabled: promote_threshold = self._cfg.promote_frac * self._cell_size demote_threshold = self._cfg.demote_frac * self._cell_size promote_mask = walked > promote_threshold demote_mask = walked < demote_threshold promote_ids = active[promote_mask] demote_ids = active[demote_mask] num_promoted = int(promote_ids.size) num_demoted = int(demote_ids.size) self.levels[promote_ids] += 1 self.levels[demote_ids] -= 1 overflow_mask = self.levels[promote_ids] >= self._num_rows if overflow_mask.any(): lo = int(self._num_rows * self._cfg.cycle_top_frac) lo = min(max(lo, 0), self._num_rows - 1) overflow_ids = promote_ids[overflow_mask] self.levels[overflow_ids] = self._rng.integers( lo, self._num_rows, size=overflow_ids.size ).astype(np.int32) np.clip(self.levels, 0, self._num_rows - 1, out=self.levels) return { "mean_level": float(self.levels.mean()), "max_level": float(self.levels.max()), "mean_walked": float(walked.mean()), "num_promoted": num_promoted, "num_demoted": num_demoted, "num_skipped": num_skipped, }