Feature Extractors Module
Warning
Status: BETA
This module is currently in BETA and is actively being developed. The API and functionality may change in future releases (v6.X).
At a Glance
- Purpose:
GNN-based feature extraction for RL observations
- Location:
fusion/modules/rl/feat_extrs/- Key Files:
path_gnn.py,graphormer.py,base_feature_extractor.py- Prerequisites:
PyTorch, PyTorch Geometric, Stable-Baselines3
Overview
The feature extractors module transforms variable-sized graph observations into fixed-size feature vectors for reinforcement learning agents. This is essential because standard RL algorithms (PPO, DQN, etc.) require fixed-dimensional inputs, but optical network state is naturally represented as a graph with varying topology.
Why Use Feature Extractors?
Graph-to-Vector Transformation: Converts network topology and state into dense vector representations suitable for neural network policies
Topology-Aware Learning: GNN architectures capture structural relationships between nodes and edges in the network
SB3 Compatibility: All extractors integrate seamlessly with Stable-Baselines3 policy networks
Current Extractors:
Extractor |
Architecture |
Use Case |
|---|---|---|
|
GNN (GAT/SAGE/GraphConv) |
Standard path-based feature extraction |
|
Pre-computed embeddings |
Static graphs, fast inference |
|
Transformer with attention |
Experimental, attention-based processing |
Processing Pipeline
All feature extractors follow a common pipeline:
+------------------+ +------------------+ +------------------+
| Graph Observation|---->| GNN Convolutions |---->| Node Embeddings |
| (x, edge_index, | | (GAT/SAGE/Conv) | | [N, emb_dim] |
| path_masks) | +------------------+ +------------------+
+------------------+ |
v
+------------------+ +------------------+ +------------------+
| Feature Vector |<----| Path Aggregation |<----| Edge Embeddings |
| [batch, feat_dim]| | (mask @ edges) | | [E, emb_dim] |
+------------------+ +------------------+ +------------------+
Pipeline Steps:
Node Processing: Graph convolution layers process node features
Edge Embeddings: Computed from source and destination node embeddings
Path Aggregation: Path masks select and aggregate edge embeddings
Flattening: Output flattened to fixed-size vector for RL policy
Input/Output Format
Input Observation
Feature extractors expect a dictionary observation with:
Key |
Shape |
Description |
|---|---|---|
|
|
Node features (N nodes, F features per node) |
|
|
Edge connectivity (source/dest indices) |
|
|
Binary masks selecting edges for K paths |
Output
All extractors output a tensor of shape [batch_size, features_dim] where
features_dim = emb_dim * num_paths.
Quick Start
Using PathGNN with Stable-Baselines3
from stable_baselines3 import PPO
from fusion.modules.rl.feat_extrs import PathGNN
# Create environment (provides observation space)
env = make_your_env()
# Configure PPO with PathGNN feature extractor
model = PPO(
"MultiInputPolicy",
env,
policy_kwargs={
"features_extractor_class": PathGNN,
"features_extractor_kwargs": {
"emb_dim": 64,
"gnn_type": "gat", # Options: 'gat', 'sage', 'graphconv'
"layers": 2,
}
}
)
model.learn(total_timesteps=10000)
Direct Usage
from fusion.modules.rl.feat_extrs import PathGNN
# Create extractor
extractor = PathGNN(
obs_space=env.observation_space,
emb_dim=64,
gnn_type="gat",
layers=2,
)
# Process observation
observation = env.reset()[0]
features = extractor(observation) # [1, features_dim]
Available GNN Types
The PathGNN extractor supports multiple GNN convolution types via the
gnn_type parameter:
GAT (Graph Attention Network)
- Config Value:
gat- PyG Class:
GATv2Conv- Description:
Uses attention mechanisms to weight neighbor contributions
GAT learns attention weights for each edge, allowing the model to focus on more important connections. This is often effective for heterogeneous graphs where edge importance varies.
[rl_settings]
feature_extractor = path_gnn
gnn_type = gat
SAGE (GraphSAGE)
- Config Value:
sage- PyG Class:
SAGEConv- Description:
Samples and aggregates features from neighbors
GraphSAGE uses a sampling-based approach that scales well to larger graphs. It aggregates neighbor features using mean, max, or LSTM aggregators.
[rl_settings]
feature_extractor = path_gnn
gnn_type = sage
GraphConv
- Config Value:
graphconv- PyG Class:
GraphConv- Description:
Standard graph convolution with sum aggregation
A simpler architecture that sums neighbor features with learnable weights. Often faster than attention-based methods with comparable performance.
[rl_settings]
feature_extractor = path_gnn
gnn_type = graph_conv
GraphTransformer (Experimental)
The GraphTransformerExtractor uses Transformer-style multi-head attention
for graph processing. This is an experimental approach exploring attention
mechanisms for network state representation.
from fusion.modules.rl.feat_extrs import GraphTransformerExtractor
extractor = GraphTransformerExtractor(
obs_space=env.observation_space,
emb_dim=64, # Must be divisible by heads
heads=4, # Number of attention heads
layers=2,
)
Note
The GraphTransformer extractor has higher computational cost due to the attention mechanism. It is currently experimental and may be refined in future versions.
Configuration Reference
INI File Settings
Configure feature extractors in the [rl_settings] section:
Parameter |
Default |
Description |
|---|---|---|
|
|
Extractor type ( |
|
|
GNN architecture ( |
|
|
Embedding dimension for GNN layers |
|
|
Number of GNN convolution layers |
|
|
Attention heads (GraphTransformer only) |
Example Configuration
[rl_settings]
# Algorithm selection
path_algorithm = ppo
# Feature extractor configuration
feature_extractor = path_gnn
gnn_type = gat
emb_dim = 64
layers = 2
# Training parameters
is_training = True
device = cuda
Module Constants
Default values defined in constants.py:
DEFAULT_EMBEDDING_DIMENSION = 64
DEFAULT_NUM_LAYERS = 2
DEFAULT_GNN_TYPE = "gat"
DEFAULT_NUM_HEADS = 4
EDGE_EMBEDDING_SCALE_FACTOR = 0.5
Architecture Details
BaseGraphFeatureExtractor
All extractors inherit from BaseGraphFeatureExtractor, which extends
SB3’s BaseFeaturesExtractor with graph-specific utilities:
from fusion.modules.rl.feat_extrs import BaseGraphFeatureExtractor
class CustomExtractor(BaseGraphFeatureExtractor):
def __init__(self, obs_space, features_dim):
super().__init__(obs_space, features_dim)
# Custom initialization
def forward(self, observation):
# Use inherited utilities
x, edge_index, masks, batch_size = self._process_batch_dimensions(
observation["x"],
observation["edge_index"],
observation["path_masks"],
)
# Process graph...
node_emb = self.process_nodes(x, edge_index)
# Compute edge and path embeddings
edge_emb = self._compute_edge_embeddings(node_emb, edge_index)
path_emb = self._compute_path_embeddings(edge_emb, masks)
return path_emb.flatten().unsqueeze(0)
Inherited Methods:
_process_batch_dimensions(): Normalizes batch dimensions_compute_edge_embeddings(): Creates edge representations from nodes_compute_path_embeddings(): Aggregates edges according to path masks
Caching for Static Graphs
For static network topologies, use CachedPathGNN to avoid redundant
computation:
from fusion.modules.rl.feat_extrs import PathGNNEncoder, CachedPathGNN
# Pre-compute embeddings once
encoder = PathGNNEncoder(obs_space, emb_dim=64, gnn_type="gat", layers=2)
cached_embedding = encoder(x, edge_index, path_masks)
# Use cached version for fast inference
extractor = CachedPathGNN(
obs_space=obs_space,
cached_embedding=cached_embedding,
)
Performance Considerations
Extractor |
Speed |
When to Use |
|---|---|---|
|
Standard |
Dynamic graphs, general use |
|
Fastest |
Static topology, inference-heavy workloads |
|
Slowest |
Experimental, attention analysis |
Future Development (v6.X)
The following enhancements are planned:
Optuna Integration: Hyperparameter optimization for GNN architecture (embedding dimension, layer count, GNN type selection)
Additional Architectures: GIN (Graph Isomorphism Network), PNA (Principal Neighbourhood Aggregation)
Edge Features: Support for edge attributes in convolution layers
See TODO.md in the module directory for the current development roadmap.
File Reference
fusion/modules/rl/feat_extrs/
|-- __init__.py # Public exports
|-- README.md # Module documentation
|-- TODO.md # Development roadmap (BETA status)
|-- constants.py # Default values and paths
|-- base_feature_extractor.py # BaseGraphFeatureExtractor
|-- path_gnn.py # PathGNN extractor
|-- path_gnn_cached.py # CachedPathGNN, PathGNNEncoder
`-- graphormer.py # GraphTransformerExtractor
What to Import:
# Feature extractors
from fusion.modules.rl.feat_extrs import (
PathGNN,
CachedPathGNN,
PathGNNEncoder,
GraphTransformerExtractor,
BaseGraphFeatureExtractor,
)
# Constants
from fusion.modules.rl.feat_extrs import (
DEFAULT_EMBEDDING_DIMENSION,
DEFAULT_NUM_LAYERS,
DEFAULT_GNN_TYPE,
DEFAULT_NUM_HEADS,
)