Source code for pyhazards.datasets.base
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Protocol
[docs]
@dataclass
class FeatureSpec:
"""Describes input features (shapes, dtypes, normalization)."""
input_dim: Optional[int] = None
channels: Optional[int] = None
description: Optional[str] = None
extra: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class LabelSpec:
"""Describes labels/targets for downstream tasks."""
num_targets: Optional[int] = None
task_type: str = "regression" # classification|regression|segmentation
description: Optional[str] = None
extra: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class DataSplit:
"""Container for a single split."""
inputs: Any
targets: Any
metadata: Dict[str, Any] = field(default_factory=dict)
[docs]
@dataclass
class DataBundle:
"""
Bundle of train/val/test splits plus metadata.
Keeps feature/label specs to make model construction easy.
"""
splits: Dict[str, DataSplit]
feature_spec: FeatureSpec
label_spec: LabelSpec
metadata: Dict[str, Any] = field(default_factory=dict)
[docs]
def get_split(self, name: str) -> DataSplit:
if name not in self.splits:
raise KeyError(f"Split '{name}' not found. Available: {list(self.splits.keys())}")
return self.splits[name]
[docs]
class Dataset:
"""
Base class for hazard datasets.
Subclasses should load data and return a DataBundle with splits ready for training.
"""
name: str = "base"
def __init__(self, cache_dir: Optional[str] = None):
self.cache_dir = cache_dir
[docs]
def load(self, split: Optional[str] = None, transforms: Optional[List[Transform]] = None) -> DataBundle:
"""
Return a DataBundle. Optionally return a specific split if provided.
"""
bundle = self._load()
if transforms:
for t in transforms:
bundle = t(bundle)
if split:
return DataBundle(
splits={split: bundle.get_split(split)},
feature_spec=bundle.feature_spec,
label_spec=bundle.label_spec,
metadata=bundle.metadata,
)
return bundle
[docs]
def _load(self) -> DataBundle:
raise NotImplementedError("Subclasses must implement _load() to return a DataBundle.")