Source code for pygip.models.defense.ImperceptibleWM2

import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import numpy as np
from tqdm import tqdm
from dgl.dataloading import NeighborSampler, NodeCollator
from torch.utils.data import DataLoader
from dgl.nn import GraphConv
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from pygip.models.defense.base import BaseDefense
from pygip.models.nn import GraphSAGE

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


[docs]class TriggerGenerator(nn.Module): """ Generate watermark trigger features and edge probabilities using a GCN-based architecture. This module constructs a small graph template and applies multiple GCN layers to produce node features that represent the watermark trigger. It also learns a function to generate edge probabilities between nodes using a neural edge generator. Parameters ---------- feature_dim : int Dimension of node feature vectors. hidden_dim : int, optional Dimension of hidden layers in GCN and edge generator. Default is 64. output_nodes : int, optional Number of nodes in the generated trigger graph. Default is 50. """ def __init__(self, feature_dim, hidden_dim=64, output_nodes=50): super(TriggerGenerator, self).__init__() self.feature_dim = feature_dim self.output_nodes = output_nodes self.gcn1 = GraphConv(feature_dim, hidden_dim) self.gcn2 = GraphConv(hidden_dim, hidden_dim) self.gcn3 = GraphConv(hidden_dim, feature_dim) self.edge_generator = nn.Sequential( nn.Linear(feature_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) # Create a small template graph for GCN processing self.template_graph = self._create_template_graph()
[docs] def _create_template_graph(self): """ Create a small template DGL graph structure to serve as the base for GCN processing. This function builds a fully connected undirected graph (with self-loops) consisting of up to 10 nodes. This graph serves as a structural template for generating watermark trigger node features. Returns ------- dgl.DGLGraph A small connected DGL graph with self-loops, moved to the appropriate device. """ # Create a small connected graph for initial processing edges = [] for i in range(min(10, self.output_nodes)): for j in range(i + 1, min(10, self.output_nodes)): edges.append((i, j)) edges.append((j, i)) if not edges: edges = [(0, 1), (1, 0)] src, dst = zip(*edges) if edges else ([0], [1]) g = dgl.graph((src, dst), num_nodes=min(10, self.output_nodes)) g = dgl.add_self_loop(g) return g.to(device)
[docs] def forward(self, clean_features, selected_nodes): """ Forward pass to generate trigger node features and edge probabilities. Constructs a trigger graph by first computing a prototype feature from selected clean nodes, propagating it through GCN layers, and generating additional nodes and edge probabilities to match the required trigger size. Parameters ---------- clean_features : torch.Tensor Feature matrix from the clean graph (shape: [num_nodes, feature_dim]). selected_nodes : list[int] or torch.Tensor Indices of nodes selected for constructing the prototype vector. Returns ------- trigger_features : torch.Tensor Feature matrix of generated trigger nodes (shape: [output_nodes, feature_dim]). edge_probs : torch.Tensor A 1D tensor containing probabilities for edges between node pairs (upper triangular, shape: [output_nodes * (output_nodes - 1) / 2]). """ # Create prototype from selected nodes if len(selected_nodes) > 0: sample_features = clean_features[selected_nodes[:min(len(selected_nodes), 10)]] prototype = sample_features.mean(dim=0, keepdim=True) else: prototype = clean_features.mean(dim=0, keepdim=True) # Replicate prototype for template graph nodes template_size = self.template_graph.num_nodes() template_features = prototype.repeat(template_size, 1) # Apply GCN layers h = F.relu(self.gcn1(self.template_graph, template_features)) h = F.relu(self.gcn2(self.template_graph, h)) h = torch.sigmoid(self.gcn3(self.template_graph, h)) # Expand to desired number of trigger nodes if template_size < self.output_nodes: # Replicate and add noise for additional nodes additional_nodes = self.output_nodes - template_size noise = torch.randn(additional_nodes, self.feature_dim, device=device) * 0.1 additional_features = h[-1:].repeat(additional_nodes, 1) + noise trigger_features = torch.cat([h, additional_features], dim=0) else: trigger_features = h[:self.output_nodes] # Generate edge probabilities n_nodes = self.output_nodes edge_probs = [] for i in range(n_nodes): for j in range(i + 1, n_nodes): pair_features = torch.cat([trigger_features[i], trigger_features[j]], dim=0) edge_prob = self.edge_generator(pair_features.unsqueeze(0)) edge_probs.append(edge_prob) edge_probs = torch.cat(edge_probs, dim=0) if edge_probs else torch.tensor([]).to(device) return trigger_features, edge_probs
[docs]class ImperceptibleWM2(BaseDefense): def __init__(self, dataset, attack_node_fraction=0.2, wm_node=50, target_label=None, N=50, M=5, epsilon1=1.0, epsilon2=0.5, epsilon3=1.0, owner_id=None, beta=0.001, T_acc=0.8): """ Initialize the watermark defense using bilevel optimization. Parameters ---------- dataset : object The graph dataset containing features, labels, and graph structure attack_node_fraction : float, default=0.2 Fraction of nodes to consider for attack simulation wm_node : int, default=50 Number of nodes in the watermark/trigger graph target_label : int, optional Target label for watermark classification. If None, randomly selected N : int, default=50 Number of bilevel optimization iterations M : int, default=5 Number of embedding phase iterations per bilevel step epsilon1 : float, default=1.0 Weight for imperception loss in generator objective epsilon2 : float, default=0.5 Weight for regulation loss in generator objective epsilon3 : float, default=1.0 Weight for trigger loss in generator objective owner_id : array-like, optional Owner identifier for watermark regulation. If None, randomly generated beta : float, default=0.001 Learning rate for the main model optimizer T_acc : float, default=0.8 Accuracy threshold for ownership verification """ super().__init__(dataset, attack_node_fraction) self.dataset = dataset self.graph = dataset.graph self.node_number = dataset.node_number if hasattr(dataset, 'node_number') else self.graph.num_nodes() self.feature_number = dataset.feature_number if hasattr(dataset, 'feature_number') else \ self.graph.ndata['feat'].shape[1] self.label_number = dataset.label_number if hasattr(dataset, 'label_number') else ( int(max(self.graph.ndata['label']) - min(self.graph.ndata['label'])) + 1) self.attack_node_number = int(self.node_number * attack_node_fraction) self.wm_node = wm_node self.target_label = target_label if target_label is not None else np.random.randint(0, self.label_number) self.N = N self.M = M self.beta = beta self.T_acc = T_acc self.epsilon1 = epsilon1 self.epsilon2 = epsilon2 self.epsilon3 = epsilon3 self.owner_id = owner_id if owner_id is not None else torch.rand(self.feature_number, device=device) if isinstance(self.owner_id, (list, np.ndarray)): self.owner_id = torch.tensor(self.owner_id, dtype=torch.float32, device=device) elif not isinstance(self.owner_id, torch.Tensor): self.owner_id = torch.rand(self.feature_number, device=device) self.features = dataset.features if hasattr(dataset, 'features') else self.graph.ndata['feat'] self.labels = dataset.labels if hasattr(dataset, 'labels') else self.graph.ndata['label'] self.train_mask = dataset.train_mask if hasattr(dataset, 'train_mask') else self.graph.ndata['train_mask'] self.test_mask = dataset.test_mask if hasattr(dataset, 'test_mask') else self.graph.ndata['test_mask'] if device != 'cpu': self.graph = self.graph.to(device) self.features = self.features.to(device) self.labels = self.labels.to(device) self.train_mask = self.train_mask.to(device) self.test_mask = self.test_mask.to(device) self.owner_id = self.owner_id.to(device)
[docs] def _select_poisoning_nodes(self, clean_model): """ Select nodes for watermark poisoning based on model predictions. Uses the clean model's confidence scores to identify high-confidence nodes across different labels for creating the watermark trigger. Parameters ---------- clean_model : torch.nn.Module Pre-trained clean model used for node selection Returns ------- torch.Tensor Tensor of selected node indices for poisoning """ clean_model.eval() with torch.no_grad(): sampler = NeighborSampler([5, 5]) all_nids = torch.arange(self.graph.num_nodes(), device=device) collator = NodeCollator(self.graph, all_nids, sampler) dataloader = DataLoader( collator.dataset, batch_size=64, shuffle=False, collate_fn=collator.collate, drop_last=False ) all_predictions = [] node_indices = [] for input_nodes, output_nodes, blocks in dataloader: blocks = [b.to(device) for b in blocks] input_features = blocks[0].srcdata['feat'] logits = clean_model(blocks, input_features) predictions = F.softmax(logits, dim=1) all_predictions.append(predictions) node_indices.append(output_nodes) all_predictions = torch.cat(all_predictions, dim=0) node_indices = torch.cat(node_indices, dim=0) poisoning_nodes = [] nodes_per_label = max(1, self.wm_node // self.label_number) for label in range(self.label_number): label_probs = all_predictions[:, label] _, top_indices = torch.topk(label_probs, min(nodes_per_label, len(label_probs))) selected_nodes = node_indices[top_indices] poisoning_nodes.extend(selected_nodes.tolist()) if len(poisoning_nodes) < self.wm_node: remaining_nodes = set(range(self.graph.num_nodes())) - set(poisoning_nodes) additional_nodes = np.random.choice( list(remaining_nodes), size=min(self.wm_node - len(poisoning_nodes), len(remaining_nodes)), replace=False ) poisoning_nodes.extend(additional_nodes) poisoning_nodes = poisoning_nodes[:self.wm_node] return torch.tensor(poisoning_nodes, device=device)
[docs] def _generate_trigger_graph(self, f_g, V_p): """ Generate a watermark trigger graph using the generator network. Creates trigger features and edges based on poisoning nodes and constructs a DGL graph for watermark embedding. Parameters ---------- f_g : torch.nn.Module Trigger generator network V_p : torch.Tensor Selected poisoning node indices Returns ------- dgl.DGLGraph The generated watermark trigger graph with features and labels """ f_g.eval() with torch.no_grad(): trigger_features, edge_probs = f_g(self.features, V_p) edge_threshold = 0.5 edges_src, edges_dst = [], [] edge_idx = 0 for i in range(self.wm_node): for j in range(i + 1, self.wm_node): if edge_idx < len(edge_probs) and edge_probs[edge_idx] > edge_threshold: edges_src.extend([i, j]) edges_dst.extend([j, i]) edge_idx += 1 if len(edges_src) == 0: edges_src = [0, 1] edges_dst = [1, 0] trigger_graph = dgl.graph((edges_src, edges_dst), num_nodes=self.wm_node) trigger_graph = trigger_graph.to(device) trigger_graph.ndata['feat'] = trigger_features.detach() trigger_graph.ndata['label'] = torch.full((self.wm_node,), self.target_label, dtype=torch.long, device=device) trigger_graph.ndata['train_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=device) trigger_graph.ndata['test_mask'] = torch.ones(self.wm_node, dtype=torch.bool, device=device) trigger_graph = dgl.add_self_loop(trigger_graph) return trigger_graph
[docs] def _construct_backdoor_graph(self, clean_graph, trigger_graph, V_p): """ Construct a backdoor graph by combining clean graph with trigger graph. Merges the original graph with the watermark trigger graph by adding connections between poisoning nodes and trigger nodes. Parameters ---------- clean_graph : dgl.DGLGraph Original clean graph trigger_graph : dgl.DGLGraph Generated trigger/watermark graph V_p : torch.Tensor Poisoning node indices for connection Returns ------- dgl.DGLGraph Combined backdoor graph with embedded watermark """ clean_adj = clean_graph.adj().to_dense() trigger_adj = trigger_graph.adj().to_dense() clean_features = clean_graph.ndata['feat'] trigger_features = trigger_graph.ndata['feat'] clean_labels = clean_graph.ndata['label'] trigger_labels = trigger_graph.ndata['label'] n_clean = clean_graph.num_nodes() n_trigger = trigger_graph.num_nodes() A_I = torch.zeros(n_trigger, n_clean, device=device) for i in range(n_trigger): for j in V_p: if torch.rand(1) > 0.7: A_I[i, j] = 1 top_row = torch.cat([clean_adj, A_I.t()], dim=1) bottom_row = torch.cat([A_I, trigger_adj], dim=1) backdoor_adj = torch.cat([top_row, bottom_row], dim=0) backdoor_features = torch.cat([clean_features, trigger_features], dim=0) backdoor_labels = torch.cat([clean_labels, trigger_labels], dim=0) edges_src, edges_dst = torch.nonzero(backdoor_adj, as_tuple=True) backdoor_graph = dgl.graph((edges_src, edges_dst), num_nodes=n_clean + n_trigger) backdoor_graph = backdoor_graph.to(device) backdoor_graph.ndata['feat'] = backdoor_features backdoor_graph.ndata['label'] = backdoor_labels clean_train_mask = clean_graph.ndata['train_mask'] clean_test_mask = clean_graph.ndata['test_mask'] trigger_train_mask = torch.ones(n_trigger, dtype=torch.bool, device=device) trigger_test_mask = torch.ones(n_trigger, dtype=torch.bool, device=device) backdoor_graph.ndata['train_mask'] = torch.cat([clean_train_mask, trigger_train_mask]) backdoor_graph.ndata['test_mask'] = torch.cat([clean_test_mask, trigger_test_mask]) backdoor_graph = dgl.add_self_loop(backdoor_graph) return backdoor_graph
[docs] def _calculate_imperception_loss(self, trigger_features, V_p): """ Calculate imperception loss to make watermark features similar to clean features. Measures cosine similarity between trigger features and poisoning node features to ensure the watermark remains hidden. Parameters ---------- trigger_features : torch.Tensor Generated trigger node features V_p : torch.Tensor Poisoning node indices Returns ------- torch.Tensor Imperception loss value """ if len(V_p) == 0: return torch.tensor(0.0, device=device) poisoning_features = self.features[V_p] total_similarity = 0 count = 0 for i, trigger_feat in enumerate(trigger_features): for poison_feat in poisoning_features: similarity = F.cosine_similarity(trigger_feat.unsqueeze(0), poison_feat.unsqueeze(0)) total_similarity += similarity count += 1 return -total_similarity / count if count > 0 else torch.tensor(0.0, device=device)
[docs] def _calculate_regulation_loss(self, trigger_features): """ Calculate regulation loss based on owner ID signature. Enforces the trigger features to embed owner identification information using cross-entropy loss with the owner ID as target. Parameters ---------- trigger_features : torch.Tensor Generated trigger node features Returns ------- torch.Tensor Regulation loss value """ total_loss = 0 for trigger_feat in trigger_features: loss = -(self.owner_id * torch.log(trigger_feat + 1e-8) + (1 - self.owner_id) * torch.log(1 - trigger_feat + 1e-8)) total_loss += loss.mean() return total_loss / len(trigger_features)
[docs] def _calculate_trigger_loss(self, f_theta, trigger_features, trigger_graph): """ Calculate trigger loss for watermark effectiveness. Measures how well the model classifies trigger nodes to the target label, ensuring the watermark functions correctly. Parameters ---------- f_theta : torch.nn.Module Main classification model trigger_features : torch.Tensor Generated trigger node features trigger_graph : dgl.DGLGraph Trigger graph structure Returns ------- torch.Tensor Trigger loss value """ f_theta.eval() sampler = NeighborSampler([5, 5]) trigger_nids = torch.arange(trigger_graph.number_of_nodes(), device=device) collator = NodeCollator(trigger_graph, trigger_nids, sampler) dataloader = DataLoader( collator.dataset, batch_size=self.wm_node, shuffle=False, collate_fn=collator.collate, drop_last=False ) total_loss = 0 count = 0 with torch.no_grad(): for _, _, blocks in dataloader: blocks = [b.to(device) for b in blocks] input_features = blocks[0].srcdata['feat'] logits = f_theta(blocks, input_features) probs = F.softmax(logits, dim=1) target_probs = probs[:, self.target_label] loss = -torch.log(target_probs + 1e-8).mean() total_loss += loss count += 1 break return total_loss / count if count > 0 else torch.tensor(0.0, device=device)
[docs] def _calculate_generation_loss_integrated(self, f_theta_s, f_g, V_p): """ Calculate integrated generation loss combining all generator objectives. Combines imperception, regulation, and trigger losses with respective weights to optimize the trigger generator network. Parameters ---------- f_theta_s : torch.nn.Module Current state of the main model f_g : torch.nn.Module Trigger generator network V_p : torch.Tensor Poisoning node indices Returns ------- torch.Tensor Combined generation loss """ f_g.train() f_theta_s.eval() trigger_features, edge_probs = f_g(self.features, V_p) temp_trigger_graph = self._create_temp_trigger_graph(trigger_features, edge_probs) L_imperception = self._calculate_imperception_loss(trigger_features, V_p) L_regulation = self._calculate_regulation_loss(trigger_features) L_trigger = self._calculate_trigger_loss(f_theta_s, trigger_features, temp_trigger_graph) L_g = (self.epsilon1 * L_imperception + self.epsilon2 * L_regulation + self.epsilon3 * L_trigger) return L_g
[docs] def _create_temp_trigger_graph(self, trigger_features, edge_probs): """ Create a temporary trigger graph for loss calculation. Constructs a temporary graph structure using generated features and edge probabilities for intermediate computations. Parameters ---------- trigger_features : torch.Tensor Generated trigger node features edge_probs : torch.Tensor Edge existence probabilities Returns ------- dgl.DGLGraph Temporary trigger graph """ edge_threshold = 0.5 edges_src, edges_dst = [], [] edge_idx = 0 for i in range(self.wm_node): for j in range(i + 1, self.wm_node): if edge_idx < len(edge_probs) and edge_probs[edge_idx] > edge_threshold: edges_src.extend([i, j]) edges_dst.extend([j, i]) edge_idx += 1 if len(edges_src) == 0: edges_src = [0, 1] edges_dst = [1, 0] temp_graph = dgl.graph((edges_src, edges_dst), num_nodes=self.wm_node) temp_graph = temp_graph.to(device) temp_graph.ndata['feat'] = trigger_features temp_graph.ndata['label'] = torch.full((self.wm_node,), self.target_label, dtype=torch.long, device=device) temp_graph = dgl.add_self_loop(temp_graph) return temp_graph
[docs] def _calculate_embedding_loss(self, f_theta, backdoor_graph): """ Calculate embedding loss for model training on backdoor graph. Computes cross-entropy loss for training the main model on the combined clean and trigger graph data. Parameters ---------- f_theta : torch.nn.Module Main classification model backdoor_graph : dgl.DGLGraph Combined graph with embedded watermark Returns ------- torch.Tensor Embedding loss value """ f_theta.train() backdoor_train_nids = backdoor_graph.ndata['train_mask'].nonzero(as_tuple=True)[0].to(device) sampler = NeighborSampler([5, 5]) collator = NodeCollator(backdoor_graph, backdoor_train_nids, sampler) dataloader = DataLoader( collator.dataset, batch_size=32, shuffle=True, collate_fn=collator.collate, drop_last=False ) total_loss = 0 count = 0 for _, _, blocks in dataloader: blocks = [b.to(device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = f_theta(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) total_loss += loss count += 1 if count >= 10: break return total_loss / count if count > 0 else torch.tensor(0.0, device=device)
[docs] def _inner_optimization(self, f_theta, f_g, V_p, optimizer): """ Execute the watermark embedding phase of bilevel optimization. Performs M iterations of model training on the backdoor graph to embed the watermark into the model parameters. Parameters ---------- f_theta : torch.nn.Module Main classification model f_g : torch.nn.Module Trigger generator network V_p : torch.Tensor Poisoning node indices optimizer : torch.optim.Optimizer Optimizer for model parameters Returns ------- torch.nn.Module Updated model with embedded watermark """ trigger_graph = self._generate_trigger_graph(f_g, V_p) backdoor_graph = self._construct_backdoor_graph(self.graph, trigger_graph, V_p) for t in range(self.M): L_embed = self._calculate_embedding_loss(f_theta, backdoor_graph) optimizer.zero_grad() L_embed.backward() optimizer.step() return f_theta
[docs] def defend(self): """ Execute the complete watermark defense strategy. Trains target model, applies watermark defense, and verifies ownership. Returns comprehensive evaluation metrics and ownership verification results. Returns ------- dict Dictionary containing attack metrics, defense metrics, ownership verification status, and trained generator """ attack_model = self._train_target_model() sampler = NeighborSampler([5, 5]) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(device) test_collator = NodeCollator(self.graph, test_nids, sampler) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) attack_acc, attack_prec, attack_rec, attack_f1 = self._evaluate_with_metrics(attack_model, test_dataloader) print("Target Model Metrics:") print(f" Accuracy : {attack_acc * 100:.2f}%") print(f" Precision: {attack_prec * 100:.2f}%") print(f" Recall : {attack_rec * 100:.2f}%") print(f" F1 Score : {attack_f1 * 100:.2f}%") defense_model, generator = self._train_defense_model() defense_acc, defense_prec, defense_rec, defense_f1 = self._evaluate_with_metrics(defense_model, test_dataloader) print("Defense Model Metrics:") print(f" Accuracy : {defense_acc * 100:.2f}%") print(f" Precision: {defense_prec * 100:.2f}%") print(f" Recall : {defense_rec * 100:.2f}%") print(f" F1 Score : {defense_f1 * 100:.2f}%") is_owner, ownership_acc = self.verify_ownership(defense_model) print(f"\nOwnership Verification: {is_owner}, Watermark Accuracy: {ownership_acc * 100:.2f}%") return { "attack_accuracy": attack_acc, "attack_precision": attack_prec, "attack_recall": attack_rec, "attack_f1": attack_f1, "defense_accuracy": defense_acc, "defense_precision": defense_prec, "defense_recall": defense_rec, "defense_f1": defense_f1, "ownership_verified": is_owner, "ownership_accuracy": ownership_acc, "generator": generator }
[docs] def _train_target_model(self): """ Train the target model on clean graph data. Creates and trains a GraphSAGE model on the original dataset without any watermark or defense mechanisms. Returns ------- torch.nn.Module Trained target model """ model = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) sampler = NeighborSampler([5, 5]) train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(device) test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(device) train_collator = NodeCollator(self.graph, train_nids, sampler) test_collator = NodeCollator(self.graph, test_nids, sampler) train_dataloader = DataLoader( train_collator.dataset, batch_size=32, shuffle=True, collate_fn=train_collator.collate, drop_last=False ) test_dataloader = DataLoader( test_collator.dataset, batch_size=32, shuffle=False, collate_fn=test_collator.collate, drop_last=False ) for epoch in tqdm(range(1, 51), desc="========== Training Target Model =========="): model.train() for _, _, blocks in train_dataloader: blocks = [b.to(device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] optimizer.zero_grad() output_predictions = model(blocks, input_features) loss = F.cross_entropy(output_predictions, output_labels) loss.backward() optimizer.step() return model
[docs] def _train_defense_model(self): """ Train the defense model with watermark embedding using bilevel optimization. Implements the complete bilevel optimization process alternating between watermark embedding and trigger generation phases. Returns ------- tuple (trained_defense_model, trigger_generator) """ f_theta = GraphSAGE(in_channels=self.feature_number, hidden_channels=128, out_channels=self.label_number).to(device) f_g = TriggerGenerator(feature_dim=self.feature_number, output_nodes=self.wm_node).to(device) print("\n========== Training Defense Model ==========") # High confidence nodes from the target model will be used as trigger print("Retraining the target model to select poisoning nodes") clean_model = self._train_target_model() V_p = self._select_poisoning_nodes(clean_model) theta_optimizer = torch.optim.Adam(f_theta.parameters(), lr=self.beta, weight_decay=5e-4) g_optimizer = torch.optim.Adam(f_g.parameters(), lr=0.001, weight_decay=5e-4) for i in tqdm(range(self.N), desc="Starting BiLevelOptimization Process"): f_theta = self._inner_optimization(f_theta, f_g, V_p, theta_optimizer) f_theta_s = f_theta L_g = self._calculate_generation_loss_integrated(f_theta_s, f_g, V_p) g_optimizer.zero_grad() L_g.backward() g_optimizer.step() self.watermark_graph = self._generate_trigger_graph(f_g, V_p) self.poisoning_nodes = V_p return f_theta, f_g
[docs] def _evaluate_with_metrics(self, model, dataloader): """ Evaluate model performance using multiple classification metrics. Parameters ---------- model : torch.nn.Module The neural network model to evaluate dataloader : torch.utils.data.DataLoader DataLoader containing evaluation data Returns ------- tuple of float (accuracy, precision, recall, f1_score) metrics """ model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for _, _, blocks in dataloader: blocks = [b.to(device) for b in blocks] input_features = blocks[0].srcdata['feat'] output_labels = blocks[-1].dstdata['label'] output_predictions = model(blocks, input_features) pred = output_predictions.argmax(dim=1) all_preds.extend(pred.cpu().numpy()) all_labels.extend(output_labels.cpu().numpy()) if len(all_preds) == 0: return 0.0, 0.0, 0.0, 0.0 accuracy = accuracy_score(all_labels, all_preds) precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) return accuracy, precision, recall, f1
[docs] def verify_ownership(self, suspicious_model): """ Verify ownership of a suspicious model using the watermark. Tests if the suspicious model correctly classifies the watermark trigger graph to determine if it contains the embedded watermark. Parameters ---------- suspicious_model : torch.nn.Module Model to test for ownership verification Returns ------- tuple (is_owner: bool, ownership_accuracy: float) """ if not hasattr(self, 'watermark_graph'): return False, 0.0 G_key_p = self.watermark_graph acc, _, _, _ = self._evaluate_model_on_graph(suspicious_model, G_key_p) is_owner = acc > self.T_acc return is_owner, acc
[docs] def _evaluate_model_on_graph(self, model, graph): """ Evaluate model performance on a specific graph. Computes classification metrics for the given model on the provided graph, handling different model architectures appropriately. Parameters ---------- model : torch.nn.Module Model to evaluate graph : dgl.DGLGraph Graph data for evaluation Returns ------- tuple of float (accuracy, precision, recall, f1_score) metrics """ model_name = model.__class__.__name__ if model_name == 'GraphSAGE': sampler = NeighborSampler([5, 5]) trigger_nids = torch.arange(graph.number_of_nodes(), device=device) trigger_collator = NodeCollator(graph, trigger_nids, sampler) trigger_dataloader = DataLoader( trigger_collator.dataset, batch_size=graph.number_of_nodes(), shuffle=False, collate_fn=trigger_collator.collate, drop_last=False ) return self._evaluate_with_metrics(model, trigger_dataloader) else: return 0.0, 0.0, 0.0, 0.0