Source code for pyhazards.models.wildfirespreadts

from __future__ import annotations

import torch
import torch.nn as nn


[docs] class WildfireSpreadTS(nn.Module): """Temporal convolution baseline for wildfire spread masks.""" def __init__( self, history: int = 4, in_channels: int = 6, hidden_dim: int = 32, out_channels: int = 1, dropout: float = 0.1, ): super().__init__() if history <= 0: raise ValueError(f"history must be positive, got {history}") if in_channels <= 0: raise ValueError(f"in_channels must be positive, got {in_channels}") if hidden_dim <= 0: raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") if out_channels <= 0: raise ValueError(f"out_channels must be positive, got {out_channels}") if not 0.0 <= dropout < 1.0: raise ValueError(f"dropout must be in [0, 1), got {dropout}") self.history = int(history) self.in_channels = int(in_channels) self.encoder = nn.Sequential( nn.Conv3d(in_channels, hidden_dim, kernel_size=(3, 3, 3), padding=1), nn.GELU(), nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(), nn.Conv3d(hidden_dim, hidden_dim, kernel_size=(3, 3, 3), padding=1), nn.GELU(), ) self.head = nn.Conv2d(hidden_dim, out_channels, kernel_size=1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 5: raise ValueError( "WildfireSpreadTS expects input shape (batch, history, channels, height, width), " f"got {tuple(x.shape)}." ) if x.size(1) != self.history: raise ValueError(f"WildfireSpreadTS expected history={self.history}, got {x.size(1)}.") if x.size(2) != self.in_channels: raise ValueError(f"WildfireSpreadTS expected in_channels={self.in_channels}, got {x.size(2)}.") encoded = self.encoder(x.permute(0, 2, 1, 3, 4)) return self.head(torch.mean(encoded, dim=2))
[docs] def wildfirespreadts_builder( task: str, history: int = 4, in_channels: int = 6, hidden_dim: int = 32, out_channels: int = 1, dropout: float = 0.1, **kwargs, ) -> nn.Module: _ = kwargs if task.lower() not in {"segmentation", "regression"}: raise ValueError( "wildfirespreadts supports task='segmentation' or 'regression', " f"got {task!r}." ) return WildfireSpreadTS( history=history, in_channels=in_channels, hidden_dim=hidden_dim, out_channels=out_channels, dropout=dropout, )
__all__ = ["WildfireSpreadTS", "wildfirespreadts_builder"]