Source code for unilab.ipc.memory_budget

"""Memory budget estimation for async RL training buffers.

Pure functions that estimate memory usage and warn if the system
is likely to OOM before allocating large shared buffers.
"""

from __future__ import annotations

import os
import sys


[docs] def estimate_offpolicy_bytes( num_envs: int, replay_buffer_n: int, obs_dim: int, action_dim: int, critic_dim: int, batch_size: int, updates_per_step: int, ) -> dict[str, int | str]: """Estimate memory for off-policy replay buffer + double-buffer slots.""" row_width = 2 * obs_dim + action_dim + 3 + 2 * critic_dim capacity = replay_buffer_n * num_envs replay_bytes = capacity * row_width * 4 sample_count = batch_size * updates_per_step slot_bytes = sample_count * row_width * 4 * 2 total = replay_bytes + slot_bytes return { "replay_buffer": replay_bytes, "double_buffer_slots": slot_bytes, "total": total, "breakdown": ( f"Replay: {replay_bytes / 1024**2:.0f} MB " f"({num_envs} envs × {replay_buffer_n} steps × {row_width} cols × 4B)\n" f" Double-buffer: {slot_bytes / 1024**2:.0f} MB " f"({sample_count} samples × {row_width} cols × 4B × 2 slots)" ), }
[docs] def estimate_appo_bytes( num_envs: int, steps_per_env: int, obs_dim: int, action_dim: int, critic_dim: int, num_slots: int = 4, ) -> dict[str, int | str]: """Estimate memory for APPO rollout ring buffer.""" per_step = obs_dim + action_dim + 1 + 1 + 1 + 1 + critic_dim per_slot = num_envs * steps_per_env * per_step * 4 last_obs_per_slot = num_envs * (obs_dim + critic_dim) * 4 total_per_slot = per_slot + last_obs_per_slot total = total_per_slot * num_slots return { "ring_buffer": total, "total": total, "breakdown": ( f"Ring buffer: {total / 1024**2:.0f} MB " f"({num_slots} slots × {num_envs} envs × {steps_per_env} steps × " f"{per_step} cols × 4B)" ), }
[docs] def get_available_memory_bytes() -> int | None: """Best-effort available memory detection.""" try: with open("/proc/meminfo") as f: for line in f: if line.startswith("MemAvailable:"): return int(line.split()[1]) * 1024 except (OSError, ValueError): pass try: import psutil return int(psutil.virtual_memory().available) except ImportError: pass return None
[docs] def warn_if_over_budget( estimated: dict[str, int | str], label: str, threshold: float = 0.8, ) -> None: """Print a warning if estimated memory exceeds threshold of available.""" if os.environ.get("UNILAB_SKIP_MEMORY_CHECK"): return available = get_available_memory_bytes() if available is None: return total = int(estimated["total"]) ratio = total / available if ratio > threshold: est_gb = total / 1024**3 avail_gb = available / 1024**3 breakdown = estimated.get("breakdown", "") print( f"\n[Memory Warning] {label}: estimated {est_gb:.1f} GB, " f"available {avail_gb:.1f} GB ({ratio:.0%} usage).\n" f" {breakdown}\n" f" Consider reducing algo.num_envs or algo.replay_buffer_n.\n" f" Suppress: export UNILAB_SKIP_MEMORY_CHECK=1\n", file=sys.stderr, flush=True, )