Source code for pygip.models.defense.RandomWM

import importlib

import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.dataloading import NeighborSampler, NodeCollator
from torch.utils.data import DataLoader
from torch_geometric.utils import erdos_renyi_graph
from tqdm import tqdm

from pygip.models.defense.base import BaseDefense
from pygip.models.nn import GraphSAGE


[docs]class RandomWM(BaseDefense): """ A flexible defense implementation using watermarking to protect against model extraction attacks on graph neural networks. This class combines the functionalities from the original watermark.py: - Generating watermark graphs - Training models on original and watermark graphs - Merging graphs for testing - Evaluating effectiveness against attacks - Dynamic selection of attack methods """ supported_api_types = {"dgl"} def __init__(self, dataset, attack_node_fraction=0.2, wm_node=50, pr=0.2, pg=0.2, attack_name=None): """ Initialize the custom defense. Parameters ---------- dataset : Dataset The original dataset containing the graph to defend attack_node_fraction : float, optional Fraction of nodes to consider for attack (default: 0.25) wm_node : int, optional Number of nodes in the watermark graph (default: 50) pr : float, optional Probability for feature generation in watermark (default: 0.1) pg : float, optional Probability for edge creation in watermark (default: 0) attack_name : str, optional Name of the attack class to use (default: None, will use ModelExtractionAttack0) """ super().__init__(dataset, attack_node_fraction) self.attack_name = attack_name or "ModelExtractionAttack0" self.dataset = dataset self.graph = dataset.graph_data # Extract dataset properties self.node_number = dataset.num_nodes self.feature_number = dataset.num_features self.label_number = dataset.num_classes self.attack_node_number = int(self.node_number * attack_node_fraction) # Watermark parameters self.wm_node = wm_node self.pr = pr self.pg = pg # Extract features and labels self.features = self.graph.ndata['feat'] self.labels = self.graph.ndata['label'] # Extract masks self.train_mask = self.graph.ndata['train_mask'] self.test_mask = self.graph.ndata['test_mask'] # Move tensors to self.device if self.device != 'cpu': self.graph = self.graph.to(self.device) self.features = self.features.to(self.device) self.labels = self.labels.to(self.device) self.train_mask = self.train_mask.to(self.device) self.test_mask = self.test_mask.to(self.device)
[docs] def _get_attack_class(self, attack_name): """ Dynamically import and return the specified attack class. Parameters ---------- attack_name : str Name of the attack class to import Returns ------- class The requested attack class """ try: # Try to import from models.attack module attack_module = importlib.import_module('models.attack') attack_class = getattr(attack_module, attack_name) return attack_class except (ImportError, AttributeError) as e: print(f"Error loading attack class '{attack_name}': {e}") print("Falling back to ModelExtractionAttack0") # Fallback to ModelExtractionAttack0 attack_module = importlib.import_module('models.attack') return getattr(attack_module, "ModelExtractionAttack0")
[docs] def defend(self, attack_name=None): """ Main defense workflow: 1. Train a target model on the original graph 2. Attack the target model to establish baseline vulnerability 3. Train a defense model with watermarking 4. Test the defense model against the same attack 5. Print performance metrics Parameters ---------- attack_name : str, optional Name of the attack class to use, overrides the one set in __init__ Returns ------- dict Dictionary containing performance metrics """ # Use the provided attack_name or fall back to the one from __init__ attack_name = attack_name or self.attack_name AttackClass = self._get_attack_class(attack_name) print(f"Using attack method: {attack_name}") # Step 1: Train target model target_model = self._train_target_model() # Step 2: Attack target model attack = AttackClass(self.dataset, attack_node_fraction=0.2) target_attack_results = attack.attack() print("Attack results on target model:") if isinstance(target_attack_results, dict): if 'success_rate' in target_attack_results: print(f"Attack success rate: {target_attack_results['success_rate']:.4f}") if 'similarity' in target_attack_results: print(f"Model similarity: {target_attack_results['similarity']:.4f}") else: print("Attack completed. Results structure varies by attack type.") target_attack_results = {"completed": True} # Step 3: Train defense model with watermarking target_attack_model = attack.net2 if hasattr(attack, 'net2') else None defense_model = self._train_defense_model() # Step 4: Test the defense model against the same attack attack = AttackClass(self.dataset, attack_node_fraction=0.2) defense_attack_results = attack.attack() defense_attack_model = attack.net2 if hasattr(attack, 'net2') else None watermark_accuracy_by_target_attack = 0 if target_attack_model is not None: watermark_accuracy_by_target_attack = self._evaluate_attack_on_watermark(target_attack_model) print(f"Target attack model's accuracy on watermark: {watermark_accuracy_by_target_attack:.4f}") watermark_accuracy_by_defense_attack = 0 if defense_attack_model is not None: watermark_accuracy_by_defense_attack = self._evaluate_attack_on_watermark(defense_attack_model) print(f"Defense attack model's accuracy on watermark: {watermark_accuracy_by_defense_attack:.4f}") # Step 5: Print performance metrics print("\nPerformance metrics:") print("Attack results on defense model:") if isinstance(defense_attack_results, dict): if 'success_rate' in defense_attack_results: print(f"Attack success rate: {defense_attack_results['success_rate']:.4f}") if 'similarity' in defense_attack_results: print(f"Model similarity: {defense_attack_results['similarity']:.4f}") # Calculate defense effectiveness if metrics are available if 'success_rate' in target_attack_results and 'success_rate' in defense_attack_results: effectiveness = 1 - defense_attack_results['success_rate'] / max(target_attack_results['success_rate'], 1e-10) print(f"Defense effectiveness: {effectiveness:.4f}") else: print("Attack completed. Results structure varies by attack type.") defense_attack_results = {"completed": True} # Evaluate watermark detection wm_detection = self._evaluate_watermark(defense_model) print(f"Watermark detection accuracy: {wm_detection:.4f}") return { "target_attack_results": target_attack_results, "defense_attack_results": defense_attack_results, "watermark_detection": wm_detection, "target_attack_watermark_accuracy": watermark_accuracy_by_target_attack, "defense_attack_watermark_accuracy": watermark_accuracy_by_defense_attack }
[docs] def _train_target_model(self): """ Helper function for training the target model on the original graph. Returns ------- torch.nn.Module The trained target model """ print("Training target model...") # Initialize model model = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number) model = model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) # Setup data loading sampler = NeighborSampler([5, 5]) train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(self.device) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device) train_collator = NodeCollator(self.graph, train_nids, sampler) test_collator = NodeCollator(self.graph, test_nids, sampler) train_dataloader = DataLoader( train_collator.dataset, batch_size=32, shuffle=True, collate_fn=train_collator.collate, drop_last=False ) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) # Training loop best_acc = 0 for epoch in tqdm(range(1, 51), desc="Target model training"): # Train model.train() total_loss = 0 for _, _, blocks in train_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Test model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) acc = correct / total if acc > best_acc: best_acc = acc print(f"Target model trained. Test accuracy: {best_acc:.4f}") return model
[docs] def _train_defense_model(self): """ Helper function for training a defense model with watermarking. Returns ------- torch.nn.Module The trained defense model with embedded watermark """ print("Training defense model with watermarking...") # Generate watermark graph wm_graph = self._generate_watermark_graph() # Initialize model model = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number) model = model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) # Setup data loading for original graph sampler = NeighborSampler([5, 5]) train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(self.device) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device) train_collator = NodeCollator(self.graph, train_nids, sampler) test_collator = NodeCollator(self.graph, test_nids, sampler) train_dataloader = DataLoader( train_collator.dataset, batch_size=32, shuffle=True, collate_fn=train_collator.collate, drop_last=False ) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) # Setup data loading for watermark graph wm_nids = torch.arange(wm_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(wm_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=True, collate_fn=wm_collator.collate, drop_last=False ) # First stage: Train on original graph best_acc = 0 for epoch in tqdm(range(1, 51), desc="Defense model - stage 1"): # Train model.train() total_loss = 0 for _, _, blocks in train_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Test model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) acc = correct / total if acc > best_acc: best_acc = acc # Second stage: Fine-tune on watermark graph for epoch in tqdm(range(1, 11), desc="Defense model - stage 2"): # Train on watermark model.train() total_loss = 0 for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Final evaluation model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) final_acc = correct / total # Watermark accuracy wm_acc = self._test_on_watermark(model, wm_dataloader) print(f"Defense model trained.") print(f"Test accuracy on original data: {final_acc:.4f}") print(f"Test accuracy on watermark: {wm_acc:.4f}") # Store watermark graph for later verification self.watermark_graph = wm_graph return model
[docs] def _generate_watermark_graph(self): """ Generate a watermark graph using Erdos-Renyi random graph model. Returns ------- dgl.DGLGraph The generated watermark graph """ # Generate random edges using Erdos-Renyi model wm_edge_index = erdos_renyi_graph(self.wm_node, self.pg, directed=False) # Generate random features with binomial distribution wm_features = torch.tensor(np.random.binomial( 1, self.pr, size=(self.wm_node, self.feature_number)), dtype=torch.float32).to(self.device) # Generate random labels wm_labels = torch.tensor(np.random.randint( low=0, high=self.label_number, size=self.wm_node), dtype=torch.long).to(self.device) # Create DGL graph wm_graph = dgl.graph((wm_edge_index[0], wm_edge_index[1]), num_nodes=self.wm_node) wm_graph = wm_graph.to(self.device) # Add node features and labels wm_graph.ndata['feat'] = wm_features wm_graph.ndata['label'] = wm_labels # Add train and test masks (all True for simplicity) wm_graph.ndata['train_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=self.device) wm_graph.ndata['test_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=self.device) # Add self-loops wm_graph = dgl.add_self_loop(wm_graph) return wm_graph
[docs] def _test_on_watermark(self, model, wm_dataloader): """ Test a model's accuracy on the watermark graph. Parameters ---------- model : torch.nn.Module The model to test wm_dataloader : DataLoader DataLoader for the watermark graph Returns ------- float Accuracy on the watermark graph """ model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) return correct / total
[docs] def _evaluate_watermark(self, model): """ Evaluate watermark detection effectiveness. Parameters ---------- model : torch.nn.Module The model to evaluate Returns ------- float Watermark detection accuracy """ if not hasattr(self, 'watermark_graph'): print("Warning: No watermark graph found. Generate one first.") return 0.0 # Setup data loading for watermark graph sampler = NeighborSampler([5, 5]) wm_nids = torch.arange(self.watermark_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(self.watermark_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=wm_collator.collate, drop_last=False ) return self._test_on_watermark(model, wm_dataloader)
[docs] def _evaluate_attack_on_watermark(self, attack_model): """ Evaluate how well the attack model performs on the watermark graph. Parameters ---------- attack_model : torch.nn.Module The model obtained from the attack Returns ------- float Attack model's accuracy on the watermark graph """ if not hasattr(self, 'watermark_graph'): print("Warning: No watermark graph found. Generate one first.") return 0.0 # Check the model type to determine the correct evaluation approach model_name = attack_model.__class__.__name__ # For GCN models that expect (g, features) input format if model_name == 'GCN': # Evaluate using the whole graph at once attack_model.eval() with torch.no_grad(): # Pass the entire graph and features at once output_predictions = attack_model(self.watermark_graph, self.watermark_graph.ndata['feat']) pred = output_predictions.argmax(dim=1) correct = (pred == self.watermark_graph.ndata['label']).sum().item() total = self.watermark_graph.number_of_nodes() return correct / total # For GraphSAGE models that expect blocks input format elif model_name == 'GraphSAGE': # Setup data loading for watermark graph sampler = NeighborSampler([5, 5]) wm_nids = torch.arange(self.watermark_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(self.watermark_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=wm_collator.collate, drop_last=False ) # Evaluate attack model on watermark attack_model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = attack_model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) return correct / total # For any other model type, print a warning and return 0 else: print(f"Warning: Unsupported model type '{model_name}' for watermark evaluation") return 0.0