Source code for unilab.training.run

"""Run directory and checkpoint resolution helpers."""

from __future__ import annotations

import os
from os import PathLike
from pathlib import Path

from omegaconf import DictConfig, OmegaConf

from unilab.base.backend.base import BackendPlayRenderPlan, normalize_play_render_mode

_TEST_LOG_ROOT_ENV = "UNILAB_TEST_LOG_ROOT"


[docs] def should_run_playback(*, play_only: bool, no_play: bool, play_render_mode: str | None) -> bool: """Return whether train/eval should enter playback for the configured mode.""" if normalize_play_render_mode(play_render_mode) == "none": return False return bool(play_only) or not bool(no_play)
[docs] def log_playback_plan(plan: BackendPlayRenderPlan, *, prefix: str = "") -> None: """Print user-facing playback status for a resolved backend plan.""" if plan.mode == "none": print(f"{prefix}Skipping playback because training.play_render_mode=none.") return if plan.record_video: print(f"{prefix}Rendering video to {plan.output_video}...") elif plan.mode == "interactive": print(f"{prefix}Starting interactive visualization (motrix native renderer)...") print(f"{prefix}Close the render window to exit.") else: print(f"{prefix}Running playback without video recording...") print(f"{prefix}Rendering playback frames...")
[docs] def get_log_root(root_dir: str | Path, cfg: DictConfig) -> Path: """Resolve the algorithm log root, honoring optional training.log_root overrides.""" configured_root = OmegaConf.select(cfg, "training.log_root") if configured_root: log_root = Path(str(configured_root)) return log_root if log_root.is_absolute() else Path(root_dir) / log_root test_log_root = os.environ.get(_TEST_LOG_ROOT_ENV) if test_log_root: return Path(test_log_root) / str(OmegaConf.select(cfg, "algo.algo_log_name")) return Path(root_dir) / "logs" / str(OmegaConf.select(cfg, "algo.algo_log_name"))
[docs] def get_entrypoint_log_root( root_dir: str | Path, *, algo_log_name: str, log_root: str | Path | None = None, ) -> Path: """Resolve the log root for non-Hydra entrypoints using training helper semantics.""" if log_root is not None: configured_root = Path(log_root) return ( configured_root if configured_root.is_absolute() else Path(root_dir) / configured_root ) test_log_root = os.environ.get(_TEST_LOG_ROOT_ENV) if test_log_root: return Path(test_log_root) / algo_log_name return Path(root_dir) / "logs" / algo_log_name
[docs] def get_latest_run(log_dir: str | Path) -> Path | None: """Return the lexicographically latest run directory under a task log root.""" base_dir = Path(log_dir) if not base_dir.exists(): return None runs = sorted(path for path in base_dir.iterdir() if path.is_dir()) return runs[-1] if runs else None
[docs] def get_latest_checkpoint(run_dir: str | Path, *, suffix: str = ".pt") -> Path | None: """Return the latest model checkpoint inside a run directory.""" run_path = Path(run_dir) if not run_path.exists(): return None def _iteration(path: Path) -> int: stem_parts = path.stem.split("_", 1) if len(stem_parts) != 2: return -1 try: return int(stem_parts[1]) except ValueError: return -1 model_files = [ path for path in run_path.iterdir() if path.is_file() and path.name.startswith("model_") and path.suffix == suffix ] if not model_files: return None return max(model_files, key=_iteration)
def _normalize_load_run(load_run: str | int | PathLike[str]) -> str: return str(load_run)
[docs] def resolve_checkpoint_path( base_log_dir: str | Path, load_run: str | int | PathLike[str], *, suffix: str = ".pt", ) -> tuple[Path | None, Path | None]: """Resolve a latest or explicit checkpoint path from a task log root.""" base_dir = Path(base_log_dir) selected_run = _normalize_load_run(load_run) if selected_run == "-1": run_dir = get_latest_run(base_dir) if run_dir is None: return None, None checkpoint = get_latest_checkpoint(run_dir, suffix=suffix) return (checkpoint, run_dir) if checkpoint is not None else (None, None) candidate = Path(selected_run) if not candidate.exists(): candidate = base_dir / selected_run if candidate.is_file(): return candidate, candidate.parent if candidate.is_dir(): checkpoint = get_latest_checkpoint(candidate, suffix=suffix) return (checkpoint, candidate) if checkpoint is not None else (None, None) return None, None
[docs] def parse_checkpoint_path( cfg: DictConfig, *, root_dir: str | Path, load_run: str | int | PathLike[str] | None = None, task_name: str | None = None, checkpoint: str | int | None = None, suffix: str = ".pt", ) -> tuple[Path | None, Path | None]: """Resolve a checkpoint path from Hydra config and repository root.""" selected_task = task_name or str(OmegaConf.select(cfg, "training.task_name")) selected_run = ( _normalize_load_run(load_run) if load_run is not None else str(OmegaConf.select(cfg, "algo.load_run", default="-1")) ) selected_checkpoint = checkpoint if selected_checkpoint is None: selected_checkpoint = OmegaConf.select(cfg, "algo.checkpoint", default=-1) if selected_checkpoint in (None, "", -1, "-1"): selected_checkpoint = None return resolve_task_checkpoint_path( root_dir, task_name=selected_task, load_run=selected_run, algo_log_name=str(OmegaConf.select(cfg, "algo.algo_log_name")), checkpoint=str(selected_checkpoint) if selected_checkpoint is not None else None, suffix=suffix, log_root=OmegaConf.select(cfg, "training.log_root"), )
[docs] def resolve_task_checkpoint_path( root_dir: str | Path, *, task_name: str, load_run: str | int | PathLike[str], algo_log_name: str, checkpoint: str | None = None, suffix: str = ".pt", log_root: str | Path | None = None, ) -> tuple[Path | None, Path | None]: """Resolve checkpoint paths for auxiliary entrypoints through shared training semantics.""" task_log_root = ( get_entrypoint_log_root( root_dir, algo_log_name=algo_log_name, log_root=log_root, ) / task_name ) run_dir: Path | None selected_run = _normalize_load_run(load_run) if selected_run == "-1": run_dir = get_latest_run(task_log_root) else: candidate = Path(selected_run) if not candidate.exists(): candidate = task_log_root / selected_run if candidate.is_file(): return candidate, candidate.parent run_dir = candidate if candidate.is_dir() else None if run_dir is None: return None, None checkpoint_path: Path | None if checkpoint is not None: checkpoint_name = ( f"model_{checkpoint}{suffix}" if str(checkpoint).isdigit() else str(checkpoint) ) checkpoint_path = run_dir / checkpoint_name return (checkpoint_path, run_dir) if checkpoint_path.exists() else (None, run_dir) checkpoint_path = get_latest_checkpoint(run_dir, suffix=suffix) return (checkpoint_path, run_dir) if checkpoint_path is not None else (None, run_dir)