RedditData¶
Implementation for the Reddit network dataset.
A graph dataset representing Reddit posts, where nodes are posts and edges indicate that the same user commented on both posts.
Parameters
- api_typestr, optional
API selection (‘dgl’ or ‘torch_geometric’), defaults to ‘dgl’.
- pathstr, optional
Dataset directory path, defaults to ‘./downloads/’.
Methods
- load_dgl_data()¶
Load Reddit dataset using DGL’s RedditDataset loader. Uses pre-defined node features, labels, and splits for training, validation, and testing.
- load_tg_data()¶
Load Reddit dataset using PyTorch Geometric (PyG) loader. Extracts dataset features, node labels, and masks for training, validation, and testing.
Attributes
- graph: dgl.DGLGraph¶
The graph data structure (when using DGL).
- dataset_name: str¶
Name of the dataset (“reddit”).
- node_number: int¶
Number of nodes in the graph.
- feature_number: int¶
Dimension of node features.
- label_number: int¶
Number of unique node labels.
- features: torch.Tensor¶
Node feature matrix.
- labels: torch.Tensor¶
Node label tensor.
- train_mask: torch.Tensor¶
Boolean mask indicating training nodes.
- val_mask: torch.Tensor¶
Boolean mask indicating validation nodes.
- test_mask: torch.Tensor¶
Boolean mask indicating testing nodes.
- dataset: torch_geometric.datasets.Reddit¶
The PyG dataset object (only available in PyG implementation).
- data: torch_geometric.data.Data¶
The PyG data object (only available in PyG implementation).
- edge_index: torch.Tensor¶
Edge list representation (only explicitly set in PyG implementation).