unilab.algos.torch.flash_sac.runner.FlashSACRunner

class unilab.algos.torch.flash_sac.runner.FlashSACRunner[source]

Bases: OffPolicyRunner

Parameters:
  • env_name (str)

  • env_cfg_override (dict[str, Any] | None)

  • device (str | None)

  • num_envs (int)

  • replay_buffer_n (int)

  • batch_size (int)

  • learning_starts (int)

  • updates_per_step (int)

  • policy_frequency (int)

  • sync_collection (bool)

  • env_steps_per_sync (int)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • obs_normalization (bool)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • num_atoms (int)

  • use_amp (bool)

  • sim_backend (str)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • target_sigma (float)

  • target_entropy (float | None)

  • temp_initial_value (float)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • amp_dtype (str)

  • use_compile (bool)

  • seed (int | None)

  • trace_enabled (bool)

  • trace_output_dir (str | None)

  • trace_thread_time (bool)

  • trace_cuda_events (bool)

Methods

__init__(env_name[, env_cfg_override, ...])

close()

learn([max_iterations, save_interval, ...])

Unified training loop for off-policy algorithms.

__init__(env_name, env_cfg_override=None, device=None, num_envs=2048, replay_buffer_n=512, batch_size=2048, learning_starts=0, updates_per_step=1, policy_frequency=2, sync_collection=True, env_steps_per_sync=1, gamma=0.99, tau=0.01, actor_lr=0.0003, critic_lr=0.0003, obs_normalization=False, actor_hidden_dim=128, critic_hidden_dim=256, num_atoms=101, use_amp=False, sim_backend='mujoco', actor_num_blocks=2, critic_num_blocks=2, actor_bc_alpha=0.0, actor_noise_zeta_mu=2.0, actor_noise_zeta_max=16, critic_min_v=-5.0, critic_max_v=5.0, target_sigma=0.15, target_entropy=None, temp_initial_value=0.01, learning_rate_init=0.0003, learning_rate_peak=0.0003, learning_rate_end=0.00015, learning_rate_warmup_steps=0, learning_rate_decay_steps=500000, normalize_reward=True, normalized_g_max=5.0, n_step=1, amp_dtype='auto', use_compile=False, seed=None, trace_enabled=False, trace_output_dir=None, trace_thread_time=False, trace_cuda_events=True)[source]
Parameters:
  • env_name (str)

  • env_cfg_override (dict[str, Any] | None)

  • device (str | None)

  • num_envs (int)

  • replay_buffer_n (int)

  • batch_size (int)

  • learning_starts (int)

  • updates_per_step (int)

  • policy_frequency (int)

  • sync_collection (bool)

  • env_steps_per_sync (int)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • obs_normalization (bool)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • num_atoms (int)

  • use_amp (bool)

  • sim_backend (str)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • target_sigma (float)

  • target_entropy (float | None)

  • temp_initial_value (float)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • amp_dtype (str)

  • use_compile (bool)

  • seed (int | None)

  • trace_enabled (bool)

  • trace_output_dir (str | None)

  • trace_thread_time (bool)

  • trace_cuda_events (bool)

close()
Return type:

None

learn(max_iterations=1500, save_interval=50, log_dir='logs', logger_type='tensorboard')

Unified training loop for off-policy algorithms.

Parameters:
  • max_iterations (int)

  • save_interval (int)

  • log_dir (str)

  • logger_type (str)

Return type:

None