Source code for pyhazards.datasets.graph

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

import torch
from torch.utils.data import Dataset


[docs] class GraphTemporalDataset(Dataset): """ Simple container for county/day style tensors with an optional adjacency. Each sample is a window of shape (past_days, num_counties, num_features) and a label of shape (num_counties,). """ def __init__( self, x: torch.Tensor, y: torch.Tensor, adjacency: Optional[torch.Tensor] = None, ): """ Args: x: Tensor (samples, past_days, num_counties, num_features) y: Tensor (samples, num_counties) or (samples, num_counties, targets) adjacency: Optional Tensor - (num_counties, num_counties) global adjacency - (samples, num_counties, num_counties) per-sample adjacency """ if x.ndim != 4: raise ValueError("x must be (samples, past_days, num_counties, num_features)") if y.ndim not in (2, 3): raise ValueError("y must be (samples, num_counties) or (samples, num_counties, targets)") if adjacency is not None and adjacency.ndim not in (2, 3): raise ValueError("adjacency must be None, (N,N), or (B,N,N)") if adjacency is not None and adjacency.ndim == 2 and adjacency.size(0) != x.size(2): raise ValueError("adjacency size mismatch with num_counties") if adjacency is not None and adjacency.ndim == 3 and adjacency.size(1) != x.size(2): raise ValueError("adjacency size mismatch with num_counties") self.x = x self.y = y self.adj = adjacency def __len__(self) -> int: return self.x.size(0) def __getitem__(self, idx: int) -> Tuple[Dict[str, Any], torch.Tensor]: adj = None if self.adj is not None: adj = self.adj if self.adj.ndim == 2 else self.adj[idx] return {"x": self.x[idx], "adj": adj}, self.y[idx]
[docs] def graph_collate(batch: List[Tuple[Dict[str, Any], torch.Tensor]]): """ Collate function that stacks x and adjacency if provided. """ xs, ys = zip(*batch) x_tensor = torch.stack([item["x"] for item in xs], dim=0) adj_list = [item["adj"] for item in xs] adj = None if any(a is not None for a in adj_list): # If some entries are None, replace with first non-None first = next(a for a in adj_list if a is not None) adj = torch.stack([a if a is not None else first for a in adj_list], dim=0) y_tensor = torch.stack(ys, dim=0) return {"x": x_tensor, "adj": adj}, y_tensor
__all__ = ["GraphTemporalDataset", "graph_collate"]