Source code for unilab.algos.torch.offpolicy.runtime

"""Runtime resolution helpers for off-policy script assembly."""

from __future__ import annotations

from dataclasses import dataclass, field
from importlib import import_module
from typing import Any


[docs] @dataclass(frozen=True) class OffPolicyRuntime: """Optional runtime overrides for the generic off-policy SAC path. All fields are optional so custom runtimes only declare the behaviour they need to change from standard SAC. """ learner_cls: type[Any] | None = None algo_type: str | None = None actor_kwargs: dict[str, Any] = field(default_factory=dict) supports_symmetry: bool = True
[docs] def build_model_kwargs(self, *, obs_dim: int, critic_obs_dim: int) -> dict[str, Any]: """Build kwargs shared by learner construction and collector actor construction.""" del obs_dim, critic_obs_dim return dict(self.actor_kwargs)
def _resolve_callable(path: str) -> Any: module_path: str attr_name: str if ":" in path: module_path, attr_name = path.split(":", 1) else: module_path, _, attr_name = path.rpartition(".") if not module_path or not attr_name: raise ValueError(f"Invalid runtime resolver path: {path!r}") resolved = getattr(import_module(module_path), attr_name) if not callable(resolved): raise TypeError(f"Runtime resolver {path!r} is not callable.") return resolved
[docs] def resolve_custom_offpolicy_runtime(rl_cfg: dict[str, Any]) -> OffPolicyRuntime | None: """Resolve an optional custom off-policy runtime from owner config.""" runtime_resolver = rl_cfg.get("runtime_resolver") if runtime_resolver in (None, ""): runtime_impl = rl_cfg.get("runtime_impl") if runtime_impl not in (None, ""): raise ValueError( "Off-policy owner config selected " f"runtime_impl={runtime_impl!r} but did not define algo.runtime_resolver." ) return None resolver = _resolve_callable(str(runtime_resolver)) runtime = resolver(rl_cfg) if runtime is None: raise ValueError( f"Off-policy runtime resolver {runtime_resolver!r} returned None " "for rl_cfg runtime selection." ) if not isinstance(runtime, OffPolicyRuntime): raise TypeError( f"Off-policy runtime resolver {runtime_resolver!r} must return " "an OffPolicyRuntime instance." ) return runtime