Sansan Tech Blog


Hands-on guidance to DGL library _ (5) Training on large graphs


Hi, I am XING LI, a researcher from Sansan DSOC Data Analysis Group.

This is the article of Day 9 of Sansan Advent Calendar 2020.

Last time, we talked about some common tasks in deep graph learning and built a toy network on Node Classification task as a demo. We have known how to build, train and test a simple deep graph neural network by DGL. However, the graph dataset we used, Cora dataset, only has 2708 nodes and 5429 edges. The GNN's aggregation operation will not take a long time on such tiny graph. But in reality, we also face much larger graphs, for example, over million nodes and billion edges. Consider an  \textit{L}-layer GCN with hidden state size  \textit{H}, training on an  \textit{N}-node graph. To store the intermediate hidden states requires  \textit{O(NLH)} memory, which is easily exceeding one GPU’s capacity with a large  \textit{N}. So today we are going to talk about how to train a deep GNN on large graphs with help of DGL.

The answer is simple. It is inspired to take the method widely used in other CV or NLP areas to deal with large datasets, the stochastic mini-batch training. And this is why we talked about NodeFlow Data Structure in Series 2.

Recall the neighbourhood sampling approaches

To understand how we generate this so-called "mini-batch" subgraph from one large graph, recall two figures from Series 2: f:id:sansanxingli:20200629132418p:plainf:id:sansanxingli:20200629143037p:plain

The neighbourhood sampling methods generally work as the above two figures. For example, in each gradient descent step, a mini-batch of nodes whose final representations at the  \textit{L}-layer are to be computed. We then take all or some of their neighbours, depending on variable sampling policies, at the  \textit{L-1} layer. This process continues until we reach the input. This iterative process builds the dependency graph starting from the output and working backwards to the input. With this, we can save the workload and computation resources for training a GNN on large graphs.

In Series 2, I also put a piece of sample codes to demonstrate the usage of NeighborSampler function. But since DGL updated to 0.5.x version, the NeighborSampler was deprecated and DGL starts offering a mature sampling package as dgl.sampling. As an old saying goes, "Don't build wheels repetitively!", let's explore this new package.

The new package dgl.sampling

Here is the package's link.

"The dgl.sampling package contains operators and utilities for sampling from a graph via random walks, neighbor sampling, etc."

What surprises me is the new package implements not only the basic Random Walk and sample_neighbors, but also the PinSAGE style neighbourhood sampling method dgl.sampling.PinSAGESampler!

This is basically how we could solve the training on large graphs through DGL new version.

▼【Hands-on guidance to DGL library】Series

© Sansan, Inc.