Implementation¶
PyGIP is built to be modular and extensible, allowing contributors to implement their own attack and defense strategies. Below, we detail how to extend the framework by implementing custom attack and defense classes, with a focus on how to leverage the provided dataset structure.
Dataset¶
The Dataset
class standardizes the data format across PyGIP. Here’s its structure:
class Dataset(object):
def __init__(self, api_type='dgl', path='./data'):
assert api_type in {'dgl', 'pyg'}, 'API type must be dgl or pyg'
self.api_type = api_type
self.path = path
self.dataset_name = self.get_name()
# DGLGraph or PyGData
self.graph_dataset = None
self.graph_data = None
# meta data
self.num_nodes = 0
self.num_features = 0
self.num_classes = 0
Importance: We are currently using the default api_type='pyg'
to load the data. It is important to note that when
api_type='pyg'
, self.graph_data
should be an instance of torch_geometric.data.Data
. In your implementation, make
sure to use our defined Dataset class to build your code.
Device¶
To ensure consistency and simplicity when managing CUDA devices across attacks and defenses, we follow the convention below:
Both
BaseAttack
andBaseDefense
define the device attributeself.device
in their__init__()
method.Subclasses should not manually redefine or modify the device logic.
If you are implementing a custom attack or defense class, simply inherit from
BaseAttack
orBaseDefense
.You can directly access the device using:
x = x.to(self.device)
Implementing Attack¶
To create a custom attack, you need to extend the abstract base class BaseAttack
. Here’s the structure
of BaseAttack
:
class BaseAttack(ABC):
supported_api_types = set()
supported_datasets = set()
def __init__(self, dataset: Dataset, attack_node_fraction: float = None, model_path: str = None,
device: Optional[Union[str, torch.device]] = None):
self.device = torch.device(device) if device else get_device()
print(f"Using device: {self.device}")
# graph data
self.dataset = dataset
self.graph_dataset = dataset.graph_dataset
self.graph_data = dataset.graph_data
# meta data
self.num_nodes = dataset.num_nodes
self.num_features = dataset.num_features
self.num_classes = dataset.num_classes
# params
self.attack_node_fraction = attack_node_fraction
self.model_path = model_path
self._check_dataset_compatibility()
To implement your own attack:
Inherit from ``BaseAttack``: Create a new class that inherits from
BaseAttack
. You’ll need to provide the following required parameters in the constructor:dataset
: An instance of theDataset
class (see below for details).attack_node_fraction
: A float between 0 and 1 representing the fraction of nodes to attack.model_path
(optional): A string specifying the path to a pre-trained model (defaults toNone
).
You need to implement the following methods:
attack()
: Add main attack logic here. If multiple attack types are supported, define the attack type as an optional argument to this function. For each specific attack type, implement a corresponding helper function such as_attack_type1()
or_attack_type2()
, and call the appropriate helper insideattack()
based on the given method name._load_model()
: Load victim model._train_target_model()
: Train victim model._train_attack_model()
: Train attack model._helper_func()
(optional): Add your helper functions based on your needs, but keep the methods private.
Implement the ``attack()`` Method: Override the abstract
attack()
method with your attack logic, and return a dict of results. For example:
class MyCustomAttack(BaseAttack):
supported_api_types = {"pyg"} # "pyg" or "dgl"
supported_datasets = {"Cora"} # you can leave this blank if your method supports all datasets
def __init__(self, dataset: Dataset, attack_node_fraction: float, model_path: str = None):
super().__init__(dataset, attack_node_fraction, model_path)
# Additional initialization if needed
def attack(self):
# Example: Access the graph and perform an attack
print(f"Attacking {self.attack_node_fraction * 100}% of nodes")
num_nodes = self.graph.num_nodes()
print(f"Graph has {num_nodes} nodes")
# Add your attack logic here
return {
'metric1': 'metric1 here',
'metric2': 'metric2 here'
}
def _load_model(self):
# add your logic here
pass
def _train_target_model(self):
# add your logic here
pass
def _train_attack_model(self):
# add your logic here
pass
Implementing Defense¶
To create a custom defense, you need to extend the abstract base class BaseDefense
. Here’s the structure
of BaseDefense
:
class BaseDefense(ABC):
supported_api_types = set()
supported_datasets = set()
def __init__(self, dataset: Dataset, attack_node_fraction: float,
device: Optional[Union[str, torch.device]] = None):
self.device = torch.device(device) if device else get_device()
print(f"Using device: {self.device}")
# graph data
self.dataset = dataset
self.graph_dataset = dataset.graph_dataset
self.graph_data = dataset.graph_data
# meta data
self.num_nodes = dataset.num_nodes
self.num_features = dataset.num_features
self.num_classes = dataset.num_classes
# params
self.attack_node_fraction = attack_node_fraction
self._check_dataset_compatibility()
To implement your own defense:
Inherit from ``BaseDefense``: Create a new class that inherits from
BaseDefense
. You’ll need to provide the following required parameters in the constructor:dataset
: An instance of theDataset
class (see below for details).attack_node_fraction
: A float between 0 and 1 representing the fraction of nodes to attack.model_path
(optional): A string specifying the path to a pre-trained model (defaults toNone
).
You need to implement the following methods:
defense()
: Add main defense logic here. If multiple defense types are supported, define the defense type as an optional argument to this function. For each specific defense type, implement a corresponding helper function such as_defense_type1()
or_defense_type2()
, and call the appropriate helper insidedefense()
._load_model()
: Load victim model._train_target_model()
: Train victim model._train_defense_model()
: Train defense model._train_surrogate_model()
: Train attack model._helper_func()
(optional): Add your helper functions based on your needs, but keep the methods private.
Implement the ``defense()`` Method: Override the abstract
defense()
method with your defense logic, and return a dict of results. For example:
class MyCustomDefense(BaseDefense):
supported_api_types = {"pyg"} # "pyg" or "dgl"
supported_datasets = {"Cora"} # you can leave this blank if your method supports all datasets
def defend(self):
# Step 1: Train target model
target_model = self._train_target_model()
# Step 2: Attack target model
attack = MyCustomAttack(self.dataset, attack_node_fraction=0.3)
attack.attack(target_model)
# Step 3: Train defense model
defense_model = self._train_defense_model()
# Step 4: Test defense against attack
attack = MyCustomAttack(self.dataset, attack_node_fraction=0.3)
attack.attack(defense_model)
# Print performance metrics
def _load_model(self):
# add your logic here
pass
def _train_target_model(self):
# add your logic here
pass
def _train_defense_model(self):
# add your logic here
pass
def _train_surrogate_model(self):
# add your logic here
pass
Miscellaneous Tips¶
Reference Implementation: The
ModelExtractionAttack0
class is a fully implemented attack example. Study it for inspiration or as a template.Flexibility: Add as many helper functions as needed within your class to keep your code clean and modular.
Backbone Models: We provide several basic backbone models like
GCN, GraphSAGE
. You can use or add more atfrom models.nn import GraphSAGE
.Example Scripts: Please provide an example script in the
examples/
folder demonstrating how to run your code. This will significantly speed up our code review process.
By following these guidelines, you can seamlessly integrate your custom attack or defense strategies into PyGIP. Happy coding!