pygip.models.defense.RandomWM¶
Module Attributes
- class pygip.models.defense.RandomWM.RandomWM(dataset, defense_ratio=0.1, wm_node=50, pr=0.2, pg=0.2, attack_name=None)[source]¶
Bases:
BaseDefenseA flexible defense implementation using watermarking to protect against model extraction attacks on graph neural networks.
This class combines the functionalities from the original watermark.py: - Generating watermark graphs - Training models on original and watermark graphs - Merging graphs for testing - Evaluating effectiveness against attacks - Dynamic selection of attack methods
- _abc_impl = <_abc_data object>¶
- _evaluate_attack_on_watermark(attack_model)[source]¶
Evaluate how well the attack model performs on the watermark graph.
- Parameters:
attack_model (torch.nn.Module) – The model obtained from the attack
- Returns:
Attack model’s accuracy on the watermark graph
- Return type:
float
- _evaluate_watermark(model)[source]¶
Evaluate watermark detection effectiveness.
- Parameters:
model (torch.nn.Module) – The model to evaluate
- Returns:
Watermark detection accuracy
- Return type:
float
- _generate_watermark_graph()[source]¶
Generate a watermark graph using Erdos-Renyi random graph model.
- Returns:
The generated watermark graph
- Return type:
dgl.DGLGraph
- _get_attack_class(attack_name)[source]¶
Dynamically import and return the specified attack class.
- Parameters:
attack_name (str) – Name of the attack class to import
- Returns:
The requested attack class
- Return type:
class
- _test_on_watermark(model, wm_dataloader)[source]¶
Test a model’s accuracy on the watermark graph.
- Parameters:
model (torch.nn.Module) – The model to test
wm_dataloader (DataLoader) – DataLoader for the watermark graph
- Returns:
Accuracy on the watermark graph
- Return type:
float
- _train_defense_model()[source]¶
Helper function for training a defense model with watermarking.
- Returns:
The trained defense model with embedded watermark
- Return type:
torch.nn.Module
- _train_target_model()[source]¶
Helper function for training the target model on the original graph.
- Returns:
The trained target model
- Return type:
torch.nn.Module
- supported_api_types = {'dgl'}¶