Source code for pyhazards.benchmarks.tc
from __future__ import annotations
import torch
import torch.nn as nn
from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from .base import Benchmark
from .registry import register_benchmark
from .schemas import BenchmarkResult
[docs]
class TropicalCycloneBenchmark(Benchmark):
name = "tc"
hazard_task = "tc.track_intensity"
metric_names_by_task = {
"tc.track_intensity": ["track_error", "intensity_mae"],
}
[docs]
def evaluate(self, model: nn.Module, data: DataBundle, config: ExperimentConfig) -> BenchmarkResult:
split = data.get_split(config.benchmark.eval_split)
preds = model(split.inputs)
targets = split.targets
track_error = torch.norm(preds[..., :2] - targets[..., :2], dim=-1).mean()
intensity_mae = torch.mean(torch.abs(preds[..., 2] - targets[..., 2]))
metrics = {
"track_error": float(track_error.detach().cpu()),
"intensity_mae": float(intensity_mae.detach().cpu()),
}
return BenchmarkResult(
benchmark_name=self.name,
hazard_task=config.benchmark.hazard_task,
metrics=metrics,
metadata={
"split": config.benchmark.eval_split,
"dataset_name": data.metadata.get("dataset"),
"source_dataset": data.metadata.get("source_dataset", data.metadata.get("dataset")),
"history": data.metadata.get("history"),
"horizon": data.feature_spec.extra.get("horizon") if data.feature_spec.extra else None,
},
)
register_benchmark(TropicalCycloneBenchmark.name, TropicalCycloneBenchmark)
__all__ = ["TropicalCycloneBenchmark"]