import dgl
import numpy as np
import torch
from dgl import DGLGraph
from dgl.data import AmazonCoBuyComputerDataset # Amazon-Computer
from dgl.data import AmazonCoBuyPhotoDataset # Amazon-Photo
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset
from dgl.data import FakeNewsDataset
from dgl.data import FlickrDataset
from dgl.data import GINDataset
from dgl.data import MUTAGDataset
from dgl.data import RedditDataset
from dgl.data import YelpDataset
from dgl.data import citation_graph # Cora, CiteSeer, PubMed
from sklearn.model_selection import StratifiedShuffleSplit
from torch_geometric.data import Data as PyGData
from torch_geometric.datasets import Amazon # Amazon Computers, Photo
from torch_geometric.datasets import Coauthor # cs, physics
from torch_geometric.datasets import FacebookPagePage
from torch_geometric.datasets import Flickr as FlickrPyG
from torch_geometric.datasets import LastFMAsia
from torch_geometric.datasets import Planetoid # Cora, CiteSeer, PubMed
from torch_geometric.datasets import PolBlogs as PolBlogsPyG
from torch_geometric.datasets import Reddit
from torch_geometric.datasets import TUDataset # ENZYMES
[docs]def dgl_to_tg(dgl_graph):
edge_index = torch.stack(dgl_graph.edges())
x = dgl_graph.ndata.get('feat')
y = dgl_graph.ndata.get('label')
train_mask = dgl_graph.ndata.get('train_mask')
val_mask = dgl_graph.ndata.get('val_mask')
test_mask = dgl_graph.ndata.get('test_mask')
data = PyGData(x=x, edge_index=edge_index, y=y,
train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
return data
[docs]def tg_to_dgl(py_g_data):
edge_index = py_g_data.edge_index
dgl_graph = dgl.graph((edge_index[0], edge_index[1]))
if py_g_data.x is not None:
dgl_graph.ndata['feat'] = py_g_data.x
if py_g_data.y is not None:
dgl_graph.ndata['label'] = py_g_data.y
if hasattr(py_g_data, 'train_mask') and py_g_data.train_mask is not None:
dgl_graph.ndata['train_mask'] = py_g_data.train_mask
if hasattr(py_g_data, 'val_mask') and py_g_data.val_mask is not None:
dgl_graph.ndata['val_mask'] = py_g_data.val_mask
if hasattr(py_g_data, 'test_mask') and py_g_data.test_mask is not None:
dgl_graph.ndata['test_mask'] = py_g_data.test_mask
return dgl_graph
[docs]class Dataset(object):
def __init__(self, api_type='dgl', path='./data'):
assert api_type in {'dgl', 'pyg'}, 'API type must be dgl or pyg'
self.api_type = api_type
self.path = path
self.dataset_name = self.get_name()
# DGLGraph or PyGData
self.graph_dataset = None
self.graph_data = None
# meta data
self.num_nodes = 0
self.num_features = 0
self.num_classes = 0
if self.api_type == 'dgl':
self.load_dgl_data()
elif self.api_type == 'pyg':
self.load_pyg_data()
else:
raise ValueError("Unsupported api_type.")
self._load_meta_data()
[docs] def get_name(self):
return self.__class__.__name__
[docs] def load_dgl_data(self):
raise NotImplementedError("load_dgl_data not implemented in subclasses.")
[docs] def load_pyg_data(self):
raise NotImplementedError("load_pyg_data not implemented in subclasses.")
[docs] def _generate_masks_by_ratio(self, train_ratio=0.8):
if self.graph_data is None:
raise ValueError("graph_data is not loaded.")
try:
import dgl
except ImportError:
dgl = None
try:
from torch_geometric.data import Data
except ImportError:
Data = None
is_dgl = dgl and isinstance(self.graph_data, dgl.DGLGraph)
is_pyg = Data and isinstance(self.graph_data, Data)
if not (is_dgl or is_pyg):
raise TypeError("graph_data must be either DGLGraph or torch_geometric.data.Data.")
# Check if masks already exist
if is_dgl:
if all(k in self.graph_data.ndata for k in ['train_mask', 'val_mask', 'test_mask']):
print("Masks already exist in DGL graph. Skipping mask generation.")
return
num_nodes = self.graph_data.num_nodes()
else: # PyG
if all(hasattr(self.graph_data, k) for k in ['train_mask', 'val_mask', 'test_mask']):
print("Masks already exist in PyG data. Skipping mask generation.")
return
num_nodes = self.graph_data.num_nodes
# Generate masks
indices = torch.randperm(num_nodes)
train_size = int(train_ratio * num_nodes)
val_size = (num_nodes - train_size) // 2
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
train_mask[indices[:train_size]] = True
val_mask[indices[train_size:train_size + val_size]] = True
test_mask[indices[train_size + val_size:]] = True
# Store masks
if is_dgl:
self.graph_data.ndata['train_mask'] = train_mask
self.graph_data.ndata['val_mask'] = val_mask
self.graph_data.ndata['test_mask'] = test_mask
else: # PyG
self.graph_data.train_mask = train_mask
self.graph_data.val_mask = val_mask
self.graph_data.test_mask = test_mask
print(f"Masks successfully generated and stored. (train_ratio={train_ratio})")
[docs] def _generate_masks_by_classes(self, num_class_samples=100, val_count=500, test_count=1000, seed=42):
"""
For Amazon and Coauthor datasets:
- train: `num_class_samples` per class
- val: `val_count` nodes from remaining
- test: `test_count` nodes from remaining after val
Works for both DGL and PyG graphs via self.graph_data
"""
try:
import dgl
except ImportError:
dgl = None
try:
from torch_geometric.data import Data as PyGData
except ImportError:
PyGData = None
is_dgl = dgl is not None and isinstance(self.graph_data, dgl.DGLGraph)
is_pyg = PyGData is not None and isinstance(self.graph_data, PyGData)
if not (is_dgl or is_pyg):
raise TypeError("graph_data must be either DGLGraph or torch_geometric.data.Data.")
if is_dgl:
if all(k in self.graph_data.ndata for k in ['train_mask', 'val_mask', 'test_mask']):
print("Masks already exist in DGL graph. Skipping mask generation.")
return
num_nodes = self.graph_data.num_nodes()
labels = self.graph_data.ndata['label']
else: # PyG
if all(hasattr(self.graph_data, k) for k in ['train_mask', 'val_mask', 'test_mask']):
print("Masks already exist in PyG data. Skipping mask generation.")
return
num_nodes = self.graph_data.num_nodes
labels = self.graph_data.y
num_classes = int(labels.max().item()) + 1
used_mask = torch.zeros(num_nodes, dtype=torch.bool)
generator = torch.Generator().manual_seed(seed)
train_idx_parts = []
# train set
print("Training samples per class:")
for c in range(num_classes):
class_idx = (labels == c).nonzero(as_tuple=True)[0]
if class_idx.numel() == 0:
print(f" class {c}: no available samples")
continue
perm = class_idx[torch.randperm(class_idx.size(0), generator=generator)]
n_select = min(num_class_samples, perm.size(0))
selected = perm[:n_select]
train_idx_parts.append(selected)
used_mask[selected] = True
print(f" class {c}: select {n_select} samples")
if len(train_idx_parts) == 0:
raise ValueError("no training samples available.")
train_idx = torch.cat(train_idx_parts, dim=0)
# val set
remaining_idx = (~used_mask).nonzero(as_tuple=True)[0]
if remaining_idx.numel() == 0:
raise ValueError("no remaining samples available.")
remaining_perm = remaining_idx[torch.randperm(remaining_idx.size(0), generator=generator)]
val_take = min(val_count, remaining_perm.size(0))
val_idx = remaining_perm[:val_take]
used_mask[val_idx] = True
# test set
remaining_idx = (~used_mask).nonzero(as_tuple=True)[0]
test_take = min(test_count, remaining_idx.size(0))
test_idx = remaining_idx[:test_take]
train_mask = self._index_to_mask(train_idx, num_nodes)
val_mask = self._index_to_mask(val_idx, num_nodes)
test_mask = self._index_to_mask(test_idx, num_nodes)
if is_pyg:
self.graph_data.train_mask = train_mask
self.graph_data.val_mask = val_mask
self.graph_data.test_mask = test_mask
else:
self.graph_data.ndata["train_mask"] = train_mask
self.graph_data.ndata["val_mask"] = val_mask
self.graph_data.ndata["test_mask"] = test_mask
[docs] def _index_to_mask(self, index: torch.Tensor, size: int):
mask = torch.zeros(size, dtype=torch.bool, device=index.device if isinstance(index, torch.Tensor) else None)
mask[index] = True
return mask
def __repr__(self):
return (f"Dataset(name={self.dataset_name}, api_type={self.api_type}, "
f"#Nodes={self.num_nodes}, #Features={self.num_features}, "
f"#Classes={self.num_classes})")
[docs]class Cora(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = citation_graph.load_cora()
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs] def load_pyg_data(self):
dataset = Planetoid(root=self.path, name='Cora')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class CiteSeer(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = citation_graph.load_citeseer()
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs] def load_pyg_data(self):
dataset = Planetoid(root=self.path, name='Citeseer')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class PubMed(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = citation_graph.load_pubmed()
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs] def load_pyg_data(self):
dataset = Planetoid(root=self.path, name='PubMed')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class Computers(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = AmazonCoBuyComputerDataset(raw_dir=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = dgl.add_self_loop(data)
self._generate_masks_by_classes()
[docs] def load_pyg_data(self):
dataset = Amazon(root=self.path, name='Computers')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs]class Photo(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = AmazonCoBuyPhotoDataset(raw_dir=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = dgl.add_self_loop(data)
self._generate_masks_by_classes()
[docs] def load_pyg_data(self):
dataset = Amazon(root=self.path, name='Photo')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs]class CoauthorCS(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = CoauthorCSDataset(raw_dir=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs] def load_pyg_data(self):
dataset = Coauthor(root=self.path, name='CS')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs]class CoauthorPhysics(Dataset):
[docs] def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = CoauthorPhysicsDataset(raw_dir=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs] def load_pyg_data(self):
dataset = Coauthor(root=self.path, name='Physics')
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs]class ENZYMES(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_pyg_data(self):
dataset = TUDataset(root=self.path, name='ENZYMES')
data_list = [data for data in dataset]
all_x = torch.cat([d.x for d in data_list], dim=0)
mean, std = all_x.mean(0), all_x.std(0)
for d in data_list:
d.x = (d.x - mean) / (std + 1e-6)
all_labels = np.array([int(d.y) for d in data_list])
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(splitter.split(np.zeros(len(all_labels)), all_labels))
self.train_data = [data_list[i] for i in train_idx]
self.test_data = [data_list[i] for i in test_idx]
[docs]class Facebook(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_pyg_data(self):
dataset = FacebookPagePage(root=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class Flickr(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = FlickrDataset(raw_dir=self.path)
self.graph_dataset = dataset
self.graph_data = dataset[0]
[docs] def load_pyg_data(self):
dataset = FlickrPyG(root=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class PolBlogs(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_pyg_data(self):
dataset = PolBlogsPyG(root=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
self._generate_masks_by_classes()
[docs]class LastFM(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_pyg_data(self):
dataset = LastFMAsia(root=self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class Reddit(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = RedditDataset(raw_dir=self.path)
self.graph_dataset = dataset
self.graph_data = dataset[0]
[docs] def load_pyg_data(self):
dataset = Reddit(self.path)
data = dataset[0]
self.graph_dataset = dataset
self.graph_data = data
[docs]class MUTAG(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = MUTAGDataset(raw_dir=self.path)
self.graph_dataset = dataset
self.graph_data = dataset[0]
[docs]class PTC(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = GINDataset(name='PTC', raw_dir=self.path, self_loop=False)
graph, _ = zip(*[dataset[i] for i in range(16)])
self.graph_dataset = dataset
self.graph_data = dgl.batch(graph)
[docs]class NCI1(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = GINDataset(name='NCI1', raw_dir=self.path, self_loop=False)
graph, _ = zip(*[dataset[i] for i in range(16)])
self.graph_dataset = dataset
self.graph_data = dgl.batch(graph)
[docs]class PROTEINS(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = GINDataset(name='PROTEINS', raw_dir=self.path, self_loop=False)
graph, _ = zip(*[dataset[i] for i in range(16)])
self.graph_dataset = dataset
self.graph_data = dgl.batch(graph)
[docs]class Collab(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = GINDataset(name='COLLAB', raw_dir=self.path, self_loop=False)
graph, _ = zip(*[dataset[i] for i in range(16)])
self.graph_dataset = dataset
self.graph_data = dgl.batch(graph)
[docs]class IMDB(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = GINDataset(name='IMDB-BINARY', raw_dir=self.path, self_loop=False)
graph, _ = zip(*[dataset[i] for i in range(16)])
self.graph_dataset = dataset
self.graph_data = dgl.batch(graph)
[docs]class YelpData(Dataset):
def __init__(self, api_type='dgl', path='./data'):
super().__init__(api_type, path)
[docs] def load_dgl_data(self):
dataset = YelpDataset(raw_dir=self.path)
self.graph_dataset = dataset
self.graph_data = dataset[0]