Source code for unilab.algos.torch.common.actor_factory

"""Actor factory helpers for torch off-policy algorithms."""

from __future__ import annotations


[docs] def build_actor( algo_type, obs_dim, action_dim, actor_hidden_dim, use_layer_norm, device, num_envs=1, actor_num_blocks: int = 2, actor_noise_zeta_mu: float = 2.0, actor_noise_zeta_max: int = 16, priv_info_dim: int | None = None, priv_info_embed_dim: int = 9, priv_mlp_hidden_dims: tuple[int, ...] | list[int] = (256, 128, 9), **kwargs, ): """Build the correct actor model based on algorithm type.""" if algo_type == "sac": from unilab.algos.torch.fast_sac.learner import SACActor return SACActor( obs_dim=obs_dim, action_dim=action_dim, hidden_dim=actor_hidden_dim, use_layer_norm=use_layer_norm, device=device, ) if algo_type == "hora_sac": if priv_info_dim is None: raise ValueError("build_actor(algo_type='hora_sac') requires priv_info_dim.") from unilab.algos.torch.hora.sac_models import HoraSACActor return HoraSACActor( obs_dim=obs_dim, priv_info_dim=int(priv_info_dim), action_dim=action_dim, hidden_dim=actor_hidden_dim, priv_info_embed_dim=priv_info_embed_dim, priv_mlp_hidden_dims=tuple(priv_mlp_hidden_dims), use_layer_norm=use_layer_norm, device=device, ) if algo_type == "td3": from unilab.algos.torch.fast_td3.learner import TD3Actor return TD3Actor( obs_dim=obs_dim, n_act=action_dim, num_envs=num_envs, hidden_dim=actor_hidden_dim, init_scale=kwargs.get("init_scale", 0.01), log_std_min=kwargs.get("log_std_min", -1.6), log_std_max=kwargs.get("log_std_max", -0.22), device=device, ) if algo_type == "flashsac": from unilab.algos.torch.flash_sac.network import FlashSACActor return FlashSACActor( num_blocks=actor_num_blocks, input_dim=obs_dim, hidden_dim=actor_hidden_dim, action_dim=action_dim, noise_zeta_mu=actor_noise_zeta_mu, noise_zeta_max=actor_noise_zeta_max, device=device, ) raise ValueError(f"Unknown algo_type: {algo_type}")