pygip.models.defense.SurviveWM

Module Attributes

class pygip.models.defense.SurviveWM.SurviveWM(dataset, defense_ratio=0.1, model_path=None)[source]

Bases: BaseDefense

_abc_impl = <_abc_data object>
_load_model()[source]

Load a pre-trained model.

_to_cpu(tensor)[source]

Safely move tensor to CPU for NumPy operations

_train_watermarked_model()[source]

Helper function to train the watermarked model

combine_with_trigger(base_graph, base_features, base_labels, trigger_data)[source]
compute_metrics(y_true, y_pred, y_score=None)[source]
defend()[source]

Execute the SurviveWM defense.

evaluate_model(model)[source]

Evaluate model performance on downstream task

generate_key_graph(num_nodes=None, edge_prob=None)[source]
snn_loss(x, y, T=0.5)[source]
supported_api_types = {'dgl'}
train_target_model(metric_comp)[source]

Train the target model with watermark injection.

train_with_snnl(model, graph, features, labels, train_mask, optimizer, T=0.5, alpha=0.1)[source]
verify_watermark(model, trigger_graph, trigger_labels)
verify_watermark_model(model)[source]

Verify watermark success rate