Source code for unilab.algos.torch.flash_sac.update

"""FlashSAC update helpers."""

from __future__ import annotations

import math

import torch


[docs] def build_lr_lambda( init_lr: float, peak_lr: float, end_lr: float, warmup_steps: int, decay_steps: int, ): init_ratio = init_lr / peak_lr if peak_lr > 0 else 1.0 end_ratio = end_lr / peak_lr if peak_lr > 0 else 1.0 warmup_steps = max(int(warmup_steps), 0) decay_steps = max(int(decay_steps), warmup_steps + 1) def schedule(step: int) -> float: if warmup_steps > 0 and step < warmup_steps: progress = float(step) / float(max(warmup_steps, 1)) return init_ratio + (1.0 - init_ratio) * progress progress = min(max((step - warmup_steps) / float(decay_steps - warmup_steps), 0.0), 1.0) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return end_ratio + (1.0 - end_ratio) * cosine return schedule
[docs] def select_min_q_log_probs( next_q_values: torch.Tensor, next_q_log_probs: torch.Tensor, ) -> torch.Tensor: num_bins = next_q_log_probs.shape[-1] min_indices = next_q_values.argmin(dim=0) gather_index = min_indices[None, :, None].expand(1, -1, num_bins) return torch.gather(next_q_log_probs, dim=0, index=gather_index)[0]
[docs] def compute_categorical_td_target( support: torch.Tensor, target_log_probs: torch.Tensor, reward: torch.Tensor, dones: torch.Tensor, truncated: torch.Tensor, actor_entropy: torch.Tensor, gamma: float, ) -> torch.Tensor: batch_size, num_bins = target_log_probs.shape support = support.view(1, -1) reward = reward.view(-1, 1) dones = dones.view(-1, 1) truncated = truncated.view(-1, 1) actor_entropy = actor_entropy.view(-1, 1) bootstrap = torch.clamp(1.0 - dones + truncated, 0.0, 1.0) target_bin_values = reward + bootstrap * gamma * (support - actor_entropy) target_bin_values = torch.clamp(target_bin_values, float(support.min()), float(support.max())) bin_width = float(support[0, 1] - support[0, 0]) offsets = (target_bin_values - float(support.min())) / max(bin_width, 1e-8) lower = torch.floor(offsets).long().clamp(0, num_bins - 1) upper = torch.ceil(offsets).long().clamp(0, num_bins - 1) frac = offsets - lower.float() probs = target_log_probs.exp() target_probs = torch.zeros(batch_size, num_bins, dtype=probs.dtype, device=probs.device) target_probs.scatter_add_(1, lower, probs * (1.0 - frac)) target_probs.scatter_add_(1, upper, probs * frac) return target_probs
[docs] def resolve_target_entropy( action_dim: int, target_sigma: float, target_entropy: float | None, ) -> float: if target_entropy is not None: return float(target_entropy) sigma = max(float(target_sigma), 1e-6) per_dim_entropy = 0.5 * math.log(2.0 * math.pi * math.e * sigma * sigma) return float(action_dim) * per_dim_entropy