Source code for unilab.ipc.replay_pipelines.transfer.cuda_like

"""CUDA/ROCm replay transfer backend."""

from __future__ import annotations

import time
from typing import Any, cast

import torch

from unilab.ipc.replay_pipelines.base import ReplayTickMetadata


[docs] class CudaLikeReplayTransferBackend: """Pinned host to CUDA-like device transfer backend. PyTorch ROCm exposes the same ``torch.cuda`` surface for runtime streams and events, so this backend intentionally keys on the PyTorch device type instead of NVIDIA-specific platform names. """ host_memory_kind = "registered_pinned_shared" supports_async_submit = True supports_timing_events = True
[docs] def __init__(self, *, device: torch.device, ring_depth: int) -> None: self.device = device torch_version = getattr(torch, "version", None) self.device_family = "rocm" if getattr(torch_version, "hip", None) else "cuda" self.h2d_submitter = "torch_copy_stream" if self.device_family == "rocm" else "pybind11" self._ring_depth = int(ring_depth) self._cudart: Any = torch.cuda.cudart() if self._cudart is None: raise RuntimeError("torch.cuda.cudart() is required for replay host registration") self._copy_stream = torch.cuda.Stream(device=device) self._ready_events = [torch.cuda.Event() for _ in range(self._ring_depth)] self._registered_shared_slots: list[torch.Tensor] = [] self._registered_shared_ptrs: list[int] = [] self.host_pinned = False self.direct_pinned_shared = False if self.h2d_submitter == "pybind11": from unilab.ipc.replay_pipelines.native_h2d import get_diagnostic, is_available if not is_available(): import sys print( f"[ReplayTransfer] Native H2D unavailable, using torch_copy_stream.\n" f" Reason: {get_diagnostic()}\n" f" Performance impact: negligible for pinned-memory transfers.", file=sys.stderr, flush=True, ) self.h2d_submitter = "torch_copy_stream"
[docs] def register_host_slots(self, slots: list[torch.Tensor]) -> None: for slot in slots: nbytes = int(slot.numel() * slot.element_size()) result = self._cudart.cudaHostRegister(int(slot.data_ptr()), nbytes, 0) if result != self._cudart.cudaError.success: raise RuntimeError(f"cudaHostRegister failed for collector replay slot: {result}") if not slot.is_pinned(): self._cudart.cudaHostUnregister(int(slot.data_ptr())) raise RuntimeError("cudaHostRegister did not make collector replay slot pinned") self._registered_shared_slots.append(slot) self._registered_shared_ptrs.append(int(slot.data_ptr())) self.host_pinned = True self.direct_pinned_shared = True
[docs] def allocate_device_slots( self, *, count: int, shape: tuple[int, int], dtype: torch.dtype, ) -> list[torch.Tensor]: return [torch.empty(shape, dtype=dtype, device=self.device) for _ in range(count)]
[docs] def submit_h2d( self, *, slot: int, dst: torch.Tensor, src: torch.Tensor, metadata: ReplayTickMetadata | None, trace_recorder, trace_cuda_events: bool, h2d_bytes: int, pack_layout: str, pack_executor: str, ) -> float: h2d_begin_ns = time.perf_counter_ns() start_event = None end_event = None record_cuda = trace_recorder is not None and trace_cuda_events with torch.cuda.device(self.device): copy_stream = cast(torch.cuda.Stream, self._copy_stream) with torch.cuda.stream(copy_stream): if record_cuda: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) cast(Any, start_event).record() if self.h2d_submitter == "pybind11": from unilab.ipc.replay_pipelines.native_h2d import submit_h2d submit_h2d(dst, src, copy_stream) else: dst.copy_(src, non_blocking=True) if end_event is not None: cast(Any, end_event).record() self._ready_events[slot].record(copy_stream) h2d_end_ns = time.perf_counter_ns() if record_cuda and start_event is not None and end_event is not None: args: dict[str, object] = { "slot": slot, "h2d_bytes": h2d_bytes, "pinned_memory": self.host_pinned, "pack_layout": pack_layout, "pack_executor": pack_executor, "h2d_submitter": self.h2d_submitter, "direct_pinned_shared": self.direct_pinned_shared, } if metadata is not None: args.update( { "tick_id": int(metadata.tick_id), "snapshot_ptr": int(metadata.snapshot_ptr), "snapshot_size": int(metadata.snapshot_size), "sample_seed": int(metadata.sample_seed), "sample_count": int(metadata.sample_count), "batch_host_slot": metadata.batch_host_slot, "batch_gpu_slot": metadata.batch_gpu_slot, } ) trace_recorder.add_cuda_pending_span( "gpu/replay_pipeline_batch_h2d", category="gpu", cpu_begin_ns=h2d_begin_ns, start_event=cast(Any, start_event), end_event=cast(Any, end_event), args=args, ) return (h2d_end_ns - h2d_begin_ns) / 1e9
[docs] def clear_ready(self, slot: int) -> None: del slot return None
[docs] def ready_query(self, slot: int) -> bool: return bool(self._ready_events[slot].query())
[docs] def synchronize_ready(self, slot: int) -> None: self._ready_events[slot].synchronize()
[docs] def wait_current_stream_for_ready(self, slot: int) -> None: current_stream = cast(Any, torch.cuda.current_stream(self.device)) current_stream.wait_event(self._ready_events[slot])
[docs] def close(self) -> None: while self._registered_shared_ptrs: ptr = self._registered_shared_ptrs.pop() try: self._cudart.cudaHostUnregister(int(ptr)) except Exception: pass self._registered_shared_slots.clear() self.host_pinned = False self.direct_pinned_shared = False