Source code for unilab.ipc.replay_pipelines.transfer.xpu
"""Intel XPU replay transfer backend."""
from __future__ import annotations
import time
from contextlib import nullcontext
from typing import Any
import torch
from unilab.ipc.replay_pipelines.base import ReplayTickMetadata
[docs]
class XpuReplayTransferBackend:
"""XPU stream/event backend for packed replay batch transfer."""
device_family = "xpu"
h2d_submitter = "torch_xpu_copy_stream"
host_memory_kind = "pageable_shared"
host_pinned = False
direct_pinned_shared = False
supports_async_submit = True
supports_timing_events = False
[docs]
def __init__(self, *, device: torch.device, ring_depth: int) -> None:
xpu = getattr(torch, "xpu", None)
required = ("Stream", "Event", "stream", "current_stream")
if xpu is None or any(getattr(xpu, name, None) is None for name in required):
raise RuntimeError("XPU replay transfer requires torch.xpu Stream/Event support")
self.device = device
self._xpu: Any = xpu
self._ring_depth = int(ring_depth)
self._copy_stream = xpu.Stream(device=device)
self._ready_events = [xpu.Event() for _ in range(self._ring_depth)]
self._submitted = [False for _ in range(self._ring_depth)]
[docs]
def register_host_slots(self, slots: list[torch.Tensor]) -> None:
del slots
return None
[docs]
def allocate_device_slots(
self,
*,
count: int,
shape: tuple[int, int],
dtype: torch.dtype,
) -> list[torch.Tensor]:
return [torch.empty(shape, dtype=dtype, device=self.device) for _ in range(count)]
[docs]
def submit_h2d(
self,
*,
slot: int,
dst: torch.Tensor,
src: torch.Tensor,
metadata: ReplayTickMetadata | None,
trace_recorder,
trace_cuda_events: bool,
h2d_bytes: int,
pack_layout: str,
pack_executor: str,
) -> float:
del metadata, trace_recorder, trace_cuda_events, h2d_bytes, pack_layout, pack_executor
h2d_begin_ns = time.perf_counter_ns()
self.clear_ready(slot)
device_context = getattr(self._xpu, "device", None)
context: Any = device_context(self.device) if callable(device_context) else nullcontext()
with context:
with self._xpu.stream(self._copy_stream):
dst.copy_(src, non_blocking=True)
self._ready_events[slot].record(self._copy_stream)
self._submitted[slot] = True
return (time.perf_counter_ns() - h2d_begin_ns) / 1e9
[docs]
def clear_ready(self, slot: int) -> None:
self._submitted[slot] = False
[docs]
def ready_query(self, slot: int) -> bool:
return bool(self._submitted[slot]) and bool(self._ready_events[slot].query())
[docs]
def synchronize_ready(self, slot: int) -> None:
if self._submitted[slot]:
self._ready_events[slot].synchronize()
[docs]
def wait_current_stream_for_ready(self, slot: int) -> None:
if not self._submitted[slot]:
return
current_stream = self._xpu.current_stream(self.device)
wait_event = getattr(current_stream, "wait_event", None)
if callable(wait_event):
wait_event(self._ready_events[slot])
else:
self._ready_events[slot].synchronize()
[docs]
def close(self) -> None:
return None