Source code for unilab.base.observations

from __future__ import annotations

import numpy as np


[docs] def flatten_obs_dict(obs: dict[str, np.ndarray]) -> np.ndarray: """Concatenate obs groups in insertion order -> flat (N, total_dim) array.""" return np.concatenate(list(obs.values()), axis=1)
[docs] def flatten_policy_obs_dict(obs: dict[str, np.ndarray]) -> np.ndarray: """Build actor-policy inputs from the single actor observation group.""" return obs["obs"]
[docs] def split_obs_dict(obs: dict[str, np.ndarray]) -> tuple[np.ndarray, np.ndarray]: """Split observation dict into (actor_obs, critic_obs). When no separate critic group exists, critic_obs == actor_obs. """ actor = obs["obs"] return actor, obs.get("critic", actor)
[docs] def get_obs_dims(obs_groups_spec: dict[str, int]) -> tuple[int, int]: """Extract (actor_obs_dim, critic_obs_dim) from obs_groups_spec. When no separate critic group exists, critic_obs_dim == actor_obs_dim. """ obs_dim = obs_groups_spec.get("obs", 0) return obs_dim, obs_groups_spec.get("critic", obs_dim)
[docs] def get_critic_base_dim(obs_groups_spec: dict[str, int]) -> int: """Get critic observation dim, falling back to actor obs when absent.""" critic_dim = obs_groups_spec.get("critic", 0) return critic_dim if critic_dim > 0 else obs_groups_spec.get("obs", 0)