Source code for fusion.policies.ml_policy

"""
ML Control Policy for path selection using pre-trained models.

This module provides MLControlPolicy, which loads pre-trained ML models
(PyTorch, sklearn, ONNX) and uses them for path selection inference.

MLControlPolicy implements the ControlPolicy protocol with:
- Multi-framework support via file extension detection
- Robust fallback to heuristic policies on errors
- Action masking for feasibility constraints
- Feature engineering matching RL observation space

MLControlPolicy is deployment-only: no online training, update() is no-op.

Example:
    >>> from fusion.policies.ml_policy import MLControlPolicy
    >>> policy = MLControlPolicy("model.pt", fallback_type="first_feasible")
    >>> action = policy.select_action(request, options, network_state)
"""

from __future__ import annotations

import logging
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any, Protocol

import numpy as np

if TYPE_CHECKING:
    from fusion.domain.network_state import NetworkState
    from fusion.domain.request import Request
    from fusion.interfaces.control_policy import ControlPolicy
    from fusion.modules.rl.adapter import PathOption

logger = logging.getLogger(__name__)


[docs] class FeatureBuilder: """ Build feature vectors for ML model inference. Creates fixed-size feature vectors from Request and PathOption inputs, with padding for variable numbers of paths. The feature layout matches the RL training observation space, ensuring model compatibility. :ivar k_paths: Expected number of paths (for padding). :vartype k_paths: int :ivar features_per_path: Number of features extracted per path. :vartype features_per_path: int Example:: >>> builder = FeatureBuilder(k_paths=5) >>> features = builder.build(request, options, network_state) >>> features.shape (21,) # 1 + 5*4 """ FEATURES_PER_PATH = 4 # Normalization constants MAX_BANDWIDTH_GBPS = 1000.0 MAX_WEIGHT_KM = 10000.0 MAX_SLOTS = 100.0
[docs] def __init__(self, k_paths: int = 5) -> None: """ Initialize feature builder. :param k_paths: Expected number of path options. :type k_paths: int """ self.k_paths = k_paths self._feature_size = 1 + k_paths * self.FEATURES_PER_PATH
@property def feature_size(self) -> int: """ Total size of feature vector. :return: Size of the feature vector. :rtype: int """ return self._feature_size
[docs] def build( self, request: Request, options: list[PathOption], network_state: NetworkState, ) -> np.ndarray: """ Build feature vector from inputs. :param request: The request being processed. :type request: Request :param options: Available path options. :type options: list[PathOption] :param network_state: Current network state (for future extensions). :type network_state: NetworkState :return: Feature vector of shape (feature_size,). :rtype: np.ndarray """ features: list[float] = [] # Request-level features bandwidth = getattr(request, "bandwidth_gbps", 0.0) if request else 0.0 features.append(self._normalize_bandwidth(bandwidth)) # Per-path features for i in range(self.k_paths): if i < len(options): features.extend(self._extract_path_features(options[i])) else: features.extend(self._get_padding_features()) return np.array(features, dtype=np.float32)
def _normalize_bandwidth(self, bandwidth_gbps: float) -> float: """ Normalize bandwidth to [0, 1] range. :param bandwidth_gbps: Bandwidth in Gbps. :type bandwidth_gbps: float :return: Normalized bandwidth value. :rtype: float """ return bandwidth_gbps / self.MAX_BANDWIDTH_GBPS def _extract_path_features(self, opt: PathOption) -> list[float]: """ Extract features from a single path option. :param opt: Path option to extract features from. :type opt: PathOption :return: List of feature values. :rtype: list[float] """ return [ opt.weight_km / self.MAX_WEIGHT_KM, opt.congestion, 1.0 if opt.is_feasible else 0.0, (opt.slots_needed or 0) / self.MAX_SLOTS, ] def _get_padding_features(self) -> list[float]: """ Get padding features for missing paths. Padding values chosen to represent "worst case" path: - weight: 0.0 (no path) - congestion: 1.0 (fully congested) - feasible: 0.0 (not available) - slots: 0.0 (not needed) :return: List of padding feature values. :rtype: list[float] """ return [0.0, 1.0, 0.0, 0.0]
[docs] class ModelWrapper(Protocol): """Protocol for model wrappers providing predict interface."""
[docs] def predict(self, features: np.ndarray) -> np.ndarray: """ Predict action scores from features. :param features: Feature array of shape (feature_size,) or (batch, feature_size). :type features: np.ndarray :return: Scores/logits for each action of shape (k_paths,) or (batch, k_paths). :rtype: np.ndarray """ ...
[docs] class TorchModelWrapper: """Wrapper for PyTorch models."""
[docs] def __init__(self, model: Any, device: str = "cpu") -> None: """ Initialize torch model wrapper. :param model: PyTorch nn.Module. :type model: Any :param device: Device to run inference on. :type device: str """ import torch self.model = model self.device = torch.device(device) self.model.to(self.device) self.model.eval()
[docs] def predict(self, features: np.ndarray) -> np.ndarray: """ Run inference through PyTorch model. :param features: Input feature array. :type features: np.ndarray :return: Model output scores. :rtype: np.ndarray """ import torch # Ensure 2D input if features.ndim == 1: features = features.reshape(1, -1) with torch.no_grad(): tensor = torch.from_numpy(features).float().to(self.device) output = self.model(tensor) result: np.ndarray = output.cpu().numpy().squeeze() return result
[docs] class SklearnModelWrapper: """Wrapper for sklearn models."""
[docs] def __init__(self, model: Any) -> None: """ Initialize sklearn model wrapper. :param model: sklearn model with predict_proba or predict method. :type model: Any """ self.model = model self._has_proba = hasattr(model, "predict_proba")
[docs] def predict(self, features: np.ndarray) -> np.ndarray: """ Run inference through sklearn model. :param features: Input feature array. :type features: np.ndarray :return: Model output scores or predictions. :rtype: np.ndarray """ # Ensure 2D input if features.ndim == 1: features = features.reshape(1, -1) if self._has_proba: probs = self.model.predict_proba(features) result: np.ndarray = np.asarray(probs).squeeze() return result else: # For regressors, predict returns shape (n_samples,) or (n_samples, n_outputs) output = self.model.predict(features) result = np.atleast_1d(np.asarray(output).squeeze()) return result
[docs] class OnnxModelWrapper: """Wrapper for ONNX models."""
[docs] def __init__(self, session: Any) -> None: """ Initialize ONNX model wrapper. :param session: onnxruntime InferenceSession. :type session: Any """ self.session = session self._input_name = session.get_inputs()[0].name
[docs] def predict(self, features: np.ndarray) -> np.ndarray: """ Run inference through ONNX model. :param features: Input feature array. :type features: np.ndarray :return: Model output scores. :rtype: np.ndarray """ # Ensure 2D input and float32 if features.ndim == 1: features = features.reshape(1, -1) features = features.astype(np.float32) outputs = self.session.run(None, {self._input_name: features}) result: np.ndarray = np.asarray(outputs[0]).squeeze() return result
[docs] class CallableModelWrapper: """Wrapper for callable models (functions or objects with __call__)."""
[docs] def __init__(self, model: Callable[[np.ndarray], np.ndarray]) -> None: """ Initialize callable model wrapper. :param model: Any callable that takes features and returns scores. :type model: Callable[[np.ndarray], np.ndarray] """ self._callable = model
[docs] def predict(self, features: np.ndarray) -> np.ndarray: """ Run inference through callable. :param features: Input feature array. :type features: np.ndarray :return: Model output scores. :rtype: np.ndarray """ return self._callable(features)
def _load_torch_model(model_path: Path, device: str) -> TorchModelWrapper: """ Load a PyTorch model from file. :param model_path: Path to the model file. :type model_path: Path :param device: Device to load model onto. :type device: str :return: Wrapped PyTorch model. :rtype: TorchModelWrapper :raises ImportError: If PyTorch is not installed. :raises ValueError: If model file contains only state_dict. """ try: import torch except ImportError as e: raise ImportError("PyTorch not installed. Install with: pip install torch") from e model = torch.load( model_path, map_location=device, weights_only=False, ) # nosec B614 - Loading trusted model files from local filesystem # Handle state_dict or full model if isinstance(model, dict): raise ValueError( "Model file contains state_dict only. Please save the full model with torch.save(model, path) or provide model architecture." ) return TorchModelWrapper(model, device) def _load_sklearn_model(model_path: Path) -> SklearnModelWrapper: """ Load a sklearn model from joblib or pickle file. :param model_path: Path to the model file. :type model_path: Path :return: Wrapped sklearn model. :rtype: SklearnModelWrapper :raises ImportError: If joblib is not installed. """ try: import joblib except ImportError as e: raise ImportError("joblib not installed. Install with: pip install joblib") from e model = joblib.load(model_path) return SklearnModelWrapper(model) def _load_onnx_model(model_path: Path) -> OnnxModelWrapper: """ Load an ONNX model. :param model_path: Path to the ONNX model file. :type model_path: Path :return: Wrapped ONNX model. :rtype: OnnxModelWrapper :raises ImportError: If onnxruntime is not installed. """ try: import onnxruntime as ort except ImportError as e: raise ImportError("onnxruntime not installed. Install with: pip install onnxruntime") from e session = ort.InferenceSession(str(model_path)) return OnnxModelWrapper(session)
[docs] def load_model(model_path: str, device: str = "cpu") -> ModelWrapper: """ Load a model based on file extension. Supported formats: - .pt, .pth: PyTorch models - .joblib, .pkl: sklearn models (via joblib) - .onnx: ONNX models :param model_path: Path to model file. :type model_path: str :param device: Device for PyTorch models ("cpu", "cuda", "mps"). :type device: str :return: ModelWrapper with predict() method. :rtype: ModelWrapper :raises FileNotFoundError: If model file doesn't exist. :raises ValueError: If file extension not supported. :raises ImportError: If required framework not installed. """ path = Path(model_path) if not path.exists(): raise FileNotFoundError(f"Model file not found: {model_path}") suffix = path.suffix.lower() if suffix in (".pt", ".pth"): return _load_torch_model(path, device) elif suffix in (".joblib", ".pkl"): return _load_sklearn_model(path) elif suffix == ".onnx": return _load_onnx_model(path) else: raise ValueError(f"Unsupported model format: {suffix}. Supported: .pt, .pth, .joblib, .pkl, .onnx")
[docs] class MLControlPolicy: """ ML-based control policy for path selection. Loads pre-trained ML models and uses them for deterministic inference. Implements robust fallback to heuristic policies when model fails. This is a deployment-only policy: update() is a no-op. :ivar model: Wrapped model with predict() interface. :vartype model: ModelWrapper :ivar fallback: Fallback heuristic policy. :vartype fallback: HeuristicPolicy :ivar feature_builder: Feature vector constructor. :vartype feature_builder: FeatureBuilder Example:: >>> policy = MLControlPolicy("model.pt", fallback_type="first_feasible") >>> action = policy.select_action(request, options, network_state) >>> print(policy.get_stats()) # View fallback statistics """
[docs] def __init__( self, model_path: str | None = None, model: ModelWrapper | None = None, device: str = "cpu", k_paths: int = 5, fallback_policy: Any | None = None, fallback_type: str = "first_feasible", ) -> None: """ Initialize ML control policy. :param model_path: Path to model file. Mutually exclusive with model. :type model_path: str | None :param model: Pre-loaded model wrapper. Mutually exclusive with model_path. :type model: ModelWrapper | None :param device: Device for PyTorch models ("cpu", "cuda", "mps"). :type device: str :param k_paths: Expected number of path options (for feature builder). :type k_paths: int :param fallback_policy: Explicit fallback policy instance. :type fallback_policy: Any | None :param fallback_type: Fallback type if policy not provided: "first_feasible" (default), "shortest_feasible", "least_congested", "random". :type fallback_type: str :raises ValueError: If neither or both model_path and model provided. :raises FileNotFoundError: If model file doesn't exist. :raises ImportError: If required framework not installed. """ # Validate inputs if model_path is None and model is None: raise ValueError("Either model_path or model must be provided") if model_path is not None and model is not None: raise ValueError("Cannot provide both model_path and model") # Load model if model_path is not None: self._model: ModelWrapper = load_model(model_path, device) self._model_path = model_path else: # model is guaranteed non-None here due to validation above assert model is not None self._model = model self._model_path = "<provided>" # Setup feature builder self._feature_builder = FeatureBuilder(k_paths=k_paths) self._k_paths = k_paths # Setup fallback if fallback_policy is not None: self._fallback: ControlPolicy = fallback_policy else: self._fallback = self._create_fallback(fallback_type) # Statistics tracking self._total_calls = 0 self._fallback_calls = 0 self._error_types: dict[str, int] = {}
def _create_fallback(self, fallback_type: str) -> ControlPolicy: """ Create fallback policy from type string. :param fallback_type: Type of fallback policy. :type fallback_type: str :return: Instantiated fallback policy. :rtype: ControlPolicy :raises ValueError: If fallback_type is unknown. """ from fusion.policies.heuristic_policy import ( FirstFeasiblePolicy, LeastCongestedPolicy, RandomFeasiblePolicy, ShortestFeasiblePolicy, ) fallback_map: dict[str, type[ControlPolicy]] = { "first_feasible": FirstFeasiblePolicy, "shortest_feasible": ShortestFeasiblePolicy, "least_congested": LeastCongestedPolicy, "random": RandomFeasiblePolicy, } if fallback_type not in fallback_map: raise ValueError(f"Unknown fallback type: {fallback_type}. Options: {list(fallback_map.keys())}") return fallback_map[fallback_type]() @property def fallback(self) -> ControlPolicy: """ Current fallback policy. :return: The fallback policy instance. :rtype: ControlPolicy """ return self._fallback @property def fallback_rate(self) -> float: """ Percentage of calls that used fallback. :return: Fallback rate as a fraction (0.0 to 1.0). :rtype: float """ if self._total_calls == 0: return 0.0 return self._fallback_calls / self._total_calls
[docs] def get_stats(self) -> dict[str, Any]: """ Get fallback statistics. :return: Dictionary with total_calls, fallback_calls, fallback_rate, and error_types. :rtype: dict[str, Any] """ return { "total_calls": self._total_calls, "fallback_calls": self._fallback_calls, "fallback_rate": self.fallback_rate, "error_types": self._error_types.copy(), }
[docs] def reset_stats(self) -> None: """Reset fallback statistics.""" self._total_calls = 0 self._fallback_calls = 0 self._error_types.clear()
[docs] def select_action( self, request: Request, options: list[PathOption], network_state: NetworkState, ) -> int: """ Select an action using ML model with fallback. Flow: 1. Build features from inputs 2. Run model inference 3. Apply action masking (infeasible -> -inf) 4. Select argmax action 5. Validate action is feasible 6. On any error/invalid action: use fallback :param request: The request to serve. :type request: Request :param options: Available path options. :type options: list[PathOption] :param network_state: Current network state. :type network_state: NetworkState :return: Path index (0 to len(options)-1), or -1 if no valid action. :rtype: int """ self._total_calls += 1 # Early return for empty options if not options: return -1 # Check if any feasible options exist feasible_indices = [i for i, opt in enumerate(options) if opt.is_feasible] if not feasible_indices: return -1 try: # Step 1: Build features features = self._feature_builder.build(request, options, network_state) # Step 2: Run inference raw_output = self._model.predict(features) # Step 3: Validate output if not self._validate_output(raw_output, len(options)): return self._use_fallback(request, options, network_state, "invalid_output") # Step 4: Apply mask and select action = self._apply_mask_and_select(raw_output, options) # Step 5: Validate action if self._is_valid_action(action, options): return action # Invalid action - fallback return self._use_fallback(request, options, network_state, "infeasible_action") except ImportError as e: logger.error("ML framework import error: %s", e) return self._use_fallback(request, options, network_state, "import_error") except RuntimeError as e: logger.warning("Model runtime error: %s", e) return self._use_fallback(request, options, network_state, "runtime_error") except Exception as e: logger.warning("Unexpected ML error: %s", e) return self._use_fallback(request, options, network_state, "unknown_error")
def _validate_output(self, output: np.ndarray, expected_len: int) -> bool: """ Check if model output is valid. :param output: Model output array. :type output: np.ndarray :param expected_len: Expected minimum length of output. :type expected_len: int :return: True if output is valid. :rtype: bool """ if output is None: return False # Handle scalar output if output.ndim == 0: idx = int(output.item()) return 0 <= idx < expected_len # Handle vector output if output.ndim == 1: # Must have at least one score return len(output) > 0 # Multi-dimensional - unexpected for inference return False def _apply_mask_and_select(self, raw_output: np.ndarray, options: list[PathOption]) -> int: """ Apply feasibility mask and select best action. :param raw_output: Raw model output scores. :type raw_output: np.ndarray :param options: Available path options. :type options: list[PathOption] :return: Selected action index. :rtype: int """ # Handle scalar output (direct action index) if raw_output.ndim == 0: return int(raw_output.item()) # Build mask scores = raw_output.copy() for i, opt in enumerate(options): if i < len(scores) and not opt.is_feasible: scores[i] = float("-inf") # Also mask any scores beyond options length if len(scores) > len(options): scores[len(options) :] = float("-inf") # Select argmax if np.all(np.isinf(scores)): return -1 return int(np.argmax(scores)) def _is_valid_action(self, action: int, options: list[PathOption]) -> bool: """ Check if action is valid and feasible. :param action: Action index to validate. :type action: int :param options: Available path options. :type options: list[PathOption] :return: True if action is valid and feasible. :rtype: bool """ if action < 0 or action >= len(options): return False return options[action].is_feasible def _use_fallback( self, request: Request, options: list[PathOption], network_state: NetworkState, reason: str = "unknown", ) -> int: """ Use fallback and track statistics. :param request: The request to serve. :type request: Request :param options: Available path options. :type options: list[PathOption] :param network_state: Current network state. :type network_state: NetworkState :param reason: Reason for using fallback. :type reason: str :return: Selected action from fallback policy. :rtype: int """ self._fallback_calls += 1 self._error_types[reason] = self._error_types.get(reason, 0) + 1 return self._fallback.select_action(request, options, network_state)
[docs] def update(self, request: Request, action: int, reward: float) -> None: """ Update policy based on experience. MLControlPolicy is deployment-only and does not learn online. This method is a no-op. :param request: The request that was served (ignored). :type request: Request :param action: The action taken (ignored). :type action: int :param reward: The reward received (ignored). :type reward: float """ pass
[docs] def get_name(self) -> str: """ Return the policy name for logging. :return: Policy name with model path. :rtype: str """ return f"MLControlPolicy({self._model_path})"