"""Off-policy collector for SAC and TD3.
Collects (obs, action, reward, next_obs, done) transitions using the current
actor policy. Runs in a subprocess; writes to ReplayBuffer.
"""
import queue
import sys
import time
from typing import cast
import numpy as np
import torch
from unilab.algos.torch.common.actor_factory import build_actor
from unilab.base.final_observation import resolve_terminal_observation_contract
from unilab.base.observations import get_obs_dims, split_obs_dict
from unilab.base.registry import ensure_registries
from unilab.training.seed import apply_training_seed
# Exclusive phases for one collector loop iteration (one vectorized env.step).
# Every key is recorded once per iteration so the reported averages share one
# denominator and can be summed without double counting.
COLLECTOR_TIMING_KEYS = (
"weight_sync_ms",
"action_select_ms",
"env_step_ms",
"replay_ms",
"sync_coordination_ms",
)
[docs]
def resolve_collector_actor_dims(
env,
obs_dim: int | None = None,
action_dim: int | None = None,
) -> tuple[int, int]:
"""Resolve actor dims for the collector.
Prefer explicit dims from the parent process so learner and collector
build identical actor shapes on override-heavy env paths.
"""
if obs_dim is None:
obs_dim, _ = get_obs_dims(env.obs_groups_spec)
if action_dim is None:
assert env.action_space.shape is not None
action_dim = env.action_space.shape[0]
assert obs_dim is not None
assert action_dim is not None
return obs_dim, action_dim
[docs]
def sample_offpolicy_actions(
actor,
algo_type: str,
obs_torch: torch.Tensor,
prev_dones_torch: torch.Tensor,
priv_info_torch: torch.Tensor | None = None,
) -> torch.Tensor:
"""Sample collector actions using the algorithm's exploration policy."""
if algo_type in ("sac", "td3", "flashsac"):
return cast(
torch.Tensor,
actor.explore(obs_torch, dones=prev_dones_torch, deterministic=False),
)
if algo_type == "hora_sac":
if priv_info_torch is None:
raise ValueError("HORA-SAC collector action sampling requires priv_info_torch.")
return cast(
torch.Tensor,
actor.explore(obs_torch, priv_info_torch, deterministic=False),
)
raise ValueError(f"Unsupported off-policy algo_type for collector action sampling: {algo_type}")
[docs]
def resolve_offpolicy_actor_priv_info(
*,
algo_type: str,
obs_np: np.ndarray,
critic_np: np.ndarray,
info: dict | None,
) -> np.ndarray | None:
"""Resolve optional collector-side actor context for privileged off-policy actors."""
if algo_type != "hora_sac":
return None
from unilab.algos.torch.hora.observations import split_hora_obs_with_priv_info
_, _, priv_info_np = split_hora_obs_with_priv_info(
{"obs": obs_np, "critic": critic_np},
info,
)
if priv_info_np is None:
raise ValueError(
"HORA-SAC collector requires privileged info from info['critic_info'] "
"or the critic observation tail."
)
return np.asarray(priv_info_np, dtype=np.float32)
def _record_timing_ms(timing_accum_ms, timing_counts, key: str, value: float) -> None:
timing_accum_ms[key] += float(value)
timing_counts[key] += 1
def _record_phase_ms(cycle_timing_ms: dict[str, float], key: str, start_ns: int) -> int:
end_ns = time.perf_counter_ns()
cycle_timing_ms[key] += (end_ns - start_ns) / 1e6
return end_ns
def _collector_pack_shared_batch(replay_buffer, request: dict, shared_slots) -> dict:
tick_id = int(request["tick_id"])
snapshot_ptr = int(replay_buffer.ptr[0])
snapshot_size = int(replay_buffer.size[0])
sample_seed = int(request["sample_seed"])
sample_count = int(request["sample_count"])
shared_slot = int(request["shared_slot"])
target_gpu_slot = int(request["target_gpu_slot"])
learner_hot_gpu_slot = int(request["learner_hot_gpu_slot"])
if target_gpu_slot == learner_hot_gpu_slot:
raise RuntimeError(
"collector_thread pack target_gpu_slot must differ from learner_hot_gpu_slot"
)
pack_begin_ns = time.perf_counter_ns()
gen = torch.Generator(device="cpu")
gen.manual_seed(sample_seed)
indices = torch.randint(0, snapshot_size, (sample_count,), generator=gen)
dst = shared_slots[shared_slot]
torch.index_select(replay_buffer._storage, 0, indices, out=dst)
pack_end_ns = time.perf_counter_ns()
return {
"tick_id": tick_id,
"snapshot_ptr": snapshot_ptr,
"snapshot_size": snapshot_size,
"sample_seed": sample_seed,
"sample_count": sample_count,
"shared_slot": shared_slot,
"target_gpu_slot": target_gpu_slot,
"learner_hot_gpu_slot": learner_hot_gpu_slot,
"pack_layout": "packed",
"pack_executor": "collector_thread",
"pack_begin_ns": pack_begin_ns,
"pack_end_ns": pack_end_ns,
}
def _service_collector_pack_requests(
replay_buffer,
request_queue,
ready_queue,
shared_slots,
trace_recorder=None,
*,
block_timeout: float = 0.0,
pending_request: dict | None = None,
) -> tuple[bool, dict | None]:
if request_queue is None or ready_queue is None or shared_slots is None:
return False, pending_request
request = pending_request
if request is None:
try:
request = (
request_queue.get(timeout=block_timeout)
if block_timeout > 0
else request_queue.get_nowait()
)
except queue.Empty:
return False, None
if request is None:
return False, None
min_snapshot_ptr = int(request.get("min_snapshot_ptr", 0))
if int(replay_buffer.ptr[0]) < min_snapshot_ptr:
return False, request
ready = _collector_pack_shared_batch(replay_buffer, request, shared_slots)
if trace_recorder:
trace_recorder.add_slice(
"collector/cpu_pack_sample_batch",
category="collector",
start_ns=int(ready["pack_begin_ns"]),
end_ns=int(ready["pack_end_ns"]),
args={
"tick_id": int(ready["tick_id"]),
"sample_count": int(ready["sample_count"]),
"shared_slot": int(ready["shared_slot"]),
"target_gpu_slot": int(ready["target_gpu_slot"]),
"learner_hot_gpu_slot": int(ready["learner_hot_gpu_slot"]),
"pack_layout": "packed",
"pack_executor": "collector_thread",
"shared_memory": True,
"pinned_memory": False,
},
)
ready_queue.put(ready)
return True, None
[docs]
def off_policy_collector_fn(
stop_event,
env_name: str,
num_envs: int,
replay_buffer,
weight_sync_name: str,
weight_param_shapes: dict,
algo_type: str = "sac",
actor_hidden_dim: int = 512,
use_layer_norm: bool = True,
learning_starts: int = 0,
metrics_queue=None,
weight_sync_lock=None,
sync_collection: bool = False,
collection_ready_queue=None,
trainer_done_queue=None,
env_steps_per_sync: int = 1,
obs_normalization: bool = False,
shared_obs_normalizer_stats=None,
sim_backend: str = "mujoco",
env_cfg_override: dict | None = None,
obs_dim: int | None = None,
action_dim: int | None = None,
actor_kwargs: dict | None = None,
seed: int | None = None,
trace_enabled: bool = False,
trace_thread_time: bool = False,
collector_pack_request_queue=None,
collector_pack_ready_queue=None,
collector_pack_shared_slots=None,
**kwargs,
):
"""Entry point for the off-policy collector subprocess.
Error handling is provided by ``_collector_entry_wrapper`` in
``async_runner.py``.
"""
print("[Collector] Entry point called", file=sys.stderr, flush=True)
_run_collector(
stop_event=stop_event,
env_name=env_name,
num_envs=num_envs,
replay_buffer=replay_buffer,
weight_sync_name=weight_sync_name,
weight_param_shapes=weight_param_shapes,
algo_type=algo_type,
actor_hidden_dim=actor_hidden_dim,
use_layer_norm=use_layer_norm,
learning_starts=learning_starts,
metrics_queue=metrics_queue,
weight_sync_lock=weight_sync_lock,
sync_collection=sync_collection,
collection_ready_queue=collection_ready_queue,
trainer_done_queue=trainer_done_queue,
env_steps_per_sync=env_steps_per_sync,
obs_normalization=obs_normalization,
shared_obs_normalizer_stats=shared_obs_normalizer_stats,
sim_backend=sim_backend,
env_cfg_override=env_cfg_override,
obs_dim=obs_dim,
action_dim=action_dim,
actor_kwargs=actor_kwargs,
seed=seed,
trace_enabled=trace_enabled,
trace_thread_time=trace_thread_time,
collector_pack_request_queue=collector_pack_request_queue,
collector_pack_ready_queue=collector_pack_ready_queue,
collector_pack_shared_slots=collector_pack_shared_slots,
)
def _run_collector(
stop_event,
env_name,
num_envs,
replay_buffer,
weight_sync_name,
weight_param_shapes,
algo_type,
actor_hidden_dim,
use_layer_norm,
learning_starts,
metrics_queue,
weight_sync_lock,
sync_collection,
collection_ready_queue,
trainer_done_queue,
env_steps_per_sync,
obs_normalization,
shared_obs_normalizer_stats,
sim_backend,
env_cfg_override,
obs_dim,
action_dim,
actor_kwargs,
seed,
trace_enabled,
trace_thread_time,
collector_pack_request_queue,
collector_pack_ready_queue,
collector_pack_shared_slots,
):
del learning_starts
from unilab.base import registry
from unilab.ipc import SharedWeightSync
ensure_registries()
apply_training_seed(seed, torch_runtime=True, cuda=True)
trace_recorder = None
if trace_enabled:
from unilab.logging.trace_event import TraceRecorder
trace_recorder = TraceRecorder("offpolicy_collector")
# Initialize environment
env = registry.make(
env_name, num_envs=num_envs, sim_backend=sim_backend, env_cfg_override=env_cfg_override
)
if env.state is None:
env.init_state()
# Connect to weight sync
weight_sync = SharedWeightSync(
weight_param_shapes, create=False, shm_name=weight_sync_name, lock=weight_sync_lock
)
weight_sync.trace_recorder = trace_recorder
weight_sync.trace_thread_time = trace_thread_time
# Build actor (always on CPU for env interaction)
obs_dim, action_dim = resolve_collector_actor_dims(
env,
obs_dim=obs_dim,
action_dim=action_dim,
)
actor = build_actor(
algo_type,
obs_dim,
action_dim,
actor_hidden_dim,
use_layer_norm,
"cpu",
num_envs,
**(actor_kwargs or {}),
)
actor.eval()
replay_buffer.trace_recorder = trace_recorder
replay_buffer.trace_thread_time = trace_thread_time
# Load initial weights
sd = dict(actor.state_dict())
weight_sync.read_weights_into(sd)
actor.load_state_dict(sd)
local_weight_version = weight_sync.version
total_steps = 0
ep_rewards = []
ep_lengths = []
current_ep_rewards = np.zeros(num_envs, dtype=np.float32)
current_ep_lengths = np.zeros(num_envs, dtype=np.int32)
from collections import defaultdict
ep_reward_components = defaultdict(list)
timing_accum_ms = defaultdict(float)
timing_counts = defaultdict(int)
done_count_window = 0
timeout_count_window = 0
terminated_count_window = 0
# Initial step to get first observation
actions_np = np.zeros((num_envs, action_dim), dtype=np.float32)
state = env.step(actions_np)
obs_np, critic_np = split_obs_dict(state.obs)
obs_np = np.asarray(obs_np, dtype=np.float32)
critic_np = np.asarray(critic_np, dtype=np.float32)
info_dict = state.info
prev_dones_np = np.zeros(num_envs, dtype=np.float32)
import time as _time
_last_log_time = _time.time()
# Track env.step calls collected since the last learner phase.
env_steps_since_sync = 0
pending_collector_pack_request = None
# Collection loop
while not stop_event.is_set():
cycle_timing_ms: dict[str, float] = dict.fromkeys(COLLECTOR_TIMING_KEYS, 0.0)
phase_start_ns = _time.perf_counter_ns()
# Check for weight updates
if weight_sync.version > local_weight_version:
_wt_ns = _time.perf_counter_ns()
sd = dict(actor.state_dict())
local_weight_version = weight_sync.read_weights_into(sd)
actor.load_state_dict(sd)
if trace_recorder:
trace_recorder.add_slice(
"collector/check_weight_update",
category="collector",
start_ns=_wt_ns,
end_ns=_time.perf_counter_ns(),
)
# Update normalizer stats
if obs_normalization and shared_obs_normalizer_stats is not None:
stats = shared_obs_normalizer_stats.get()
if stats is not None:
# Apply stats to a local normalizer if needed, or directly to actor
pass # Handled by EmpiricalNormalization in learner if actor possesses it. We need a local normalizer.
phase_start_ns = _record_phase_ms(cycle_timing_ms, "weight_sync_ms", phase_start_ns)
# Normalize obs_np
obs_np_input = obs_np
if obs_normalization and shared_obs_normalizer_stats is not None:
stats = shared_obs_normalizer_stats.get()
if stats is not None:
mean, std = stats
obs_np_input = (obs_np - mean) / (std + 1e-8)
# Select action
with torch.no_grad():
_t_infer_ns = _time.perf_counter_ns()
obs_torch = torch.from_numpy(obs_np_input)
dones_torch = torch.from_numpy(prev_dones_np)
priv_info_np = resolve_offpolicy_actor_priv_info(
algo_type=algo_type,
obs_np=obs_np,
critic_np=critic_np,
info=info_dict,
)
priv_info_torch = torch.from_numpy(priv_info_np) if priv_info_np is not None else None
actions_torch = sample_offpolicy_actions(
actor=actor,
algo_type=algo_type,
obs_torch=obs_torch,
prev_dones_torch=dones_torch,
priv_info_torch=priv_info_torch,
)
actions_np = actions_torch.numpy()
if trace_recorder:
trace_recorder.add_slice(
"collector/actor_infer_cpu",
category="collector",
start_ns=_t_infer_ns,
end_ns=_time.perf_counter_ns(),
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "action_select_ms", phase_start_ns)
# Step environment
_env_ns = _time.perf_counter_ns()
state = env.step(actions_np)
if trace_recorder:
trace_recorder.add_slice(
"collector/env_step",
category="collector",
start_ns=_env_ns,
end_ns=_time.perf_counter_ns(),
args={"num_envs": num_envs},
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "env_step_ms", phase_start_ns)
# Extract data as numpy
next_obs_np, next_critic_np = split_obs_dict(state.obs)
next_obs_np = np.asarray(next_obs_np, dtype=np.float32)
next_critic_np = np.asarray(next_critic_np, dtype=np.float32)
rewards_np = np.asarray(state.reward, dtype=np.float32).ravel()
terminated_np = state.terminated.astype(np.float32, copy=False).ravel()
truncated_np = state.truncated.astype(np.float32, copy=False).ravel()
combined_dones = (state.terminated | state.truncated).astype(np.float32, copy=False).ravel()
prev_dones_np = combined_dones
done_mask_np = combined_dones > 0.5
timeout_mask_np = truncated_np > 0.5
terminated_mask_np = np.logical_and(terminated_np > 0.5, ~timeout_mask_np)
done_count_window += int(np.count_nonzero(done_mask_np))
timeout_count_window += int(np.count_nonzero(timeout_mask_np))
terminated_count_window += int(np.count_nonzero(terminated_mask_np))
terminal_contract = resolve_terminal_observation_contract(
next_obs_batch_size=next_obs_np.shape[0],
final_observation=state.final_observation,
done=done_mask_np,
info=state.info,
truncated=truncated_np,
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns)
# ReplayBuffer `dones` follows the UniLab env lifecycle contract:
# done = terminated | truncated. Learners use `truncated` to keep
# bootstrap enabled for timeout/truncation rows.
_rb_ns = _time.perf_counter_ns()
replay_buffer.add(
torch.from_numpy(obs_np),
torch.from_numpy(actions_np),
torch.from_numpy(rewards_np),
torch.from_numpy(next_obs_np),
torch.from_numpy(combined_dones),
torch.from_numpy(truncated_np),
terminal_mask=torch.from_numpy(terminal_contract.terminal_mask),
terminal_next_obs=(
torch.from_numpy(terminal_contract.terminal_obs)
if terminal_contract.terminal_obs is not None
else None
),
critic=torch.from_numpy(critic_np),
next_critic=torch.from_numpy(next_critic_np),
terminal_next_critic=(
torch.from_numpy(terminal_contract.terminal_critic)
if terminal_contract.terminal_critic is not None
else None
),
)
if trace_recorder:
trace_recorder.add_slice(
"collector/replay_add",
category="collector",
start_ns=_rb_ns,
end_ns=_time.perf_counter_ns(),
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns)
_, pending_collector_pack_request = _service_collector_pack_requests(
replay_buffer,
collector_pack_request_queue,
collector_pack_ready_queue,
collector_pack_shared_slots,
trace_recorder,
block_timeout=0.0,
pending_request=pending_collector_pack_request,
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns)
# Track episode rewards - vectorized
current_ep_rewards += rewards_np
current_ep_lengths += 1
reset_mask = combined_dones > 0.5
reset_indices = np.where(reset_mask)[0]
if len(reset_indices) > 0:
ep_rewards.extend(current_ep_rewards[reset_indices].tolist())
ep_lengths.extend(current_ep_lengths[reset_indices].tolist())
current_ep_rewards[reset_indices] = 0.0
current_ep_lengths[reset_indices] = 0
obs_np = next_obs_np
critic_np = next_critic_np
info_dict = state.info
total_steps += num_envs
env_steps_since_sync += 1
phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns)
# Signal the learner once this collection chunk is ready.
if (
sync_collection
and collection_ready_queue is not None
and trainer_done_queue is not None
):
if env_steps_since_sync >= env_steps_per_sync:
_sig_ns = _time.perf_counter_ns()
collection_ready_queue.put(1)
if trace_recorder:
trace_recorder.add_slice(
"collector/signal_ready",
category="collector",
start_ns=_sig_ns,
end_ns=_time.perf_counter_ns(),
)
phase_start_ns = _record_phase_ms(
cycle_timing_ms, "sync_coordination_ms", phase_start_ns
)
_wait_ns = _time.perf_counter_ns()
while not stop_event.is_set():
_, pending_collector_pack_request = _service_collector_pack_requests(
replay_buffer,
collector_pack_request_queue,
collector_pack_ready_queue,
collector_pack_shared_slots,
trace_recorder,
block_timeout=0.0,
pending_request=pending_collector_pack_request,
)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "replay_ms", phase_start_ns)
try:
trainer_done_queue.get(timeout=0.001)
phase_start_ns = _record_phase_ms(
cycle_timing_ms, "sync_coordination_ms", phase_start_ns
)
_, pending_collector_pack_request = _service_collector_pack_requests(
replay_buffer,
collector_pack_request_queue,
collector_pack_ready_queue,
collector_pack_shared_slots,
trace_recorder,
block_timeout=0.0,
pending_request=pending_collector_pack_request,
)
phase_start_ns = _record_phase_ms(
cycle_timing_ms, "replay_ms", phase_start_ns
)
break
except queue.Empty:
phase_start_ns = _record_phase_ms(
cycle_timing_ms, "sync_coordination_ms", phase_start_ns
)
continue
if trace_recorder:
trace_recorder.add_slice(
"collector/wait_trainer_done",
category="collector",
start_ns=_wait_ns,
end_ns=_time.perf_counter_ns(),
)
if metrics_queue is not None:
try:
metrics_queue.put_nowait(
{"trace_events": trace_recorder.drain_events()}
)
except Exception:
pass
phase_start_ns = _record_phase_ms(
cycle_timing_ms, "sync_coordination_ms", phase_start_ns
)
env_steps_since_sync = 0
elif env_steps_since_sync >= env_steps_per_sync:
env_steps_since_sync = 0
phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns)
# Progress log every 2 seconds
now = _time.time()
if now - _last_log_time > 2.0:
_last_log_time = now
# Extract reward components from env info
log_info = state.info.get("log", {})
if log_info:
for k, v in log_info.items():
if k.startswith("reward/"):
ep_reward_components[k].append(v)
# Send metrics periodically
if metrics_queue is not None and total_steps % (num_envs * 10) == 0:
import statistics
try:
msg = {
"total_steps": total_steps,
"buffer_size": int(replay_buffer.size[0]),
}
if ep_rewards:
msg["mean_ep_reward"] = statistics.mean(ep_rewards[-100:])
msg["mean_ep_length"] = (
statistics.mean(ep_lengths[-100:]) if ep_lengths else 0.0
)
# Add mean reward components
if ep_reward_components:
components_mean = {}
for k, vals in ep_reward_components.items():
if vals:
components_mean[k] = statistics.mean(vals)
msg["reward_components"] = components_mean
ep_reward_components.clear() # reset after sending
if timing_counts:
msg["collector_timing_ms"] = {
k: (v / timing_counts[k])
for k, v in timing_accum_ms.items()
if timing_counts[k] > 0
}
if done_count_window > 0:
msg["timeout_rate"] = timeout_count_window / done_count_window
msg["terminated_rate"] = terminated_count_window / done_count_window
done_count_window = 0
timeout_count_window = 0
terminated_count_window = 0
if trace_recorder:
msg["trace_events"] = trace_recorder.drain_events()
metrics_queue.put_nowait(msg)
if "collector_timing_ms" in msg:
timing_accum_ms.clear()
timing_counts.clear()
except Exception as e:
print(f"[OffPolicyWorker] metrics enqueue error: {e}", file=sys.stderr)
phase_start_ns = _record_phase_ms(cycle_timing_ms, "sync_coordination_ms", phase_start_ns)
for key in COLLECTOR_TIMING_KEYS:
_record_timing_ms(timing_accum_ms, timing_counts, key, cycle_timing_ms[key])
if metrics_queue is not None and trace_recorder:
try:
metrics_queue.put_nowait({"trace_events": trace_recorder.drain_events()})
except Exception:
pass
weight_sync.close()