"""Shared rollout IPC ring buffer for APPO / async PPO."""
from __future__ import annotations
import multiprocessing as mp
from multiprocessing import shared_memory
from typing import Dict
import numpy as np
_SPAWN_CTX = mp.get_context("spawn")
_FIELD_SHAPES = {
"obs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, od),
"critic": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, cd),
"actions": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, ad),
"log_probs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
"rewards": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
"dones": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
"truncated": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
"last_obs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, od),
"last_critic": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, cd),
}
[docs]
class RolloutRingBuffer:
"""N-slot shared-memory ring buffer for raw rollout payloads."""
[docs]
def __init__(
self,
num_envs: int,
num_steps: int,
obs_dim: int,
action_dim: int,
*,
critic_dim: int = 0,
num_slots: int = 4,
create: bool = True,
shm_name_prefix: Dict[str, str] | None = None,
):
self.num_envs = num_envs
self.num_steps = num_steps
self.obs_dim = obs_dim
self.action_dim = action_dim
self.critic_dim = critic_dim
self.num_slots = num_slots
self._shm_blocks: Dict[str, shared_memory.SharedMemory] = {}
self._arrays: Dict[str, np.ndarray] = {}
fields_to_allocate = {k: v for k, v in _FIELD_SHAPES.items()}
if critic_dim == 0:
fields_to_allocate.pop("critic", None)
fields_to_allocate.pop("last_critic", None)
for field, shape_fn in fields_to_allocate.items():
shape = shape_fn(
num_slots,
num_envs,
num_steps,
obs_dim,
action_dim,
critic_dim,
)
nbytes = int(np.prod(shape)) * np.dtype(np.float32).itemsize
if create:
shm = shared_memory.SharedMemory(create=True, size=max(nbytes, 1))
else:
assert shm_name_prefix is not None, "shm_name_prefix required when create=False"
shm = shared_memory.SharedMemory(name=shm_name_prefix[field], create=False)
self._shm_blocks[field] = shm
self._arrays[field] = np.ndarray(shape, dtype=np.float32, buffer=shm.buf)
if create:
self._write_ptr = _SPAWN_CTX.Value("l", 0)
self._read_ptr = _SPAWN_CTX.Value("l", 0)
@property
def name(self) -> Dict[str, str]:
return {field: shm.name for field, shm in self._shm_blocks.items()}
@property
def slot_shapes(self) -> Dict[str, tuple[int, ...]]:
return {field: tuple(arr.shape[1:]) for field, arr in self._arrays.items()}
[docs]
def attach_sync_primitives(self, write_ptr, read_ptr) -> None:
self._write_ptr = write_ptr
self._read_ptr = read_ptr
def _clamp_read_ptr_to_valid_window(self) -> None:
wp = int(self._write_ptr.value)
oldest_available = max(0, wp - self.num_slots)
if int(self._read_ptr.value) >= oldest_available:
return
with self._read_ptr.get_lock():
if int(self._read_ptr.value) < oldest_available:
self._read_ptr.value = oldest_available
@property
def write_slot(self) -> int:
return int(self._write_ptr.value) % self.num_slots
@property
def write_buffer(self) -> Dict[str, np.ndarray]:
s = self.write_slot
return {field: arr[s] for field, arr in self._arrays.items()}
[docs]
def signal_write_done(self) -> None:
with self._write_ptr.get_lock():
self._write_ptr.value += 1
[docs]
def available(self) -> int:
self._clamp_read_ptr_to_valid_window()
return min(max(0, int(self._write_ptr.value) - int(self._read_ptr.value)), self.num_slots)
[docs]
def wait_for_data(self, timeout: float = 60.0) -> bool:
import time
deadline = time.monotonic() + timeout
while self.available() == 0:
if time.monotonic() > deadline:
return False
time.sleep(0.001)
return True
@property
def read_slot(self) -> int:
self._clamp_read_ptr_to_valid_window()
return int(self._read_ptr.value) % self.num_slots
[docs]
def read_numpy_views(self) -> dict[str, np.ndarray]:
"""Return shared-memory views for the current read slot.
The returned arrays are borrowed views. Consumers must copy them into
owned storage before calling advance_read().
"""
s = self.read_slot
return {field: arr[s] for field, arr in self._arrays.items()}
[docs]
def copy_read_slot_to_torch(self, destination: dict) -> None:
import torch
s = self.read_slot
for field, arr in self._arrays.items():
if field not in destination:
raise KeyError(f"missing destination tensor for rollout field {field!r}")
dst = destination[field]
src_view = arr[s]
if tuple(dst.shape) != tuple(src_view.shape):
raise ValueError(
f"destination shape mismatch for {field!r}: "
f"expected {tuple(src_view.shape)}, got {tuple(dst.shape)}"
)
if dst.dtype != torch.float32:
raise TypeError(f"destination tensor for {field!r} must be torch.float32")
dst.copy_(torch.from_numpy(src_view), non_blocking=False)
[docs]
def read_torch(self, device: str) -> dict:
import torch
result = {
field: torch.empty(tuple(arr.shape[1:]), dtype=torch.float32, device=device)
for field, arr in self._arrays.items()
}
self.copy_read_slot_to_torch(result)
return result
[docs]
def advance_read(self) -> None:
with self._read_ptr.get_lock():
wp = int(self._write_ptr.value)
rp = min(int(self._read_ptr.value) + 1, wp)
oldest_available = max(0, wp - self.num_slots)
self._read_ptr.value = max(rp, oldest_available)
[docs]
def cleanup(self) -> None:
for shm in self._shm_blocks.values():
try:
shm.close()
shm.unlink()
except Exception:
pass
[docs]
def close(self) -> None:
for shm in self._shm_blocks.values():
try:
shm.close()
except Exception:
pass