import time
import dgl
import networkx as nx
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from torch_geometric.data import Data
from tqdm import tqdm
from pygip.models.defense.base import BaseDefense
from pygip.models.nn import GCN
from pygip.utils.metrics import GraphNeuralNetworkMetric
[docs]class SurviveWM(BaseDefense):
supported_api_types = {"dgl"}
def __init__(self, dataset, attack_node_fraction, model_path=None):
super().__init__(dataset, attack_node_fraction)
# load graph data
self.dataset = dataset
self.graph_dataset = dataset.graph_data
self.graph_data = dataset.graph_data.to(device=self.device)
self.model_path = model_path
self.graph = self.graph_data
self.features = self.graph_data.ndata['feat']
self.labels = self.graph_data.ndata['label']
self.train_mask = self.graph_data.ndata['train_mask']
self.test_mask = self.graph_data.ndata['test_mask']
# load meta data
self.feature_number = dataset.num_features
self.label_number = dataset.num_classes
# params
self.attack_node_fraction = attack_node_fraction
[docs] def _load_model(self):
"""
Load a pre-trained model.
"""
assert self.model_path is not None, "Please provide a pre-trained model"
# Create the model
self.net1 = GCN(self.feature_number, self.label_number).to(self.device)
# Load the saved state dict
self.net1.load_state_dict(torch.load(self.model_path, map_location=self.device))
# Set to evaluation mode
self.net1.eval()
[docs] def _to_cpu(self, tensor):
"""
Safely move tensor to CPU for NumPy operations
"""
if tensor.is_cuda:
return tensor.cpu()
return tensor
# === Soft Nearest Neighbor Loss ===
[docs] def snn_loss(self, x, y, T=0.5):
x = F.normalize(x, p=2, dim=1)
dist_matrix = torch.cdist(x, x, p=2) ** 2
eye = torch.eye(len(x), device=self.device).bool()
sim = torch.exp(-dist_matrix / T)
mask_same = y.unsqueeze(1) == y.unsqueeze(0)
sim = sim.masked_fill(eye, 0)
denom = sim.sum(1)
nom = (sim * mask_same.float()).sum(1)
loss = -torch.log(nom / (denom + 1e-10) + 1e-10).mean()
return loss
# === Trigger Graph Generator ===
[docs] def generate_key_graph(self, num_nodes=10, edge_prob=0.3):
trigger = nx.erdos_renyi_graph(num_nodes, edge_prob)
edge_index = torch.tensor(list(trigger.edges), dtype=torch.long).t().contiguous()
if edge_index.numel() == 0:
edge_index = torch.empty((2, 0), dtype=torch.long)
x = torch.randn((num_nodes, self.feature_number)) * 0.1
label = torch.randint(0, self.label_number, (num_nodes,))
return Data(x=x, edge_index=edge_index, y=label)
# === Combine base and trigger ===
[docs] def combine_with_trigger(self, base_graph, base_features, base_labels, trigger_data):
# Convert DGL graph to edge_index format
src, dst = base_graph.edges()
base_edge_index = torch.stack([src, dst], dim=0)
x = torch.cat([base_features, trigger_data.x], dim=0)
edge_index = torch.cat([base_edge_index, trigger_data.edge_index + base_features.size(0)], dim=1)
y = torch.cat([base_labels, trigger_data.y], dim=0)
# Create DGL graph from combined data
src_combined, dst_combined = edge_index[0], edge_index[1]
combined_graph = dgl.graph((src_combined, dst_combined), num_nodes=x.size(0)).to(self.device)
# **FIX: Add self-loops to handle zero in-degree nodes**
combined_graph = dgl.add_self_loop(combined_graph)
combined_graph.ndata['feat'] = x.to(self.device)
return combined_graph, y.to(self.device)
[docs] def train_with_snnl(self, model, graph, features, labels, train_mask, optimizer, T=0.5, alpha=0.1):
model.train()
optimizer.zero_grad()
out = model(graph, features)
loss_nll = F.nll_loss(F.log_softmax(out, dim=1)[train_mask], labels[train_mask])
snnl = self.snn_loss(out[train_mask], labels[train_mask], T=T)
loss = loss_nll - alpha * snnl
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def verify_watermark(self, model, trigger_graph, trigger_labels):
model.eval()
out = model(trigger_graph, trigger_graph.ndata['feat'])
pred = out.argmax(dim=1)
return (pred == trigger_labels).float().mean().item()
[docs] def compute_metrics(self, y_true, y_pred, y_score=None):
return {
'accuracy': accuracy_score(y_true, y_pred),
'f1': f1_score(y_true, y_pred, average='macro'),
'precision': precision_score(y_true, y_pred, average='macro'),
'recall': recall_score(y_true, y_pred, average='macro'),
'auroc': roc_auc_score(y_true, y_score, multi_class='ovo') if y_score is not None else None
}
[docs] def defend(self):
print("=========SurviveWM Attack==========================")
# Generate trigger graph
trigger_data = self.generate_key_graph().to(self.device)
# Combine base graph with trigger
combined_graph, combined_labels = self.combine_with_trigger(
self.graph, self.features, self.labels, trigger_data)
# Create train mask for combined data (include trigger nodes in training)
base_train_mask = self.train_mask
trigger_train_mask = torch.ones(trigger_data.num_nodes, dtype=torch.bool, device=self.device)
combined_train_mask = torch.cat([base_train_mask, trigger_train_mask])
# Create test mask for combined data (exclude trigger nodes from testing)
base_test_mask = self.test_mask
trigger_test_mask = torch.zeros(trigger_data.num_nodes, dtype=torch.bool, device=self.device)
combined_test_mask = torch.cat([base_test_mask, trigger_test_mask])
# Create watermarked model
watermarked_model = GCN(self.feature_number, self.label_number).to(self.device)
optimizer = torch.optim.Adam(watermarked_model.parameters(), lr=0.01, weight_decay=5e-4)
# Create trigger graph for watermark verification
trigger_src, trigger_dst = trigger_data.edge_index[0], trigger_data.edge_index[1]
trigger_graph = dgl.graph((trigger_src, trigger_dst), num_nodes=trigger_data.num_nodes).to(self.device)
# **FIX: Add self-loops to trigger graph as well**
trigger_graph = dgl.add_self_loop(trigger_graph)
trigger_graph.ndata['feat'] = trigger_data.x.to(self.device)
dur = []
best_performance_metrics = GraphNeuralNetworkMetric()
print("Training watermarked model...")
for epoch in tqdm(range(200)):
if epoch >= 3:
t0 = time.time()
# Train with SNNL
loss = self.train_with_snnl(
watermarked_model, combined_graph, combined_graph.ndata['feat'],
combined_labels, combined_train_mask, optimizer)
if epoch >= 3:
dur.append(time.time() - t0)
# Evaluate periodically
if epoch % 20 == 0:
watermarked_model.eval()
with torch.no_grad():
# Test on original graph (ensure it has self-loops)
test_graph = dgl.add_self_loop(self.graph)
logits = watermarked_model(test_graph, self.features)
pred = logits.argmax(dim=1)
test_acc = (pred[self.test_mask] == self.labels[self.test_mask]).float().mean()
# Verify watermark
wm_acc = self.verify_watermark(watermarked_model, trigger_graph, trigger_data.y)
print(f"Epoch {epoch}, Test Acc: {test_acc.item():.4f}, Watermark Acc: {wm_acc:.4f}")
# Final evaluation
watermarked_model.eval()
with torch.no_grad():
# Evaluate on test set (ensure graph has self-loops)
test_graph = dgl.add_self_loop(self.graph)
logits = watermarked_model(test_graph, self.features)
pred = logits.argmax(dim=1)
probs = F.softmax(logits, dim=1)
test_metrics = self.compute_metrics(
self._to_cpu(self.labels[self.test_mask]).numpy(),
self._to_cpu(pred[self.test_mask]).numpy(),
self._to_cpu(probs[self.test_mask]).numpy()
)
# Verify watermark
wm_acc = self.verify_watermark(watermarked_model, trigger_graph, trigger_data.y)
# Create custom metrics object
final_metrics = GraphNeuralNetworkMetric()
final_metrics.accuracy = test_metrics['accuracy']
final_metrics.fidelity = wm_acc # Use watermark accuracy as fidelity measure
print("========================Final results:=========================================")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test F1: {test_metrics['f1']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Watermark Accuracy: {wm_acc:.4f}")
print(final_metrics)
self.net2 = watermarked_model
return final_metrics, watermarked_model