Source code for pyhazards.models.eqtransformer
from __future__ import annotations
import torch
import torch.nn as nn
[docs]
class EQTransformer(nn.Module):
"""Compact sequence model for joint earthquake phase picking."""
def __init__(
self,
in_channels: int = 3,
hidden_dim: int = 48,
num_layers: int = 2,
dropout: float = 0.1,
):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(in_channels, hidden_dim, kernel_size=11, padding=5),
nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3),
nn.ReLU(),
)
self.temporal = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
bidirectional=True,
dropout=dropout if num_layers > 1 else 0.0,
)
self.attention = nn.Linear(2 * hidden_dim, 1)
self.head = nn.Sequential(
nn.Linear(2 * hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2),
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.ndim != 3:
raise ValueError("EQTransformer expects inputs shaped (batch, channels, length).")
encoded = self.encoder(x).transpose(1, 2)
temporal, _ = self.temporal(encoded)
weights = torch.softmax(self.attention(temporal), dim=1)
pooled = torch.sum(weights * temporal, dim=1)
return self.head(pooled)
[docs]
def eqtransformer_builder(
task: str,
in_channels: int = 3,
hidden_dim: int = 48,
num_layers: int = 2,
dropout: float = 0.1,
**kwargs,
) -> nn.Module:
_ = kwargs
if task.lower() != "regression":
raise ValueError("EQTransformer only supports regression-style phase picking outputs.")
return EQTransformer(
in_channels=in_channels,
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout,
)
__all__ = ["EQTransformer", "eqtransformer_builder"]