from __future__ import annotations
import math
import statistics
import time
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
from unilab.algos.torch.common.normalization import EmpiricalNormalization
from unilab.algos.torch.hora.models import (
HoraActorModel,
HoraCoreOutput,
HoraSharedActorCritic,
ProprioAdaptTConv,
)
from unilab.algos.torch.hora.sac_models import HoraSACActor
[docs]
class HoraSACDistillShared(nn.Module):
"""SAC-teacher-compatible HORA stage-2 shared actor."""
[docs]
def __init__(
self,
obs_dim: int,
action_dim: int,
*,
priv_info_dim: int,
hidden_dim: int = 512,
priv_info_embed_dim: int = 9,
priv_mlp_hidden_dims: list[int] | tuple[int, ...] = (256, 128, 9),
use_layer_norm: bool = True,
proprio_hist_len: int = 30,
proprio_frame_dim: int | None = None,
device: torch.device | str = "cpu",
) -> None:
super().__init__()
teacher = HoraSACActor(
obs_dim=obs_dim,
priv_info_dim=priv_info_dim,
action_dim=action_dim,
hidden_dim=hidden_dim,
priv_info_embed_dim=priv_info_embed_dim,
priv_mlp_hidden_dims=priv_mlp_hidden_dims,
use_layer_norm=use_layer_norm,
device=device,
)
self.obs_dim = int(obs_dim)
self.action_dim = int(action_dim)
self.priv_info_dim = int(priv_info_dim)
self.priv_info_embed_dim = int(priv_info_embed_dim)
self.proprio_hist_len = int(proprio_hist_len)
self.proprio_frame_dim = (
int(proprio_frame_dim) if proprio_frame_dim is not None else self.obs_dim // 3
)
self.obs_normalizer = nn.Identity()
self.priv_encoder = teacher.priv_encoder
self.priv_projection = teacher.priv_projection
self.actor_trunk = teacher.actor_trunk
self.action_mean_head = teacher.action_mean_head
self.adapt_tconv = ProprioAdaptTConv(self.proprio_frame_dim, self.priv_info_embed_dim)
[docs]
def load_teacher_actor_state_dict(self, actor_state: dict[str, torch.Tensor]) -> None:
own_state = self.state_dict()
teacher_state = {
key: value
for key, value in actor_state.items()
if key in own_state and not key.startswith("adapt_tconv.")
}
missing = sorted(
key
for key in own_state
if not key.startswith("adapt_tconv.") and key not in teacher_state
)
if missing:
raise ValueError(f"HORA-SAC teacher checkpoint is missing actor keys: {missing}")
self.load_state_dict(teacher_state, strict=False)
[docs]
def encode_privileged_info(self, priv_info: torch.Tensor) -> torch.Tensor:
return torch.tanh(self.priv_projection(self.priv_encoder(priv_info)))
[docs]
def encode_proprio_history(self, proprio_hist: torch.Tensor) -> torch.Tensor:
return torch.tanh(self.adapt_tconv(proprio_hist))
def _zero_privileged_latent(
self,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.zeros((batch_size, self.priv_info_embed_dim), device=device, dtype=dtype)
[docs]
def policy_mean(
self,
obs: TensorDict,
*,
prefer_student: bool,
require_privileged_target: bool = True,
) -> tuple[torch.Tensor, HoraCoreOutput]:
actor_obs = obs["actor"]
priv_info = obs.get("priv_info")
if priv_info is None:
if require_privileged_target or not prefer_student:
raise ValueError("priv_info is required for HORA-SAC distillation")
privileged_target = self._zero_privileged_latent(
actor_obs.shape[0],
actor_obs.device,
actor_obs.dtype,
)
else:
privileged_target = self.encode_privileged_info(priv_info)
if prefer_student:
proprio_hist = obs.get("proprio_hist")
if proprio_hist is None:
raise ValueError("proprio_hist is required for HORA-SAC student inference")
privileged_latent = self.encode_proprio_history(proprio_hist)
else:
privileged_latent = privileged_target
trunk_latent = self.actor_trunk(torch.cat([actor_obs, privileged_latent], dim=-1))
mean = self.action_mean_head(trunk_latent)
return (
mean,
HoraCoreOutput(
policy_obs=actor_obs,
trunk_latent=trunk_latent,
privileged_latent=privileged_latent,
privileged_target=privileged_target,
),
)
[docs]
class HoraSACDistillActor(nn.Module):
"""Stage-2 actor wrapper for HORA-SAC teachers."""
is_recurrent: bool = False
[docs]
def __init__(self, shared: HoraSACDistillShared) -> None:
super().__init__()
self.shared = shared
self.prefer_student = True
[docs]
def forward(
self,
obs: TensorDict,
masks: torch.Tensor | None = None,
hidden_state=None,
stochastic_output: bool = False,
) -> torch.Tensor:
del masks, hidden_state, stochastic_output
mean, _ = self.shared.policy_mean(
obs,
prefer_student=self.prefer_student,
require_privileged_target=False,
)
return torch.tanh(mean)
[docs]
def load_sac_teacher_actor_state_dict(self, actor_state: dict[str, torch.Tensor]) -> None:
self.shared.load_teacher_actor_state_dict(actor_state)
DistillActor = HoraActorModel | HoraSACDistillActor
[docs]
@dataclass
class HoraDistillStats:
agent_steps: int = 0
best_reward: float = float("-inf")
mean_reward: float = float("nan")
mean_episode_length: float = float("nan")
[docs]
def build_student_actor_and_normalizer(
env,
cfg: DictConfig,
*,
device: torch.device,
) -> tuple[DistillActor, EmpiricalNormalization]:
actor_obs = env.get_observations()
actor_dim = int(actor_obs["actor"].shape[-1])
priv_info_dim = int(actor_obs["priv_info"].shape[-1])
proprio_hist_shape = actor_obs["proprio_hist"].shape[1:]
model_cfg = OmegaConf.to_container(cfg.algo.model, resolve=True)
assert isinstance(model_cfg, dict)
if model_cfg.get("teacher_arch") == "hora_sac":
shared_sac = HoraSACDistillShared(
obs_dim=actor_dim,
action_dim=int(env.num_actions),
priv_info_dim=priv_info_dim,
hidden_dim=int(model_cfg.get("actor_hidden_dim", 512)),
priv_info_embed_dim=int(model_cfg.get("priv_info_embed_dim", priv_info_dim)),
priv_mlp_hidden_dims=model_cfg.get("priv_mlp_hidden_dims", [256, 128, 9]),
use_layer_norm=bool(model_cfg.get("use_layer_norm", True)),
proprio_hist_len=int(proprio_hist_shape[0]),
proprio_frame_dim=int(proprio_hist_shape[1]),
device=device,
).to(device)
actor = cast(HoraSACDistillActor, HoraSACDistillActor(shared_sac).to(device))
hist_normalizer = EmpiricalNormalization(proprio_hist_shape, device=device)
return actor, hist_normalizer
shared = HoraSharedActorCritic(
obs_dim=actor_dim,
action_dim=int(env.num_actions),
priv_info_dim=priv_info_dim,
actor_hidden_dims=model_cfg.get("hidden_dims", [512, 256, 128]),
activation=model_cfg.get("activation", "elu"),
obs_normalization=model_cfg.get("obs_normalization", True),
distribution_cfg=model_cfg.get("distribution_cfg", {"init_std": 1.0, "std_type": "scalar"}),
priv_info_embed_dim=model_cfg.get("priv_info_embed_dim", priv_info_dim),
priv_mlp_hidden_dims=model_cfg.get("priv_mlp_hidden_dims", [256, 128, 8]),
use_student_encoder=True,
proprio_hist_len=int(proprio_hist_shape[0]),
proprio_frame_dim=int(proprio_hist_shape[1]),
).to(device)
actor = cast(
HoraActorModel,
HoraActorModel(
actor_obs,
{"actor": ["actor"], "critic": ["actor"]},
"actor",
int(env.num_actions),
shared_model=shared,
use_student_encoder=True,
).to(device),
)
hist_normalizer = EmpiricalNormalization(proprio_hist_shape, device=device)
return actor, hist_normalizer
[docs]
def load_teacher_actor_weights(
actor: nn.Module,
teacher_checkpoint: str | Path,
*,
teacher_algo_family: str,
device: torch.device,
) -> None:
checkpoint = torch.load(teacher_checkpoint, map_location=device, weights_only=False)
if str(teacher_algo_family) == "sac":
actor_state = checkpoint.get("actor")
if actor_state is None:
raise ValueError(
"Checkpoint does not contain the expected HORA-SAC actor weights. "
f"checkpoint={teacher_checkpoint}"
)
load_sac = getattr(actor, "load_sac_teacher_actor_state_dict", None)
if load_sac is None:
raise ValueError("Selected distillation actor does not support HORA-SAC weights.")
load_sac(actor_state)
return
actor_state_key = {
"ppo": "actor_state_dict",
"appo": "actor",
}.get(str(teacher_algo_family))
if actor_state_key is None:
raise ValueError(
"Unsupported HORA teacher algorithm family for distillation: "
f"{teacher_algo_family!r}. Expected one of ['ppo', 'appo', 'sac']."
)
actor_state = checkpoint.get(actor_state_key)
if actor_state is None:
raise ValueError(
"Checkpoint does not contain the expected teacher actor weights. "
f"algo_family={teacher_algo_family!r} expected_key={actor_state_key!r} "
f"checkpoint={teacher_checkpoint}"
)
actor.load_state_dict(actor_state, strict=False)
[docs]
def load_distilled_checkpoint(
actor: nn.Module,
hist_normalizer: EmpiricalNormalization,
checkpoint_path: str | Path,
*,
device: torch.device,
) -> dict[str, Any]:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model_state = checkpoint.get("model_state_dict")
if model_state is None:
raise ValueError(f"Checkpoint does not contain model_state_dict: {checkpoint_path}")
actor.load_state_dict(model_state, strict=True)
history_normalizer = checkpoint.get("history_normalizer")
if history_normalizer is not None:
hist_normalizer.load_state_dict(history_normalizer)
return cast(dict[str, Any], checkpoint)
[docs]
class HoraDistillationTrainer:
"""Stage-2 HORA latent distillation trainer."""
[docs]
def __init__(
self,
env,
cfg: DictConfig,
*,
device: str,
log_dir: str | Path,
teacher_checkpoint: str | Path,
teacher_algo_family: str,
teacher_metadata: dict[str, Any] | None = None,
distill_runtime_cfg: DictConfig,
logger,
) -> None:
self.env = env
self.cfg = cfg
self.device = torch.device(device)
self.log_dir = Path(log_dir)
self.logger = logger
self.teacher_checkpoint = Path(teacher_checkpoint)
self.teacher_algo_family = str(teacher_algo_family)
self.teacher_metadata = dict(teacher_metadata or {})
self.distill_runtime_cfg = OmegaConf.to_container(distill_runtime_cfg, resolve=True)
self.actor, self.hist_normalizer = build_student_actor_and_normalizer(
env,
cfg,
device=self.device,
)
self.optimizer = torch.optim.Adam(
self._trainable_parameters(), lr=float(cfg.algo.learning_rate)
)
self.stats = HoraDistillStats()
self._reward_buffer: deque[float] = deque(maxlen=100)
self._episode_length_buffer: deque[float] = deque(maxlen=100)
self._step_reward = torch.zeros((env.num_envs,), dtype=torch.float32, device=self.device)
self._step_length = torch.zeros((env.num_envs,), dtype=torch.float32, device=self.device)
self._tb_writer = self._build_tensorboard_writer()
self._load_teacher_checkpoint()
def _trainable_parameters(self) -> list[torch.nn.Parameter]:
params: list[torch.nn.Parameter] = []
for name, param in self.actor.named_parameters():
requires_grad = "adapt_tconv" in name
param.requires_grad = requires_grad
if requires_grad:
params.append(param)
return params
def _load_teacher_checkpoint(self) -> None:
load_teacher_actor_weights(
self.actor,
self.teacher_checkpoint,
teacher_algo_family=self.teacher_algo_family,
device=self.device,
)
self.actor.train()
self.actor.shared.obs_normalizer.eval()
def _build_tensorboard_writer(self) -> Any | None:
"""Create the stage-2 TensorBoard writer when the config requests it.
Args:
None.
Returns:
Summary writer rooted at ``<log_dir>/tb``, or ``None`` when scalar
backend logging is disabled or TensorBoard is unavailable.
"""
logger_type = str(OmegaConf.select(self.cfg, "training.logger", default="tensorboard"))
if logger_type.lower() != "tensorboard":
return None
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
self.logger.warning(
"tensorboard is not installed; disabling HORA distillation TensorBoard logging."
)
return None
tb_dir = self.log_dir / "tb"
tb_dir.mkdir(parents=True, exist_ok=True)
self.logger.info("TensorBoard: %s", tb_dir)
return SummaryWriter(log_dir=str(tb_dir))
def _add_scalar_if_finite(self, tag: str, value: float, *, step: int) -> None:
"""Write a scalar only when a TensorBoard writer exists and the value is finite.
Args:
tag: TensorBoard metric name.
value: Scalar value to record.
step: Global step associated with the scalar.
Returns:
None. Invalid or disabled values are skipped silently so early NaNs
from unfinished episodes do not pollute the event stream.
"""
if self._tb_writer is None or not math.isfinite(value):
return
self._tb_writer.add_scalar(tag, value, step)
def _log_tensorboard_step(self, *, loss: float, elapsed: float) -> None:
"""Record the latest distillation scalars to TensorBoard.
Args:
loss: Latest latent-distillation loss.
elapsed: Wall-clock training time since the run started.
Returns:
None. Metrics are written at the current agent-step count.
"""
if self._tb_writer is None:
return
step = self.stats.agent_steps
self._add_scalar_if_finite("train/loss", loss, step=step)
self._add_scalar_if_finite("reward/mean", self.stats.mean_reward, step=step)
self._add_scalar_if_finite("reward/best", self.stats.best_reward, step=step)
self._add_scalar_if_finite(
"episode/length",
self.stats.mean_episode_length,
step=step,
)
self._add_scalar_if_finite("perf/fps", step / max(elapsed, 1e-6), step=step)
self._add_scalar_if_finite("perf/training_time_sec", elapsed, step=step)
self._tb_writer.flush()
def _normalize_student_obs(self, obs_td) -> dict[str, torch.Tensor]:
actor_obs = obs_td["actor"].to(self.device)
proprio_hist = obs_td["proprio_hist"].to(self.device)
return {
"actor": actor_obs,
"priv_info": obs_td["priv_info"].to(self.device),
"proprio_hist": self.hist_normalizer(proprio_hist),
}
@staticmethod
def _next_interval_boundary(current_steps: int, interval_steps: int) -> int | None:
"""Return the next positive save boundary after the current step count.
Args:
current_steps: Number of agent steps already completed.
interval_steps: Positive interval in agent steps between saves.
Returns:
The next interval boundary, or ``None`` when periodic saving is disabled.
"""
if interval_steps <= 0:
return None
return ((current_steps // interval_steps) + 1) * interval_steps
[docs]
def train(self) -> None:
obs_td, _ = self.env.reset()
max_agent_steps = int(self.cfg.algo.max_agent_steps)
save_interval = int(self.cfg.algo.save_interval_steps)
log_interval = int(self.cfg.algo.log_interval_steps)
next_log_steps = self._next_interval_boundary(self.stats.agent_steps, log_interval)
next_save_steps = self._next_interval_boundary(self.stats.agent_steps, save_interval)
start_time = time.time()
last_loss = float("nan")
try:
while self.stats.agent_steps < max_agent_steps:
norm_obs = self._normalize_student_obs(obs_td)
obs_batch = {
key: value.detach() if key == "actor" else value
for key, value in norm_obs.items()
}
td = TensorDict(obs_batch, batch_size=obs_td.batch_size, device=self.device)
_, core_output = self.actor.shared.policy_mean(td, prefer_student=True)
loss = torch.mean(
(core_output.privileged_latent - core_output.privileged_target.detach()) ** 2
)
last_loss = float(loss.item())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.no_grad():
actions = self.actor(td, stochastic_output=False).clamp_(-1.0, 1.0)
obs_td, rewards, dones, infos = self.env.step(actions)
rewards = rewards.to(self.device)
dones = dones.to(self.device)
self.stats.agent_steps += int(self.env.num_envs)
self._step_reward += rewards
self._step_length += 1
done_idx = torch.nonzero(dones, as_tuple=False).flatten()
if len(done_idx) > 0:
completed_rewards = self._step_reward[done_idx]
completed_lengths = self._step_length[done_idx]
done_mean_reward = float(torch.mean(completed_rewards).item())
self._reward_buffer.extend(completed_rewards.detach().cpu().numpy().tolist())
self._episode_length_buffer.extend(
completed_lengths.detach().cpu().numpy().tolist()
)
self.stats.mean_reward = float(statistics.mean(self._reward_buffer))
self.stats.mean_episode_length = float(
statistics.mean(self._episode_length_buffer)
)
self.stats.best_reward = max(self.stats.best_reward, done_mean_reward)
self._step_reward[done_idx] = 0.0
self._step_length[done_idx] = 0.0
if next_log_steps is not None and self.stats.agent_steps >= next_log_steps:
elapsed = max(time.time() - start_time, 1e-6)
self.logger.info(
"agent_steps=%d loss=%.6f mean_reward=%.4f best_reward=%.4f "
"mean_episode_length=%.2f training_time=%.2fs fps=%.1f",
self.stats.agent_steps,
last_loss,
self.stats.mean_reward,
self.stats.best_reward,
self.stats.mean_episode_length,
elapsed,
self.stats.agent_steps / elapsed,
)
self._log_tensorboard_step(loss=last_loss, elapsed=elapsed)
next_log_steps = self._next_interval_boundary(
self.stats.agent_steps, log_interval
)
if next_save_steps is not None and self.stats.agent_steps >= next_save_steps:
self.save(self.log_dir / f"hora_stage2_{self.stats.agent_steps}.pt")
next_save_steps = self._next_interval_boundary(
self.stats.agent_steps, save_interval
)
self.save(self.log_dir / "hora_stage2_last.pt")
total_elapsed = max(time.time() - start_time, 1e-6)
self.logger.info(
"training_complete agent_steps=%d mean_reward=%.4f best_reward=%.4f "
"mean_episode_length=%.2f training_time=%.2fs",
self.stats.agent_steps,
self.stats.mean_reward,
self.stats.best_reward,
self.stats.mean_episode_length,
total_elapsed,
)
self._log_tensorboard_step(loss=last_loss, elapsed=total_elapsed)
finally:
if self._tb_writer is not None:
self._tb_writer.close()
self._tb_writer = None
[docs]
def save(self, path: str | Path) -> None:
torch.save(
{
"model_state_dict": self.actor.state_dict(),
"history_normalizer": self.hist_normalizer.state_dict(),
"agent_steps": self.stats.agent_steps,
"teacher_checkpoint": str(self.teacher_checkpoint),
"teacher_algo_family": self.teacher_algo_family,
"teacher_metadata": self.teacher_metadata,
"distill_runtime_cfg": self.distill_runtime_cfg,
},
Path(path),
)