Source code for pygip.utils.dglTopyg

from torch_geometric.utils import from_networkx
import networkx as nx

[docs]def dgl_to_pyg_data(dgl_graph): nx_graph = dgl_graph.to_networkx(node_attrs=['feat', 'label', 'train_mask', 'val_mask', 'test_mask']) pyg_data = from_networkx(nx_graph) pyg_data.x = pyg_data.feat pyg_data.y = pyg_data.label pyg_data.train_mask = pyg_data.train_mask pyg_data.val_mask = pyg_data.val_mask pyg_data.test_mask = pyg_data.test_mask return pyg_data