Source code for pyhazards.benchmarks.flood
from __future__ import annotations
from typing import Dict
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from ..datasets.graph import graph_collate
from .base import Benchmark
from .registry import register_benchmark
from .schemas import BenchmarkResult
[docs]
class FloodBenchmark(Benchmark):
name = "flood"
hazard_task = "flood.streamflow"
metric_names_by_task = {
"flood.streamflow": ["mae", "rmse", "nse", "kge"],
"flood.inundation": ["pixel_mae", "iou", "f1"],
}
[docs]
def evaluate(self, model: nn.Module, data: DataBundle, config: ExperimentConfig) -> BenchmarkResult:
split = data.get_split(config.benchmark.eval_split)
if (
config.benchmark.hazard_task == "flood.streamflow"
and hasattr(split.inputs, "__len__")
and not isinstance(split.inputs, torch.Tensor)
):
loader = DataLoader(split.inputs, batch_size=4, shuffle=False, collate_fn=graph_collate)
preds_all = []
target_all = []
with torch.no_grad():
for batch, target in loader:
preds_all.append(model(batch))
target_all.append(target)
preds = torch.cat(preds_all, dim=0)
targets = torch.cat(target_all, dim=0)
else:
preds = model(split.inputs)
targets = split.targets
if config.benchmark.hazard_task == "flood.inundation":
pred_depth = preds.float()
target_depth = targets.float()
pred_mask = (pred_depth >= 0.5).float()
target_mask = (target_depth > 0).float()
intersection = (pred_mask * target_mask).sum()
union = pred_mask.sum() + target_mask.sum() - intersection
metrics: Dict[str, float] = {
"pixel_mae": float(torch.mean(torch.abs(pred_depth - target_depth)).detach().cpu()),
"iou": float((intersection / union.clamp(min=1.0)).detach().cpu()),
"f1": float(
(
2 * intersection
/ (pred_mask.sum() + target_mask.sum()).clamp(min=1.0)
).detach().cpu()
),
}
else:
mae = torch.mean(torch.abs(preds - targets))
rmse = torch.sqrt(torch.mean((preds - targets) ** 2))
target_mean = torch.mean(targets)
denominator = torch.sum((targets - target_mean) ** 2).clamp(min=1e-6)
nse = 1.0 - torch.sum((preds - targets) ** 2) / denominator
pred_std = torch.std(preds).clamp(min=1e-6)
target_std = torch.std(targets).clamp(min=1e-6)
covariance = torch.mean((preds - torch.mean(preds)) * (targets - target_mean))
correlation = covariance / (pred_std * target_std)
alpha = pred_std / target_std
beta = torch.mean(preds) / target_mean.clamp(min=1e-6)
kge = 1.0 - torch.sqrt((correlation - 1.0) ** 2 + (alpha - 1.0) ** 2 + (beta - 1.0) ** 2)
metrics = {
"mae": float(mae.detach().cpu()),
"rmse": float(rmse.detach().cpu()),
"nse": float(nse.detach().cpu()),
"kge": float(kge.detach().cpu()),
}
return BenchmarkResult(
benchmark_name=self.name,
hazard_task=config.benchmark.hazard_task,
metrics=metrics,
metadata={
"split": config.benchmark.eval_split,
"dataset_name": data.metadata.get("dataset"),
"source_dataset": data.metadata.get("source_dataset", data.metadata.get("dataset")),
},
)
register_benchmark(FloodBenchmark.name, FloodBenchmark)
__all__ = ["FloodBenchmark"]