"""
Multi-seed result aggregation for survivability experiments.
This module provides utilities for aggregating simulation results across
multiple random seeds, computing statistics like mean, standard deviation,
and 95% confidence intervals.
"""
from typing import Any
import numpy as np
from fusion.utils.logging_config import get_logger
logger = get_logger(__name__)
[docs]
def aggregate_seed_results(
results: list[dict[str, Any]],
metric_keys: list[str] | None = None,
) -> dict[str, dict[str, float]]:
"""
Aggregate results across multiple seeds.
Computes mean, std, and 95% confidence intervals for specified metrics.
:param results: List of result dicts (one per seed)
:type results: list[dict[str, Any]]
:param metric_keys: Metrics to aggregate (if None, aggregates all numeric metrics)
:type metric_keys: list[str] | None
:return: Aggregated stats with mean, std, ci95
:rtype: dict[str, dict[str, float]]
Example:
>>> results = [
... {'bp_overall': 0.05, 'recovery_time_mean_ms': 52.3, 'seed': 42},
... {'bp_overall': 0.06, 'recovery_time_mean_ms': 51.8, 'seed': 43}
... ]
>>> metrics = ['bp_overall', 'recovery_time_mean_ms']
>>> agg = aggregate_seed_results(results, metrics)
>>> print(agg['bp_overall'])
{'mean': 0.055, 'std': 0.007, 'ci95_lower': 0.041, 'n': 2}
"""
if not results:
logger.warning("No results to aggregate")
return {}
# Auto-detect numeric metrics if not specified
if metric_keys is None:
metric_keys = []
for key, value in results[0].items():
if isinstance(value, (int, float)) and key != "seed":
metric_keys.append(key)
aggregated = {}
for key in metric_keys:
values = [r[key] for r in results if key in r and isinstance(r[key], (int, float))]
if not values:
logger.warning("No values found for metric: %s", key)
continue
mean_val = np.mean(values)
std_val = np.std(values, ddof=1) if len(values) > 1 else 0.0
ci95 = 1.96 * std_val / np.sqrt(len(values)) if len(values) > 0 else 0.0
aggregated[key] = {
"mean": float(mean_val),
"std": float(std_val),
"ci95_lower": float(mean_val - ci95),
"ci95_upper": float(mean_val + ci95),
"n": len(values),
}
logger.info("Aggregated %d metrics across %d seeds", len(aggregated), len(results))
return aggregated
[docs]
def create_comparison_table(
baseline_results: list[dict[str, Any]],
rl_results: list[dict[str, Any]],
metrics: list[str],
) -> dict[str, dict[str, Any]]:
"""
Create comparison table for baseline vs RL.
:param baseline_results: Baseline results (one per seed)
:type baseline_results: list[dict[str, Any]]
:param rl_results: RL policy results (one per seed)
:type rl_results: list[dict[str, Any]]
:param metrics: Metrics to compare
:type metrics: list[str]
:return: Comparison dictionary
:rtype: dict[str, dict[str, Any]]
Example:
>>> baseline = [{'bp_overall': 0.10, 'seed': 42}]
>>> rl = [{'bp_overall': 0.08, 'seed': 42}]
>>> comp = create_comparison_table(baseline, rl, ['bp_overall'])
>>> print(comp['bp_overall']['improvement_pct'])
25.0
"""
baseline_agg = aggregate_seed_results(baseline_results, metrics)
rl_agg = aggregate_seed_results(rl_results, metrics)
comparison = {}
for metric in metrics:
if metric not in baseline_agg or metric not in rl_agg:
logger.warning("Metric %s not found in both baseline and RL results", metric)
continue
baseline_val = baseline_agg[metric]["mean"]
rl_val = rl_agg[metric]["mean"]
# Improvement calculation (positive = RL is better for BP-like metrics)
if baseline_val != 0:
improvement = ((baseline_val - rl_val) / baseline_val) * 100
else:
improvement = 0.0
comparison[metric] = {
"baseline_mean": baseline_val,
"baseline_std": baseline_agg[metric]["std"],
"baseline_ci95": (
baseline_agg[metric]["ci95_lower"],
baseline_agg[metric]["ci95_upper"],
),
"rl_mean": rl_val,
"rl_std": rl_agg[metric]["std"],
"rl_ci95": (rl_agg[metric]["ci95_lower"], rl_agg[metric]["ci95_upper"]),
"improvement_pct": improvement,
"n_baseline": baseline_agg[metric]["n"],
"n_rl": rl_agg[metric]["n"],
}
logger.info("Created comparison table for %d metrics", len(comparison))
return comparison