from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.init as init
[docs]
class ConvLEMCell(nn.Module):
"""
Convolutional Long Expressive Memory (ConvLEM) cell used by WaveCastNet.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
dt: float = 1.0,
activation: str = "tanh",
use_reset_gate: bool = False,
):
super().__init__()
if activation == "tanh":
self.activation = torch.tanh
elif activation == "relu":
self.activation = torch.relu
else:
raise ValueError(
"Unsupported activation: {activation}. Use 'tanh' or 'relu'.".format(
activation=activation
)
)
self.dt = float(dt)
self.use_reset_gate = bool(use_reset_gate)
self.out_channels = int(out_channels)
padding = (kernel_size - 1) // 2
if self.use_reset_gate:
self.conv_x = nn.Conv2d(
in_channels,
5 * out_channels,
kernel_size,
padding=padding,
)
self.conv_h = nn.Conv2d(
out_channels,
4 * out_channels,
kernel_size,
padding=padding,
)
else:
self.conv_x = nn.Conv2d(
in_channels,
4 * out_channels,
kernel_size,
padding=padding,
)
self.conv_h = nn.Conv2d(
out_channels,
3 * out_channels,
kernel_size,
padding=padding,
)
self.conv_c = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
self.W_c1 = nn.Parameter(torch.empty(out_channels, 1, 1))
self.W_c2 = nn.Parameter(torch.empty(out_channels, 1, 1))
if self.use_reset_gate:
self.W_c4 = nn.Parameter(torch.empty(out_channels, 1, 1))
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
for name, param in self.named_parameters():
if "W_c" in name:
nn.init.constant_(param, 0.0)
elif param.ndim > 1:
init.xavier_uniform_(param)
else:
nn.init.constant_(param, 0.0)
[docs]
def forward(
self,
x: torch.Tensor,
h: torch.Tensor,
c: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if x.ndim != 4 or h.ndim != 4 or c.ndim != 4:
raise ValueError("ConvLEMCell expects x, h, c shaped (B, C, H, W).")
conv_x_out = self.conv_x(x)
conv_h_out = self.conv_h(h)
if self.use_reset_gate:
i_dt1, i_dt2, g_dx2, i_c, i_h = torch.chunk(conv_x_out, chunks=5, dim=1)
h_dt1, h_dt2, h_h, g_dh2 = torch.chunk(conv_h_out, chunks=4, dim=1)
ms_dt = self.dt * torch.sigmoid(i_dt2 + h_dt2 + self.W_c2 * c)
c = (1.0 - ms_dt) * c + ms_dt * self.activation(i_h + h_h)
gate2 = self.dt * torch.sigmoid(g_dx2 + g_dh2 + self.W_c4 * c)
conv_c_out = gate2 * self.conv_c(c)
ms_dt_bar = self.dt * torch.sigmoid(i_dt1 + h_dt1 + self.W_c1 * c)
h = (1.0 - ms_dt_bar) * h + ms_dt_bar * self.activation(conv_c_out + i_c)
else:
i_dt1, i_dt2, i_c, i_h = torch.chunk(conv_x_out, chunks=4, dim=1)
h_dt1, h_dt2, h_h = torch.chunk(conv_h_out, chunks=3, dim=1)
ms_dt = self.dt * torch.sigmoid(i_dt2 + h_dt2 + self.W_c2 * c)
c = (1.0 - ms_dt) * c + ms_dt * self.activation(i_h + h_h)
conv_c_out = self.conv_c(c)
ms_dt_bar = self.dt * torch.sigmoid(i_dt1 + h_dt1 + self.W_c1 * c)
h = (1.0 - ms_dt_bar) * h + ms_dt_bar * self.activation(conv_c_out + i_c)
return h, c
[docs]
class WaveCastNet(nn.Module):
"""
Sequence-to-sequence wavefield forecasting model based on ConvLEM cells.
Input shape: (B, C, T_in, H, W)
Output shape: (B, C, T_out, H, W)
"""
def __init__(
self,
in_channels: int,
height: int,
width: int,
temporal_in: int,
temporal_out: int,
hidden_dim: int = 144,
num_layers: int = 2,
kernel_size: int = 3,
dt: float = 1.0,
activation: str = "tanh",
dropout: float = 0.1,
):
super().__init__()
self.in_channels = int(in_channels)
self.height = int(height)
self.width = int(width)
self.temporal_in = int(temporal_in)
self.temporal_out = int(temporal_out)
self.hidden_dim = int(hidden_dim)
self.num_layers = int(num_layers)
padding = (kernel_size - 1) // 2
proj_dim = max(1, self.hidden_dim // 2)
self.input_embed = nn.Sequential(
nn.Conv2d(self.in_channels, self.hidden_dim, kernel_size, padding=padding),
nn.BatchNorm2d(self.hidden_dim),
nn.ReLU(),
nn.Dropout2d(dropout),
)
self.encoder_layers = nn.ModuleList(
[
ConvLEMCell(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
kernel_size=kernel_size,
dt=dt,
activation=activation,
use_reset_gate=False,
)
for _ in range(self.num_layers)
]
)
self.decoder_layers = nn.ModuleList(
[
ConvLEMCell(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
kernel_size=kernel_size,
dt=dt,
activation=activation,
use_reset_gate=False,
)
for _ in range(self.num_layers)
]
)
self.output_proj = nn.Sequential(
nn.Conv2d(self.hidden_dim, proj_dim, kernel_size, padding=padding),
nn.ReLU(),
nn.Dropout2d(dropout),
nn.Conv2d(proj_dim, self.in_channels, kernel_size, padding=padding),
)
self.dropout = nn.Dropout2d(dropout)
[docs]
def _init_states(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
hidden = [
x.new_zeros(x.size(0), self.hidden_dim, self.height, self.width)
for _ in range(self.num_layers)
]
memory = [
x.new_zeros(x.size(0), self.hidden_dim, self.height, self.width)
for _ in range(self.num_layers)
]
return hidden, memory
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 5:
raise ValueError(
"WaveCastNet expects x shaped (B, C, T, H, W), got {shape}".format(
shape=tuple(x.shape)
)
)
batch_size, channels, temporal_in, height, width = x.shape
if channels != self.in_channels:
raise ValueError(
"Expected in_channels={expected}, got {actual}".format(
expected=self.in_channels,
actual=channels,
)
)
if temporal_in != self.temporal_in:
raise ValueError(
"Expected temporal_in={expected}, got {actual}".format(
expected=self.temporal_in,
actual=temporal_in,
)
)
if height != self.height or width != self.width:
raise ValueError(
"Expected spatial size ({h}, {w}), got ({actual_h}, {actual_w})".format(
h=self.height,
w=self.width,
actual_h=height,
actual_w=width,
)
)
encoder_h, encoder_c = self._init_states(x)
for t in range(self.temporal_in):
encoded = self.input_embed(x[:, :, t, :, :])
for i, layer in enumerate(self.encoder_layers):
layer_input = encoded if i == 0 else encoder_h[i - 1]
encoder_h[i], encoder_c[i] = layer(layer_input, encoder_h[i], encoder_c[i])
decoder_h = [state.clone() for state in encoder_h]
decoder_c = [state.clone() for state in encoder_c]
outputs = []
for t in range(self.temporal_out):
decoder_input = encoder_h[-1] if t == 0 else decoder_h[-1]
for i, layer in enumerate(self.decoder_layers):
layer_input = decoder_input if i == 0 else decoder_h[i - 1]
decoder_h[i], decoder_c[i] = layer(layer_input, decoder_h[i], decoder_c[i])
output_t = self.output_proj(self.dropout(decoder_h[-1]))
outputs.append(output_t)
if len(outputs) != self.temporal_out:
raise RuntimeError(
"Decoder generated {actual} steps, expected {expected}".format(
actual=len(outputs),
expected=self.temporal_out,
)
)
return torch.stack(outputs, dim=2)
[docs]
def wavecastnet_builder(
task: str,
in_channels: int,
height: int,
width: int,
temporal_in: int,
temporal_out: int,
**kwargs,
) -> WaveCastNet:
if task.lower() != "regression":
raise ValueError("WaveCastNet only supports regression tasks.")
return WaveCastNet(
in_channels=in_channels,
height=height,
width=width,
temporal_in=temporal_in,
temporal_out=temporal_out,
hidden_dim=kwargs.get("hidden_dim", 144),
num_layers=kwargs.get("num_layers", 2),
kernel_size=kwargs.get("kernel_size", 3),
dt=kwargs.get("dt", 1.0),
activation=kwargs.get("activation", "tanh"),
dropout=kwargs.get("dropout", 0.1),
)
[docs]
class WaveCastNetLoss(nn.Module):
"""
Huber loss used in the WaveCastNet paper.
"""
def __init__(self, delta: float = 1.0):
super().__init__()
self.delta = float(delta)
[docs]
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
diff = pred - target
abs_diff = diff.abs()
quadratic = 0.5 * diff.square()
linear = self.delta * abs_diff - 0.5 * self.delta**2
return torch.where(abs_diff <= self.delta, quadratic, linear).mean()
[docs]
class WavefieldMetrics:
"""
ACC and RFNE metrics reported in the WaveCastNet paper.
"""
[docs]
@staticmethod
def accuracy(pred: torch.Tensor, target: torch.Tensor) -> float:
pred_flat = pred.reshape(pred.size(0), -1)
target_flat = target.reshape(target.size(0), -1)
numerator = (pred_flat * target_flat).sum(dim=1)
pred_norm = pred_flat.square().sum(dim=1).sqrt()
target_norm = target_flat.square().sum(dim=1).sqrt()
acc = numerator / (pred_norm * target_norm).clamp(min=1e-8)
return float(acc.mean().detach().cpu())
[docs]
@staticmethod
def rfne(pred: torch.Tensor, target: torch.Tensor) -> float:
error_norm = (pred - target).reshape(pred.size(0), -1).square().sum(dim=1).sqrt()
target_norm = target.reshape(target.size(0), -1).square().sum(dim=1).sqrt()
rfne = error_norm / target_norm.clamp(min=1e-8)
return float(rfne.mean().detach().cpu())
[docs]
@staticmethod
def compute_all(pred: torch.Tensor, target: torch.Tensor) -> dict[str, float]:
return {
"ACC": WavefieldMetrics.accuracy(pred, target),
"RFNE": WavefieldMetrics.rfne(pred, target),
}
__all__ = [
"ConvLEMCell",
"WaveCastNet",
"WaveCastNetLoss",
"WavefieldMetrics",
"wavecastnet_builder",
]