Source code for pyhazards.datasets.wildfire

from __future__ import annotations

import torch

from ..base import DataBundle, DataSplit, Dataset, FeatureSpec, LabelSpec


[docs] class SyntheticWildfireSpreadDataset(Dataset): """Synthetic raster dataset for wildfire spread smoke runs.""" name = "wildfire_spread_synthetic" def __init__( self, cache_dir: str | None = None, samples: int = 64, channels: int = 12, height: int = 32, width: int = 32, micro: bool = False, ): super().__init__(cache_dir=cache_dir) self.samples = 16 if micro else int(samples) self.channels = int(channels) self.height = int(height) self.width = int(width)
[docs] def _load(self) -> DataBundle: x = torch.randn(self.samples, self.channels, self.height, self.width, dtype=torch.float32) y = torch.zeros(self.samples, 1, self.height, self.width, dtype=torch.float32) rows = torch.arange(self.height).view(1, self.height, 1) cols = torch.arange(self.width).view(1, 1, self.width) for idx in range(self.samples): center_r = (idx * 3) % self.height center_c = (idx * 5) % self.width radius = 4 + (idx % 5) mask = ((rows - center_r).float().pow(2) + (cols - center_c).float().pow(2)) <= radius**2 y[idx, 0] = mask.float() x[idx, 0] = x[idx, 0] + 2.5 * mask.float() train_end = max(1, int(0.7 * self.samples)) val_end = max(train_end + 1, int(0.85 * self.samples)) splits = { "train": DataSplit(x[:train_end], y[:train_end]), "val": DataSplit(x[train_end:val_end], y[train_end:val_end]), "test": DataSplit(x[val_end:], y[val_end:]), } return DataBundle( splits=splits, feature_spec=FeatureSpec( channels=self.channels, description="Synthetic raster weather and fuel covariates for wildfire spread.", ), label_spec=LabelSpec( num_targets=1, task_type="segmentation", description="Binary spread mask for the next forecast horizon.", ), metadata={ "dataset": self.name, "source_dataset": self.name, "hazard_task": "wildfire.spread", }, )
[docs] class SyntheticWildfireSpreadTemporalDataset(Dataset): """Synthetic temporal wildfire spread dataset for sequence-based spread baselines.""" name = "wildfire_spread_temporal_synthetic" def __init__( self, cache_dir: str | None = None, samples: int = 48, history: int = 4, channels: int = 6, height: int = 16, width: int = 16, micro: bool = False, ): super().__init__(cache_dir=cache_dir) self.samples = 12 if micro else int(samples) self.history = int(history) self.channels = int(channels) self.height = int(height) self.width = int(width)
[docs] def _load(self) -> DataBundle: x = torch.randn( self.samples, self.history, self.channels, self.height, self.width, dtype=torch.float32, ) y = torch.zeros(self.samples, 1, self.height, self.width, dtype=torch.float32) rows = torch.arange(self.height).view(1, self.height, 1) cols = torch.arange(self.width).view(1, 1, self.width) for idx in range(self.samples): center_r = (idx * 2 + 3) % self.height center_c = (idx * 3 + 5) % self.width radius = 3 + (idx % 4) final_mask = ( ((rows - center_r).float().pow(2) + (cols - center_c).float().pow(2)) <= radius**2 ).float() y[idx, 0] = final_mask for step in range(self.history): inner_radius = max(1, radius - (self.history - step - 1)) history_mask = ( ((rows - center_r).float().pow(2) + (cols - center_c).float().pow(2)) <= inner_radius**2 ).float() x[idx, step, 0] = x[idx, step, 0] + history_mask train_end = max(1, int(0.7 * self.samples)) val_end = max(train_end + 1, int(0.85 * self.samples)) splits = { "train": DataSplit(x[:train_end], y[:train_end]), "val": DataSplit(x[train_end:val_end], y[train_end:val_end]), "test": DataSplit(x[val_end:], y[val_end:]), } return DataBundle( splits=splits, feature_spec=FeatureSpec( channels=self.channels, description="Synthetic temporal wildfire spread covariates over forecast history windows.", extra={"history": self.history}, ), label_spec=LabelSpec( num_targets=1, task_type="segmentation", description="Binary spread mask for the next forecast horizon.", ), metadata={ "dataset": self.name, "source_dataset": self.name, "hazard_task": "wildfire.spread", }, )
__all__ = ["SyntheticWildfireSpreadDataset", "SyntheticWildfireSpreadTemporalDataset"]