Source code for pyhazards.benchmarks.earthquake

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict

import torch
import torch.nn as nn

from ..configs import ExperimentConfig
from ..datasets.base import DataBundle
from .base import Benchmark
from .registry import register_benchmark
from .schemas import BenchmarkResult


[docs] class EarthquakeBenchmark(Benchmark): name = "earthquake" hazard_task = "earthquake.picking" metric_names_by_task = { "earthquake.picking": ["p_pick_mae", "s_pick_mae", "precision", "recall", "f1"], "earthquake.forecasting": ["mae", "mse"], }
[docs] def evaluate(self, model: nn.Module, data: DataBundle, config: ExperimentConfig) -> BenchmarkResult: split = data.get_split(config.benchmark.eval_split) x = split.inputs y = split.targets preds = model(x) if config.benchmark.hazard_task == "earthquake.picking": mae = (preds - y).abs() tolerances = config.benchmark.params.get("detection_tolerances", [4.0, 8.0, 12.0]) threshold_curve: Dict[str, float] = {} detection_rate = 0.0 for tolerance in tolerances: hits = ((preds - y).abs() <= float(tolerance)).all(dim=1).float() hit_rate = float(hits.mean().detach().cpu()) threshold_curve[str(tolerance)] = hit_rate if float(tolerance) == 8.0: detection_rate = hit_rate metrics = { "p_pick_mae": float(mae[:, 0].mean().detach().cpu()), "s_pick_mae": float(mae[:, 1].mean().detach().cpu()), "mean_pick_mae": float(mae.mean().detach().cpu()), "precision": detection_rate, "recall": detection_rate, "f1": detection_rate, } else: mse = torch.mean((preds - y) ** 2) mae = torch.mean(torch.abs(preds - y)) threshold_curve = {} metrics = { "mae": float(mae.detach().cpu()), "mse": float(mse.detach().cpu()), } return BenchmarkResult( benchmark_name=self.name, hazard_task=config.benchmark.hazard_task, metrics=metrics, metadata={ "split": config.benchmark.eval_split, "threshold_curve": threshold_curve, "dataset_name": data.metadata.get("dataset"), "source_dataset": data.metadata.get("source_dataset", data.metadata.get("dataset")), }, )
[docs] def export_report( self, result: BenchmarkResult, output_dir: str, formats, ) -> Dict[str, str]: paths = super().export_report(result, output_dir=output_dir, formats=formats) if result.hazard_task == "earthquake.forecasting": target = Path(output_dir) target.mkdir(parents=True, exist_ok=True) pycsep_path = target / "earthquake_pycsep.json" payload = { "adapter": "pyCSEP-style", "benchmark_name": result.benchmark_name, "hazard_task": result.hazard_task, "metrics": result.metrics, "metadata": result.metadata, } pycsep_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") paths["pycsep"] = str(pycsep_path) return paths
register_benchmark(EarthquakeBenchmark.name, EarthquakeBenchmark) __all__ = ["EarthquakeBenchmark"]