"""FastSAC Learner — replicated from holosoma's FastSAC implementation.
Network architecture:
- Actor: MLP with SiLU + LayerNorm, tanh-squashed Gaussian
- Critic: Distributional Q-Networks (C51 variant, num_atoms=101)
- Automatic entropy coefficient (alpha) learning
Hyperparameters aligned with holosoma FastSACConfig defaults.
"""
from __future__ import annotations
import math
from typing import Any, Dict, Tuple, cast
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from unilab.base.augmentation import SymmetryAugmentation
# ---------------------------------------------------------------------------
# Actor Network (holosoma-style: SiLU + LayerNorm + Tanh squashing)
# ---------------------------------------------------------------------------
[docs]
class SACActor(nn.Module):
"""Stochastic actor for SAC with tanh-squashed Gaussian policy.
Architecture: Linear→LN→SiLU → Linear→LN→SiLU → Linear→LN→SiLU → fc_mu + fc_logstd
Hidden dims: [hidden_dim, hidden_dim//2, hidden_dim//4]
"""
action_scale: torch.Tensor
action_bias: torch.Tensor
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
hidden_dim: int = 512,
log_std_max: float = 0.0,
log_std_min: float = -5.0,
use_tanh: bool = True,
use_layer_norm: bool = True,
device: str | torch.device = "cpu",
action_scale: torch.Tensor | None = None,
action_bias: torch.Tensor | None = None,
):
super().__init__()
self.obs_dim = obs_dim
self.action_dim = action_dim
self.log_std_max = log_std_max
self.log_std_min = log_std_min
self.use_tanh = use_tanh
self.device_ = device # avoid name collision with nn.Module.device
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim, device=device),
nn.LayerNorm(hidden_dim, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim // 2, device=device),
nn.LayerNorm(hidden_dim // 2, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device),
nn.LayerNorm(hidden_dim // 4, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
)
self.fc_mu = nn.Linear(hidden_dim // 4, action_dim, device=device)
self.fc_logstd = nn.Linear(hidden_dim // 4, action_dim, device=device)
# Zero-init output heads (holosoma style)
nn.init.constant_(self.fc_mu.weight, 0.0)
nn.init.constant_(self.fc_mu.bias, 0.0)
nn.init.constant_(self.fc_logstd.weight, 0.0)
nn.init.constant_(self.fc_logstd.bias, 0.0)
# Action scaling
if action_scale is not None:
self.register_buffer("action_scale", action_scale.to(device))
else:
self.register_buffer("action_scale", torch.ones(action_dim, device=device))
if action_bias is not None:
self.register_buffer("action_bias", action_bias.to(device))
else:
self.register_buffer("action_bias", torch.zeros(action_dim, device=device))
[docs]
def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns (action, mean, log_std)."""
x = self.net(obs)
mean = self.fc_mu(x)
log_std = self.fc_logstd(x)
# Squash log_std to [log_std_min, log_std_max] (SpinUp / Denis Yarats style)
log_std = torch.tanh(log_std)
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1)
# NaN protection: clamp mean to prevent exploding values
mean = torch.clamp(mean, -10.0, 10.0)
mean = torch.nan_to_num(mean, nan=0.0)
log_std = torch.nan_to_num(log_std, nan=self.log_std_min)
if self.use_tanh:
tanh_mean = torch.tanh(mean)
action = tanh_mean * self.action_scale + self.action_bias
else:
action = mean
return action, mean, log_std
[docs]
def as_export_module(self) -> "nn.Module":
"""Return a single-input/single-output wrapper suitable for torch.onnx.export."""
actor = self
class _Wrapper(nn.Module):
def __init__(self) -> None:
super().__init__()
self.base = actor
def forward(self, obs: torch.Tensor) -> torch.Tensor:
action, _, _ = self.base(obs)
return cast(torch.Tensor, action)
return _Wrapper()
[docs]
def get_actions_and_log_probs(
self, obs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample actions and compute log probabilities. Returns (action, log_prob, log_std)."""
_, mean, log_std = self(obs)
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
raw_action = dist.rsample()
if self.use_tanh:
tanh_action = torch.tanh(raw_action)
action = tanh_action * self.action_scale + self.action_bias
log_prob = dist.log_prob(raw_action)
log_prob -= torch.log(1 - tanh_action.pow(2) + 1e-6)
log_prob -= torch.log(self.action_scale + 1e-6)
else:
action = raw_action
log_prob = dist.log_prob(raw_action)
log_prob = log_prob.sum(1)
return action, log_prob, log_std
[docs]
@torch.no_grad()
def explore(
self,
obs: torch.Tensor,
dones: torch.Tensor | None = None,
deterministic: bool = False,
) -> torch.Tensor:
"""Get exploration actions.
Args:
obs: Batched observations.
dones: Unused for SAC; kept for API alignment with TD3 actor.
deterministic: Whether to return deterministic policy actions.
"""
# Backward compatibility: previous signature was explore(obs, deterministic=False).
if isinstance(dones, bool):
deterministic = dones
dones = None
_ = dones
_, mean, log_std = self.forward(obs)
if deterministic:
if self.use_tanh:
return torch.tanh(mean) * self.action_scale + self.action_bias
return mean
std = log_std.exp()
dist = torch.distributions.Normal(mean, std)
raw_action = dist.rsample()
if self.use_tanh:
return torch.tanh(raw_action) * self.action_scale + self.action_bias
return raw_action
# ---------------------------------------------------------------------------
# Distributional Q-Network (C51 variant, from holosoma)
# ---------------------------------------------------------------------------
[docs]
class DistributionalQNetwork(nn.Module):
"""Single distributional Q-network (C51).
Architecture: Linear→LN→SiLU → Linear→LN→SiLU → Linear→LN→SiLU → Linear(num_atoms)
Input: concat(obs, action)
"""
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
num_atoms: int = 101,
v_min: float = -20.0,
v_max: float = 20.0,
hidden_dim: int = 768,
use_layer_norm: bool = True,
device: str | torch.device = "cpu",
):
super().__init__()
self.num_atoms = num_atoms
self.v_min = v_min
self.v_max = v_max
input_dim = obs_dim + action_dim
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim, device=device),
nn.LayerNorm(hidden_dim, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim // 2, device=device),
nn.LayerNorm(hidden_dim // 2, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device),
nn.LayerNorm(hidden_dim // 4, device=device) if use_layer_norm else nn.Identity(),
nn.SiLU(),
nn.Linear(hidden_dim // 4, num_atoms, device=device),
)
[docs]
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
x = torch.cat([obs, actions], dim=-1)
return self.net(x) # type: ignore[no-any-return]
[docs]
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
discount: torch.Tensor,
q_support: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""Categorical projection for distributional RL."""
delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
batch_size = rewards.shape[0]
target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support
target_z = target_z.clamp(self.v_min, self.v_max)
b = (target_z - self.v_min) / delta_z
lower = torch.floor(b).long()
upper = torch.ceil(b).long()
is_integer = upper == lower
lower_mask = torch.logical_and((lower > 0), is_integer)
upper_mask = torch.logical_and((lower == 0), is_integer)
lower = torch.where(lower_mask, lower - 1, lower)
upper = torch.where(upper_mask, upper + 1, upper)
next_dist = F.softmax(self(obs, actions), dim=1)
proj_dist = torch.zeros_like(next_dist)
offset = (
torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size, device=device)
.unsqueeze(1)
.expand(batch_size, self.num_atoms)
.long()
)
lower_indices = (lower + offset).view(-1)
upper_indices = (upper + offset).view(-1)
max_index = proj_dist.numel() - 1
lower_indices = torch.clamp(lower_indices, 0, max_index)
upper_indices = torch.clamp(upper_indices, 0, max_index)
proj_dist.view(-1).index_add_(0, lower_indices, (next_dist * (upper.float() - b)).view(-1))
proj_dist.view(-1).index_add_(0, upper_indices, (next_dist * (b - lower.float())).view(-1))
return proj_dist
[docs]
class SACCritic(nn.Module):
"""Ensemble of distributional Q-networks for SAC.
Uses ``num_q_networks`` independent DistributionalQNetwork instances.
"""
q_support: torch.Tensor
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
num_atoms: int = 101,
v_min: float = -20.0,
v_max: float = 20.0,
hidden_dim: int = 768,
use_layer_norm: bool = True,
num_q_networks: int = 2,
device: str | torch.device = "cpu",
):
super().__init__()
self.obs_dim = obs_dim
self.action_dim = action_dim
self.num_atoms = num_atoms
self.v_min = v_min
self.v_max = v_max
self.num_q_networks = num_q_networks
self.qnets = nn.ModuleList(
[
DistributionalQNetwork(
obs_dim=obs_dim,
action_dim=action_dim,
num_atoms=num_atoms,
v_min=v_min,
v_max=v_max,
hidden_dim=hidden_dim,
use_layer_norm=use_layer_norm,
device=device,
)
for _ in range(num_q_networks)
]
)
self.register_buffer("q_support", torch.linspace(v_min, v_max, num_atoms, device=device))
[docs]
def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""Returns stacked logits: (num_q_nets, batch, num_atoms)."""
outputs = [qnet(obs, actions) for qnet in self.qnets]
return torch.stack(outputs, dim=0)
[docs]
def projection(
self,
obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
bootstrap: torch.Tensor,
discount: torch.Tensor,
) -> torch.Tensor:
"""Project for all Q-networks: (num_q_nets, batch, num_atoms)."""
projections = [
qnet.projection( # type: ignore[operator]
obs, actions, rewards, bootstrap, discount, self.q_support, self.q_support.device
)
for qnet in self.qnets
]
return torch.stack(projections, dim=0)
[docs]
def get_value(self, probs: torch.Tensor) -> torch.Tensor:
"""Calculate value from probabilities using support."""
return torch.sum(probs * self.q_support, dim=-1)
# ---------------------------------------------------------------------------
# FastSACLearner — the training algorithm
# ---------------------------------------------------------------------------
[docs]
class FastSACLearner:
"""FastSAC learner with holosoma-aligned hyperparameters.
Key hyperparameters (aligned with holosoma FastSACConfig):
- gamma=0.97, tau=0.125
- batch_size=8192, num_updates=8, policy_frequency=4
- alpha_init=0.001, target_entropy_ratio=0.0
- AdamW with betas=(0.9, 0.95), weight_decay=0.001
- Distributional critic (C51, num_atoms=101)
"""
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
critic_obs_dim: int,
device: str = "cpu",
# Hyperparameters aligned with holosoma
gamma: float = 0.97,
tau: float = 0.125,
actor_lr: float = 3e-4,
critic_lr: float = 3e-4,
alpha_lr: float = 3e-4,
alpha_init: float = 0.001,
target_entropy_ratio: float = 0.0,
actor_hidden_dim: int = 512,
critic_hidden_dim: int = 768,
num_atoms: int = 101,
v_min: float = -20.0,
v_max: float = 20.0,
num_q_networks: int = 2,
use_layer_norm: bool = True,
use_tanh: bool = True,
log_std_max: float = 0.0,
log_std_min: float = -5.0,
weight_decay: float = 0.001,
max_grad_norm: float = 0.0,
use_autotune: bool = True,
use_symmetry: bool = False,
use_amp: bool = False,
amp_dtype: str = "auto",
use_compile: bool = False,
symmetry_augmentation: SymmetryAugmentation | None = None,
world_size: int = 1,
):
self.device = device
self._device_type = torch.device(device).type
self.gamma = gamma
self.tau = tau
self.max_grad_norm = max_grad_norm
self.use_autotune = use_autotune
self.use_amp = bool(use_amp) and self._device_type in ("cuda", "xpu")
self.use_compile = (
bool(use_compile) and self._device_type == "cuda" and hasattr(torch, "compile")
)
self.amp_dtype = amp_dtype
self._amp_dtype = self._resolve_amp_dtype(amp_dtype, self._device_type)
self.world_size = world_size
self.critic_obs_dim = critic_obs_dim
# Build actor (uses obs only)
self.actor = SACActor(
obs_dim=obs_dim,
action_dim=action_dim,
hidden_dim=actor_hidden_dim,
log_std_max=log_std_max,
log_std_min=log_std_min,
use_tanh=use_tanh,
use_layer_norm=use_layer_norm,
device=device,
)
self.qnet = SACCritic(
obs_dim=critic_obs_dim,
action_dim=action_dim,
num_atoms=num_atoms,
v_min=v_min,
v_max=v_max,
hidden_dim=critic_hidden_dim,
use_layer_norm=use_layer_norm,
num_q_networks=num_q_networks,
device=device,
)
# Target critic
self.qnet_target = SACCritic(
obs_dim=critic_obs_dim,
action_dim=action_dim,
num_atoms=num_atoms,
v_min=v_min,
v_max=v_max,
hidden_dim=critic_hidden_dim,
use_layer_norm=use_layer_norm,
num_q_networks=num_q_networks,
device=device,
)
self.qnet_target.load_state_dict(self.qnet.state_dict())
# Entropy coefficient
self.log_alpha = torch.tensor([math.log(alpha_init)], requires_grad=True, device=device)
self.target_entropy = -action_dim * target_entropy_ratio
# fused AdamW requires CUDA; MPS and CPU do not support it
_fused = isinstance(device, str) and device.startswith("cuda")
# Optimizers (AdamW with holosoma betas)
self.q_optimizer = optim.AdamW(
self.qnet.parameters(),
lr=critic_lr,
weight_decay=weight_decay,
fused=_fused,
betas=(0.9, 0.95),
)
self.actor_optimizer = optim.AdamW(
self.actor.parameters(),
lr=actor_lr,
weight_decay=weight_decay,
fused=_fused,
betas=(0.9, 0.95),
)
self.alpha_optimizer = optim.AdamW(
[self.log_alpha],
lr=alpha_lr,
fused=_fused,
betas=(0.9, 0.95),
weight_decay=0.0,
)
# Step counter
self.update_count = 0
# AMP scaler for mixed precision (fp16 only; bf16 has fp32 range and skips scaler)
self.scaler = (
torch.amp.GradScaler("cuda") # pyright: ignore[reportPrivateImportUsage]
if self._should_use_grad_scaler(self.use_amp, self._device_type, self._amp_dtype)
else None
)
self.symmetry = symmetry_augmentation
if use_symmetry and symmetry_augmentation is None:
raise ValueError(
"FastSACLearner use_symmetry=True requires a symmetry_augmentation contract"
)
self.use_symmetry = use_symmetry
if self.use_compile:
self._compile_training_methods()
@staticmethod
def _resolve_amp_dtype(amp_dtype: str, device_type: str) -> torch.dtype:
normalized = amp_dtype.lower()
if normalized == "auto":
return torch.bfloat16
if normalized == "fp16":
return torch.float16
if normalized == "bf16":
return torch.bfloat16
raise ValueError("FastSAC amp_dtype must be one of: auto, fp16, bf16")
@staticmethod
def _should_use_grad_scaler(
use_amp: bool,
device_type: str,
amp_dtype: torch.dtype,
) -> bool:
return bool(use_amp) and device_type == "cuda" and amp_dtype == torch.float16
def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or torch.device(self.device).type != "cuda":
return
compile_kwargs = {"options": {"triton.cudagraphs": False}}
self._critic_loss_tensors = compile_fn( # type: ignore[method-assign]
self._critic_loss_tensors, **compile_kwargs
)
self._actor_loss_tensors = compile_fn( # type: ignore[method-assign]
self._actor_loss_tensors, **compile_kwargs
)
def _autocast(self):
return torch.amp.autocast( # pyright: ignore[reportPrivateImportUsage]
self._device_type, dtype=self._amp_dtype, enabled=self.use_amp
)
def _reduce_gradients(self, model: nn.Module) -> None:
"""All-reduce gradients across all workers and divide by world_size.
Must be called after ``backward()`` and, when using AMP, after
``scaler.unscale_(optimizer)`` so that gradients are in full precision.
"""
if self.world_size <= 1:
return
grads = [p.grad.view(-1) for p in model.parameters() if p.grad is not None]
if not grads:
return
flat = torch.cat(grads)
dist.all_reduce(flat, op=dist.ReduceOp.SUM)
flat /= self.world_size
offset = 0
for p in model.parameters():
if p.grad is not None:
n = p.grad.numel()
p.grad.copy_(flat[offset : offset + n].view_as(p.grad))
offset += n
def _get_actions_and_log_probs_for_critic(
self,
actor_obs: torch.Tensor,
critic_obs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample actor actions for critic targets.
Subclasses can use ``critic_obs`` to supply auxiliary policy context while
preserving the standard SAC update path.
"""
del critic_obs
return self.actor.get_actions_and_log_probs(actor_obs)
def _get_actions_and_log_probs_for_actor(
self,
actor_obs: torch.Tensor,
critic_obs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sample actor actions for the actor loss update."""
del critic_obs
return self.actor.get_actions_and_log_probs(actor_obs)
def _critic_loss_tensors(
self,
critic_obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_obs: torch.Tensor,
critic_next_obs: torch.Tensor,
dones: torch.Tensor,
truncated: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
bootstrap = torch.clamp(1.0 - dones.float() + truncated.float(), 0.0, 1.0)
discount = torch.full_like(dones, self.gamma)
with torch.no_grad():
with self._autocast():
next_actions, next_log_probs, _ = self._get_actions_and_log_probs_for_critic(
next_obs,
critic_next_obs,
)
adjusted_rewards = (
rewards - discount * bootstrap * self.log_alpha.exp() * next_log_probs
)
with self._autocast():
target_distributions = self.qnet_target.projection(
critic_next_obs, next_actions, adjusted_rewards, bootstrap, discount
)
target_values = self.qnet_target.get_value(target_distributions)
target_q_max = target_values.max()
target_q_min = target_values.min()
with self._autocast():
q_outputs = self.qnet(critic_obs, actions)
critic_log_probs = F.log_softmax(q_outputs, dim=-1).clamp(min=-30.0)
critic_losses = -torch.sum(target_distributions * critic_log_probs, dim=-1)
qf_loss = critic_losses.mean(dim=1).sum(dim=0)
return qf_loss, target_q_max, target_q_min, next_log_probs.detach()
def _actor_loss_tensors(
self,
obs: torch.Tensor,
critic_obs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with self._autocast():
actions, log_probs, log_std = self._get_actions_and_log_probs_for_actor(
obs,
critic_obs,
)
with torch.no_grad():
action_std = log_std.exp().mean()
policy_entropy = -log_probs.mean()
with self._autocast():
q_outputs = self.qnet(critic_obs, actions)
q_probs = F.softmax(q_outputs, dim=-1)
q_values = self.qnet.get_value(q_probs)
qf_value = q_values.mean(dim=0)
actor_loss = (self.log_alpha.exp().detach() * log_probs - qf_value).mean()
return actor_loss, policy_entropy, action_std
[docs]
def update_critic(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""One critic update step."""
obs = batch["obs"]
critic_obs = batch["critic"]
actions = batch["actions"]
rewards = batch["rewards"]
next_obs = batch["next_obs"]
critic_next_obs = batch["next_critic"]
dones = batch["dones"]
truncated = batch["truncated"]
# Apply symmetry augmentation
if self.use_symmetry:
orig_actions = actions
assert self.symmetry is not None
obs, actions = self.symmetry.augment_obs_and_actions(obs, actions, obs_group="obs")
next_obs, _ = self.symmetry.augment_obs_and_actions(
next_obs, orig_actions, obs_group="obs"
)
critic_obs, _ = self.symmetry.augment_obs_and_actions(
critic_obs,
orig_actions,
obs_group="critic",
)
critic_next_obs, _ = self.symmetry.augment_obs_and_actions(
critic_next_obs,
orig_actions,
obs_group="critic",
)
# Double the batch size for other tensors
rewards = rewards.repeat(2)
dones = dones.repeat(2)
truncated = truncated.repeat(2)
qf_loss, target_q_max, target_q_min, next_log_probs = self._critic_loss_tensors(
critic_obs,
actions,
rewards,
next_obs,
critic_next_obs,
dones,
truncated,
)
# Skip if NaN
if torch.isfinite(qf_loss):
self.q_optimizer.zero_grad(set_to_none=True)
if self.scaler:
self.scaler.scale(qf_loss).backward()
self.scaler.unscale_(self.q_optimizer)
self._reduce_gradients(self.qnet)
if self.max_grad_norm > 0:
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
self.qnet.parameters(), max_norm=self.max_grad_norm
)
else:
critic_grad_norm = torch.tensor(0.0, device=self.device)
self.scaler.step(self.q_optimizer)
self.scaler.update()
else:
qf_loss.backward()
self._reduce_gradients(self.qnet)
if self.max_grad_norm > 0:
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
self.qnet.parameters(), max_norm=self.max_grad_norm
)
else:
critic_grad_norm = torch.tensor(0.0, device=self.device)
self.q_optimizer.step()
else:
critic_grad_norm = torch.tensor(0.0, device=self.device)
# Alpha loss (temperature update) - matching holosoma
alpha_loss = torch.tensor(0.0, device=self.device)
if self.use_autotune:
self.alpha_optimizer.zero_grad(set_to_none=True)
alpha_loss = (-self.log_alpha.exp() * (next_log_probs + self.target_entropy)).mean()
if torch.isfinite(alpha_loss):
alpha_loss.backward()
if self.world_size > 1 and self.log_alpha.grad is not None:
dist.all_reduce(self.log_alpha.grad, op=dist.ReduceOp.SUM)
self.log_alpha.grad /= self.world_size
self.alpha_optimizer.step()
return {
"qf_loss": qf_loss.item(),
"critic_grad_norm": critic_grad_norm.item(),
"target_q_max": target_q_max.item(),
"target_q_min": target_q_min.item(),
"alpha_loss": alpha_loss.item(),
"alpha": self.log_alpha.exp().item(),
}
[docs]
def update_actor(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""One actor update step."""
obs = batch["obs"]
critic_obs = batch["critic"]
# Apply symmetry augmentation
if self.use_symmetry:
assert self.symmetry is not None
obs = torch.cat([obs, self.symmetry.mirror_obs(obs, obs_group="obs")], dim=0)
critic_obs = torch.cat(
[critic_obs, self.symmetry.mirror_obs(critic_obs, obs_group="critic")],
dim=0,
)
actor_loss, policy_entropy, action_std = self._actor_loss_tensors(obs, critic_obs)
# Skip if NaN
if torch.isfinite(actor_loss):
self.actor_optimizer.zero_grad(set_to_none=True)
if self.scaler:
self.scaler.scale(actor_loss).backward()
self.scaler.unscale_(self.actor_optimizer)
self._reduce_gradients(self.actor)
if self.max_grad_norm > 0:
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
self.actor.parameters(), max_norm=self.max_grad_norm
)
else:
actor_grad_norm = torch.tensor(0.0, device=self.device)
self.scaler.step(self.actor_optimizer)
self.scaler.update()
else:
actor_loss.backward()
self._reduce_gradients(self.actor)
if self.max_grad_norm > 0:
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
self.actor.parameters(), max_norm=self.max_grad_norm
)
else:
actor_grad_norm = torch.tensor(0.0, device=self.device)
self.actor_optimizer.step()
else:
actor_grad_norm = torch.tensor(0.0, device=self.device)
return {
"actor_loss": actor_loss.item(),
"actor_grad_norm": actor_grad_norm.item(),
"policy_entropy": policy_entropy.item(),
"action_std": action_std.item(),
}
[docs]
def soft_update_target(self) -> None:
"""Polyak-average update of the target Q-network."""
with torch.no_grad():
for tgt, src in zip(self.qnet_target.parameters(), self.qnet.parameters()):
tgt.data.mul_(1.0 - self.tau).add_(src.data, alpha=self.tau)
[docs]
def get_state_dict(self) -> Dict[str, Any]:
"""Save all components."""
return {
"actor": self.actor.state_dict(),
"qnet": self.qnet.state_dict(),
"qnet_target": self.qnet_target.state_dict(),
"log_alpha": self.log_alpha.detach().cpu(),
"actor_optimizer": self.actor_optimizer.state_dict(),
"q_optimizer": self.q_optimizer.state_dict(),
"alpha_optimizer": self.alpha_optimizer.state_dict(),
"update_count": self.update_count,
}
[docs]
def load_state_dict(self, state_dict: Dict) -> None:
"""Load all components."""
self.actor.load_state_dict(state_dict["actor"])
self.qnet.load_state_dict(state_dict["qnet"])
self.qnet_target.load_state_dict(state_dict["qnet_target"])
self.log_alpha.data.copy_(state_dict["log_alpha"].to(self.device))
self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
self.q_optimizer.load_state_dict(state_dict["q_optimizer"])
self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
self.update_count = state_dict.get("update_count", 0)
# ---------------------------------------------------------------------------