"""Packed shared-memory replay buffer for off-policy RL."""
import time
from typing import Any, Dict
import torch
from unilab.ipc.shared_buffer import SharedBufferBase
[docs]
class ReplayBuffer(SharedBufferBase):
"""Shared replay buffer backed by authoritative packed CPU storage.
Device transfer is owned by replay pipeline transfer backends. The
fallback sample() path copies a sampled packed batch to ``self.device`` and
keeps no per-device replay cache.
"""
[docs]
def __init__(
self,
capacity: int,
obs_dim: int,
action_dim: int,
device: str,
defer_gpu: bool = False,
critic_dim: int = 0,
packed_cpu_storage: bool = False,
):
super().__init__(capacity, device, defer_gpu=defer_gpu)
del packed_cpu_storage
self._obs_dim = obs_dim
self._action_dim = action_dim
self._critic_dim = critic_dim
self.last_incremental_h2d_time_s = 0.0
self._packed_cpu_storage = True
self.trace_recorder: Any | None = None
self.trace_thread_time = False
self.trace_cuda_events = True
self.size = torch.zeros(1, dtype=torch.int64).share_memory_()
self._init_packed_storage(capacity, obs_dim, action_dim, critic_dim)
def _init_packed_storage(
self, capacity: int, obs_dim: int, action_dim: int, critic_dim: int
) -> None:
total_dim = 2 * obs_dim + action_dim + 3 + 2 * critic_dim
self._storage = torch.zeros(capacity, total_dim).share_memory_()
c = 0
self._obs_sl = slice(c, c + obs_dim)
c += obs_dim
self._nobs_sl = slice(c, c + obs_dim)
c += obs_dim
self._act_sl = slice(c, c + action_dim)
c += action_dim
self._rew_col = c
c += 1
self._done_col = c
c += 1
self._trunc_col = c
c += 1
if critic_dim > 0:
self._critic_sl = slice(c, c + critic_dim)
c += critic_dim
self._ncritic_sl = slice(c, c + critic_dim)
c += critic_dim
def __getstate__(self) -> dict:
"""Custom pickle support.
The collector subprocess only calls add(), which writes to the CPU
shared-memory tensor. The original object in the learner process is
unaffected.
"""
state = self.__dict__.copy()
state["trace_recorder"] = None
return state
[docs]
def add(
self,
obs,
actions,
rewards,
next_obs,
dones,
truncated,
terminal_mask=None,
terminal_next_obs=None,
critic=None,
next_critic=None,
terminal_next_critic=None,
):
"""Add batch (called by collector).
`dones` follows the UniLab env lifecycle contract:
done = terminated | truncated. Learners must pair it with
`truncated` when computing bootstrap masks.
"""
_trace_ns = time.perf_counter_ns() if self.trace_recorder is not None else 0
n = obs.shape[0]
idx = int(self.ptr[0]) % self.capacity
has_critic = self._critic_dim > 0 and critic is not None
if self._critic_dim > 0 and (critic is None or next_critic is None):
raise ValueError("ReplayBuffer with critic_dim > 0 requires critic and next_critic")
parts = [
obs,
next_obs,
actions,
rewards.unsqueeze(1),
dones.unsqueeze(1),
truncated.unsqueeze(1),
]
if has_critic:
assert next_critic is not None
parts.extend([critic, next_critic])
row = torch.cat(parts, dim=1)
if idx + n <= self.capacity:
self._storage[idx : idx + n] = row
self._patch_terminal_next_observations(
self._storage[idx : idx + n, self._nobs_sl],
terminal_mask,
terminal_next_obs,
self._storage[idx : idx + n, self._ncritic_sl] if has_critic else None,
terminal_next_critic,
)
else:
split = self.capacity - idx
self._storage[idx:] = row[:split]
self._storage[: n - split] = row[split:]
self._patch_terminal_next_observations(
self._storage[idx:, self._nobs_sl],
terminal_mask[:split] if terminal_mask is not None else None,
terminal_next_obs[:split] if terminal_next_obs is not None else None,
self._storage[idx:, self._ncritic_sl] if has_critic else None,
terminal_next_critic[:split] if terminal_next_critic is not None else None,
)
self._patch_terminal_next_observations(
self._storage[: n - split, self._nobs_sl],
terminal_mask[split:] if terminal_mask is not None else None,
terminal_next_obs[split:] if terminal_next_obs is not None else None,
self._storage[: n - split, self._ncritic_sl] if has_critic else None,
terminal_next_critic[split:] if terminal_next_critic is not None else None,
)
self.ptr[0] += n
self.size[0] = min(int(self.size[0]) + n, self.capacity)
if self.trace_recorder is not None:
self.trace_recorder.add_slice(
"replay/add",
category="replay",
start_ns=_trace_ns,
end_ns=time.perf_counter_ns(),
args={"batch_size": int(n), "device": self.device},
)
@staticmethod
def _patch_terminal_next_observations(
target_next_obs,
terminal_mask,
terminal_next_obs,
target_next_critic=None,
terminal_next_critic=None,
) -> None:
if terminal_mask is None or terminal_next_obs is None:
return
if terminal_mask.ndim != 1 or terminal_mask.shape[0] != target_next_obs.shape[0]:
return
if not torch.any(terminal_mask):
return
target_next_obs[terminal_mask] = terminal_next_obs[terminal_mask]
if target_next_critic is not None and terminal_next_critic is not None:
target_next_critic[terminal_mask] = terminal_next_critic[terminal_mask]
[docs]
def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:
"""Sample batch (called by learner)."""
self.last_incremental_h2d_time_s = 0.0
_trace_ns = time.perf_counter_ns() if self.trace_recorder is not None else 0
size = int(self.size[0])
_indices_ns = time.perf_counter_ns() if self.trace_recorder is not None else 0
indices = torch.randint(0, size, (batch_size,))
if self.trace_recorder is not None:
self.trace_recorder.add_slice(
"replay/sample_indices",
category="replay",
start_ns=_indices_ns,
end_ns=time.perf_counter_ns(),
args={"batch_size": int(batch_size), "size": int(size)},
)
chunk = self._storage[indices].to(self.device)
batch = {
"obs": chunk[:, self._obs_sl],
"next_obs": chunk[:, self._nobs_sl],
"actions": chunk[:, self._act_sl],
"rewards": chunk[:, self._rew_col],
"dones": chunk[:, self._done_col],
"truncated": chunk[:, self._trunc_col],
}
if self._critic_dim > 0:
batch["critic"] = chunk[:, self._critic_sl]
batch["next_critic"] = chunk[:, self._ncritic_sl]
if self.trace_recorder is not None:
self.trace_recorder.add_slice(
"replay/sample",
category="replay",
start_ns=_trace_ns,
end_ns=time.perf_counter_ns(),
args={"batch_size": int(batch_size), "device": self.device},
)
return batch