Source code for unilab.training.common
"""Shared helpers for training entrypoints."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from unilab.base.registry import ensure_registries as _ensure_registries
[docs]
def ensure_registries() -> None:
"""Import env modules so registry-based entrypoints can instantiate tasks."""
_ensure_registries()
[docs]
def get_hydra_runtime_choice(cfg: DictConfig, group: str) -> str | None:
"""Return a selected Hydra config-group choice when runtime metadata is available."""
cfg_choice = OmegaConf.select(cfg, f"hydra.runtime.choices.{group}")
if cfg_choice is not None:
return str(cfg_choice)
if not HydraConfig.initialized():
return None
try:
runtime_choice = HydraConfig.get().runtime.choices.get(group)
except Exception:
return None
return str(runtime_choice) if runtime_choice is not None else None
[docs]
def assert_offpolicy_task_choice_matches_algo(
cfg: DictConfig,
*,
algo_name: str | None = None,
) -> None:
"""Reject offpolicy configs whose task owner path does not match the selected algo."""
cfg_algo_name = str(OmegaConf.select(cfg, "algo.algo"))
if algo_name is not None and cfg_algo_name != algo_name:
raise ValueError(
f"Off-policy algo argument {algo_name!r} is inconsistent with cfg.algo.algo={cfg_algo_name!r}"
)
selected_algo = algo_name or cfg_algo_name
task_choice = get_hydra_runtime_choice(cfg, "task")
if task_choice is None:
return
task_algo, sep, _ = task_choice.partition("/")
if not sep:
raise ValueError(
f"Off-policy task choice must use task=<algo>/<task>/<backend>; got task={task_choice}"
)
if task_algo != selected_algo:
raise ValueError(
f"Off-policy algo/task mismatch: algo={selected_algo} is inconsistent with task={task_choice}. "
"Use task=<algo>/<task>/<backend> with the same algo prefix."
)
[docs]
def setup_logger(
log_dir: str | Path,
algo_name: str,
*,
echo: bool = True,
filename: str = "train.log",
) -> logging.Logger:
"""Create a simple file-backed logger for script-local progress messages."""
path = Path(log_dir)
path.mkdir(parents=True, exist_ok=True)
logger_name = f"unilab.training.{algo_name}.{path.resolve()}"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
logger.propagate = False
for handler in list(logger.handlers):
logger.removeHandler(handler)
handler.close()
formatter = logging.Formatter("%(message)s")
file_handler = logging.FileHandler(path / filename, encoding="utf-8")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if echo:
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
[docs]
def create_env(
cfg: DictConfig,
*,
num_envs: int,
env_cfg_override: dict[str, Any] | None = None,
sim_backend: str | None = None,
task_name: str | None = None,
):
"""Construct an environment via the registry using the current Hydra config."""
from unilab.base import registry
return registry.make(
task_name or str(OmegaConf.select(cfg, "training.task_name")),
num_envs=num_envs,
sim_backend=sim_backend or str(OmegaConf.select(cfg, "training.sim_backend")),
env_cfg_override=env_cfg_override,
)