Skip to content

[Model] [PinSage] Add pinsage model. #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions examples/pinsage/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# PinSage

[PinSage](https://arxiv.org/abs/1806.01973) combines efficient random walks and graph convolutions to generate embeddings of nodes (i.e., items) that incorporate both graph structure as well as node feature information.


### Datasets
The reddit dataset should be downloaded from the following links and placed in the directory ```pgl.data```. The details for Reddit Dataset can be found [here](https://cs.stanford.edu/people/jure/pubs/pinsage-nips17.pdf).

- reddit.npz https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J
- reddit_adj.npz: https://drive.google.com/open?id=174vb0Ws7Vxk_QTUtxqTgDHSQ4El4qDHt


### Dependencies

- paddlepaddle>=2.0
- pgl

### How to run

To train a PinSage model on Reddit Dataset, you can just run

```
python train.py --epoch 10 --normalize --symmetry
```

If you want to train a PinSage model with multiple GPUs, you can just run with fleetrun API with `CUDA_VISIBLE_DEVICES`

```
CUDA_VISIBLE_DEVICES=0,1 fleetrun train.py --epoch 10 --normalize --symmetry
```


#### Hyperparameters

- epoch: Number of epochs default (10)
- normalize: Normalize the input feature if assign normalize.
- sample_workers: The number of workers for multiprocessing subgraph sample.
- lr: Learning rate.
- symmetry: Make the edges symmetric if assign symmetry.
- batch_size: Batch size.
- samples: The max neighbors for each layers hop neighbor sampling. (default: [30, 20])
- top_k: the top k nodes should be reseved.
- hidden_size: The hidden size of the PinSage models.


### Performance

We train our models for 10 epochs and report the accuracy on the test dataset.


| Aggregator | Accuracy |
| --- | --- |
| SUM | 91.36% |
58 changes: 58 additions & 0 deletions examples/pinsage/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import paddle
import numpy as np
from paddle.io import get_worker_info

from pgl import graph_kernel
from pgl.utils.logger import log
from pgl.sampling import pinsage_sample
from pgl.utils.data import Dataset


def batch_fn(batch_ex, graph, samples, top_k=50):
batch_train_samples = []
batch_train_labels = []
for i, l in batch_ex:
batch_train_samples.append(i)
batch_train_labels.append(l)

subgraphs = pinsage_sample(
graph, batch_train_samples, samples, top_k=top_k)
subgraph, sample_index, node_index = subgraphs[0]

node_label = np.array(batch_train_labels, dtype="int64").reshape([-1, 1])

return subgraph, sample_index, node_index, node_label


class ShardedDataset(Dataset):
def __init__(self, data_index, data_label, mode="train"):
worker_info = get_worker_info()
if worker_info is None or mode != "train":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShardedDataset这块好像有问题。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_worker_info()之前测试一直为None

self.data = [data_index, data_label]
else:
self.data = [
data_index[worker_info.id::worker_info.num_workers],
data_label[worker_info.id::worker_info.num_workers]
]

def __getitem__(self, idx):
return [data[idx] for data in self.data]

def __len__(self):
return len(self.data[0])
50 changes: 50 additions & 0 deletions examples/pinsage/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PinSage Model
"""
import pgl
import paddle.nn as nn


class PinSage(nn.Layer):
"""Implement of PinSage
"""

def __init__(self,
input_size,
num_class,
num_layers=1,
hidden_size=64,
dropout=0.5,
aggr_func="sum"):
super(PinSage, self).__init__()
self.num_class = num_class
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout = dropout
self.convs = nn.LayerList()
self.linear = nn.Linear(self.hidden_size, self.num_class)
for i in range(self.num_layers):
self.convs.append(
pgl.nn.PinSageConv(
input_size if i == 0 else hidden_size,
hidden_size,
aggr_func=aggr_func))

def forward(self, graph, feature, weight):
for conv in self.convs:
feature = conv(graph, feature, weight)
feature = self.linear(feature)
return feature
199 changes: 199 additions & 0 deletions examples/pinsage/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from functools import partial

import numpy as np
import tqdm
import pgl
import paddle
from pgl.utils.logger import log
from pgl.utils.data import Dataloader

from model import PinSage
from dataset import ShardedDataset, batch_fn


def train(dataloader, model, feature, criterion, optim, log_per_step=100):
model.train()

batch = 0
total_loss = 0.
total_acc = 0.
total_sample = 0

for graph, sample_index, index, label in dataloader:
label = label.reshape([-1, 1])
batch += 1
num_samples = len(index)

graph.tensor()
sample_index = paddle.to_tensor(sample_index)
index = paddle.to_tensor(index)
label = paddle.to_tensor(label)

feat = paddle.gather(feature, sample_index)
weight = graph.edge_feat["weight"]
pred = model(graph, feat, weight)
pred = paddle.gather(pred, index)
loss = criterion(pred, label)
loss.backward()
acc = paddle.metric.accuracy(input=pred, label=label, k=1)
optim.step()
optim.clear_grad()

total_loss += loss.numpy() * num_samples
total_acc += acc.numpy() * num_samples
total_sample += num_samples

if batch % log_per_step == 0:
log.info("Batch %s %s-Loss %s %s-Acc %s", batch, "train",
loss.numpy(), "train", acc.numpy())

return total_loss / total_sample, total_acc / total_sample


@paddle.no_grad()
def evaluate(dataloader, model, feature, criterion):
model.eval()
loss_all, acc_all = [], []
for graph, sample_index, index, label in dataloader:
graph.tensor()
sample_index = paddle.to_tensor(sample_index)
index = paddle.to_tensor(index)
label = paddle.to_tensor(label)

feat = paddle.gather(feature, sample_index)
weight = graph.edge_feat["weight"]
pred = model(graph, feat, weight)
pred = paddle.gather(pred, index)
loss = criterion(pred, label)
acc = paddle.metric.accuracy(input=pred, label=label, k=1)
loss_all.append(loss.numpy())
acc_all.append(acc.numpy())

return np.mean(loss_all), np.mean(acc_all)


def main(args):
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()

data = pgl.dataset.RedditDataset(args.normalize, args.symmetry)
#data = pgl.dataset.CoraDataset(args.normalize, args.symmetry)
log.info("Preprocess finish")
log.info("Train Examples: %s", len(data.train_index))
log.info("Val Examples: %s", len(data.val_index))
log.info("Test Examples: %s", len(data.test_index))
log.info("Num nodes %s", data.graph.num_nodes)
log.info("Num edges %s", data.graph.num_edges)
log.info("Average Degree %s", np.mean(data.graph.indegree()))

graph = data.graph
train_index = data.train_index
val_index = data.val_index
test_index = data.test_index

train_label = data.train_label
val_label = data.val_label
test_label = data.test_label

model = PinSage(
input_size=data.feature.shape[-1],
num_class=data.num_classes,
hidden_size=args.hidden_size,
num_layers=len(args.samples),
aggr_func=args.aggr_func)
if paddle.distributed.get_world_size() > 1:
model = paddle.DataParallel(model)

criterion = paddle.nn.loss.CrossEntropyLoss()

optim = paddle.optimizer.Adam(
learning_rate=args.lr,
parameters=model.parameters(),
weight_decay=0.001)

feature = paddle.to_tensor(data.feature)

train_ds = ShardedDataset(train_index, train_label)
val_ds = ShardedDataset(val_index, val_label)
test_ds = ShardedDataset(test_index, test_label)

collate_fn = partial(
batch_fn, graph=graph, samples=args.samples, top_k=args.top_k)

train_loader = Dataloader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.sample_workers,
collate_fn=collate_fn)
val_loader = Dataloader(
val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.sample_workers,
collate_fn=collate_fn)
test_loader = Dataloader(
test_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.sample_workers,
collate_fn=collate_fn)

cal_val_acc = []
cal_test_acc = []
cal_val_loss = []
for epoch in tqdm.tqdm(range(args.epoch)):
train_loss, train_acc = train(train_loader, model, feature, criterion,
optim)
log.info("Runing epoch:%s\t train_loss:%s\t train_acc:%s", epoch,
train_loss, train_acc)
val_loss, val_acc = evaluate(val_loader, model, feature, criterion)
cal_val_acc.append(val_acc)
cal_val_loss.append(val_loss)
log.info("Runing epoch:%s\t val_loss:%s\t val_acc:%s", epoch, val_loss,
val_acc)
test_loss, test_acc = evaluate(test_loader, model, feature, criterion)
cal_test_acc.append(test_acc)
log.info("Runing epoch:%s\t test_loss:%s\t test_acc:%s", epoch,
test_loss, test_acc)

log.info("Runs %s: Model: %s Best Test Accuracy: %f", 0, "pinsage",
cal_test_acc[np.argmax(cal_val_acc)])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='pinsage')
parser.add_argument(
"--normalize", action='store_true', help="normalize features")
parser.add_argument(
"--symmetry", action='store_true', help="undirect graph")
parser.add_argument(
"--aggr_func",
type=str,
default="sum",
help="aggregate function, sum, mean, max, min available.")
parser.add_argument("--sample_workers", type=int, default=8)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--hidden_size", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--samples', nargs='+', type=int, default=[30, 20])
parser.add_argument("--top_k", type=int, default=200)
args = parser.parse_args()
log.info(args)
main(args)
4 changes: 4 additions & 0 deletions pgl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def _load_data(self):
self.val_index = perm[200:500]
self.test_index = perm[500:1500]
self.y = np.array(y, dtype="int64")
self.train_label = self.y[self.train_index]
self.val_label = self.y[self.val_index]
self.test_label = self.y[self.test_index]
self.num_classes = len(y_dict)
self.feature = node_feature


class BlogCatalogDataset(object):
Expand Down
Loading