Source code for pyhazards.models.wildfire_aspp
from __future__ import annotations
import torch
import torch.nn as nn
from .cnn_aspp import WildfireCNNASPP, cnn_aspp_builder
[docs]
class WildfireASPP(WildfireCNNASPP):
"""
Backward-compatible name for the CNN + ASPP wildfire model.
"""
[docs]
def wildfire_aspp_builder(task: str, **kwargs) -> nn.Module:
return cnn_aspp_builder(task=task, **kwargs)
[docs]
class TverskyLoss(nn.Module):
"""
Tversky loss for binary segmentation.
"""
def __init__(
self,
alpha: float = 0.5,
beta: float = 0.5,
smooth: float = 1e-6,
from_logits: bool = True,
):
super().__init__()
self.alpha = float(alpha)
self.beta = float(beta)
self.smooth = float(smooth)
self.from_logits = bool(from_logits)
[docs]
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
if self.from_logits:
probs = torch.sigmoid(logits)
else:
probs = logits
targets = targets.float()
probs = probs.view(probs.size(0), -1)
targets = targets.view(targets.size(0), -1)
tp = (probs * targets).sum(dim=1)
fp = (probs * (1 - targets)).sum(dim=1)
fn = ((1 - probs) * targets).sum(dim=1)
tversky = (tp + self.smooth) / (
tp + self.alpha * fp + self.beta * fn + self.smooth
)
loss = 1.0 - tversky
return loss.mean()