"""Double-buffer replay pipeline for packed CPU replay samples.
CUDA uses the original CPU-pinned shared slot + native async H2D fast path.
Non-CUDA devices use the same collector-thread packing contract and a
portable torch copy into the device batch slots.
"""
from __future__ import annotations
import csv
import os
import queue
import threading
import time
from typing import Dict, List, Tuple
import torch
from unilab.ipc.replay_buffer import ReplayBuffer
from unilab.ipc.replay_pipelines.base import ReplayTickMetadata
from unilab.ipc.replay_pipelines.transfer import build_replay_transfer_backend
[docs]
class CPUPinnedDoubleBufferReplayPipeline:
"""Double-buffered packed replay batch pipeline.
CUDA keeps the pinned-host → GPU fast path. MPS/CPU keep the same
collector-thread pack and hot/cold batch contract with a portable torch
copy into the learner device slot.
"""
[docs]
def __init__(
self,
replay_buffer: ReplayBuffer,
*,
device: str,
sample_count: int,
base_seed: int = 0,
trace_recorder=None,
trace_cuda_events: bool = True,
verbose: bool = False,
verbose_output_dir: str | None = None,
collector_pack_request_queue=None,
collector_pack_ready_queue=None,
collector_pack_shared_slots=None,
) -> None:
self._replay_buffer = replay_buffer
self._device = torch.device(device)
self._device_type = self._device.type
self._sample_count = sample_count
self._base_seed = base_seed
self._trace_recorder = trace_recorder
self._trace_cuda_events = bool(trace_cuda_events) and self._device_type == "cuda"
self._verbose = bool(verbose)
self._verbose_output_dir = verbose_output_dir if self._verbose else None
self._pack_layout = "packed"
self._pack_executor = "collector_thread"
self._ring_depth = 2
self._transfer_backend = build_replay_transfer_backend(
device=self._device,
ring_depth=self._ring_depth,
)
self._trace_cuda_events = bool(trace_cuda_events) and (
self._transfer_backend.supports_timing_events
)
self._h2d_submitter = self._transfer_backend.h2d_submitter
self._host_pinned = self._transfer_backend.host_pinned
self._direct_pinned_shared = self._transfer_backend.direct_pinned_shared
self._device_family = self._transfer_backend.device_family
self._host_memory_kind = self._transfer_backend.host_memory_kind
self._supports_async_submit = self._transfer_backend.supports_async_submit
if not getattr(replay_buffer, "_packed_cpu_storage", False):
raise ValueError("pack_layout='packed' requires ReplayBuffer(packed_cpu_storage=True)")
if (
collector_pack_request_queue is None
or collector_pack_ready_queue is None
or collector_pack_shared_slots is None
):
raise ValueError(
"collector_thread pack executor requires collector pack IPC queues and slots"
)
self._verbose_pack_records: List[Tuple[int, int, str, int, int, int, int]] | None = (
[] if self._verbose else None
)
self._collector_pack_request_queue = collector_pack_request_queue
self._collector_pack_ready_queue = collector_pack_ready_queue
self._collector_pack_shared_slots = collector_pack_shared_slots
self._fields: Dict[str, tuple[torch.Tensor, int]] = {}
self._packed_width = int(replay_buffer._storage.shape[1])
self._host_packed: list[torch.Tensor] = []
self._register_collector_shared_slots()
self._gpu_packed = self._transfer_backend.allocate_device_slots(
count=self._ring_depth,
shape=(self._sample_count, self._packed_width),
dtype=torch.float32,
)
self._host: list[Dict[str, torch.Tensor]] = []
self._gpu: list[Dict[str, torch.Tensor]] = []
self._hot = 0
self._cold = 1
self._has_hot_batch = False
self._hot_metadata: ReplayTickMetadata | None = None
self._prepared_metadata: ReplayTickMetadata | None = None
self._prepare_tick_id: int | None = None
self._prepare_state = "idle"
self._prepare_error: BaseException | None = None
self.last_incremental_h2d_time_s = 0.0
self._prepare_condition = threading.Condition()
self._closed = False
self._collector_h2d_thread = threading.Thread(
target=self._collector_h2d_worker,
name="replay_collector_h2d",
daemon=True,
)
self._collector_h2d_thread.start()
@property
def h2d_submitter(self) -> str:
return self._h2d_submitter
@property
def transfer_manifest(self) -> dict[str, object]:
return {
"backend": type(self._transfer_backend).__name__,
"device": str(self._device),
"device_family": self._device_family,
"host_memory_kind": self._host_memory_kind,
"host_pinned": self._host_pinned,
"direct_pinned_shared": self._direct_pinned_shared,
"supports_async_submit": self._supports_async_submit,
"supports_timing_events": self._transfer_backend.supports_timing_events,
"h2d_submitter": self._h2d_submitter,
"ring_depth": self._ring_depth,
}
# -- allocation helpers --------------------------------------------------
def _register_collector_shared_slots(self) -> None:
assert self._collector_pack_shared_slots is not None
self._transfer_backend.register_host_slots(self._collector_pack_shared_slots)
self._host_pinned = self._transfer_backend.host_pinned
self._direct_pinned_shared = self._transfer_backend.direct_pinned_shared
def _unregister_collector_shared_slots(self) -> None:
self._transfer_backend.close()
self._host_pinned = self._transfer_backend.host_pinned
self._direct_pinned_shared = self._transfer_backend.direct_pinned_shared
def _packed_h2d_source(self, slot: int) -> torch.Tensor:
assert self._collector_pack_shared_slots is not None
return self._collector_pack_shared_slots[slot]
def _packed_batch_view(self, packed: torch.Tensor) -> Dict[str, torch.Tensor]:
rb = self._replay_buffer
batch = {
"obs": packed[:, rb._obs_sl],
"next_obs": packed[:, rb._nobs_sl],
"actions": packed[:, rb._act_sl],
"rewards": packed[:, rb._rew_col],
"dones": packed[:, rb._done_col],
"truncated": packed[:, rb._trunc_col],
}
if rb._critic_dim > 0:
batch["critic"] = packed[:, rb._critic_sl]
batch["next_critic"] = packed[:, rb._ncritic_sl]
return batch
# -- snapshot / H2D --------------------------------------------------------
def _snapshot(self) -> tuple[int, int]:
return int(self._replay_buffer.ptr[0]), int(self._replay_buffer.size[0])
def _submit_h2d(self, slot: int, metadata: ReplayTickMetadata | None = None) -> float:
self._clear_ready(slot)
return self._transfer_backend.submit_h2d(
slot=slot,
dst=self._gpu_packed[slot],
src=self._packed_h2d_source(slot),
metadata=metadata,
trace_recorder=self._trace_recorder,
trace_cuda_events=self._trace_cuda_events,
h2d_bytes=self._h2d_bytes(),
pack_layout=self._pack_layout,
pack_executor=self._pack_executor,
)
def _submit_collector_packed_h2d(self, ready: dict) -> ReplayTickMetadata:
metadata = ReplayTickMetadata(
tick_id=int(ready["tick_id"]),
snapshot_ptr=int(ready["snapshot_ptr"]),
snapshot_size=int(ready["snapshot_size"]),
sample_seed=int(ready["sample_seed"]),
sample_count=int(ready["sample_count"]),
batch_host_slot=int(ready["shared_slot"]),
batch_gpu_slot=int(ready["target_gpu_slot"]),
)
slot = metadata.batch_gpu_slot
assert slot is not None
shared_slot = int(ready["shared_slot"])
if shared_slot != slot:
raise RuntimeError("collector_thread shared slot must match target GPU slot")
self._submit_device_transfer(metadata)
return metadata
def _submit_device_transfer(self, metadata: ReplayTickMetadata) -> None:
slot = metadata.batch_gpu_slot
shared_slot = metadata.batch_host_slot
assert slot is not None
assert shared_slot is not None
h2d_submit_ns = time.perf_counter_ns()
self.last_incremental_h2d_time_s = self._submit_h2d(slot, metadata)
if self._trace_recorder is not None:
self._trace_recorder.add_slice(
"replay_pipeline/batch_h2d_submit",
category="replay_pipeline",
start_ns=h2d_submit_ns,
end_ns=time.perf_counter_ns(),
args={
"tick_id": metadata.tick_id,
"batch_gpu_slot": slot,
"shared_slot": shared_slot,
"pack_layout": self._pack_layout,
"pack_executor": self._pack_executor,
"h2d_submitter": self._h2d_submitter,
"h2d_bytes": self._h2d_bytes(),
"h2d_submitted": True,
"pinned_memory": self._host_pinned,
"direct_pinned_shared": self._direct_pinned_shared,
"device_family": self._device_family,
"host_memory_kind": self._host_memory_kind,
"supports_async_submit": self._supports_async_submit,
"supports_timing_events": self._transfer_backend.supports_timing_events,
"ring_depth": self._ring_depth,
"transfer_worker_submit": True,
},
)
def _ensure_device_transfer_ready(self, metadata: ReplayTickMetadata) -> None:
slot = metadata.batch_gpu_slot
assert slot is not None
if self._ready_query(slot):
return
self._synchronize_ready(slot)
def _collector_h2d_worker(self) -> None:
while True:
if self._closed:
return
try:
ready = self._collector_pack_ready_queue.get(timeout=0.1)
except queue.Empty:
continue
if ready is None:
return
try:
metadata = self._submit_collector_packed_h2d(ready)
with self._prepare_condition:
if self._prepare_tick_id != metadata.tick_id:
raise RuntimeError(
f"Collector packed tick {metadata.tick_id} does not match "
f"pending tick {self._prepare_tick_id}"
)
self._prepared_metadata = metadata
self._prepare_state = "h2d_submitted"
self._prepare_error = None
self._prepare_condition.notify_all()
except BaseException as exc:
with self._prepare_condition:
self._prepare_error = exc
self._prepare_condition.notify_all()
def _h2d_bytes(self) -> int:
source = self._packed_h2d_source(0)
return int(source.numel() * source.element_size())
def _clear_ready(self, slot: int) -> None:
self._transfer_backend.clear_ready(slot)
def _ready_query(self, slot: int) -> bool:
return self._transfer_backend.ready_query(slot)
def _synchronize_ready(self, slot: int) -> None:
self._transfer_backend.synchronize_ready(slot)
def _wait_current_stream_for_ready(self, slot: int) -> None:
self._transfer_backend.wait_current_stream_for_ready(slot)
wait_copy_time_s = float(getattr(self._transfer_backend, "last_wait_copy_time_s", 0.0))
if wait_copy_time_s > 0.0:
self.last_incremental_h2d_time_s = wait_copy_time_s
# -- public API -----------------------------------------------------------
def _validate_sample_count(self, sample_count: int) -> None:
if int(sample_count) != int(self._sample_count):
raise ValueError("sample_count must match the value used to allocate the double buffer")
def _refresh_prepare_state(self) -> None:
if self._prepare_error is not None:
raise self._prepare_error
if self._prepared_metadata is not None:
slot = self._prepared_metadata.batch_gpu_slot
if slot is not None and self._ready_query(slot):
self._prepare_state = "ready"
[docs]
def start_prepare(
self,
tick_id: int,
sample_count: int,
min_snapshot_ptr: int | None = None,
) -> bool:
"""Start CPU pack + device transfer for the current cold slot.
Returns True when this call launches new work. If the same tick is
already pending or prepared, returns False.
"""
self._validate_sample_count(sample_count)
if self._closed:
raise RuntimeError("Cannot prepare replay batch after pipeline.close()")
self._refresh_prepare_state()
active_tick = self._prepare_tick_id
if self._prepared_metadata is not None or self._prepare_state not in {"idle", "ready"}:
prepared_tick = (
self._prepared_metadata.tick_id
if self._prepared_metadata is not None
else active_tick
)
if prepared_tick == int(tick_id):
return False
raise RuntimeError(
"Cannot prepare a new replay batch before the previous batch is consumed"
)
slot = self._cold
self._clear_ready(slot)
self._prepare_tick_id = int(tick_id)
self._prepare_error = None
snapshot_ptr, snapshot_size = self._snapshot()
sample_seed = self._base_seed + int(tick_id)
min_snapshot_ptr = snapshot_ptr if min_snapshot_ptr is None else int(min_snapshot_ptr)
request = {
"tick_id": int(tick_id),
"snapshot_ptr": snapshot_ptr,
"snapshot_size": snapshot_size,
"min_snapshot_ptr": min_snapshot_ptr,
"sample_seed": sample_seed,
"sample_count": self._sample_count,
"shared_slot": slot,
"learner_hot_gpu_slot": self._hot,
"target_gpu_slot": slot,
"pack_layout": self._pack_layout,
"pack_executor": self._pack_executor,
}
if self._trace_recorder is not None:
_req_ns = time.perf_counter_ns()
self._trace_recorder.add_slice(
"replay_pipeline/collector_pack_request",
category="replay_pipeline",
start_ns=_req_ns,
end_ns=time.perf_counter_ns(),
args=request,
)
self._prepare_state = "collector_pack_requested"
self._collector_pack_request_queue.put(request)
return True
[docs]
def batch_ready(self, tick_id: int, sample_count: int) -> bool:
self._validate_sample_count(sample_count)
if self._has_hot_batch:
if self._hot_metadata is not None and self._hot_metadata.tick_id != int(tick_id):
return False
return True
self._refresh_prepare_state()
if self._prepared_metadata is None:
return False
if self._prepared_metadata.tick_id != int(tick_id):
return False
return self._prepare_state == "ready"
[docs]
def wait_ready(self) -> None:
return None
[docs]
def wait_until_ready(self, tick_id: int, sample_count: int) -> bool:
self._validate_sample_count(sample_count)
metadata = self._prepared_or_wait(tick_id)
slot = metadata.batch_gpu_slot
assert slot is not None
self._ensure_device_transfer_ready(metadata)
self._synchronize_ready(slot)
self._prepare_state = "ready"
return True
def _prepared_or_wait(self, tick_id: int) -> ReplayTickMetadata:
self._refresh_prepare_state()
if self._prepared_metadata is None:
if self._prepare_tick_id is None:
self.start_prepare(tick_id, self._sample_count)
with self._prepare_condition:
while self._prepared_metadata is None and self._prepare_error is None:
self._prepare_condition.wait(timeout=0.1)
if self._prepare_error is not None:
raise self._prepare_error
assert self._prepared_metadata is not None
return self._prepared_metadata
if self._prepared_metadata.tick_id != int(tick_id):
raise RuntimeError(
f"Prepared replay batch tick {self._prepared_metadata.tick_id} "
f"does not match requested tick {tick_id}"
)
return self._prepared_metadata
[docs]
def sample_large_batch(self, tick_id: int, sample_count: int) -> Dict[str, torch.Tensor]:
self._validate_sample_count(sample_count)
if self._has_hot_batch:
if self._hot_metadata is not None and self._hot_metadata.tick_id != int(tick_id):
raise RuntimeError(
f"Hot batch tick {self._hot_metadata.tick_id} does not match "
f"requested tick {tick_id}"
)
return self._packed_batch_view(self._gpu_packed[self._hot])
if not self._has_hot_batch:
if not self.batch_ready(tick_id, sample_count):
self.wait_until_ready(tick_id, sample_count)
metadata = self._prepared_or_wait(tick_id)
slot = metadata.batch_gpu_slot
assert slot is not None
_t0 = time.perf_counter_ns()
self._wait_current_stream_for_ready(slot)
if self._trace_recorder is not None:
_wait_end = time.perf_counter_ns()
self._trace_recorder.add_slice(
"replay_pipeline/batch_h2d_wait",
category="replay_pipeline",
start_ns=_t0,
end_ns=_wait_end,
args={"tick_id": tick_id, "batch_gpu_slot": slot},
)
self._trace_recorder.add_slice(
"replay_pipeline/gpu_wait_for_batch",
category="replay_pipeline",
start_ns=_t0,
end_ns=_wait_end,
args={"tick_id": tick_id, "batch_gpu_slot": slot},
)
_swap_ns = time.perf_counter_ns()
old_hot = self._hot
old_cold = self._cold
if slot != self._cold:
raise RuntimeError("Prepared replay batch is not in the current cold slot")
self._hot, self._cold = self._cold, self._hot
if self._trace_recorder is not None:
self._trace_recorder.add_slice(
"replay_pipeline/hot_cold_swap",
category="replay_pipeline",
start_ns=_swap_ns,
end_ns=time.perf_counter_ns(),
args={
"tick_id": tick_id,
"old_hot": old_hot,
"old_cold": old_cold,
"new_hot": self._hot,
"new_cold": self._cold,
},
)
self._has_hot_batch = True
self._hot_metadata = metadata
self._prepared_metadata = None
self._prepare_tick_id = None
self._prepare_state = "idle"
return self._packed_batch_view(self._gpu_packed[self._hot])
[docs]
def after_tick(self) -> None:
self._has_hot_batch = False
self._hot_metadata = None
[docs]
def close(self) -> None:
self._closed = True
if self._collector_pack_ready_queue is not None:
try:
self._collector_pack_ready_queue.put_nowait(None)
except Exception:
pass
if self._collector_h2d_thread is not None:
self._collector_h2d_thread.join(timeout=2.0)
if self._prepared_metadata is not None:
slot = self._prepared_metadata.batch_gpu_slot
if slot is not None:
self._synchronize_ready(slot)
self._unregister_collector_shared_slots()
if self._verbose and self._verbose_output_dir and self._verbose_pack_records:
try:
verbose_dir = os.path.join(self._verbose_output_dir, "verbose")
os.makedirs(verbose_dir, exist_ok=True)
csv_path = os.path.join(verbose_dir, "pack_fields.csv")
with open(csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["tick_id", "slot", "field", "rows", "cols", "bytes", "dur_ns"])
for row in self._verbose_pack_records:
writer.writerow(row)
except OSError:
pass
self._host.clear()
self._gpu.clear()
if hasattr(self, "_host_packed"):
self._host_packed.clear()
if hasattr(self, "_gpu_packed"):
self._gpu_packed.clear()