pygip.models.defense.RandomWM¶
Module Attributes
- class pygip.models.defense.RandomWM.RandomWM(dataset, attack_node_fraction=0.2, wm_node=50, pr=0.2, pg=0.2, attack_name=None)[source]¶
Bases:
BaseDefense
A flexible defense implementation using watermarking to protect against model extraction attacks on graph neural networks.
This class combines the functionalities from the original watermark.py: - Generating watermark graphs - Training models on original and watermark graphs - Merging graphs for testing - Evaluating effectiveness against attacks - Dynamic selection of attack methods
- _abc_impl = <_abc_data object>¶
- _evaluate_attack_on_watermark(attack_model)[source]¶
Evaluate how well the attack model performs on the watermark graph.
- Parameters:
attack_model (torch.nn.Module) – The model obtained from the attack
- Returns:
Attack model’s accuracy on the watermark graph
- Return type:
float
- _evaluate_watermark(model)[source]¶
Evaluate watermark detection effectiveness.
- Parameters:
model (torch.nn.Module) – The model to evaluate
- Returns:
Watermark detection accuracy
- Return type:
float
- _generate_watermark_graph()[source]¶
Generate a watermark graph using Erdos-Renyi random graph model.
- Returns:
The generated watermark graph
- Return type:
dgl.DGLGraph
- _get_attack_class(attack_name)[source]¶
Dynamically import and return the specified attack class.
- Parameters:
attack_name (str) – Name of the attack class to import
- Returns:
The requested attack class
- Return type:
class
- _test_on_watermark(model, wm_dataloader)[source]¶
Test a model’s accuracy on the watermark graph.
- Parameters:
model (torch.nn.Module) – The model to test
wm_dataloader (DataLoader) – DataLoader for the watermark graph
- Returns:
Accuracy on the watermark graph
- Return type:
float
- _train_defense_model()[source]¶
Helper function for training a defense model with watermarking.
- Returns:
The trained defense model with embedded watermark
- Return type:
torch.nn.Module
- _train_target_model()[source]¶
Helper function for training the target model on the original graph.
- Returns:
The trained target model
- Return type:
torch.nn.Module
- defend(attack_name=None)[source]¶
Main defense workflow: 1. Train a target model on the original graph 2. Attack the target model to establish baseline vulnerability 3. Train a defense model with watermarking 4. Test the defense model against the same attack 5. Print performance metrics
- Parameters:
attack_name (str, optional) – Name of the attack class to use, overrides the one set in __init__
- Returns:
Dictionary containing performance metrics
- Return type:
dict
- supported_api_types = {'dgl'}¶