import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_adj, dense_to_sparse
from tqdm import tqdm
from pygip.models.nn.backbones import GCN_PyG
from .base import BaseDefense
[docs]class ImperceptibleWM(BaseDefense):
supported_api_types = {"pyg"}
def __init__(self, dataset, attack_node_fraction=0.3, model_path=None):
super().__init__(dataset, attack_node_fraction)
# load data
self.dataset = dataset
self.graph_dataset = dataset.graph_dataset
self.graph_data = dataset.graph_data.to(self.device)
self.attack_node_fraction = attack_node_fraction
self.model_path = model_path
self.owner_id = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9], dtype=torch.float32, device=self.graph_data.x.device)
in_feats = dataset.num_features
num_classes = dataset.num_classes
self.generator = TriggerGenerator(in_feats, 64, self.owner_id).to(self.device)
self.model = GCN_PyG(in_feats, 128, num_classes).to(self.device)
[docs] def defend(self):
pyg_data = self.graph_data
bi_level_optimization(self.model, self.generator, pyg_data)
trigger_data = generate_trigger_graph(pyg_data, self.generator, self.model)
metrics = calculate_metrics(self.model, trigger_data)
print("========================Final results:=========================================")
for name, value in metrics.items():
print(f"{name}: {value:.4f}")
return metrics
[docs] def _load_model(self):
if self.model_path:
self.model.load_state_dict(torch.load(self.model_path))
[docs] def _train_target_model(self):
# optional if you split training from watermarking
pass
[docs] def _train_defense_model(self):
return self.model
[docs] def _train_surrogate_model(self):
pass
[docs]class TriggerGenerator(nn.Module):
def __init__(self, in_channels, hidden_channels, owner_id):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, in_channels)
self.owner_id = owner_id
[docs] def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = torch.sigmoid(self.conv2(x, edge_index))
out = x.clone()
out[:, -5:] = self.owner_id
return out
[docs]def generate_trigger_graph(data, generator, target_model, num_triggers=50):
with torch.no_grad():
probs = F.softmax(target_model(data.x, data.edge_index), dim=1)
selected_nodes = []
for class_idx in range(probs.size(1)):
class_nodes = torch.where(data.y == class_idx)[0]
if len(class_nodes) > 0:
selected_nodes.append(class_nodes[probs[class_nodes, class_idx].argmax()].item())
trigger_features = generator(data.x, data.edge_index)
trigger_nodes = list(range(data.num_nodes, data.num_nodes + num_triggers))
total_nodes = data.num_nodes + num_triggers
# Create new dense adjacency matrix
adj = to_dense_adj(data.edge_index)[0]
new_adj = torch.zeros((total_nodes, total_nodes), device=adj.device)
new_adj[:adj.size(0), :adj.size(1)] = adj
# Connect trigger nodes to selected nodes
for i, trigger in enumerate(trigger_nodes):
for node in selected_nodes:
new_adj[node, trigger] = 1
new_adj[trigger, node] = 1
new_data = copy.deepcopy(data)
new_data.x = torch.cat([data.x, trigger_features[:num_triggers]], dim=0)
new_data.edge_index = dense_to_sparse(new_adj)[0]
new_data.y = torch.cat([
data.y,
torch.zeros(num_triggers, dtype=torch.long, device=data.y.device)
])
new_data.train_mask = torch.cat([
data.train_mask,
torch.zeros(num_triggers, dtype=torch.bool, device=data.x.device)
])
new_data.val_mask = torch.cat([
data.val_mask,
torch.zeros(num_triggers, dtype=torch.bool, device=data.x.device)
])
new_data.test_mask = torch.cat([
data.test_mask,
torch.zeros(num_triggers, dtype=torch.bool, device=data.x.device)
])
new_data.original_test_mask = data.test_mask.clone()
# Add trigger info
new_data.trigger_nodes = trigger_nodes
new_data.selected_nodes = selected_nodes
new_data.trigger_mask = torch.zeros(total_nodes, dtype=torch.bool, device=data.x.device)
new_data.trigger_mask[trigger_nodes] = True
return new_data
[docs]def bi_level_optimization(target_model, generator, data, epochs=100, inner_steps=5):
optimizer_model = torch.optim.Adam(target_model.parameters(), lr=0.01)
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in tqdm(range(epochs)):
for _ in range(inner_steps):
optimizer_model.zero_grad()
trigger_data = generate_trigger_graph(data, generator, target_model)
out_clean = target_model(data.x, data.edge_index)
out_trigger = target_model(trigger_data.x, trigger_data.edge_index)
clean_loss = criterion(out_clean[data.train_mask], data.y[data.train_mask])
trigger_loss = criterion(out_trigger[trigger_data.trigger_mask],
trigger_data.y[trigger_data.trigger_mask])
total_loss = clean_loss + trigger_loss
total_loss.backward()
optimizer_model.step()
optimizer_gen.zero_grad()
trigger_data = generate_trigger_graph(data, generator, target_model)
orig_features = data.x[trigger_data.selected_nodes]
trigger_features = trigger_data.x[trigger_data.trigger_nodes]
sim_loss = -F.cosine_similarity(orig_features.unsqueeze(1),
trigger_features.unsqueeze(0), dim=-1).mean()
out = target_model(trigger_data.x, trigger_data.edge_index)
trigger_loss = criterion(out[trigger_data.trigger_mask],
trigger_data.y[trigger_data.trigger_mask])
owner_loss = F.binary_cross_entropy(
trigger_data.x[trigger_data.trigger_nodes, -5:],
generator.owner_id.expand(len(trigger_data.trigger_nodes), 5)
)
total_gen_loss = 0.4 * sim_loss + 0.4 * trigger_loss + 0.2 * owner_loss
total_gen_loss.backward()
optimizer_gen.step()
[docs]def calculate_metrics(model, data):
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
true = data.y
# Handle both original and watermarked data cases
if hasattr(data, 'original_test_mask'):
test_mask = data.original_test_mask
if test_mask.size(0) < pred.size(0):
pad_len = pred.size(0) - test_mask.size(0)
test_mask = torch.cat([test_mask, torch.zeros(pad_len, dtype=torch.bool, device=test_mask.device)])
else:
test_mask = data.test_mask
metrics = {
'accuracy': (pred[test_mask] == true[test_mask]).float().mean().item(),
'precision': precision_score(true[test_mask].cpu(), pred[test_mask].cpu(), average='macro'),
'recall': recall_score(true[test_mask].cpu(), pred[test_mask].cpu(), average='macro'),
'f1': f1_score(true[test_mask].cpu(), pred[test_mask].cpu(), average='macro'),
'wm_accuracy': None
}
if hasattr(data, 'trigger_nodes'):
wm_mask = torch.zeros(data.x.size(0), dtype=torch.bool, device=data.x.device)
wm_mask[data.trigger_nodes] = True
wm_pred = pred[wm_mask]
wm_true = true[wm_mask]
metrics['wm_accuracy'] = (wm_pred == wm_true).float().mean().item()
return metrics