Sansan Builders Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

【ML Tech RPT. 】第21回 構造に関連する機械学習を学ぶ (7) ~グラフ その3~

f:id:ssatsuki040508:20181210005017p:plain
DSOC研究員の吉村です. あけましておめでとうございます. 今年もブログをたくさん書いていけたらいいなと思います. 今年は 01/04 に初めて餅を焼いて食べました. やはり, 餅を食べると正月を感じられて最高ですね.

さて, 今回は GNN (Graph Neural Network) の話です. 一通りの手法についてまとめていきます.

GNN をなぜ使うのか

そもそも, なぜ GNN を使うのでしょうか. その理由は, 単純に性能が良いというのが一つと, GNN のそれぞれのモデルの定式化の意図が明確であるという意味において解釈性があるというのが一つです. また, グラフの構造であれば容易に応用しやすいというのも, 理由として考えられそうです.

GNN の手法

各種手法を説明する前に, 簡単に記号を整理しておきます. まず, 対象となるグラフを  G(E, V) とします. この時,  E は枝集合で,  V は頂点集合とします. また, ある枝  e \in E についての特徴量を  X_e, ある頂点  v \in V についての特徴量を  X_v で表すこととします. GNN については, 基本的な考え方として, 頂点についての埋め込み表現をまずは獲得することを目標とするので, ある頂点  v に対応する埋め込み表現を  h_v \in \mathcal{R}^d で表し, これを全てまとめた行列を  H \in \mathcal{R}^{|V|\times d} で表すこととします. 加えて, ある頂点  v から生えている枝集合を  co[v] とし,  v が接続している頂点集合を  ne[v] で表します.

(Vanilla) GNN

まずは, 最も基本となる GNN について考えていきます. おそらく初出となる GNN は 2009 年に発表されています[*1]. この中では, ある頂点  v の埋め込み表現  h_v を次のようにモデリングしています.

 h_v =\displaystyle{ \sum_{u \in ne[v]}f_{\bf{w}} (X_v, X_{(v, u)}, h_u, X_u), \ \ \ \ v \in V}

ただし,  f_{\bf{w}} は, パラメータ  \bf{w} による関数とします.

この式を見るとGNN では, ある頂点の埋め込み表現は, その頂点自身の特徴量, その頂点から生える枝の特徴量, そして, その頂点に接続する頂点の特徴量と埋め込み表現にのみ依存することを仮定しています. この関数  f_{\bf{w}} のパラメータ  \bf{w} は学習するパラメータで, 埋め込み表現はどちらかというと, 上記の式を全体として満たすような解として捉えることができます. そこで, この各  v の埋め込み表現  h_v を得るために, 元の(方程)式を下記のような更新式に捉え直した反復法により計算します.

 \displaystyle{h_v (t+1) = \sum_{u \in ne[v]}f_{\bf{w}} (X_v, X_{(v, u)}, h_u(t), X_u), \ \ \ \ v \in V}

ただし,  t は更新回数を表し, 反復は収束するまで (収束条件を満たすまで) 繰り返します.

GCN

GCN (Graph Convolutional Network) [*2] は, グラフ構造に信号処理などで用いられている畳み込み演算を適用する手法です.

まず, グラフの畳み込み演算を下記のように表します.

 \displaystyle{g_{\theta}\star x = Ug_{\theta}U^\top x,\ \ \ x \in \mathcal{R}^d}

ただし,  g_{\theta} = \text{diag}(\theta),  U は正規化グラフラプラシアン L = I_N - D^{-\frac{1}{2}}AD^{-\frac{1}{2}}=U\Lambda U^\topの固有ベクトルを並べた行列を表します.  A は隣接行列で,  D は次数行列です.  N は頂点の数,  I_N N \times N の単位行列を表しています. この畳み込み演算では, まず  U を求める必要がありますが, この計算に  O(N^2) の計算量がかかります. そこで, GCN では複数の近似手法と, renormalization trick を用いることで最終的に下記の簡単な行列演算のみの式を計算することになります.

 \displaystyle{H^{(l+1)} =\sigma(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)})}

ただし,  l は層の番号を,  \sigma は活性化関数を, そして,  W^{(l)} は学習パラメータを表します. また,  \tilde{D}_{ii}=\sum_{j}\tilde{A}_{ij},  \tilde{A} = A + I_N です. 何層であっても,  \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} の部分は, 最初に一度前処理として計算しておくだけで, 残りの部分はただの行列演算二回とすることができることになります.

GCN には更に様々なバリエーションがあり, [*3] の中では, Neural FPS, PATCHY-SAN, DCNN, DGCN, LGCN, MoNet, GraphSAGE が紹介されています.

GAT

GAT (Graph Attention Network) [*4]は, 各頂点の埋め込み表現を取得するときに, Attention 構造を利用する手法です.

GAT では, まず Attention coefficient を次にように定義します.

 \displaystyle{\alpha_{ij}=\frac{\exp(\text{LeakyReLU}(\textbf{a}^\top[\textbf{W}h_i||\textbf{W}h_j]))}{\sum_{k\in ne[i]} \exp(\text{LeakyReLU}(\textbf{a}^\top[\textbf{W}h_i||\textbf{W}h_k]))}}

ただし,  \textbf{W} \textbf{a} は学習する重みパラメータを表し,  || は concatenate 演算 を表します. この coefficient を用いて, 最終的に埋め込み表現を次の形で得ることができます.

 \displaystyle{h_i'=\sigma(\sum_{j \in ne[i]} \alpha_{ij} \textbf{W}h_j)}

ここからさらに, self-attention の学習を安定化させるために, Attention 構造を増やした multi-head の場合も元の GAT の論文では提案されており, その場合には, いくつかの Attention 構造で得られた結果は最終的に, concatenate するか平均をとるかを行い, 埋め込みを得ています. GAT は, 実験的に GCN よりも高い性能が得られたことも示されています.

まとめ

今回はグラフに対する Deep Learning 手法である GNN について, 基本的なものをいくつか紹介しました. それぞれが, どう違っているのかがざっくりとわかってもらえていれば幸いです. 長くなりすぎるので名前だけ紹介した GCN から少し変化した手法や, そのほかにも書ききれなかったたくさんの手法が次々と出ており, これらは大半が実装されて公開されているので実際に試してみると, 構造がどのように結果に影響してくるかがわかって面白いかもしれません. 前回の最後で書いていたもののうち, 今回書ききれなかった部分は次回に回すことにします. また次もご期待ください.




▼【ML Tech RPT. 】シリーズ
buildersbox.corp-sansan.com

*1:F. Scarselli, AC. Tsoi, and M. Hagenbuchner, "The graph neural network model," in TNNLS, 2019.

*2:TN. Kipf, and M. Welling, "Semi-Supervised Classification with Graph Convolutional Networks," in ICLR, 2017.

*3:Zhiyuan Liu, and Jie Zhou, "Introduction to Graph Neural Networks," in MORGAN & CLAYPOOL, 2020.

*4:P. Veličković, G. Cucurull, A. Casanova, A. Romero, P. Liò, and Y. Bengio, "Graph Attention Networks," in ICLR, 2018.

© Sansan, Inc.