Source code for unilab.algos.torch.fast_sac.runner

"""FastSAC runner using unified OffPolicyRunner."""

from typing import Any

from unilab.algos.torch.common.device import get_env_dims
from unilab.algos.torch.fast_sac.learner import FastSACLearner
from unilab.algos.torch.offpolicy.runner import OffPolicyRunner
from unilab.utils.device import get_default_device


[docs] class FastSACRunner(OffPolicyRunner): """FastSAC using OffPolicyRunner infrastructure."""
[docs] def __init__( self, env_name: str, env_cfg_override: dict[str, Any] | None = None, device: str | None = None, num_envs: int = 4096, replay_buffer_n: int = 1024, batch_size: int = 8192, learning_starts: int = 0, updates_per_step: int = 8, policy_frequency: int = 4, sync_collection: bool = True, env_steps_per_sync: int = 1, gamma: float = 0.97, tau: float = 0.125, actor_lr: float = 3e-4, critic_lr: float = 3e-4, alpha_lr: float = 3e-4, alpha_init: float = 0.001, target_entropy_ratio: float = 1.0, obs_normalization: bool = True, actor_hidden_dim: int = 512, critic_hidden_dim: int = 768, num_atoms: int = 101, use_layer_norm: bool = True, max_grad_norm: float = 0.0, use_amp: bool = False, amp_dtype: str = "auto", sim_backend: str = "mujoco", use_symmetry: bool = False, world_size: int = 1, seed: int | None = None, trace_enabled: bool = False, trace_output_dir: str | None = None, trace_thread_time: bool = False, trace_cuda_events: bool = True, ): from unilab.base import registry from unilab.base.registry import ensure_registries from unilab.training.seed import apply_training_seed ensure_registries() apply_training_seed(seed, torch_runtime=True, cuda=True) env: Any = registry.make( env_name, num_envs=1, sim_backend=sim_backend, env_cfg_override=env_cfg_override ) from unilab.base.observations import get_obs_dims obs_dim, critic_obs_dim = get_obs_dims(env.obs_groups_spec) act_space_shape = env.action_space.shape assert act_space_shape is not None action_dim = act_space_shape[0] device = device or get_default_device() symmetry_augmentation = None if use_symmetry: symmetry_augmentation = env.build_symmetry_augmentation(device=device) if symmetry_augmentation is None: env.close() raise ValueError( f"{env_name} with backend={sim_backend} does not provide symmetry augmentation" ) env.close() learner = FastSACLearner( obs_dim=obs_dim, action_dim=action_dim, device=device, gamma=gamma, tau=tau, actor_lr=actor_lr, critic_lr=critic_lr, alpha_lr=alpha_lr, alpha_init=alpha_init, target_entropy_ratio=target_entropy_ratio, actor_hidden_dim=actor_hidden_dim, critic_hidden_dim=critic_hidden_dim, num_atoms=num_atoms, use_layer_norm=use_layer_norm, max_grad_norm=max_grad_norm, use_amp=use_amp, amp_dtype=amp_dtype, use_symmetry=use_symmetry, symmetry_augmentation=symmetry_augmentation, world_size=getattr(self, "world_size", world_size), critic_obs_dim=critic_obs_dim, ) if symmetry_augmentation is not None: if batch_size % symmetry_augmentation.batch_multiplier != 0: raise ValueError( "Symmetry augmentation requires algo.batch_size to be divisible by " f"{symmetry_augmentation.batch_multiplier}, got {batch_size}" ) batch_size = batch_size // symmetry_augmentation.batch_multiplier print( "[FastSAC] Symmetry enabled: " f"batch_size adjusted to {batch_size} " f"(effective: {batch_size * symmetry_augmentation.batch_multiplier})" ) super().__init__( learner=learner, env_name=env_name, algo_type="sac", num_envs=num_envs, replay_buffer_n=replay_buffer_n, batch_size=batch_size, learning_starts=learning_starts, updates_per_step=updates_per_step, policy_frequency=policy_frequency, sync_collection=sync_collection, env_steps_per_sync=env_steps_per_sync, device=device, actor_hidden_dim=actor_hidden_dim, use_layer_norm=use_layer_norm, obs_normalization=obs_normalization, sim_backend=sim_backend, env_cfg_override=env_cfg_override, seed=seed, trace_enabled=trace_enabled, trace_output_dir=trace_output_dir, trace_thread_time=trace_thread_time, trace_cuda_events=trace_cuda_events, )