Source code for unilab.algos.torch.common.device
from unilab.base import registry
[docs]
def get_env_dims(
env_name: str, sim_backend: str = "mujoco", env_cfg_override: dict | None = None
) -> tuple[int, int, int]:
"""Get (actor_obs_dim, action_dim, critic_obs_dim) from environment."""
from unilab.base.observations import get_obs_dims as get_obs_dims_from_spec
env = registry.make(
env_name, num_envs=1, sim_backend=sim_backend, env_cfg_override=env_cfg_override
)
obs_dim, critic_dim = get_obs_dims_from_spec(env.obs_groups_spec)
action_shape = env.action_space.shape
assert action_shape is not None
action_dim = action_shape[0]
env.close() # type: ignore[attr-defined]
return obs_dim, action_dim, critic_dim
__all__ = ["get_env_dims"]