unilab.algos.torch.hora.distill.HoraDistillationTrainer
-
class unilab.algos.torch.hora.distill.HoraDistillationTrainer[source]
Bases: object
Stage-2 HORA latent distillation trainer.
- Parameters:
cfg (DictConfig)
device (str)
log_dir (str | Path)
teacher_checkpoint (str | Path)
teacher_algo_family (str)
teacher_metadata (dict[str, Any] | None)
distill_runtime_cfg (DictConfig)
Methods
-
__init__(env, cfg, *, device, log_dir, teacher_checkpoint, teacher_algo_family, teacher_metadata=None, distill_runtime_cfg, logger)[source]
- Parameters:
cfg (DictConfig)
device (str)
log_dir (str | Path)
teacher_checkpoint (str | Path)
teacher_algo_family (str)
teacher_metadata (dict[str, Any] | None)
distill_runtime_cfg (DictConfig)
-
train()[source]
- Return type:
None
-
save(path)[source]
- Parameters:
path (str | Path)
- Return type:
None