from __future__ import annotations
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 64, dropout: float = 0.0):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class KAN(nn.Module):
"""
Lightweight KAN-style harmonic basis encoder for node features.
"""
def __init__(self, in_dim: int, harmonics: int = 5, hidden_dim: int = 64):
super().__init__()
self.in_dim = in_dim
self.harmonics = harmonics
self.feature_proj = nn.ModuleList(
[nn.Linear(2 * harmonics + 1, hidden_dim) for _ in range(in_dim)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, N, F)
outputs = []
for i in range(self.in_dim):
xi = x[:, :, i].unsqueeze(-1)
basis = [torch.ones_like(xi)]
for k in range(1, self.harmonics + 1):
basis.append(torch.sin(k * xi))
basis.append(torch.cos(k * xi))
basis = torch.cat(basis, dim=-1)
outputs.append(self.feature_proj[i](basis))
return torch.stack(outputs, dim=0).sum(dim=0)
class GNBlock(nn.Module):
"""
Message-passing block with residual edge and node updates.
"""
def __init__(self, hidden_dim: int, dropout: float = 0.0):
super().__init__()
self.edge_mlp = MLP(3 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)
self.node_mlp = MLP(2 * hidden_dim, hidden_dim, hidden_dim, dropout=dropout)
def forward(
self,
node: torch.Tensor,
edge: torch.Tensor,
senders: torch.Tensor,
receivers: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
sender_feat = node[:, senders, :]
receiver_feat = node[:, receivers, :]
edge_input = torch.cat([edge, sender_feat, receiver_feat], dim=-1)
edge = edge + self.edge_mlp(edge_input)
agg = torch.zeros_like(node)
agg.index_add_(1, receivers, edge)
# Degree-normalized aggregation improves stability when graph density changes.
deg = torch.zeros(node.size(1), device=node.device, dtype=node.dtype)
deg.index_add_(0, receivers, torch.ones_like(receivers, dtype=node.dtype))
agg = agg / deg.clamp(min=1.0).view(1, -1, 1)
node_input = torch.cat([node, agg], dim=-1)
node = node + self.node_mlp(node_input)
return node, edge
[docs]
class HydroGraphNet(nn.Module):
"""
PhysicsNeMo-inspired HydroGraphNet:
encoder -> message-passing processor -> residual delta-state decoder.
Supports one-step forward prediction and autoregressive rollout.
"""
def __init__(
self,
node_in_dim: int,
edge_in_dim: int,
out_dim: int,
hidden_dim: int = 64,
harmonics: int = 5,
num_gn_blocks: int = 5,
state_dim: Optional[int] = None,
rollout_steps: int = 1,
enforce_nonnegative: bool = False,
dropout: float = 0.0,
):
super().__init__()
self.node_in_dim = int(node_in_dim)
self.edge_in_dim = int(edge_in_dim)
self.out_dim = int(out_dim)
self.state_dim = int(state_dim) if state_dim is not None else min(2, self.node_in_dim)
self.state_dim = max(1, min(self.state_dim, self.node_in_dim))
if self.out_dim > self.state_dim:
raise ValueError(
f"out_dim={self.out_dim} cannot exceed residual state_dim={self.state_dim}."
)
self.rollout_steps = max(1, int(rollout_steps))
self.enforce_nonnegative = bool(enforce_nonnegative)
# Encoder
self.node_encoder = KAN(
in_dim=self.node_in_dim,
hidden_dim=hidden_dim,
harmonics=harmonics,
)
self.edge_encoder = MLP(
in_dim=self.edge_in_dim,
out_dim=hidden_dim,
hidden_dim=hidden_dim,
dropout=dropout,
)
# Processor
self.processor = nn.ModuleList(
[GNBlock(hidden_dim=hidden_dim, dropout=dropout) for _ in range(num_gn_blocks)]
)
# Decoder predicts delta of physically meaningful states.
self.decoder = MLP(
in_dim=hidden_dim,
out_dim=self.state_dim,
hidden_dim=hidden_dim,
dropout=dropout,
)
[docs]
def _edge_index(self, adj: torch.Tensor, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
if adj.dim() == 2:
a = adj
elif adj.dim() == 3:
if adj.size(0) != batch_size:
raise ValueError(f"adj batch size mismatch: got {adj.size(0)}, expected {batch_size}")
a = adj[0]
for i in range(1, batch_size):
if not torch.allclose(adj[i], a):
raise ValueError(
"Per-sample varying adjacency is not supported yet. "
"Provide a shared (N, N) adjacency or identical (B, N, N) adjacency."
)
else:
raise ValueError("adj must be shaped (N, N) or (B, N, N).")
a = (a > 0).to(dtype=torch.bool)
a.fill_diagonal_(True)
return a.nonzero(as_tuple=True)
[docs]
def _match_edge_dim(self, edge_feat: torch.Tensor) -> torch.Tensor:
# edge_feat: (B, E, F_edge_raw)
f_raw = edge_feat.size(-1)
if f_raw == self.edge_in_dim:
return edge_feat
if f_raw > self.edge_in_dim:
return edge_feat[..., : self.edge_in_dim]
pad = torch.zeros(
edge_feat.size(0),
edge_feat.size(1),
self.edge_in_dim - f_raw,
device=edge_feat.device,
dtype=edge_feat.dtype,
)
return torch.cat([edge_feat, pad], dim=-1)
[docs]
def _one_step(
self,
node_x: torch.Tensor,
batch: Dict[str, torch.Tensor],
) -> torch.Tensor:
# node_x: (B, N, F)
if node_x.ndim != 3:
raise ValueError(f"Expected node_x with shape (B,N,F), got {tuple(node_x.shape)}")
if node_x.size(-1) < self.state_dim:
raise ValueError(
f"Input feature dim {node_x.size(-1)} is smaller than state_dim {self.state_dim}."
)
adj = batch.get("adj")
if adj is None:
raise ValueError("HydroGraphNet requires `adj` in the batch.")
adj = adj.to(device=node_x.device)
senders, receivers = self._edge_index(adj, batch_size=node_x.size(0))
# ---- encoder ----
node = self.node_encoder(node_x)
edge_in = self._prepare_edge_inputs(
batch=batch,
senders=senders,
receivers=receivers,
batch_size=node.size(0),
device=node.device,
dtype=node.dtype,
)
edge = self.edge_encoder(edge_in)
# ---- processor ----
for gn in self.processor:
node, edge = gn(node, edge, senders, receivers)
# ---- decoder: residual state update ----
delta_state = self.decoder(node) # (B, N, state_dim)
prev_state = node_x[..., : self.state_dim]
next_state = prev_state + delta_state
if self.enforce_nonnegative:
next_state = next_state.clamp_min(0.0)
# Return requested targets from the evolved state.
return next_state[..., : self.out_dim]
[docs]
def rollout(self, batch: Dict[str, torch.Tensor], predict_steps: int) -> torch.Tensor:
batch_roll = dict(batch)
batch_roll["predict_steps"] = int(predict_steps)
out = self.forward(batch_roll)
if out.ndim != 4:
raise RuntimeError("rollout expected stacked output with shape (B, S, N, out_dim).")
return out
[docs]
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# batch["x"]: (B, T, N, F)
x = batch["x"]
if x.ndim != 4:
raise ValueError(
f"HydroGraphNet expects x shaped (B, T, N, F), got {tuple(x.shape)}"
)
predict_steps = int(batch.get("predict_steps", self.rollout_steps))
predict_steps = max(1, predict_steps)
history = x
preds = []
for _ in range(predict_steps):
node_x = history[:, -1]
y_next = self._one_step(node_x=node_x, batch=batch) # (B, N, out_dim)
preds.append(y_next)
if predict_steps > 1:
next_frame = history[:, -1].clone()
next_frame[..., : self.out_dim] = y_next
history = torch.cat([history[:, 1:], next_frame.unsqueeze(1)], dim=1)
if predict_steps == 1:
return preds[0]
return torch.stack(preds, dim=1)
[docs]
class HydroGraphNetLoss(nn.Module):
"""
Supervised regression loss with optional continuity regularization.
"""
def __init__(self, supervised_weight: float = 1.0, continuity_weight: float = 0.0):
super().__init__()
self.supervised_weight = float(supervised_weight)
self.continuity_weight = float(continuity_weight)
[docs]
def forward(
self,
preds: torch.Tensor,
targets: torch.Tensor,
prev_state: Optional[torch.Tensor] = None,
cell_area: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, float]]:
supervised = F.mse_loss(preds, targets)
total = self.supervised_weight * supervised
metrics: Dict[str, float] = {"mse": float(supervised.detach().cpu())}
if (
self.continuity_weight > 0
and prev_state is not None
and cell_area is not None
and preds.size(-1) >= 2
and prev_state.size(-1) >= 2
):
# Approximate local continuity: depth-change * area ~= volume-change
depth_delta = preds[..., 0] - prev_state[..., 0]
volume_delta = preds[..., 1] - prev_state[..., 1]
area = cell_area.to(device=preds.device, dtype=preds.dtype)
if area.dim() == 1:
area = area.unsqueeze(0)
continuity = F.mse_loss(depth_delta * area, volume_delta)
total = total + self.continuity_weight * continuity
metrics["continuity"] = float(continuity.detach().cpu())
metrics["total"] = float(total.detach().cpu())
return total, metrics
[docs]
def hydrographnet_builder(
task: str,
node_in_dim: int,
edge_in_dim: int,
out_dim: int,
**kwargs,
) -> HydroGraphNet:
if task != "regression":
raise ValueError("HydroGraphNet only supports regression")
return HydroGraphNet(
node_in_dim=node_in_dim,
edge_in_dim=edge_in_dim,
out_dim=out_dim,
hidden_dim=kwargs.get("hidden_dim", 64),
harmonics=kwargs.get("harmonics", 5),
num_gn_blocks=kwargs.get("num_gn_blocks", 5),
state_dim=kwargs.get("state_dim"),
rollout_steps=kwargs.get("rollout_steps", 1),
enforce_nonnegative=kwargs.get("enforce_nonnegative", False),
dropout=kwargs.get("dropout", 0.0),
)
__all__ = ["HydroGraphNet", "HydroGraphNetLoss", "hydrographnet_builder"]