Source code for unilab.algos.torch.common.ane_inference

"""Simplified ANE backend using deterministic inference."""

import numpy as np


[docs] def create_ane_actor(actor_model, obs_dim, action_dim): """Create ANE-compatible actor using deterministic inference. Uses mean action (no sampling) to avoid CoreML limitations. """ try: import coremltools as ct import torch # Create deterministic wrapper class DeterministicActor(torch.nn.Module): def __init__(self, actor): super().__init__() self.actor = actor def forward(self, obs): # Get mean action (deterministic) with torch.no_grad(): mean, _ = self.actor.forward(obs) return torch.tanh(mean) det_actor = DeterministicActor(actor_model) det_actor.eval() # Trace model example = torch.randn(1, obs_dim) traced = torch.jit.trace(det_actor, example) # Convert to CoreML mlmodel = ct.convert( traced, inputs=[ct.TensorType(shape=(ct.RangeDim(1, 8192), obs_dim))], compute_units=ct.ComputeUnit.ALL, ) return mlmodel except Exception as e: print(f"ANE conversion failed: {e}") return None
[docs] class ANEInference: """ANE inference wrapper."""
[docs] def __init__(self, coreml_model): self.model = coreml_model
[docs] def predict(self, obs_np): """Run inference.""" result = self.model.predict({"obs": obs_np}) return result[list(result.keys())[0]]