Source code for pyhazards.datasets.tc

from __future__ import annotations

import torch

from ..base import DataBundle, DataSplit, Dataset, FeatureSpec, LabelSpec


[docs] class SyntheticTropicalCycloneDataset(Dataset): """Synthetic storm-history dataset for track/intensity smoke runs.""" name = "tc_tracks_synthetic" def __init__( self, cache_dir: str | None = None, samples: int = 64, history: int = 6, horizon: int = 5, features: int = 8, micro: bool = False, ): super().__init__(cache_dir=cache_dir) self.samples = 20 if micro else int(samples) self.history = int(history) self.horizon = int(horizon) self.features = int(features)
[docs] def _load(self) -> DataBundle: x = torch.randn(self.samples, self.history, self.features, dtype=torch.float32) last_state = x[:, -1, :3] deltas = torch.linspace(0.2, 1.0, steps=self.horizon, dtype=torch.float32).view(1, self.horizon, 1) direction = torch.tensor([0.4, 0.2, 1.5], dtype=torch.float32).view(1, 1, 3) y = last_state.unsqueeze(1) + deltas * direction train_end = max(1, int(0.7 * self.samples)) val_end = max(train_end + 1, int(0.85 * self.samples)) splits = { "train": DataSplit(x[:train_end], y[:train_end]), "val": DataSplit(x[train_end:val_end], y[train_end:val_end]), "test": DataSplit(x[val_end:], y[val_end:]), } return DataBundle( splits=splits, feature_spec=FeatureSpec( input_dim=self.features, description="Synthetic storm history with environmental context features.", extra={"history": self.history, "horizon": self.horizon}, ), label_spec=LabelSpec( num_targets=3, task_type="regression", description="Forecast track latitude/longitude and intensity trajectory.", ), metadata={ "dataset": self.name, "source_dataset": self.name, "hazard_task": "tc.track_intensity", }, )
[docs] class IBTrACSTropicalCycloneDataset(SyntheticTropicalCycloneDataset): """Synthetic-backed adapter for IBTrACS-style storm tracks.""" name = "ibtracs_tracks"
[docs] def _load(self) -> DataBundle: bundle = super()._load() bundle.metadata.update({"adapter": "IBTrACS", "source_dataset": self.name}) return bundle
[docs] class TCBenchAlphaDataset(SyntheticTropicalCycloneDataset): """Synthetic-backed adapter for TCBench Alpha evaluation runs.""" name = "tcbench_alpha"
[docs] def _load(self) -> DataBundle: bundle = super()._load() bundle.metadata.update({"adapter": "TCBench Alpha", "source_dataset": self.name}) return bundle
[docs] class TropiCycloneNetDataset(SyntheticTropicalCycloneDataset): """Synthetic-backed adapter for TropiCycloneNet-Dataset style smoke runs.""" name = "tropicyclonenet_dataset"
[docs] def _load(self) -> DataBundle: bundle = super()._load() bundle.metadata.update({"adapter": "TropiCycloneNet-Dataset", "source_dataset": self.name}) return bundle
__all__ = [ "IBTrACSTropicalCycloneDataset", "SyntheticTropicalCycloneDataset", "TCBenchAlphaDataset", "TropiCycloneNetDataset", ]