from typing import Any, Dict, Iterable, Tuple
import dgl
import torch
import torch.nn.functional as F
from dgl.dataloading import NeighborSampler, NodeCollator
from torch.utils.data import DataLoader
from tqdm import tqdm
from pygip.models.defense.base import BaseDefense
from pygip.models.nn import GraphSAGE
from pygip.utils.metrics import DefenseCompMetric
[docs]class Revisiting(BaseDefense):
"""
A lightweight defense that 'revisits' node features via neighbor mixing.
Idea (defense intuition)
------------------------
We pick a subset of nodes (size ~ attack_node_fraction * |V|) and *smoothly*
mix their features with their 1-hop / 2-hop neighborhoods using a mixing
factor `alpha`. This keeps utility (accuracy) largely intact while making
local feature structure less extractable for subgraph-based queries.
API shape follows RandomWM:
- lives under models/defense/
- inherits BaseDefense
- public entrypoint: .defend()
Parameters
----------
dataset : Any
A dataset object providing a DGLGraph in `dataset.graph_data` and
ndata fields: 'feat', 'label', 'train_mask', 'test_mask'.
attack_node_fraction : float, default=0.2
Fraction of nodes used as the 'focus set' for our revisiting transform.
alpha : float, default=0.8
Mixing coefficient in [0,1]. Higher -> stronger neighbor mixing.
"""
supported_api_types = {"dgl"}
def __init__(
self,
dataset: Any,
attack_node_fraction: float = 0.2,
alpha: float = 0.8,
) -> None:
super().__init__(dataset, attack_node_fraction)
# knobs
self.alpha = float(alpha)
# cache handles similar to RandomWM for consistency
self.dataset = dataset
self.graph: dgl.DGLGraph = dataset.graph_data
self.num_nodes = dataset.num_nodes
self.num_features = dataset.num_features
self.num_classes = dataset.num_classes
self.num_focus_nodes = max(1, int(self.num_nodes * attack_node_fraction))
self.features: torch.Tensor = self.graph.ndata["feat"]
self.labels: torch.Tensor = self.graph.ndata["label"]
self.train_mask: torch.Tensor = self.graph.ndata["train_mask"]
self.test_mask: torch.Tensor = self.graph.ndata["test_mask"]
if self.device != "cpu":
self.graph = self.graph.to(self.device)
self.features = self.features.to(self.device)
self.labels = self.labels.to(self.device)
self.train_mask = self.train_mask.to(self.device)
self.test_mask = self.test_mask.to(self.device)
# --------------------------------------------------------------------- #
# Public entrypoint
# --------------------------------------------------------------------- #
[docs] def defend(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
1) Train a baseline GraphSAGE on the original graph (utility baseline)
2) Apply revisiting feature-mixing on a subset of nodes
3) Train a defended GraphSAGE on the transformed features
4) Return accuracy metrics and basic metadata
"""
metric_comp = DefenseCompMetric()
metric_comp.start()
# ---- Baseline (no transform) ------------------------------------- #
baseline_acc = self._train_and_eval_graphsage(use_transformed_features=False)
# ---- Build transformed features (revisiting) --------------------- #
feat_defended, picked = self._build_revisiting_features()
# ---- Train with defended features -------------------------------- #
# Temporarily override graph features, then restore
orig_feat = self.graph.ndata["feat"]
try:
self.graph.ndata["feat"] = feat_defended
defense_acc = self._train_and_eval_graphsage(use_transformed_features=True)
finally:
self.graph.ndata["feat"] = orig_feat # restore
res = {
"ok": True,
"method": "Revisiting",
"alpha": self.alpha,
"focus_nodes": int(self.num_focus_nodes),
"baseline_test_acc": float(baseline_acc),
"defense_test_acc": float(defense_acc),
"acc_delta": float(defense_acc - baseline_acc),
# returning a small sample of picked nodes for debuggability
"sample_picked_nodes": picked[:10].tolist() if isinstance(picked, torch.Tensor) else [],
}
return res, metric_comp.compute()
# --------------------------------------------------------------------- #
# Core: feature revisiting (neighbor mixing)
# --------------------------------------------------------------------- #
[docs] def _build_revisiting_features(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a new feature tensor where a subset of nodes (and optionally
their neighbors) are mixed with neighbor features.
Mixing rule (simple & stable):
- For each picked node u:
x[u] <- (1 - alpha) * x[u] + alpha * mean(x[N(u)])
- For each 1-hop neighbor v in N(u) we apply a *lighter* mix
x[v] <- (1 - 0.5*alpha) * x[v] + (0.5*alpha) * mean(x[N(v)])
This keeps the transform localized and smooth.
"""
g = self.graph
x = self.features.clone()
# pick focus nodes
picked = torch.randperm(self.num_nodes, device=self.device)[: self.num_focus_nodes]
# precompute neighbor lists (on CPU tensors if needed)
# we'll use undirected neighborhood by combining predecessors/successors
def neighbors(nodes: Iterable[int]) -> torch.Tensor:
cols = []
for n in nodes:
# concatenate in- and out-neighbors to emulate undirected
nb = torch.unique(
torch.cat([g.successors(int(n)), g.predecessors(int(n))], dim=0)
)
if nb.numel() > 0:
cols.append(nb)
if not cols:
return torch.empty(0, dtype=torch.long, device=self.device)
return torch.unique(torch.cat(cols))
# 1) mix picked nodes with mean of their neighbors
for u in picked.tolist():
nb = neighbors([u])
if nb.numel() == 0:
continue
mean_nb = self.features[nb].mean(dim=0)
x[u] = (1.0 - self.alpha) * self.features[u] + self.alpha * mean_nb
# 2) lightly mix 1-hop neighbors as well (half strength)
one_hop = neighbors(picked.tolist())
for v in one_hop.tolist():
nb = neighbors([v])
if nb.numel() == 0:
continue
mean_nb = self.features[nb].mean(dim=0)
x[v] = (1.0 - 0.5 * self.alpha) * self.features[v] + (0.5 * self.alpha) * mean_nb
return x, picked
# --------------------------------------------------------------------- #
# Training/Eval (same style as RandomWM)
# --------------------------------------------------------------------- #
[docs] def _train_and_eval_graphsage(self, use_transformed_features: bool) -> float:
"""
Train a GraphSAGE for a few epochs and return test accuracy.
Uses NeighborSampler + NodeCollator (same pattern as RandomWM).
"""
model = GraphSAGE(
in_channels=self.num_features,
hidden_channels=128,
out_channels=self.num_classes,
).to(self.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
sampler = NeighborSampler([5, 5])
train_nids = self.train_mask.nonzero(as_tuple=True)[0].to(self.device)
test_nids = self.test_mask.nonzero(as_tuple=True)[0].to(self.device)
train_collator = NodeCollator(self.graph, train_nids, sampler)
test_collator = NodeCollator(self.graph, test_nids, sampler)
train_loader = DataLoader(
train_collator.dataset,
batch_size=32,
shuffle=True,
collate_fn=train_collator.collate,
drop_last=False,
)
test_loader = DataLoader(
test_collator.dataset,
batch_size=32,
shuffle=False,
collate_fn=test_collator.collate,
drop_last=False,
)
best_acc = 0.0
for _ in tqdm(range(1, 51),
desc=("GraphSAGE (defended)" if use_transformed_features else "GraphSAGE (baseline)")):
# ---- Train
model.train()
for _, _, blocks in train_loader:
blocks = [b.to(self.device) for b in blocks]
feats = blocks[0].srcdata["feat"]
labels = blocks[-1].dstdata["label"]
optimizer.zero_grad()
logits = model(blocks, feats)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
# ---- Eval
model.eval()
correct = 0
total = 0
with torch.no_grad():
for _, _, blocks in test_loader:
blocks = [b.to(self.device) for b in blocks]
feats = blocks[0].srcdata["feat"]
labels = blocks[-1].dstdata["label"]
logits = model(blocks, feats)
pred = logits.argmax(dim=1)
correct += (pred == labels).sum().item()
total += labels.numel()
acc = correct / max(1, total)
best_acc = max(best_acc, acc)
return best_acc