Source code for pyhazards.models.wrf_sfire

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class WRFSFireAdapter(nn.Module): """Lightweight raster adapter inspired by WRF-SFIRE style spread diffusion.""" def __init__( self, in_channels: int = 12, out_channels: int = 1, diffusion_steps: int = 3, ): super().__init__() if in_channels <= 0: raise ValueError(f"in_channels must be positive, got {in_channels}") if out_channels != 1: raise ValueError(f"WRFSFireAdapter only supports out_channels=1, got {out_channels}") if diffusion_steps <= 0: raise ValueError(f"diffusion_steps must be positive, got {diffusion_steps}") self.in_channels = int(in_channels) self.diffusion_steps = int(diffusion_steps) kernel = torch.tensor( [[0.02, 0.08, 0.02], [0.08, 0.60, 0.08], [0.02, 0.08, 0.02]], dtype=torch.float32, ).view(1, 1, 3, 3) self.register_buffer("transport_kernel", kernel)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 4: raise ValueError( "WRFSFireAdapter expects input shape (batch, channels, height, width), " f"got {tuple(x.shape)}." ) if x.size(1) != self.in_channels: raise ValueError(f"WRFSFireAdapter expected in_channels={self.in_channels}, got {x.size(1)}.") fireline = torch.sigmoid(x[:, :1]) terrain = torch.sigmoid(x[:, 1:2]) moisture = torch.sigmoid(x[:, 2:3]) for _ in range(self.diffusion_steps): fireline = F.conv2d(fireline, self.transport_kernel, padding=1) fireline = torch.clamp(fireline * (0.9 + 0.1 * terrain) * (1.0 - 0.15 * moisture), 0.0, 1.0) return fireline
[docs] def wrf_sfire_builder( task: str, in_channels: int = 12, out_channels: int = 1, diffusion_steps: int = 3, **kwargs, ) -> nn.Module: _ = kwargs if task.lower() not in {"segmentation", "regression"}: raise ValueError(f"wrf_sfire supports task='segmentation' or 'regression', got {task!r}.") return WRFSFireAdapter( in_channels=in_channels, out_channels=out_channels, diffusion_steps=diffusion_steps, )
__all__ = ["WRFSFireAdapter", "wrf_sfire_builder"]