unilab.algos.torch.hora.distill.HoraDistillationTrainer

class unilab.algos.torch.hora.distill.HoraDistillationTrainer[source]

Bases: object

Stage-2 HORA latent distillation trainer.

Parameters:
  • cfg (DictConfig)

  • device (str)

  • log_dir (str | Path)

  • teacher_checkpoint (str | Path)

  • teacher_algo_family (str)

  • teacher_metadata (dict[str, Any] | None)

  • distill_runtime_cfg (DictConfig)

Methods

__init__(env, cfg, *, device, log_dir, ...)

save(path)

train()

__init__(env, cfg, *, device, log_dir, teacher_checkpoint, teacher_algo_family, teacher_metadata=None, distill_runtime_cfg, logger)[source]
Parameters:
  • cfg (DictConfig)

  • device (str)

  • log_dir (str | Path)

  • teacher_checkpoint (str | Path)

  • teacher_algo_family (str)

  • teacher_metadata (dict[str, Any] | None)

  • distill_runtime_cfg (DictConfig)

train()[source]
Return type:

None

save(path)[source]
Parameters:

path (str | Path)

Return type:

None