Source code for pyhazards.benchmarks.runner
from __future__ import annotations
from typing import Union
import torch.nn as nn
from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from .base import Benchmark
from .registry import build_benchmark
from .schemas import BenchmarkRunSummary
[docs]
def resolve_benchmark(benchmark: Union[str, Benchmark]) -> Benchmark:
if isinstance(benchmark, Benchmark):
return benchmark
return build_benchmark(benchmark)
[docs]
def run_benchmark(
benchmark: Union[str, Benchmark],
model: nn.Module,
data: DataBundle,
config: ExperimentConfig,
output_dir: str | None = None,
) -> BenchmarkRunSummary:
benchmark_obj = resolve_benchmark(benchmark)
result = benchmark_obj.evaluate(model=model, data=data, config=config)
metrics = benchmark_obj.aggregate_metrics([result])
result.metrics = metrics
report_dir = output_dir or config.report.output_dir
report_paths = benchmark_obj.export_report(result, output_dir=report_dir, formats=config.report.formats)
metadata = dict(result.metadata)
metadata.setdefault("eval_split", config.benchmark.eval_split)
return BenchmarkRunSummary(
benchmark_name=result.benchmark_name,
hazard_task=result.hazard_task,
metrics=metrics,
report_paths=report_paths,
metadata=metadata,
)