Source code for pyhazards.metrics

from abc import ABC, abstractmethod
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F


[docs] class MetricBase(ABC):
[docs] @abstractmethod def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: ...
[docs] @abstractmethod def compute(self) -> Dict[str, float]: ...
[docs] @abstractmethod def reset(self) -> None: ...
[docs] class ClassificationMetrics(MetricBase): def __init__(self): self.reset()
[docs] def reset(self) -> None: self._preds: List[torch.Tensor] = [] self._targets: List[torch.Tensor] = []
[docs] def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: self._preds.append(preds.detach().cpu()) self._targets.append(targets.detach().cpu())
[docs] def compute(self) -> Dict[str, float]: preds = torch.cat(self._preds) targets = torch.cat(self._targets) pred_labels = preds.argmax(dim=-1) acc = (pred_labels == targets).float().mean().item() return {"Acc": acc}
[docs] class RegressionMetrics(MetricBase): def __init__(self): self.reset()
[docs] def reset(self) -> None: self._preds: List[torch.Tensor] = [] self._targets: List[torch.Tensor] = []
[docs] def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: self._preds.append(preds.detach().cpu()) self._targets.append(targets.detach().cpu())
[docs] def compute(self) -> Dict[str, float]: preds = torch.cat(self._preds) targets = torch.cat(self._targets) mae = F.l1_loss(preds, targets).item() rmse = torch.sqrt(F.mse_loss(preds, targets)).item() return {"MAE": mae, "RMSE": rmse}
[docs] class SegmentationMetrics(MetricBase): def __init__(self, num_classes: Optional[int] = None): self.num_classes = num_classes self.reset()
[docs] def reset(self) -> None: self._preds: List[torch.Tensor] = [] self._targets: List[torch.Tensor] = []
[docs] def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: self._preds.append(preds.detach().cpu()) self._targets.append(targets.detach().cpu())
[docs] def compute(self) -> Dict[str, float]: preds = torch.cat(self._preds) targets = torch.cat(self._targets) pred_labels = preds.argmax(dim=1) # simple pixel accuracy; extend to IoU/Dice as needed acc = (pred_labels == targets).float().mean().item() return {"PixelAcc": acc}
__all__ = ["MetricBase", "ClassificationMetrics", "RegressionMetrics", "SegmentationMetrics"]