Source code for unilab.logging.onpolicy

from __future__ import annotations

from typing import Any

from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text

from unilab.logging.common import BaseTrainingLogger, _fmt_number, _load_wandb


[docs] class OnPolicyLogger(BaseTrainingLogger): """Rich logger for on-policy RL (PPO, A2C, etc)."""
[docs] def __init__( self, algo_name: str = "PPO", max_iterations: int = 1500, num_envs: int = 4096, num_steps: int = 24, env_name: str = "", log_dir: str = "", log_backend: str = "tensorboard", wandb_project: str = "unilab", wandb_entity: str | None = None, wandb_name: str = "", wandb_group: str | None = None, wandb_job_type: str | None = None, wandb_tags: list[str] | None = None, wandb_notes: str | None = None, ): super().__init__( algo_name=algo_name, max_iterations=max_iterations, num_envs=num_envs, env_name=env_name, log_dir=log_dir, log_backend=log_backend, wandb_project=wandb_project, wandb_entity=wandb_entity, wandb_name=wandb_name, wandb_group=wandb_group, wandb_job_type=wandb_job_type, wandb_tags=wandb_tags, wandb_notes=wandb_notes, tensorboard_subdir="tb", ) self.num_steps = num_steps
[docs] def start(self, *, status: str = ""): super().start(status=status)
[docs] def finish(self, *, title: str = "Training Summary", extra_summary: str = ""): super().finish(title=title, extra_summary=extra_summary)
[docs] def log_step( self, iteration: int, metrics: dict[str, float] | None = None, reward: float | None = None, reward_components: dict[str, float] | None = None, collect_time: float = 0.0, train_time: float = 0.0, ): self._iteration = iteration self._collect_time = collect_time self._train_time = train_time if metrics: self._latest_metrics.update(metrics) if reward is not None: self._reward_history.append(reward) if reward_components: self._latest_reward_components = reward_components self._refresh() self._backend_log_step(iteration, metrics, reward, reward_components)
def _backend_log_step( self, iteration: int, metrics: dict[str, float] | None, reward: float | None, reward_components: dict[str, float] | None, ): if self._tb_writer: w = self._tb_writer if metrics: for k, v in metrics.items(): w.add_scalar(f"train/{k}", v, iteration) if reward is not None: w.add_scalar("reward/mean", reward, iteration) if reward_components: for k, v in reward_components.items(): w.add_scalar(f"reward/{k}", v, iteration) if self._mean_ep_length > 0: w.add_scalar("episode/length", self._mean_ep_length, iteration) w.add_scalar("perf/collect_time_ms", self._collect_time * 1000, iteration) w.add_scalar("perf/train_time_ms", self._train_time * 1000, iteration) if self._wandb_run: wandb = _load_wandb() if wandb is None: return log_dict: dict[str, Any] = {"iteration": iteration} if metrics: for k, v in metrics.items(): log_dict[f"train/{k}"] = v if reward is not None: log_dict["reward/mean"] = reward if reward_components: for k, v in reward_components.items(): log_dict[f"reward/{k}"] = v if self._mean_ep_length > 0: log_dict["episode/length"] = self._mean_ep_length log_dict["perf/collect_time_ms"] = self._collect_time * 1000 log_dict["perf/train_time_ms"] = self._train_time * 1000 wandb.log(log_dict, step=iteration) def _build_display(self) -> Panel: header = self._build_compact_header(include_status=True) left = self._build_metrics_table() right = self._build_reward_table() bottom = self._build_timing_table() grid = Table.grid(expand=True) grid.add_column(ratio=1) grid.add_column(width=2) grid.add_column(ratio=1) grid.add_row(left, "", right) return Panel( Group(header, Text(""), grid, Text(""), bottom), title="[bold] 🚀 UniLab On-Policy Training [/]", border_style="bright_blue", padding=(0, 1), ) def _build_metrics_table(self) -> Table: table = Table( box=box.SIMPLE_HEAVY, show_header=True, show_edge=False, header_style="bold cyan", expand=True, pad_edge=False, ) table.add_column("Losses & Metrics", style="white", ratio=2) table.add_column("Value", style="yellow", justify="right", ratio=1) if not self._latest_metrics: table.add_row("[dim]Waiting for data...[/]", "") else: loss_keys = sorted([key for key in self._latest_metrics if "loss" in key.lower()]) other_keys = sorted([key for key in self._latest_metrics if "loss" not in key.lower()]) for key in loss_keys: value = self._latest_metrics[key] style = "red" if value > 10 else "yellow" table.add_row(key.replace("_", " ").title(), f"[{style}]{_fmt_number(value)}[/]") for key in other_keys: value = self._latest_metrics[key] table.add_row(f" {key.replace('_', ' ').title()}", _fmt_number(value)) return table def _build_reward_table(self) -> Table: return self._build_reward_table_common( wait_message="[dim]Waiting for data...[/]", include_ep_length=False, ) def _build_timing_table(self) -> Table: table = Table( box=box.SIMPLE_HEAVY, show_header=True, show_edge=False, header_style="bold blue", expand=True, pad_edge=False, ) table.add_column("Learner", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) table.add_column("Collector", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) table.add_column("System", style="white", ratio=2, no_wrap=True) table.add_column("Value", style="yellow", justify="right", ratio=1, no_wrap=True) iter_time = self._collect_time + self._train_time fps = int(self.num_envs * self.num_steps / max(iter_time, 1e-8)) if iter_time > 0 else 0 learner_items = [ ("Train", f"{self._train_time * 1000:.1f}ms"), ("Iter Time", f"{iter_time * 1000:.1f}ms"), ] collector_items = [ ("Collect", f"{self._collect_time * 1000:.1f}ms"), ] system_items = [ ("Envs", f"{self.num_envs:,}"), ("Steps/s", f"{fps:,}"), ] row_count = max(len(learner_items), len(collector_items), len(system_items)) for index in range(row_count): row: list[str] = [] for items in (learner_items, collector_items, system_items): if index < len(items): row.extend(items[index]) else: row.extend(["", ""]) table.add_row(*row) return table def _build_compact_header( self, *, include_status: bool, extra_fields: list[tuple[str, str]] | None = None, ) -> Text: iter_time = self._collect_time + self._train_time header_extra_fields: list[tuple[str, str]] = [] if iter_time > 0: steps_per_second = self.num_envs * self.num_steps / iter_time header_extra_fields.append((f"Steps/s {steps_per_second:,.0f}", "bold green")) if extra_fields: header_extra_fields.extend(extra_fields) return super()._build_compact_header( include_status=include_status, extra_fields=header_extra_fields, )