Source code for unilab.algos.torch.common.normalization
"""Observation normalization for RL training."""
from __future__ import annotations
import torch
import torch.nn as nn
[docs]
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of observations using running statistics."""
_mean: torch.Tensor
_var: torch.Tensor
_std: torch.Tensor
count: torch.Tensor
[docs]
def __init__(self, shape, device, eps=1e-2):
super().__init__()
self.eps = eps
self.device = device
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0).to(device))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0).to(device))
self.register_buffer("count", torch.tensor(0, dtype=torch.long).to(device))
@property
def mean(self):
return self._mean.squeeze(0).clone()
@property
def std(self):
return self._std.squeeze(0).clone()
[docs]
@torch.no_grad()
def forward(self, x: torch.Tensor, center: bool = True, update: bool = True) -> torch.Tensor:
if self.training and update:
self.update(x)
if center:
return torch.as_tensor((x - self._mean) / (self._std + self.eps))
else:
return torch.as_tensor(x / (self._std + self.eps))
[docs]
def update(self, x):
batch_size = x.shape[0]
batch_mean = torch.mean(x, dim=0, keepdim=True)
batch_var = torch.var(x, dim=0, keepdim=True, unbiased=False)
new_count = self.count + batch_size
# Welford's online algorithm
delta = batch_mean - self._mean
self._mean.copy_(self._mean + delta * (batch_size / new_count))
delta2 = batch_mean - self._mean
m_a = self._var * self.count
m_b = batch_var * batch_size
M2 = m_a + m_b + delta2.pow(2) * (self.count * batch_size / new_count)
self._var.copy_(M2 / new_count)
self._std.copy_(self._var.sqrt())
self.count.copy_(new_count)
[docs]
def inverse(self, y):
return y * (self._std + self.eps) + self._mean