Source code for pyhazards.models.wildfire_mamba

from __future__ import annotations

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


def _normalize_adjacency(adj: torch.Tensor) -> torch.Tensor:
    """
    Row-normalize an adjacency matrix and ensure self-loops.
    Accepts (N, N) or (B, N, N) and returns the same rank.
    """
    if adj.dim() == 2:
        adj = adj.unsqueeze(0)
    eye = torch.eye(adj.size(-1), device=adj.device, dtype=adj.dtype)
    adj = adj.float() + eye.unsqueeze(0)
    return adj / adj.sum(-1, keepdim=True).clamp(min=1e-6)


class SelectiveSSMBlock(nn.Module):
    """
    Lightweight selective state-space block inspired by Mamba.

    Operates over a single temporal stream: (batch, time, features) -> (batch, time, hidden_dim).
    """

    def __init__(self, in_dim: int, hidden_dim: int, state_dim: int = 64, conv_kernel: int = 5, dropout: float = 0.1):
        super().__init__()
        self.in_proj = nn.Linear(in_dim, hidden_dim)
        self.dwconv = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=conv_kernel, padding=conv_kernel // 2, groups=hidden_dim)
        self.gate = nn.Linear(hidden_dim, hidden_dim)
        self.A = nn.Parameter(torch.randn(hidden_dim, state_dim) * 0.02)
        self.B = nn.Parameter(torch.randn(state_dim, hidden_dim) * 0.02)
        self.out_proj = nn.Linear(hidden_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, F)
        h = self.in_proj(x)  # (B, T, H)
        h_conv = self.dwconv(h.transpose(1, 2)).transpose(1, 2)
        g = torch.sigmoid(self.gate(h_conv))
        B, T, H = h_conv.shape
        state = torch.zeros(B, H, device=h_conv.device, dtype=h_conv.dtype)
        outputs = []
        for t in range(T):
            # selective update: gates decide how much new signal to mix into the running state
            state = g[:, t, :] * (state @ self.A @ self.B + h_conv[:, t, :]) + (1 - g[:, t, :]) * state
            outputs.append(state)
        y = torch.stack(outputs, dim=1)
        y = self.out_proj(self.drop(y)) + h_conv
        return self.norm(y)


class MambaTemporalEncoder(nn.Module):
    """Stack of selective SSM blocks; returns the last hidden state."""

    def __init__(self, in_dim: int, hidden_dim: int = 128, num_layers: int = 2, state_dim: int = 64, conv_kernel: int = 5, dropout: float = 0.1):
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                SelectiveSSMBlock(
                    in_dim=in_dim if i == 0 else hidden_dim,
                    hidden_dim=hidden_dim,
                    state_dim=state_dim,
                    conv_kernel=conv_kernel,
                    dropout=dropout,
                )
                for i in range(num_layers)
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = x
        for block in self.blocks:
            h = block(h)
        return h[:, -1, :]


