Source code for unilab.ipc.shared_obs_stats

"""Shared observation normalization statistics for multi-process training."""

import sys


[docs] class SharedObsNormStats: """Synchronize observation normalization statistics between learner and collector. Uses a queue to pass (mean, std) tuples from learner to collector. """
[docs] def __init__(self, ctx): self.q = ctx.Queue(maxsize=2) self.last_stats = None
[docs] def put(self, stats): """Put new stats, clearing old ones first.""" while not self.q.empty(): try: self.q.get_nowait() except Exception as e: print(f"[SharedObsNormStats] queue error: {e}", file=sys.stderr) self.q.put(stats)
[docs] def get(self): """Get latest stats, returns None if no new stats.""" try: while not self.q.empty(): self.last_stats = self.q.get_nowait() except Exception as e: print(f"[SharedObsNormStats] queue error: {e}", file=sys.stderr) return self.last_stats