Source code for unilab.algos.torch.hora.sac
"""HORA-owned SAC entry helpers."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from unilab.algos.torch.hora.runtime import HORA_SAC_RUNTIME_IMPL, is_hora_sac_runtime
from unilab.algos.torch.hora.sac_learner import HoraSACLearner
from unilab.algos.torch.offpolicy.runtime import OffPolicyRuntime
[docs]
@dataclass(frozen=True)
class HoraSACRuntime(OffPolicyRuntime):
"""Resolved HORA-SAC hooks consumed by the generic off-policy script."""
learner_cls: type[Any] | None = HoraSACLearner
algo_type: str | None = HORA_SAC_RUNTIME_IMPL
supports_symmetry: bool = False
actor_cfg: dict[str, Any] = field(default_factory=dict)
[docs]
def build_model_kwargs(self, *, obs_dim: int, critic_obs_dim: int) -> dict[str, Any]:
"""Build shared learner/collector actor kwargs for the HORA-SAC actor."""
priv_info_dim = int(critic_obs_dim - obs_dim)
if priv_info_dim <= 0:
raise ValueError(
"HORA-SAC requires critic observations to contain privileged tail "
f"features; got obs_dim={obs_dim}, critic_obs_dim={critic_obs_dim}."
)
return {
"priv_info_dim": priv_info_dim,
"priv_info_embed_dim": int(self.actor_cfg.get("priv_info_embed_dim", 9)),
"priv_mlp_hidden_dims": tuple(
self.actor_cfg.get("priv_mlp_hidden_dims", (256, 128, 9))
),
}
[docs]
def resolve_hora_sac_runtime(rl_cfg: dict[str, Any]) -> HoraSACRuntime | None:
"""Resolve HORA-SAC hooks from an explicit owner-config runtime marker."""
if not is_hora_sac_runtime(rl_cfg):
return None
actor_cfg_raw = rl_cfg.get("actor", {})
actor_cfg = actor_cfg_raw if isinstance(actor_cfg_raw, dict) else {}
return HoraSACRuntime(actor_cfg=dict(actor_cfg))
__all__ = ["HoraSACRuntime", "resolve_hora_sac_runtime"]