class SimpleGCN(nn.Module):
    """Two-layer GCN that mixes counties with a fixed adjacency."""

    def __init__(self, in_dim: int, hidden_dim: int = 64, out_dim: int = 64, dropout: float = 0.1):
        super().__init__()
        self.lin1 = nn.Linear(in_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, H: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
        # H: (B, N, D); adj: (B, N, N)
        z = torch.matmul(adj, H)
        z = F.relu(self.lin1(z))
        z = self.drop(z)
        z = torch.matmul(adj, z)
        return F.relu(self.lin2(z))


[docs] class WildfireMamba(nn.Module): """ Mamba-based spatio-temporal wildfire model for county-day ERA5 features. Input shape: (batch, past_days, num_counties, num_features) Output: logits per county for the next day (use sigmoid for probabilities) """ def __init__( self, in_dim: int, num_counties: int, past_days: int, hidden_dim: int = 128, gcn_hidden: int = 64, mamba_layers: int = 2, state_dim: int = 64, conv_kernel: int = 5, dropout: float = 0.1, adjacency: Optional[torch.Tensor] = None, with_count_head: bool = False, ): super().__init__() self.num_counties = num_counties self.past_days = past_days self.with_count_head = with_count_head self.temporal = MambaTemporalEncoder( in_dim=in_dim, hidden_dim=hidden_dim, num_layers=mamba_layers, state_dim=state_dim, conv_kernel=conv_kernel, dropout=dropout, ) # differential branch is shallower and gates how much change to inject self.delta_temporal = MambaTemporalEncoder( in_dim=in_dim, hidden_dim=hidden_dim, num_layers=max(1, mamba_layers - 1), state_dim=state_dim, conv_kernel=conv_kernel, dropout=dropout, ) self.delta_gate = nn.Linear(hidden_dim, hidden_dim) self.gcn = SimpleGCN(hidden_dim, hidden_dim=gcn_hidden, out_dim=gcn_hidden, dropout=dropout) self.cls_head = nn.Linear(gcn_hidden, 1) if self.with_count_head: self.count_head = nn.Linear(gcn_hidden, 1) self.dropout = nn.Dropout(dropout) self.register_buffer("_adjacency", None) if adjacency is not None: self.set_adjacency(adjacency)
[docs] def set_adjacency(self, adj: torch.Tensor) -> None: """Set/override the spatial adjacency.""" adj = _normalize_adjacency(adj.detach()) self._adjacency = adj
[docs] def _get_adjacency(self, batch_size: int) -> torch.Tensor: if self._adjacency is None: eye = torch.eye(self.num_counties, device=self.cls_head.weight.device) adj = _normalize_adjacency(eye) else: adj = self._adjacency if adj.dim() == 2: adj = adj.unsqueeze(0) if adj.size(0) == 1 and batch_size > 1: adj = adj.expand(batch_size, -1, -1) return adj
[docs] @staticmethod def _temporal_delta(x: torch.Tensor) -> torch.Tensor: # prepend zeros so delta has the same length as the input sequence zeros = torch.zeros(x.size(0), 1, x.size(2), device=x.device, dtype=x.dtype) return torch.cat([zeros, x[:, 1:] - x[:, :-1]], dim=1)
[docs] def forward(self, x: torch.Tensor, adjacency: Optional[torch.Tensor] = None): """ Args: x: Tensor shaped (batch, past_days, num_counties, in_dim) adjacency: Optional (N, N) or (B, N, N) adjacency override. Returns: - logits: (batch, num_counties) - optional counts: (batch, num_counties) if with_count_head is enabled. """ B, T, N, F = x.shape if T != self.past_days: raise ValueError(f"Expected past_days={self.past_days}, got {T}.") if N != self.num_counties: raise ValueError(f"Expected num_counties={self.num_counties}, got {N}.") # flatten counties into the batch for temporal encoding x_flat = x.permute(0, 2, 1, 3).reshape(B * N, T, F) base = self.temporal(x_flat) delta = self.delta_temporal(self._temporal_delta(x_flat)) gate = torch.sigmoid(self.delta_gate(delta)) fused = base * gate + delta fused = fused.view(B, N, -1) adj = _normalize_adjacency(adjacency) if adjacency is not None else self._get_adjacency(B) spatial = self.gcn(fused, adj) spatial = self.dropout(spatial) logits = self.cls_head(spatial).squeeze(-1) if self.with_count_head: counts = F.relu(self.count_head(spatial)).squeeze(-1) return logits, counts return logits
[docs] def wildfire_mamba_builder( task: str, in_dim: int, num_counties: int, past_days: int, **kwargs, ) -> WildfireMamba: """ Builder used by the model registry. """ if task.lower() not in {"classification", "binary_classification"}: raise ValueError("WildfireMamba is designed for binary per-county classification.") return WildfireMamba( in_dim=in_dim, num_counties=num_counties, past_days=past_days, hidden_dim=kwargs.get("hidden_dim", 128), gcn_hidden=kwargs.get("gcn_hidden", 64), mamba_layers=kwargs.get("mamba_layers", 2), state_dim=kwargs.get("state_dim", 64), conv_kernel=kwargs.get("conv_kernel", 5), dropout=kwargs.get("dropout", 0.1), adjacency=kwargs.get("adjacency"), with_count_head=kwargs.get("with_count_head", False), )
__all__ = ["WildfireMamba", "wildfire_mamba_builder"]