こんにちは、DSOC R&Dグループ研究員の 奥田 です。最近はYouTubeでコーギーの動画ばかり見ているのですが、あの食パンみたいなお尻が最高です。
今回は大規模グラフに対するグラフ埋め込み(Graph Embedding)を計算するPytorch-BigGraphについて紹介いたします。また、記事の後半ではWikipediaの実データを対象に、約200万ノード1億エッジという大規模グラフに対するグラフ埋め込みの計算や類似記事検索の結果などをご報告できればと思います。
概要
グラフ埋め込み
グラフ埋め込みとは、ノードとエッジから構成されたグラフ構造から、ノードの埋め込み表現を得るための手法やその表現自体のことを指します。直感的には、自然言語処理における単語埋め込み(Word Embedding)のグラフ版だと考えると理解しやすいかもしれません。
単語埋め込みにおいては、ある単語の意味は周囲に出現する単語の関係性から決定されるという分布仮説のもと、Skip-gramやCBOWといった単語系列の文脈を利用した方法で計算されます。”Man”や”Woman”といった単語が同じ文脈で使われやすいということは、意味的にも似ているといった考え方ですね。
単語埋め込みと同様に、グラフ埋め込みでは単語の系列の代わりにグラフ構造の繋がりを利用し、周囲のノードから注目するノードの性質を特徴付けます。同じような繋がりの傾向を持つノードは意味的にも近くなるという性質は、グラフ構造においても直感的に成り立ちそうです。このように似た概念をもつ単語埋め込みとグラフ埋め込みですが、言語は基本的に系列としては分岐せずに一直線なのに対して、グラフは複数の繋がりによって様々に分岐するのが大きな違いと言えます。
グラフ埋め込みによって得られた分散表現は、グラフ構造におけるリンク予測タスクに用いられるほか、他の機械学習タスクの特徴量として追加することによる精度向上等に利用できます。
Pytorch-BigGraphとは
Pytorch-BigGraphは、大規模グラフに対するグラフ埋め込みを計算するOSSライブラリです。Facebookにより開発されており、バックエンドにはPyTorchが使われています。
Pytorch-BigGraphの大きな特徴としては、大規模グラフを現実的な計算機リソースで扱うためのグラフを分割する方法や、単一計算機上でCPUのマルチスレッドを利用した並列計算、複数の計算機を組み合わせた分散実行などが挙げられます。計算はGPUではなくCPUで行いますので、CPUとメモリをもりもり積んだ環境でお試し下さい。
また、グラフ埋め込みの手法として知識グラフの埋め込みの手法を採用しているのも特筆すべき点でしょう。知識グラフの世界では、グラフ構造のノードとエッジを表す際に(subject, relation, object)
の三つ組(タプル)として表現します。例えばSansanというノードを対象にした場合は、(Sansan, 業種, IT)
と(Sansan, 取引形態, BtoB)
といったように、異なる関係で異なるノードとの繋がりが表現されます。これにより、より複雑で様々な概念が入り混じったグラフ構造を表現できます。そして、グラフ埋め込みの計算においてはタプルに対するスコア関数を定義し、学習データ内のタプルに関するスコアが小さくなるよう学習をします。 スコア関数の違いにより様々な手法が提案されているので、Pytorch-BigGraphではいくつかのスコア関数が実装されています(参考)。
Pytorch-BigGraphを使ってみる
それでは実際にPytorch-BigGraphを使ってみましょう。後ほど解説するWikipediaを対象にしたグラフ埋め込みの節で具体的なデータやソースコードを例示するので、ここでは大まかな流れを紹介します。
インストール
まずはpipでtorchbiggraphをインストールします。
$ pip install torchbiggraph
データの準備
Pytorch-BigGraphに必要なデータはグラフのノード-エッジの関係を表したタプルです。以下のようにタブ区切りのファイルとして用意しておきます。
subject_1 relation_1 object_1 subject_1 relation_2 object_2 subject_2 relation_1 object_3
データのインポート
それでは準備したデータからPytorch-BigGraphが扱うデータ構造に変換します。なお、ここからはコマンドラインのインターフェイスを利用していきますが、Pythonのコマンドを利用することも可能です。
$ torchbiggraph_import_from_tsv --lhs-col=0 --rel-col=1 --rhs-col=2 \ src/config/jawiki_split.py \ data/jawiki-20190901.tsv
コマンドには先ほど作成したタブ区切りのファイルを指定するほか、どのカラムが何に対応するのかを示すパラメータを指定します。lhs
,rhs
はそれぞれタプルの左側と右側(fromとto)、relは関係(エッジの種類)を表します。
また、Pytorch-BigGraphの実行全体で必要になるのが設定ファイルです。利用するアルゴリズムや各種パラメータを指定したり、中間ファイルの一時保存ディレクトリなどを設定します。詳しくはPytorch-BigGraphのexampleやこの後実験で使用するソースコードを参照下さい。
グラフ埋め込み計算
準備はすべて整ったので、あとは計算を回すだけです!以下のようにtorchbiggraph_train
を実行します。パラメータのedge_paths
は前のステップで自動で作成されるディレクトリです。
$ torchbiggraph_train src/config/jawiki_split.py \ -p edge_paths=data/jawiki-20190901_partitioned
あとは気長に学習が終わるのを待ちましょう。毎エポック終了時にその時点でのモデルファイルが出力されるので、途中でグラフ埋め込みのベクトルを利用することも可能です。
Wikipediaのリンク構造を対象にしたグラフ埋め込み
せっかくなので具体的なグラフ構造のデータを対象に、Pytorch-BigGraphによるグラフ埋め込みを作成してみましょう。
手軽に利用できる大規模グラフのデータとして、今回はWikipediaの記事のリンク構造を利用します。リンクというのは、例えばSansanの記事の文中には東証マザーズ(上場した株式市場名)のリンクが貼られているように、Wikipedia内の別の記事への参照リンクのことです。記事中にはそのようなリンクが多く付与されているため、それらを使うことで記事間の関係性を記述できます。ではリンク構造からどうやってグラフ構造に変換するかですが、参照リンクが貼った/貼られたの関係を
(Sansan, has_link_to, 東証マザーズ)
というタプルで表現します。これにより「Sansan」と「東証マザーズ」がノードで、その間を「has_link_to」という関係のエッジで結んでいることを表現します。これをすべての記事においてリストアップすることで、Wikipediaのリンクからグラフ構造を抽出します。
既存事例における計算リソースと計算時間
日本語のWikipedia記事を対象にした埋め込み表現は既にいくつかの事例が公開されていますので、先にそれらについて触れておきましょう。今回の記事ではあまり精度評価に踏み込まないので、ここでは手法や計算リソース/計算時間について紹介します。
1. node2vec
Wikipediaのグラフ構造を直接扱うという点で類似した手法として、このブログ記事ではnode2vecという手法で埋め込み表現を計算しています。
ブログ記事中にもあるように著者実装のnode2vecは大量にメモリを必要とするため、CPU96スレッド / メモリ624GBのn1-highmem-96 (GCP)を使って約10時間ほどかかったようです。
2. Wikipedia2vec
グラフ埋め込みとは少し文脈が異なりますが、Wikipediaの記事内の文章に含まれる単語やそのリンク構造を同一空間上に埋め込むWikipedia2vecという方法があります。
Wikipedia2vecの論文 では、CPU36スレッド / メモリ72GBのc5d.9xlarge (AWS)を使って 約4.6時間 (276min)かかったと記載されています。
データセット/実験条件
Wikipediaは記事のメタ情報やリンク構造などを様々な形式で配布しています。今回は2019年9月1日時点におけるWikipedia日本語記事のSQLダンプデータからデータセットを作成しました。
jawiki dump progress on 20190901
- 記事数(# of nodes):1,879,504
- リンク数(# of edges):95,699,595
また、Pytorch-BigGraphは以下の設定で実行しました。
- グラフ分割数:1
- 次元数:400
- 手法:ComplEx (
complex_diagonal
)
ソースコード
今回実験で使用したコードは下記レポジトリに置いてありますので参照下さい。
yagays/wikipedia_graph_embedding
結果1. 類似記事検索
それでは結果です。いくつかの記事に対して、グラフ埋め込みを元に近似最近傍探索によって類似する記事5件を出力してみます。
順位 | Query: 新宿駅 | Query: 乃木坂46 | Query: イギリス |
---|---|---|---|
1 | 渋谷駅 | 欅坂46 | ロンドン |
2 | 中央線快速 | AKB48 | 欧州連合 |
3 | 新宿 | Sexy_Zone | スリランカ |
4 | 駅ナンバリング | KARA | マンチェスター |
5 | 東京メトロ東西線 | 梅田彩佳 | ウェールズ |
いい感じに分散表現が計算できているのではないでしょうか?ここで「記事が類似している」とは、あくまでWikipediaのリンク関係のみに基づいたものなので、多少人間の直感と違う部分があるかもしれません。またリンク数/被リンク数の少ない記事は、やはり精度が悪くなってしまう問題もあります。
順位 | Query: ガリレオ_(探査機) | Query: ガリレオ_(テレビドラマ) |
---|---|---|
1 | パイオニア計画 | 東京DOGS |
2 | サーベイヤー計画 | コンフィデンスマンJP |
3 | ジェミニ計画 | フジテレビ火曜9時枠の連続ドラマ |
4 | 小惑星の一覧_(134001-135000) | はじめの一歩 |
5 | 木星への天体衝突 | 金曜ナイトドラマ |
また、Wikipediaのリンク構造を使うことで、語義曖昧性に対応できるというメリットもあります。Wikipedia内では事象ごとにユニークなIDが割り当てられているため、文章上では区別が付かない単語においても、同一意味空間の別の点としてマッピングできます。
ここでは「ガリレオ」という単語に対して、宇宙探索機のガリレオとテレビドラマのガリレオの2つの記事の類似記事を並べています。これを見ると確かに意味的な違いを区別できていそうです。これはシンプルなword2vecのような手法では対応できない問題だと言えます。
結果2. 計算時間
今回の計算にはCPU 8スレッド / メモリ30GBの計算機を利用しました。CPUは全スレッド利用し、メモリ的には10~11GB程度利用していたようです。そして計算時間はなんと6.9日 (epochあたり約3.3時間 × 50 epochs) !この原稿に間に合うか少し不安でしたが、なんとか最後まで回ってくれました。
node2vecやwikipedia2vecと比較するとメモリ消費を抑えられている一方で、計算時間はかなり必要としたことがわかります。
結果3. CPUスレッド数を変化されたときの処理時間
では、どれくらいの計算機環境があれば効率よくグラフ埋め込みが計算できるのでしょうか?今回は利用するCPUを8スレッドから64スレッドまで増やしていった時の計算時間の変化を測定してみました。
処理時間はtorchbiggraph_train
が出力するログから取得できる処理時間に対して、開始5エポックの平均を用いました。また本実験はCPU数をカスタマイズできるGCP環境で実施しました。
これを見ると、ある程度のスレッド数で頭打ちになっていることがわかります。この規模のデータであれば、16~24スレッド程度で問題ないでしょう。これなら6.9日→4.5日程度で学習が終わります。
Pytorch-BigGraphは大規模データを扱う都合上ファイルI/Oを頻繁に行う実装のようです。htop
等でモニタリングしているとカーネルの利用率が高いことから、処理時間そのものよりもI/Oやメモリアクセスにボトルネックがあるのかもしれせん。
まとめ
今回の記事では、Pytorch-BigGraphについて紹介しつつ、Wikipediaにおけるグラフ埋め込みを計算してその実力を測ってみました。
約200万ノード1億エッジという巨大なグラフ構造に対してもメモリ消費を抑えながら分散表現を計算できるのは、Pytorch-BigGraphの大きな長所だと思います。一方で計算時間の長さが気になるポイントでしょうか。流石に結果がわかるのが数日〜1週間後というのは、学習に時間がかかるディープラーニングの世界と比較しても長い方でしょう。試行錯誤しながら実験を繰り返すのにはあまりに時間がかかりすぎるため、学習の途中でモデルの学習具合を見るなど工夫が必要そうです。あとは複数の計算機による分散計算で向上する余地はありそうなものの、今回は環境準備が間に合わず手が出せなかったので、その実力が気になるところです。
最後に
ちなみに、Sansanでは今回扱ったWikipediaのデータ規模以上のスケールのグラフを対象にした研究開発も行っています。ミリオン/ビリオンスケールに及ぶ大規模グラフの深淵を覗きたい方は、ぜひ弊チームで一緒に働きましょう。お待ちしております。