from __future__ import annotations
import importlib
import os
import time
from collections import deque
from typing import Any
from rich import box
from rich.console import Console
from rich.live import Live
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
def _fmt_time(seconds: float) -> str:
if seconds < 60:
return f"{seconds:.0f}s"
m, s = divmod(int(seconds), 60)
if m < 60:
return f"{m}m{s:02d}s"
h, m = divmod(m, 60)
return f"{h}h{m:02d}m{s:02d}s"
def _fmt_number(v: float) -> str:
if abs(v) == 0:
return "0"
if abs(v) >= 1e6:
return f"{v:.2e}"
if abs(v) >= 100:
return f"{v:.1f}"
if abs(v) >= 1:
return f"{v:.3f}"
if abs(v) >= 0.001:
return f"{v:.4f}"
return f"{v:.2e}"
def _load_wandb() -> Any | None:
"""Load wandb lazily so it remains an optional dependency."""
try:
return importlib.import_module("wandb")
except ImportError:
return None
[docs]
class BaseTrainingLogger:
"""Shared lifecycle and backend logging setup for rich training loggers."""
[docs]
def __init__(
self,
*,
algo_name: str,
max_iterations: int,
num_envs: int,
env_name: str,
log_dir: str,
log_backend: str,
wandb_project: str,
wandb_entity: str | None,
wandb_name: str,
wandb_group: list[str] | None | str,
wandb_job_type: str | None,
wandb_tags: list[str] | None,
wandb_notes: str | None,
refresh_per_second: int = 4,
tensorboard_subdir: str | None = "tb",
wandb_config: dict[str, Any] | None = None,
):
self.algo_name = algo_name
self.max_iterations = max_iterations
self.num_envs = num_envs
self.env_name = env_name
self._no_print = log_backend.lower() == "no_print"
self._log_backend = "none" if self._no_print else log_backend.lower()
self._console = Console()
self._live: Live | None = None
self._refresh_rate = refresh_per_second
self._last_live_refresh_time: float | None = None
self._start_time: float = 0.0
self._iteration: int = 0
self._reward_history: deque[float] = deque(maxlen=200)
self._latest_metrics: dict[str, float] = {}
self._latest_reward_components: dict[str, float] = {}
self._collect_time: float = 0.0
self._train_time: float = 0.0
self._mean_ep_length: float = 0.0
self._last_save: str = ""
self._status: str = ""
self._log_dir = log_dir
self._tb_writer: Any | None = None
self._wandb_run = None
self._owns_wandb_run = False
self._finished = False
self._closed = False
if self._log_backend == "tensorboard" and log_dir:
self._init_tensorboard(log_dir, tensorboard_subdir)
elif self._log_backend == "wandb":
self._init_wandb(
project=wandb_project,
entity=wandb_entity,
name=wandb_name or f"{algo_name}_{env_name}",
log_dir=log_dir,
group=wandb_group,
job_type=wandb_job_type,
tags=wandb_tags,
notes=wandb_notes,
extra_config=wandb_config,
)
def _format_tensorboard_message(self, tb_dir: str) -> str:
return f"[dim]TensorBoard: {tb_dir}[/]"
def _format_wandb_message(self, project: str, name: str) -> str:
return f"[dim]W&B: {project}/{name}[/]"
def _init_tensorboard(self, log_dir: str, subdir: str | None):
try:
from torch.utils.tensorboard import SummaryWriter
tb_dir = log_dir if subdir is None else os.path.join(log_dir, subdir)
os.makedirs(tb_dir, exist_ok=True)
self._tb_writer = SummaryWriter(log_dir=tb_dir)
if not self._no_print:
self._console.print(self._format_tensorboard_message(tb_dir))
except ImportError:
if not self._no_print:
self._console.print("[yellow]tensorboard not installed[/]")
def _init_wandb(
self,
*,
project: str,
entity: str | None,
name: str,
log_dir: str,
group: str | None | list[str],
job_type: str | None,
tags: list[str] | None,
notes: str | None,
extra_config: dict[str, Any] | None = None,
):
wandb = _load_wandb()
if wandb is None:
if not self._no_print:
self._console.print("[yellow]wandb not installed[/]")
return
self._wandb_run = wandb.run
if self._wandb_run is None:
config: dict[str, Any] = {
"algo": self.algo_name,
"env": self.env_name,
"num_envs": self.num_envs,
}
if extra_config:
config.update(extra_config)
kwargs: dict[str, Any] = {
"project": project,
"name": name,
"config": config,
"dir": log_dir or None,
"reinit": True,
}
if entity:
kwargs["entity"] = entity
if group:
kwargs["group"] = group
if job_type:
kwargs["job_type"] = job_type
if tags:
kwargs["tags"] = tags
if notes:
kwargs["notes"] = notes
self._wandb_run = wandb.init(**kwargs)
self._owns_wandb_run = True
if not self._no_print:
self._console.print(self._format_wandb_message(project, name))
[docs]
def start(self, *, status: str = ""):
if self._live is not None:
if status:
self._status = status
self._refresh()
return
self._start_time = time.time()
self._status = status
if not self._no_print:
self._live = Live(
self._build_display(),
console=self._console,
auto_refresh=False,
refresh_per_second=self._refresh_rate,
transient=False,
)
self._live.start(refresh=False)
def _stop_live(self) -> None:
if self._live is not None:
self._live.update(self._build_display(), refresh=True)
self._live.stop()
self._live = None
self._last_live_refresh_time = None
def _close_backends(self) -> None:
if self._tb_writer:
self._tb_writer.close()
self._tb_writer = None
if self._wandb_run and self._owns_wandb_run:
wandb = _load_wandb()
if wandb is not None:
wandb.finish()
self._wandb_run = None
self._owns_wandb_run = False
[docs]
def close(self) -> None:
"""Release live terminal state and backend handles without printing a summary."""
if self._closed:
return
self._stop_live()
self._close_backends()
self._closed = True
[docs]
def finish(self, *, title: str = "Training Summary", extra_summary: str = ""):
if self._finished:
return
self._stop_live()
elapsed = time.time() - self._start_time
if not self._no_print:
summary = (
f"[bold green]Training complete[/]\n"
f" Algo: [cyan]{self.algo_name}[/] | Env: [cyan]{self.env_name}[/]\n"
f" Iterations: [yellow]{self._iteration}[/]/{self.max_iterations}\n"
f" Total time: [yellow]{_fmt_time(elapsed)}[/]\n"
)
if extra_summary:
summary += extra_summary
if self._last_save:
summary += f" Last checkpoint: [dim]{self._last_save}[/]"
self._console.print()
self._console.print(Panel(summary, title=f"[bold]{title}[/]", border_style="green"))
self._close_backends()
self._closed = True
self._finished = True
[docs]
def update_ep_length(self, length: float):
self._mean_ep_length = length
[docs]
def log_save(self, path: str):
self._last_save = path
def _refresh(self, *, force: bool = False):
if self._live is None:
return
now = time.time()
if not force and self._refresh_rate > 0 and self._last_live_refresh_time is not None:
min_interval_s = 1.0 / self._refresh_rate
if now - self._last_live_refresh_time < min_interval_s:
return
self._last_live_refresh_time = now
self._live.update(self._build_display(), refresh=True)
def _estimate_eta(self) -> str:
if self._iteration <= 0:
return ""
elapsed = time.time() - self._start_time
remaining = self.max_iterations - self._iteration
avg_iter = elapsed / self._iteration
eta_s = remaining * avg_iter
return _fmt_time(eta_s)
def _build_header(self, *, include_status: bool) -> Panel:
elapsed = time.time() - self._start_time if self._start_time else 0
eta = self._estimate_eta()
header_text = Text()
header_text.append(f" {self.algo_name}", style="bold cyan")
header_text.append(" │ ", style="dim")
header_text.append(f"{self.env_name}", style="bold white")
header_text.append(" │ ", style="dim")
header_text.append(f"iter {self._iteration}/{self.max_iterations}", style="yellow")
header_text.append(" │ ", style="dim")
header_text.append(f"⏱ {_fmt_time(elapsed)}", style="green")
if eta:
header_text.append(" │ ETA ", style="dim")
header_text.append(eta, style="bold magenta")
if include_status and self._status:
header_text.append(" │ ", style="dim")
header_text.append(self._status, style="dim italic")
return Panel(header_text, style="dim", box=box.SIMPLE)
def _build_compact_header(
self,
*,
include_status: bool,
extra_fields: list[tuple[str, str]] | None = None,
) -> Text:
elapsed = time.time() - self._start_time if self._start_time else 0
eta = self._estimate_eta()
fields: list[tuple[str, str]] = [
(f" {self.algo_name}", "bold cyan"),
(self.env_name, "bold white"),
(f"iter {self._iteration}/{self.max_iterations}", "yellow"),
(f"⏱ {_fmt_time(elapsed)}", "green"),
]
if eta:
fields.append((f"ETA {eta}", "bold magenta"))
if self._mean_ep_length > 0:
fields.append((f"Ep Len {self._mean_ep_length:.1f}", "yellow"))
if extra_fields:
fields.extend(extra_fields)
if include_status and self._status:
fields.append((self._status, "dim italic"))
header_text = Text(no_wrap=True, overflow="ellipsis")
for index, (text, style) in enumerate(fields):
if index > 0:
header_text.append(" | ", style="dim")
header_text.append(text, style=style)
return header_text
def _build_reward_table_common(
self,
*,
wait_message: str,
include_ep_length: bool = True,
) -> Table:
table = Table(
box=box.SIMPLE_HEAVY,
show_header=True,
show_edge=False,
header_style="bold green",
expand=True,
pad_edge=False,
)
table.add_column("Rewards", style="white", ratio=1)
table.add_column("Value", justify="right", ratio=2)
if self._reward_history:
recent = list(self._reward_history)
mean_rew = sum(recent[-50:]) / max(len(recent[-50:]), 1)
peak_rew = max(recent) if recent else 0
if len(recent) >= 10:
old = sum(recent[-20:-10]) / 10
new = sum(recent[-10:]) / 10
trend = (
"[green]▲[/]"
if new > old * 1.05
else "[red]▼[/]"
if new < old * 0.95
else "[yellow]━[/]"
)
else:
trend = ""
table.add_row(
f"[bold]Reward[/] {trend}",
f"Mean [bold green]{mean_rew:.3f}[/] / Peak [dim]{peak_rew:.3f}[/]",
)
if include_ep_length and self._mean_ep_length > 0:
table.add_row(" Ep Len", f"[dim]{self._mean_ep_length:.1f}[/]")
if include_ep_length:
table.add_row("", "")
else:
table.add_row(wait_message, "")
if self._latest_reward_components:
for name, val in sorted(self._latest_reward_components.items()):
display = name.replace("reward/", "").replace("_", " ")
color = "green" if val > 0 else "red" if val < 0 else "dim"
table.add_row(f" {display}", f"[{color}]{val:+.4f}[/]")
return table
def _build_display(self) -> Panel:
raise NotImplementedError