[docs]classRegressionHead(nn.Module):"""Regression head for scalar or multi-target outputs."""def__init__(self,in_dim:int,out_dim:int=1):super().__init__()self.fc=nn.Linear(in_dim,out_dim)
[docs]classSegmentationHead(nn.Module):"""Segmentation head for raster masks."""def__init__(self,in_channels:int,num_classes:int):super().__init__()self.conv=nn.Conv2d(in_channels,num_classes,kernel_size=1)