Source code for unilab.algos.torch.flash_sac.double_buffer

"""FlashSAC builder for the CPU-pinned double-buffer replay path."""

from __future__ import annotations

from typing import Any

from omegaconf import DictConfig

from unilab.algos.torch.flash_sac.learner import FlashSACLearner
from unilab.algos.torch.offpolicy.double_buffer_runner import DoubleBufferOffPolicyRunner
from unilab.training import create_env, ensure_registries
from unilab.training.seed import apply_training_seed
from unilab.utils.device import get_default_device


def _validate_flashsac_double_buffer_runtime(
    cfg: DictConfig,
    *,
    device: str,
    replay_prefetch_mode: str,
) -> None:
    _ = device
    if cfg.training.num_gpus > 1:
        raise ValueError("FlashSAC-B cpu_pinned_double_buffer is single-GPU only")
    if cfg.training.no_sync_collection:
        raise ValueError("FlashSAC-B cpu_pinned_double_buffer requires synchronized collection")
    if replay_prefetch_mode != "one_tick":
        raise ValueError(
            "FlashSAC-B cpu_pinned_double_buffer requires replay_prefetch_mode='one_tick'"
        )
    if cfg.algo.algo_params.n_step != 1:
        raise ValueError("FlashSAC-B initially supports n_step=1 only")


[docs] def build_flashsac_double_buffer_runner( cfg: DictConfig, *, env_cfg_override: dict[str, Any] | None, replay_prefetch_mode: str, verbose_metrics: bool, ) -> DoubleBufferOffPolicyRunner: """Build FlashSAC with the opt-in CPU-pinned double-buffer replay pipeline.""" from unilab.base.observations import get_obs_dims ensure_registries() apply_training_seed(cfg.algo.seed, torch_runtime=True, cuda=True) device = cfg.training.device or get_default_device() _validate_flashsac_double_buffer_runtime( cfg, device=device, replay_prefetch_mode=replay_prefetch_mode, ) env = create_env(cfg, num_envs=1, env_cfg_override=env_cfg_override) try: obs_dim, critic_obs_dim = get_obs_dims(env.obs_groups_spec) action_shape = env.action_space.shape assert action_shape is not None action_dim = int(action_shape[0]) finally: env.close() learner = FlashSACLearner( obs_dim=obs_dim, action_dim=action_dim, critic_obs_dim=critic_obs_dim, device=device, gamma=cfg.algo.gamma, tau=cfg.algo.tau, actor_lr=cfg.algo.actor_lr, critic_lr=cfg.algo.critic_lr, actor_hidden_dim=cfg.algo.actor_hidden_dim, critic_hidden_dim=cfg.algo.critic_hidden_dim, actor_num_blocks=cfg.algo.algo_params.actor_num_blocks, critic_num_blocks=cfg.algo.algo_params.critic_num_blocks, num_atoms=cfg.algo.num_atoms, critic_min_v=cfg.algo.algo_params.critic_min_v, critic_max_v=cfg.algo.algo_params.critic_max_v, temp_initial_value=cfg.algo.algo_params.temp_initial_value, temp_target_sigma=cfg.algo.algo_params.temp_target_sigma, temp_target_entropy=cfg.algo.algo_params.temp_target_entropy, actor_bc_alpha=cfg.algo.algo_params.actor_bc_alpha, actor_noise_zeta_mu=cfg.algo.algo_params.actor_noise_zeta_mu, actor_noise_zeta_max=cfg.algo.algo_params.actor_noise_zeta_max, learning_rate_init=cfg.algo.algo_params.learning_rate_init, learning_rate_peak=cfg.algo.algo_params.learning_rate_peak, learning_rate_end=cfg.algo.algo_params.learning_rate_end, learning_rate_warmup_steps=cfg.algo.algo_params.learning_rate_warmup_steps, learning_rate_decay_steps=cfg.algo.algo_params.learning_rate_decay_steps, normalize_reward=cfg.algo.algo_params.normalize_reward, normalized_g_max=cfg.algo.algo_params.normalized_g_max, n_step=cfg.algo.algo_params.n_step, obs_normalization=cfg.algo.obs_normalization, use_amp=cfg.training.use_amp, amp_dtype=cfg.algo.algo_params.amp_dtype, use_compile=cfg.algo.algo_params.use_compile, ) return DoubleBufferOffPolicyRunner( learner=learner, env_name=cfg.training.task_name, algo_type="flashsac", num_envs=cfg.algo.num_envs, replay_buffer_n=cfg.algo.replay_buffer_n, batch_size=cfg.algo.batch_size, learning_starts=cfg.algo.learning_starts, updates_per_step=cfg.algo.updates_per_step, policy_frequency=cfg.algo.policy_frequency, sync_collection=not cfg.training.no_sync_collection, env_steps_per_sync=cfg.training.env_steps_per_sync, device=device, actor_hidden_dim=cfg.algo.actor_hidden_dim, use_layer_norm=False, obs_normalization=cfg.algo.obs_normalization, sim_backend=cfg.training.sim_backend, env_cfg_override=env_cfg_override, actor_kwargs={ "actor_num_blocks": cfg.algo.algo_params.actor_num_blocks, "actor_noise_zeta_mu": cfg.algo.algo_params.actor_noise_zeta_mu, "actor_noise_zeta_max": cfg.algo.algo_params.actor_noise_zeta_max, }, seed=cfg.algo.seed, trace_enabled=cfg.training.trace_enabled, trace_output_dir=cfg.training.trace_output_dir, trace_thread_time=cfg.training.trace_thread_time, trace_cuda_events=cfg.training.trace_cuda_events, replay_prefetch_mode=replay_prefetch_mode, verbose_metrics=verbose_metrics, )