"""FlashSAC actor, critic, and temperature modules."""
from __future__ import annotations
import math
from typing import cast
import torch
import torch.nn as nn
from unilab.algos.torch.flash_sac.layers import (
EnsembleCategoricalValue,
EnsembleFlashSACBlock,
EnsembleFlashSACEmbedder,
EnsembleUnitRMSNorm,
FlashSACBlock,
FlashSACEmbedder,
NormalTanhPolicy,
UnitRMSNorm,
)
def _normalize_module_tree(module: nn.Module) -> None:
with torch.no_grad():
for child in module.modules():
if child is module:
continue
normalize = getattr(child, "normalize_parameters", None)
if callable(normalize):
normalize()
[docs]
class FlashSACActor(nn.Module):
zeta_cdf: torch.Tensor
_noise: torch.Tensor
_repeat_count: torch.Tensor
_repeat_target: torch.Tensor
[docs]
def __init__(
self,
num_blocks: int,
input_dim: int,
hidden_dim: int,
action_dim: int,
noise_zeta_mu: float = 2.0,
noise_zeta_max: int = 16,
device: str | torch.device = "cpu",
):
super().__init__()
self.embedder = FlashSACEmbedder(input_dim=input_dim, hidden_dim=hidden_dim)
self.encoder = nn.ModuleList([FlashSACBlock(hidden_dim) for _ in range(num_blocks)])
self.post_norm = UnitRMSNorm(hidden_dim)
self.predictor = NormalTanhPolicy(hidden_dim=hidden_dim, action_dim=action_dim)
self.noise_zeta_mu = noise_zeta_mu
self.noise_zeta_max = noise_zeta_max
ns = torch.arange(1, noise_zeta_max + 1, dtype=torch.float32)
pmf = ns.pow(-noise_zeta_mu)
self.register_buffer("zeta_cdf", torch.cumsum(pmf / pmf.sum(), dim=0))
self.register_buffer("_noise", torch.zeros(0), persistent=False)
self.register_buffer("_repeat_count", torch.zeros(0, dtype=torch.int32), persistent=False)
self.register_buffer("_repeat_target", torch.zeros(0, dtype=torch.int32), persistent=False)
self.to(device)
self.normalize_parameters()
[docs]
def normalize_parameters(self) -> None:
_normalize_module_tree(self)
def _encode(self, observations: torch.Tensor, training: bool) -> torch.Tensor:
x = self.embedder(observations, training=training)
for block in self.encoder:
x = block(x, training=training)
return cast(torch.Tensor, self.post_norm(x))
[docs]
def get_mean_and_std(
self, observations: torch.Tensor, training: bool
) -> tuple[torch.Tensor, torch.Tensor]:
encoded = self._encode(observations, training=training)
return self.predictor.get_mean_and_std(encoded)
[docs]
def forward(
self, observations: torch.Tensor, training: bool
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
encoded = self._encode(observations, training=training)
return cast(tuple[torch.Tensor, dict[str, torch.Tensor]], self.predictor(encoded))
[docs]
def as_export_module(self) -> "nn.Module":
"""Return a single-input/single-output wrapper suitable for torch.onnx.export."""
actor = self
class _Wrapper(nn.Module):
def __init__(self) -> None:
super().__init__()
self.base = actor
def forward(self, obs: torch.Tensor) -> torch.Tensor:
mean, _std = self.base.get_mean_and_std(obs, training=False)
return cast(torch.Tensor, torch.tanh(mean))
return _Wrapper()
def _ensure_exploration_state(
self, batch_size: int, action_dim: int, device: torch.device, dtype: torch.dtype
) -> None:
if (
self._noise.numel() == batch_size * action_dim
and self._noise.device == device
and self._noise.dtype == dtype
):
return
self._noise = torch.zeros(batch_size, action_dim, device=device, dtype=dtype)
self._repeat_count = torch.zeros(batch_size, device=device, dtype=torch.int32)
self._repeat_target = torch.zeros(batch_size, device=device, dtype=torch.int32)
def _sample_repeat_targets(self, batch_size: int, device: torch.device) -> torch.Tensor:
draws = torch.rand(batch_size, device=device)
cdf = self.zeta_cdf.to(device)
return cast(torch.Tensor, torch.searchsorted(cdf, draws).to(torch.int32) + 1)
[docs]
@torch.no_grad()
def explore(
self,
obs: torch.Tensor,
dones: torch.Tensor | None = None,
deterministic: bool = False,
) -> torch.Tensor:
if isinstance(dones, bool):
deterministic = dones
dones = None
mean, std = self.get_mean_and_std(obs, training=False)
if deterministic:
return torch.tanh(mean)
batch_size, action_dim = mean.shape
self._ensure_exploration_state(batch_size, action_dim, mean.device, mean.dtype)
if dones is None:
done_mask = torch.zeros(batch_size, device=mean.device, dtype=torch.bool)
else:
done_mask = dones.to(device=mean.device).reshape(-1) > 0.5
reinit = done_mask | (self._repeat_count <= 0) | (self._repeat_count >= self._repeat_target)
if torch.any(reinit):
new_noise = torch.randn_like(mean)
new_target = self._sample_repeat_targets(batch_size, mean.device)
self._noise = torch.where(reinit.unsqueeze(-1), new_noise, self._noise)
self._repeat_target = torch.where(reinit, new_target, self._repeat_target)
self._repeat_count = torch.where(
reinit, torch.zeros_like(self._repeat_count), self._repeat_count
)
actions = torch.tanh(mean + std * self._noise)
self._repeat_count = self._repeat_count + 1
return actions
[docs]
class FlashSACDoubleCritic(nn.Module):
[docs]
def __init__(
self,
num_blocks: int,
input_dim: int,
hidden_dim: int,
num_bins: int,
min_v: float,
max_v: float,
num_qs: int = 2,
device: str | torch.device = "cpu",
):
super().__init__()
self.embedder = EnsembleFlashSACEmbedder(num_qs, input_dim, hidden_dim)
self.encoder = nn.ModuleList(
[EnsembleFlashSACBlock(num_qs, hidden_dim) for _ in range(num_blocks)]
)
self.post_norm = EnsembleUnitRMSNorm(num_qs, hidden_dim)
self.predictor = EnsembleCategoricalValue(
num_ensemble=num_qs,
hidden_dim=hidden_dim,
num_bins=num_bins,
min_v=min_v,
max_v=max_v,
)
self.to(device)
self.normalize_parameters()
[docs]
def normalize_parameters(self) -> None:
_normalize_module_tree(self)
[docs]
def forward(
self,
observations: torch.Tensor,
actions: torch.Tensor,
training: bool,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
x = torch.cat((observations, actions), dim=-1)
x = x.unsqueeze(0).expand(self.predictor.logit_w.weight.shape[0], -1, -1)
x = self.embedder(x, training=training)
for block in self.encoder:
x = block(x, training=training)
x = self.post_norm(x)
return cast(tuple[torch.Tensor, dict[str, torch.Tensor]], self.predictor(x))
[docs]
class FlashSACTemperature(nn.Module):
[docs]
def __init__(self, initial_value: float = 0.01):
super().__init__()
self.log_temp = nn.Parameter(torch.tensor([math.log(initial_value)], dtype=torch.float32))
[docs]
def forward(self) -> torch.Tensor:
return torch.exp(self.log_temp)