"""Off-policy runner using CPU-pinned double-buffer replay pipeline (B path)."""
from __future__ import annotations
import os
import statistics
import sys
import time
from collections import defaultdict, deque
from contextlib import nullcontext
from pathlib import Path
import torch
from unilab.algos.torch.offpolicy.runner import (
OffPolicyRunner,
build_reward_comparison_metrics,
compute_train_start_threshold,
replay_buffer_ready_for_learning,
)
from unilab.algos.torch.offpolicy.worker import off_policy_collector_fn
from unilab.ipc import SharedObsNormStats, SharedWeightSync
from unilab.ipc.async_runner import _SPAWN_CTX
from unilab.ipc.replay_buffer import ReplayBuffer
from unilab.ipc.replay_pipelines.cpu_pinned_double_buffer import (
CPUPinnedDoubleBufferReplayPipeline,
)
from unilab.logging import OffPolicyLogger, TraceRecorder
from unilab.training.seed import derive_worker_seed
[docs]
class DoubleBufferOffPolicyRunner(OffPolicyRunner):
"""OffPolicyRunner variant that uses CPUPinnedDoubleBufferReplayPipeline.
The only behavioural difference from the parent class is in learn():
- ReplayBuffer is created as packed CPU shared storage.
- Sampling goes through CPUPinnedDoubleBufferReplayPipeline instead of
ReplayBuffer.sample().
"""
LEARNER_LOG_INTERVAL = 10
[docs]
def __init__(
self,
*,
replay_prefetch_mode: str = "one_tick",
verbose_metrics: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if replay_prefetch_mode != "one_tick":
raise ValueError(
"DoubleBufferOffPolicyRunner only supports replay_prefetch_mode='one_tick'"
)
self.replay_prefetch_mode = replay_prefetch_mode
self.verbose_metrics = bool(verbose_metrics)
self.replay_pack_layout = "packed"
self.replay_pack_executor = "collector_thread"
self.replay_h2d_submitter = "auto"
self.replay_transfer_backend: dict[str, object] = {}
[docs]
def learn(
self,
max_iterations: int = 1500,
save_interval: int = 50,
log_dir: str = "logs",
logger_type: str = "tensorboard",
) -> None:
os.makedirs(log_dir, exist_ok=True)
trace_output_path = None
trace_recorder: TraceRecorder | None = None
if self.trace_enabled:
trace_root = Path(self.trace_output_dir or log_dir)
trace_output_path = trace_root / "perfetto_offpolicy_timeline.json"
trace_recorder = TraceRecorder("offpolicy_learner")
train_start_wall = time.time()
best_mean_reward = float("-inf")
last_mean_reward = 0.0
ckpt_path: str | None = None
iteration = 0
# --- memory budget check ---
from unilab.ipc.memory_budget import estimate_offpolicy_bytes, warn_if_over_budget
mem_est = estimate_offpolicy_bytes(
num_envs=self.num_envs,
replay_buffer_n=self.replay_buffer_n,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
critic_dim=self.critic_obs_dim,
batch_size=self.batch_size,
updates_per_step=self.updates_per_step,
)
warn_if_over_budget(mem_est, label=f"Off-policy ({self.algo_type})")
# --- replay buffer (packed CPU shared storage) ---
buffer_capacity = self.replay_buffer_n * self.num_envs
replay_buffer = ReplayBuffer(
capacity=buffer_capacity,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
device=self.device,
defer_gpu=True,
critic_dim=self.critic_obs_dim,
packed_cpu_storage=self.replay_pack_layout == "packed",
)
self._shared_resources.append(replay_buffer)
replay_buffer.trace_recorder = trace_recorder
replay_buffer.trace_thread_time = self.trace_thread_time
replay_buffer.trace_cuda_events = self.trace_cuda_events
# --- replay pipeline (double buffer) ---
sample_count = self.batch_size * self.updates_per_step
collector_pack_request_queue = _SPAWN_CTX.Queue(maxsize=1)
collector_pack_ready_queue = _SPAWN_CTX.Queue(maxsize=1)
packed_width = int(replay_buffer._storage.shape[1])
collector_pack_shared_slots = [
torch.empty((sample_count, packed_width), dtype=torch.float32).share_memory_()
for _ in range(2)
]
_verbose_output_dir: str | None = None
if self.verbose_metrics:
_vroot = Path(self.trace_output_dir) if self.trace_output_dir else Path(log_dir)
_verbose_output_dir = str(_vroot)
replay_pipeline = CPUPinnedDoubleBufferReplayPipeline(
replay_buffer,
device=self.device,
sample_count=sample_count,
base_seed=int(self.seed or 0),
trace_recorder=trace_recorder,
trace_cuda_events=self.trace_cuda_events,
verbose=self.verbose_metrics,
verbose_output_dir=_verbose_output_dir,
collector_pack_request_queue=collector_pack_request_queue,
collector_pack_ready_queue=collector_pack_ready_queue,
collector_pack_shared_slots=collector_pack_shared_slots,
)
self.replay_h2d_submitter = getattr(
replay_pipeline,
"h2d_submitter",
self.replay_h2d_submitter,
)
self.replay_transfer_backend = getattr(
replay_pipeline,
"transfer_manifest",
{},
)
# --- weight sync ---
weight_sync = SharedWeightSync.from_state_dict(self.learner.actor.state_dict(), create=True)
self._shared_resources.append(weight_sync)
weight_sync.trace_recorder = trace_recorder
weight_sync.trace_thread_time = self.trace_thread_time
# --- sync queues ---
collection_ready_queue = None
trainer_done_queue = None
if self.sync_collection:
collection_ready_queue = _SPAWN_CTX.Queue(maxsize=1)
trainer_done_queue = _SPAWN_CTX.Queue(maxsize=1)
trainer_done_queue.put(1)
print(
f"[DoubleBufferRunner] Collection sync enabled: "
f"env_steps_per_sync={self.env_steps_per_sync}"
)
metrics_queue = _SPAWN_CTX.Queue(maxsize=100)
# --- obs normalization ---
shared_obs_normalizer_stats = None
if self.obs_normalization:
shared_obs_normalizer_stats = SharedObsNormStats(_SPAWN_CTX)
# --- start collector ---
weight_param_shapes = {k: v.shape for k, v in self.learner.actor.state_dict().items()}
collector_kwargs = {
"env_name": self.env_name,
"num_envs": self.num_envs,
"replay_buffer": replay_buffer,
"weight_sync_name": weight_sync.name,
"weight_sync_lock": weight_sync._lock,
"weight_param_shapes": weight_param_shapes,
"algo_type": self.algo_type,
"actor_hidden_dim": self.actor_hidden_dim,
"use_layer_norm": self.use_layer_norm,
"learning_starts": self.learning_starts,
"metrics_queue": metrics_queue,
"sync_collection": self.sync_collection,
"collection_ready_queue": collection_ready_queue,
"trainer_done_queue": trainer_done_queue,
"env_steps_per_sync": self.env_steps_per_sync,
"obs_normalization": self.obs_normalization,
"shared_obs_normalizer_stats": shared_obs_normalizer_stats,
"sim_backend": self.sim_backend,
"env_cfg_override": self.env_cfg_override,
"obs_dim": self.obs_dim,
"action_dim": self.action_dim,
"actor_kwargs": self.actor_kwargs,
"seed": derive_worker_seed(self.seed, worker_index=0),
"trace_enabled": self.trace_enabled,
"trace_thread_time": self.trace_thread_time,
"collector_pack_request_queue": collector_pack_request_queue,
"collector_pack_ready_queue": collector_pack_ready_queue,
"collector_pack_shared_slots": collector_pack_shared_slots,
}
self._start_collector(
target_fn=off_policy_collector_fn,
kwargs={"stop_event": self._stop_event, **collector_kwargs},
)
time.sleep(0.5)
if self._collector_process:
print(
f"[DoubleBufferRunner] Collector process alive: "
f"{self._collector_process.is_alive()}"
)
# --- logger ---
logger = OffPolicyLogger(
algo_name=f"Fast{self.algo_type.upper()}",
max_iterations=max_iterations,
num_envs=self.num_envs,
env_name=self.env_name,
obs_dim=self.obs_dim,
action_dim=self.action_dim,
log_dir=log_dir,
log_backend=logger_type,
)
logger.set_collection_sync(self.sync_collection, self.env_steps_per_sync)
if hasattr(self.learner, "use_symmetry") and self.learner.use_symmetry:
logger.log_status("Symmetry augmentation: enabled")
logger.log_status("Replay pipeline: cpu_pinned_double_buffer")
logger.log_status(f"Replay prefetch mode: {self.replay_prefetch_mode}")
logger.log_status(f"Replay pack layout: {self.replay_pack_layout}")
logger.log_status(f"Replay pack executor: {self.replay_pack_executor}")
logger.log_status(f"Replay H2D submitter: {self.replay_h2d_submitter}")
if self.replay_transfer_backend:
logger.log_status(
"Replay transfer backend: "
f"{self.replay_transfer_backend.get('backend')} "
f"({self.replay_transfer_backend.get('device_family')})"
)
logger.log_status(
f"Replay learner lightweight: fixed (log_interval={self.LEARNER_LOG_INTERVAL})"
)
if self.verbose_metrics:
logger.log_status("Verbose metrics: enabled (field-level pack CSV)")
logger.start()
reward_history: deque = deque(maxlen=100)
latest_reward_components: dict[str, float] = {}
last_buf_log = 0
write_read_ema = 0.0
reward_stats_ptr = 0
train_start_threshold = self.train_start_threshold
prepared_tick: int | None = None
training_e2e_start_ns = time.perf_counter_ns() if trace_recorder else 0
# ---- training loop ----
for iteration in range(1, max_iterations + 1):
# -- wait for data --
wait_start = time.time()
wait_start_ns = time.perf_counter_ns()
if self.sync_collection and collection_ready_queue:
import queue
while True:
try:
collection_ready_queue.get(timeout=1.0)
except queue.Empty:
if not self._check_collector_alive():
self._drain_metrics(
metrics_queue,
reward_history,
latest_reward_components,
logger,
trace_recorder,
)
logger.log_status("[red]ERROR: Collector died[/]")
logger.finish()
self.last_run_summary = self._make_summary(
"collector_died",
iteration,
logger,
None,
None,
ckpt_path,
train_start_wall,
None,
)
replay_pipeline.close()
return
continue
self._drain_metrics(
metrics_queue,
reward_history,
latest_reward_components,
logger,
trace_recorder,
)
cur_size = int(replay_buffer.size[0])
if replay_buffer_ready_for_learning(
cur_size,
batch_size=self.batch_size,
learning_starts=self.learning_starts,
num_envs=self.num_envs,
):
if prepared_tick != iteration:
replay_pipeline.start_prepare(iteration, sample_count)
prepared_tick = iteration
break
if cur_size - last_buf_log >= self.num_envs * 10:
last_buf_log = cur_size
logger.log_buffer_fill(cur_size, train_start_threshold)
if trainer_done_queue:
trainer_done_queue.put(1)
else:
while not replay_buffer_ready_for_learning(
int(replay_buffer.size[0]),
batch_size=self.batch_size,
learning_starts=self.learning_starts,
num_envs=self.num_envs,
):
if not self._check_collector_alive():
self._drain_metrics(
metrics_queue,
reward_history,
latest_reward_components,
logger,
)
logger.log_status("[red]ERROR: Collector died[/]")
logger.finish()
self.last_run_summary = self._make_summary(
"collector_died",
iteration,
logger,
None,
None,
ckpt_path,
train_start_wall,
None,
)
replay_pipeline.close()
return
cur_size = int(replay_buffer.size[0])
if cur_size - last_buf_log >= self.num_envs * 10:
last_buf_log = cur_size
logger.log_buffer_fill(cur_size, train_start_threshold)
time.sleep(0.1)
self._drain_metrics(
metrics_queue,
reward_history,
latest_reward_components,
logger,
trace_recorder,
)
wait_time = time.time() - wait_start
if trace_recorder:
trace_recorder.add_slice(
"learner/wait_for_data",
category="learner",
start_ns=wait_start_ns,
end_ns=time.perf_counter_ns(),
args={"iteration": iteration},
)
self._drain_metrics(
metrics_queue,
reward_history,
latest_reward_components,
logger,
trace_recorder,
)
_reward_stats_ns = time.perf_counter_ns()
reward_stats_ptr = self._update_reward_stats_from_replay(
replay_buffer,
reward_stats_ptr,
int(replay_buffer.ptr[0]),
)
if trace_recorder:
trace_recorder.add_slice(
"learner/update_reward_stats",
category="learner",
start_ns=_reward_stats_ns,
end_ns=time.perf_counter_ns(),
)
# -- train --
iter_metrics = defaultdict(list)
ptr_before = int(replay_buffer.ptr[0])
collector_released_for_next = False
learner = self.learner
with nullcontext():
_sample_ns = time.perf_counter_ns()
batch_ready = replay_pipeline.batch_ready(iteration, sample_count)
_wait_batch_ns = time.perf_counter_ns()
if not batch_ready:
batch_ready = replay_pipeline.wait_until_ready(iteration, sample_count)
if trace_recorder:
trace_recorder.add_slice(
"learner/wait_for_replay_batch",
category="learner",
start_ns=_wait_batch_ns,
end_ns=time.perf_counter_ns(),
args={"iteration": iteration, "batch_ready": batch_ready},
)
large_batch = replay_pipeline.sample_large_batch(
tick_id=iteration,
sample_count=sample_count,
)
learner_incremental_h2d_time = float(
getattr(replay_pipeline, "last_incremental_h2d_time_s", 0.0)
)
if iteration < max_iterations:
min_snapshot_ptr = int(replay_buffer.ptr[0]) + (
self.num_envs * self.env_steps_per_sync
)
replay_pipeline.start_prepare(
iteration + 1,
sample_count,
min_snapshot_ptr=min_snapshot_ptr,
)
if self.sync_collection and trainer_done_queue:
trainer_done_queue.put(1)
collector_released_for_next = True
prepared_tick = iteration + 1
if trace_recorder:
trace_recorder.add_slice(
"learner/replay_sample",
category="learner",
start_ns=_sample_ns,
end_ns=time.perf_counter_ns(),
args={
"total_batch": sample_count,
"pipeline": "cpu_pinned_double_buffer",
"batch_ready": batch_ready,
"prefetch_mode": self.replay_prefetch_mode,
"replay_pack_layout": self.replay_pack_layout,
"replay_pack_executor": self.replay_pack_executor,
"replay_h2d_submitter": self.replay_h2d_submitter,
"replay_transfer_backend": self.replay_transfer_backend,
"prepared_tick": prepared_tick,
"explicit_compute_stream": False,
},
)
train_start = time.time()
for update_idx in range(self.updates_per_step):
s = update_idx * self.batch_size
e = s + self.batch_size
batch = {k: v[s:e] for k, v in large_batch.items()}
_critic_ns = time.perf_counter_ns()
critic_metrics = learner.update_critic(batch)
if trace_recorder:
trace_recorder.add_slice(
"learner/update_critic",
category="learner",
start_ns=_critic_ns,
end_ns=time.perf_counter_ns(),
args={"update_idx": update_idx},
)
for k, v in critic_metrics.items():
iter_metrics[k].append(v)
if update_idx % self.policy_frequency == 0:
_actor_ns = time.perf_counter_ns()
actor_metrics = learner.update_actor(batch)
if trace_recorder:
trace_recorder.add_slice(
"learner/update_actor",
category="learner",
start_ns=_actor_ns,
end_ns=time.perf_counter_ns(),
args={"update_idx": update_idx},
)
for k, v in actor_metrics.items():
iter_metrics[k].append(v)
_target_ns = time.perf_counter_ns()
learner.soft_update_target()
if trace_recorder:
trace_recorder.add_slice(
"learner/soft_update_target",
category="learner",
start_ns=_target_ns,
end_ns=time.perf_counter_ns(),
args={"update_idx": update_idx},
)
replay_pipeline.after_tick()
if self.obs_normalization and getattr(self.learner, "obs_normalizer", None) is not None:
assert shared_obs_normalizer_stats is not None
shared_obs_normalizer_stats.put(
(
self.learner.obs_normalizer.mean.cpu().numpy(),
self.learner.obs_normalizer.std.cpu().numpy(),
)
)
train_time = time.time() - train_start
self.learner.update_count += 1
_ws_ns = time.perf_counter_ns()
weight_sync_start = time.perf_counter()
weight_sync.write_weights(self.learner.actor.state_dict())
weight_sync_time = time.perf_counter() - weight_sync_start
if trace_recorder:
trace_recorder.add_slice(
"learner/weight_sync_write",
category="learner",
start_ns=_ws_ns,
end_ns=time.perf_counter_ns(),
args={"mode": "sync"},
)
trace_recorder.add_counter(
"replay_size",
int(replay_buffer.size[0]),
category="replay",
)
if self.sync_collection and trainer_done_queue and not collector_released_for_next:
trainer_done_queue.put(1)
write_delta = int(replay_buffer.ptr[0]) - ptr_before
consume = self.batch_size * self.updates_per_step
write_read_ema = 0.9 * write_read_ema + 0.1 * (write_delta / max(consume, 1))
logger.update_buffer_utilization(write_read_ema)
avg_metrics = {k: statistics.mean(v) for k, v in iter_metrics.items() if v}
mean_reward = statistics.mean(reward_history) if reward_history else 0.0
last_mean_reward = float(mean_reward)
best_mean_reward = max(best_mean_reward, last_mean_reward)
if (
iteration == 1
or iteration == max_iterations
or iteration % self.LEARNER_LOG_INTERVAL == 0
):
logger.log_step(
iteration=iteration,
metrics=avg_metrics,
reward=mean_reward,
reward_metrics=build_reward_comparison_metrics(reward_history, mean_reward),
reward_components=latest_reward_components,
train_time=train_time,
wait_time=wait_time,
learner_incremental_h2d_time=learner_incremental_h2d_time,
weight_sync_time=weight_sync_time,
extra_info={
"throughput_steps": self.num_envs * self.env_steps_per_sync,
},
)
if save_interval > 0 and iteration % save_interval == 0:
ckpt_path = os.path.join(log_dir, f"model_{iteration}.pt")
torch.save(self.learner.get_state_dict(), ckpt_path)
logger.log_save(ckpt_path)
if trace_recorder:
trace_recorder.add_slice(
"learner/training_e2e",
category="learner",
start_ns=training_e2e_start_ns,
end_ns=time.perf_counter_ns(),
args={
"iterations": iteration,
"pipeline": "cpu_pinned_double_buffer",
"replay_h2d_submitter": self.replay_h2d_submitter,
"replay_transfer_backend": self.replay_transfer_backend,
"learner_log_interval": self.LEARNER_LOG_INTERVAL,
},
)
# -- finalize --
replay_pipeline.close()
ckpt_path = os.path.join(log_dir, f"model_{max_iterations}.pt")
torch.save(self.learner.get_state_dict(), ckpt_path)
logger.log_save(ckpt_path)
logger.finish()
if trace_recorder and trace_output_path:
trace_recorder.write_json(trace_output_path)
print(f"[DoubleBufferRunner] Perfetto trace written to {trace_output_path}")
self.last_run_summary = self._make_summary(
"completed",
iteration,
logger,
last_mean_reward if reward_history else None,
best_mean_reward if reward_history else None,
ckpt_path,
train_start_wall,
str(trace_output_path) if trace_output_path else None,
)
@staticmethod
def _make_summary(
status,
iteration,
logger,
final_reward,
best_reward,
ckpt_path,
train_start_wall,
trace_path,
) -> dict:
return {
"status": status,
"completed_iterations": iteration,
"total_env_steps": int(logger._total_steps),
"final_mean_reward": final_reward,
"best_mean_reward": best_reward,
"mean_episode_length": float(logger._mean_ep_length),
"last_checkpoint": ckpt_path,
"trace_path": trace_path,
"training_wall_time_sec": time.time() - train_start_wall,
}