Feature Extractors Module

Warning

Status: BETA

This module is currently in BETA and is actively being developed. The API and functionality may change in future releases (v6.X).

At a Glance

Purpose:

GNN-based feature extraction for RL observations

Location:

fusion/modules/rl/feat_extrs/

Key Files:

path_gnn.py, graphormer.py, base_feature_extractor.py

Prerequisites:

PyTorch, PyTorch Geometric, Stable-Baselines3

Overview

The feature extractors module transforms variable-sized graph observations into fixed-size feature vectors for reinforcement learning agents. This is essential because standard RL algorithms (PPO, DQN, etc.) require fixed-dimensional inputs, but optical network state is naturally represented as a graph with varying topology.

Why Use Feature Extractors?

  • Graph-to-Vector Transformation: Converts network topology and state into dense vector representations suitable for neural network policies

  • Topology-Aware Learning: GNN architectures capture structural relationships between nodes and edges in the network

  • SB3 Compatibility: All extractors integrate seamlessly with Stable-Baselines3 policy networks

Current Extractors:

Extractor

Architecture

Use Case

PathGNN

GNN (GAT/SAGE/GraphConv)

Standard path-based feature extraction

CachedPathGNN

Pre-computed embeddings

Static graphs, fast inference

GraphTransformerExtractor

Transformer with attention

Experimental, attention-based processing

Processing Pipeline

All feature extractors follow a common pipeline:

+------------------+     +------------------+     +------------------+
| Graph Observation|---->| GNN Convolutions |---->| Node Embeddings  |
| (x, edge_index,  |     | (GAT/SAGE/Conv)  |     | [N, emb_dim]     |
|  path_masks)     |     +------------------+     +------------------+
+------------------+                                      |
                                                          v
+------------------+     +------------------+     +------------------+
| Feature Vector   |<----| Path Aggregation |<----| Edge Embeddings  |
| [batch, feat_dim]|     | (mask @ edges)   |     | [E, emb_dim]     |
+------------------+     +------------------+     +------------------+

Pipeline Steps:

  1. Node Processing: Graph convolution layers process node features

  2. Edge Embeddings: Computed from source and destination node embeddings

  3. Path Aggregation: Path masks select and aggregate edge embeddings

  4. Flattening: Output flattened to fixed-size vector for RL policy

Input/Output Format

Input Observation

Feature extractors expect a dictionary observation with:

Key

Shape

Description

x

[batch, N, F] or [N, F]

Node features (N nodes, F features per node)

edge_index

[batch, 2, E] or [2, E]

Edge connectivity (source/dest indices)

path_masks

[batch, K, E] or [K, E]

Binary masks selecting edges for K paths

Output

All extractors output a tensor of shape [batch_size, features_dim] where features_dim = emb_dim * num_paths.

Quick Start

Using PathGNN with Stable-Baselines3

from stable_baselines3 import PPO
from fusion.modules.rl.feat_extrs import PathGNN

# Create environment (provides observation space)
env = make_your_env()

# Configure PPO with PathGNN feature extractor
model = PPO(
    "MultiInputPolicy",
    env,
    policy_kwargs={
        "features_extractor_class": PathGNN,
        "features_extractor_kwargs": {
            "emb_dim": 64,
            "gnn_type": "gat",   # Options: 'gat', 'sage', 'graphconv'
            "layers": 2,
        }
    }
)

model.learn(total_timesteps=10000)

Direct Usage

from fusion.modules.rl.feat_extrs import PathGNN

# Create extractor
extractor = PathGNN(
    obs_space=env.observation_space,
    emb_dim=64,
    gnn_type="gat",
    layers=2,
)

# Process observation
observation = env.reset()[0]
features = extractor(observation)  # [1, features_dim]

Available GNN Types

The PathGNN extractor supports multiple GNN convolution types via the gnn_type parameter:

GAT (Graph Attention Network)

Config Value:

gat

PyG Class:

GATv2Conv

Description:

Uses attention mechanisms to weight neighbor contributions

GAT learns attention weights for each edge, allowing the model to focus on more important connections. This is often effective for heterogeneous graphs where edge importance varies.

[rl_settings]
feature_extractor = path_gnn
gnn_type = gat

SAGE (GraphSAGE)

Config Value:

sage

PyG Class:

SAGEConv

Description:

Samples and aggregates features from neighbors

GraphSAGE uses a sampling-based approach that scales well to larger graphs. It aggregates neighbor features using mean, max, or LSTM aggregators.

[rl_settings]
feature_extractor = path_gnn
gnn_type = sage

GraphConv

Config Value:

graphconv

PyG Class:

GraphConv

Description:

Standard graph convolution with sum aggregation

A simpler architecture that sums neighbor features with learnable weights. Often faster than attention-based methods with comparable performance.

[rl_settings]
feature_extractor = path_gnn
gnn_type = graph_conv

GraphTransformer (Experimental)

