Source code for pyhazards.models.cnn_aspp

from __future__ import annotations

from typing import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------
# Basic blocks
# ---------------------------------------------------------------------

class ConvBNReLU(nn.Module):
    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        k: int = 3,
        s: int = 1,
        p: int = 1,
        d: int = 1,
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_ch,
            out_ch,
            kernel_size=k,
            stride=s,
            padding=p,
            dilation=d,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.bn(self.conv(x)))


# ---------------------------------------------------------------------
# ASPP
# ---------------------------------------------------------------------

class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling (ASPP).

    Parallel atrous convolutions + image pooling branch,
    followed by projection.
    """

    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        dilations: Sequence[int] = (1, 3, 6, 12),
    ):
        super().__init__()

        if len(dilations) != 4:
            raise ValueError("ASPP expects exactly 4 dilation rates")

        d1, d2, d3, d4 = dilations

        self.b1 = ConvBNReLU(in_ch, out_ch, k=1, p=0, d=d1)
        self.b2 = ConvBNReLU(in_ch, out_ch, k=3, p=d2, d=d2)
        self.b3 = ConvBNReLU(in_ch, out_ch, k=3, p=d3, d=d3)
        self.b4 = ConvBNReLU(in_ch, out_ch, k=3, p=d4, d=d4)

        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBNReLU(in_ch, out_ch, k=1, p=0),
        )

        self.proj = ConvBNReLU(out_ch * 5, out_ch, k=1, p=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape

        p = self.pool(x)
        p = F.interpolate(p, size=(h, w), mode="bilinear", align_corners=False)

        y = torch.cat(
            [self.b1(x), self.b2(x), self.b3(x), self.b4(x), p],
            dim=1,
        )
        return self.proj(y)


# ---------------------------------------------------------------------
# CNN + ASPP model
# ---------------------------------------------------------------------

[docs] class WildfireCNNASPP(nn.Module): """ CNN + ASPP wildfire segmentation model. Input: x : (B, C, H, W) float tensor Output: logits : (B, 1, H, W) float tensor (sigmoid applied externally) """ def __init__( self, in_channels: int = 12, base_channels: int = 32, aspp_channels: int = 32, dilations: Sequence[int] = (1, 3, 6, 12), dropout: float = 0.0, ): super().__init__() self.stem = nn.Sequential( ConvBNReLU(in_channels, base_channels, k=3, p=1), ConvBNReLU(base_channels, base_channels, k=3, p=1), ) self.aspp = ASPP( in_ch=base_channels, out_ch=aspp_channels, dilations=dilations, ) self.drop = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() self.head = nn.Conv2d(aspp_channels, 1, kernel_size=1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 4: raise ValueError( f"Expected input of shape (B,C,H,W), got {tuple(x.shape)}" ) f = self.stem(x) y = self.aspp(f) y = self.drop(y) return self.head(y)
# --------------------------------------------------------------------- # PyHazards model builder # ---------------------------------------------------------------------
[docs] def cnn_aspp_builder( task: str, in_channels: int = 12, base_channels: int = 32, aspp_channels: int = 32, dilations: Sequence[int] = (1, 3, 6, 12), dropout: float = 0.0, **kwargs, ) -> nn.Module: """ PyHazards-style model builder. """ _ = kwargs # explicitly ignore unused builder args if "segmentation" not in task: raise ValueError( f"WildfireCNNASPP is segmentation-only. Got task='{task}'" ) return WildfireCNNASPP( in_channels=in_channels, base_channels=base_channels, aspp_channels=aspp_channels, dilations=dilations, dropout=dropout, )