Source code for pyhazards.benchmarks.base
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Sequence
import torch.nn as nn
from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from ..reports import BenchmarkReport, export_report_bundle
from .schemas import BenchmarkResult
[docs]
class Benchmark(ABC):
"""Shared benchmark contract for hazard evaluators."""
name: str = "benchmark"
hazard_task: str = ""
[docs]
@abstractmethod
def evaluate(
self,
model: nn.Module,
data: DataBundle,
config: ExperimentConfig,
) -> BenchmarkResult:
raise NotImplementedError
[docs]
def aggregate_metrics(self, results: Sequence[BenchmarkResult]) -> Dict[str, float]:
totals: Dict[str, float] = {}
counts: Dict[str, int] = {}
for result in results:
for key, value in result.metrics.items():
totals[key] = totals.get(key, 0.0) + float(value)
counts[key] = counts.get(key, 0) + 1
return {
key: totals[key] / counts[key]
for key in sorted(totals.keys())
if counts[key] > 0
}
[docs]
def export_report(
self,
result: BenchmarkResult,
output_dir: str,
formats: Iterable[str],
) -> Dict[str, str]:
report = BenchmarkReport(
benchmark_name=result.benchmark_name,
hazard_task=result.hazard_task,
metrics=result.metrics,
metadata=result.metadata,
artifacts=result.artifacts,
)
return export_report_bundle(report, output_dir=output_dir, formats=list(formats))