[docs]@dataclassclassFeatureSpec:"""Describes input features (shapes, dtypes, normalization)."""input_dim:Optional[int]=Nonechannels:Optional[int]=Nonedescription:Optional[str]=Noneextra:Dict[str,Any]=field(default_factory=dict)
[docs]@dataclassclassLabelSpec:"""Describes labels/targets for downstream tasks."""num_targets:Optional[int]=Nonetask_type:str="regression"# classification|regression|segmentationdescription:Optional[str]=Noneextra:Dict[str,Any]=field(default_factory=dict)
[docs]@dataclassclassDataSplit:"""Container for a single split."""inputs:Anytargets:Anymetadata:Dict[str,Any]=field(default_factory=dict)
[docs]@dataclassclassDataBundle:""" Bundle of train/val/test splits plus metadata. Keeps feature/label specs to make model construction easy. """splits:Dict[str,DataSplit]feature_spec:FeatureSpeclabel_spec:LabelSpecmetadata:Dict[str,Any]=field(default_factory=dict)
[docs]defget_split(self,name:str)->DataSplit:ifnamenotinself.splits:raiseKeyError(f"Split '{name}' not found. Available: {list(self.splits.keys())}")returnself.splits[name]
[docs]classTransform(Protocol):"""Callable data transform."""def__call__(self,bundle:DataBundle)->DataBundle:...
[docs]classDataset:""" Base class for hazard datasets. Subclasses should load data and return a DataBundle with splits ready for training. """name:str="base"def__init__(self,cache_dir:Optional[str]=None):self.cache_dir=cache_dir
[docs]defload(self,split:Optional[str]=None,transforms:Optional[List[Transform]]=None)->DataBundle:""" Return a DataBundle. Optionally return a specific split if provided. """bundle=self._load()iftransforms:fortintransforms:bundle=t(bundle)ifsplit:returnDataBundle(splits={split:bundle.get_split(split)},feature_spec=bundle.feature_spec,label_spec=bundle.label_spec,metadata=bundle.metadata,)returnbundle
[docs]def_load(self)->DataBundle:raiseNotImplementedError("Subclasses must implement _load() to return a DataBundle.")