RedditData

Implementation for the Reddit network dataset.

A graph dataset representing Reddit posts, where nodes are posts and edges indicate that the same user commented on both posts.

Parameters

api_typestr, optional

API selection (‘dgl’ or ‘torch_geometric’), defaults to ‘dgl’.

pathstr, optional

Dataset directory path, defaults to ‘./downloads/’.

Methods

load_dgl_data()

Load Reddit dataset using DGL’s RedditDataset loader. Uses pre-defined node features, labels, and splits for training, validation, and testing.

load_tg_data()

Load Reddit dataset using PyTorch Geometric (PyG) loader. Extracts dataset features, node labels, and masks for training, validation, and testing.

Attributes

graph: dgl.DGLGraph

The graph data structure (when using DGL).

dataset_name: str

Name of the dataset (“reddit”).

node_number: int

Number of nodes in the graph.

feature_number: int

Dimension of node features.

label_number: int

Number of unique node labels.

features: torch.Tensor

Node feature matrix.

labels: torch.Tensor

Node label tensor.

train_mask: torch.Tensor

Boolean mask indicating training nodes.

val_mask: torch.Tensor

Boolean mask indicating validation nodes.

test_mask: torch.Tensor

Boolean mask indicating testing nodes.

dataset: torch_geometric.datasets.Reddit

The PyG dataset object (only available in PyG implementation).

data: torch_geometric.data.Data

The PyG data object (only available in PyG implementation).

edge_index: torch.Tensor

Edge list representation (only explicitly set in PyG implementation).