Source code for pyhazards.models.google_flood_forecasting

from __future__ import annotations

import torch
import torch.nn as nn


[docs] class GoogleFloodForecasting(nn.Module): """Sequence baseline for streamflow-style flood forecasting.""" def __init__( self, input_dim: int = 2, hidden_dim: int = 64, out_dim: int = 1, history: int = 4, dropout: float = 0.1, ): super().__init__() if input_dim <= 0: raise ValueError(f"input_dim must be positive, got {input_dim}") if hidden_dim <= 0: raise ValueError(f"hidden_dim must be positive, got {hidden_dim}") if out_dim <= 0: raise ValueError(f"out_dim must be positive, got {out_dim}") if history <= 0: raise ValueError(f"history must be positive, got {history}") if not 0.0 <= dropout < 1.0: raise ValueError(f"dropout must be in [0, 1), got {dropout}") self.history = int(history) self.proj = nn.Linear(input_dim, hidden_dim) self.temporal = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=4, dim_feedforward=hidden_dim * 2, dropout=dropout, batch_first=True, activation="gelu", ), num_layers=2, ) self.head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, out_dim), )
[docs] def forward(self, batch) -> torch.Tensor: if not isinstance(batch, dict) or "x" not in batch: raise ValueError("GoogleFloodForecasting expects a mapping input with key 'x'.") x = batch["x"] if x.ndim != 4: raise ValueError( "GoogleFloodForecasting expects input shape (batch, history, nodes, features), " f"got {tuple(x.shape)}." ) if x.size(1) != self.history: raise ValueError(f"GoogleFloodForecasting expected history={self.history}, got {x.size(1)}.") encoded = self.proj(x) temporal = encoded.permute(0, 2, 1, 3).reshape(-1, self.history, encoded.size(-1)) hidden = self.temporal(temporal)[:, -1] preds = self.head(hidden) return preds.view(x.size(0), x.size(2), -1)
[docs] def google_flood_forecasting_builder( task: str, input_dim: int = 2, hidden_dim: int = 64, out_dim: int = 1, history: int = 4, dropout: float = 0.1, **kwargs, ) -> nn.Module: _ = kwargs if task.lower() != "regression": raise ValueError( "google_flood_forecasting only supports task='regression', " f"got {task!r}." ) return GoogleFloodForecasting( input_dim=input_dim, hidden_dim=hidden_dim, out_dim=out_dim, history=history, dropout=dropout, )
__all__ = ["GoogleFloodForecasting", "google_flood_forecasting_builder"]