Source code for unilab.base.augmentation
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
import torch
SymmetryObsLayout = tuple[tuple[str, int], ...]
[docs]
class SymmetryAugmentation(Protocol):
"""Runtime symmetry augmentation contract owned by env/backend adapters."""
batch_multiplier: int
[docs]
def augment_obs_and_actions(
self,
obs: torch.Tensor,
actions: torch.Tensor,
*,
obs_group: str = "obs",
) -> tuple[torch.Tensor, torch.Tensor]: ...
self,
obs: torch.Tensor,
*,
obs_group: str = "obs",
) -> torch.Tensor: ...