Source code for pyhazards.utils.hardware
from __future__ import annotations
import os
from typing import Optional
import torch
_DEFAULT_DEVICE_STR = os.getenv("PYHAZARDS_DEVICE") or ("cuda:0" if torch.cuda.is_available() else "cpu")
_default_device = torch.device(_DEFAULT_DEVICE_STR)
[docs]
def auto_device(prefer: str | None = None) -> torch.device:
"""
Choose a device automatically. Respects PYHAZARDS_DEVICE and prefer flag.
"""
if prefer:
return torch.device(prefer)
return _default_device
[docs]
def num_devices() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
return 0
[docs]
def get_device() -> torch.device:
return _default_device
[docs]
def set_device(device_str: str | torch.device) -> None:
global _default_device
_default_device = torch.device(device_str)