From 6532d9d4720797bf8d6ed47f997470b6a1aeeaa4 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 16 Mar 2021 15:01:29 +0800 Subject: [PATCH] [PinSage] Add pinsage model. --- examples/pinsage/README.md | 53 ++++++++++ examples/pinsage/dataset.py | 58 +++++++++++ examples/pinsage/model.py | 50 +++++++++ examples/pinsage/train.py | 199 ++++++++++++++++++++++++++++++++++++ pgl/dataset.py | 4 + pgl/nn/conv.py | 71 +++++++------ pgl/sampling/custom.py | 16 ++- pgl/sampling/sage.py | 151 ++++++++++++++++++++++++++- tests/test_sample.py | 10 +- 9 files changed, 571 insertions(+), 41 deletions(-) create mode 100644 examples/pinsage/README.md create mode 100644 examples/pinsage/dataset.py create mode 100644 examples/pinsage/model.py create mode 100644 examples/pinsage/train.py diff --git a/examples/pinsage/README.md b/examples/pinsage/README.md new file mode 100644 index 00000000..e609f9a6 --- /dev/null +++ b/examples/pinsage/README.md @@ -0,0 +1,53 @@ +# PinSage + +[PinSage](https://arxiv.org/abs/1806.01973) combines efficient random walks and graph convolutions to generate embeddings of nodes (i.e., items) that incorporate both graph structure as well as node feature information. + + +### Datasets +The reddit dataset should be downloaded from the following links and placed in the directory ```pgl.data```. The details for Reddit Dataset can be found [here](https://cs.stanford.edu/people/jure/pubs/pinsage-nips17.pdf). + +- reddit.npz https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J +- reddit_adj.npz: https://drive.google.com/open?id=174vb0Ws7Vxk_QTUtxqTgDHSQ4El4qDHt + + +### Dependencies + +- paddlepaddle>=2.0 +- pgl + +### How to run + +To train a PinSage model on Reddit Dataset, you can just run + +``` + python train.py --epoch 10 --normalize --symmetry +``` + +If you want to train a PinSage model with multiple GPUs, you can just run with fleetrun API with `CUDA_VISIBLE_DEVICES` + +``` +CUDA_VISIBLE_DEVICES=0,1 fleetrun train.py --epoch 10 --normalize --symmetry +``` + + +#### Hyperparameters + +- epoch: Number of epochs default (10) +- normalize: Normalize the input feature if assign normalize. +- sample_workers: The number of workers for multiprocessing subgraph sample. +- lr: Learning rate. +- symmetry: Make the edges symmetric if assign symmetry. +- batch_size: Batch size. +- samples: The max neighbors for each layers hop neighbor sampling. (default: [30, 20]) +- top_k: the top k nodes should be reseved. +- hidden_size: The hidden size of the PinSage models. + + +### Performance + +We train our models for 10 epochs and report the accuracy on the test dataset. + + +| Aggregator | Accuracy | +| --- | --- | +| SUM | 91.36% | diff --git a/examples/pinsage/dataset.py b/examples/pinsage/dataset.py new file mode 100644 index 00000000..1de6a7cc --- /dev/null +++ b/examples/pinsage/dataset.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import numpy as np +from paddle.io import get_worker_info + +from pgl import graph_kernel +from pgl.utils.logger import log +from pgl.sampling import pinsage_sample +from pgl.utils.data import Dataset + + +def batch_fn(batch_ex, graph, samples, top_k=50): + batch_train_samples = [] + batch_train_labels = [] + for i, l in batch_ex: + batch_train_samples.append(i) + batch_train_labels.append(l) + + subgraphs = pinsage_sample( + graph, batch_train_samples, samples, top_k=top_k) + subgraph, sample_index, node_index = subgraphs[0] + + node_label = np.array(batch_train_labels, dtype="int64").reshape([-1, 1]) + + return subgraph, sample_index, node_index, node_label + + +class ShardedDataset(Dataset): + def __init__(self, data_index, data_label, mode="train"): + worker_info = get_worker_info() + if worker_info is None or mode != "train": + self.data = [data_index, data_label] + else: + self.data = [ + data_index[worker_info.id::worker_info.num_workers], + data_label[worker_info.id::worker_info.num_workers] + ] + + def __getitem__(self, idx): + return [data[idx] for data in self.data] + + def __len__(self): + return len(self.data[0]) diff --git a/examples/pinsage/model.py b/examples/pinsage/model.py new file mode 100644 index 00000000..6b54b565 --- /dev/null +++ b/examples/pinsage/model.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PinSage Model +""" +import pgl +import paddle.nn as nn + + +class PinSage(nn.Layer): + """Implement of PinSage + """ + + def __init__(self, + input_size, + num_class, + num_layers=1, + hidden_size=64, + dropout=0.5, + aggr_func="sum"): + super(PinSage, self).__init__() + self.num_class = num_class + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout = dropout + self.convs = nn.LayerList() + self.linear = nn.Linear(self.hidden_size, self.num_class) + for i in range(self.num_layers): + self.convs.append( + pgl.nn.PinSageConv( + input_size if i == 0 else hidden_size, + hidden_size, + aggr_func=aggr_func)) + + def forward(self, graph, feature, weight): + for conv in self.convs: + feature = conv(graph, feature, weight) + feature = self.linear(feature) + return feature diff --git a/examples/pinsage/train.py b/examples/pinsage/train.py new file mode 100644 index 00000000..f3eeb23c --- /dev/null +++ b/examples/pinsage/train.py @@ -0,0 +1,199 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from functools import partial + +import numpy as np +import tqdm +import pgl +import paddle +from pgl.utils.logger import log +from pgl.utils.data import Dataloader + +from model import PinSage +from dataset import ShardedDataset, batch_fn + + +def train(dataloader, model, feature, criterion, optim, log_per_step=100): + model.train() + + batch = 0 + total_loss = 0. + total_acc = 0. + total_sample = 0 + + for graph, sample_index, index, label in dataloader: + label = label.reshape([-1, 1]) + batch += 1 + num_samples = len(index) + + graph.tensor() + sample_index = paddle.to_tensor(sample_index) + index = paddle.to_tensor(index) + label = paddle.to_tensor(label) + + feat = paddle.gather(feature, sample_index) + weight = graph.edge_feat["weight"] + pred = model(graph, feat, weight) + pred = paddle.gather(pred, index) + loss = criterion(pred, label) + loss.backward() + acc = paddle.metric.accuracy(input=pred, label=label, k=1) + optim.step() + optim.clear_grad() + + total_loss += loss.numpy() * num_samples + total_acc += acc.numpy() * num_samples + total_sample += num_samples + + if batch % log_per_step == 0: + log.info("Batch %s %s-Loss %s %s-Acc %s", batch, "train", + loss.numpy(), "train", acc.numpy()) + + return total_loss / total_sample, total_acc / total_sample + + +@paddle.no_grad() +def evaluate(dataloader, model, feature, criterion): + model.eval() + loss_all, acc_all = [], [] + for graph, sample_index, index, label in dataloader: + graph.tensor() + sample_index = paddle.to_tensor(sample_index) + index = paddle.to_tensor(index) + label = paddle.to_tensor(label) + + feat = paddle.gather(feature, sample_index) + weight = graph.edge_feat["weight"] + pred = model(graph, feat, weight) + pred = paddle.gather(pred, index) + loss = criterion(pred, label) + acc = paddle.metric.accuracy(input=pred, label=label, k=1) + loss_all.append(loss.numpy()) + acc_all.append(acc.numpy()) + + return np.mean(loss_all), np.mean(acc_all) + + +def main(args): + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + data = pgl.dataset.RedditDataset(args.normalize, args.symmetry) + #data = pgl.dataset.CoraDataset(args.normalize, args.symmetry) + log.info("Preprocess finish") + log.info("Train Examples: %s", len(data.train_index)) + log.info("Val Examples: %s", len(data.val_index)) + log.info("Test Examples: %s", len(data.test_index)) + log.info("Num nodes %s", data.graph.num_nodes) + log.info("Num edges %s", data.graph.num_edges) + log.info("Average Degree %s", np.mean(data.graph.indegree())) + + graph = data.graph + train_index = data.train_index + val_index = data.val_index + test_index = data.test_index + + train_label = data.train_label + val_label = data.val_label + test_label = data.test_label + + model = PinSage( + input_size=data.feature.shape[-1], + num_class=data.num_classes, + hidden_size=args.hidden_size, + num_layers=len(args.samples), + aggr_func=args.aggr_func) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + criterion = paddle.nn.loss.CrossEntropyLoss() + + optim = paddle.optimizer.Adam( + learning_rate=args.lr, + parameters=model.parameters(), + weight_decay=0.001) + + feature = paddle.to_tensor(data.feature) + + train_ds = ShardedDataset(train_index, train_label) + val_ds = ShardedDataset(val_index, val_label) + test_ds = ShardedDataset(test_index, test_label) + + collate_fn = partial( + batch_fn, graph=graph, samples=args.samples, top_k=args.top_k) + + train_loader = Dataloader( + train_ds, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.sample_workers, + collate_fn=collate_fn) + val_loader = Dataloader( + val_ds, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.sample_workers, + collate_fn=collate_fn) + test_loader = Dataloader( + test_ds, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.sample_workers, + collate_fn=collate_fn) + + cal_val_acc = [] + cal_test_acc = [] + cal_val_loss = [] + for epoch in tqdm.tqdm(range(args.epoch)): + train_loss, train_acc = train(train_loader, model, feature, criterion, + optim) + log.info("Runing epoch:%s\t train_loss:%s\t train_acc:%s", epoch, + train_loss, train_acc) + val_loss, val_acc = evaluate(val_loader, model, feature, criterion) + cal_val_acc.append(val_acc) + cal_val_loss.append(val_loss) + log.info("Runing epoch:%s\t val_loss:%s\t val_acc:%s", epoch, val_loss, + val_acc) + test_loss, test_acc = evaluate(test_loader, model, feature, criterion) + cal_test_acc.append(test_acc) + log.info("Runing epoch:%s\t test_loss:%s\t test_acc:%s", epoch, + test_loss, test_acc) + + log.info("Runs %s: Model: %s Best Test Accuracy: %f", 0, "pinsage", + cal_test_acc[np.argmax(cal_val_acc)]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='pinsage') + parser.add_argument( + "--normalize", action='store_true', help="normalize features") + parser.add_argument( + "--symmetry", action='store_true', help="undirect graph") + parser.add_argument( + "--aggr_func", + type=str, + default="sum", + help="aggregate function, sum, mean, max, min available.") + parser.add_argument("--sample_workers", type=int, default=8) + parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--hidden_size", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=128) + parser.add_argument("--lr", type=float, default=0.01) + parser.add_argument('--samples', nargs='+', type=int, default=[30, 20]) + parser.add_argument("--top_k", type=int, default=200) + args = parser.parse_args() + log.info(args) + main(args) diff --git a/pgl/dataset.py b/pgl/dataset.py index 36c63196..cebdda59 100644 --- a/pgl/dataset.py +++ b/pgl/dataset.py @@ -235,7 +235,11 @@ def _load_data(self): self.val_index = perm[200:500] self.test_index = perm[500:1500] self.y = np.array(y, dtype="int64") + self.train_label = self.y[self.train_index] + self.val_label = self.y[self.val_index] + self.test_label = self.y[self.test_index] self.num_classes = len(y_dict) + self.feature = node_feature class BlogCatalogDataset(object): diff --git a/pgl/nn/conv.py b/pgl/nn/conv.py index 81aa4d7b..95b195c1 100644 --- a/pgl/nn/conv.py +++ b/pgl/nn/conv.py @@ -45,9 +45,9 @@ class GraphSageConv(nn.Layer): Advances in neural information processing systems. 2017. Args: - - input_size: The size of the inputs. - + + input_size: The size of the inputs. + hidden_size: The size of outputs aggr_func: (default "sum") Aggregation function for GraphSage ["sum", "mean", "max", "min"]. @@ -64,16 +64,16 @@ def __init__(self, input_size, hidden_size, aggr_func="sum"): def forward(self, graph, feature, act=None): """ - + Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) - + act: (default None) Activation for outputs and before normalize. - + Return: A tensor with shape (num_nodes, output_size) @@ -107,13 +107,13 @@ class PinSageConv(nn.Layer): Paper reference: Ying, Rex, et al. "Graph convolutional neural networks for web-scale recommender systems." - Proceedings of the 24th ACM SIGKDD International Conference on Knowledge + Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2018. Args: - - input_size: The size of the inputs. - + + input_size: The size of the inputs. + hidden_size: The size of outputs aggr_func: (default "sum") Aggregation function for GraphSage ["sum", "mean", "max", "min"]. @@ -132,16 +132,15 @@ def __init__(self, input_size, hidden_size, aggr_func="sum"): def forward(self, graph, nfeat, efeat, act=None): """ Args: - + graph: `pgl.Graph` instance. nfeat: A tensor with shape (num_nodes, input_size) efeat: A tensor with shape (num_edges, 1) denotes edge weight. - + act: (default None) Activation for outputs and before normalize. - Return: A tensor with shape (num_nodes, output_size) @@ -157,7 +156,7 @@ def _recv_func(message): _send_func, src_feat={"h": nfeat}, edge_feat={"w": efeat}) neigh_feature = graph.recv(reduce_func=_recv_func, msg=msg) - self_feature = self.self_linear(feature) + self_feature = self.self_linear(nfeat) neigh_feature = self.neigh_linear(neigh_feature) output = self_feature + neigh_feature if act is not None: @@ -175,7 +174,7 @@ class GCNConv(nn.Layer): Args: - input_size: The size of the inputs. + input_size: The size of the inputs. output_size: The size of outputs @@ -198,15 +197,15 @@ def __init__(self, input_size, output_size, activation=None, norm=True): def forward(self, graph, feature, norm=None): """ - + Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) norm: (default None). If :code:`norm` is not None, then the feature will be normalized by given norm. If :code:`norm` is None and :code:`self.norm` is `true`, then we use `lapacian degree norm`. - + Return: A tensor with shape (num_nodes, output_size) @@ -243,7 +242,7 @@ class GATConv(nn.Layer): Args: - input_size: The size of the inputs. + input_size: The size of the inputs. hidden_size: The hidden size for gat. @@ -310,18 +309,18 @@ def _reduce_attention(self, msg): def forward(self, graph, feature): """ - + Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) - + Return: If `concat=True` then return a tensor with shape (num_nodes, hidden_size), - else return a tensor with shape (num_nodes, hidden_size * num_heads) + else return a tensor with shape (num_nodes, hidden_size * num_heads) """ @@ -348,7 +347,7 @@ def forward(self, graph, feature): class APPNP(nn.Layer): """Implementation of APPNP of "Predict then Propagate: Graph Neural Networks - meet Personalized PageRank" (ICLR 2019). + meet Personalized PageRank" (ICLR 2019). Args: @@ -367,15 +366,15 @@ def __init__(self, alpha=0.2, k_hop=10): def forward(self, graph, feature, norm=None): """ - + Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) norm: (default None). If :code:`norm` is not None, then the feature will be normalized by given norm. If :code:`norm` is None, then we use `lapacian degree norm`. - + Return: A tensor with shape (num_nodes, output_size) @@ -395,7 +394,7 @@ def forward(self, graph, feature, norm=None): class GCNII(nn.Layer): - """Implementation of GCNII of "Simple and Deep Graph Convolutional Networks" + """Implementation of GCNII of "Simple and Deep Graph Convolutional Networks" paper: https://arxiv.org/pdf/2007.02133.pdf @@ -405,9 +404,9 @@ class GCNII(nn.Layer): activation: The activation for the output. k_hop: Number of layers for gcnii. - + lambda_l: The hyperparameter of lambda in the paper. - + alpha: The hyperparameter of alpha in the paper. dropout: Feature dropout rate. @@ -438,13 +437,13 @@ def __init__(self, def forward(self, graph, feature, norm=None): """ Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) norm: (default None). If :code:`norm` is not None, then the feature will be normalized by given norm. If :code:`norm` is None, then we use `lapacian degree norm`. - + Return: A tensor with shape (num_nodes, output_size) @@ -656,13 +655,13 @@ def __init__(self, def forward(self, graph, feature): """ - + Args: - + graph: `pgl.Graph` instance. feature: A tensor with shape (num_nodes, input_size) - + Return: A tensor with shape (num_nodes, output_size) diff --git a/pgl/sampling/custom.py b/pgl/sampling/custom.py index 5dfd6b99..8dda4286 100644 --- a/pgl/sampling/custom.py +++ b/pgl/sampling/custom.py @@ -24,6 +24,8 @@ def subgraph(graph, nodes, eid=None, edges=None, + append_nodes_feat=None, + append_edges_feat=None, with_node_feat=True, with_edge_feat=True): """Generate subgraph with nodes and edge ids. @@ -36,7 +38,8 @@ def subgraph(graph, nodes: Node ids which will be included in the subgraph. eid (optional): Edge ids which will be included in the subgraph. edges (optional): Edge(src, dst) list which will be included in the subgraph. - + append_nodes_feat: Additional nodes feature to append. + append_edges_feat: Additional edges feature to append. with_node_feat: Whether to inherit node features from parent graph. with_edge_feat: Whether to inherit edge features from parent graph. @@ -53,12 +56,17 @@ def subgraph(graph, for ind, node in enumerate(nodes): reindex[node] = ind - sub_edge_feat = {} if edges is None: edges = graph._edges[eid] else: edges = np.array(edges, dtype="int64") + sub_edge_feat = {} + if append_edges_feat is not None: + for k, v in append_edges_feat.items(): + assert v.shape[0] == edges.shape[0] + sub_edge_feat[k] = v + if with_edge_feat: for key, value in graph._edge_feat.items(): if eid is None: @@ -70,6 +78,10 @@ def subgraph(graph, len(edges), dtype="int64"), edges, reindex) sub_node_feat = {} + if append_nodes_feat is not None: + for k, v in append_nodes_feat.items(): + assert v.shape[0] == len(nodes) + sub_node_feat[k] = v if with_node_feat: for key, value in graph._node_feat.items(): sub_node_feat[key] = value[nodes] diff --git a/pgl/sampling/sage.py b/pgl/sampling/sage.py index 0fa9df57..488e0618 100644 --- a/pgl/sampling/sage.py +++ b/pgl/sampling/sage.py @@ -22,7 +22,10 @@ from pgl.sampling.custom import subgraph -__all__ = ['graphsage_sample', ] +__all__ = [ + 'graphsage_sample', + 'pinsage_sample', +] def traverse(item): @@ -119,3 +122,149 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]): graph_list.append((sg, sample_index, node_index)) return graph_list + + +def random_walk_with_start_prob(graph, nodes, max_depth, proba=0.5): + """Implement of random walk with the probability of returning the origin node. + + This function get random walks path for given nodes and depth. + + Args: + nodes: Walk starting from nodes + max_depth: Max walking depth + proba: the proba to return the origin node + + Return: + A list of walks. + """ + walk = [] + # init + for node in nodes: + walk.append([node]) + + walk_ids = np.arange(0, len(nodes)) + cur_nodes = np.array(nodes) + nodes = np.array(nodes) + for l in range(max_depth): + # select the walks not end + if l >= 1: + return_proba = np.random.rand(cur_nodes.shape[0]) + proba_mask = (return_proba < proba) + cur_nodes[proba_mask] = nodes[proba_mask] + outdegree = graph.outdegree(cur_nodes) + mask = (outdegree != 0) + if np.any(mask): + cur_walk_ids = walk_ids[mask] + outdegree = outdegree[mask] + else: + # stop when all nodes have no successor, wait start next loop to get precesssor + continue + succ = graph.successor(cur_nodes[mask]) + sample_index = np.floor( + np.random.rand(outdegree.shape[0]) * outdegree).astype("int64") + + nxt_cur_nodes = cur_nodes + for s, ind, walk_id in zip(succ, sample_index, cur_walk_ids): + walk[walk_id].append(s[ind]) + nxt_cur_nodes[walk_id] = s[ind] + cur_nodes = np.array(nxt_cur_nodes) + return walk + + +def pinsage_sample(graph, + nodes, + samples, + top_k=10, + proba=0.5, + norm_bais=1.0, + ignore_edges=None): + """Implement of graphsage sample. + + Reference paper: https://arxiv.org/abs/1806.01973 + + Args: + graph: A pgl graph instance + nodes: Sample starting from nodes + samples: A list, number of neighbors in each layer + top_k: select the top_k visit count nodes to construct the edges + proba: the probability to return the origin node + norm_bais: the normlization for the visit count + ignore_edges: list of edge(src, dst) will be ignored. + + Return: + A list of subgraphs + """ + if ignore_edges is None: + ignore_edges = set() + num_layers = len(samples) + start_nodes = nodes + node_index = copy.deepcopy(nodes) + edges, weights = [], [] + layer_nodes, layer_edges, layer_weights = [], [], [] + ignore_edge_set = {edge_hash(src, dst) for src, dst in ignore_edges} + + for layer_idx in reversed(range(num_layers)): + if len(start_nodes) == 0: + layer_nodes = [nodes] + layer_nodes + layer_edges = [edges] + layer_edges + layer_edges_weight = [weights] + layer_weights + continue + walks = random_walk_with_start_prob( + graph, start_nodes, samples[layer_idx], proba=proba) + walks = [walk[1:] for walk in walks] + pred_edges = [] + pred_weights = [] + pred_nodes = [] + for node, walk in zip(start_nodes, walks): + walk_nodes = [] + walk_weights = [] + count_sum = 0 + + for random_walk_node in walk: + if len(ignore_edge_set) > 0 and random_walk_node != node and \ + edge_hash(random_walk_node, node) in ignore_edge_set: + continue + walk_nodes.append(random_walk_node) + unique, counts = np.unique(walk_nodes, return_counts=True) + frequencies = np.asarray((unique, counts)).T + frequencies = frequencies[np.argsort(frequencies[:, 1])] + frequencies = frequencies[-1 * top_k:, :] + for random_walk_node, random_count in zip( + frequencies[:, 0].tolist(), frequencies[:, 1].tolist()): + pred_nodes.append(random_walk_node) + pred_edges.append((random_walk_node, node)) + walk_weights.append(random_count) + count_sum += random_count + count_sum += len(walk_weights) * norm_bais + walk_weights = (np.array(walk_weights) + norm_bais) / (count_sum) + pred_weights.extend(walk_weights.tolist()) + last_node_set = set(nodes) + nodes, edges, weights = flat_node_and_edge([nodes, pred_nodes], \ + [edges, pred_edges], [weights, pred_weights]) + + layer_edges = [edges] + layer_edges + layer_weights = [weights] + layer_weights + layer_nodes = [nodes] + layer_nodes + + start_nodes = list(set(nodes) - last_node_set) + + from_reindex = {x: i for i, x in enumerate(layer_nodes[0])} + node_index = graph_kernel.map_nodes(node_index, from_reindex) + sample_index = np.array(layer_nodes[0], dtype="int64") + + subgraphs = [] + + for i in range(num_layers): + edge_feat_dict = { + "weight": np.array( + layer_weights[i], dtype='float32').rshape([-1, 1]) + } + sg = subgraph( + graph, + nodes=layer_nodes[0], + edges=layer_edges[i], + append_edges_feat=edge_feat_dict, + with_edge_feat=False) + subgraphs.append((sg, sample_index, node_index)) + + return subgraphs diff --git a/tests/test_sample.py b/tests/test_sample.py index 8b1710c7..929a8e4a 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -18,7 +18,7 @@ import numpy as np import paddle import pgl -from pgl.sampling import graphsage_sample +from pgl.sampling import graphsage_sample, pinsage_sample from pgl.sampling import random_walk from testsuite import create_random_graph @@ -34,7 +34,13 @@ def test_graphsage_sample(self): graph = create_random_graph() nodes = [1, 2, 3] np.random.seed(1) - subgraphs = graphsage_sample(graph, nodes, [10, 10], []) + subgraphs = graphsage_sample(graph, nodes, [10, 10]) + + def test_pinsage_sample(self): + graph = create_random_graph() + nodes = [1, 2, 3] + np.random.seed(1) + subgraphs = pinsage_sample(graph, nodes, [10, 10], top_k=5) def test_random_walk(self): num_nodes = 5