"""FlashSAC learner adapted to UniLab's off-policy contract."""
from __future__ import annotations
import copy
from dataclasses import dataclass
from typing import Any, cast
import torch
import torch.nn as nn
import torch.optim as optim
from unilab.algos.torch.common.normalization import EmpiricalNormalization
from unilab.algos.torch.flash_sac.network import (
FlashSACActor,
FlashSACDoubleCritic,
FlashSACTemperature,
)
from unilab.algos.torch.flash_sac.update import (
build_lr_lambda,
resolve_target_entropy,
select_min_q_log_probs,
)
[docs]
@dataclass
class RunningMeanStd:
mean: torch.Tensor
var: torch.Tensor
count: torch.Tensor
[docs]
@classmethod
def create(cls, device: torch.device) -> "RunningMeanStd":
return cls(
mean=torch.zeros(1, device=device, dtype=torch.float32),
var=torch.ones(1, device=device, dtype=torch.float32),
count=torch.tensor(1e-4, device=device, dtype=torch.float32),
)
[docs]
def update(self, x: torch.Tensor) -> None:
x = x.reshape(-1).to(dtype=torch.float32)
if x.numel() == 0:
return
batch_mean = x.mean()
batch_var = x.var(unbiased=False)
batch_count = torch.tensor(float(x.numel()), device=x.device, dtype=torch.float32)
delta = batch_mean - self.mean
total_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
correction = delta.pow(2) * self.count * batch_count / total_count
new_var = (m_a + m_b + correction) / total_count
self.mean = new_mean
self.var = new_var
self.count = total_count
[docs]
def state_dict(self) -> dict[str, torch.Tensor]:
return {"mean": self.mean, "var": self.var, "count": self.count}
[docs]
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
self.mean = state_dict["mean"]
self.var = state_dict["var"]
self.count = state_dict["count"]
[docs]
class RewardNormalizer:
"""Adaptive reward scaling with running discounted-return statistics."""
[docs]
def __init__(
self,
gamma: float,
g_max: float,
device: torch.device,
eps: float = 1e-8,
):
self.gamma = gamma
self.g_max = g_max
self.eps = eps
self.device = device
self.rms = RunningMeanStd.create(device)
self.g_r = torch.zeros(0, device=device, dtype=torch.float32)
self.g_r_max = torch.tensor(0.0, device=device, dtype=torch.float32)
def _ensure_g_r_shape(self, num_envs: int) -> None:
if self.g_r.shape == (num_envs,):
return
self.g_r = torch.zeros(num_envs, device=self.device, dtype=torch.float32)
[docs]
def update_from_transitions(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
) -> None:
rewards = rewards.to(device=self.device, dtype=torch.float32)
dones = dones.to(device=self.device, dtype=torch.float32)
if rewards.ndim == 1:
rewards = rewards.unsqueeze(0)
dones = dones.unsqueeze(0)
if rewards.numel() == 0:
return
num_envs = int(rewards.shape[-1])
self._ensure_g_r_shape(num_envs)
done = torch.clamp(dones, min=0.0, max=1.0)
for step in range(rewards.shape[0]):
self.g_r = self.gamma * (1.0 - done[step]) * self.g_r + rewards[step]
self.g_r_max = torch.maximum(self.g_r_max, self.g_r.abs().max())
self.rms.update(self.g_r)
[docs]
def normalize(self, rewards: torch.Tensor) -> torch.Tensor:
denominator = torch.maximum(
torch.sqrt(self.rms.var + self.eps),
self.g_r_max / max(self.g_max, self.eps),
)
return rewards / denominator
[docs]
def state_dict(self) -> dict[str, Any]:
return {
"rms": self.rms.state_dict(),
"g_r": self.g_r,
"g_r_max": self.g_r_max,
}
[docs]
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.rms.load_state_dict(state_dict["rms"])
self.g_r = state_dict["g_r"]
self.g_r_max = state_dict["g_r_max"]
[docs]
class FlashSACLearner:
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
critic_obs_dim: int,
device: str = "cpu",
gamma: float = 0.99,
tau: float = 0.01,
actor_lr: float = 3e-4,
critic_lr: float = 3e-4,
actor_hidden_dim: int = 128,
critic_hidden_dim: int = 256,
actor_num_blocks: int = 2,
critic_num_blocks: int = 2,
num_atoms: int = 101,
critic_min_v: float = -5.0,
critic_max_v: float = 5.0,
temp_initial_value: float = 0.01,
temp_target_sigma: float = 0.15,
temp_target_entropy: float | None = None,
actor_bc_alpha: float = 0.0,
actor_noise_zeta_mu: float = 2.0,
actor_noise_zeta_max: int = 16,
learning_rate_init: float = 3e-4,
learning_rate_peak: float = 3e-4,
learning_rate_end: float = 1.5e-4,
learning_rate_warmup_steps: int = 0,
learning_rate_decay_steps: int = 500000,
normalize_reward: bool = True,
normalized_g_max: float = 5.0,
n_step: int = 1,
obs_normalization: bool = False,
use_amp: bool = False,
amp_dtype: str = "auto",
use_compile: bool = False,
):
self.device = torch.device(device)
self.gamma = gamma
self.tau = tau
self.n_step = n_step
self.actor_bc_alpha = actor_bc_alpha
self.obs_dim = obs_dim
self.critic_obs_dim = critic_obs_dim
self.action_dim = action_dim
self.update_count = 0
self.use_amp = bool(use_amp and self.device.type in ("cuda", "xpu"))
self.amp_dtype = amp_dtype
self._amp_dtype = self._resolve_amp_dtype(amp_dtype, self.device.type)
self.use_compile = bool(
use_compile and hasattr(torch, "compile") and self.device.type == "cuda"
)
self.actor = FlashSACActor(
num_blocks=actor_num_blocks,
input_dim=obs_dim,
hidden_dim=actor_hidden_dim,
action_dim=action_dim,
noise_zeta_mu=actor_noise_zeta_mu,
noise_zeta_max=actor_noise_zeta_max,
device=self.device,
)
self.critic = FlashSACDoubleCritic(
num_blocks=critic_num_blocks,
input_dim=self.critic_obs_dim + action_dim,
hidden_dim=critic_hidden_dim,
num_bins=num_atoms,
min_v=critic_min_v,
max_v=critic_max_v,
device=self.device,
)
self.target_critic = copy.deepcopy(self.critic).to(self.device)
self.target_critic.eval()
self.temperature = FlashSACTemperature(temp_initial_value).to(self.device)
self.target_entropy = resolve_target_entropy(
action_dim=action_dim,
target_sigma=temp_target_sigma,
target_entropy=temp_target_entropy,
)
self.obs_normalizer: EmpiricalNormalization | nn.Identity
if obs_normalization:
self.obs_normalizer = EmpiricalNormalization(shape=obs_dim, device=self.device)
else:
self.obs_normalizer = nn.Identity()
self.reward_normalizer = (
RewardNormalizer(gamma=self.gamma, g_max=normalized_g_max, device=self.device)
if normalize_reward
else None
)
# GradScaler is only needed for fp16 (cuda); bf16 on xpu doesn't need it.
self.scaler: Any | None = (
getattr(torch.amp, "GradScaler")("cuda")
if self._should_use_grad_scaler(self.use_amp, self.device.type, self._amp_dtype)
else None
)
lr_peak = learning_rate_peak if learning_rate_peak > 0 else actor_lr
fused = self.device.type == "cuda"
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr_peak, fused=fused)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr_peak, fused=fused)
self.temperature_optimizer = optim.Adam(
self.temperature.parameters(), lr=lr_peak, fused=fused
)
scheduler_fn = build_lr_lambda(
init_lr=learning_rate_init,
peak_lr=lr_peak,
end_lr=learning_rate_end,
warmup_steps=learning_rate_warmup_steps,
decay_steps=learning_rate_decay_steps,
)
self.actor_scheduler = optim.lr_scheduler.LambdaLR(self.actor_optimizer, scheduler_fn)
self.critic_scheduler = optim.lr_scheduler.LambdaLR(self.critic_optimizer, scheduler_fn)
self.temperature_scheduler = optim.lr_scheduler.LambdaLR(
self.temperature_optimizer, scheduler_fn
)
if self.use_compile:
self._compile_training_methods()
def _compile_training_methods(self) -> None:
compile_fn = getattr(torch, "compile", None)
if compile_fn is None or self.device.type != "cuda":
return
compile_kwargs = {"options": {"triton.cudagraphs": False}}
self.actor.get_mean_and_std = compile_fn( # type: ignore[method-assign]
self.actor.get_mean_and_std, **compile_kwargs
)
self._critic_loss_tensors = compile_fn( # type: ignore[method-assign]
self._critic_loss_tensors, **compile_kwargs
)
self._actor_loss_tensors = compile_fn( # type: ignore[method-assign]
self._actor_loss_tensors, **compile_kwargs
)
@staticmethod
def _resolve_amp_dtype(amp_dtype: str, device_type: str) -> torch.dtype:
normalized = amp_dtype.lower()
if normalized == "auto":
return torch.bfloat16
if normalized == "fp16":
return torch.float16
if normalized == "bf16":
return torch.bfloat16
raise ValueError("FlashSAC amp_dtype must be one of: auto, fp16, bf16")
@staticmethod
def _should_use_grad_scaler(
use_amp: bool,
device_type: str,
amp_dtype: torch.dtype,
) -> bool:
return bool(use_amp) and device_type == "cuda" and amp_dtype == torch.float16
def _maybe_normalize_obs(self, obs: torch.Tensor, *, update: bool) -> torch.Tensor:
if isinstance(self.obs_normalizer, nn.Identity):
return obs
return cast(torch.Tensor, self.obs_normalizer(obs, update=update))
def _autocast(self):
return torch.autocast(
device_type=self.device.type, dtype=self._amp_dtype, enabled=self.use_amp
)
[docs]
def update_reward_stats(
self,
rewards: torch.Tensor,
dones: torch.Tensor,
) -> None:
if self.reward_normalizer is None:
return
self.reward_normalizer.update_from_transitions(rewards, dones)
@staticmethod
def _set_requires_grad(module: nn.Module, requires_grad: bool) -> None:
for param in module.parameters():
param.requires_grad_(requires_grad)
def _critic_loss_tensors(
self,
next_q_values: torch.Tensor,
next_q_log_probs_full: torch.Tensor,
support: torch.Tensor,
rewards: torch.Tensor,
dones: torch.Tensor,
truncated: torch.Tensor,
actor_entropy: torch.Tensor,
pred_log_probs: torch.Tensor,
gamma: float,
) -> torch.Tensor:
next_q_log_probs = select_min_q_log_probs(next_q_values, next_q_log_probs_full)
batch_size, num_bins = next_q_log_probs.shape
support_view = support.view(1, -1)
rewards = rewards.view(-1, 1)
dones = dones.view(-1, 1)
truncated = truncated.view(-1, 1)
actor_entropy = actor_entropy.view(-1, 1)
bootstrap = torch.clamp(1.0 - dones + truncated, 0.0, 1.0)
support_min = support_view.min()
support_max = support_view.max()
target_bin_values = rewards + bootstrap * gamma * (support_view - actor_entropy)
target_bin_values = torch.clamp(target_bin_values, support_min, support_max)
bin_width = torch.clamp(support_view[0, 1] - support_view[0, 0], min=1e-8)
offsets = (target_bin_values - support_min) / bin_width
lower = torch.floor(offsets).long().clamp(0, num_bins - 1)
upper = torch.ceil(offsets).long().clamp(0, num_bins - 1)
frac = offsets - lower.float()
probs = next_q_log_probs.exp()
target_probs = torch.zeros(batch_size, num_bins, dtype=probs.dtype, device=probs.device)
target_probs.scatter_add_(1, lower, probs * (1.0 - frac))
target_probs.scatter_add_(1, upper, probs * frac)
return cast(torch.Tensor, -(target_probs.unsqueeze(0) * pred_log_probs).sum(dim=-1).mean())
def _actor_loss_tensors(
self,
log_probs: torch.Tensor,
q_values: torch.Tensor,
actions: torch.Tensor,
expert_actions: torch.Tensor,
temp_value: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
min_q = torch.min(q_values[0], q_values[1])
actor_loss = (temp_value.detach() * log_probs - min_q).mean()
if self.actor_bc_alpha > 0:
bc_loss = torch.mean((actions - expert_actions) ** 2)
actor_loss = actor_loss + self.actor_bc_alpha * min_q.abs().mean().detach() * bc_loss
entropy = -log_probs.detach().mean()
return actor_loss, entropy
[docs]
def update_critic(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
obs = batch["obs"].to(self.device)
actions = batch["actions"].to(self.device)
rewards = batch["rewards"].to(self.device)
next_obs = batch["next_obs"].to(self.device)
dones = batch["dones"].to(self.device)
truncated = batch["truncated"].to(self.device)
critic_obs = batch["critic"].to(self.device)
critic_next_obs = batch["next_critic"].to(self.device)
obs = self._maybe_normalize_obs(obs, update=True)
next_obs = self._maybe_normalize_obs(next_obs, update=False)
if self.reward_normalizer is not None:
rewards = self.reward_normalizer.normalize(rewards)
gamma = self.gamma**self.n_step
obs_all = torch.cat([critic_obs, critic_next_obs], dim=0)
with torch.no_grad():
with self._autocast():
next_actions, actor_info = self.actor(next_obs, training=False)
actor_entropy = self.temperature().detach() * actor_info["log_prob"]
act_all = torch.cat([actions, next_actions], dim=0)
qs_all, q_info_all = self.target_critic(obs_all, act_all, training=True)
next_q_values = qs_all.chunk(2, dim=1)[1]
next_q_log_probs_full = q_info_all["log_prob"].chunk(2, dim=1)[1]
support = cast(torch.Tensor, self.target_critic.predictor.support)
with self._autocast():
_, pred_info_all = self.critic(obs_all, act_all, training=True)
pred_log_probs = pred_info_all["log_prob"].chunk(2, dim=1)[0]
critic_loss = self._critic_loss_tensors(
next_q_values,
next_q_log_probs_full,
support,
rewards,
dones,
truncated,
actor_entropy,
pred_log_probs,
gamma,
)
self.critic_optimizer.zero_grad(set_to_none=True)
if self.scaler is not None:
self.scaler.scale(critic_loss).backward()
self.scaler.step(self.critic_optimizer)
self.scaler.update()
else:
critic_loss.backward()
self.critic_optimizer.step()
self.critic_scheduler.step()
self.critic.normalize_parameters()
return {
"critic_loss": float(critic_loss.detach().cpu()),
"reward_scale_std": float(
torch.sqrt(self.reward_normalizer.rms.var).detach().cpu()
if self.reward_normalizer is not None
else torch.tensor(1.0)
),
}
[docs]
def update_actor(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
obs = batch["obs"].to(self.device)
next_obs = batch["next_obs"].to(self.device)
expert_actions = batch["actions"].to(self.device)
critic_obs = batch["critic"].to(self.device)
obs = self._maybe_normalize_obs(obs, update=False)
next_obs = self._maybe_normalize_obs(next_obs, update=False)
obs_all = torch.cat([obs, next_obs], dim=0)
with self._autocast():
actions_all, actor_info_all = self.actor(obs_all, training=True)
actions = actions_all.chunk(2, dim=0)[0]
log_probs = actor_info_all["log_prob"].chunk(2, dim=0)[0]
self._set_requires_grad(self.critic, False)
q_values, _ = self.critic(critic_obs, actions, training=False)
self._set_requires_grad(self.critic, True)
actor_loss, entropy = self._actor_loss_tensors(
log_probs, q_values, actions, expert_actions, self.temperature()
)
self.actor_optimizer.zero_grad(set_to_none=True)
if self.scaler is not None:
self.scaler.scale(actor_loss).backward()
self.scaler.step(self.actor_optimizer)
self.scaler.update()
else:
actor_loss.backward()
self.actor_optimizer.step()
self.actor_scheduler.step()
self.actor.normalize_parameters()
temp_value = self.temperature()
temp_loss = temp_value * (entropy - self.target_entropy)
self.temperature_optimizer.zero_grad(set_to_none=True)
temp_loss.backward()
self.temperature_optimizer.step()
self.temperature_scheduler.step()
return {
"actor_loss": float(actor_loss.detach().cpu()),
"actor_entropy": float(entropy.detach().cpu()),
"temperature": float(temp_value.detach().cpu()),
"temperature_loss": float(temp_loss.detach().cpu()),
}
[docs]
def soft_update_target(self) -> None:
with torch.no_grad():
for target_param, param in zip(
self.target_critic.parameters(), self.critic.parameters()
):
target_param.data.mul_(1.0 - self.tau).add_(param.data, alpha=self.tau)
[docs]
def get_state_dict(self) -> dict[str, Any]:
return {
"actor": self.actor.state_dict(),
"critic": self.critic.state_dict(),
"target_critic": self.target_critic.state_dict(),
"temperature": self.temperature.state_dict(),
"actor_optimizer": self.actor_optimizer.state_dict(),
"critic_optimizer": self.critic_optimizer.state_dict(),
"temperature_optimizer": self.temperature_optimizer.state_dict(),
"actor_scheduler": self.actor_scheduler.state_dict(),
"critic_scheduler": self.critic_scheduler.state_dict(),
"temperature_scheduler": self.temperature_scheduler.state_dict(),
"obs_normalizer": (
self.obs_normalizer.state_dict()
if hasattr(self.obs_normalizer, "state_dict")
else None
),
"reward_normalizer": (
self.reward_normalizer.state_dict() if self.reward_normalizer is not None else None
),
"update_count": self.update_count,
}
[docs]
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.actor.load_state_dict(state_dict["actor"])
self.critic.load_state_dict(state_dict["critic"])
self.target_critic.load_state_dict(state_dict["target_critic"])
self.temperature.load_state_dict(state_dict["temperature"])
self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
self.critic_optimizer.load_state_dict(state_dict["critic_optimizer"])
self.temperature_optimizer.load_state_dict(state_dict["temperature_optimizer"])
self.actor_scheduler.load_state_dict(state_dict["actor_scheduler"])
self.critic_scheduler.load_state_dict(state_dict["critic_scheduler"])
self.temperature_scheduler.load_state_dict(state_dict["temperature_scheduler"])
if state_dict.get("obs_normalizer") and hasattr(self.obs_normalizer, "load_state_dict"):
self.obs_normalizer.load_state_dict(state_dict["obs_normalizer"])
if self.reward_normalizer is not None and state_dict.get("reward_normalizer"):
self.reward_normalizer.load_state_dict(state_dict["reward_normalizer"])
self.update_count = int(state_dict.get("update_count", 0))