Source code for pyhazards.engine.inference
from __future__ import annotations
from typing import Any, Callable, Iterable, List
import torch
[docs]
class SlidingWindowInference:
"""
Placeholder for sliding-window inference over large rasters or grids.
Implement windowing logic and stitching as needed.
"""
def __init__(self, model: torch.nn.Module, window_fn: Callable[..., Iterable[Any]] | None = None):
self.model = model
self.window_fn = window_fn
def __call__(self, inputs: Any) -> List[torch.Tensor]:
if self.window_fn is None:
raise NotImplementedError("Provide a window_fn to generate windows from inputs.")
outputs: List[torch.Tensor] = []
self.model.eval()
with torch.no_grad():
for window in self.window_fn(inputs):
outputs.append(self.model(window))
return outputs