Source code for pyhazards.engine.runner
from __future__ import annotations
from typing import Optional, Union
import torch.nn as nn
from ..benchmarks import Benchmark, BenchmarkRunSummary, run_benchmark
from ..configs import ExperimentConfig
from ..datasets import load_dataset
from ..datasets.base import DataBundle
from ..models import build_model
[docs]
class BenchmarkRunner:
"""High-level runner that resolves datasets/models and executes a benchmark."""
def __init__(self, benchmark: Optional[Union[str, Benchmark]] = None):
self.benchmark = benchmark
[docs]
def run(
self,
experiment: ExperimentConfig,
model: Optional[nn.Module] = None,
data: Optional[DataBundle] = None,
output_dir: Optional[str] = None,
) -> BenchmarkRunSummary:
built_model = model or self._build_model(experiment)
bundle = data or self._load_data(experiment)
benchmark = self.benchmark or experiment.benchmark.name
return run_benchmark(
benchmark=benchmark,
model=built_model,
data=bundle,
config=experiment,
output_dir=output_dir,
)
[docs]
def _build_model(self, experiment: ExperimentConfig) -> nn.Module:
return build_model(
name=experiment.model.name,
task=experiment.model.task,
**experiment.model.params,
)
[docs]
def _load_data(self, experiment: ExperimentConfig) -> DataBundle:
return load_dataset(
experiment.dataset.name,
**experiment.dataset.params,
).load()