[docs]classBaseDefense(ABC):supported_api_types=set()supported_datasets=set()def__init__(self,dataset:Dataset,attack_node_fraction:float,device:Optional[Union[str,torch.device]]=None):self.device=torch.device(device)ifdeviceelseget_device()print(f"Using device: {self.device}")# graph dataself.dataset=datasetself.graph_dataset=dataset.graph_datasetself.graph_data=dataset.graph_data# meta dataself.num_nodes=dataset.num_nodesself.num_features=dataset.num_featuresself.num_classes=dataset.num_classes# paramsself.attack_node_fraction=attack_node_fractionself._check_dataset_compatibility()
[docs]def_check_dataset_compatibility(self):cls_name=self.dataset.__class__.__name__ifself.supported_api_typesandself.dataset.api_typenotinself.supported_api_types:raiseValueError(f"API type '{self.dataset.api_type}' is not supported. Supported: {self.supported_api_types}")ifself.supported_datasetsandcls_namenotinself.supported_datasets:raiseValueError(f"Dataset '{cls_name}' is not supported. Supported: {self.supported_datasets}")
[docs]@abstractmethoddefdefend(self):""" Execute the defense mechanism. """raiseNotImplementedError