Source code for pyhazards.configs._schema

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List

import yaml

from ..tasks import get_hazard_task

_REPORT_FORMATS = {"json", "md", "csv"}


[docs] @dataclass class DatasetRef: name: str params: Dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class ModelRef: name: str task: str params: Dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class ReportConfig: output_dir: str = "reports" formats: List[str] = field(default_factory=lambda: ["json"]) def __post_init__(self) -> None: normalized = [fmt.lower() for fmt in self.formats] unknown = [fmt for fmt in normalized if fmt not in _REPORT_FORMATS] if unknown: raise ValueError( "Unknown report format(s): {unknown}. Known: {known}".format( unknown=", ".join(sorted(set(unknown))), known=", ".join(sorted(_REPORT_FORMATS)), ) ) self.formats = normalized
[docs] @dataclass class BenchmarkConfig: name: str hazard_task: str metrics: List[str] = field(default_factory=list) eval_split: str = "test" params: Dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: self.hazard_task = get_hazard_task(self.hazard_task).name
[docs] @dataclass class ExperimentConfig: benchmark: BenchmarkConfig dataset: DatasetRef model: ModelRef report: ReportConfig = field(default_factory=ReportConfig) seed: int = 0 metadata: Dict[str, Any] = field(default_factory=dict)
[docs] def to_dict(self) -> Dict[str, Any]: return asdict(self)
[docs] def load_experiment_config(path: str | Path) -> ExperimentConfig: raw = yaml.safe_load(Path(path).read_text(encoding="utf-8")) or {} return ExperimentConfig( benchmark=BenchmarkConfig(**raw["benchmark"]), dataset=DatasetRef(**raw["dataset"]), model=ModelRef(**raw["model"]), report=ReportConfig(**raw.get("report", {})), seed=raw.get("seed", 0), metadata=raw.get("metadata", {}), )
[docs] def dump_experiment_config(config: ExperimentConfig, path: str | Path) -> None: payload = config.to_dict() Path(path).write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8")
__all__ = [ "BenchmarkConfig", "DatasetRef", "ExperimentConfig", "ModelRef", "ReportConfig", "dump_experiment_config", "load_experiment_config", ]