Source code for fusion.configs.validate

"""Schema validation for FUSION configuration files."""

import json
import os
from pathlib import Path
from typing import Any

from fusion.utils.logging_config import get_logger

logger = get_logger(__name__)


[docs] class ValidationError(Exception): """Custom exception for configuration validation errors."""
[docs] class SchemaValidator: """Schema validator for configuration files."""
[docs] def __init__(self, schema_dir: str | None = None): """ Initialize schema validator. :param schema_dir: Directory containing schema files :type schema_dir: str | None """ self.schema_dir = schema_dir or os.path.join(os.path.dirname(__file__), "schemas") self.schemas: dict[str, dict[str, Any]] = {} self._load_schemas()
def _load_schemas(self) -> None: """Load all schema files from schema directory.""" if not os.path.exists(self.schema_dir): return for schema_file in Path(self.schema_dir).glob("*.json"): schema_name = schema_file.stem try: with open(schema_file, encoding="utf-8") as f: self.schemas[schema_name] = json.load(f) except (OSError, json.JSONDecodeError) as e: logger.warning(f"Could not load schema {schema_file}: {e}")
[docs] def validate(self, config: dict[str, Any], schema_name: str = "main") -> None: """ Validate configuration against schema. :param config: Configuration dictionary to validate :type config: dict[str, Any] :param schema_name: Name of schema to use for validation, defaults to 'main' :type schema_name: str :raises ValidationError: If configuration is invalid """ if schema_name not in self.schemas: logger.warning(f"Schema '{schema_name}' not found, skipping validation") return schema = self.schemas[schema_name] errors = self._validate_recursive(config, schema, "") if errors: raise ValidationError("Configuration validation failed:\n" + "\n".join(errors))
def _validate_recursive(self, config: Any, schema: Any, path: str) -> list[str]: errors_list: list[str] = [] if isinstance(schema, dict): if "type" in schema: errors_list.extend(self._validate_type(config, schema, path)) if "required" in schema and isinstance(config, dict): errors_list.extend(self._validate_required_fields(config, schema["required"], path)) if "properties" in schema and isinstance(config, dict): for prop, prop_schema in schema["properties"].items(): if prop in config: prop_path = f"{path}.{prop}" if path else prop errors_list.extend(self._validate_recursive(config[prop], prop_schema, prop_path)) if "items" in schema and isinstance(config, list): for i, item in enumerate(config): item_path = f"{path}[{i}]" if path else f"[{i}]" errors_list.extend(self._validate_recursive(item, schema["items"], item_path)) return errors_list def _validate_type(self, value: Any, schema: dict[str, Any], path: str) -> list[str]: errors_list: list[str] = [] expected_type = schema["type"] type_map: dict[str, type | tuple[type, ...]] = { "string": str, "number": (int, float), "integer": int, "boolean": bool, "object": dict, "array": list, "null": type(None), } if expected_type in type_map: expected_python_type = type_map[expected_type] if not isinstance(value, expected_python_type): actual_type = type(value).__name__ errors_list.append(f"{path}: Expected {expected_type}, got {actual_type}") # Validate numeric constraints if expected_type == "number" and "minimum" in schema: if value < schema["minimum"]: errors_list.append(f"{path}: Value {value} is below minimum {schema['minimum']}") if expected_type == "number" and "maximum" in schema: if value > schema["maximum"]: errors_list.append(f"{path}: Value {value} is above maximum {schema['maximum']}") # Validate string enumeration if expected_type == "string" and "enum" in schema: if value not in schema["enum"]: errors_list.append(f"{path}: Value '{value}' not in allowed values: {schema['enum']}") return errors_list def _validate_required_fields(self, config: dict[str, Any], required: list[str], path: str) -> list[str]: errors_list: list[str] = [] for field in required: if field not in config: field_path = f"{path}.{field}" if path else field errors_list.append(f"{field_path}: Required field missing") return errors_list
[docs] def get_default_config(self, schema_name: str = "main") -> dict[str, Any]: """ Generate default configuration from schema. :param schema_name: Name of schema to use, defaults to 'main' :type schema_name: str :return: Default configuration dictionary :rtype: dict[str, Any] """ if schema_name not in self.schemas: return {} defaults = self._generate_defaults(self.schemas[schema_name]) return defaults if isinstance(defaults, dict) else {}
def _generate_defaults(self, schema: dict[str, Any]) -> Any: """ Generate default values from schema. :param schema: Schema to generate defaults from :type schema: dict[str, Any] :return: Default value based on schema :rtype: Any """ if "default" in schema: return schema["default"] if "type" not in schema: return None schema_type = schema["type"] if schema_type == "object" and "properties" in schema: result = {} for prop, prop_schema in schema["properties"].items(): result[prop] = self._generate_defaults(prop_schema) return result # Type defaults mapping type_defaults = { "array": [], "string": "", "number": 0.0, "integer": 0, "boolean": False, "null": None, } return type_defaults.get(schema_type, None)
[docs] def validate_survivability_config(config: dict[str, Any]) -> None: """ Validate survivability-specific configuration. :param config: Configuration dictionary :type config: dict[str, Any] :raises ValidationError: If validation fails Example: >>> config = load_config('survivability_experiment.ini') >>> validate_survivability_config(config) """ # Validate against schema validator = SchemaValidator() try: validator.validate(config, "survivability") except ValidationError: # Schema validation failed, but continue with logical validations pass # Additional logical validations _validate_failure_config(config) _validate_protection_config(config) _validate_rl_policy_config(config)
def _validate_failure_config(config: dict[str, Any]) -> None: """Validate failure settings.""" failure_settings = config.get("failure_settings", {}) failure_type = failure_settings.get("failure_type", "none") # Type-specific validations if failure_type == "link": if "failed_link_src" not in failure_settings: raise ValidationError("Link failure requires 'failed_link_src'") if "failed_link_dst" not in failure_settings: raise ValidationError("Link failure requires 'failed_link_dst'") elif failure_type == "node": if "failed_node_id" not in failure_settings: raise ValidationError("Node failure requires 'failed_node_id'") elif failure_type == "srlg": srlg_links = failure_settings.get("srlg_links", []) if not srlg_links: raise ValidationError("SRLG failure requires non-empty 'srlg_links'") elif failure_type == "geo": if "geo_center_node" not in failure_settings: raise ValidationError("Geographic failure requires 'geo_center_node'") if "geo_hop_radius" not in failure_settings: raise ValidationError("Geographic failure requires 'geo_hop_radius'") def _validate_protection_config(config: dict[str, Any]) -> None: """ Validate protection settings. Note: Protection is enabled via route_method=1plus1_protection. This function is kept for future protection-related validations. """ pass # No validation needed - protection controlled by route_method def _validate_rl_policy_config(config: dict[str, Any]) -> None: """Validate RL policy settings.""" rl_settings = config.get("offline_rl_settings", {}) policy_type = rl_settings.get("policy_type", "ksp_ff") # Validate model paths for RL policies if policy_type == "bc": if "bc_model_path" not in rl_settings: raise ValidationError("BC policy requires 'bc_model_path'") model_path = Path(rl_settings["bc_model_path"]) if not model_path.exists(): raise ValidationError(f"BC model not found: {model_path}") elif policy_type == "iql": if "iql_model_path" not in rl_settings: raise ValidationError("IQL policy requires 'iql_model_path'") model_path = Path(rl_settings["iql_model_path"]) if not model_path.exists(): raise ValidationError(f"IQL model not found: {model_path}")