Source code for unilab.training.experiment

"""Shared experiment tracking utilities for local files and W&B."""

from __future__ import annotations

import dataclasses
import getpass
import importlib
import importlib.util
import json
import os
import platform
import socket
import subprocess
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

from omegaconf import OmegaConf


def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any:
    if cfg is None:
        return default
    if isinstance(cfg, dict):
        return cfg.get(key, default)
    return getattr(cfg, key, default)


def _plain_dict(value: Any) -> Any:
    if OmegaConf.is_config(value):
        return OmegaConf.to_container(value, resolve=True)
    if dataclasses.is_dataclass(value) and not isinstance(value, type):
        return dataclasses.asdict(value)
    return value


def _load_wandb() -> Any | None:
    try:
        return importlib.import_module("wandb")
    except ImportError:
        return None


def _json_safe(value: Any) -> Any:
    if isinstance(value, dict):
        return {str(k): _json_safe(v) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return [_json_safe(v) for v in value]
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, (str, int, float, bool)) or value is None:
        return value
    try:
        json.dumps(value)
        return value
    except TypeError:
        return str(value)


def _fallback_device_info_dict() -> dict[str, str]:
    return {
        "platform": platform.platform(),
        "chip": platform.processor() or "unknown",
        "cpu_total_cores": str(os.cpu_count() or "unknown"),
        "gpu_name": "unknown",
        "memory": "unknown",
    }


def _benchmark_device_info_path() -> Path | None:
    for parent in Path(__file__).resolve().parents:
        candidate = parent / "benchmark" / "core" / "device_info.py"
        if candidate.is_file():
            return candidate
    return None


