Source code for unilab.algos.torch.appo.runtime

"""Runtime resolution helpers for APPO script assembly.

This module keeps entrypoint scripts generic: they resolve an APPO runtime bundle
from owner config and then call the returned runner/play entrypoints without
knowing which concrete runtime implementation is active.
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any


[docs] @dataclass(frozen=True) class APPORuntime: """Resolved APPO runtime entrypoints consumed by the generic script. Args: runner_cls: Runner class used for APPO training mode. play_fn: Play-mode callable used for checkpoint playback. Returns: Immutable APPO runtime bundle selected from owner config. """ runner_cls: type[Any] play_fn: Callable[..., str | None]
[docs] def resolve_appo_runtime( rl_cfg: dict[str, Any], *, default_play_fn: Callable[..., str | None], ) -> APPORuntime: """Resolve the APPO runtime bundle from owner config. Args: rl_cfg: Resolved algorithm config dictionary from Hydra composition. default_play_fn: Generic APPO play function used when no custom runtime resolver is selected by the owner config. Returns: ``APPORuntime`` containing the train and play entrypoints for the selected APPO runtime. """ runtime_resolver = rl_cfg.get("runtime_resolver") if runtime_resolver in (None, ""): from unilab.algos.torch.appo.runner import APPORunner return APPORuntime(runner_cls=APPORunner, play_fn=default_play_fn) from rsl_rl.utils import resolve_callable resolver = resolve_callable(str(runtime_resolver)) runtime = resolver(rl_cfg) if runtime is None: raise ValueError( f"APPO runtime resolver {runtime_resolver!r} returned None for rl_cfg runtime selection." ) runner_cls = getattr(runtime, "runner_cls", None) play_fn = getattr(runtime, "play_fn", None) if runner_cls is None or play_fn is None: raise TypeError( f"APPO runtime resolver {runtime_resolver!r} must return an object with " "'runner_cls' and 'play_fn' attributes." ) return APPORuntime(runner_cls=runner_cls, play_fn=play_fn)