unilab.algos.torch.fast_sac.learner.FastSACLearner¶
- class unilab.algos.torch.fast_sac.learner.FastSACLearner[source]¶
Bases:
objectFastSAC learner with holosoma-aligned hyperparameters.
Key hyperparameters (aligned with holosoma FastSACConfig): - gamma=0.97, tau=0.125 - batch_size=8192, num_updates=8, policy_frequency=4 - alpha_init=0.001, target_entropy_ratio=0.0 - AdamW with betas=(0.9, 0.95), weight_decay=0.001 - Distributional critic (C51, num_atoms=101)
- Parameters:
obs_dim (
int)action_dim (
int)critic_obs_dim (
int)device (
str)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)alpha_lr (
float)alpha_init (
float)target_entropy_ratio (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)num_q_networks (
int)use_layer_norm (
bool)use_tanh (
bool)log_std_max (
float)log_std_min (
float)weight_decay (
float)max_grad_norm (
float)use_autotune (
bool)use_symmetry (
bool)use_amp (
bool)amp_dtype (
str)use_compile (
bool)symmetry_augmentation (
SymmetryAugmentation|None)world_size (
int)
Methods
__init__(obs_dim, action_dim, critic_obs_dim)Save all components.
load_state_dict(state_dict)Load all components.
Polyak-average update of the target Q-network.
update_actor(batch)One actor update step.
update_critic(batch)One critic update step.
- __init__(obs_dim, action_dim, critic_obs_dim, device='cpu', gamma=0.97, tau=0.125, actor_lr=0.0003, critic_lr=0.0003, alpha_lr=0.0003, alpha_init=0.001, target_entropy_ratio=0.0, actor_hidden_dim=512, critic_hidden_dim=768, num_atoms=101, v_min=-20.0, v_max=20.0, num_q_networks=2, use_layer_norm=True, use_tanh=True, log_std_max=0.0, log_std_min=-5.0, weight_decay=0.001, max_grad_norm=0.0, use_autotune=True, use_symmetry=False, use_amp=False, amp_dtype='auto', use_compile=False, symmetry_augmentation=None, world_size=1)[source]¶
- Parameters:
obs_dim (
int)action_dim (
int)critic_obs_dim (
int)device (
str)gamma (
float)tau (
float)actor_lr (
float)critic_lr (
float)alpha_lr (
float)alpha_init (
float)target_entropy_ratio (
float)actor_hidden_dim (
int)critic_hidden_dim (
int)num_atoms (
int)v_min (
float)v_max (
float)num_q_networks (
int)use_layer_norm (
bool)use_tanh (
bool)log_std_max (
float)log_std_min (
float)weight_decay (
float)max_grad_norm (
float)use_autotune (
bool)use_symmetry (
bool)use_amp (
bool)amp_dtype (
str)use_compile (
bool)symmetry_augmentation (
SymmetryAugmentation|None)world_size (
int)