Source code for unilab.logging.offpolicy

"""Rich-based training logger for off-policy RL algorithms (SAC, TD3, etc)."""

from __future__ import annotations

from collections import deque
from typing import Any

from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

from unilab.logging.common import BaseTrainingLogger, _fmt_number, _load_wandb

OFFPOLICY_COLLECTOR_TIMING_ORDER = {
    "weight_sync_ms": 0,
    "action_select_ms": 1,
    "env_step_ms": 2,
    "replay_ms": 3,
    "sync_coordination_ms": 4,
}

OFFPOLICY_COLLECTOR_TIMING_LABELS = {
    "weight_sync_ms": "Weight Sync",
    "action_select_ms": "Action Select",
    "env_step_ms": "Env Step",
    "replay_ms": "Replay",
    "sync_coordination_ms": "Sync Coordination",
}


def _metric_backend_key(key: str) -> str:
    """Keep canonical slash metrics intact; namespace legacy flat metrics under train/."""
    return key if "/" in key else f"train/{key}"


def _reward_backend_key(key: str) -> str:
    """Keep canonical reward/* keys intact; namespace bare component names under reward/."""
    return key if key.startswith("reward/") else f"reward/{key}"


def _dedupe_metric_aliases(metrics: dict[str, float] | None) -> dict[str, float] | None:
    """Drop legacy flat APPO aliases when canonical metrics are present."""
    if not metrics:
        return metrics
    normalized = dict(metrics)
    aliases = {
        "surrogate_loss": "loss/policy_loss",
        "value_loss": "loss/value_loss",
        "entropy": "policy/entropy",
        "kl": "ppo/approx_kl",
    }
    for legacy_key, canonical_key in aliases.items():
        if canonical_key in normalized:
            normalized.pop(legacy_key, None)
    return normalized