The GraphTransformerExtractor uses Transformer-style multi-head attention for graph processing. This is an experimental approach exploring attention mechanisms for network state representation.

from fusion.modules.rl.feat_extrs import GraphTransformerExtractor

extractor = GraphTransformerExtractor(
    obs_space=env.observation_space,
    emb_dim=64,   # Must be divisible by heads
    heads=4,      # Number of attention heads
    layers=2,
)

Note

The GraphTransformer extractor has higher computational cost due to the attention mechanism. It is currently experimental and may be refined in future versions.

Configuration Reference

INI File Settings

Configure feature extractors in the [rl_settings] section:

Parameter

Default

Description

feature_extractor

path_gnn

Extractor type (path_gnn, mlp)

gnn_type

gat

GNN architecture (gat, sage, graphconv)

emb_dim

64

Embedding dimension for GNN layers

layers

2

Number of GNN convolution layers

heads

4

Attention heads (GraphTransformer only)

Example Configuration

[rl_settings]
# Algorithm selection
path_algorithm = ppo

# Feature extractor configuration
feature_extractor = path_gnn
gnn_type = gat
emb_dim = 64
layers = 2

# Training parameters
is_training = True
device = cuda

Module Constants

Default values defined in constants.py:

DEFAULT_EMBEDDING_DIMENSION = 64
DEFAULT_NUM_LAYERS = 2
DEFAULT_GNN_TYPE = "gat"
DEFAULT_NUM_HEADS = 4
EDGE_EMBEDDING_SCALE_FACTOR = 0.5

Architecture Details

BaseGraphFeatureExtractor

All extractors inherit from BaseGraphFeatureExtractor, which extends SB3’s BaseFeaturesExtractor with graph-specific utilities:

from fusion.modules.rl.feat_extrs import BaseGraphFeatureExtractor

class CustomExtractor(BaseGraphFeatureExtractor):
    def __init__(self, obs_space, features_dim):
        super().__init__(obs_space, features_dim)
        # Custom initialization

    def forward(self, observation):
        # Use inherited utilities
        x, edge_index, masks, batch_size = self._process_batch_dimensions(
            observation["x"],
            observation["edge_index"],
            observation["path_masks"],
        )

        # Process graph...
        node_emb = self.process_nodes(x, edge_index)

        # Compute edge and path embeddings
        edge_emb = self._compute_edge_embeddings(node_emb, edge_index)
        path_emb = self._compute_path_embeddings(edge_emb, masks)

        return path_emb.flatten().unsqueeze(0)

Inherited Methods:

  • _process_batch_dimensions(): Normalizes batch dimensions

  • _compute_edge_embeddings(): Creates edge representations from nodes

  • _compute_path_embeddings(): Aggregates edges according to path masks

Caching for Static Graphs

For static network topologies, use CachedPathGNN to avoid redundant computation:

from fusion.modules.rl.feat_extrs import PathGNNEncoder, CachedPathGNN

# Pre-compute embeddings once
encoder = PathGNNEncoder(obs_space, emb_dim=64, gnn_type="gat", layers=2)
cached_embedding = encoder(x, edge_index, path_masks)

# Use cached version for fast inference
extractor = CachedPathGNN(
    obs_space=obs_space,
    cached_embedding=cached_embedding,
)

Performance Considerations

Extractor

Speed

When to Use

PathGNN

Standard

Dynamic graphs, general use

CachedPathGNN

Fastest

Static topology, inference-heavy workloads

GraphTransformerExtractor

Slowest

Experimental, attention analysis

Future Development (v6.X)

The following enhancements are planned:

  • Optuna Integration: Hyperparameter optimization for GNN architecture (embedding dimension, layer count, GNN type selection)

  • Additional Architectures: GIN (Graph Isomorphism Network), PNA (Principal Neighbourhood Aggregation)

  • Edge Features: Support for edge attributes in convolution layers

See TODO.md in the module directory for the current development roadmap.

File Reference

fusion/modules/rl/feat_extrs/
|-- __init__.py                  # Public exports
|-- README.md                    # Module documentation
|-- TODO.md                      # Development roadmap (BETA status)
|-- constants.py                 # Default values and paths
|-- base_feature_extractor.py    # BaseGraphFeatureExtractor
|-- path_gnn.py                  # PathGNN extractor
|-- path_gnn_cached.py           # CachedPathGNN, PathGNNEncoder
`-- graphormer.py                # GraphTransformerExtractor

What to Import:

# Feature extractors
from fusion.modules.rl.feat_extrs import (
    PathGNN,
    CachedPathGNN,
    PathGNNEncoder,
    GraphTransformerExtractor,
    BaseGraphFeatureExtractor,
)

# Constants
from fusion.modules.rl.feat_extrs import (
    DEFAULT_EMBEDDING_DIMENSION,
    DEFAULT_NUM_LAYERS,
    DEFAULT_GNN_TYPE,
    DEFAULT_NUM_HEADS,
)