[docs]classBaseAttack(ABC):"""Abstract base class for attack models. This class provides a common interface for various attack strategies on graph-based machine learning models. It handles device management, dataset loading, and compatibility checks to ensure that the attack can be executed on the given dataset and model API type. Attributes: supported_api_types (set): A set of strings representing the supported API types (e.g., 'pyg', 'dgl'). supported_datasets (set): A set of strings representing the names of supported dataset classes. device (torch.device): The computing device (CPU or GPU) to be used for the attack. dataset (Dataset): The dataset object containing graph data and metadata. graph_dataset: The raw graph dataset from the underlying library. graph_data: The primary graph data structure. num_nodes (int): The number of nodes in the graph. num_features (int): The number of features per node. num_classes (int): The number of classes for node classification. attack_node_fraction (float, optional): The fraction of nodes to be targeted by the attack. model_path (str, optional): The path to a pre-trained target model. """supported_api_types=set()supported_datasets=set()def__init__(self,dataset:Dataset,attack_node_fraction:float=None,model_path:str=None,device:Optional[Union[str,torch.device]]=None):"""Initializes the BaseAttack. Args: dataset (Dataset): The dataset to be attacked. attack_node_fraction (float, optional): The fraction of nodes to target in the attack. Defaults to None. model_path (str, optional): The path to a pre-trained model file. Defaults to None. device (Union[str, torch.device], optional): The device to run the attack on. If None, it will be automatically selected. Defaults to None. """self.device=torch.device(device)ifdeviceelseget_device()print(f"Using device: {self.device}")# graph dataself.dataset=datasetself.graph_dataset=dataset.graph_datasetself.graph_data=dataset.graph_data# meta dataself.num_nodes=dataset.num_nodesself.num_features=dataset.num_featuresself.num_classes=dataset.num_classes# paramsself.attack_node_fraction=attack_node_fractionself.model_path=model_pathself._check_dataset_compatibility()
[docs]def_check_dataset_compatibility(self):"""Checks if the dataset is compatible with the attack. Raises: ValueError: If the dataset's API type or class name is not in the list of supported types. """cls_name=self.dataset.__class__.__name__ifself.supported_api_typesandself.dataset.api_typenotinself.supported_api_types:raiseValueError(f"API type '{self.dataset.api_type}' is not supported. Supported: {self.supported_api_types}")ifself.supported_datasetsandcls_namenotinself.supported_datasets:raiseValueError(f"Dataset '{cls_name}' is not supported. Supported: {self.supported_datasets}")
[docs]@abstractmethoddefattack(self):""" Execute the attack. """raiseNotImplementedError
[docs]def_load_model(self,model_path):""" Load a pre-trained model. """raiseNotImplementedError
[docs]def_train_target_model(self):""" Train the target model if not provided. """raiseNotImplementedError
[docs]def_train_attack_model(self):""" Train the attack model. """raiseNotImplementedError