Toggle Light / Dark / Auto color theme
Toggle table of contents sidebar
Source code for pyhazards.metrics
from abc import ABC , abstractmethod
from typing import Dict , List , Optional
import torch
import torch.nn.functional as F
[docs]
class MetricBase ( ABC ):
[docs]
@abstractmethod
def update ( self , preds : torch . Tensor , targets : torch . Tensor ) -> None :
...
[docs]
@abstractmethod
def compute ( self ) -> Dict [ str , float ]:
...
[docs]
@abstractmethod
def reset ( self ) -> None :
...
[docs]
class ClassificationMetrics ( MetricBase ):
def __init__ ( self ):
self . reset ()
[docs]
def reset ( self ) -> None :
self . _preds : List [ torch . Tensor ] = []
self . _targets : List [ torch . Tensor ] = []
[docs]
def update ( self , preds : torch . Tensor , targets : torch . Tensor ) -> None :
self . _preds . append ( preds . detach () . cpu ())
self . _targets . append ( targets . detach () . cpu ())
[docs]
def compute ( self ) -> Dict [ str , float ]:
preds = torch . cat ( self . _preds )
targets = torch . cat ( self . _targets )
pred_labels = preds . argmax ( dim =- 1 )
acc = ( pred_labels == targets ) . float () . mean () . item ()
return { "Acc" : acc }
[docs]
class RegressionMetrics ( MetricBase ):
def __init__ ( self ):
self . reset ()
[docs]
def reset ( self ) -> None :
self . _preds : List [ torch . Tensor ] = []
self . _targets : List [ torch . Tensor ] = []
[docs]
def update ( self , preds : torch . Tensor , targets : torch . Tensor ) -> None :
self . _preds . append ( preds . detach () . cpu ())
self . _targets . append ( targets . detach () . cpu ())
[docs]
def compute ( self ) -> Dict [ str , float ]:
preds = torch . cat ( self . _preds )
targets = torch . cat ( self . _targets )
mae = F . l1_loss ( preds , targets ) . item ()
rmse = torch . sqrt ( F . mse_loss ( preds , targets )) . item ()
return { "MAE" : mae , "RMSE" : rmse }
[docs]
class SegmentationMetrics ( MetricBase ):
def __init__ ( self , num_classes : Optional [ int ] = None ):
self . num_classes = num_classes
self . reset ()
[docs]
def reset ( self ) -> None :
self . _preds : List [ torch . Tensor ] = []
self . _targets : List [ torch . Tensor ] = []
[docs]
def update ( self , preds : torch . Tensor , targets : torch . Tensor ) -> None :
self . _preds . append ( preds . detach () . cpu ())
self . _targets . append ( targets . detach () . cpu ())
[docs]
def compute ( self ) -> Dict [ str , float ]:
preds = torch . cat ( self . _preds )
targets = torch . cat ( self . _targets )
pred_labels = preds . argmax ( dim = 1 )
# simple pixel accuracy; extend to IoU/Dice as needed
acc = ( pred_labels == targets ) . float () . mean () . item ()
return { "PixelAcc" : acc }
__all__ = [ "MetricBase" , "ClassificationMetrics" , "RegressionMetrics" , "SegmentationMetrics" ]