Models

Summary

PyHazards provides a lightweight, extensible model architecture with:

  • Backbones for common data types: MLP (tabular), CNN patch encoder (raster), temporal encoder (time-series).

  • Task heads: classification, regression, segmentation.

  • A registry-driven builder so you can construct built-ins by name or register your own.

Core modules

  • pyhazards.models.backbones — reusable feature extractors.

  • pyhazards.models.heads — task-specific heads.

  • pyhazards.models.builderbuild_model(name, task, **kwargs) helper plus default_builder.

  • pyhazards.models.registryregister_model / available_models.

  • pyhazards.models — convenience re-exports and default registrations for mlp, cnn, temporal.

Build a built-in model

from pyhazards.models import build_model

model = build_model(
    name="mlp",
    task="classification",
    in_dim=32,
    out_dim=5,
    hidden_dim=256,
    depth=3,
)

Register a custom model

Create a builder function that returns an nn.Module and register it with a name. The registry handles defaults and discoverability.

import torch.nn as nn
from pyhazards.models import register_model, build_model

def my_custom_builder(task: str, in_dim: int, out_dim: int, **kwargs) -> nn.Module:
    hidden = kwargs.get("hidden_dim", 128)
    layers = nn.Sequential(
        nn.Linear(in_dim, hidden),
        nn.ReLU(),
        nn.Linear(hidden, out_dim),
    )
    return layers

register_model("my_mlp", my_custom_builder, defaults={"hidden_dim": 128})

model = build_model(name="my_mlp", task="regression", in_dim=16, out_dim=1)

Mamba-based wildfire model (spatio-temporal)

PyHazards ships a Mamba-style spatio-temporal model for county-day wildfire prediction using ERA5 features. It couples a selective state-space temporal encoder with a lightweight GCN to mix neighboring counties.

  • Input: (batch, past_days, num_counties, num_features) county-day ERA5 tensors.

  • Temporal: stacked selective SSM blocks plus a differential branch to highlight day-to-day changes.

  • Spatial: two-layer GCN over a provided adjacency (falls back to identity if none is given).

  • Output: per-county logits for the next day (apply torch.sigmoid for probabilities). Optional count head via with_count_head=True.

Toy usage with random ERA5-like tensors:

import torch
from pyhazards.datasets import DataBundle, DataSplit, FeatureSpec, LabelSpec
from pyhazards.engine import Trainer
from pyhazards.models import build_model

past_days = 12
num_counties = 5
num_features = 6  # e.g., t2m, d2m, u10, v10, tp, ssr
samples = 64

# Fake county-day ERA5 cube and binary fire labels
x = torch.randn(samples, past_days, num_counties, num_features)
y = torch.randint(0, 2, (samples, num_counties)).float()
adjacency = torch.eye(num_counties)  # replace with a distance or correlation matrix

bundle = DataBundle(
    splits={
        "train": DataSplit(x[:48], y[:48]),
        "val": DataSplit(x[48:], y[48:]),
    },
    feature_spec=FeatureSpec(input_dim=num_features, extra={"past_days": past_days, "counties": num_counties}),
    label_spec=LabelSpec(num_targets=num_counties, task_type="classification"),
)

model = build_model(
    name="wildfire_mamba",
    task="classification",
    in_dim=num_features,
    num_counties=num_counties,
    past_days=past_days,
    adjacency=adjacency,
)

trainer = Trainer(model=model, mixed_precision=False)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.BCEWithLogitsLoss()

# Fit on the toy data; Trainer works because inputs/targets are plain tensors
trainer.fit(bundle, optimizer=optimizer, loss_fn=loss_fn, max_epochs=2, batch_size=8)

# Predict probabilities for the next day
with torch.no_grad():
    logits = model(x[:1])
    probs = torch.sigmoid(logits)
    print(probs.shape)  # (1, num_counties)

# For more complex batches (dicts with adjacency), wrap tensors in GraphTemporalDataset
# and pass graph_collate to Trainer.fit/evaluate/predict.

Design notes

  • Builders receive task plus any kwargs you pass; use this to switch heads internally if needed.

  • register_model stores optional defaults so you can keep CLI/configs minimal.

  • Models are plain PyTorch modules, so you can compose them with the Trainer or your own loops.