Source code for pygip.models.attack.DataFreeMEA

from abc import abstractmethod

import dgl
import networkx as nx
import torch
import torch.nn.functional as F
from tqdm import tqdm

from pygip.models.attack.base import BaseAttack
from pygip.models.nn import GCN, GraphSAGE  # Backbone architectures


[docs]class GraphGenerator: def __init__(self, node_number, feature_number, label_number): self.node_number = node_number self.feature_number = feature_number self.label_number = label_number
[docs] def generate(self): # Generate a random Erdős–Rényi graph and convert to DGL g_nx = nx.erdos_renyi_graph(n=self.node_number, p=0.05) g_dgl = dgl.from_networkx(g_nx) # Random node features features = torch.randn((self.node_number, self.feature_number)) return g_dgl, features
[docs]class DFEAAttack(BaseAttack): supported_api_types = {"dgl"} def __init__(self, dataset, attack_node_fraction, model_path=None): super().__init__(dataset, attack_node_fraction, model_path) # load graph data self.graph = dataset.graph_data.to(self.device) self.features = self.graph.ndata['feat'] self.labels = self.graph.ndata['label'] self.train_mask = self.graph.ndata['train_mask'] self.val_mask = self.graph.ndata['val_mask'] self.test_mask = self.graph.ndata['test_mask'] # meta data self.feature_number = dataset.num_features self.label_number = dataset.num_classes self.attack_node_number = int(dataset.num_nodes * attack_node_fraction) # Generate synthetic graph and features for surrogate training self.generator = GraphGenerator( node_number=self.attack_node_number, feature_number=self.feature_number, label_number=self.label_number ) self.synthetic_graph, self.synthetic_features = self.generator.generate() self.synthetic_graph = self.synthetic_graph.to(self.device) self.synthetic_features = self.synthetic_features.to(self.device) if model_path is None: self._train_target_model() else: self._load_model(model_path)
[docs] def _train_target_model(self): # Train the victim GCN model on real data (mirroring main.py) model = GCN(self.feature_number, self.label_number).to(self.device) optimizer = torch.optim.Adam( model.parameters(), lr=0.01, weight_decay=5e-4 ) model.train() # Identify dataset for label shaping name = getattr(self.dataset, 'dataset_name', None) or getattr(self.dataset, 'name', None) epochs = 200 for epoch in range(1, epochs + 1): optimizer.zero_grad() logits = model(self.graph, self.features) labels = self.labels.squeeze() if name == 'ogb-arxiv' else self.labels loss = F.nll_loss( F.log_softmax(logits[self.train_mask], dim=1), labels[self.train_mask] ) loss.backward() optimizer.step() model.eval() self.model = model
[docs] def _load_model(self, model_path): # Load a pretrained victim model model = GCN(self.feature_number, self.label_number) state = torch.load(model_path, map_location=self.device) model.load_state_dict(state) model.eval() self.model = model
[docs] def _forward(self, model, graph, features): # Abstract forward for GCN and GraphSAGE if isinstance(model, GraphSAGE): # GraphSAGE expects two-block input list return model([graph, graph], features) return model(graph, features)
[docs] def evaluate(self, surrogate): # Compute agreement accuracy between surrogate and victim on synthetic data surrogate.eval() self.model.eval() g = self.graph x = self.features y = self.labels mask = self.test_mask with torch.no_grad(): # victim predict logits_v = self._forward(self.model, g, x) preds_v = logits_v.argmax(dim=1) # surrogate predict logits_s = self._forward(surrogate, g, x) preds_s = logits_s.argmax(dim=1) # victim acc, surrogate acc victim_acc = (preds_v[mask] == y[mask]).float().mean().item() surrogate_acc = (preds_s[mask] == y[mask]).float().mean().item() # fidelity fidelity = (preds_s[mask] == preds_v[mask]).float().mean().item() return { "victim_acc": victim_acc, "surrogate_acc": surrogate_acc, "fidelity": fidelity, }
[docs] @abstractmethod def attack(self): pass
[docs]class DFEATypeI(DFEAAttack): """ Type I: Uses victim outputs + gradients for surrogate training. """
[docs] def attack(self): surrogate = GCN(self.feature_number, self.label_number).to(self.device) optimizer = torch.optim.Adam(surrogate.parameters(), lr=0.01) for _ in tqdm(range(200)): surrogate.train() optimizer.zero_grad() # Victim logits (no gradient) with torch.no_grad(): logits_v = self._forward( self.model, self.synthetic_graph, self.synthetic_features ) logits_s = self._forward( surrogate, self.synthetic_graph, self.synthetic_features ) loss = F.kl_div( F.log_softmax(logits_s, dim=1), F.softmax(logits_v, dim=1), reduction='batchmean' ) loss.backward() optimizer.step() metric = self.evaluate(surrogate) print('Agreement Acc: ', metric) return metric
[docs]class DFEATypeII(DFEAAttack): """ Type II: Uses victim outputs only (hard labels). """
[docs] def attack(self): surrogate = GraphSAGE(self.feature_number, 16, self.label_number).to(self.device) optimizer = torch.optim.Adam(surrogate.parameters(), lr=0.01) for _ in tqdm(range(200)): surrogate.train() optimizer.zero_grad() with torch.no_grad(): logits_v = self._forward( self.model, self.synthetic_graph, self.synthetic_features ) logits_s = self._forward( surrogate, self.synthetic_graph, self.synthetic_features ) pseudo = logits_v.argmax(dim=1) loss = F.cross_entropy(logits_s, pseudo) loss.backward() optimizer.step() metric = self.evaluate(surrogate) print('Agreement Acc: ', metric) return metric
[docs]class DFEATypeIII(DFEAAttack): """ Type III: Two surrogates with victim supervision + consistency. """
[docs] def attack(self): s1 = GCN(self.feature_number, self.label_number).to(self.device) s2 = GraphSAGE(self.feature_number, 16, self.label_number).to(self.device) opt1 = torch.optim.Adam(s1.parameters(), lr=0.01) opt2 = torch.optim.Adam(s2.parameters(), lr=0.01) for _ in tqdm(range(200)): s1.train() s2.train() opt1.zero_grad() opt2.zero_grad() # Victim pseudo-labels with torch.no_grad(): logits_v = self._forward( self.model, self.synthetic_graph, self.synthetic_features ) pseudo_v = logits_v.argmax(dim=1) # Surrogate predictions l1 = self._forward(s1, self.synthetic_graph, self.synthetic_features) l2 = self._forward(s2, self.synthetic_graph, self.synthetic_features) # Loss: supervised + consistency loss1 = F.cross_entropy(l1, pseudo_v) loss2 = F.cross_entropy(l2, pseudo_v) cons = F.mse_loss(l1, l2) total = loss1 + loss2 + 0.5 * cons total.backward() opt1.step() opt2.step() metric = self.evaluate(s1) print('Agreement Acc: ', metric) return metric
# Factory mapping of attack names to classes ATTACK_FACTORY = { "ModelExtractionAttack0": DFEATypeI, "ModelExtractionAttack1": DFEATypeI, "ModelExtractionAttack2": DFEATypeII, "ModelExtractionAttack3": DFEATypeIII }