unilab.algos.torch.fast_td3.learner.FastTD3Learner¶
- class unilab.algos.torch.fast_td3.learner.FastTD3Learner[source]¶
Bases:
objectFastTD3 learner aligned with reference FastTD3 repository.
Key hyperparameters (from Go1JoystickFlat): - gamma=0.97, tau=0.1 - AdamW with weight_decay=0.1 - Cosine LR schedule - Distributional critic (C51, num_atoms=101, v_min/max=±10) - CDQ (Clipped Double Q-learning) toggle - Observation normalization
- Parameters:
obs_dim (
int)action_dim (
int)critic_obs_dim (
int)num_envs (
int)device (
str)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)init_scale (
float)log_std_min (
float)log_std_max (
float)weight_decay (
float)use_cdq (
bool)policy_noise (
float)noise_clip (
float)policy_frequency (
int)max_iterations (
int)obs_normalization (
bool)
Methods
__init__(obs_dim, action_dim, critic_obs_dim)load_state_dict(state_dict)normalize_obs(obs[, update])Normalize observations using running statistics.
Backward-compatible alias for older call sites.
Polyak-average update of critic target network only (matching reference FastTD3).
update_actor(data)One actor update step.
update_critic(data)One critic update step.
- __init__(obs_dim, action_dim, critic_obs_dim, num_envs=1024, device='cpu', gamma=0.97, tau=0.01, actor_lr=0.0003, critic_lr=0.0003, actor_hidden_dim=512, critic_hidden_dim=1024, num_atoms=101, v_min=-10.0, v_max=10.0, init_scale=0.01, log_std_min=-3.0, log_std_max=0.0, weight_decay=0.001, use_cdq=True, policy_noise=0.1, noise_clip=0.2, policy_frequency=2, max_iterations=50000, obs_normalization=True)[source]¶
- Parameters:
obs_dim (
int)action_dim (
int)critic_obs_dim (
int)num_envs (
int)device (
str)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)init_scale (
float)log_std_min (
float)log_std_max (
float)weight_decay (
float)use_cdq (
bool)policy_noise (
float)noise_clip (
float)policy_frequency (
int)max_iterations (
int)obs_normalization (
bool)