Source code for unilab.ipc.weight_sync
"""Shared weight synchronization for actor networks."""
from __future__ import annotations
import multiprocessing as mp
import time
from multiprocessing import shared_memory
from typing import Any, Dict
import numpy as np
_SPAWN_CTX = mp.get_context("spawn")
[docs]
class SharedWeightSync:
"""Synchronize actor weights between learner and collector."""
[docs]
def __init__(
self, param_shapes: Dict, *, create: bool = True, shm_name: str | None = None, lock=None
):
self._param_shapes = param_shapes
self._param_names = list(param_shapes.keys())
self.trace_recorder: Any | None = None
self.trace_thread_time = False
total_numel = sum(s.numel() for s in param_shapes.values())
_f32 = np.dtype(np.float32).itemsize
_i64 = np.dtype(np.int64).itemsize
data_bytes = total_numel * _f32
meta_bytes = _i64
total_bytes = data_bytes + meta_bytes
if create:
self._shm = shared_memory.SharedMemory(create=True, size=max(total_bytes, 1))
self._lock = _SPAWN_CTX.Lock()
else:
assert shm_name is not None
self._shm = shared_memory.SharedMemory(name=shm_name, create=False)
# lock must be passed in from the parent process when attaching
self._lock = lock
buf = self._shm.buf
assert buf is not None
self._buffer: np.ndarray = np.ndarray((total_numel,), dtype=np.float32, buffer=buf)
self._version_arr: np.ndarray = np.ndarray((1,), dtype=np.int64, buffer=buf[data_bytes:])
if create:
self._version_arr[0] = 0
@property
def name(self) -> str:
return self._shm.name
@property
def version(self) -> int:
return int(self._version_arr[0])
[docs]
@classmethod
def from_state_dict(cls, state_dict, **kwargs):
param_shapes = {name: p.shape for name, p in state_dict.items()}
obj = cls(param_shapes, **kwargs)
obj.write_weights(state_dict)
return obj
[docs]
def write_weights(self, state_dict) -> None:
_trace_ns = time.perf_counter_ns()
_thread_ns = time.thread_time_ns() if self.trace_thread_time else None
if self._lock is not None:
with self._lock:
offset = 0
for name in self._param_names:
param = state_dict[name]
arr = param.detach().cpu().numpy().ravel()
n = arr.size
self._buffer[offset : offset + n] = arr
offset += n
self._version_arr[0] += 1
else:
# No lock - direct write
offset = 0
for name in self._param_names:
param = state_dict[name]
arr = param.detach().cpu().numpy().ravel()
n = arr.size
self._buffer[offset : offset + n] = arr
offset += n
self._version_arr[0] += 1
if self.trace_recorder is not None:
self.trace_recorder.add_slice(
"weight_sync/write_weights_d2h",
category="weight_sync",
start_ns=_trace_ns,
end_ns=time.perf_counter_ns(),
args={"version": int(self._version_arr[0]), "mode": "sync"},
)
if _thread_ns is not None:
self.trace_recorder.add_counter(
"weight_sync/write_thread_cpu_us",
(time.thread_time_ns() - _thread_ns) / 1000.0,
category="weight_sync",
)
[docs]
def read_weights_into(self, state_dict) -> int:
import torch
_trace_ns = time.perf_counter_ns()
_thread_ns = time.thread_time_ns() if self.trace_thread_time else None
if self._lock is not None:
with self._lock:
offset = 0
for name in self._param_names:
param = state_dict[name]
n = param.numel()
data = self._buffer[offset : offset + n].copy()
param.data.copy_(torch.from_numpy(data.reshape(param.shape)))
offset += n
version = int(self._version_arr[0])
else:
# No lock - direct read (for subprocess)
offset = 0
for name in self._param_names:
param = state_dict[name]
n = param.numel()
data = self._buffer[offset : offset + n].copy()
param.data.copy_(torch.from_numpy(data.reshape(param.shape)))
offset += n
version = int(self._version_arr[0])
if self.trace_recorder is not None:
self.trace_recorder.add_slice(
"weight_sync/read_weights_into_cpu_actor",
category="weight_sync",
start_ns=_trace_ns,
end_ns=time.perf_counter_ns(),
args={"version": version},
)
if _thread_ns is not None:
self.trace_recorder.add_counter(
"weight_sync/read_thread_cpu_us",
(time.thread_time_ns() - _thread_ns) / 1000.0,
category="weight_sync",
)
return version
[docs]
def cleanup(self) -> None:
try:
self._shm.close()
self._shm.unlink()
except Exception:
pass
[docs]
def close(self) -> None:
try:
self._shm.close()
except Exception:
pass