Source code for unilab.algos.torch.common.networks

"""Neural network architectures for RL algorithms."""

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class DistributionalQNetwork(nn.Module): """Single distributional Q-network (C51 variant). Architecture: Linear→ReLU → Linear→ReLU → Linear→ReLU → Linear Outputs num_atoms logits over the value distribution. """
[docs] def __init__( self, obs_dim: int, n_act: int, num_atoms: int, v_min: float, v_max: float, hidden_dim: int, device: Optional[torch.device] = None, ): super().__init__() self.net = nn.Sequential( nn.Linear(obs_dim + n_act, hidden_dim, device=device), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim // 2, device=device), nn.ReLU(), nn.Linear(hidden_dim // 2, hidden_dim // 4, device=device), nn.ReLU(), nn.Linear(hidden_dim // 4, num_atoms, device=device), ) self.v_min = v_min self.v_max = v_max self.num_atoms = num_atoms
[docs] def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: x = torch.cat([obs, actions], 1) x = self.net(x) return torch.as_tensor(x)
[docs] def projection( self, obs: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, bootstrap: torch.Tensor, discount: torch.Tensor, q_support: torch.Tensor, device: torch.device, ) -> torch.Tensor: """Categorical projection (Bellman update on the distribution support).""" delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) batch_size = rewards.shape[0] target_z = rewards.unsqueeze(1) + bootstrap.unsqueeze(1) * discount.unsqueeze(1) * q_support target_z = target_z.clamp(self.v_min, self.v_max) b = (target_z - self.v_min) / delta_z l = torch.floor(b).long() u = torch.ceil(b).long() is_int = l == u l_mask = is_int & (l > 0) u_mask = is_int & (l == 0) l = torch.where(l_mask, l - 1, l) u = torch.where(u_mask, u + 1, u) next_dist = F.softmax(self.forward(obs, actions), dim=1) proj_dist = torch.zeros_like(next_dist) offset = ( torch.linspace(0, (batch_size - 1) * self.num_atoms, batch_size, device=device) .unsqueeze(1) .expand(batch_size, self.num_atoms) .long() ) proj_dist.view(-1).index_add_( 0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1) ) proj_dist.view(-1).index_add_( 0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1) ) return proj_dist
[docs] class Critic(nn.Module): """Twin distributional Q-networks for off-policy RL (SAC/TD3).""" q_support: torch.Tensor
[docs] def __init__( self, obs_dim: int, n_act: int, num_atoms: int, v_min: float, v_max: float, hidden_dim: int, device: Optional[torch.device] = None, ): super().__init__() self.qnet1 = DistributionalQNetwork( obs_dim=obs_dim, n_act=n_act, num_atoms=num_atoms, v_min=v_min, v_max=v_max, hidden_dim=hidden_dim, device=device, ) self.qnet2 = DistributionalQNetwork( obs_dim=obs_dim, n_act=n_act, num_atoms=num_atoms, v_min=v_min, v_max=v_max, hidden_dim=hidden_dim, device=device, ) self.register_buffer("q_support", torch.linspace(v_min, v_max, num_atoms, device=device)) self.device = device
[docs] def forward(self, obs: torch.Tensor, actions: torch.Tensor): return self.qnet1(obs, actions), self.qnet2(obs, actions)
[docs] def projection( self, obs: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, bootstrap: torch.Tensor, discount: torch.Tensor, ): """Projection operation using both Q-networks.""" q1_proj = self.qnet1.projection( obs, actions, rewards, bootstrap, discount, self.q_support, self.q_support.device, ) q2_proj = self.qnet2.projection( obs, actions, rewards, bootstrap, discount, self.q_support, self.q_support.device, ) return q1_proj, q2_proj
[docs] def get_value(self, probs: torch.Tensor) -> torch.Tensor: """Calculate value from probability distribution using support.""" return torch.sum(probs * self.q_support, dim=1)