Source code for unilab.algos.torch.common.base_collector
"""Base collector class for shared functionality."""
from __future__ import annotations
from collections import defaultdict
from typing import Any
import numpy as np
from unilab.ipc import SharedWeightSync
[docs]
class BaseCollector:
"""Base class for collectors with common weight sync and episode tracking."""
[docs]
def __init__(
self,
env_name: str,
num_envs: int,
weight_sync_name: str,
weight_sync_lock,
weight_param_shapes: dict,
metrics_queue,
stop_event,
):
self.env_name = env_name
self.num_envs = num_envs
self.metrics_queue = metrics_queue
self.stop_event = stop_event
# Weight sync
self.weight_sync = SharedWeightSync(
weight_param_shapes, create=False, shm_name=weight_sync_name, lock=weight_sync_lock
)
self.local_weight_version = 0
# Episode tracking
self.episode_rewards: list[float] = []
self.episode_lengths: list[int] = []
self.current_episode_rewards = np.zeros(num_envs, dtype=np.float32)
self.current_episode_lengths = np.zeros(num_envs, dtype=np.int32)
self.ep_reward_components: defaultdict[str, list[float]] = defaultdict(list)
# Timing
self.timing_accum_ms: defaultdict[str, float] = defaultdict(float)
self.timing_count = 0
[docs]
def sync_weights_if_needed(self):
"""Check and sync weights if updated."""
if self.weight_sync.version > self.local_weight_version:
sd = self._get_state_dict_template()
self.local_weight_version = self.weight_sync.read_weights_into(sd)
self._load_state_dict(sd)
[docs]
def track_episode(self, rewards, dones, state=None):
"""Track episode statistics."""
self.current_episode_rewards += rewards
self.current_episode_lengths += 1
for i in range(self.num_envs):
if dones[i] > 0:
self.episode_rewards.append(float(self.current_episode_rewards[i]))
self.episode_lengths.append(int(self.current_episode_lengths[i]))
self.current_episode_rewards[i] = 0
self.current_episode_lengths[i] = 0
# Abstract methods
def _get_state_dict_template(self) -> dict:
raise NotImplementedError
def _load_state_dict(self, sd: dict):
raise NotImplementedError