unilab.base.observations

Functions

flatten_obs_dict(obs)

Concatenate obs groups in insertion order -> flat (N, total_dim) array.

flatten_policy_obs_dict(obs)

Build actor-policy inputs from the single actor observation group.

get_critic_base_dim(obs_groups_spec)

Get critic observation dim, falling back to actor obs when absent.

get_obs_dims(obs_groups_spec)

Extract (actor_obs_dim, critic_obs_dim) from obs_groups_spec.

split_obs_dict(obs)

Split observation dict into (actor_obs, critic_obs).

unilab.base.observations.flatten_obs_dict(obs)[source]

Concatenate obs groups in insertion order -> flat (N, total_dim) array.

Parameters:

obs (dict[str, ndarray])

Return type:

ndarray

unilab.base.observations.flatten_policy_obs_dict(obs)[source]

Build actor-policy inputs from the single actor observation group.

Parameters:

obs (dict[str, ndarray])

Return type:

ndarray

unilab.base.observations.split_obs_dict(obs)[source]

Split observation dict into (actor_obs, critic_obs).

When no separate critic group exists, critic_obs == actor_obs.

Parameters:

obs (dict[str, ndarray])

Return type:

tuple[ndarray, ndarray]

unilab.base.observations.get_obs_dims(obs_groups_spec)[source]

Extract (actor_obs_dim, critic_obs_dim) from obs_groups_spec.

When no separate critic group exists, critic_obs_dim == actor_obs_dim.

Parameters:

obs_groups_spec (dict[str, int])

Return type:

tuple[int, int]

unilab.base.observations.get_critic_base_dim(obs_groups_spec)[source]

Get critic observation dim, falling back to actor obs when absent.

Parameters:

obs_groups_spec (dict[str, int])

Return type:

int