Source code for unilab.ipc.rollout_ring_buffer

"""Shared rollout IPC ring buffer for APPO / async PPO."""

from __future__ import annotations

import multiprocessing as mp
from multiprocessing import shared_memory
from typing import Dict

import numpy as np

_SPAWN_CTX = mp.get_context("spawn")

_FIELD_SHAPES = {
    "obs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, od),
    "critic": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, cd),
    "actions": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns, ad),
    "log_probs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
    "rewards": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
    "dones": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
    "truncated": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, ns),
    "last_obs": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, od),
    "last_critic": lambda ns_slots, ne, ns, od, ad, cd: (ns_slots, ne, cd),
}


[docs] class RolloutRingBuffer: """N-slot shared-memory ring buffer for raw rollout payloads."""
[docs] def __init__( self, num_envs: int, num_steps: int, obs_dim: int, action_dim: int, *, critic_dim: int = 0, num_slots: int = 4, create: bool = True, shm_name_prefix: Dict[str, str] | None = None, ): self.num_envs = num_envs self.num_steps = num_steps self.obs_dim = obs_dim self.action_dim = action_dim self.critic_dim = critic_dim self.num_slots = num_slots self._shm_blocks: Dict[str, shared_memory.SharedMemory] = {} self._arrays: Dict[str, np.ndarray] = {} fields_to_allocate = {k: v for k, v in _FIELD_SHAPES.items()} if critic_dim == 0: fields_to_allocate.pop("critic", None) fields_to_allocate.pop("last_critic", None) for field, shape_fn in fields_to_allocate.items(): shape = shape_fn( num_slots, num_envs, num_steps, obs_dim, action_dim, critic_dim, ) nbytes = int(np.prod(shape)) * np.dtype(np.float32).itemsize if create: shm = shared_memory.SharedMemory(create=True, size=max(nbytes, 1)) else: assert shm_name_prefix is not None, "shm_name_prefix required when create=False" shm = shared_memory.SharedMemory(name=shm_name_prefix[field], create=False) self._shm_blocks[field] = shm self._arrays[field] = np.ndarray(shape, dtype=np.float32, buffer=shm.buf) if create: self._write_ptr = _SPAWN_CTX.Value("l", 0) self._read_ptr = _SPAWN_CTX.Value("l", 0)
@property def name(self) -> Dict[str, str]: return {field: shm.name for field, shm in self._shm_blocks.items()} @property def slot_shapes(self) -> Dict[str, tuple[int, ...]]: return {field: tuple(arr.shape[1:]) for field, arr in self._arrays.items()}
[docs] def attach_sync_primitives(self, write_ptr, read_ptr) -> None: self._write_ptr = write_ptr self._read_ptr = read_ptr
def _clamp_read_ptr_to_valid_window(self) -> None: wp = int(self._write_ptr.value) oldest_available = max(0, wp - self.num_slots) if int(self._read_ptr.value) >= oldest_available: return with self._read_ptr.get_lock(): if int(self._read_ptr.value) < oldest_available: self._read_ptr.value = oldest_available @property def write_slot(self) -> int: return int(self._write_ptr.value) % self.num_slots @property def write_buffer(self) -> Dict[str, np.ndarray]: s = self.write_slot return {field: arr[s] for field, arr in self._arrays.items()}
[docs] def signal_write_done(self) -> None: with self._write_ptr.get_lock(): self._write_ptr.value += 1
[docs] def available(self) -> int: self._clamp_read_ptr_to_valid_window() return min(max(0, int(self._write_ptr.value) - int(self._read_ptr.value)), self.num_slots)
[docs] def wait_for_data(self, timeout: float = 60.0) -> bool: import time deadline = time.monotonic() + timeout while self.available() == 0: if time.monotonic() > deadline: return False time.sleep(0.001) return True
@property def read_slot(self) -> int: self._clamp_read_ptr_to_valid_window() return int(self._read_ptr.value) % self.num_slots
[docs] def read_numpy_views(self) -> dict[str, np.ndarray]: """Return shared-memory views for the current read slot. The returned arrays are borrowed views. Consumers must copy them into owned storage before calling advance_read(). """ s = self.read_slot return {field: arr[s] for field, arr in self._arrays.items()}
[docs] def copy_read_slot_to_torch(self, destination: dict) -> None: import torch s = self.read_slot for field, arr in self._arrays.items(): if field not in destination: raise KeyError(f"missing destination tensor for rollout field {field!r}") dst = destination[field] src_view = arr[s] if tuple(dst.shape) != tuple(src_view.shape): raise ValueError( f"destination shape mismatch for {field!r}: " f"expected {tuple(src_view.shape)}, got {tuple(dst.shape)}" ) if dst.dtype != torch.float32: raise TypeError(f"destination tensor for {field!r} must be torch.float32") dst.copy_(torch.from_numpy(src_view), non_blocking=False)
[docs] def read_torch(self, device: str) -> dict: import torch result = { field: torch.empty(tuple(arr.shape[1:]), dtype=torch.float32, device=device) for field, arr in self._arrays.items() } self.copy_read_slot_to_torch(result) return result
[docs] def advance_read(self) -> None: with self._read_ptr.get_lock(): wp = int(self._write_ptr.value) rp = min(int(self._read_ptr.value) + 1, wp) oldest_available = max(0, wp - self.num_slots) self._read_ptr.value = max(rp, oldest_available)
[docs] def cleanup(self) -> None: for shm in self._shm_blocks.values(): try: shm.close() shm.unlink() except Exception: pass
[docs] def close(self) -> None: for shm in self._shm_blocks.values(): try: shm.close() except Exception: pass