"""
RL Policy wrapper for Stable-Baselines3 models.
This module provides the RLPolicy class that wraps pre-trained SB3 models
to implement the ControlPolicy protocol, enabling unified policy handling
in the SDNOrchestrator.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from fusion.utils.logging_config import get_logger
if TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm
from fusion.domain.network_state import NetworkState
from fusion.domain.request import Request
from fusion.modules.rl.adapter import PathOption, RLSimulationAdapter
logger = get_logger(__name__)
[docs]
class RLPolicy:
"""
Wrapper enabling SB3 models to implement ControlPolicy.
This adapter bridges the existing RL infrastructure with the
ControlPolicy protocol. It handles:
1. Observation building: Converts (request, options, state) to RL observation
2. Action masking: Enforces feasibility constraints during prediction
3. Action conversion: Converts SB3 action to path index
The wrapped model is pre-trained and does not learn online. Use
UnifiedSimEnv and SB3's learn() for online training.
:ivar model: Pre-trained SB3 model (PPO, DQN, A2C, etc.).
:vartype model: BaseAlgorithm
:ivar k_paths: Number of path options (for observation space).
:vartype k_paths: int
Example::
>>> from stable_baselines3 import PPO
>>> model = PPO.load("trained_model.zip")
>>> policy = RLPolicy(model)
>>> action = policy.select_action(request, options, network_state)
"""
[docs]
def __init__(
self,
model: BaseAlgorithm,
adapter: RLSimulationAdapter | None = None,
k_paths: int = 5,
) -> None:
"""
Initialize RLPolicy with a trained SB3 model.
:param model: Pre-trained SB3 model with predict() method.
:type model: BaseAlgorithm
:param adapter: Optional adapter for observation building. If None,
uses internal observation construction.
:type adapter: RLSimulationAdapter | None
:param k_paths: Expected number of path options (for obs space size).
:type k_paths: int
:raises ValueError: If model does not have predict() method.
"""
self.model = model
self._adapter = adapter
self.k_paths = k_paths
# Validate model has predict method
if not hasattr(model, "predict"):
raise ValueError(f"Model {type(model).__name__} does not have predict() method")
logger.info(
"RLPolicy initialized with %s, k_paths=%d",
type(model).__name__,
k_paths,
)
[docs]
def select_action(
self,
request: Request,
options: list[PathOption],
network_state: NetworkState,
) -> int:
"""
Select action using the trained SB3 model.
Builds an observation from the inputs, generates an action mask
from feasibility flags, and uses the model to predict an action.
:param request: The incoming request to serve.
:type request: Request
:param options: Available path options with feasibility information.
:type options: list[PathOption]
:param network_state: Current state of the network.
:type network_state: NetworkState
:return: Path index (0 to len(options)-1), or -1 if no valid action.
:rtype: int
.. note::
The model must support action masking. For models trained with
sb3-contrib's MaskablePPO, the action_masks parameter is used.
For standard models, masking is applied post-prediction.
"""
# Build observation
obs = self._build_observation(request, options, network_state)
# Build action mask
action_mask = self._build_action_mask(options)
# Check if any action is valid
if not any(action_mask):
logger.debug("No feasible actions available")
return -1
try:
# Try to use native action masking if available
if self._supports_action_masking():
# MaskablePPO supports action_masks parameter
raw_action, _ = self.model.predict(
obs,
deterministic=True,
action_masks=np.array(action_mask), # type: ignore[call-arg]
)
action: int = int(raw_action)
else:
# Predict without masking, then validate
raw_action, _ = self.model.predict(obs, deterministic=True)
action = int(raw_action)
# If predicted action is infeasible, find first feasible
if action >= len(options) or not action_mask[action]:
logger.debug(
"Model predicted infeasible action %d, selecting first feasible",
action,
)
action = self._find_first_feasible(action_mask)
return action if action >= 0 else -1
except Exception as e:
logger.warning("Model prediction failed: %s, returning -1", e)
return -1
def _build_observation(
self,
request: Request,
options: list[PathOption],
network_state: NetworkState,
) -> np.ndarray:
"""
Build observation array for model prediction.
If an adapter is available, delegates to adapter.build_observation().
Otherwise, constructs observation matching training format.
:param request: The incoming request.
:type request: Request
:param options: Available path options.
:type options: list[PathOption]
:param network_state: Current network state.
:type network_state: NetworkState
:return: Numpy array matching model's observation space.
:rtype: np.ndarray
"""
if self._adapter is not None:
obs = self._adapter.build_observation(request, options, network_state)
# Adapter may return dict for Dict observation spaces; we expect array
if isinstance(obs, np.ndarray):
return obs
# Fallback: convert to array if possible
return np.asarray(obs, dtype=np.float32)
# Internal observation construction
features: list[float] = []
# Request features
features.append(request.bandwidth_gbps / 1000.0) # Normalized
# Per-path features (padded to k_paths)
for i in range(self.k_paths):
if i < len(options):
opt = options[i]
features.extend(
[
opt.weight_km / 10000.0, # Normalized length
opt.congestion, # Already 0-1
1.0 if opt.is_feasible else 0.0, # Feasibility
(opt.slots_needed or 0) / 100.0, # Normalized slots
]
)
else:
# Padding for missing paths
features.extend([0.0, 1.0, 0.0, 0.0])
return np.array(features, dtype=np.float32)
def _build_action_mask(self, options: list[PathOption]) -> list[bool]:
"""
Build action mask from path options.
:param options: Available path options.
:type options: list[PathOption]
:return: List of booleans, True where action is valid (is_feasible).
:rtype: list[bool]
"""
mask = [opt.is_feasible for opt in options]
# Pad to k_paths if needed
while len(mask) < self.k_paths:
mask.append(False)
return mask[: self.k_paths]
def _supports_action_masking(self) -> bool:
"""
Check if model supports native action masking.
:return: True if model supports action_masks parameter.
:rtype: bool
"""
# MaskablePPO and similar algorithms support action_masks parameter
model_name = type(self.model).__name__
return model_name in ("MaskablePPO", "MaskableRecurrentPPO")
def _find_first_feasible(self, mask: list[bool]) -> int:
"""
Find index of first feasible action.
:param mask: Action mask with True for feasible actions.
:type mask: list[bool]
:return: Index of first feasible action, or -1 if none.
:rtype: int
"""
for i, is_feasible in enumerate(mask):
if is_feasible:
return i
return -1
[docs]
def update(self, request: Request, action: int, reward: float) -> None:
"""
Update policy based on experience.
RLPolicy wraps pre-trained models that do not learn online.
This method is a no-op to satisfy the ControlPolicy protocol.
For online RL training, use UnifiedSimEnv with SB3's learn() method.
:param request: The request that was served (ignored).
:type request: Request
:param action: The action taken (ignored).
:type action: int
:param reward: The reward received (ignored).
:type reward: float
"""
pass
[docs]
def get_name(self) -> str:
"""
Return policy name for logging and metrics.
:return: String identifying this policy and underlying model.
:rtype: str
"""
model_name = type(self.model).__name__
return f"RLPolicy({model_name})"
[docs]
def set_adapter(self, adapter: RLSimulationAdapter) -> None:
"""
Set the RL simulation adapter for observation building.
:param adapter: RL simulation adapter for observation building.
:type adapter: RLSimulationAdapter
"""
self._adapter = adapter
[docs]
@classmethod
def from_file(
cls,
model_path: str,
algorithm: str = "PPO",
**kwargs: Any,
) -> RLPolicy:
"""
Load RLPolicy from a saved model file.
:param model_path: Path to saved model (e.g., "model.zip").
:type model_path: str
:param algorithm: SB3 algorithm name ("PPO", "DQN", "A2C", "MaskablePPO", etc.).
:type algorithm: str
:param kwargs: Additional arguments passed to RLPolicy.__init__.
:return: RLPolicy wrapping the loaded model.
:rtype: RLPolicy
:raises ValueError: If algorithm is unknown/not installed.
Example::
>>> policy = RLPolicy.from_file("trained_ppo.zip", algorithm="PPO")
"""
import importlib
algorithm_class = None
# Try standard stable_baselines3 first
try:
sb3_module = importlib.import_module("stable_baselines3")
algorithm_class = getattr(sb3_module, algorithm, None)
except ImportError:
pass
# If not found, try sb3_contrib for maskable algorithms
if algorithm_class is None:
try:
sb3_contrib = importlib.import_module("sb3_contrib")
algorithm_class = getattr(sb3_contrib, algorithm, None)
except ImportError:
pass
if algorithm_class is None:
raise ValueError(f"Unknown algorithm: {algorithm}. Ensure stable_baselines3 or sb3_contrib is installed.")
model = algorithm_class.load(model_path)
return cls(model, **kwargs)
# Type alias for backwards compatibility
RLControlPolicy = RLPolicy