Source code for unilab.algos.torch.offpolicy.runner

"""Unified runner for off-policy RL algorithms (SAC, TD3)."""

import os
import statistics
import sys
import time
from collections import deque
from pathlib import Path
from typing import Any, cast

import torch

from unilab.algos.torch.common.device import get_env_dims
from unilab.algos.torch.offpolicy.worker import off_policy_collector_fn
from unilab.ipc import SharedObsNormStats, SharedWeightSync
from unilab.ipc.async_runner import _SPAWN_CTX, AsyncRunner
from unilab.ipc.replay_buffer import ReplayBuffer
from unilab.logging import OffPolicyLogger, TraceRecorder
from unilab.training.seed import apply_training_seed, derive_worker_seed
from unilab.utils.device import get_default_device


[docs] def compute_train_start_threshold(batch_size: int, learning_starts: int, num_envs: int) -> int: """Return the minimum replay size required before learner updates may start.""" return max(int(batch_size), max(int(learning_starts), 0) * max(int(num_envs), 1), 0)
[docs] def replay_buffer_ready_for_learning( replay_buffer_size: int, *, batch_size: int, learning_starts: int, num_envs: int, ) -> bool: """Whether the replay buffer has enough samples for the first learner step.""" return int(replay_buffer_size) >= compute_train_start_threshold( batch_size, learning_starts, num_envs, )
[docs] def build_reward_comparison_metrics( reward_history: deque, smoothed_reward: float, ) -> dict[str, float]: """Return the latest collector-side 100-episode mean for reward comparison.""" del smoothed_reward if not reward_history: return {} return {"mean_ep100": float(reward_history[-1])}
[docs] class OffPolicyRunner(AsyncRunner): """Unified runner for SAC and TD3."""
[docs] def __init__( self, learner, env_name: str, algo_type: str, # "sac", "td3", or "flashsac" num_envs: int = 4096, replay_buffer_n: int = 1024, batch_size: int = 8192, learning_starts: int = 0, updates_per_step: int = 8, policy_frequency: int = 4, sync_collection: bool = True, env_steps_per_sync: int = 1, device: str | None = None, actor_hidden_dim: int = 512, use_layer_norm: bool = True, obs_normalization: bool = False, sim_backend: str = "mujoco", env_cfg_override: dict | None = None, actor_kwargs: dict | None = None, seed: int | None = None, trace_enabled: bool = False, trace_output_dir: str | None = None, trace_thread_time: bool = False, trace_cuda_events: bool = True, ): super().__init__( env_name=env_name, env_cfg_overrides={}, rl_cfg={}, device=device, collector_device="cpu", num_envs=num_envs, sim_backend=sim_backend, ) self.learner = learner self.env_cfg_override = env_cfg_override self.algo_type = algo_type self.replay_buffer_n = replay_buffer_n self.batch_size = batch_size self.learning_starts = max(int(learning_starts), 0) self.train_start_threshold = compute_train_start_threshold( batch_size, self.learning_starts, num_envs, ) self.updates_per_step = updates_per_step self.policy_frequency = policy_frequency self.sync_collection = sync_collection self.env_steps_per_sync = env_steps_per_sync self.actor_hidden_dim = actor_hidden_dim self.use_layer_norm = use_layer_norm self.obs_normalization = obs_normalization self.actor_kwargs = actor_kwargs or {} self.seed = seed self._active_logger: OffPolicyLogger | None = None self.trace_enabled = trace_enabled self.trace_output_dir = trace_output_dir self.trace_thread_time = trace_thread_time self.trace_cuda_events = trace_cuda_events apply_training_seed(self.seed, torch_runtime=True, cuda=True) self.obs_dim, self.action_dim, self.critic_obs_dim = get_env_dims( self.env_name, sim_backend, env_cfg_override )
def _get_default_device(self) -> str: return get_default_device() def _build_learner(self): return self.learner def _collector_fn(self, stop_event, **kwargs): off_policy_collector_fn(stop_event=stop_event, **kwargs) @staticmethod def _read_recent_replay_field( replay_buffer, field_name: str, start_ptr: int, count: int ) -> torch.Tensor: idx = start_ptr % replay_buffer.capacity if hasattr(replay_buffer, field_name): source = getattr(replay_buffer, field_name) else: packed_key = { "rewards": "_rew_col", "dones": "_done_col", "truncated": "_trunc_col", }[field_name] source = replay_buffer._storage[:, getattr(replay_buffer, packed_key)] if idx + count <= replay_buffer.capacity: return cast(torch.Tensor, source[idx : idx + count].clone()) split = replay_buffer.capacity - idx return cast(torch.Tensor, torch.cat([source[idx:], source[: count - split]], dim=0).clone()) def _update_reward_stats_from_replay(self, replay_buffer, start_ptr: int, end_ptr: int) -> int: if not hasattr(self.learner, "update_reward_stats"): return end_ptr if getattr(self.learner, "reward_normalizer", None) is None: return end_ptr count = end_ptr - start_ptr if count <= 0: return end_ptr if count > replay_buffer.capacity: count = replay_buffer.capacity start_ptr = end_ptr - count if count % self.num_envs != 0: count -= count % self.num_envs start_ptr = end_ptr - count if count <= 0: return end_ptr rewards = self._read_recent_replay_field(replay_buffer, "rewards", start_ptr, count) dones = self._read_recent_replay_field(replay_buffer, "dones", start_ptr, count) num_steps = count // self.num_envs self.learner.update_reward_stats( rewards.view(num_steps, self.num_envs), dones.view(num_steps, self.num_envs), ) return end_ptr
[docs] def learn( self, max_iterations: int = 1500, save_interval: int = 50, log_dir: str = "logs", logger_type: str = "tensorboard", ) -> None: """Unified training loop for off-policy algorithms.""" os.makedirs(log_dir, exist_ok=True) trace_output_path = None trace_recorder: TraceRecorder | None = None if self.trace_enabled: trace_root = Path(self.trace_output_dir or log_dir) trace_output_path = trace_root / "perfetto_offpolicy_timeline.json" trace_recorder = TraceRecorder("offpolicy_learner") train_start_wall = time.time() best_mean_reward = float("-inf") last_mean_reward = 0.0 ckpt_path: str | None = None iteration = 0 # Setup replay buffer buffer_capacity = self.replay_buffer_n * self.num_envs replay_buffer = ReplayBuffer( capacity=buffer_capacity, obs_dim=self.obs_dim, action_dim=self.action_dim, device=self.device, critic_dim=self.critic_obs_dim, ) self._shared_resources.append(replay_buffer) replay_buffer.trace_recorder = trace_recorder replay_buffer.trace_thread_time = self.trace_thread_time replay_buffer.trace_cuda_events = self.trace_cuda_events # Setup weight sync weight_sync = SharedWeightSync.from_state_dict(self.learner.actor.state_dict(), create=True) self._shared_resources.append(weight_sync) weight_sync.trace_recorder = trace_recorder weight_sync.trace_thread_time = self.trace_thread_time # Setup sync queues collection_ready_queue = None trainer_done_queue = None if self.sync_collection: collection_ready_queue = _SPAWN_CTX.Queue(maxsize=1) trainer_done_queue = _SPAWN_CTX.Queue(maxsize=1) trainer_done_queue.put(1) print(f"[Runner] Collection sync enabled: env_steps_per_sync={self.env_steps_per_sync}") metrics_queue = _SPAWN_CTX.Queue(maxsize=100) # Setup obs normalization shared_obs_normalizer_stats = None if self.obs_normalization: shared_obs_normalizer_stats = SharedObsNormStats(_SPAWN_CTX) # Start collector weight_param_shapes = {k: v.shape for k, v in self.learner.actor.state_dict().items()} collector_kwargs = { "env_name": self.env_name, "num_envs": self.num_envs, "replay_buffer": replay_buffer, "weight_sync_name": weight_sync.name, "weight_sync_lock": weight_sync._lock, "weight_param_shapes": weight_param_shapes, "algo_type": self.algo_type, "actor_hidden_dim": self.actor_hidden_dim, "use_layer_norm": self.use_layer_norm, "learning_starts": self.learning_starts, "metrics_queue": metrics_queue, "sync_collection": self.sync_collection, "collection_ready_queue": collection_ready_queue, "trainer_done_queue": trainer_done_queue, "env_steps_per_sync": self.env_steps_per_sync, "obs_normalization": self.obs_normalization, "shared_obs_normalizer_stats": shared_obs_normalizer_stats, "sim_backend": self.sim_backend, "env_cfg_override": self.env_cfg_override, "obs_dim": self.obs_dim, "action_dim": self.action_dim, "actor_kwargs": self.actor_kwargs, "seed": derive_worker_seed(self.seed, worker_index=0), "trace_enabled": self.trace_enabled, "trace_thread_time": self.trace_thread_time, } self._start_collector( target_fn=off_policy_collector_fn, kwargs={"stop_event": self._stop_event, **collector_kwargs}, ) time.sleep(0.5) if self._collector_process: print(f"[Runner] Collector process alive: {self._collector_process.is_alive()}") # Setup logger logger = OffPolicyLogger( algo_name=( "FlashSAC" if self.algo_type == "flashsac" else f"Fast{self.algo_type.upper()}" ), max_iterations=max_iterations, num_envs=self.num_envs, env_name=self.env_name, obs_dim=self.obs_dim, action_dim=self.action_dim, log_dir=log_dir, log_backend=logger_type, ) logger.set_collection_sync(self.sync_collection, self.env_steps_per_sync) if hasattr(self.learner, "use_symmetry") and self.learner.use_symmetry: logger.log_status("Symmetry augmentation: enabled") self._active_logger = logger logger.start() reward_history: deque = deque(maxlen=100) latest_reward_components: dict[str, float] = {} last_buf_log = 0 write_read_ema = 0.0 reward_stats_ptr = 0 train_start_threshold = self.train_start_threshold training_e2e_start_ns = time.perf_counter_ns() if trace_recorder else 0 # Training loop for iteration in range(1, max_iterations + 1): # Wait for data wait_start = time.time() wait_start_ns = time.perf_counter_ns() if trace_recorder else 0 if self.sync_collection and collection_ready_queue: import queue while True: try: collection_ready_queue.get(timeout=1.0) except queue.Empty: if not self._check_collector_alive(): self._drain_metrics( metrics_queue, reward_history, latest_reward_components, logger, trace_recorder, ) logger.log_status("[red]ERROR: Collector died[/]") logger.finish() summary = { "status": "collector_died", "completed_iterations": iteration, "total_env_steps": int(logger._total_steps), "final_mean_reward": None, "best_mean_reward": None, "mean_episode_length": float(logger._mean_ep_length), "last_checkpoint": ckpt_path, "training_wall_time_sec": time.time() - train_start_wall, } self.last_run_summary = summary return continue self._drain_metrics( metrics_queue, reward_history, latest_reward_components, logger, trace_recorder, ) cur_size = int(replay_buffer.size[0]) if replay_buffer_ready_for_learning( cur_size, batch_size=self.batch_size, learning_starts=self.learning_starts, num_envs=self.num_envs, ): break if cur_size - last_buf_log >= self.num_envs * 10: last_buf_log = cur_size logger.log_buffer_fill(cur_size, train_start_threshold) if trainer_done_queue: trainer_done_queue.put(1) else: while not replay_buffer_ready_for_learning( int(replay_buffer.size[0]), batch_size=self.batch_size, learning_starts=self.learning_starts, num_envs=self.num_envs, ): if not self._check_collector_alive(): self._drain_metrics( metrics_queue, reward_history, latest_reward_components, logger ) logger.log_status("[red]ERROR: Collector died[/]") logger.finish() summary = { "status": "collector_died", "completed_iterations": iteration, "total_env_steps": int(logger._total_steps), "final_mean_reward": None, "best_mean_reward": None, "mean_episode_length": float(logger._mean_ep_length), "last_checkpoint": ckpt_path, "training_wall_time_sec": time.time() - train_start_wall, } self.last_run_summary = summary return cur_size = int(replay_buffer.size[0]) if cur_size - last_buf_log >= self.num_envs * 10: last_buf_log = cur_size logger.log_buffer_fill(cur_size, train_start_threshold) time.sleep(0.1) self._drain_metrics( metrics_queue, reward_history, latest_reward_components, logger, trace_recorder, ) wait_time = time.time() - wait_start if trace_recorder: trace_recorder.add_slice( "learner/wait_for_data", category="learner", start_ns=wait_start_ns, end_ns=time.perf_counter_ns(), args={"iteration": iteration}, ) self._drain_metrics( metrics_queue, reward_history, latest_reward_components, logger, trace_recorder, ) _reward_stats_ns = time.perf_counter_ns() if trace_recorder else 0 reward_stats_ptr = self._update_reward_stats_from_replay( replay_buffer, reward_stats_ptr, int(replay_buffer.ptr[0]), ) if trace_recorder: trace_recorder.add_slice( "learner/update_reward_stats", category="learner", start_ns=_reward_stats_ns, end_ns=time.perf_counter_ns(), ) from collections import defaultdict iter_metrics = defaultdict(list) ptr_before = int(replay_buffer.ptr[0]) # Local variable for faster access in hot loop learner = self.learner # Sample from torch buffer (zero-copy on CUDA/MPS) _sample_ns = time.perf_counter_ns() if trace_recorder else 0 large_batch = replay_buffer.sample(self.batch_size * self.updates_per_step) learner_incremental_h2d_time = float( getattr(replay_buffer, "last_incremental_h2d_time_s", 0.0) ) if trace_recorder: trace_recorder.add_slice( "learner/replay_sample", category="learner", start_ns=_sample_ns, end_ns=time.perf_counter_ns(), args={"total_batch": self.batch_size * self.updates_per_step}, ) train_start = time.time() for update_idx in range(self.updates_per_step): s = update_idx * self.batch_size e = s + self.batch_size batch = {k: v[s:e] for k, v in large_batch.items()} _critic_ns = time.perf_counter_ns() if trace_recorder else 0 critic_metrics = learner.update_critic(batch) if trace_recorder: trace_recorder.add_slice( "learner/update_critic", category="learner", start_ns=_critic_ns, end_ns=time.perf_counter_ns(), args={"update_idx": update_idx}, ) for k, v in critic_metrics.items(): iter_metrics[k].append(v) if update_idx % self.policy_frequency == 0: _actor_ns = time.perf_counter_ns() if trace_recorder else 0 actor_metrics = learner.update_actor(batch) if trace_recorder: trace_recorder.add_slice( "learner/update_actor", category="learner", start_ns=_actor_ns, end_ns=time.perf_counter_ns(), args={"update_idx": update_idx}, ) for k, v in actor_metrics.items(): iter_metrics[k].append(v) _target_ns = time.perf_counter_ns() if trace_recorder else 0 learner.soft_update_target() if trace_recorder: trace_recorder.add_slice( "learner/soft_update_target", category="learner", start_ns=_target_ns, end_ns=time.perf_counter_ns(), args={"update_idx": update_idx}, ) if self.obs_normalization and getattr(self.learner, "obs_normalizer", None) is not None: assert shared_obs_normalizer_stats is not None shared_obs_normalizer_stats.put( ( self.learner.obs_normalizer.mean.cpu().numpy(), self.learner.obs_normalizer.std.cpu().numpy(), ) ) train_time = time.time() - train_start self.learner.update_count += 1 _ws_ns = time.perf_counter_ns() if trace_recorder else 0 weight_sync_start = time.perf_counter() weight_sync.write_weights(self.learner.actor.state_dict()) weight_sync_time = time.perf_counter() - weight_sync_start if trace_recorder: trace_recorder.add_slice( "learner/weight_sync_write", category="learner", start_ns=_ws_ns, end_ns=time.perf_counter_ns(), ) trace_recorder.add_counter( "replay_size", int(replay_buffer.size[0]), category="replay", ) trace_recorder.flush_cuda_pending() if self.sync_collection and trainer_done_queue: trainer_done_queue.put(1) write_delta = int(replay_buffer.ptr[0]) - ptr_before consume = self.batch_size * self.updates_per_step write_read_ema = 0.9 * write_read_ema + 0.1 * (write_delta / max(consume, 1)) logger.update_buffer_utilization(write_read_ema) avg_metrics = {k: statistics.mean(v) for k, v in iter_metrics.items() if v} mean_reward = statistics.mean(reward_history) if reward_history else 0.0 last_mean_reward = float(mean_reward) best_mean_reward = max(best_mean_reward, last_mean_reward) logger.log_step( iteration=iteration, metrics=avg_metrics, reward=mean_reward, reward_metrics=build_reward_comparison_metrics(reward_history, mean_reward), reward_components=latest_reward_components, train_time=train_time, wait_time=wait_time, learner_incremental_h2d_time=learner_incremental_h2d_time, weight_sync_time=weight_sync_time, extra_info={ "throughput_steps": self.num_envs * self.env_steps_per_sync, }, ) if save_interval > 0 and iteration % save_interval == 0: ckpt_path = os.path.join(log_dir, f"model_{iteration}.pt") torch.save(self.learner.get_state_dict(), ckpt_path) logger.log_save(ckpt_path) if trace_recorder: trace_recorder.add_slice( "learner/training_e2e", category="learner", start_ns=training_e2e_start_ns, end_ns=time.perf_counter_ns(), args={"iterations": iteration}, ) ckpt_path = os.path.join(log_dir, f"model_{max_iterations}.pt") torch.save(self.learner.get_state_dict(), ckpt_path) logger.log_save(ckpt_path) logger.finish() if trace_recorder and trace_output_path: trace_recorder.write_json(trace_output_path) print(f"[Runner] Perfetto trace written to {trace_output_path}") summary = { "status": "completed", "completed_iterations": iteration, "total_env_steps": int(logger._total_steps), "final_mean_reward": last_mean_reward if reward_history else None, "best_mean_reward": best_mean_reward if reward_history else None, "mean_episode_length": float(logger._mean_ep_length), "last_checkpoint": ckpt_path, "trace_path": str(trace_output_path) if trace_output_path is not None else None, "training_wall_time_sec": time.time() - train_start_wall, } self.last_run_summary = summary self._active_logger = None
[docs] def close(self) -> None: active_logger = getattr(self, "_active_logger", None) if active_logger is not None: active_logger.close() self._active_logger = None super().close()
# _check_collector_alive() inherited from AsyncRunner base class @staticmethod def _drain_metrics(queue, reward_history, reward_components, logger, trace_recorder=None): while True: try: m = queue.get_nowait() except Exception: break try: if "error" in m: logger.log_status(f"[red]Collector ERROR: {m['error']}[/]") raise RuntimeError(f"Collector process failed: {m['error']}") updated_rew = False if "mean_ep_reward" in m: reward_history.append(m["mean_ep_reward"]) updated_rew = True if "reward_components" in m: reward_components.clear() reward_components.update(m["reward_components"]) if "mean_ep_length" in m: logger.update_ep_length(m["mean_ep_length"]) if "collector_timing_ms" in m: logger.update_collector_timing(m["collector_timing_ms"]) if "timeout_rate" in m or "terminated_rate" in m: logger.update_done_rates( timeout_rate=float(m.get("timeout_rate", 0.0)), terminated_rate=float(m.get("terminated_rate", 0.0)), ) if "total_steps" in m and "buffer_size" in m: logger.log_collector( m["total_steps"], m["buffer_size"], m.get("mean_ep_reward", 0.0) if updated_rew else 0.0, ) if trace_recorder and "trace_events" in m: trace_recorder.extend(m["trace_events"]) except Exception as e: print(f"[OffPolicyRunner] metrics drain error: {e}", file=sys.stderr) break