unilab.algos.torch.flash_sac.layers

FlashSAC layers and lightweight normalization helpers.

Functions

safe_tanh_log_det_jacobian(x)

Stable log|det J_tanh(x)| term.

Classes

EnsembleCategoricalValue

EnsembleFlashSACBlock

EnsembleFlashSACEmbedder

EnsembleUnitBatchNorm

EnsembleUnitLinear

EnsembleUnitRMSNorm

FlashSACBlock

FlashSACEmbedder

NormalTanhPolicy

UnitBatchNorm

BatchNorm variant with normalized affine parameters.

UnitLinear

Linear layer with post-step weight normalization.

UnitRMSNorm

RMSNorm with unit-length scale vector.

unilab.algos.torch.flash_sac.layers.safe_tanh_log_det_jacobian(x)[source]

Stable log|det J_tanh(x)| term.

Parameters:

x (Tensor)

Return type:

Tensor

class unilab.algos.torch.flash_sac.layers.UnitLinear[source]

Bases: Module

Linear layer with post-step weight normalization.

Parameters:
  • input_dim (int)

  • output_dim (int)

__init__(input_dim, output_dim)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • input_dim (int)

  • output_dim (int)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.UnitBatchNorm[source]

Bases: Module

BatchNorm variant with normalized affine parameters.

Parameters:
running_mean: torch.Tensor
running_var: torch.Tensor
__init__(input_dim, momentum=0.01, eps=1e-05)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.UnitRMSNorm[source]

Bases: Module

RMSNorm with unit-length scale vector.

Parameters:
__init__(input_dim, eps=1e-06)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.FlashSACEmbedder[source]

Bases: Module

Parameters:
  • input_dim (int)

  • hidden_dim (int)

__init__(input_dim, hidden_dim)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • input_dim (int)

  • hidden_dim (int)

forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class unilab.algos.torch.flash_sac.layers.FlashSACBlock[source]

Bases: Module

Parameters:
  • hidden_dim (int)

  • expansion (int)

__init__(hidden_dim, expansion=4)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • hidden_dim (int)

  • expansion (int)

forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class unilab.algos.torch.flash_sac.layers.NormalTanhPolicy[source]

Bases: Module

Parameters:
__init__(hidden_dim, action_dim, log_std_min=-10.0, log_std_max=2.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
get_mean_and_std(x)[source]
Parameters:

x (Tensor)

Return type:

tuple[Tensor, Tensor]

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

tuple[Tensor, dict[str, Tensor]]

class unilab.algos.torch.flash_sac.layers.EnsembleUnitLinear[source]

Bases: Module

Parameters:
  • num_ensemble (int)

  • input_dim (int)

  • output_dim (int)

__init__(num_ensemble, input_dim, output_dim)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • num_ensemble (int)

  • input_dim (int)

  • output_dim (int)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.EnsembleUnitBatchNorm[source]

Bases: Module

Parameters:
running_mean: torch.Tensor
running_var: torch.Tensor
__init__(num_ensemble, input_dim, momentum=0.01, eps=1e-05)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.EnsembleUnitRMSNorm[source]

Bases: Module

Parameters:
__init__(num_ensemble, input_dim, eps=1e-06)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

normalize_parameters()[source]
Return type:

None

class unilab.algos.torch.flash_sac.layers.EnsembleFlashSACEmbedder[source]

Bases: Module

Parameters:
  • num_ensemble (int)

  • input_dim (int)

  • hidden_dim (int)

__init__(num_ensemble, input_dim, hidden_dim)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • num_ensemble (int)

  • input_dim (int)

  • hidden_dim (int)

forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class unilab.algos.torch.flash_sac.layers.EnsembleFlashSACBlock[source]

Bases: Module

Parameters:
  • num_ensemble (int)

  • hidden_dim (int)

  • expansion (int)

__init__(num_ensemble, hidden_dim, expansion=4)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
  • num_ensemble (int)

  • hidden_dim (int)

  • expansion (int)

forward(x, training)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class unilab.algos.torch.flash_sac.layers.EnsembleCategoricalValue[source]

Bases: Module

Parameters:
support: torch.Tensor
__init__(num_ensemble, hidden_dim, num_bins, min_v, max_v)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

tuple[Tensor, dict[str, Tensor]]