unilab.algos.torch.fast_sac.learner

FastSAC Learner — replicated from holosoma’s FastSAC implementation.

Network architecture: - Actor: MLP with SiLU + LayerNorm, tanh-squashed Gaussian - Critic: Distributional Q-Networks (C51 variant, num_atoms=101) - Automatic entropy coefficient (alpha) learning

Hyperparameters aligned with holosoma FastSACConfig defaults.

Classes

DistributionalQNetwork

Single distributional Q-network (C51).

FastSACLearner

FastSAC learner with holosoma-aligned hyperparameters.

SACActor

Stochastic actor for SAC with tanh-squashed Gaussian policy.

SACCritic

Ensemble of distributional Q-networks for SAC.

class unilab.algos.torch.fast_sac.learner.SACActor[source]

Bases: Module

Stochastic actor for SAC with tanh-squashed Gaussian policy.

Architecture: Linear→LN→SiLU → Linear→LN→SiLU → Linear→LN→SiLU → fc_mu + fc_logstd Hidden dims: [hidden_dim, hidden_dim//2, hidden_dim//4]

Parameters:
action_scale: torch.Tensor
action_bias: torch.Tensor
__init__(obs_dim, action_dim, hidden_dim=512, log_std_max=0.0, log_std_min=-5.0, use_tanh=True, use_layer_norm=True, device='cpu', action_scale=None, action_bias=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(obs)[source]

Returns (action, mean, log_std).

Parameters:

obs (Tensor)

Return type:

Tuple[Tensor, Tensor, Tensor]

as_export_module()[source]

Return a single-input/single-output wrapper suitable for torch.onnx.export.

Return type:

Module

get_actions_and_log_probs(obs)[source]

Sample actions and compute log probabilities. Returns (action, log_prob, log_std).

Parameters:

obs (Tensor)

Return type:

Tuple[Tensor, Tensor, Tensor]

explore(obs, dones=None, deterministic=False)[source]

Get exploration actions.

Parameters:
  • obs (Tensor) – Batched observations.

  • dones (Tensor | None) – Unused for SAC; kept for API alignment with TD3 actor.

  • deterministic (bool) – Whether to return deterministic policy actions.

Return type:

Tensor

class unilab.algos.torch.fast_sac.learner.DistributionalQNetwork[source]

Bases: Module

Single distributional Q-network (C51).

Architecture: Linear→LN→SiLU → Linear→LN→SiLU → Linear→LN→SiLU → Linear(num_atoms) Input: concat(obs, action)

Parameters:
__init__(obs_dim, action_dim, num_atoms=101, v_min=-20.0, v_max=20.0, hidden_dim=768, use_layer_norm=True, device='cpu')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(obs, actions)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

projection(obs, actions, rewards, bootstrap, discount, q_support, device)[source]

Categorical projection for distributional RL.

Parameters:
Return type:

Tensor

class unilab.algos.torch.fast_sac.learner.SACCritic[source]

Bases: Module

Ensemble of distributional Q-networks for SAC.

Uses num_q_networks independent DistributionalQNetwork instances.

Parameters:
q_support: torch.Tensor
__init__(obs_dim, action_dim, num_atoms=101, v_min=-20.0, v_max=20.0, hidden_dim=768, use_layer_norm=True, num_q_networks=2, device='cpu')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(obs, actions)[source]

Returns stacked logits: (num_q_nets, batch, num_atoms).

Parameters:
Return type:

Tensor

projection(obs, actions, rewards, bootstrap, discount)[source]

Project for all Q-networks: (num_q_nets, batch, num_atoms).

Parameters:
Return type:

Tensor

get_value(probs)[source]

Calculate value from probabilities using support.

Parameters:

probs (Tensor)

Return type:

Tensor

class unilab.algos.torch.fast_sac.learner.FastSACLearner[source]

Bases: object

FastSAC learner with holosoma-aligned hyperparameters.

Key hyperparameters (aligned with holosoma FastSACConfig): - gamma=0.97, tau=0.125 - batch_size=8192, num_updates=8, policy_frequency=4 - alpha_init=0.001, target_entropy_ratio=0.0 - AdamW with betas=(0.9, 0.95), weight_decay=0.001 - Distributional critic (C51, num_atoms=101)

Parameters:
__init__(obs_dim, action_dim, critic_obs_dim, device='cpu', gamma=0.97, tau=0.125, actor_lr=0.0003, critic_lr=0.0003, alpha_lr=0.0003, alpha_init=0.001, target_entropy_ratio=0.0, actor_hidden_dim=512, critic_hidden_dim=768, num_atoms=101, v_min=-20.0, v_max=20.0, num_q_networks=2, use_layer_norm=True, use_tanh=True, log_std_max=0.0, log_std_min=-5.0, weight_decay=0.001, max_grad_norm=0.0, use_autotune=True, use_symmetry=False, use_amp=False, amp_dtype='auto', use_compile=False, symmetry_augmentation=None, world_size=1)[source]
Parameters:
update_critic(batch)[source]

One critic update step.

Parameters:

batch (Dict[str, Tensor])

Return type:

Dict[str, float]

update_actor(batch)[source]

One actor update step.

Parameters:

batch (Dict[str, Tensor])

Return type:

Dict[str, float]

soft_update_target()[source]

Polyak-average update of the target Q-network.

Return type:

None

get_state_dict()[source]

Save all components.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Load all components.

Parameters:

state_dict (Dict)

Return type:

None