unilab.algos.torch.common¶
- class unilab.algos.torch.common.EmpiricalNormalization[source]¶
Bases:
ModuleNormalize mean and variance of observations using running statistics.
- count: torch.Tensor¶
- __init__(shape, device, eps=0.01)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- property mean¶
- property std¶
- forward(x, center=True, update=True)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class unilab.algos.torch.common.DistributionalQNetwork[source]¶
Bases:
ModuleSingle distributional Q-network (C51 variant).
Architecture: Linear→ReLU → Linear→ReLU → Linear→ReLU → Linear Outputs num_atoms logits over the value distribution.
- Parameters:
- __init__(obs_dim, n_act, num_atoms, v_min, v_max, hidden_dim, device=None)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs, actions)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class unilab.algos.torch.common.Critic[source]¶
Bases:
ModuleTwin distributional Q-networks for off-policy RL (SAC/TD3).
- Parameters:
- q_support: torch.Tensor¶
- __init__(obs_dim, n_act, num_atoms, v_min, v_max, hidden_dim, device=None)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs, actions)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- unilab.algos.torch.common.get_env_dims(env_name, sim_backend='mujoco', env_cfg_override=None)[source]¶
Get (actor_obs_dim, action_dim, critic_obs_dim) from environment.
- unilab.algos.torch.common.check_nan_loss(loss, default_metrics)[source]¶
Check if loss contains NaN or Inf values.
- unilab.algos.torch.common.clip_gradients(parameters, max_norm=10.0)[source]¶
Clip gradients by global norm.
- Parameters:
parameters – Model parameters
max_norm (
float) – Maximum gradient norm
- unilab.algos.torch.common.safe_tensor(tensor, nan_value=0.0, clamp_range=(-10.0, 10.0))[source]¶
Make tensor numerically safe by clamping and replacing NaN values.
- unilab.algos.torch.common.ensure_registries(packages=None, *, optional_packages=None, fail_on_error=True)[source]¶
Import env registry bootstrap modules.
- unilab.algos.torch.common.build_actor(algo_type, obs_dim, action_dim, actor_hidden_dim, use_layer_norm, device, num_envs=1, actor_num_blocks=2, actor_noise_zeta_mu=2.0, actor_noise_zeta_max=16, priv_info_dim=None, priv_info_embed_dim=9, priv_mlp_hidden_dims=(256, 128, 9), **kwargs)[source]¶
Build the correct actor model based on algorithm type.
Modules
Actor factory helpers for torch off-policy algorithms. |
|
Simplified ANE backend using deterministic inference. |
|
ANE (Apple Neural Engine) inference wrapper for Fast SAC. |
|
Base collector class for shared functionality. |
|
Neural network architectures for RL algorithms. |
|
Observation normalization for RL training. |
|
Numerical stability utilities for RL training. |