import copy
import random
import dgl
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch_geometric.utils import to_networkx, from_networkx
from tqdm import tqdm
from pygip.models.nn import GCN
from .base import BaseDefense
[docs]class QueryBasedVerificationDefense(BaseDefense):
supported_api_types = {"dgl"}
supported_datasets = {}
def __init__(self, dataset, attack_node_fraction, model_path=None):
super().__init__(dataset, attack_node_fraction)
self.model_path = model_path
[docs] def defend(self, fingerprint_mode='inductive', knowledge='full', attack_type='bitflip',
k=5, num_trials=10, use_edge_perturbation=False, verbose=True, **kwargs):
"""
Main defense routine. Generates fingerprints, runs attacks, and verifies integrity.
Returns a dict with per-trial and average metrics.
"""
trial_results = []
for trial in range(num_trials):
if verbose:
print(f"\n=== Trial {trial + 1}/{num_trials} ===")
model_clean = self._train_target_model()
acc_clean = self._evaluate_accuracy(model_clean, self.dataset)
fingerprints = self._generate_fingerprints(model_clean, mode=fingerprint_mode, knowledge=knowledge, k=k,
perturb_fingerprints=use_edge_perturbation,
perturb_budget=kwargs.get('perturb_budget', 5),
**kwargs)
bit = kwargs.pop('bit', 30)
bfa_variant = kwargs.pop('bfa_variant', 'BFA')
poisoned_model, attack_info = self._run_attack(
model_clean,
attack_type=attack_type,
knowledge=knowledge,
bit=bit,
bfa_variant=bfa_variant,
**kwargs
)
poisoned_dataset = copy.deepcopy(self.dataset)
if 'graph' in attack_info:
poisoned_dataset.graph_data = attack_info['graph']
acc_poisoned = self._evaluate_accuracy(poisoned_model, poisoned_dataset)
flipped_info = self._evaluate_fingerprints(poisoned_model, fingerprints)
flip_rate = flipped_info['flip_rate']
acc_drop = acc_clean - acc_poisoned
num_flipped = len(flipped_info['flipped'])
num_total = len(fingerprints)
detection_rate = num_flipped / num_total if num_total > 0 else 0.0
if verbose:
print(f"Clean Accuracy: {acc_clean:.4f}")
print(f"Poisoned Accuracy: {acc_poisoned:.4f}")
print(f"Accuracy Drop: {acc_drop:.4f}")
print(f"Flip Rate: {flip_rate:.4f}")
print(f"Detection Rate: {detection_rate:.4f}")
trial_results.append({
'flip_rate': flip_rate,
'accuracy_drop': acc_drop,
'detection_rate': detection_rate
})
avg_flip_rate = sum(r['flip_rate'] for r in trial_results) / num_trials
avg_acc_drop = sum(r['accuracy_drop'] for r in trial_results) / num_trials
avg_detection_rate = sum(r['detection_rate'] for r in trial_results) / num_trials
return {
'trial_results': trial_results,
'average_flip_rate': avg_flip_rate,
'average_accuracy_drop': avg_acc_drop,
'average_detection_rate': avg_detection_rate
}
[docs] def _get_features(self):
return self.graph_data.ndata['feat'] if hasattr(self.graph_data, 'ndata') else self.graph_data.x
[docs] def _train_target_model(self, epochs=200):
"""
Trains target GCN model according to protocol in
Wu et al. (2023), Section 6.1 for graph node classification.
Returns
-------
model : torch.nn.Module
The trained GCN model.
"""
model = GCN(
feature_number=self.dataset.num_features,
label_number=self.dataset.num_classes
).to(self.device)
print(f"Training target model on device: {self.device} ...")
optimizer = Adam(model.parameters(), lr=0.02)
loss_fn = torch.nn.NLLLoss()
features = self._get_features().to(self.device)
labels = self.dataset.graph_data.ndata['label'].to(self.device)
train_mask = self.dataset.graph_data.ndata['train_mask'].to(self.device)
val_mask = getattr(self.dataset.graph_data.ndata, "val_mask", None)
if val_mask is None:
val_mask = self.dataset.graph_data.ndata['test_mask']
val_mask = val_mask.to(self.device)
for epoch in range(epochs):
model.train()
logits = model(self.graph_data.to(self.device), features)
log_probs = F.log_softmax(logits, dim=1)
loss = loss_fn(log_probs[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0 or epoch == 0:
model.eval()
with torch.no_grad():
val_logits = model(self.graph_data.to(self.device), features)
val_log_probs = F.log_softmax(val_logits, dim=1)
val_pred = val_log_probs[val_mask].max(1)[1]
val_acc = (val_pred == labels[val_mask]).float().mean().item()
print(f"Epoch {epoch + 1}: Loss={loss.item():.4f} | Val Acc={val_acc:.4f}")
return model
[docs] def _load_model(self, model_path):
model = GCN(
in_feats=self.dataset.feature_number,
hidden_feats=16,
out_feats=self.dataset.label_number
)
model.load_state_dict(torch.load(model_path))
return model
[docs] def _generate_fingerprints(self, model, mode='transductive', knowledge='full', k=5, **kwargs):
"""
Wrapper for fingerprint generation based on mode and knowledge level.
Returns:
List of fingerprints
"""
if mode == 'transductive':
generator = TransductiveFingerprintGenerator(
model=model,
dataset=self.dataset,
candidate_fraction=kwargs.get('candidate_fraction', 1.0),
random_seed=kwargs.get('random_seed', None),
device=self.device,
randomize=kwargs.get('randomize', True),
)
fingerprints = generator.generate_fingerprints(k=k, method=knowledge)
unified_fingerprints = [(self.graph_data, node_id, label) for (node_id, label) in fingerprints]
elif mode == 'inductive':
generator = InductiveFingerprintGenerator(
model=model,
dataset=self.dataset,
shadow_graph=self.dataset.graph_data,
knowledge=knowledge,
candidate_fraction=kwargs.get('candidate_fraction', 0.3),
num_fingerprints=k,
randomize=kwargs.get('randomize', True),
random_seed=kwargs.get('random_seed', None),
device=self.device,
perturb_fingerprints=kwargs.get('perturb_fingerprints', False),
perturb_budget=kwargs.get('perturb_budget', 5),
)
fingerprints = generator.generate_fingerprints(method=knowledge)
if kwargs.get('perturb_fingerprints', False):
for i, (graph, node_idx, label) in enumerate(fingerprints):
generator.shadow_graph = graph
generator.greedy_edge_perturbation(
node_idx=node_idx,
perturb_budget=kwargs.get('perturb_budget', 5),
knowledge=knowledge
)
fingerprints[i] = (generator.shadow_graph, node_idx, label)
unified_fingerprints = fingerprints
else:
raise ValueError("Unknown fingerprinting mode. Use 'transductive' or 'inductive'.")
return unified_fingerprints
[docs] def _evaluate_fingerprints(self, model, fingerprints):
"""
Checks if fingerprinted nodes have changed labels under the given model.
Args:
model: The model to evaluate.
fingerprints: List of (graph, node_id, label) tuples.
Returns:
results: {
'flipped': List[Tuple[node_id, old_label, new_label]],
'flip_rate': float
}
"""
model.eval()
flipped = []
with torch.no_grad():
for graph, node_id, expected_label in fingerprints:
x = graph.ndata['feat'] if hasattr(graph, 'ndata') else graph.x
logits = model(graph.to(self.device), x.to(self.device))
pred = logits[node_id].argmax().item()
if pred != expected_label:
flipped.append((node_id, expected_label, pred))
return {
'flipped': flipped,
'flip_rate': len(flipped) / len(fingerprints) if fingerprints else 0.0
}
[docs] def _run_attack(self, model, attack_type='mettack', knowledge='full', **kwargs):
"""
Run the specified attack on the model.
Returns:
poisoned_model: torch.nn.Module
metadata: dict with info about the attack
"""
if attack_type == 'bitflip':
bit = kwargs.get('bit', 30)
bfa_variant = kwargs.get('bfa_variant', 'BFA')
attacker = BitFlipAttack(model, attack_type=bfa_variant, bit=bit)
attack_info = attacker.apply()
return model, attack_info
elif attack_type == 'random':
perturbed_graph = self._random_edge_addition_poisoning(
node_fraction=kwargs.get('node_fraction', 0.1),
edges_per_node=kwargs.get('edges_per_node', 5),
random_seed=kwargs.get('random_seed', None),
)
poisoned_model = self._retrain_poisoned_model(
poisoned_graph=perturbed_graph,
epochs=kwargs.get('epochs', 200),
)
return poisoned_model, {'type': 'random_poison', 'graph': perturbed_graph}
elif attack_type == 'mettack':
num_edges = self.graph_data.num_edges()
poison_frac = kwargs.get('poison_frac', 0.05)
n_perturbations = int(poison_frac * num_edges)
helper = MettackHelper(
graph=self.graph_data,
features=self._get_features(),
labels=self.dataset.labels,
train_mask=self.dataset.train_mask,
val_mask=getattr(self.dataset, 'val_mask', None),
test_mask=self.dataset.test_mask,
n_perturbations=n_perturbations,
device=self.device,
max_perturbations=kwargs.get('max_perturbations', 50),
surrogate_epochs=kwargs.get('surrogate_epochs', 30),
candidate_sample_size=kwargs.get('candidate_sample_size', 20),
)
poisoned_graph, attack_metrics = helper.run()
poisoned_model = self._retrain_poisoned_model(
poisoned_graph=poisoned_graph,
epochs=kwargs.get('epochs', 200),
)
return poisoned_model, {'type': 'mettack', 'metrics': attack_metrics, 'graph': poisoned_graph}
else:
raise ValueError(f"Unsupported attack_type: {attack_type}")
[docs] def _random_edge_addition_poisoning(self, node_fraction=0.1, edges_per_node=5, random_seed=None):
"""
Poison a fraction of nodes by adding random edges.
Args:
dataset: Dataset object (DGL-based)
node_fraction: Fraction of nodes to poison (e.g., 0.1 = 10%)
edges_per_node: Number of random edges to add per poisoned node
random_seed: Optional seed
Returns:
poisoned_graph: DGLGraph
"""
if random_seed is not None:
random.seed(random_seed)
torch.manual_seed(random_seed)
poisoned_graph = copy.deepcopy(self.graph_data)
num_nodes = poisoned_graph.num_nodes()
num_poisoned_nodes = int(node_fraction * num_nodes)
poisoned_nodes = random.sample(range(num_nodes), num_poisoned_nodes)
new_edges = []
for src in poisoned_nodes:
for _ in range(edges_per_node):
dst = random.randint(0, num_nodes - 1)
if src != dst and \
not poisoned_graph.has_edges_between(src, dst) and \
not poisoned_graph.has_edges_between(dst, src):
new_edges.append((src, dst))
new_edges.append((dst, src))
if new_edges:
src, dst = zip(*new_edges)
poisoned_graph.add_edges(src, dst)
return poisoned_graph
[docs] def _retrain_poisoned_model(self, poisoned_graph, epochs=200):
"""
Retrain target GCN using the poisoned graph structure.
Args:
dataset: Original Dataset object (provides features, labels, masks)
poisoned_graph: DGLGraph (with new random edges added)
defense_class: The defense class to use for model training (e.g., QueryBasedVerificationDefense)
device: 'cpu' or 'cuda'
Returns:
model: Trained GCN model
"""
dataset_poisoned = copy.deepcopy(self.dataset)
dataset_poisoned.graph_data = poisoned_graph
defense = QueryBasedVerificationDefense(dataset=dataset_poisoned, attack_node_fraction=0.1)
model = defense._train_target_model(epochs=epochs)
return model
[docs] def _evaluate_accuracy(self, model, dataset):
"""
Evaluates test accuracy of the given model on the dataset.
Args:
model: Trained GCN model
dataset: Dataset object (provides features, labels, test_mask, graph)
Returns:
accuracy: float (test accuracy, 0-1)
"""
model.eval()
features = self._get_features().to(self.device)
labels = dataset.graph_data.ndata['label'].to(self.device)
test_mask = dataset.graph_data.ndata['test_mask']
with torch.no_grad():
logits = model(dataset.graph_data.to(self.device), features)
pred = logits.argmax(dim=1)
correct = (pred[test_mask] == labels[test_mask]).float()
accuracy = correct.sum().item() / test_mask.sum().item()
return accuracy
[docs] def run_full_pipeline(self, attack_type='random', mode='transductive', knowledge='full', k=5, trials=1, **kwargs):
"""
Runs the full fingerprinting + attack + evaluation pipeline.
Parameters:
attack_type: 'random', 'bitflip', or 'mettack'
mode: 'transductive' or 'inductive'
knowledge: 'full' or 'limited'
k: number of fingerprints
trials: number of repeated trials
kwargs: extra params for attack or fingerprinting
Prints per-trial results and summary statistics.
"""
flip_rates = []
acc_drops = []
for trial in range(trials):
print(f"\n=== Trial {trial + 1}/{trials} ===")
model_clean = self._train_target_model()
acc_clean = self._evaluate_accuracy(model_clean, self.dataset)
print(f"Clean model accuracy: {acc_clean:.4f}")
fingerprints = self._generate_fingerprints(model_clean, mode=mode, knowledge=knowledge, k=k, **kwargs)
model_poisoned, attack_meta = self._run_attack(model_clean, attack_type=attack_type, knowledge=knowledge,
**kwargs)
acc_poisoned = self._evaluate_accuracy(model_poisoned, self.dataset)
print(f"Poisoned model accuracy: {acc_poisoned:.4f}")
eval_result = self._evaluate_fingerprints(model_poisoned, fingerprints)
flip_rate = eval_result['flip_rate']
print(f"Fingerprint flip rate: {flip_rate:.4f}")
for (nid, old, new) in eval_result['flipped']:
print(f" Node {nid}: {old} → {new}")
flip_rates.append(flip_rate)
acc_drops.append(acc_clean - acc_poisoned)
print("\n=== Summary ===")
print(f"Avg Accuracy Drop: {np.mean(acc_drops):.4f}")
print(f"Avg Fingerprint Flip Rate: {np.mean(flip_rates):.4f}")
[docs]class TransductiveFingerprintGenerator:
def __init__(self, model, dataset, candidate_fraction=0.3, random_seed=None, device='cpu', randomize=True):
self.device = torch.device(device)
self.model = model.to(self.device)
self.dataset = dataset
self.graph_data = dataset.graph_data
self.candidate_fraction = candidate_fraction
self.random_seed = random_seed
self.randomize = randomize
[docs] def _get_features(self):
"""Backend-agnostic feature getter (DGL or PyG)."""
return self.graph_data.ndata['feat'] if hasattr(self.graph_data, 'ndata') else self.graph_data.x
[docs] def get_candidate_nodes(self):
"""Randomly sample a subset of nodes as candidates."""
all_nodes = torch.arange(self.graph_data.num_nodes())
num_candidates = max(1, int(len(all_nodes) * self.candidate_fraction))
if self.randomize and self.candidate_fraction < 1.0:
generator = torch.Generator(device=self.device)
if self.random_seed is not None:
generator.manual_seed(self.random_seed)
idx = torch.randperm(len(all_nodes), generator=generator)[:num_candidates]
return all_nodes[idx]
return all_nodes
[docs] def compute_fingerprint_scores_full(self, candidate_nodes):
"""Full-knowledge fingerprint scores (gradient-based)."""
self.model.eval()
scores = []
x = self._get_features().to(self.device)
logits = self.model(self.graph_data.to(self.device), x)
for node in candidate_nodes:
self.model.zero_grad()
logit = logits[node]
label = logit.argmax().item()
loss = F.cross_entropy(logit.unsqueeze(0), torch.tensor([label], device=self.device))
loss.backward(retain_graph=True)
grad_norm = sum((p.grad ** 2).sum().item() for p in self.model.parameters() if p.grad is not None)
scores.append(grad_norm)
return torch.tensor(scores, device=self.device)
[docs] def compute_fingerprint_scores_limited(self, candidate_nodes):
"""Limited-knowledge fingerprint scores (confidence margin)."""
self.model.eval()
x = self._get_features().to(self.device)
with torch.no_grad():
logits = self.model(self.graph_data.to(self.device), x)
probs = F.softmax(logits, dim=1)
labels = probs.argmax(dim=1)
scores = 1.0 - probs[candidate_nodes, labels[candidate_nodes]]
return scores
[docs] def select_top_fingerprints(self, scores, candidate_nodes, k, method='full'):
"""Selects top-k fingerprint nodes after filtering out extreme score outliers."""
q = 0.99 if method == 'full' else 1.0
threshold = torch.quantile(scores, q)
mask = scores <= threshold
filtered_scores = scores[mask]
filtered_candidates = candidate_nodes[mask]
if filtered_scores.size(0) < k:
k = filtered_scores.size(0)
topk = torch.topk(filtered_scores, k)
return filtered_candidates[topk.indices], topk.values
[docs] def generate_fingerprints(self, k=5, method='full'):
candidate_nodes = self.get_candidate_nodes().to(self.device)
x = self._get_features().to(self.device)
with torch.no_grad():
logits = self.model(self.graph_data.to(self.device), x)
labels = logits.argmax(dim=1)
if method == 'full':
scores = self.compute_fingerprint_scores_full(candidate_nodes)
elif method == 'limited':
scores = self.compute_fingerprint_scores_limited(candidate_nodes)
else:
raise ValueError("method must be 'full' or 'limited'")
class_to_candidates = {}
for i, node in enumerate(candidate_nodes):
cls = int(labels[node])
class_to_candidates.setdefault(cls, []).append((node.item(), scores[i].item()))
rng = random.Random(self.random_seed)
class_list = list(class_to_candidates.keys())
rng.shuffle(class_list)
fingerprints = []
for cls in class_list:
class_nodes = sorted(class_to_candidates[cls], key=lambda x: x[1], reverse=True)
top_node = class_nodes[0][0]
fingerprints.append((top_node, cls))
if len(fingerprints) >= k:
break
if len(fingerprints) < k:
fingerprint_nodes, _ = self.select_top_fingerprints(scores, candidate_nodes, k, method=method)
fingerprints = [(int(n), int(labels[n])) for n in fingerprint_nodes]
return fingerprints
[docs]class InductiveFingerprintGenerator:
def __init__(self, model, dataset, shadow_graph=None, knowledge='limited',
candidate_fraction=0.3, num_fingerprints=5,
randomize=True, random_seed=None, device='cpu',
perturb_fingerprints=False, perturb_budget=5):
self.device = torch.device(device)
self.model = model.to(self.device)
self.dataset = dataset
self.shadow_graph = shadow_graph if shadow_graph is not None else dataset.graph_data
self.knowledge = knowledge
self.candidate_fraction = candidate_fraction
self.num_fingerprints = num_fingerprints
self.randomize = randomize
self.random_seed = random_seed
self.perturb_fingerprints = perturb_fingerprints
self.perturb_budget = perturb_budget
if self.random_seed is not None:
torch.manual_seed(self.random_seed)
random.seed(self.random_seed)
[docs] def _get_features(self):
return self.shadow_graph.ndata['feat'] if hasattr(self.shadow_graph, 'ndata') else self.shadow_graph.x
[docs] def get_candidate_nodes(self):
all_nodes = torch.arange(self.shadow_graph.num_nodes())
num_candidates = max(1, int(len(all_nodes) * self.candidate_fraction))
if self.randomize and self.candidate_fraction < 1.0:
generator = torch.Generator(device=self.device)
if self.random_seed is not None:
generator.manual_seed(self.random_seed)
idx = torch.randperm(len(all_nodes), generator=generator)[:num_candidates]
candidates = all_nodes[idx]
else:
candidates = all_nodes
return candidates
[docs] def compute_fingerprint_score(self, node_idx, graph_override=None):
"""
Computes the fingerprint score for a given node according to knowledge mode.
If graph_override is provided, scoring is done on that graph instead of shadow_graph.
"""
graph = graph_override if graph_override is not None else self.shadow_graph
x = (graph.ndata['feat'] if hasattr(graph, 'ndata') else graph.x).to(self.device)
self.model.eval()
if self.knowledge == 'limited':
with torch.no_grad():
logits = self.model(graph.to(self.device), x)
probs = torch.softmax(logits[node_idx], dim=0)
pred_class = probs.argmax().item()
return 1 - probs[pred_class].item()
elif self.knowledge == 'full':
x.requires_grad_(True)
logits = self.model(graph.to(self.device), x)
pred = logits[node_idx]
label = pred.argmax().item()
self.model.zero_grad()
loss = torch.nn.functional.nll_loss(
torch.log_softmax(pred.unsqueeze(0), dim=1),
torch.tensor([label], device=self.device)
)
loss.backward(retain_graph=True)
grad = x.grad[node_idx]
grad_norm_sq = (grad ** 2).sum().item()
x.requires_grad_(False)
x.grad = None
return grad_norm_sq
else:
raise ValueError("knowledge must be 'limited' or 'full'")
[docs] def generate_fingerprint_nodes(self):
"""
Step 3: Identifies and returns the top-k (num_fingerprints) nodes with the highest
fingerprint scores from the candidate set. (Section 4.2.2)
Returns:
List[int]: Indices of selected fingerprint nodes.
"""
candidates = self.get_candidate_nodes()
scores = []
for idx in candidates:
score = self.compute_fingerprint_score(idx)
scores.append((score, int(idx)))
scores.sort(reverse=True)
selected = [idx for (_, idx) in scores[:self.num_fingerprints]]
return selected
[docs] def save_fingerprint_tuples(self, node_indices):
self.model.eval()
x = self._get_features().to(self.device)
with torch.no_grad():
logits = self.model(self.shadow_graph.to(self.device), x)
labels = logits.argmax(dim=1).cpu().numpy()
return [(self.shadow_graph, int(idx), int(labels[idx])) for idx in node_indices]
[docs] def generate_fingerprints(self, method='full'):
"""
Generate inductive fingerprints for model watermarking.
Parameters:
method (str): 'full' for gradient-based or 'limited' for output-based
Returns:
List of fingerprints
"""
if method == 'full':
return self._generate_full()
elif method == 'limited':
return self._generate_limited()
else:
raise ValueError(f"Invalid fingerprinting method: '{method}'")
[docs] def _generate_full(self):
"""
Implements full knowledge fingerprint generation (gradient-based).
Based on Section 4.2.1 and 5.2 of Wu et al. (2023).
"""
self.knowledge = 'full'
print("[Fingerprint] Generating FULL knowledge fingerprints...")
fingerprint_nodes = self.generate_fingerprint_nodes()
if self.perturb_fingerprints:
print("[Fingerprint] Applying greedy feature perturbation (FULL)...")
self.greedy_perturb_fingerprints(fingerprint_nodes)
return self.save_fingerprint_tuples(fingerprint_nodes)
[docs] def _generate_limited(self):
"""
Implements limited knowledge fingerprint generation (output-based).
Based on Section 4.2.2 and 5.2 of Wu et al. (2023).
"""
self.knowledge = 'limited'
print("[Fingerprint] Generating LIMITED knowledge fingerprints...")
fingerprint_nodes = self.generate_fingerprint_nodes()
if self.perturb_fingerprints:
print("[Fingerprint] Applying greedy feature perturbation (LIMITED)...")
self.greedy_perturb_fingerprints(fingerprint_nodes)
return self.save_fingerprint_tuples(fingerprint_nodes)
[docs] def greedy_perturb_fingerprints(self, node_indices):
"""
Greedily perturbs each fingerprint node's features (not edges) to increase its
fingerprint score, without changing the predicted label.
- For each node, for each feature dimension:
- Add or subtract a small epsilon.
- Accept change if predicted label stays the same and fingerprint score increases.
- Stop after perturb_budget attempts or no improvement.
Returns:
List[int]: Indices of perturbed fingerprint nodes (features in shadow_graph are updated in-place).
"""
epsilon = 0.01
features = self._get_features().clone().detach().to(self.device)
self.shadow_graph = self.shadow_graph.to(self.device)
for idx in node_indices:
num_tries = 0
improved = True
while num_tries < self.perturb_budget and improved:
improved = False
current_score = self.compute_fingerprint_score(idx, graph_override=self.shadow_graph)
self.model.eval()
with torch.no_grad():
logits = self.model(self.shadow_graph, features)
pred_label = logits[idx].argmax().item()
original_features = features[idx].clone()
for dim in range(features.shape[1]):
for direction in [+1, -1]:
features[idx][dim] += direction * epsilon
self.model.eval()
with torch.no_grad():
logits_new = self.model(self.shadow_graph, features)
new_pred_label = logits_new[idx].argmax().item()
new_score = self.compute_fingerprint_score(idx, graph_override=self.shadow_graph)
if new_pred_label == pred_label and new_score > current_score:
current_score = new_score
improved = True
num_tries += 1
else:
features[idx][dim] = original_features[dim]
if num_tries >= self.perturb_budget:
break
if num_tries >= self.perturb_budget:
break
if hasattr(self.shadow_graph, 'ndata'):
self.shadow_graph.ndata['feat'] = features
else:
self.shadow_graph.x = features
return node_indices
[docs] def greedy_edge_perturbation(self, node_idx, perturb_budget=5, knowledge='full'):
"""
Dispatch to greedy edge perturbation strategy based on verifier knowledge level.
Args:
node_idx (int): Fingerprint node index.
perturb_budget (int): Number of edge perturbations allowed.
knowledge (str): 'full' or 'limited'
"""
if knowledge == 'full':
self._greedy_edge_perturbation_f(node_idx, perturb_budget)
elif knowledge == 'limited':
self._greedy_edge_perturbation_l(node_idx, perturb_budget)
else:
raise ValueError("knowledge must be 'full' or 'limited'")
[docs] def _greedy_edge_perturbation_f(self, node_idx, perturb_budget):
"""
Full knowledge edge perturbation (Inductive-F).
Increases fingerprint score using model gradients while preserving prediction.
"""
g_nx = to_networkx(self.shadow_graph.to('cpu'), to_undirected=True)
x = self._get_features().to(self.device)
self.model.eval()
with torch.no_grad():
original_pred = self.model(self.shadow_graph.to(self.device), x)[node_idx].argmax().item()
def score_fn(modified_graph):
return self.compute_fingerprint_score(node_idx, graph_override=modified_graph)
neighbors = list(g_nx.neighbors(node_idx))
non_neighbors = list(set(range(self.shadow_graph.num_nodes())) - set(neighbors) - {node_idx})
applied = 0
while applied < perturb_budget:
best_delta = 0
best_graph = None
best_action = None
for nbr in non_neighbors:
temp_g = copy.deepcopy(g_nx)
temp_g.add_edge(node_idx, nbr)
g_temp = from_networkx(temp_g).to(self.device)
with torch.no_grad():
pred = self.model(g_temp, x)[node_idx].argmax().item()
if pred != original_pred:
continue
delta = score_fn(g_temp) - score_fn(self.shadow_graph)
if delta > best_delta:
best_delta = delta
best_graph = g_temp
best_action = ('add', nbr)
for nbr in neighbors:
temp_g = copy.deepcopy(g_nx)
if temp_g.has_edge(node_idx, nbr):
temp_g.remove_edge(node_idx, nbr)
g_temp = from_networkx(temp_g).to(self.device)
with torch.no_grad():
pred = self.model(g_temp, x)[node_idx].argmax().item()
if pred != original_pred:
continue
delta = score_fn(g_temp) - score_fn(self.shadow_graph)
if delta > best_delta:
best_delta = delta
best_graph = g_temp
best_action = ('remove', nbr)
if best_graph is None:
break
self.shadow_graph = best_graph
g_nx = to_networkx(best_graph.to('cpu'), to_undirected=True)
if best_action[0] == 'add':
non_neighbors.remove(best_action[1])
neighbors.append(best_action[1])
else:
neighbors.remove(best_action[1])
non_neighbors.append(best_action[1])
applied += 1
[docs] def _greedy_edge_perturbation_l(self, node_idx, perturb_budget):
"""
Limited knowledge edge perturbation (Inductive-L).
Uses confidence margin (1 - confidence) as proxy for fingerprint sensitivity.
"""
g_nx = to_networkx(self.shadow_graph.to('cpu'), to_undirected=True)
x = self._get_features().to(self.device)
self.model.eval()
with torch.no_grad():
logits = self.model(self.shadow_graph.to(self.device), x)
original_pred = logits[node_idx].argmax().item()
original_conf = F.softmax(logits[node_idx], dim=0)[original_pred].item()
original_score = 1 - original_conf
def score_fn(modified_graph):
with torch.no_grad():
logits = self.model(modified_graph.to(self.device), x)
pred = logits[node_idx].argmax().item()
if pred != original_pred:
return -1
conf = F.softmax(logits[node_idx], dim=0)[pred].item()
return 1 - conf
neighbors = list(g_nx.neighbors(node_idx))
non_neighbors = list(set(range(self.shadow_graph.num_nodes())) - set(neighbors) - {node_idx})
applied = 0
while applied < perturb_budget:
best_delta = 0
best_graph = None
best_action = None
for nbr in non_neighbors:
temp_g = copy.deepcopy(g_nx)
temp_g.add_edge(node_idx, nbr)
g_temp = from_networkx(temp_g).to(self.device)
new_score = score_fn(g_temp)
delta = new_score - original_score
if new_score >= 0 and delta > best_delta:
best_delta = delta
best_graph = g_temp
best_action = ('add', nbr)
for nbr in neighbors:
temp_g = copy.deepcopy(g_nx)
if temp_g.has_edge(node_idx, nbr):
temp_g.remove_edge(node_idx, nbr)
g_temp = from_networkx(temp_g).to(self.device)
new_score = score_fn(g_temp)
delta = new_score - original_score
if new_score >= 0 and delta > best_delta:
best_delta = delta
best_graph = g_temp
best_action = ('remove', nbr)
if best_graph is None:
break
self.shadow_graph = best_graph
g_nx = to_networkx(best_graph.to('cpu'), to_undirected=True)
if best_action[0] == 'add':
non_neighbors.remove(best_action[1])
neighbors.append(best_action[1])
else:
neighbors.remove(best_action[1])
non_neighbors.append(best_action[1])
applied += 1
[docs]class BitFlipAttack:
def __init__(self, model, attack_type='random', bit=0):
self.model = model
self.attack_type = attack_type
self.bit = bit
[docs] def _get_target_params(self):
params = [p for p in self.model.parameters() if p.requires_grad and p.numel() > 0]
if self.attack_type in ['random', 'BFA']:
return params
elif self.attack_type == 'BFA-F':
return [params[0]]
elif self.attack_type == 'BFA-L':
return [params[-1]]
else:
raise ValueError(f"Unknown attack_type {self.attack_type}")
[docs] def _true_bit_flip(self, tensor, index=None, bit=0):
a = tensor.detach().cpu().numpy().copy()
flat = a.ravel()
if index is None:
index = np.random.randint(0, flat.size)
old_val = flat[index]
int_view = np.frombuffer(flat[index].tobytes(), dtype=np.uint32)[0]
int_view ^= (1 << bit)
new_val = np.frombuffer(np.uint32(int_view).tobytes(), dtype=np.float32)[0]
flat[index] = new_val
a = flat.reshape(a.shape)
tensor.data = torch.from_numpy(a).to(tensor.device)
return old_val, new_val, index
[docs] def apply(self):
params = self._get_target_params()
with torch.no_grad():
layer_idx = random.randrange(len(params))
param = params[layer_idx]
idx = random.randrange(param.numel())
old_val, new_val, actual_idx = self._true_bit_flip(param, index=idx, bit=self.bit)
return {
'layer': layer_idx,
'param_idx': actual_idx,
'old_val': old_val,
'new_val': new_val,
'bit': self.bit,
'attack_type': self.attack_type
}
[docs]class MettackHelper:
def __init__(self, graph, features, labels, train_mask, val_mask, test_mask,
n_perturbations=5, device='cpu', max_perturbations=50,
surrogate_epochs=30, candidate_sample_size=20):
self.device = device
self.graph = dgl.add_self_loop(graph).to(self.device)
self.features = features.to(self.device)
self.labels = labels.to(self.device)
self.train_mask = train_mask.to(self.device)
self.surrogate_epochs = surrogate_epochs
self.candidate_sample_size = candidate_sample_size
if val_mask is not None:
self.val_mask = val_mask.to(self.device)
else:
self.val_mask = self._create_val_mask_from_train(train_mask).to(self.device)
self.test_mask = test_mask.to(self.device)
self.n_perturbations = n_perturbations
in_feats = features.shape[1]
n_classes = int(labels.max().item()) + 1
self.surrogate = GCN(in_feats, n_classes).to(self.device)
torch.manual_seed(42)
np.random.seed(42)
self.modified_edges = set()
original_graph_no_self_loop = dgl.remove_self_loop(graph)
self.original_edges = set(zip(original_graph_no_self_loop.edges()[0].cpu().numpy(),
original_graph_no_self_loop.edges()[1].cpu().numpy()))
self.candidate_edges = self._get_candidate_edges()
[docs] def _create_val_mask_from_train(self, train_mask):
"""
Create a validation mask by taking a subset of training nodes.
This is needed when the dataset doesn't provide a validation mask.
"""
train_indices = torch.where(train_mask)[0]
n_val = min(500, len(train_indices) // 4)
perm = torch.randperm(len(train_indices))
val_indices = train_indices[perm[:n_val]]
val_mask = torch.zeros_like(train_mask, dtype=torch.bool)
val_mask[val_indices] = True
self.train_mask = train_mask.clone()
self.train_mask[val_indices] = False
return val_mask
[docs] def run(self):
"""
Main entrypoint to run the Mettack algorithm.
Returns:
poisoned_graph (DGLGraph): The perturbed graph with edges changed.
metrics (dict): Metrics for before/after attack, for evaluation.
"""
print("Starting Mettack attack...")
print("Training surrogate model...")
self._train_surrogate()
print("Applying structure attack...")
poisoned_graph = self._apply_structure_attack()
print("Evaluating attack results...")
metrics = self._evaluate(poisoned_graph)
return poisoned_graph, metrics
[docs] def _train_surrogate(self):
"""
Trains a surrogate GCN on the clean graph.
(Matches Wu et al., Section 6.1)
"""
optimizer = optim.Adam(self.surrogate.parameters(), lr=0.01, weight_decay=5e-4)
self.surrogate.train()
for epoch in range(self.surrogate_epochs):
optimizer.zero_grad()
logits = self.surrogate(self.graph, self.features)
loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask])
loss.backward()
optimizer.step()
if epoch % 50 == 0:
self.surrogate.eval()
with torch.no_grad():
val_logits = self.surrogate(self.graph, self.features)
val_acc = self._compute_accuracy(val_logits[self.val_mask],
self.labels[self.val_mask])
print(f"Surrogate epoch {epoch}: Val Acc = {val_acc:.4f}")
self.surrogate.train()
[docs] def _apply_structure_attack(self):
"""
Runs the Mettack structure perturbation loop (bi-level optimization).
- At each step, modify the adjacency matrix (add/remove an edge).
- Select the perturbation that maximizes surrogate model loss on the validation nodes.
- Repeat up to n_perturbations times.
Returns a new DGLGraph with edges modified.
(See Appendix A.2 in Wu et al.)
"""
current_graph = copy.deepcopy(self.graph)
perturbed_edges = set()
for step in range(self.n_perturbations):
print(f"Perturbation step {step + 1}/{self.n_perturbations}")
best_edge = None
best_loss = -float('inf')
best_action = None
candidate_sample = np.random.choice(len(self.candidate_edges),
min(self.candidate_sample_size, len(self.candidate_edges)),
replace=False)
for idx in tqdm(candidate_sample, desc="Evaluating candidates"):
edge = self.candidate_edges[idx]
if edge in perturbed_edges or (edge[1], edge[0]) in perturbed_edges:
continue
for action in ['add', 'remove']:
if action == 'add' and edge in self.original_edges:
continue
if action == 'remove' and edge not in self.original_edges:
continue
temp_graph = self._apply_single_perturbation(current_graph, edge, action)
attack_loss = self._compute_attack_loss(temp_graph)
if attack_loss > best_loss:
best_loss = attack_loss
best_edge = edge
best_action = action
if best_edge is not None:
current_graph = self._apply_single_perturbation(current_graph, best_edge, best_action)
perturbed_edges.add(best_edge)
self.modified_edges.add((best_edge, best_action))
print(f"Applied {best_action} edge {best_edge} with loss increase: {best_loss:.4f}")
else:
print("No beneficial perturbation found, stopping early.")
break
return current_graph
[docs] def _get_candidate_edges(self):
"""
Generate candidate edges for perturbation.
Includes both existing edges (for removal) and non-existing edges (for addition).
"""
n_nodes = self.graph.num_nodes()
all_possible_edges = []
for i in range(n_nodes):
for j in range(i + 1, n_nodes):
all_possible_edges.append((i, j))
return all_possible_edges[:min(10000, len(all_possible_edges))]
[docs] def _apply_single_perturbation(self, graph, edge, action):
"""
Apply a single edge perturbation (add or remove) to the graph.
"""
temp_graph = copy.deepcopy(graph)
if action == 'add':
temp_graph.add_edges([edge[0], edge[1]], [edge[1], edge[0]])
elif action == 'remove':
src, dst = temp_graph.edges()
edge_ids = []
for i, (s, d) in enumerate(zip(src.cpu().numpy(), dst.cpu().numpy())):
if (s == edge[0] and d == edge[1]) or (s == edge[1] and d == edge[0]):
edge_ids.append(i)
if edge_ids:
temp_graph.remove_edges(edge_ids)
temp_graph = dgl.add_self_loop(temp_graph)
return temp_graph
[docs] def _compute_attack_loss(self, perturbed_graph):
"""
Compute the attack loss on a perturbed graph.
This measures how much the surrogate model's performance degrades.
Uses proper bi-level optimization as in the original Mettack paper.
"""
temp_surrogate = copy.deepcopy(self.surrogate)
temp_surrogate.train()
optimizer = optim.Adam(temp_surrogate.parameters(), lr=0.01)
for _ in range(5):
optimizer.zero_grad()
logits = temp_surrogate(perturbed_graph, self.features)
loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask])
loss.backward()
optimizer.step()
temp_surrogate.eval()
with torch.no_grad():
val_logits = temp_surrogate(perturbed_graph, self.features)
val_loss = F.cross_entropy(val_logits[self.val_mask], self.labels[self.val_mask])
return val_loss.item()
[docs] def _evaluate(self, poisoned_graph):
"""
Evaluates GCN accuracy before/after poisoning, etc.
"""
metrics = {}
self.surrogate.eval()
with torch.no_grad():
clean_logits = self.surrogate(self.graph, self.features)
clean_acc = self._compute_accuracy(clean_logits[self.test_mask],
self.labels[self.test_mask])
metrics['clean_test_acc'] = clean_acc
poisoned_model = GCN(self.features.shape[1],
int(self.labels.max().item()) + 1).to(self.device)
optimizer = optim.Adam(poisoned_model.parameters(), lr=0.01, weight_decay=5e-4)
poisoned_model.train()
for epoch in range(200):
optimizer.zero_grad()
logits = poisoned_model(poisoned_graph, self.features)
loss = F.cross_entropy(logits[self.train_mask], self.labels[self.train_mask])
loss.backward()
optimizer.step()
poisoned_model.eval()
with torch.no_grad():
poisoned_logits = poisoned_model(poisoned_graph, self.features)
poisoned_acc = self._compute_accuracy(poisoned_logits[self.test_mask],
self.labels[self.test_mask])
metrics['poisoned_test_acc'] = poisoned_acc
metrics['accuracy_drop'] = clean_acc - poisoned_acc
metrics['num_perturbations'] = len(self.modified_edges)
return metrics
[docs] def _compute_accuracy(self, logits, labels):
"""Helper function to compute accuracy."""
_, predicted = torch.max(logits, 1)
correct = (predicted == labels).sum().item()
return correct / len(labels)