unilab.algos.torch.flash_sac.learner

FlashSAC learner adapted to UniLab’s off-policy contract.

Classes

FlashSACLearner

RewardNormalizer

Adaptive reward scaling with running discounted-return statistics.

RunningMeanStd

RunningMeanStd(mean: 'torch.Tensor', var: 'torch.Tensor', count: 'torch.Tensor')

class unilab.algos.torch.flash_sac.learner.RunningMeanStd[source]

Bases: object

RunningMeanStd(mean: ‘torch.Tensor’, var: ‘torch.Tensor’, count: ‘torch.Tensor’)

Parameters:
mean: Tensor
var: Tensor
count: Tensor
classmethod create(device)[source]
Parameters:

device (device)

Return type:

RunningMeanStd

update(x)[source]
Parameters:

x (Tensor)

Return type:

None

state_dict()[source]
Return type:

dict[str, Tensor]

load_state_dict(state_dict)[source]
Parameters:

state_dict (dict[str, Tensor])

Return type:

None

__init__(mean, var, count)
Parameters:
class unilab.algos.torch.flash_sac.learner.RewardNormalizer[source]

Bases: object

Adaptive reward scaling with running discounted-return statistics.

Parameters:
__init__(gamma, g_max, device, eps=1e-08)[source]
Parameters:
update_from_transitions(rewards, dones)[source]
Parameters:
Return type:

None

normalize(rewards)[source]
Parameters:

rewards (Tensor)

Return type:

Tensor

state_dict()[source]
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]
Parameters:

state_dict (dict[str, Any])

Return type:

None

class unilab.algos.torch.flash_sac.learner.FlashSACLearner[source]

Bases: object

Parameters:
  • obs_dim (int)

  • action_dim (int)

  • critic_obs_dim (int)

  • device (str)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • num_atoms (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • temp_initial_value (float)

  • temp_target_sigma (float)

  • temp_target_entropy (float | None)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • obs_normalization (bool)

  • use_amp (bool)

  • amp_dtype (str)

  • use_compile (bool)

__init__(obs_dim, action_dim, critic_obs_dim, device='cpu', gamma=0.99, tau=0.01, actor_lr=0.0003, critic_lr=0.0003, actor_hidden_dim=128, critic_hidden_dim=256, actor_num_blocks=2, critic_num_blocks=2, num_atoms=101, critic_min_v=-5.0, critic_max_v=5.0, temp_initial_value=0.01, temp_target_sigma=0.15, temp_target_entropy=None, actor_bc_alpha=0.0, actor_noise_zeta_mu=2.0, actor_noise_zeta_max=16, learning_rate_init=0.0003, learning_rate_peak=0.0003, learning_rate_end=0.00015, learning_rate_warmup_steps=0, learning_rate_decay_steps=500000, normalize_reward=True, normalized_g_max=5.0, n_step=1, obs_normalization=False, use_amp=False, amp_dtype='auto', use_compile=False)[source]
Parameters:
  • obs_dim (int)

  • action_dim (int)

  • critic_obs_dim (int)

  • device (str)

  • gamma (float)

  • tau (float)

  • actor_lr (float)

  • critic_lr (float)

  • actor_hidden_dim (int)

  • critic_hidden_dim (int)

  • actor_num_blocks (int)

  • critic_num_blocks (int)

  • num_atoms (int)

  • critic_min_v (float)

  • critic_max_v (float)

  • temp_initial_value (float)

  • temp_target_sigma (float)

  • temp_target_entropy (float | None)

  • actor_bc_alpha (float)

  • actor_noise_zeta_mu (float)

  • actor_noise_zeta_max (int)

  • learning_rate_init (float)

  • learning_rate_peak (float)

  • learning_rate_end (float)

  • learning_rate_warmup_steps (int)

  • learning_rate_decay_steps (int)

  • normalize_reward (bool)

  • normalized_g_max (float)

  • n_step (int)

  • obs_normalization (bool)

  • use_amp (bool)

  • amp_dtype (str)

  • use_compile (bool)

update_reward_stats(rewards, dones)[source]
Parameters:
Return type:

None

update_critic(batch)[source]
Parameters:

batch (dict[str, Tensor])

Return type:

dict[str, float]

update_actor(batch)[source]
Parameters:

batch (dict[str, Tensor])

Return type:

dict[str, float]

soft_update_target()[source]
Return type:

None

get_state_dict()[source]
Return type:

dict[str, Any]

load_state_dict(state_dict)[source]
Parameters:

state_dict (dict[str, Any])

Return type:

None