Source code for unilab.structured_configs
"""Typed dataclass configs for all training algorithms.
Replaces ml_collections.ConfigDict factory functions.
Use OmegaConf / Hydra to compose these at runtime.
"""
from __future__ import annotations
import dataclasses
from dataclasses import dataclass, field
from typing import Any, Optional, cast
[docs]
class BaseConfig:
[docs]
def to_dict(self) -> dict[str, Any]:
return cast(dict[str, Any], dataclasses.asdict(cast(Any, self)))
# ── Off-policy: SAC ──────────────────────────────────────────────────────────
[docs]
@dataclass
class SACAlgoParams:
alpha_lr: float = 3e-4
alpha_init: float = 0.01
target_entropy_ratio: float = 0.0
max_grad_norm: float = 0.0
amp_dtype: str = "auto"
use_compile: bool = True
[docs]
@dataclass
class SACConfig(BaseConfig):
algo: str = "sac"
algo_log_name: str = "fast_sac"
runtime_impl: Optional[str] = None
runtime_resolver: Optional[str] = None
seed: int = 1
num_envs: int = 4096
batch_size: int = 8192
replay_buffer_n: int = 512
updates_per_step: int = 4
learning_starts: int = 1
policy_frequency: int = 4
env_steps_per_sync: int = 1
max_iterations: int = 500
save_interval: int = 500
gamma: float = 0.97
tau: float = 0.125
actor_lr: float = 3e-4
critic_lr: float = 3e-4
actor_hidden_dim: int = 512
critic_hidden_dim: int = 768
num_atoms: int = 101
obs_normalization: bool = True
use_layer_norm: bool = True
use_symmetry: bool = False
actor: dict[str, Any] = field(default_factory=dict)
algo_params: SACAlgoParams = field(default_factory=SACAlgoParams)
# ── Off-policy: TD3 ──────────────────────────────────────────────────────────
[docs]
@dataclass
class TD3AlgoParams:
weight_decay: float = 0.1
v_min: float = -10.0
v_max: float = 10.0
init_scale: float = 0.01
log_std_min: float = -0.9
log_std_max: float = 0.0
policy_noise: float = 0.2
noise_clip: float = 0.5
use_cdq: bool = True
[docs]
@dataclass
class TD3Config(BaseConfig):
algo: str = "td3"
algo_log_name: str = "fast_td3"
seed: int = 1
num_envs: int = 4096
batch_size: int = 8192
replay_buffer_n: int = 1000
updates_per_step: int = 4
learning_starts: int = 1
policy_frequency: int = 2
env_steps_per_sync: int = 1
max_iterations: int = 5000
save_interval: int = 500
gamma: float = 0.97
tau: float = 0.1
actor_lr: float = 3e-4
critic_lr: float = 3e-4
actor_hidden_dim: int = 256
critic_hidden_dim: int = 512
num_atoms: int = 101
obs_normalization: bool = True
use_layer_norm: bool = False
algo_params: TD3AlgoParams = field(default_factory=TD3AlgoParams)
# ── Off-policy: FlashSAC ─────────────────────────────────────────────────────
[docs]
@dataclass
class FlashSACAlgoParams:
normalize_reward: bool = True
normalized_g_max: float = 5.0
actor_num_blocks: int = 2
critic_num_blocks: int = 2
actor_bc_alpha: float = 0.0
actor_noise_zeta_mu: float = 2.0
actor_noise_zeta_max: int = 16
critic_min_v: float = -5.0
critic_max_v: float = 5.0
temp_initial_value: float = 0.01
temp_target_sigma: float = 0.15
temp_target_entropy: float | None = None
learning_rate_init: float = 3e-4
learning_rate_peak: float = 3e-4
learning_rate_end: float = 1.5e-4
learning_rate_warmup_steps: int = 0
learning_rate_decay_steps: int = 500000
n_step: int = 1
amp_dtype: str = "auto"
use_compile: bool = True
[docs]
@dataclass
class FlashSACConfig(BaseConfig):
algo: str = "flashsac"
algo_log_name: str = "flash_sac"
seed: int = 1
num_envs: int = 1024
batch_size: int = 2048
replay_buffer_n: int = 512
updates_per_step: int = 2
learning_starts: int = 98
policy_frequency: int = 2
env_steps_per_sync: int = 1
max_iterations: int = 5000
save_interval: int = 1000
gamma: float = 0.97
tau: float = 0.01
actor_lr: float = 3e-4
critic_lr: float = 3e-4
actor_hidden_dim: int = 128
critic_hidden_dim: int = 256
num_atoms: int = 101
obs_normalization: bool = False
use_layer_norm: bool = False
algo_params: FlashSACAlgoParams = field(default_factory=FlashSACAlgoParams)
# ── APPO ─────────────────────────────────────────────────────────────────────
[docs]
@dataclass
class APPOAlgorithmConfig:
num_learning_epochs: int = 5
num_mini_batches: int = 4
clip_param: float = 0.2
gamma: float = 0.99
lam: float = 0.95
value_loss_coef: float = 1.0
entropy_coef: float = 0.01
learning_rate: float = 1e-3
max_grad_norm: float = 1.0
use_clipped_value_loss: bool = True
schedule: str = "adaptive"
desired_kl: float = 0.01
adaptive_kl_factor: float = 1.2
adaptive_lr_factor: float = 1.1
optimizer: str = "adam"
tau: float = 1.0
target_update_freq: int = 1
vtrace_clip_rho: float = 1.0
vtrace_clip_c: float = 1.0
enable_compile: bool = True
[docs]
@dataclass
class APPODistributionConfig:
class_name: str = "rsl_rl.modules.distribution.GaussianDistribution"
init_std: float = 1.0
std_type: str = "scalar"
[docs]
@dataclass
class APPOActorConfig:
class_name: str = "rsl_rl.models.MLPModel"
hidden_dims: list = field(default_factory=lambda: [512, 256, 128])
activation: str = "elu"
distribution_cfg: APPODistributionConfig = field(default_factory=APPODistributionConfig)
[docs]
@dataclass
class APPOCriticConfig:
class_name: str = "rsl_rl.models.MLPModel"
hidden_dims: list = field(default_factory=lambda: [512, 256, 128])
activation: str = "elu"
[docs]
@dataclass
class APPOConfig(BaseConfig):
algo: str = "appo"
algo_log_name: str = "appo"
seed: int = 1
num_envs: int = 2048
steps_per_env: int = 24
max_iterations: int = 150
save_interval: int = 50
obs_groups: dict = field(default_factory=lambda: {"actor": {"policy": 0}})
actor: APPOActorConfig = field(default_factory=APPOActorConfig)
critic: APPOCriticConfig = field(default_factory=APPOCriticConfig)
algorithm: APPOAlgorithmConfig = field(default_factory=APPOAlgorithmConfig)
# ── PPO (rsl-rl) ─────────────────────────────────────────────────────────────
[docs]
@dataclass
class PPOPolicyConfig:
init_noise_std: float = 1.0
actor_hidden_dims: list = field(default_factory=lambda: [512, 256, 128])
critic_hidden_dims: list = field(default_factory=lambda: [512, 256, 128])
activation: str = "elu"
class_name: str = "ActorCritic"
[docs]
@dataclass
class PPOAlgorithmConfig:
class_name: str = "unilab.algos.torch.rsl_rl_ppo:FinalObservationAwarePPO"
value_loss_coef: float = 1.0
use_clipped_value_loss: bool = True
clip_param: float = 0.2
entropy_coef: float = 0.01
num_learning_epochs: int = 5
num_mini_batches: int = 4
learning_rate: float = 1e-3
schedule: str = "adaptive"
gamma: float = 0.99
lam: float = 0.95
desired_kl: float = 0.01
target_kl_stop: Optional[float] = None
max_grad_norm: float = 1.0
adaptive_kl_beta: float = 0.9
adaptive_lr_growth: float = 1.1
adaptive_lr_decay: float = 1.2
adaptive_lr_update_interval: int = 5
metrics_interval: int = 8
finite_check_interval: int = 8
enable_compile: bool = True
warmup_strict_iters: int = 10
warmup_metrics_interval: int = 2
warmup_finite_check_interval: int = 2
disable_finite_checks: bool = True
[docs]
@dataclass
class PPOConfig(BaseConfig):
algo: str = "ppo"
algo_log_name: str = "rsl_rl_ppo"
seed: int = 1
num_envs: int = 4096
num_steps_per_env: int = 24
max_iterations: int = 101
save_interval: int = 100
empirical_normalization: bool = False
runner_class_name: str = "OnPolicyRunner"
obs_groups: dict = field(default_factory=lambda: {"default": ["policy"]})
experiment_name: str = "test"
run_name: str = ""
resume: bool = False
load_run: str = "-1"
checkpoint: int = -1
resume_path: Optional[str] = None
policy: PPOPolicyConfig = field(default_factory=PPOPolicyConfig)
algorithm: PPOAlgorithmConfig = field(default_factory=PPOAlgorithmConfig)