Source code for pyhazards.models.graphcast_tc

from __future__ import annotations

import torch
import torch.nn as nn


[docs] class GraphCastTC(nn.Module): """Experimental wrapper-style GraphCast storm adapter.""" def __init__( self, input_dim: int = 8, hidden_dim: int = 96, horizon: int = 5, output_dim: int = 3, num_layers: int = 2, num_heads: int = 4, dropout: float = 0.1, ): super().__init__() self.horizon = int(horizon) self.output_dim = int(output_dim) self.proj = nn.Linear(input_dim, hidden_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=num_heads, dim_feedforward=2 * hidden_dim, dropout=dropout, batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.head = nn.Linear(hidden_dim, self.horizon * self.output_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 3: raise ValueError("GraphCastTC expects inputs shaped (batch, history, features).") encoded = self.encoder(self.proj(x)) preds = self.head(encoded.mean(dim=1)) return preds.view(x.size(0), self.horizon, self.output_dim)
[docs] def graphcast_tc_builder( task: str, input_dim: int = 8, hidden_dim: int = 96, horizon: int = 5, output_dim: int = 3, num_layers: int = 2, num_heads: int = 4, dropout: float = 0.1, **kwargs, ) -> nn.Module: _ = kwargs if task.lower() != "regression": raise ValueError("GraphCastTC only supports regression for track/intensity forecasting.") return GraphCastTC( input_dim=input_dim, hidden_dim=hidden_dim, horizon=horizon, output_dim=output_dim, num_layers=num_layers, num_heads=num_heads, dropout=dropout, )
__all__ = ["GraphCastTC", "graphcast_tc_builder"]