.. _rl-feat-extrs: ========================= Feature Extractors Module ========================= .. warning:: **Status: BETA** This module is currently in **BETA** and is actively being developed. The API and functionality may change in future releases (v6.X). .. admonition:: At a Glance :class: tip :Purpose: GNN-based feature extraction for RL observations :Location: ``fusion/modules/rl/feat_extrs/`` :Key Files: ``path_gnn.py``, ``graphormer.py``, ``base_feature_extractor.py`` :Prerequisites: PyTorch, PyTorch Geometric, Stable-Baselines3 Overview ======== The feature extractors module transforms variable-sized graph observations into fixed-size feature vectors for reinforcement learning agents. This is essential because standard RL algorithms (PPO, DQN, etc.) require fixed-dimensional inputs, but optical network state is naturally represented as a graph with varying topology. **Why Use Feature Extractors?** - **Graph-to-Vector Transformation**: Converts network topology and state into dense vector representations suitable for neural network policies - **Topology-Aware Learning**: GNN architectures capture structural relationships between nodes and edges in the network - **SB3 Compatibility**: All extractors integrate seamlessly with Stable-Baselines3 policy networks **Current Extractors:** .. list-table:: :header-rows: 1 :widths: 25 25 50 * - Extractor - Architecture - Use Case * - ``PathGNN`` - GNN (GAT/SAGE/GraphConv) - Standard path-based feature extraction * - ``CachedPathGNN`` - Pre-computed embeddings - Static graphs, fast inference * - ``GraphTransformerExtractor`` - Transformer with attention - Experimental, attention-based processing Processing Pipeline =================== All feature extractors follow a common pipeline: .. code-block:: text +------------------+ +------------------+ +------------------+ | Graph Observation|---->| GNN Convolutions |---->| Node Embeddings | | (x, edge_index, | | (GAT/SAGE/Conv) | | [N, emb_dim] | | path_masks) | +------------------+ +------------------+ +------------------+ | v +------------------+ +------------------+ +------------------+ | Feature Vector |<----| Path Aggregation |<----| Edge Embeddings | | [batch, feat_dim]| | (mask @ edges) | | [E, emb_dim] | +------------------+ +------------------+ +------------------+ **Pipeline Steps:** 1. **Node Processing**: Graph convolution layers process node features 2. **Edge Embeddings**: Computed from source and destination node embeddings 3. **Path Aggregation**: Path masks select and aggregate edge embeddings 4. **Flattening**: Output flattened to fixed-size vector for RL policy Input/Output Format =================== Input Observation ----------------- Feature extractors expect a dictionary observation with: .. list-table:: :header-rows: 1 :widths: 20 40 40 * - Key - Shape - Description * - ``x`` - ``[batch, N, F]`` or ``[N, F]`` - Node features (N nodes, F features per node) * - ``edge_index`` - ``[batch, 2, E]`` or ``[2, E]`` - Edge connectivity (source/dest indices) * - ``path_masks`` - ``[batch, K, E]`` or ``[K, E]`` - Binary masks selecting edges for K paths Output ------ All extractors output a tensor of shape ``[batch_size, features_dim]`` where ``features_dim = emb_dim * num_paths``. Quick Start =========== Using PathGNN with Stable-Baselines3 ------------------------------------ .. code-block:: python from stable_baselines3 import PPO from fusion.modules.rl.feat_extrs import PathGNN # Create environment (provides observation space) env = make_your_env() # Configure PPO with PathGNN feature extractor model = PPO( "MultiInputPolicy", env, policy_kwargs={ "features_extractor_class": PathGNN, "features_extractor_kwargs": { "emb_dim": 64, "gnn_type": "gat", # Options: 'gat', 'sage', 'graphconv' "layers": 2, } } ) model.learn(total_timesteps=10000) Direct Usage ------------ .. code-block:: python from fusion.modules.rl.feat_extrs import PathGNN # Create extractor extractor = PathGNN( obs_space=env.observation_space, emb_dim=64, gnn_type="gat", layers=2, ) # Process observation observation = env.reset()[0] features = extractor(observation) # [1, features_dim] Available GNN Types =================== The ``PathGNN`` extractor supports multiple GNN convolution types via the ``gnn_type`` parameter: GAT (Graph Attention Network) ----------------------------- :Config Value: ``gat`` :PyG Class: ``GATv2Conv`` :Description: Uses attention mechanisms to weight neighbor contributions GAT learns attention weights for each edge, allowing the model to focus on more important connections. This is often effective for heterogeneous graphs where edge importance varies. .. code-block:: ini [rl_settings] feature_extractor = path_gnn gnn_type = gat SAGE (GraphSAGE) ---------------- :Config Value: ``sage`` :PyG Class: ``SAGEConv`` :Description: Samples and aggregates features from neighbors GraphSAGE uses a sampling-based approach that scales well to larger graphs. It aggregates neighbor features using mean, max, or LSTM aggregators. .. code-block:: ini [rl_settings] feature_extractor = path_gnn gnn_type = sage GraphConv --------- :Config Value: ``graphconv`` :PyG Class: ``GraphConv`` :Description: Standard graph convolution with sum aggregation A simpler architecture that sums neighbor features with learnable weights. Often faster than attention-based methods with comparable performance. .. code-block:: ini [rl_settings] feature_extractor = path_gnn gnn_type = graph_conv GraphTransformer (Experimental) ------------------------------- The ``GraphTransformerExtractor`` uses Transformer-style multi-head attention for graph processing. This is an experimental approach exploring attention mechanisms for network state representation. .. code-block:: python from fusion.modules.rl.feat_extrs import GraphTransformerExtractor extractor = GraphTransformerExtractor( obs_space=env.observation_space, emb_dim=64, # Must be divisible by heads heads=4, # Number of attention heads layers=2, ) .. note:: The GraphTransformer extractor has higher computational cost due to the attention mechanism. It is currently experimental and may be refined in future versions. Configuration Reference ======================= INI File Settings ----------------- Configure feature extractors in the ``[rl_settings]`` section: .. list-table:: :header-rows: 1 :widths: 25 15 60 * - Parameter - Default - Description * - ``feature_extractor`` - ``path_gnn`` - Extractor type (``path_gnn``, ``mlp``) * - ``gnn_type`` - ``gat`` - GNN architecture (``gat``, ``sage``, ``graphconv``) * - ``emb_dim`` - ``64`` - Embedding dimension for GNN layers * - ``layers`` - ``2`` - Number of GNN convolution layers * - ``heads`` - ``4`` - Attention heads (GraphTransformer only) Example Configuration --------------------- .. code-block:: ini [rl_settings] # Algorithm selection path_algorithm = ppo # Feature extractor configuration feature_extractor = path_gnn gnn_type = gat emb_dim = 64 layers = 2 # Training parameters is_training = True device = cuda Module Constants ---------------- Default values defined in ``constants.py``: .. code-block:: python DEFAULT_EMBEDDING_DIMENSION = 64 DEFAULT_NUM_LAYERS = 2 DEFAULT_GNN_TYPE = "gat" DEFAULT_NUM_HEADS = 4 EDGE_EMBEDDING_SCALE_FACTOR = 0.5 Architecture Details ==================== BaseGraphFeatureExtractor ------------------------- All extractors inherit from ``BaseGraphFeatureExtractor``, which extends SB3's ``BaseFeaturesExtractor`` with graph-specific utilities: .. code-block:: python from fusion.modules.rl.feat_extrs import BaseGraphFeatureExtractor class CustomExtractor(BaseGraphFeatureExtractor): def __init__(self, obs_space, features_dim): super().__init__(obs_space, features_dim) # Custom initialization def forward(self, observation): # Use inherited utilities x, edge_index, masks, batch_size = self._process_batch_dimensions( observation["x"], observation["edge_index"], observation["path_masks"], ) # Process graph... node_emb = self.process_nodes(x, edge_index) # Compute edge and path embeddings edge_emb = self._compute_edge_embeddings(node_emb, edge_index) path_emb = self._compute_path_embeddings(edge_emb, masks) return path_emb.flatten().unsqueeze(0) **Inherited Methods:** - ``_process_batch_dimensions()``: Normalizes batch dimensions - ``_compute_edge_embeddings()``: Creates edge representations from nodes - ``_compute_path_embeddings()``: Aggregates edges according to path masks Caching for Static Graphs ------------------------- For static network topologies, use ``CachedPathGNN`` to avoid redundant computation: .. code-block:: python from fusion.modules.rl.feat_extrs import PathGNNEncoder, CachedPathGNN # Pre-compute embeddings once encoder = PathGNNEncoder(obs_space, emb_dim=64, gnn_type="gat", layers=2) cached_embedding = encoder(x, edge_index, path_masks) # Use cached version for fast inference extractor = CachedPathGNN( obs_space=obs_space, cached_embedding=cached_embedding, ) Performance Considerations ========================== .. list-table:: :header-rows: 1 :widths: 25 25 50 * - Extractor - Speed - When to Use * - ``PathGNN`` - Standard - Dynamic graphs, general use * - ``CachedPathGNN`` - Fastest - Static topology, inference-heavy workloads * - ``GraphTransformerExtractor`` - Slowest - Experimental, attention analysis Future Development (v6.X) ========================= The following enhancements are planned: - **Optuna Integration**: Hyperparameter optimization for GNN architecture (embedding dimension, layer count, GNN type selection) - **Additional Architectures**: GIN (Graph Isomorphism Network), PNA (Principal Neighbourhood Aggregation) - **Edge Features**: Support for edge attributes in convolution layers See ``TODO.md`` in the module directory for the current development roadmap. File Reference ============== .. code-block:: text fusion/modules/rl/feat_extrs/ |-- __init__.py # Public exports |-- README.md # Module documentation |-- TODO.md # Development roadmap (BETA status) |-- constants.py # Default values and paths |-- base_feature_extractor.py # BaseGraphFeatureExtractor |-- path_gnn.py # PathGNN extractor |-- path_gnn_cached.py # CachedPathGNN, PathGNNEncoder `-- graphormer.py # GraphTransformerExtractor **What to Import:** .. code-block:: python # Feature extractors from fusion.modules.rl.feat_extrs import ( PathGNN, CachedPathGNN, PathGNNEncoder, GraphTransformerExtractor, BaseGraphFeatureExtractor, ) # Constants from fusion.modules.rl.feat_extrs import ( DEFAULT_EMBEDDING_DIMENSION, DEFAULT_NUM_LAYERS, DEFAULT_GNN_TYPE, DEFAULT_NUM_HEADS, ) Related Documentation ===================== - :ref:`rl-module` - Parent RL module documentation - :ref:`rl-algorithms` - RL algorithms that use feature extractors - :ref:`rl-environments` - Environments providing graph observations .. seealso:: - `PyTorch Geometric `_ - GNN library - `Stable-Baselines3 Custom Features `_