Source code for pyhazards.models.wildfire_forecasting

from __future__ import annotations

import torch
import torch.nn as nn


[docs] class WildfireForecasting(nn.Module): """Sequence forecaster for weekly wildfire size-group activity.""" def __init__( self, input_dim: int = 7, hidden_dim: int = 64, output_dim: int = 5, lookback: int = 12, num_layers: int = 2, dropout: float = 0.1, ): super().__init__() if input_dim <= 0: raise ValueError(f"input_dim must be positive, got {input_dim}") if hidden_dim <= 0: raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") if output_dim <= 0: raise ValueError(f"output_dim must be positive, got {output_dim}") if lookback <= 0: raise ValueError(f"lookback must be positive, got {lookback}") if num_layers <= 0: raise ValueError(f"num_layers must be positive, got {num_layers}") if not 0.0 <= dropout < 1.0: raise ValueError(f"dropout must be in [0, 1), got {dropout}") self.lookback = int(lookback) self.encoder = nn.GRU( input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0.0, ) self.attention = nn.Linear(hidden_dim, 1) self.head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, output_dim), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 3: raise ValueError( "WildfireForecasting expects input shape (batch, lookback, features), " f"got {tuple(x.shape)}." ) if x.size(1) != self.lookback: raise ValueError( f"WildfireForecasting expected lookback={self.lookback}, got sequence length {x.size(1)}." ) encoded, _ = self.encoder(x) weights = torch.softmax(self.attention(encoded), dim=1) pooled = torch.sum(weights * encoded, dim=1) return self.head(pooled)
[docs] def wildfire_forecasting_builder( task: str, input_dim: int = 7, hidden_dim: int = 64, output_dim: int = 5, lookback: int = 12, num_layers: int = 2, dropout: float = 0.1, **kwargs, ) -> nn.Module: _ = kwargs if task.lower() not in {"forecasting", "regression"}: raise ValueError( "wildfire_forecasting supports task='forecasting' or 'regression', " f"got {task!r}." ) return WildfireForecasting( input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, lookback=lookback, num_layers=num_layers, dropout=dropout, )
__all__ = ["WildfireForecasting", "wildfire_forecasting_builder"]