Source code for pyhazards.models.builder

from __future__ import annotations

from typing import Any, Dict

import torch.nn as nn

from .backbones import CNNPatchEncoder, MLPBackbone, TemporalEncoder
from .heads import ClassificationHead, RegressionHead, SegmentationHead
from .registry import get_model_config


[docs] def build_model(name: str, task: str, **kwargs: Any) -> nn.Module: """ Build a model by name and task. This delegates to registry metadata to keep a consistent interface. """ cfg = get_model_config(name) if cfg is None: raise KeyError(f"Model '{name}' is not registered.") task = task.lower() builder = cfg["builder"] defaults: Dict[str, Any] = cfg.get("defaults", {}) merged = {**defaults, **kwargs, "task": task} return builder(**merged)
[docs] def default_builder(name: str, task: str, **kwargs: Any) -> nn.Module: """ Generic builder for standard backbones + heads. """ task = task.lower() if name == "mlp": backbone = MLPBackbone(kwargs["in_dim"], hidden_dim=kwargs.get("hidden_dim", 256), depth=kwargs.get("depth", 2)) head = _make_head(task, kwargs) return _combine(backbone, head) if name == "cnn": backbone = CNNPatchEncoder(kwargs.get("in_channels", 3), hidden_dim=kwargs.get("hidden_dim", 64)) head = _make_head(task, kwargs, backbone_out_dim=kwargs.get("hidden_dim", 64)) return _combine(backbone, head) if name == "temporal": backbone = TemporalEncoder(kwargs["in_dim"], hidden_dim=kwargs.get("hidden_dim", 128), num_layers=kwargs.get("num_layers", 1)) head = _make_head(task, kwargs) return _combine(backbone, head) raise ValueError(f"Unknown backbone '{name}'.")
def _make_head(task: str, kwargs: Dict[str, Any], backbone_out_dim: int | None = None) -> nn.Module: if task == "classification": in_dim = backbone_out_dim or kwargs.get("hidden_dim") or kwargs["in_dim"] return ClassificationHead(in_dim=in_dim, num_classes=kwargs["out_dim"]) if task == "regression": in_dim = backbone_out_dim or kwargs.get("hidden_dim") or kwargs["in_dim"] return RegressionHead(in_dim=in_dim, out_dim=kwargs.get("out_dim", 1)) if task == "segmentation": in_channels = kwargs.get("hidden_dim") or backbone_out_dim or kwargs.get("in_channels", 1) return SegmentationHead(in_channels=in_channels, num_classes=kwargs["out_dim"]) raise ValueError(f"Unsupported task '{task}'.") def _combine(backbone: nn.Module, head: nn.Module) -> nn.Module: return nn.Sequential(backbone, head) __all__ = ["build_model", "default_builder"]