unilab.algos.torch.fast_td3.runner¶
FastTD3 runner built on top of the unified off-policy infra.
Classes
FastTD3 runner using the shared OffPolicyRunner training loop. |
- class unilab.algos.torch.fast_td3.runner.FastTD3Runner[source]¶
Bases:
OffPolicyRunnerFastTD3 runner using the shared OffPolicyRunner training loop.
- Parameters:
env_name (
str)num_envs (
int)replay_buffer_n (
int)batch_size (
int)learning_starts (
int)num_updates (
int)policy_frequency (
int)sync_collection (
bool)env_steps_per_sync (
int)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)init_scale (
float)log_std_min (
float)log_std_max (
float)policy_noise (
float)noise_clip (
float)weight_decay (
float)use_cdq (
bool)obs_normalization (
bool)sim_backend (
str)trace_enabled (
bool)trace_thread_time (
bool)trace_cuda_events (
bool)
- __init__(env_name, env_cfg_override=None, device=None, num_envs=4096, replay_buffer_n=1000, batch_size=8192, learning_starts=0, num_updates=4, policy_frequency=2, sync_collection=True, env_steps_per_sync=1, gamma=0.97, tau=0.01, actor_lr=0.0003, critic_lr=0.0003, actor_hidden_dim=256, critic_hidden_dim=512, num_atoms=101, v_min=-10.0, v_max=10.0, init_scale=0.01, log_std_min=-0.9, log_std_max=0.0, policy_noise=0.1, noise_clip=0.2, weight_decay=0.001, use_cdq=True, obs_normalization=True, sim_backend='mujoco', seed=None, trace_enabled=False, trace_output_dir=None, trace_thread_time=False, trace_cuda_events=True)[source]¶
- Parameters:
env_name (
str)num_envs (
int)replay_buffer_n (
int)batch_size (
int)learning_starts (
int)num_updates (
int)policy_frequency (
int)sync_collection (
bool)env_steps_per_sync (
int)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)init_scale (
float)log_std_min (
float)log_std_max (
float)policy_noise (
float)noise_clip (
float)weight_decay (
float)use_cdq (
bool)obs_normalization (
bool)sim_backend (
str)trace_enabled (
bool)trace_thread_time (
bool)trace_cuda_events (
bool)