Source code for unilab.ipc.replay_pipelines.transfer.torch_copy
"""Portable torch-copy replay transfer backend."""
from __future__ import annotations
import threading
import time
import torch
from unilab.ipc.replay_pipelines.base import ReplayTickMetadata
[docs]
class TorchCopyReplayTransferBackend:
"""Portable transfer backend for CPU, MPS, and other torch devices."""
h2d_submitter = "torch_copy"
host_memory_kind = "pageable_shared"
host_pinned = False
direct_pinned_shared = False
supports_async_submit = False
supports_timing_events = False
[docs]
def __init__(self, *, device: torch.device, ring_depth: int) -> None:
self.device = device
self.device_family = device.type
self._ring_depth = int(ring_depth)
self._ready_events = [threading.Event() for _ in range(self._ring_depth)]
self._defer_copy_to_wait = device.type == "mps"
self._pending_copies: list[tuple[torch.Tensor, torch.Tensor] | None] = [
None for _ in range(self._ring_depth)
]
self.last_wait_copy_time_s = 0.0
if self._defer_copy_to_wait:
self.h2d_submitter = "torch_copy_main_thread"
[docs]
def register_host_slots(self, slots: list[torch.Tensor]) -> None:
del slots
return None
[docs]
def allocate_device_slots(
self,
*,
count: int,
shape: tuple[int, int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
return [torch.empty(shape, dtype=dtype, device=self.device) for _ in range(count)]
[docs]
def submit_h2d(
self,
*,
slot: int,
dst: torch.Tensor,
src: torch.Tensor,
metadata: ReplayTickMetadata | None,
trace_recorder,
trace_cuda_events: bool,
h2d_bytes: int,
pack_layout: str,
pack_executor: str,
) -> float:
del metadata, trace_recorder, trace_cuda_events, h2d_bytes, pack_layout, pack_executor
h2d_begin_ns = time.perf_counter_ns()
self.clear_ready(slot)
self.last_wait_copy_time_s = 0.0
if self._defer_copy_to_wait:
# PyTorch MPS command submission from a background transfer thread
# can trip Metal command-buffer assertions. Keep the collector CPU
# pack overlap, but submit the actual MPS copy from the learner
# thread when the batch is consumed.
self._pending_copies[slot] = (dst, src)
self._ready_events[slot].set()
return (time.perf_counter_ns() - h2d_begin_ns) / 1e9
dst.copy_(src, non_blocking=src.is_pinned())
self._synchronize_device()
self._ready_events[slot].set()
return (time.perf_counter_ns() - h2d_begin_ns) / 1e9
[docs]
def clear_ready(self, slot: int) -> None:
self._ready_events[slot].clear()
self._pending_copies[slot] = None
self.last_wait_copy_time_s = 0.0
[docs]
def ready_query(self, slot: int) -> bool:
return self._ready_events[slot].is_set()
[docs]
def synchronize_ready(self, slot: int) -> None:
self._ready_events[slot].wait()
[docs]
def wait_current_stream_for_ready(self, slot: int) -> None:
self.synchronize_ready(slot)
pending = self._pending_copies[slot]
if pending is None:
return
dst, src = pending
copy_begin_ns = time.perf_counter_ns()
dst.copy_(src, non_blocking=False)
self._synchronize_device()
self.last_wait_copy_time_s = (time.perf_counter_ns() - copy_begin_ns) / 1e9
self._pending_copies[slot] = None
[docs]
def close(self) -> None:
for slot in range(len(self._pending_copies)):
self._pending_copies[slot] = None
self.last_wait_copy_time_s = 0.0
return None
def _synchronize_device(self) -> None:
if self.device.type == "mps" and hasattr(torch, "mps"):
torch.mps.synchronize()
return
if self.device.type == "xpu" and hasattr(torch, "xpu"):
xpu = torch.xpu
synchronize = getattr(xpu, "synchronize", None)
if synchronize is not None:
try:
synchronize(self.device)
except TypeError:
synchronize()