Source code for pygip.models.defense.RandomWM

import importlib
import time

import dgl
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch_geometric.utils import erdos_renyi_graph
from dgl.dataloading import NeighborSampler, NodeCollator

from pygip.models.nn import GraphSAGE
from pygip.models.defense.base import BaseDefense
from pygip.utils.metrics import DefenseCompMetric, DefenseMetric


[docs]class RandomWM(BaseDefense): """ A flexible defense implementation using watermarking to protect against model extraction attacks on graph neural networks. This class combines the functionalities from the original watermark.py: - Generating watermark graphs - Training models on original and watermark graphs - Merging graphs for testing - Evaluating effectiveness against attacks - Dynamic selection of attack methods """ supported_api_types = {"dgl"} def __init__(self, dataset, defense_ratio=0.1, wm_node=50, pr=0.2, pg=0.2, attack_name=None): """ Initialize the custom defense. Parameters ---------- defense_ratio : float Defense strength (0-1): used to determine the number of watermark nodes and the attack node sampling scale wm_node : Optional[int] If specified, a fixed number of watermark nodes is used; otherwise, a dynamic calculation is performed based on defense_ratio * num_nodes pr : float Bernoulli probability of the watermark feature being 1 pg : float Edge probability of the watermark graph attack_name : Optional[str] Attack class name (from models.attack) """ super().__init__(dataset, defense_ratio) self.defense_ratio = defense_ratio self.attack_name = attack_name or "ModelExtractionAttack0" self.dataset = dataset self.graph = dataset.graph_data # Extract dataset properties self.node_number = dataset.num_nodes self.feature_number = dataset.num_features self.label_number = dataset.num_classes self.attack_node_number = int(self.node_number * defense_ratio) # Watermark parameters self.wm_node = int(wm_node) if wm_node is not None else max(10, int(dataset.num_nodes * defense_ratio)) self.pr = pr self.pg = pg # Extract features and labels self.features = self.graph.ndata['feat'] self.labels = self.graph.ndata['label'] # Extract masks self.train_mask = self.graph.ndata['train_mask'] self.test_mask = self.graph.ndata['test_mask'] # Move tensors to self.device if self.device != 'cpu': self.graph = self.graph.to(self.device) self.features = self.features.to(self.device) self.labels = self.labels.to(self.device) self.train_mask = self.train_mask.to(self.device) self.test_mask = self.test_mask.to(self.device)
[docs] def _get_attack_class(self, attack_name): """ Dynamically import and return the specified attack class. Parameters ---------- attack_name : str Name of the attack class to import Returns ------- class The requested attack class """ try: # Try to import from models.attack module attack_module = importlib.import_module('models.attack') attack_class = getattr(attack_module, attack_name) return attack_class except (ImportError, AttributeError) as e: print(f"Error loading attack class '{attack_name}': {e}") print("Falling back to ModelExtractionAttack0") # Fallback to ModelExtractionAttack0 attack_module = importlib.import_module('models.attack') return getattr(attack_module, "ModelExtractionAttack0")
[docs] def defend(self, attack_name=None): """ Execute the random watermark defense. """ metric_comp = DefenseCompMetric() metric_comp.start() print("====================Random Watermark Defense====================") # If model wasn't trained yet, train it if not hasattr(self, 'model_trained'): self.train_target_model(metric_comp) # Evaluate the defended model preds = self.evaluate_model(self.defense_model, self.dataset) inference_s = time.time() wm_preds = self.verify_watermark(self.defense_model) inference_e = time.time() # metric metric = DefenseMetric() metric.update(preds, self.labels[self.test_mask]) wm_labels = self.watermark_graph.ndata['label'] metric.update_wm(wm_preds, wm_labels) metric_comp.end() print("====================Final Results====================") res = metric.compute() metric_comp.update(inference_defense_time=(inference_e - inference_s)) res_comp = metric_comp.compute() return res, res_comp
[docs] def train_target_model(self, metric_comp: DefenseCompMetric): """Train the target model with watermark injection.""" defense_s = time.time() # Training and watermark generation (defense mechanism) self.defense_model = self._train_defense_model() self.model_trained = True defense_e = time.time() metric_comp.update(defense_time=(defense_e - defense_s)) return self.defense_model
[docs] def evaluate_model(self, model, dataset): """Evaluate model performance on downstream task""" model.eval() # Setup data loading sampler = NeighborSampler([5, 5]) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device) test_collator = NodeCollator(self.graph, test_nids, sampler) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) all_preds = [] with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) all_preds.append(pred) preds = torch.cat(all_preds, dim=0).cpu() return preds
[docs] def verify_watermark(self, model): """Verify watermark success rate""" model.eval() # Setup data loading for watermark graph sampler = NeighborSampler([5, 5]) wm_nids = torch.arange(self.watermark_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(self.watermark_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=wm_collator.collate, drop_last=False ) all_preds = [] with torch.no_grad(): for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) all_preds.append(pred) wm_preds = torch.cat(all_preds, dim=0).cpu() return wm_preds
[docs] def _train_target_model(self): """ Helper function for training the target model on the original graph. Returns ------- torch.nn.Module The trained target model """ print("Training target model...") # Initialize model model = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number) model = model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) # Setup data loading sampler = NeighborSampler([5, 5]) train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(self.device) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device) train_collator = NodeCollator(self.graph, train_nids, sampler) test_collator = NodeCollator(self.graph, test_nids, sampler) train_dataloader = DataLoader( train_collator.dataset, batch_size=32, shuffle=True, collate_fn=train_collator.collate, drop_last=False ) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) # Training loop best_acc = 0 for epoch in tqdm(range(1, 51), desc="Target model training"): # Train model.train() total_loss = 0 for _, _, blocks in train_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Test model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) acc = correct / total if acc > best_acc: best_acc = acc print(f"Target model trained. Test accuracy: {best_acc:.4f}") return model
[docs] def _train_defense_model(self): """ Helper function for training a defense model with watermarking. Returns ------- torch.nn.Module The trained defense model with embedded watermark """ print("Training defense model with watermarking...") # Generate watermark graph wm_graph = self._generate_watermark_graph() # Initialize model model = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number) model = model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) # Setup data loading for original graph sampler = NeighborSampler([5, 5]) train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(self.device) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device) train_collator = NodeCollator(self.graph, train_nids, sampler) test_collator = NodeCollator(self.graph, test_nids, sampler) train_dataloader = DataLoader( train_collator.dataset, batch_size=32, shuffle=True, collate_fn=train_collator.collate, drop_last=False ) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) # Setup data loading for watermark graph wm_nids = torch.arange(wm_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(wm_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=True, collate_fn=wm_collator.collate, drop_last=False ) # First stage: Train on original graph best_acc = 0 for epoch in tqdm(range(1, 51), desc="Defense model - stage 1"): # Train model.train() total_loss = 0 for _, _, blocks in train_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Test model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) acc = correct / total if acc > best_acc: best_acc = acc # Second stage: Fine-tune on watermark graph for epoch in tqdm(range(1, 11), desc="Defense model - stage 2"): # Train on watermark model.train() total_loss = 0 for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() total_loss += loss.item() # Final evaluation model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in test_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) final_acc = correct / total # Watermark accuracy wm_acc = self._test_on_watermark(model, wm_dataloader) print(f"Defense model trained.") print(f"Test accuracy on original data: {final_acc:.4f}") print(f"Test accuracy on watermark: {wm_acc:.4f}") # Store watermark graph for later verification self.watermark_graph = wm_graph return model
[docs] def _generate_watermark_graph(self): """ Generate a watermark graph using Erdos-Renyi random graph model. Returns ------- dgl.DGLGraph The generated watermark graph """ # Generate random edges using Erdos-Renyi model wm_edge_index = erdos_renyi_graph(self.wm_node, self.pg, directed=False) # Generate random features with binomial distribution wm_features = torch.tensor(np.random.binomial( 1, self.pr, size=(self.wm_node, self.feature_number)), dtype=torch.float32).to(self.device) # Generate random labels wm_labels = torch.tensor(np.random.randint( low=0, high=self.label_number, size=self.wm_node), dtype=torch.long).to(self.device) # Create DGL graph wm_graph = dgl.graph((wm_edge_index[0], wm_edge_index[1]), num_nodes=self.wm_node) wm_graph = wm_graph.to(self.device) # Add node features and labels wm_graph.ndata['feat'] = wm_features wm_graph.ndata['label'] = wm_labels # Add train and test masks (all True for simplicity) wm_graph.ndata['train_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=self.device) wm_graph.ndata['test_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=self.device) # Add self-loops wm_graph = dgl.add_self_loop(wm_graph) return wm_graph
[docs] def _test_on_watermark(self, model, wm_dataloader): """ Test a model's accuracy on the watermark graph. Parameters ---------- model : torch.nn.Module The model to test wm_dataloader : DataLoader DataLoader for the watermark graph Returns ------- float Accuracy on the watermark graph """ model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) return correct / total
[docs] def _evaluate_watermark(self, model): """ Evaluate watermark detection effectiveness. Parameters ---------- model : torch.nn.Module The model to evaluate Returns ------- float Watermark detection accuracy """ if not hasattr(self, 'watermark_graph'): print("Warning: No watermark graph found. Generate one first.") return 0.0 # Setup data loading for watermark graph sampler = NeighborSampler([5, 5]) wm_nids = torch.arange(self.watermark_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(self.watermark_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=wm_collator.collate, drop_last=False ) return self._test_on_watermark(model, wm_dataloader)
[docs] def _evaluate_attack_on_watermark(self, attack_model): """ Evaluate how well the attack model performs on the watermark graph. Parameters ---------- attack_model : torch.nn.Module The model obtained from the attack Returns ------- float Attack model's accuracy on the watermark graph """ if not hasattr(self, 'watermark_graph'): print("Warning: No watermark graph found. Generate one first.") return 0.0 # Check the model type to determine the correct evaluation approach model_name = attack_model.__class__.__name__ # For GCN models that expect (g, features) input format if model_name == 'GCN': # Evaluate using the whole graph at once attack_model.eval() with torch.no_grad(): # Pass the entire graph and features at once output_predictions = attack_model(self.watermark_graph, self.watermark_graph.ndata['feat']) pred = output_predictions.argmax(dim=1) correct = (pred == self.watermark_graph.ndata['label']).sum().item() total = self.watermark_graph.number_of_nodes() return correct / total # For GraphSAGE models that expect blocks input format elif model_name == 'GraphSAGE': # Setup data loading for watermark graph sampler = NeighborSampler([5, 5]) wm_nids = torch.arange(self.watermark_graph.number_of_nodes(), device=self.device) wm_collator = NodeCollator(self.watermark_graph, wm_nids, sampler) wm_dataloader = DataLoader( wm_collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=wm_collator.collate, drop_last=False ) # Evaluate attack model on watermark attack_model.eval() correct = 0 total = 0 with torch.no_grad(): for _, _, blocks in wm_dataloader: blocks = [b.to(self.device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = attack_model(blocks, input_features) pred = output_predictions.argmax(dim=1) correct += (pred == output_labels).sum().item() total += len(output_labels) return correct / total # For any other model type, print a warning and return 0 else: print(f"Warning: Unsupported model type '{model_name}' for watermark evaluation") return 0.0