Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for pyhazards.models.wildfire_aspp
from __future__ import annotations
import torch
import torch.nn as nn
from .cnn_aspp import WildfireCNNASPP , cnn_aspp_builder
[docs]
class WildfireASPP ( WildfireCNNASPP ):
"""
Backward-compatible name for the CNN + ASPP wildfire model.
"""
[docs]
def wildfire_aspp_builder ( * args , ** kwargs ) -> nn . Module :
return cnn_aspp_builder ( * args , ** kwargs )
[docs]
class TverskyLoss ( nn . Module ):
"""
Tversky loss for binary segmentation.
"""
def __init__ (
self ,
alpha : float = 0.5 ,
beta : float = 0.5 ,
smooth : float = 1e-6 ,
from_logits : bool = True ,
):
super () . __init__ ()
self . alpha = float ( alpha )
self . beta = float ( beta )
self . smooth = float ( smooth )
self . from_logits = bool ( from_logits )
[docs]
def forward ( self , logits : torch . Tensor , targets : torch . Tensor ) -> torch . Tensor :
if self . from_logits :
probs = torch . sigmoid ( logits )
else :
probs = logits
targets = targets . float ()
probs = probs . view ( probs . size ( 0 ), - 1 )
targets = targets . view ( targets . size ( 0 ), - 1 )
tp = ( probs * targets ) . sum ( dim = 1 )
fp = ( probs * ( 1 - targets )) . sum ( dim = 1 )
fn = (( 1 - probs ) * targets ) . sum ( dim = 1 )
tversky = ( tp + self . smooth ) / (
tp + self . alpha * fp + self . beta * fn + self . smooth
)
loss = 1.0 - tversky
return loss . mean ()