Source code for pyhazards.models.hydrographnet

from __future__ import annotations

from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 64, dropout: float = 0.0):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class KAN(nn.Module):
    """
    Lightweight KAN-style harmonic basis encoder for node features.
    """

    def __init__(self, in_dim: int, harmonics: int = 5, hidden_dim: int = 64):
        super().__init__()
        self.in_dim = in_dim
        self.harmonics = harmonics
        self.feature_proj = nn.ModuleList(
            [nn.Linear(2 * harmonics + 1, hidden_dim) for _ in range(in_dim)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, N, F)
        outputs = []
        for i in range(self.in_dim):
            xi = x[:, :, i].unsqueeze(-1)
            basis = [torch.ones_like(xi)]
            for k in range(1, self.harmonics + 1):
                basis.append(torch.sin(k * xi))
                basis.append(torch.cos(k * xi))
            basis = torch.cat(basis, dim=-1)
            outputs.append(self.feature_proj[i](basis))
        return torch.stack(outputs, dim=0).sum(dim=0)


class GNBlock(nn.Module):
    """
    Message-passing block with residual edge and node updates.
    """

    def __init__(self, hidden_dim: int, dropout: float = 0.0):
        super().__init__()
        self.edge_mlp = MLP(3 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)
        self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)

    def forward(
        self,
        node: torch.Tensor,
        edge: torch.Tensor,
        senders: torch.Tensor,
        receivers: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sender_feat = node[:, senders, :]
        receiver_feat = node[:, receivers, :]

        edge_input = torch.cat([edge, sender_feat, receiver_feat], dim=-1)
        edge = edge + self.edge_mlp(edge_input)

        agg = torch.zeros_like(node)
        agg.index_add_(1, receivers, edge)

        # Degree-normalized aggregation improves stability when graph density changes.
        deg = torch.zeros(node.size(1), device=node.device, dtype=node.dtype)
        deg.index_add_(0, receivers, torch.ones_like(receivers, dtype=node.dtype))
        agg = agg / deg.clamp(min=1.0).view(1, -1, 1)

        node_input = torch.cat([node, agg], dim=-1)
        node = node + self.node_mlp(node_input)
        return node, edge


[docs] class HydroGraphNet(nn.Module): """ PhysicsNeMo-inspired HydroGraphNet: encoder -> message-passing processor -> residual delta-state decoder. Supports one-step forward prediction and autoregressive rollout. """ def __init__( self, node_in_dim: int, edge_in_dim: int, out_dim: int, hidden_dim: int = 64, harmonics: int = 5, num_gn_blocks: int = 5, state_dim: Optional[int] = None, rollout_steps: int = 1, enforce_nonnegative: bool = False, dropout: float = 0.0, ): super().__init__() self.node_in_dim = int(node_in_dim) self.edge_in_dim = int(edge_in_dim) self.out_dim = int(out_dim) self.state_dim = int(state_dim) if state_dim is not None else min(2, self.node_in_dim) self.state_dim = max(1, min(self.state_dim, self.node_in_dim)) if self.out_dim > self.state_dim: raise ValueError( f"out_dim={self.out_dim} cannot exceed residual state_dim={self.state_dim}." ) self.rollout_steps = max(1, int(rollout_steps)) self.enforce_nonnegative = bool(enforce_nonnegative) # Encoder self.node_encoder = KAN( in_dim=self.node_in_dim, hidden_dim=hidden_dim, harmonics=harmonics, ) self.edge_encoder = MLP( in_dim=self.edge_in_dim, out_dim=hidden_dim, hidden_dim=hidden_dim, dropout=dropout, ) # Processor self.processor = nn.ModuleList( [GNBlock(hidden_dim=hidden_dim, dropout=dropout) for _ in range(num_gn_blocks)] ) # Decoder predicts delta of physically meaningful states. self.decoder = MLP( in_dim=hidden_dim, out_dim=self.state_dim, hidden_dim=hidden_dim, dropout=dropout, )
[docs] def _edge_index(self, adj: torch.Tensor, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: if adj.dim() == 2: a = adj elif adj.dim() == 3: if adj.size(0) != batch_size: raise ValueError(f"adj batch size mismatch: got {adj.size(0)}, expected {batch_size}") a = adj[0] for i in range(1, batch_size): if not torch.allclose(adj[i], a): raise ValueError( "Per-sample varying adjacency is not supported yet. " "Provide a shared (N, N) adjacency or identical (B, N, N) adjacency." ) else: raise ValueError("adj must be shaped (N, N) or (B, N, N).") a = (a > 0).to(dtype=torch.bool) a.fill_diagonal_(True) return a.nonzero(as_tuple=True)
[docs] def _match_edge_dim(self, edge_feat: torch.Tensor) -> torch.Tensor: # edge_feat: (B, E, F_edge_raw) f_raw = edge_feat.size(-1) if f_raw == self.edge_in_dim: return edge_feat if f_raw > self.edge_in_dim: return edge_feat[..., : self.edge_in_dim] pad = torch.zeros( edge_feat.size(0), edge_feat.size(1), self.edge_in_dim - f_raw, device=edge_feat.device, dtype=edge_feat.dtype, ) return torch.cat([edge_feat, pad], dim=-1)
[docs] def _prepare_edge_inputs( self, batch: Dict[str, torch.Tensor], senders: torch.Tensor, receivers: torch.Tensor, batch_size: int, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: edge_attr = batch.get("edge_attr") if edge_attr is not None: edge_attr = edge_attr.to(device=device, dtype=dtype) if edge_attr.dim() == 2: edge_attr = edge_attr.unsqueeze(0).expand(batch_size, -1, -1) elif edge_attr.dim() == 3 and edge_attr.size(0) == 1 and batch_size > 1: edge_attr = edge_attr.expand(batch_size, -1, -1) if edge_attr.dim() != 3: raise ValueError("edge_attr must be shaped (E, F_edge) or (B, E, F_edge).") if edge_attr.size(1) != senders.numel(): raise ValueError( f"edge_attr edge count mismatch: got {edge_attr.size(1)}, expected {senders.numel()}." ) return self._match_edge_dim(edge_attr) # Derive geometric edge features from coords: [dx, dy, distance] coords = batch.get("coords") if coords is None: edge_feat = torch.zeros(batch_size, senders.numel(), 3, device=device, dtype=dtype) return self._match_edge_dim(edge_feat) coords = coords.to(device=device, dtype=dtype) if coords.dim() == 2: coords = coords.unsqueeze(0).expand(batch_size, -1, -1) elif coords.dim() == 3 and coords.size(0) == 1 and batch_size > 1: coords = coords.expand(batch_size, -1, -1) if coords.dim() != 3: raise ValueError("coords must be shaped (N, 2) or (B, N, 2).") src = coords[:, senders, :] dst = coords[:, receivers, :] delta = src - dst dist = torch.norm(delta, dim=-1, keepdim=True) edge_feat = torch.cat([delta, dist], dim=-1) return self._match_edge_dim(edge_feat)
[docs] def _one_step( self, node_x: torch.Tensor, batch: Dict[str, torch.Tensor], ) -> torch.Tensor: # node_x: (B, N, F) if node_x.ndim != 3: raise ValueError(f"Expected node_x with shape (B,N,F), got {tuple(node_x.shape)}") if node_x.size(-1) < self.state_dim: raise ValueError( f"Input feature dim {node_x.size(-1)} is smaller than state_dim {self.state_dim}." ) adj = batch.get("adj") if adj is None: raise ValueError("HydroGraphNet requires `adj` in the batch.") adj = adj.to(device=node_x.device) senders, receivers = self._edge_index(adj, batch_size=node_x.size(0)) # ---- encoder ---- node = self.node_encoder(node_x) edge_in = self._prepare_edge_inputs( batch=batch, senders=senders, receivers=receivers, batch_size=node.size(0), device=node.device, dtype=node.dtype, ) edge = self.edge_encoder(edge_in) # ---- processor ---- for gn in self.processor: node, edge = gn(node, edge, senders, receivers) # ---- decoder: residual state update ---- delta_state = self.decoder(node) # (B, N, state_dim) prev_state = node_x[..., : self.state_dim] next_state = prev_state + delta_state if self.enforce_nonnegative: next_state = next_state.clamp_min(0.0) # Return requested targets from the evolved state. return next_state[..., : self.out_dim]
[docs] def rollout(self, batch: Dict[str, torch.Tensor], predict_steps: int) -> torch.Tensor: batch_roll = dict(batch) batch_roll["predict_steps"] = int(predict_steps) out = self.forward(batch_roll) if out.ndim != 4: raise RuntimeError("rollout expected stacked output with shape (B, S, N, out_dim).") return out
[docs] def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # batch["x"]: (B, T, N, F) x = batch["x"] if x.ndim != 4: raise ValueError( f"HydroGraphNet expects x shaped (B, T, N, F), got {tuple(x.shape)}" ) predict_steps = int(batch.get("predict_steps", self.rollout_steps)) predict_steps = max(1, predict_steps) history = x preds = [] for _ in range(predict_steps): node_x = history[:, -1] y_next = self._one_step(node_x=node_x, batch=batch) # (B, N, out_dim) preds.append(y_next) if predict_steps > 1: next_frame = history[:, -1].clone() next_frame[..., : self.out_dim] = y_next history = torch.cat([history[:, 1:], next_frame.unsqueeze(1)], dim=1) if predict_steps == 1: return preds[0] return torch.stack(preds, dim=1)
[docs] class HydroGraphNetLoss(nn.Module): """ Supervised regression loss with optional continuity regularization. """ def __init__(self, supervised_weight: float = 1.0, continuity_weight: float = 0.0): super().__init__() self.supervised_weight = float(supervised_weight) self.continuity_weight = float(continuity_weight)
[docs] def forward( self, preds: torch.Tensor, targets: torch.Tensor, prev_state: Optional[torch.Tensor] = None, cell_area: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, float]]: supervised = F.mse_loss(preds, targets) total = self.supervised_weight * supervised metrics: Dict[str, float] = {"mse": float(supervised.detach().cpu())} if ( self.continuity_weight > 0 and prev_state is not None and cell_area is not None and preds.size(-1) >= 2 and prev_state.size(-1) >= 2 ): # Approximate local continuity: depth-change * area ~= volume-change depth_delta = preds[..., 0] - prev_state[..., 0] volume_delta = preds[..., 1] - prev_state[..., 1] area = cell_area.to(device=preds.device, dtype=preds.dtype) if area.dim() == 1: area = area.unsqueeze(0) continuity = F.mse_loss(depth_delta * area, volume_delta) total = total + self.continuity_weight * continuity metrics["continuity"] = float(continuity.detach().cpu()) metrics["total"] = float(total.detach().cpu()) return total, metrics
[docs] def hydrographnet_builder( task: str, node_in_dim: int, edge_in_dim: int, out_dim: int, **kwargs, ) -> HydroGraphNet: if task != "regression": raise ValueError("HydroGraphNet only supports regression") return HydroGraphNet( node_in_dim=node_in_dim, edge_in_dim=edge_in_dim, out_dim=out_dim, hidden_dim=kwargs.get("hidden_dim", 64), harmonics=kwargs.get("harmonics", 5), num_gn_blocks=kwargs.get("num_gn_blocks", 5), state_dim=kwargs.get("state_dim"), rollout_steps=kwargs.get("rollout_steps", 1), enforce_nonnegative=kwargs.get("enforce_nonnegative", False), dropout=kwargs.get("dropout", 0.0), )
__all__ = ["HydroGraphNet", "HydroGraphNetLoss", "hydrographnet_builder"]