[docs] class OffPolicyLogger(BaseTrainingLogger): """Rich logger for off-policy RL algorithms (SAC, TD3, etc)."""
[docs] def __init__( self, algo_name: str = "RL", max_iterations: int = 1500, num_envs: int = 4096, env_name: str = "", obs_dim: int = 0, action_dim: int = 0, refresh_per_second: int = 4, log_dir: str = "", log_backend: str = "tensorboard", wandb_project: str = "unilab", wandb_entity: str | None = None, wandb_name: str = "", wandb_group: str | None = None, wandb_job_type: str | None = None, wandb_tags: list[str] | None = None, wandb_notes: str | None = None, ): super().__init__( algo_name=algo_name, max_iterations=max_iterations, num_envs=num_envs, env_name=env_name, log_dir=log_dir, log_backend=log_backend, wandb_project=wandb_project, wandb_entity=wandb_entity, wandb_name=wandb_name, wandb_group=wandb_group, wandb_job_type=wandb_job_type, wandb_tags=wandb_tags, wandb_notes=wandb_notes, refresh_per_second=refresh_per_second, tensorboard_subdir=None, wandb_config={ "obs_dim": obs_dim, "action_dim": action_dim, "max_iterations": max_iterations, }, ) self.obs_dim = obs_dim self.action_dim = action_dim self._total_steps: int = 0 self._buffer_size: int = 0 self._buffer_target: int = 0 self._wait_time: float = 0.0 self._learner_incremental_h2d_time: float = 0.0 self._weight_sync_time: float = 0.0 self._throughput_steps: int = 0 self._has_iteration_extra_info: bool = False self._iter_times: deque = deque(maxlen=50) self._collector_timing: dict[str, float] = {} self._timeout_rate: float = 0.0 self._terminated_rate: float = 0.0 self._buffer_utilization: float = 0.0 self._sync_collection: bool = False self._env_steps_per_sync: int = 0 self._staging_pool_len: int = 0 self._staging_pool_max: int = 0 self._status: str = "Initializing..." self._terminal_refresh_started: bool = False
def _format_tensorboard_message(self, tb_dir: str) -> str: return f"[dim]TensorBoard logging to: {tb_dir}[/]" def _format_wandb_message(self, project: str, name: str) -> str: return f"[dim]W&B logging to project: {project}, run: {name}[/]"
[docs] def start(self, *, status: str = "Warming up..."): super().start(status=status)
[docs] def finish(self, *, title: str = "Training Summary", extra_summary: str = ""): super().finish( title=title, extra_summary=f" Total env steps: [yellow]{self._total_steps:,}[/]\n{extra_summary}", )
[docs] def log_buffer_fill(self, current: int, target: int): self._buffer_size = current self._buffer_target = target pct = current / max(target, 1) * 100 self._status = f"Buffer fill: {current:,}/{target:,} ({pct:.0f}%)" if not self._terminal_refresh_started: self._refresh()
def _get_iter_steps_per_sec(self) -> float | None: if not self._has_iteration_extra_info or self._throughput_steps <= 0: return None iter_time = self._get_iter_pipeline_time() if iter_time <= 0: return None return self._throughput_steps / iter_time def _get_iter_pipeline_time(self) -> float: return self._learner_incremental_h2d_time + self._train_time + self._weight_sync_time def _build_compact_header( self, *, include_status: bool, extra_fields: list[tuple[str, str]] | None = None, ) -> Text: iter_steps_per_sec = self._get_iter_steps_per_sec() header_extra_fields: list[tuple[str, str]] = [] if iter_steps_per_sec is not None: header_extra_fields.append((f"Steps/s {iter_steps_per_sec:,.0f}", "bold green")) if extra_fields: header_extra_fields.extend(extra_fields) return super()._build_compact_header( include_status=include_status, extra_fields=header_extra_fields, )
[docs] def update_collector_timing(self, timing_ms: dict[str, float]): self._collector_timing.update(timing_ms)
[docs] def update_done_rates(self, timeout_rate: float, terminated_rate: float): self._timeout_rate = float(timeout_rate) self._terminated_rate = float(terminated_rate)
[docs] def update_buffer_utilization(self, utilization: float): self._buffer_utilization = float(utilization)
[docs] def update_replay_queue(self, current_len: int, max_size: int): self.update_staging_pool(current_len, max_size)
[docs] def update_staging_pool(self, current_len: int, max_size: int): self._staging_pool_len = current_len self._staging_pool_max = max_size
[docs] def set_collection_sync(self, enabled: bool, env_steps_per_sync: int = 0): self._sync_collection = enabled self._env_steps_per_sync = env_steps_per_sync
[docs] def log_collector(self, total_steps: int, buffer_size: int, mean_reward: float = 0.0): self._total_steps = total_steps self._buffer_size = buffer_size if mean_reward != 0: self._reward_history.append(mean_reward)
[docs] def log_step( self, iteration: int, metrics: dict[str, float] | None = None, reward: float | None = None, reward_metrics: dict[str, float] | None = None, reward_components: dict[str, float] | None = None, train_time: float = 0.0, wait_time: float = 0.0, learner_incremental_h2d_time: float = 0.0, weight_sync_time: float = 0.0, extra_info: dict | None = None, ): metrics = _dedupe_metric_aliases(metrics) self._iteration = iteration self._train_time = train_time self._wait_time = wait_time self._learner_incremental_h2d_time = learner_incremental_h2d_time self._weight_sync_time = weight_sync_time self._has_iteration_extra_info = extra_info is not None if extra_info: self._throughput_steps = int(extra_info.get("throughput_steps", 0)) else: self._throughput_steps = 0 self._iter_times.append(self._get_iter_pipeline_time()) if metrics: self._latest_metrics.update(metrics) if reward is not None: self._reward_history.append(reward) if reward_components: self._latest_reward_components = reward_components self._status = "Training" self._terminal_refresh_started = True self._refresh() self._backend_log_step( iteration, metrics, reward, reward_metrics, reward_components, train_time, )
def _backend_log_step( self, iteration: int, metrics: dict[str, float] | None, reward: float | None, reward_metrics: dict[str, float] | None, reward_components: dict[str, float] | None, train_time: float, ): global_step = self._total_steps if self._total_steps > 0 else iteration iter_steps_per_sec = self._get_iter_steps_per_sec() axis_scalars = { "axis/iteration": float(iteration), "axis/env_steps_total": float(global_step), } if self._tb_writer: writer = self._tb_writer for key, value in axis_scalars.items(): writer.add_scalar(key, value, global_step) if metrics: for key, value in metrics.items(): writer.add_scalar(_metric_backend_key(key), value, global_step) if reward is not None: writer.add_scalar("reward/mean", reward, global_step) if reward_metrics: for key, value in reward_metrics.items(): writer.add_scalar(_reward_backend_key(key), value, global_step) if reward_components: for key, value in reward_components.items(): writer.add_scalar(_reward_backend_key(key), value, global_step) if self._mean_ep_length > 0: writer.add_scalar("episode/length", self._mean_ep_length, global_step) writer.add_scalar("episode/timeout_rate", self._timeout_rate, global_step) writer.add_scalar("episode/terminated_rate", self._terminated_rate, global_step) writer.add_scalar("timing/learner_wait_ms", self._wait_time * 1000, global_step) writer.add_scalar( "timing/learner_incremental_h2d_ms", self._learner_incremental_h2d_time * 1000, global_step, ) writer.add_scalar("timing/learner_train_ms", train_time * 1000, global_step) writer.add_scalar( "timing/learner_weight_sync_ms", self._weight_sync_time * 1000, global_step, ) for key, value in self._collector_timing.items(): writer.add_scalar(f"timing/collector_{key}", value, global_step) if iter_steps_per_sec is not None: writer.add_scalar("perf/steps_per_sec", iter_steps_per_sec, global_step) writer.add_scalar("perf/iter_ms", self._get_iter_pipeline_time() * 1000, global_step) if self._wandb_run: wandb = _load_wandb() if wandb is None: return log_dict: dict[str, Any] = {"iteration": iteration, **axis_scalars} if metrics: for key, value in metrics.items(): log_dict[_metric_backend_key(key)] = value if reward is not None: log_dict["reward/mean"] = reward if reward_metrics: for key, value in reward_metrics.items(): log_dict[_reward_backend_key(key)] = value if reward_components: for key, value in reward_components.items(): log_dict[_reward_backend_key(key)] = value if self._mean_ep_length > 0: log_dict["episode/length"] = self._mean_ep_length log_dict["episode/timeout_rate"] = self._timeout_rate log_dict["episode/terminated_rate"] = self._terminated_rate log_dict["timing/learner_wait_ms"] = self._wait_time * 1000 log_dict["timing/learner_incremental_h2d_ms"] = ( self._learner_incremental_h2d_time * 1000 ) log_dict["timing/learner_train_ms"] = train_time * 1000 log_dict["timing/learner_weight_sync_ms"] = self._weight_sync_time * 1000 for key, value in self._collector_timing.items(): log_dict[f"timing/collector_{key}"] = value if iter_steps_per_sec is not None: log_dict["perf/steps_per_sec"] = iter_steps_per_sec log_dict["perf/iter_ms"] = self._get_iter_pipeline_time() * 1000 wandb.log(log_dict, step=global_step)
[docs] def log_status(self, status: str): self._status = status if not self._terminal_refresh_started or "[red]" in status or "ERROR" in status: self._refresh(force=True)
def _build_display(self) -> Panel: header = self._build_compact_header(include_status=True) left = self._build_metrics_table() right = self._build_reward_table() bottom = self._build_timing_table() grid = Table.grid(expand=True) grid.add_column(ratio=1) grid.add_column(width=2) grid.add_column(ratio=1) grid.add_row(left, "", right) return Panel( Group(header, Text(""), grid, Text(""), bottom), title="[bold] 🚀 UniLab Off-Policy Training [/]", border_style="bright_blue", padding=(0, 1), ) def _build_metrics_table(self) -> Table: table = Table( box=box.SIMPLE_HEAVY, show_header=True, show_edge=False, header_style="bold cyan", expand=True, pad_edge=False, ) table.add_column("Losses & Metrics", style="white", ratio=2) table.add_column("Value", style="yellow", justify="right", ratio=1) if not self._latest_metrics: table.add_row("[dim]Waiting for data...[/]", "") else: loss_keys = sorted([key for key in self._latest_metrics if "loss" in key.lower()]) other_keys = sorted([key for key in self._latest_metrics if "loss" not in key.lower()]) for key in loss_keys: value = self._latest_metrics[key] style = "red" if value > 10 else "yellow" table.add_row(key.replace("_", " ").title(), f"[{style}]{_fmt_number(value)}[/]") for key in other_keys: value = self._latest_metrics[key] table.add_row(f" {key.replace('_', ' ').title()}", _fmt_number(value)) return table def _build_reward_table(self) -> Table: return self._build_reward_table_common( wait_message="[dim]Waiting for data...[/]", include_ep_length=False, ) def _build_timing_table(self) -> Table: table = Table( box=box.SIMPLE_HEAVY, show_header=True, show_edge=False, header_style="bold blue", expand=True, pad_edge=False, ) table.add_column("Learner", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) table.add_column("Collector", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) table.add_column("System", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) wait_ms = self._wait_time * 1000 wait_color = "red" if wait_ms > 1.0 else "yellow" learner_items = [ ("Wait", f"[{wait_color}]{wait_ms:.1f}ms[/]"), ("H2D", f"{self._learner_incremental_h2d_time * 1000:.1f}ms"), ("Train", f"{self._train_time * 1000:.1f}ms"), ("Weight Sync", f"{self._weight_sync_time * 1000:.1f}ms"), ] collector_items = [ (OFFPOLICY_COLLECTOR_TIMING_LABELS.get(key, key), f"{value:.1f}ms") for key, value in sorted( self._collector_timing.items(), key=lambda item: ( OFFPOLICY_COLLECTOR_TIMING_ORDER.get( item[0], len(OFFPOLICY_COLLECTOR_TIMING_ORDER) ), item[0], ), ) ] system_items = [ ("Buffer", f"{self._buffer_size:,}"), ] system_items.extend( [ ("Timeout Rate", f"{self._timeout_rate * 100:.1f}%"), ("Terminated Rate", f"{self._terminated_rate * 100:.1f}%"), ] ) system_items.append(("Envs", f"{self.num_envs:,}")) sync_collect = ( f"{'✓' if self._sync_collection else '✗'} ({self._env_steps_per_sync})" if self._sync_collection else "✗" ) system_items.append(("Sync Collect", sync_collect)) if self._staging_pool_max > 0: staging_color = "green" if self._staging_pool_len < self._staging_pool_max else "yellow" system_items.append( ( "Staging Pool", f"[{staging_color}]{self._staging_pool_len}/{self._staging_pool_max}[/]", ) ) row_count = max(len(learner_items), len(collector_items), len(system_items)) for index in range(row_count): row: list[str] = [] for items in (learner_items, collector_items, system_items): if index < len(items): row.extend(items[index]) else: row.extend(["", ""]) table.add_row(*row) return table