Source code for unilab.envs.locomotion.go2.base

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np

from unilab.envs.locomotion.common.base import (
    BaseNoiseConfig,
    LocomotionBaseCfg,
    LocomotionBaseEnv,
    PdControlConfig,
)


[docs] @dataclass class NoiseConfig(BaseNoiseConfig): pass
[docs] @dataclass class ControlConfig(PdControlConfig): pass
[docs] @dataclass class Asset: base_name = "base" foot_name = "foot" ground = "floor"
[docs] @dataclass class Go2BaseCfg(LocomotionBaseCfg): noise_config: NoiseConfig = field(default_factory=NoiseConfig) # type: ignore[assignment] control_config: ControlConfig = field(default_factory=ControlConfig) # type: ignore[assignment] asset: Asset = field(default_factory=Asset) sim_dt: float = 0.01 ctrl_dt: float = 0.02
[docs] class Go2BaseEnv(LocomotionBaseEnv): _cfg: Go2BaseCfg
[docs] def get_foot_pos(self) -> np.ndarray: """Get foot positions. Returns shape (num_envs, 4, 3)""" foot_names = ["FL_pos", "FR_pos", "RL_pos", "RR_pos"] foot_pos = [self._backend.get_sensor_data(name) for name in foot_names] return np.stack(foot_pos, axis=1)
[docs] def get_foot_contact(self) -> np.ndarray: """Get foot contact forces. Returns shape (num_envs, 4)""" contact_names = ["FL_foot_contact", "FR_foot_contact", "RL_foot_contact", "RR_foot_contact"] contacts = [self._backend.get_sensor_data(name)[:, 0] for name in contact_names] return np.stack(contacts, axis=1)