unilab.algos.torch.flash_sac.learner.FlashSACLearner

class unilab.algos.torch.flash_sac.learner.FlashSACLearner[source]

Bases: object

Parameters:
  • obs_dim (int)

  • action_dim (int)

  • critic_obs_dim (int)

  • device (str)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • num_atoms (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • temp_initial_value (float)

  • temp_target_sigma (float)

  • temp_target_entropy (float | None)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • obs_normalization (bool)

  • use_amp (bool)

  • amp_dtype (str)

  • use_compile (bool)

Methods

__init__(obs_dim, action_dim, critic_obs_dim)

get_state_dict()

load_state_dict(state_dict)

soft_update_target()

update_actor(batch)

update_critic(batch)

update_reward_stats(rewards, dones)

__init__(obs_dim, action_dim, critic_obs_dim, device='cpu', gamma=0.99, tau=0.01, actor_lr=0.0003, critic_lr=0.0003, actor_hidden_dim=128, critic_hidden_dim=256, actor_num_blocks=2, critic_num_blocks=2, num_atoms=101, critic_min_v=-5.0, critic_max_v=5.0, temp_initial_value=0.01, temp_target_sigma=0.15, temp_target_entropy=None, actor_bc_alpha=0.0, actor_noise_zeta_mu=2.0, actor_noise_zeta_max=16, learning_rate_init=0.0003, learning_rate_peak=0.0003, learning_rate_end=0.00015, learning_rate_warmup_steps=0, learning_rate_decay_steps=500000, normalize_reward=True, normalized_g_max=5.0, n_step=1, obs_normalization=False, use_amp=False, amp_dtype='auto', use_compile=False)[source]
Parameters:
  • obs_dim (int)

  • action_dim (int)

  • critic_obs_dim (int)

  • device (str)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • num_atoms (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • temp_initial_value (float)

  • temp_target_sigma (float)

  • temp_target_entropy (float | None)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • obs_normalization (bool)

  • use_amp (bool)

  • amp_dtype (str)

  • use_compile (bool)

update_reward_stats(rewards, dones)[source]
Parameters:
Return type:

None

update_critic(batch)[source]
Parameters:

batch (dict[str, Tensor])

Return type:

dict[str, float]

update_actor(batch)[source]
Parameters:

batch (dict[str, Tensor])

Return type:

dict[str, float]

soft_update_target()[source]
Return type:

None

get_state_dict()[source]
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]
Parameters:

state_dict (dict[str, Any])

Return type:

None