Source code for pyhazards.models.heads

import torch.nn as nn


[docs] class ClassificationHead(nn.Module): """Simple classification head.""" def __init__(self, in_dim: int, num_classes: int): super().__init__() self.fc = nn.Linear(in_dim, num_classes)
[docs] def forward(self, x): return self.fc(x)
[docs] class RegressionHead(nn.Module): """Regression head for scalar or multi-target outputs.""" def __init__(self, in_dim: int, out_dim: int = 1): super().__init__() self.fc = nn.Linear(in_dim, out_dim)
[docs] def forward(self, x): return self.fc(x)
[docs] class SegmentationHead(nn.Module): """Segmentation head for raster masks.""" def __init__(self, in_channels: int, num_classes: int): super().__init__() self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)
[docs] def forward(self, x): return self.conv(x)
__all__ = ["ClassificationHead", "RegressionHead", "SegmentationHead"]