Source code for pygip.models.attack.Realistic

import random
import warnings
import time

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics.pairwise import cosine_similarity

from pygip.models.nn.backbones import GCN
from .base import BaseAttack
from pygip.utils.metrics import AttackMetric, AttackCompMetric  # align with AdvMEA

warnings.filterwarnings('ignore')


[docs]class DGLEdgePredictor(nn.Module): """DGL version of edge prediction module.""" def __init__(self, input_dim, hidden_dim, num_classes, device): super(DGLEdgePredictor, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_classes = num_classes self.device = device # Use the same GCN backbone as the target model to obtain node embeddings. self.gnn = GCN(input_dim, hidden_dim) self.node_classifier = nn.Linear(hidden_dim, num_classes) # Edge prediction head. self.edge_predictor = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid() )
[docs] def forward(self, graph, features): # Compute node embeddings then logits for node classification. node_embeddings = self.gnn(graph, features) node_logits = self.node_classifier(node_embeddings) return node_embeddings, node_logits
[docs] def predict_edges(self, node_embeddings, node_pairs): """Predict edge existence probability for a list of node index pairs.""" if len(node_pairs) == 0: return torch.tensor([], device=self.device) node_pairs = torch.tensor(node_pairs, device=self.device) src_embeddings = node_embeddings[node_pairs[:, 0]] dst_embeddings = node_embeddings[node_pairs[:, 1]] edge_features = torch.cat([src_embeddings, dst_embeddings], dim=1) edge_probs = self.edge_predictor(edge_features).squeeze() return edge_probs
[docs]class DGLSurrogateModel(nn.Module): """DGL version of surrogate model.""" def __init__(self, input_dim, num_classes, model_type='GCN'): super(DGLSurrogateModel, self).__init__() self.model_type = model_type if model_type == 'GCN': self.gnn = GCN(input_dim, num_classes) else: # Default to GCN; can be extended to other backbones. self.gnn = GCN(input_dim, num_classes)
[docs] def forward(self, graph, features): return self.gnn(graph, features)
[docs]class RealisticAttack(BaseAttack): """DGL-based GNN model extraction attack with updated metrics API.""" supported_api_types = {"dgl"} supported_datasets = {} def __init__(self, dataset, attack_x_ratio: float, attack_a_ratio: float, model_path: str = None, hidden_dim: int = 64, threshold_s: float = 0.7, threshold_a: float = 0.5): # Keep BaseAttack initialization contract; store ratios for this attack. super().__init__(dataset, attack_x_ratio, model_path) self.attack_x_ratio = float(attack_x_ratio) self.attack_a_ratio = float(attack_a_ratio) self.hidden_dim = hidden_dim self.threshold_s = threshold_s # Cosine similarity threshold self.threshold_a = threshold_a # Edge prediction threshold # Determine the number of queried nodes by the availability ratios. ratio_budget = max(self.attack_x_ratio, self.attack_a_ratio) if ratio_budget <= 0.0: ratio_budget = 0.05 # small default to avoid zero queries self.attack_node_number = max(1, int(self.num_nodes * ratio_budget)) # Graph tensors self.graph_data = self.graph_data.to(self.device) self.graph = self.graph_data self.features = self.graph.ndata['feat'] # Initialize edge predictor and surrogate model. self.edge_predictor = DGLEdgePredictor( self.num_features, hidden_dim, self.num_classes, self.device ).to(self.device) self.surrogate_model = DGLSurrogateModel( self.num_features, self.num_classes ).to(self.device) # Target model used to simulate black-box responses. self.net1 = GCN(self.num_features, self.num_classes).to(self.device) # Optimizers self.optimizer_edge = optim.Adam(self.edge_predictor.parameters(), lr=0.01, weight_decay=5e-4) self.optimizer_surrogate = optim.Adam(self.surrogate_model.parameters(), lr=0.01, weight_decay=5e-4) print(f"Initialized attack on {dataset.dataset_name} dataset") print(f"Nodes: {self.num_nodes}, Features: {self.num_features}, Classes: {self.num_classes}") print(f"Attack nodes: {self.attack_node_number} (x_ratio={self.attack_x_ratio:.2f}, a_ratio={self.attack_a_ratio:.2f})")
[docs] def simulate_target_model_queries(self, query_nodes, error_rate=0.15): """Query the target model for labels on query_nodes and introduce a small error rate.""" self.net1.eval() with torch.no_grad(): logits = self.net1(self.graph, self.features) predictions = F.log_softmax(logits, dim=1).argmax(dim=1) predicted_labels = predictions[query_nodes].clone() # Flip a portion of labels to simulate noise in responses. num_errors = int(len(predicted_labels) * error_rate) if num_errors > 0: error_indices = random.sample(range(len(predicted_labels)), num_errors) for idx in error_indices: wrong_label = random.randint(0, self.num_classes - 1) predicted_labels[idx] = wrong_label return predicted_labels
[docs] def compute_cosine_similarity(self, features): """Compute cosine similarity of node features.""" features_np = features.cpu().detach().numpy() similarity_matrix = cosine_similarity(features_np) return torch.tensor(similarity_matrix, dtype=torch.float32, device=self.device)
[docs] def generate_candidate_edges(self, labeled_nodes, unlabeled_nodes): """Generate candidate edges based on feature cosine similarity threshold.""" similarity_matrix = self.compute_cosine_similarity(self.features) candidate_edges = [] for u_node in unlabeled_nodes: for l_node in labeled_nodes: if similarity_matrix[u_node, l_node] > self.threshold_s: candidate_edges.append([u_node, l_node]) print(f"Generated {len(candidate_edges)} candidate edges based on cosine similarity") return candidate_edges
[docs] def train_edge_predictor(self, labeled_nodes, predicted_labels, epochs=100): """Train the auxiliary edge prediction model.""" print("Training edge predictor...") self.edge_predictor.train() # Create node labels tensor; only queried nodes are labeled. train_labels = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device) train_labels[labeled_nodes] = predicted_labels for epoch in range(epochs): self.optimizer_edge.zero_grad() # Forward pass through edge predictor node_embeddings, node_logits = self.edge_predictor(self.graph, self.features) # Node classification loss (supervised on labeled nodes) labeled_mask = train_labels != -1 if labeled_mask.sum() > 0: node_loss = F.cross_entropy(node_logits[labeled_mask], train_labels[labeled_mask]) else: node_loss = torch.tensor(0.0, device=self.device) # Positive and negative edge samples src_nodes, dst_nodes = self.graph.edges() positive_pairs = list(zip(src_nodes.cpu().numpy(), dst_nodes.cpu().numpy())) pos_edge_probs = self.edge_predictor.predict_edges(node_embeddings, positive_pairs) pos_loss = -torch.log(pos_edge_probs + 1e-15).mean() negative_pairs = [] num_neg_samples = min(len(positive_pairs), 1000) for _ in range(num_neg_samples): src = random.randint(0, self.num_nodes - 1) dst = random.randint(0, self.num_nodes - 1) if src != dst and not self.graph_data.has_edges_between(src, dst): negative_pairs.append([src, dst]) if negative_pairs: neg_edge_probs = self.edge_predictor.predict_edges(node_embeddings, negative_pairs) neg_loss = -torch.log(1 - neg_edge_probs + 1e-15).mean() else: neg_loss = torch.tensor(0.0, device=self.device) total_loss = node_loss + 0.5 * (pos_loss + neg_loss) total_loss.backward() self.optimizer_edge.step() if epoch % 20 == 0: print(f"Epoch {epoch:3d}: total={total_loss.item():.4f}, " f"node={node_loss.item():.4f}, edge={(pos_loss + neg_loss).item():.4f}")
[docs] def add_potential_edges(self, candidate_edges, labeled_nodes): """Add potential edges whose predicted probability exceeds the threshold.""" if not candidate_edges: return self.graph print("Predicting edge weights and adding potential edges...") self.edge_predictor.eval() with torch.no_grad(): node_embeddings, _ = self.edge_predictor(self.graph, self.features) edge_probs = self.edge_predictor.predict_edges(node_embeddings, candidate_edges) selected_edges = [] for i, (src, dst) in enumerate(candidate_edges): if edge_probs[i] > self.threshold_a: selected_edges.extend([(src, dst), (dst, src)]) # undirected print(f"Selected {len(selected_edges) // 2} potential edges to add") if selected_edges: enhanced_graph = dgl.add_edges( self.graph, [e[0] for e in selected_edges], [e[1] for e in selected_edges] ) return enhanced_graph else: return self.graph
[docs] def train_surrogate_model(self, enhanced_graph, labeled_nodes, predicted_labels, epochs=200): """Train the surrogate model on queried nodes and pseudo labels.""" print("Training surrogate model...") self.surrogate_model.train() # Build training labels for queried nodes. train_labels = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device) train_labels[labeled_nodes] = predicted_labels labeled_mask = train_labels != -1 for epoch in range(epochs): self.optimizer_surrogate.zero_grad() logits = self.surrogate_model(enhanced_graph, self.features) if labeled_mask.sum() > 0: logp = F.log_softmax(logits, dim=1) loss = F.nll_loss(logp[labeled_mask], train_labels[labeled_mask]) loss.backward() self.optimizer_surrogate.step() if epoch % 50 == 0: print(f"Surrogate epoch {epoch:3d}, loss={loss.item():.4f}")
[docs] def _evaluate_and_update_metrics(self, enhanced_graph, metric: AttackMetric, metric_comp: AttackCompMetric): """Evaluate surrogate against target on the real test set and update metric containers.""" # Target inference t0 = time.time() with torch.no_grad(): logits_target = self.net1(self.graph, self.features) metric_comp.update(inference_target_time=(time.time() - t0)) target_preds = F.log_softmax(logits_target, dim=1).argmax(dim=1) # Surrogate inference t0 = time.time() with torch.no_grad(): logits_surrogate = self.surrogate_model(enhanced_graph, self.features) metric_comp.update(inference_surrogate_time=(time.time() - t0)) surrogate_preds = F.log_softmax(logits_surrogate, dim=1).argmax(dim=1) # Update performance metrics with ground truth and target predictions on test split. mask = self.graph_data.ndata['test_mask'] labels = self.graph_data.ndata['label'] metric.update(surrogate_preds[mask], labels[mask], target_preds[mask])
[docs] def attack(self): """Execute the attack and return two JSON-like dicts: performance and computation metrics.""" metric = AttackMetric() metric_comp = AttackCompMetric() print("=" * 60) print("Starting GNN Model Extraction Attack (Realistic)") print("=" * 60) attack_start = time.time() # Step 1: Randomly select query nodes according to the budget. all_nodes = list(range(self.num_nodes)) labeled_nodes = random.sample(all_nodes, self.attack_node_number) unlabeled_nodes = [n for n in all_nodes if n not in labeled_nodes] print(f"Selected {len(labeled_nodes)} nodes for querying") # Step 2: Query the target model once for pseudo labels. t_q = time.time() predicted_labels = self.simulate_target_model_queries(labeled_nodes) query_time = time.time() - t_q metric_comp.update(query_target_time=query_time) print("Finished querying the target model") # Step 3: Generate candidate edges (feature similarity). candidate_edges = self.generate_candidate_edges(labeled_nodes, unlabeled_nodes) # Step 4: Train the auxiliary edge predictor (included in total attack time). self.train_edge_predictor(labeled_nodes, predicted_labels) # Step 5: Add potential edges to obtain an enhanced graph. enhanced_graph = self.add_potential_edges(candidate_edges, labeled_nodes) original_edges = self.graph_data.num_edges() enhanced_edges = enhanced_graph.num_edges() print(f"Enhanced graph: {original_edges} -> {enhanced_edges} edges (+{enhanced_edges - original_edges})") # Step 6: Train the surrogate model and record its training time. t_train_surr = time.time() self.train_surrogate_model(enhanced_graph, labeled_nodes, predicted_labels) train_surrogate_time = time.time() - t_train_surr metric_comp.update(train_surrogate_time=train_surrogate_time) # Step 7: One-shot evaluation and metrics update. self._evaluate_and_update_metrics(enhanced_graph, metric, metric_comp) # Finalize computation stats. metric_comp.end() metric_comp.update(attack_time=(time.time() - attack_start)) # Return two JSON-like dicts as required by the new API. res = metric.compute() res_comp = metric_comp.compute() return res, res_comp