Source code for unilab.ipc.replay_pipelines.transfer.factory

"""Replay transfer backend factory."""

from __future__ import annotations

import torch

from unilab.ipc.replay_pipelines.transfer.base import ReplayTransferBackend
from unilab.ipc.replay_pipelines.transfer.cuda_like import CudaLikeReplayTransferBackend
from unilab.ipc.replay_pipelines.transfer.torch_copy import TorchCopyReplayTransferBackend
from unilab.ipc.replay_pipelines.transfer.xpu import XpuReplayTransferBackend


[docs] def build_replay_transfer_backend( *, device: torch.device, ring_depth: int, ) -> ReplayTransferBackend: """Build the transfer backend for a learner device.""" if device.type == "cuda": return CudaLikeReplayTransferBackend(device=device, ring_depth=ring_depth) if device.type == "xpu": return XpuReplayTransferBackend(device=device, ring_depth=ring_depth) return TorchCopyReplayTransferBackend(device=device, ring_depth=ring_depth)