Hi, I am XING LI, a researcher from Sansan DSOC Data Analysis Group.
This is the article of Day 9 of Sansan Advent Calendar 2020.
Last time, we talked about some common tasks in deep graph learning and built a toy network on
Node Classification task as a demo. We have known how to build, train and test a simple deep graph neural network by DGL. However, the graph dataset we used, Cora dataset, only has 2708 nodes and 5429 edges. The GNN's aggregation operation will not take a long time on such tiny graph. But in reality, we also face much larger graphs, for example, over million nodes and billion edges. Consider an -layer GCN with hidden state size , training on an -node graph. To store the intermediate hidden states requires memory, which is easily exceeding one GPU’s capacity with a large . So today we are going to talk about how to train a deep GNN on large graphs with help of DGL.
The answer is simple. It is inspired to take the method widely used in other CV or NLP areas to deal with large datasets, the stochastic mini-batch training. And this is why we talked about
NodeFlow Data Structure in Series 2.
Recall the neighbourhood sampling approaches
To understand how we generate this so-called "mini-batch" subgraph from one large graph, recall two figures from Series 2:
The neighbourhood sampling methods generally work as the above two figures. For example, in each gradient descent step, a mini-batch of nodes whose final representations at the -layer are to be computed. We then take all or some of their neighbours, depending on variable sampling policies, at the layer. This process continues until we reach the input. This iterative process builds the dependency graph starting from the output and working backwards to the input. With this, we can save the workload and computation resources for training a GNN on large graphs.
In Series 2, I also put a piece of sample codes to demonstrate the usage of
NeighborSampler function. But since DGL updated to 0.5.x version, the
NeighborSampler was deprecated and DGL starts offering a mature sampling package as
dgl.sampling. As an old saying goes, "Don't build wheels repetitively!", let's explore this new package.
The new package dgl.sampling
Here is the package's link.
"The dgl.sampling package contains operators and utilities for sampling from a graph via random walks, neighbor sampling, etc."
What surprises me is the new package implements not only the basic
Random Walk and
sample_neighbors, but also the PinSAGE style neighbourhood sampling method
This is basically how we could solve the training on large graphs through DGL new version.
▼【Hands-on guidance to DGL library】Series buildersbox.corp-sansan.com