unilab.algos.torch.hora.sac_learner.HoraSACLearner

class unilab.algos.torch.hora.sac_learner.HoraSACLearner[source]

Bases: FastSACLearner

FastSAC learner variant whose actor consumes HORA privileged info.

Parameters:

Methods

__init__(*, obs_dim, critic_obs_dim, ...[, ...])

get_state_dict()

Save all components.

load_state_dict(state_dict)

Load all components.

soft_update_target()

Polyak-average update of the target Q-network.

update_actor(batch)

One actor update step.

update_critic(batch)

One critic update step.

__init__(*, obs_dim, critic_obs_dim, priv_info_dim, action_dim, device='cpu', actor_hidden_dim=512, priv_info_embed_dim=9, priv_mlp_hidden_dims=(256, 128, 9), log_std_max=0.0, log_std_min=-5.0, use_tanh=True, use_layer_norm=True, actor_lr=0.0003, weight_decay=0.001, use_symmetry=False, symmetry_augmentation=None, **kwargs)[source]
Parameters:
get_state_dict()

Save all components.

Return type:

Dict[str, Any]

load_state_dict(state_dict)

Load all components.

Parameters:

state_dict (Dict)

Return type:

None

soft_update_target()

Polyak-average update of the target Q-network.

Return type:

None

update_actor(batch)

One actor update step.

Parameters:

batch (Dict[str, Tensor])

Return type:

Dict[str, float]

update_critic(batch)

One critic update step.

Parameters:

batch (Dict[str, Tensor])

Return type:

Dict[str, float]