import random
import warnings
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics.pairwise import cosine_similarity
from pygip.models.nn.backbones import GCN
from .base import BaseAttack
warnings.filterwarnings('ignore')
[docs]class DGLEdgePredictor(nn.Module):
"""DGL version of edge prediction module"""
def __init__(self, input_dim, hidden_dim, num_classes, device):
super(DGLEdgePredictor, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.device = device
# Use the same GCN backbone as the target model
self.gnn = GCN(input_dim, hidden_dim)
self.node_classifier = nn.Linear(hidden_dim, num_classes)
# Edge prediction layer
self.edge_predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
[docs] def forward(self, graph, features):
# Get node embeddings
node_embeddings = self.gnn(graph, features)
# Node classification
node_logits = self.node_classifier(node_embeddings)
return node_embeddings, node_logits
[docs] def predict_edges(self, node_embeddings, node_pairs):
"""Predict edge existence probability"""
if len(node_pairs) == 0:
return torch.tensor([], device=self.device)
node_pairs = torch.tensor(node_pairs, device=self.device)
src_embeddings = node_embeddings[node_pairs[:, 0]]
dst_embeddings = node_embeddings[node_pairs[:, 1]]
# Concatenate source and destination node embeddings
edge_features = torch.cat([src_embeddings, dst_embeddings], dim=1)
edge_probs = self.edge_predictor(edge_features).squeeze()
return edge_probs
[docs]class DGLSurrogateModel(nn.Module):
"""DGL version of surrogate model"""
def __init__(self, input_dim, num_classes, model_type='GCN'):
super(DGLSurrogateModel, self).__init__()
self.model_type = model_type
if model_type == 'GCN':
self.gnn = GCN(input_dim, num_classes)
else:
# Can extend to other model types
self.gnn = GCN(input_dim, num_classes)
[docs] def forward(self, graph, features):
return self.gnn(graph, features)
[docs]class RealisticAttack(BaseAttack):
"""DGL-based GNN model extraction attack"""
supported_api_types = {"dgl"}
supported_datasets = {}
def __init__(self, dataset, attack_node_fraction: float, model_path: str = None,
hidden_dim: int = 64, threshold_s: float = 0.7, threshold_a: float = 0.5):
super().__init__(dataset, attack_node_fraction, model_path)
self.hidden_dim = hidden_dim
self.threshold_s = threshold_s # Cosine similarity threshold
self.threshold_a = threshold_a # Edge prediction threshold
self.attack_node_number = int(self.num_nodes * self.attack_node_fraction)
self.graph_data = self.graph_data.to(self.device)
self.graph = self.graph_data
self.features = self.graph.ndata['feat']
# Initialize edge predictor and surrogate model
self.edge_predictor = DGLEdgePredictor(
self.num_features, hidden_dim, self.num_classes, self.device
).to(self.device)
self.surrogate_model = DGLSurrogateModel(
self.num_features, self.num_classes
).to(self.device)
# net
self.net1 = GCN(self.num_features, self.num_classes).to(self.device)
# Optimizers
self.optimizer_edge = optim.Adam(self.edge_predictor.parameters(), lr=0.01, weight_decay=5e-4)
self.optimizer_surrogate = optim.Adam(self.surrogate_model.parameters(), lr=0.01, weight_decay=5e-4)
print(f"Initialized attack on {dataset.dataset_name} dataset")
print(f"Nodes: {self.num_nodes}, Features: {self.num_features}, Classes: {self.num_classes}")
print(f"Attack nodes: {self.attack_node_number} ({attack_node_fraction:.1%})")
[docs] def simulate_target_model_queries(self, query_nodes, error_rate=0.15):
"""Simulate target model queries with a certain proportion of incorrect labels"""
self.net1.eval()
with torch.no_grad():
logits = self.net1(self.graph, self.features)
predictions = F.log_softmax(logits, dim=1).argmax(dim=1)
# Get predicted labels for query nodes
predicted_labels = predictions[query_nodes].clone()
# Introduce incorrect labels
num_errors = int(len(predicted_labels) * error_rate)
if num_errors > 0:
error_indices = random.sample(range(len(predicted_labels)), num_errors)
for idx in error_indices:
# Randomly assign incorrect labels
wrong_label = random.randint(0, self.num_classes - 1)
predicted_labels[idx] = wrong_label
return predicted_labels
[docs] def compute_cosine_similarity(self, features):
"""Compute cosine similarity of node features"""
features_np = features.cpu().detach().numpy()
similarity_matrix = cosine_similarity(features_np)
return torch.tensor(similarity_matrix, dtype=torch.float32, device=self.device)
[docs] def generate_candidate_edges(self, labeled_nodes, unlabeled_nodes):
"""Generate candidate edge set"""
similarity_matrix = self.compute_cosine_similarity(self.features)
candidate_edges = []
for u_node in unlabeled_nodes:
for l_node in labeled_nodes:
if similarity_matrix[u_node, l_node] > self.threshold_s:
candidate_edges.append([u_node, l_node])
print(f"Generated {len(candidate_edges)} candidate edges based on cosine similarity")
return candidate_edges
[docs] def train_edge_predictor(self, labeled_nodes, predicted_labels, epochs=100):
"""Train edge prediction model"""
print("Training edge predictor...")
self.edge_predictor.train()
# Create training labels - only queried nodes have labels, others are -1
train_labels = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device)
train_labels[labeled_nodes] = predicted_labels
for epoch in range(epochs):
self.optimizer_edge.zero_grad()
# Forward pass
node_embeddings, node_logits = self.edge_predictor(self.graph, self.features)
# Node classification loss (only for labeled nodes)
labeled_mask = train_labels != -1
if labeled_mask.sum() > 0:
node_loss = F.cross_entropy(node_logits[labeled_mask], train_labels[labeled_mask])
else:
node_loss = torch.tensor(0.0, device=self.device)
# Edge prediction loss
src_nodes, dst_nodes = self.graph.edges()
positive_pairs = list(zip(src_nodes.cpu().numpy(), dst_nodes.cpu().numpy()))
# Positive samples
pos_edge_probs = self.edge_predictor.predict_edges(node_embeddings, positive_pairs)
pos_loss = -torch.log(pos_edge_probs + 1e-15).mean()
# Negative samples
negative_pairs = []
num_neg_samples = min(len(positive_pairs), 1000) # Limit negative sample size
for _ in range(num_neg_samples):
src = random.randint(0, self.num_nodes - 1)
dst = random.randint(0, self.num_nodes - 1)
if src != dst and not self.graph_data.has_edges_between(src, dst):
negative_pairs.append([src, dst])
if negative_pairs:
neg_edge_probs = self.edge_predictor.predict_edges(node_embeddings, negative_pairs)
neg_loss = -torch.log(1 - neg_edge_probs + 1e-15).mean()
else:
neg_loss = torch.tensor(0.0, device=self.device)
# Total loss
total_loss = node_loss + 0.5 * (pos_loss + neg_loss)
total_loss.backward()
self.optimizer_edge.step()
if epoch % 20 == 0:
print(f"Epoch {epoch:3d}: Total Loss: {total_loss.item():.4f}, "
f"Node Loss: {node_loss.item():.4f}, Edge Loss: {(pos_loss + neg_loss).item():.4f}")
[docs] def add_potential_edges(self, candidate_edges, labeled_nodes):
"""Add potential edges based on edge prediction results"""
if not candidate_edges:
return self.graph
print("Predicting edge weights and adding potential edges...")
self.edge_predictor.eval()
with torch.no_grad():
node_embeddings, _ = self.edge_predictor(self.graph, self.features)
edge_probs = self.edge_predictor.predict_edges(node_embeddings, candidate_edges)
# Select edges with probability above threshold
selected_edges = []
for i, (src, dst) in enumerate(candidate_edges):
if edge_probs[i] > self.threshold_a:
selected_edges.extend([(src, dst), (dst, src)]) # Undirected graph
print(f"Selected {len(selected_edges) // 2} potential edges to add")
if selected_edges:
# Create new graph and add edges
enhanced_graph = dgl.add_edges(
self.graph,
[e[0] for e in selected_edges],
[e[1] for e in selected_edges]
)
return enhanced_graph
else:
return self.graph
[docs] def train_surrogate_model(self, enhanced_graph, labeled_nodes, predicted_labels, epochs=200):
"""Train surrogate model"""
print("Training surrogate model...")
self.surrogate_model.train()
# Create training labels
train_labels = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device)
train_labels[labeled_nodes] = predicted_labels
labeled_mask = train_labels != -1
for epoch in range(epochs):
self.optimizer_surrogate.zero_grad()
# Forward pass
logits = self.surrogate_model(enhanced_graph, self.features)
if labeled_mask.sum() > 0:
logp = F.log_softmax(logits, dim=1)
loss = F.nll_loss(logp[labeled_mask], train_labels[labeled_mask])
loss.backward()
self.optimizer_surrogate.step()
if epoch % 50 == 0:
print(f"Surrogate Model Epoch {epoch:3d}, Loss: {loss.item():.4f}")
[docs] def evaluate_attack(self, enhanced_graph):
"""Evaluate attack performance"""
print("\nEvaluating attack performance...")
# Evaluate surrogate model
self.surrogate_model.eval()
with torch.no_grad():
surrogate_logits = self.surrogate_model(enhanced_graph, self.features)
surrogate_pred = F.log_softmax(surrogate_logits, dim=1).argmax(dim=1)
# Calculate accuracy on test set
test_acc = (surrogate_pred[self.graph_data.ndata['test_mask']] == self.graph_data.ndata['label'][
self.graph_data.ndata['test_mask']]).float().mean()
# Evaluate fidelity with target model
self.net1.eval()
with torch.no_grad():
target_logits = self.net1(self.graph, self.features)
target_pred = F.log_softmax(target_logits, dim=1).argmax(dim=1)
# Calculate fidelity (consistency on test set)
fidelity = (surrogate_pred[self.graph_data.ndata['test_mask']] == target_pred[
self.graph_data.ndata['test_mask']]).float().mean()
return test_acc.item(), fidelity.item(), surrogate_pred
[docs] def attack(self):
"""Execute model extraction attack"""
print("=" * 60)
print("Starting GNN Model Extraction Attack")
print("=" * 60)
# Step 1: Randomly select query nodes
all_nodes = list(range(self.num_nodes))
labeled_nodes = random.sample(all_nodes, self.attack_node_number)
unlabeled_nodes = [n for n in all_nodes if n not in labeled_nodes]
print(f"Selected {len(labeled_nodes)} nodes for querying")
# Step 2: Simulate target model queries
predicted_labels = self.simulate_target_model_queries(labeled_nodes)
print(f"Simulated target model queries with ~15% error rate")
# Step 3: Generate candidate edges
candidate_edges = self.generate_candidate_edges(labeled_nodes, unlabeled_nodes)
# Step 4: Train edge prediction model
self.train_edge_predictor(labeled_nodes, predicted_labels)
# Step 5: Add potential edges
enhanced_graph = self.add_potential_edges(candidate_edges, labeled_nodes)
original_edges = self.graph_data.num_edges()
enhanced_edges = enhanced_graph.num_edges()
print(f"Enhanced graph: {original_edges} -> {enhanced_edges} edges (+{enhanced_edges - original_edges})")
# Step 6: Train surrogate model
self.train_surrogate_model(enhanced_graph, labeled_nodes, predicted_labels)
# Step 7: Evaluate attack performance
test_accuracy, fidelity, predictions = self.evaluate_attack(enhanced_graph)
print("=" * 60)
print("Attack Results:")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Fidelity (vs Target Model): {fidelity:.4f}")
print("=" * 60)
return {
'test_accuracy': test_accuracy,
'fidelity': fidelity,
'enhanced_graph': enhanced_graph,
'predictions': predictions,
'labeled_nodes': labeled_nodes,
'predicted_labels': predicted_labels
}