[docs] def get_device_info_dict() -> dict[str, str]: try: module_path = _benchmark_device_info_path() if module_path is None: return _fallback_device_info_dict() spec = importlib.util.spec_from_file_location( "unilab_benchmark_device_info", module_path, ) if spec is None or spec.loader is None: raise ImportError(f"Unable to load device info module from {module_path}") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) getter = getattr(module, "get_device_info_dict") return dict(getter()) except Exception: return _fallback_device_info_dict()
[docs] def get_git_info(root_dir: str | Path) -> dict[str, Any]: root = Path(root_dir) def _run_git(*args: str) -> str | None: try: result = subprocess.run( ["git", *args], cwd=root, check=True, capture_output=True, text=True, ) except Exception: return None return result.stdout.strip() commit = _run_git("rev-parse", "HEAD") branch = _run_git("rev-parse", "--abbrev-ref", "HEAD") status = _run_git("status", "--short") return { "commit": commit, "branch": branch, "dirty": bool(status), }
[docs] def build_wandb_run_name(algo_name: str, task_name: str, log_dir: str | Path | None) -> str: if log_dir is None: return f"{algo_name}__{task_name}" run_dir = Path(log_dir) return f"{algo_name}__{task_name}__{run_dir.name}"
[docs] def build_wandb_settings( training_cfg: Any, *, algo_name: str, task_name: str, sim_backend: str, log_dir: str | Path | None, ) -> dict[str, Any]: name = _cfg_get(training_cfg, "wandb_name") if not name: name = build_wandb_run_name(algo_name, task_name, log_dir) group = _cfg_get(training_cfg, "wandb_group") if not group: group = task_name job_type = _cfg_get(training_cfg, "wandb_job_type") if not job_type: job_type = algo_name tags = [str(tag) for tag in (_cfg_get(training_cfg, "wandb_tags", []) or [])] auto_tags = [algo_name, task_name, sim_backend, f"user-{getpass.getuser()}"] for tag in auto_tags: if tag not in tags: tags.append(tag) return { "project": _cfg_get(training_cfg, "wandb_project", "unilab"), "entity": _cfg_get(training_cfg, "wandb_entity"), "name": name, "group": group, "job_type": job_type, "tags": tags, "notes": _cfg_get(training_cfg, "wandb_notes"), "mode": _cfg_get(training_cfg, "wandb_mode"), }
[docs] class ExperimentTracker: """Tracks experiment metadata locally and optionally in Weights & Biases."""
[docs] def __init__( self, *, root_dir: str | Path, log_dir: str | Path, algo_name: str, task_name: str, sim_backend: str, training_cfg: Any, full_cfg: Any, device: str | None = None, collector_device: str | None = None, seed_info: Any | None = None, ): self.root_dir = Path(root_dir) self.log_dir = Path(log_dir) self.algo_name = algo_name self.task_name = task_name self.sim_backend = sim_backend self.training_cfg = training_cfg self.full_cfg = full_cfg self.device = device self.collector_device = collector_device self.seed_info = seed_info self.enabled = str(_cfg_get(training_cfg, "logger", "tensorboard")).lower() == "wandb" self.log_dir.mkdir(parents=True, exist_ok=True) self._wandb = None self._run = None self._owns_run = False self._started = False self._start_monotonic = 0.0 self._start_utc = "" self._summary: dict[str, Any] = {}
@property def run(self) -> Any | None: return self._run @property def run_url(self) -> str | None: return getattr(self._run, "url", None) if self._run is not None else None @property def wandb_settings(self) -> dict[str, Any]: return build_wandb_settings( self.training_cfg, algo_name=self.algo_name, task_name=self.task_name, sim_backend=self.sim_backend, log_dir=self.log_dir, )
[docs] def start(self) -> None: if self._started: return self._started = True self._start_monotonic = time.perf_counter() self._start_utc = datetime.now(timezone.utc).isoformat() metadata = { "algo": self.algo_name, "task": self.task_name, "sim_backend": self.sim_backend, "device": self.device, "collector_device": self.collector_device, "log_dir": str(self.log_dir), "start_time_utc": self._start_utc, "hostname": socket.gethostname(), "user": getpass.getuser(), "git": get_git_info(self.root_dir), "hardware": get_device_info_dict(), "wandb": self.wandb_settings, } if self.seed_info is not None: if hasattr(self.seed_info, "to_dict"): seed_payload = self.seed_info.to_dict() elif isinstance(self.seed_info, dict): seed_payload = dict(self.seed_info) else: seed_payload = {"effective_seed": self.seed_info} metadata.update(seed_payload) payload = { "run": _json_safe(metadata), "config": _json_safe(_plain_dict(self.full_cfg)), } self._write_json(self.log_dir / "run_config.json", payload) if not self.enabled: return self._wandb = _load_wandb() if self._wandb is None: print("[experiment_tracking] wandb not installed, skipping W&B experiment tracking.") return self._run = self._wandb.run if self._run is None: kwargs = { "project": self.wandb_settings["project"], "name": self.wandb_settings["name"], "config": payload, "dir": str(self.log_dir), "reinit": True, } for key in ("entity", "group", "job_type", "tags", "notes", "mode"): value = self.wandb_settings.get(key) if value not in (None, "", []): kwargs[key] = value self._run = self._wandb.init(**kwargs) self._owns_run = True else: self._run.config.update(payload, allow_val_change=True) if self._run is not None: self._run.summary["algo"] = self.algo_name self._run.summary["task"] = self.task_name self._run.summary["sim_backend"] = self.sim_backend if self.device: self._run.summary["device"] = self.device if self.collector_device: self._run.summary["collector_device"] = self.collector_device self._run.summary["log_dir"] = str(self.log_dir)
[docs] def update_summary(self, summary: dict[str, Any] | None = None) -> None: if summary: self._summary.update(summary) if not self._started: return wall_time_sec = time.perf_counter() - self._start_monotonic payload = { **self._summary, "algo": self.algo_name, "task": self.task_name, "sim_backend": self.sim_backend, "log_dir": str(self.log_dir), "start_time_utc": self._start_utc, "end_time_utc": datetime.now(timezone.utc).isoformat(), "wall_time_sec": wall_time_sec, "wandb_run_url": self.run_url, } if self.seed_info is not None: if hasattr(self.seed_info, "to_dict"): payload.update(self.seed_info.to_dict()) elif isinstance(self.seed_info, dict): payload.update(self.seed_info) else: payload["effective_seed"] = self.seed_info self._write_json(self.log_dir / "run_summary.json", _json_safe(payload)) if self._run is not None: for key, value in payload.items(): self._run.summary[key] = _json_safe(value)
[docs] def log_video(self, video_path: str | Path | None, key: str = "media/play_video") -> None: if video_path is None: return video = Path(video_path) if not video.exists(): return self._summary["play_video_path"] = str(video) if self._run is not None and self._wandb is not None: self._wandb.log({key: self._wandb.Video(str(video), format="mp4")})
[docs] def finish(self) -> None: if not self._started: return self.update_summary() if self._run is not None and self._wandb is not None and self._owns_run: self._wandb.finish() self._run = None self._wandb = None
@staticmethod def _write_json(path: Path, payload: Any) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
[docs] def patch_rsl_rl_wandb_writer() -> None: """Patch rsl-rl W&B writer so it can reuse an already-open run.""" try: import rsl_rl.utils.wandb_utils as wandb_utils except Exception: return if getattr(wandb_utils, "_UNILAB_PATCHED", False): return wandb = _load_wandb() if wandb is None: return wandb_mod = wandb from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter class PatchedWandbSummaryWriter(TensorboardSummaryWriter): def __init__(self, log_dir: str, flush_secs: int, cfg: dict) -> None: super().__init__(log_dir, flush_secs=flush_secs) run_name = os.path.split(log_dir)[-1] project = cfg.get("wandb_project", "unilab") entity = cfg.get("wandb_entity") or os.environ.get("WANDB_USERNAME") group = cfg.get("wandb_group") job_type = cfg.get("wandb_job_type") tags = cfg.get("wandb_tags") notes = cfg.get("wandb_notes") mode = cfg.get("wandb_mode") self.logged_videos: set[str] = set() self._owns_run = wandb_mod.run is None if self._owns_run: kwargs = { "project": project, "name": run_name, "config": {"log_dir": log_dir}, "settings": wandb_mod.Settings(start_method="thread"), } if entity: kwargs["entity"] = entity if group: kwargs["group"] = group if job_type: kwargs["job_type"] = job_type if tags: kwargs["tags"] = tags if notes: kwargs["notes"] = notes if mode: kwargs["mode"] = mode wandb_mod.init(**kwargs) else: wandb_mod.config.update({"log_dir": log_dir}, allow_val_change=True) def store_config(self, env_cfg: dict | object, train_cfg: dict) -> None: wandb_mod.config.update({"train_cfg": train_cfg}, allow_val_change=True) env_payload: Any if isinstance(env_cfg, dict): env_payload = env_cfg elif dataclasses.is_dataclass(env_cfg) and not isinstance(env_cfg, type): env_payload = dataclasses.asdict(env_cfg) elif hasattr(env_cfg, "to_dict"): env_payload = env_cfg.to_dict() # type: ignore[union-attr] else: env_payload = str(env_cfg) wandb_mod.config.update({"env_cfg": env_payload}, allow_val_change=True) def add_scalar( self, tag: Any, scalar_value: Any, global_step: Any = None, walltime: Any = None, new_style: Any = False, double_precision: Any = False, ) -> None: super().add_scalar( tag, scalar_value, global_step=global_step, walltime=walltime, new_style=new_style, double_precision=double_precision, ) wandb_mod.log({tag: scalar_value}, step=global_step) def stop(self) -> None: if self._owns_run: wandb_mod.finish() def save_model(self, model_path: str, it: int) -> None: wandb_mod.save(model_path, base_path=os.path.dirname(model_path)) def save_file(self, path: str) -> None: wandb_mod.save(path, base_path=os.path.dirname(path)) def save_video(self, video: Path, it: int) -> None: if video.name not in self.logged_videos: wandb_mod.log({"video": wandb_mod.Video(str(video), format="mp4")}, step=it) self.logged_videos.add(video.name) wandb_utils.WandbSummaryWriter = PatchedWandbSummaryWriter setattr(wandb_utils, "_UNILAB_PATCHED", True)
[docs] def patch_rsl_rl_resume_state() -> None: """Persist + restore ``Logger.tot_time`` / ``tot_timesteps`` across resume. Without this patch, rsl-rl's ``Logger.__init__`` writes ``tot_time = 0`` and ``tot_timesteps = 0`` and ``OnPolicyRunner.load`` never refreshes them, so the ``Train/mean_reward/time`` and ``Train/mean_episode_length/time`` TensorBoard scalars (which use ``int(self.tot_time)`` as their step) restart from 0 on every resumed run and visually overlap the original segment. See issue #441. The patch wraps ``OnPolicyRunner.save`` / ``OnPolicyRunner.load`` to round-trip a ``unilab_logger_state`` key in the saved dict. Legacy checkpoints (without the key) load unchanged. """ try: from rsl_rl.runners.on_policy_runner import OnPolicyRunner except Exception: return if getattr(OnPolicyRunner, "_UNILAB_RESUME_PATCHED", False): return import torch def _patched_save(self: Any, path: str, infos: dict | None = None) -> None: saved_dict = self.alg.save() saved_dict["iter"] = self.current_learning_iteration saved_dict["infos"] = infos saved_dict["unilab_logger_state"] = { "tot_time": float(getattr(self.logger, "tot_time", 0.0)), "tot_timesteps": int(getattr(self.logger, "tot_timesteps", 0)), } torch.save(saved_dict, path) self.logger.save_model(path, self.current_learning_iteration) def _patched_load( self: Any, path: str, load_cfg: dict | None = None, strict: bool = True, map_location: str | None = None, ) -> Any: loaded_dict = torch.load(path, weights_only=False, map_location=map_location) load_iteration = self.alg.load(loaded_dict, load_cfg, strict) if load_iteration: self.current_learning_iteration = loaded_dict["iter"] state = loaded_dict.get("unilab_logger_state") if state is not None: self.logger.tot_time = float(state.get("tot_time", 0.0)) self.logger.tot_timesteps = int(state.get("tot_timesteps", 0)) return loaded_dict["infos"] OnPolicyRunner.save = _patched_save # type: ignore[assignment] OnPolicyRunner.load = _patched_load # type: ignore[assignment] OnPolicyRunner._UNILAB_RESUME_PATCHED = True # type: ignore[attr-defined]