Source code for pyhazards.datasets.earthquake
from __future__ import annotations
import math
import torch
from ..base import DataBundle, DataSplit, Dataset, FeatureSpec, LabelSpec
[docs]
class SyntheticEarthquakeWaveformDataset(Dataset):
"""Synthetic waveform dataset for earthquake phase-picking smoke runs."""
name = "earthquake_waveforms"
def __init__(
self,
cache_dir: str | None = None,
samples: int = 96,
channels: int = 3,
length: int = 256,
micro: bool = False,
):
super().__init__(cache_dir=cache_dir)
self.samples = 24 if micro else int(samples)
self.channels = int(channels)
self.length = int(length)
[docs]
def _load(self) -> DataBundle:
timeline = torch.linspace(0.0, 1.0, steps=self.length, dtype=torch.float32)
x = torch.zeros(self.samples, self.channels, self.length, dtype=torch.float32)
y = torch.zeros(self.samples, 2, dtype=torch.float32)
for idx in range(self.samples):
p_pick = 32 + (idx % 40)
s_pick = min(self.length - 12, p_pick + 24 + (idx % 24))
for channel in range(self.channels):
phase = 0.5 * channel
base = torch.sin(2.0 * math.pi * (channel + 1) * timeline + phase)
pulse_p = torch.exp(-0.5 * ((torch.arange(self.length) - p_pick) / 6.0) ** 2)
pulse_s = 0.8 * torch.exp(-0.5 * ((torch.arange(self.length) - s_pick) / 8.0) ** 2)
x[idx, channel] = base + pulse_p + pulse_s
y[idx, 0] = float(p_pick)
y[idx, 1] = float(s_pick)
train_end = max(1, int(0.7 * self.samples))
val_end = max(train_end + 1, int(0.85 * self.samples))
splits = {
"train": DataSplit(x[:train_end], y[:train_end]),
"val": DataSplit(x[train_end:val_end], y[train_end:val_end]),
"test": DataSplit(x[val_end:], y[val_end:]),
}
return DataBundle(
splits=splits,
feature_spec=FeatureSpec(
channels=self.channels,
description="Synthetic multichannel seismic waveforms with Gaussian phase arrivals.",
extra={"length": self.length},
),
label_spec=LabelSpec(
num_targets=2,
task_type="regression",
description="P- and S-arrival sample indices.",
),
metadata={
"dataset": self.name,
"source_dataset": self.name,
"hazard_task": "earthquake.picking",
},
)
[docs]
class SyntheticEarthquakeForecastDataset(Dataset):
"""Synthetic wavefield dataset for earthquake forecasting smoke runs."""
name = "earthquake_forecast_synthetic"
def __init__(
self,
cache_dir: str | None = None,
samples: int = 40,
channels: int = 3,
temporal_in: int = 5,
temporal_out: int = 4,
height: int = 12,
width: int = 10,
micro: bool = False,
):
super().__init__(cache_dir=cache_dir)
self.samples = 10 if micro else int(samples)
self.channels = int(channels)
self.temporal_in = int(temporal_in)
self.temporal_out = int(temporal_out)
self.height = int(height)
self.width = int(width)
[docs]
def _load(self) -> DataBundle:
grid_y = torch.linspace(-1.0, 1.0, steps=self.height, dtype=torch.float32).view(self.height, 1)
grid_x = torch.linspace(-1.0, 1.0, steps=self.width, dtype=torch.float32).view(1, self.width)
total_steps = self.temporal_in + self.temporal_out
x = torch.zeros(
self.samples,
self.channels,
self.temporal_in,
self.height,
self.width,
dtype=torch.float32,
)
y = torch.zeros(
self.samples,
self.channels,
self.temporal_out,
self.height,
self.width,
dtype=torch.float32,
)
row_index = torch.arange(self.height, dtype=torch.float32).view(self.height, 1)
col_index = torch.arange(self.width, dtype=torch.float32).view(1, self.width)
for idx in range(self.samples):
sequence = torch.zeros(
self.channels,
total_steps,
self.height,
self.width,
dtype=torch.float32,
)
for step in range(total_steps):
center_r = 2.0 + ((idx + step) % max(3, self.height - 2))
center_c = 1.0 + ((2 * idx + step) % max(2, self.width - 1))
gaussian = torch.exp(
-0.18 * ((row_index - center_r) ** 2 + (col_index - center_c) ** 2)
)
for channel in range(self.channels):
phase = 0.5 * channel + 0.2 * step
base = torch.sin(
math.pi * (channel + 1) * grid_y + phase
) + torch.cos(math.pi * (channel + 1) * grid_x - phase)
sequence[channel, step] = base + (0.6 + 0.1 * channel) * gaussian
x[idx] = sequence[:, : self.temporal_in]
y[idx] = sequence[:, self.temporal_in :]
train_end = max(1, int(0.7 * self.samples))
val_end = max(train_end + 1, int(0.85 * self.samples))
splits = {
"train": DataSplit(x[:train_end], y[:train_end]),
"val": DataSplit(x[train_end:val_end], y[train_end:val_end]),
"test": DataSplit(x[val_end:], y[val_end:]),
}
return DataBundle(
splits=splits,
feature_spec=FeatureSpec(
channels=self.channels,
description="Synthetic dense-grid wavefield history tensors for forecasting benchmarks.",
extra={
"temporal_in": self.temporal_in,
"temporal_out": self.temporal_out,
"height": self.height,
"width": self.width,
},
),
label_spec=LabelSpec(
num_targets=self.channels * self.temporal_out,
task_type="regression",
description="Future dense-grid wavefield frames over the forecast horizon.",
),
metadata={
"dataset": self.name,
"source_dataset": self.name,
"hazard_task": "earthquake.forecasting",
},
)
[docs]
class SeisBenchWaveformDataset(SyntheticEarthquakeWaveformDataset):
"""Synthetic-backed adapter with the SeisBench public dataset surface."""
name = "seisbench_waveforms"
[docs]
def _load(self) -> DataBundle:
bundle = super()._load()
bundle.metadata.update({"adapter": "SeisBench", "source_dataset": self.name})
return bundle
[docs]
class PickBenchmarkWaveformDataset(SyntheticEarthquakeWaveformDataset):
"""Synthetic-backed adapter with the pick-benchmark public dataset surface."""
name = "pick_benchmark_waveforms"
[docs]
def _load(self) -> DataBundle:
bundle = super()._load()
bundle.metadata.update({"adapter": "pick-benchmark", "source_dataset": self.name})
return bundle
[docs]
class AEFADataset(SyntheticEarthquakeForecastDataset):
"""Synthetic-backed adapter for AEFA-style earthquake forecasting inputs."""
name = "aefa_forecast"
[docs]
def _load(self) -> DataBundle:
bundle = super()._load()
bundle.metadata.update({"adapter": "AEFA", "source_dataset": self.name})
return bundle
__all__ = [
"AEFADataset",
"PickBenchmarkWaveformDataset",
"SeisBenchWaveformDataset",
"SyntheticEarthquakeForecastDataset",
"SyntheticEarthquakeWaveformDataset",
]