Source code for unilab.algos.torch.common.stability
"""Numerical stability utilities for RL training."""
import torch
[docs]
def check_nan_loss(
loss: torch.Tensor, default_metrics: dict
) -> tuple[torch.Tensor | None, dict | None]:
"""Check if loss contains NaN or Inf values.
Args:
loss: Loss tensor to check
default_metrics: Default metric values to return if NaN detected
Returns:
(loss, None) if valid, (None, nan_metrics) if invalid
"""
if torch.isnan(loss) or torch.isinf(loss):
nan_metrics = {k: float("nan") for k in default_metrics}
return None, nan_metrics
return loss, None
[docs]
def clip_gradients(parameters, max_norm: float = 10.0):
"""Clip gradients by global norm.
Args:
parameters: Model parameters
max_norm: Maximum gradient norm
"""
torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)
[docs]
def safe_tensor(
tensor: torch.Tensor, nan_value: float = 0.0, clamp_range: tuple = (-10.0, 10.0)
) -> torch.Tensor:
"""Make tensor numerically safe by clamping and replacing NaN values.
Args:
tensor: Input tensor
nan_value: Value to replace NaN with
clamp_range: (min, max) range to clamp values
Returns:
Safe tensor
"""
tensor = torch.clamp(tensor, clamp_range[0], clamp_range[1])
return torch.nan_to_num(tensor, nan=nan_value)