Source code for pyhazards.benchmarks.wildfire

from __future__ import annotations

from typing import Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, average_precision_score, f1_score, roc_auc_score

from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from .base import Benchmark
from .registry import register_benchmark
from .schemas import BenchmarkResult


def _spread_metrics(logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
    probs = torch.sigmoid(logits)
    preds = (probs >= 0.5).float()
    targets = targets.float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    iou = float((intersection / union.clamp(min=1.0)).detach().cpu())
    f1 = float((2 * intersection / (preds.sum() + targets.sum()).clamp(min=1.0)).detach().cpu())
    burned_area_mae = float(
        torch.mean(torch.abs(preds.flatten(1).sum(dim=1) - targets.flatten(1).sum(dim=1))).detach().cpu()
    )
    return {"iou": iou, "f1": f1, "burned_area_mae": burned_area_mae}


def _danger_metrics(logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
    if targets.dtype in {torch.int32, torch.int64} or targets.ndim == 1:
        preds = logits.argmax(dim=1)
        probs = F.softmax(logits, dim=1)
        y_true = targets.detach().cpu().numpy()
        y_pred = preds.detach().cpu().numpy()
        y_score = probs.detach().cpu().numpy()
        one_hot = F.one_hot(targets.long(), num_classes=logits.size(1)).detach().cpu().numpy()
        try:
            auc = float(roc_auc_score(one_hot, y_score, average="macro", multi_class="ovr"))
        except ValueError:
            auc = 0.0
        try:
            pr_auc = float(average_precision_score(one_hot, y_score, average="macro"))
        except ValueError:
            pr_auc = 0.0
        return {
            "accuracy": float(accuracy_score(y_true, y_pred)),
            "macro_f1": float(f1_score(y_true, y_pred, average="macro")),
            "auc": auc,
            "pr_auc": pr_auc,
        }

    preds = logits.float()
    targets = targets.float()
    mae = torch.mean(torch.abs(preds - targets))
    rmse = torch.sqrt(torch.mean((preds - targets) ** 2))
    return {
        "mae": float(mae.detach().cpu()),
        "rmse": float(rmse.detach().cpu()),
    }


[docs] class WildfireBenchmark(Benchmark): name = "wildfire" hazard_task = "wildfire.danger" metric_names_by_task = { "wildfire.danger": ["accuracy", "macro_f1", "auc", "pr_auc", "mae", "rmse"], "wildfire.spread": ["iou", "f1", "burned_area_mae"], }
[docs] def evaluate(self, model: nn.Module, data: DataBundle, config: ExperimentConfig) -> BenchmarkResult: split = data.get_split(config.benchmark.eval_split) x = split.inputs y = split.targets logits = model(x) if config.benchmark.hazard_task == "wildfire.danger": metrics = _danger_metrics(logits, y) else: metrics = _spread_metrics(logits, y) return BenchmarkResult( benchmark_name=self.name, hazard_task=config.benchmark.hazard_task, metrics=metrics, metadata={ "split": config.benchmark.eval_split, "dataset_name": data.metadata.get("dataset"), "source_dataset": data.metadata.get("source_dataset", data.metadata.get("dataset")), }, )
register_benchmark(WildfireBenchmark.name, WildfireBenchmark) __all__ = ["WildfireBenchmark"]