Sansan Tech Blog

Sansanのものづくりを支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信

Stein Variational Gradient Descentの理論と実装

今年の4月に新卒入社した、DSOC R&D Groupの橋本です。前回はこちらの記事で登場しました。

buildersbox.corp-sansan.com

今回は、通常の変分推論よりも高精度に事後分布を近似するアルゴリズムの1つである、Stein Variational Gradient Descentについて説明します。まず一般的な変分推論について説明します。

変分推論 (Variational Inference)

 変分推論は、真の事後分布に対しパラメトリックな確率分布を仮定して事後分布を近似する手法です。Jensenの不等式を用いることで、対数周辺尤度の下限である変分下限が導出できます。入力を {\bf x}、確率モデルのパラメータを {\bf \theta} とすると、

{ \displaystyle
\begin{eqnarray}
  \log p ( {\bf x} ) &=& \log \int p ( {\bf x} , {\bf \theta} ) d {\bf \theta} \tag {1} \\
  & = & \log \int \frac { q _ { \xi } ( {\bf \theta} | {\bf x} ) } { q _ { \xi } ( {\bf \theta} | {\bf x} ) } p ( {\bf x} , {\bf \theta} ) d {\bf \theta} \tag {2} \\
  &\ge& \int q _ { \xi } ( {\bf \theta} | {\bf x} ) \log \frac { p ( {\bf x} , {\bf \theta} ) } { q _ { \xi } ( {\bf \theta} | {\bf x} ) } d {\bf \theta}  \tag {3} \\ 
  &=& F \lbrack q _ { \xi } ( {\bf \theta} | {\bf x} ) \rbrack \tag {4}
\end{eqnarray}
}

として変分下限  F \lbrack q _ { \xi } ( {\bf \theta} | {\bf x} ) \rbrack が得られます。ここで  p ( {\bf x} , {\bf \theta} ) は確率モデルの同時分布、 q _ { \xi } ( {\bf \theta} | {\bf x} ) は変分パラメータ  \xi を有する近似事後分布を表します。さらに、対数周辺尤度   \log p ( {\bf x} ) と変分下限  q _ { \xi } ( {\bf \theta} | {\bf x} ) の差をとると、

{ \displaystyle
\begin{eqnarray}
  \log p ( {\bf x} ) - F \lbrack q _ { \xi } ( {\bf \theta} | {\bf x} ) \rbrack & = & \int  q _ { \xi } ( {\bf \theta} | {\bf x} ) \log p ( { \bf x} ) d {\bf \theta}  -  \int  q _ { \xi } ( {\bf \theta} | {\bf x} ) \log \frac { p ( {\bf x} , {\bf \theta} ) } { q _ { \xi } ( {\bf \theta} | {\bf x} ) } d {\bf \theta}  \tag {5} \\
 & = &  \int  q _ { \xi } ( {\bf \theta} | {\bf x} ) \log \frac { p ( { \bf x} ) q _ { \xi } ( {\bf \theta} | {\bf x} ) } { p ( {\bf \theta} | {\bf x} ) p ( { \bf x} ) } d {\bf \theta} \tag {6} \\
& = &  KL  \lbrack q _ { \xi } ( {\bf \theta} | {\bf x} ) || p ( {\bf \theta} | {\bf x} ) \rbrack \tag {7} 
\end{eqnarray}
}

となり、真の事後分布   p ( {\bf \theta} | {\bf x} ) と近似事後分布  q _ { \xi } ( {\bf \theta} | {\bf x} ) の近さを表す尺度であるKLダイバージェンス (Kullback-Leibler divergence) が得られます。したがって、対数周辺尤度は変分下限とKLダイバージェンスの和になることがわかります。ここで直接KLダイバージェンスを最小化すれば真の事後分布を最もよく近似する近似事後分布が得られるのですが、KLダイバージェンスにおいて真の事後分布は厳密な計算が不可能なので、直接最小化することは非常に困難です。一方で変分下限は計算が容易な確率モデルの同時分布と近似事後分布から導出できます。さらに、対数周辺尤度は  { \bf x } のみに依存する量であるため、 { \bf \theta } および  \xi を動かしてもKLダイバージェンスと変分下限の和は常に一定です。以上から、最もよい近似事後分布を得るには、KLダイバージェンスを最小化する代わりに変分下限を最大化すれば良いことがわかります。しかし、近似事後分布にはさらに因子分解可能であるという仮定 (平均場近似など)を課す必要があるため、得られる近似事後分布の表現力は必然的に大きく制限されることになります。

Stein Variational Gradient Descent (SVGD)

概要

 SVGDは、再生核ヒルベルト空間上における汎関数微分による勾配降下によって、近似事後分布と真の事後分布とのKLダイバージェンスを最小化するアルゴリズムです。*1 初めに紹介した通常の変分推論と比較すると、SVGDは真の事後分布に対してパラメトリックな確率分布や因子分解を仮定する必要がないため、近似事後分布の表現力が大きく向上しています。 本手法をはじめとする、サンプル集合をKLダイバージェンスを最小化するように更新する変分推論は Particle-based Variational Inference と呼ばれ、近年その近似精度の高さや最適輸送との繋がりが注目されています。*2

Kernelized Stein Discrepancy

まず、再生核ヒルベルト空間上での汎関数微分に基づく勾配降下とKLダイバージェンス最小化をつなぐ、Stein DiscrepancyおよびKernelized Stein Discrepancyについて説明します。

 x \chi \in \mathbb { R } ^ {d} における連続確率変数とすると、連続かつ微分可能な分布 p(x) について、以下の式が成り立つことが知られています。

{ \displaystyle
{ \mathbb{E}_{x \sim p} \lbrack \mathcal{A}_{p} \phi (x) \rbrack } = 0 \tag {8}
}

ここで  \mathcal{A} _ {p} \phi (x) = \phi (x) \nabla _ {x} \log p(x) + \nabla _ {x} \phi(x) であり、 \mathbb {E} _ { x \sim p} \lbrack \mathcal { A} _ {p} \phi (x) \rbrack をStein identityと言います。一方で p(x) と異なる連続かつ微分可能な分布 q(x) を考えると、p \neq q の場合には分布  q (x) で期待値をとってもゼロにはなりません。しかし、分布 p(x) との不一致の尺度として考えることができ、これをStein Discrepancyと呼びます。Stein Discrepancyは

{ \displaystyle
\mathbb{S} (q, p) = \max_{\phi \in \mathcal{F} } { \mathbb{E}_{x \sim q} \lbrack \mathcal{A}_{p} \phi (x) \rbrack } \tag{9}
}

として定義されます。ここで、  \mathcal {F} は 有界なLipschitz  normをもつ関数の集合です。Eq. (9) は \mathcal{F} において汎関数  \mathbb {E}_ { x \sim q } \lbrack \mathcal{ A} _ { p } \phi (x) \rbrack が最大となるように  \phi を選ぶことを意味しているのですが、 \mathbb{S} (q, p) 自体は小さいほうが近似事後分布  q (x)  と真の事後分布  p (x) が近いためよい、ということに注意してください。Eq. (9) の変分問題を解くにあたり、滑らかな関数  \phi (x) をどのように選ぶかが、 \mathbb{S} (q, p) やその計算コストに大きく影響します。しかし、  \mathcal {F} を十分に広くとるには無限個の基底関数を使用する必要があり、Stein discrepancyの変分問題を解くのは計算コスト上極めて非効率であることが知られています。*3
 そこで、再生核ヒルベルト空間  \mathcal{H}^{d} において変分問題を解くことを考えます。 \mathcal{H}^{d} 上における Stein Discrepancyである Kernelized Stein Discrepancy (KSD) は次の式で定義されます。*4

{ \displaystyle
\mathbb{S} (q, p) = \max_{\phi \in \mathcal{H}^{d} } { \mathbb{E}_{x \sim q} \lbrack \mathcal{A}_{p} \phi (x) \rbrack } \tag{10}
}

この場合変分問題の最適解を与える \mathcal{H}^{d} 上の関数  \phi ^{*} (x') が得られ、

{ \displaystyle
\phi ^{*} (x') =  \mathbb {E} _ { x \sim q} \lbrack k( x , x' ) \nabla _ { x } \log p( x' ) + \nabla _ { x } k( x , x' ) \rbrack \tag{11}
}

となります。 k にはRBF (radial basis function) カーネルのように、滑らかかつ正定値性を満たすカーネル関数が用いられます。

Kernelized Stein DiscrepancyとKLダイバージェンス

 先ほど示したKSDと、変分推論において最小化されるKLダイバージェンスにどのような関係があるのでしょうか。微小数を\epsilon とし、 恒等写像に摂動を与えた変換  {\bf T} (x) = x + \epsilon \phi (x) を考えると、次の定理が成立します。

Theorem 3.1. [Q. Liu et.al., 2017]

 x \sim q (x)  z = { \bf T } (x) である q _ { \lbrack { \bf T} \rbrack } (z) において

{ \displaystyle
\nabla _ { \epsilon } KL  \lbrack q _ { {  \lbrack \bf T \rbrack } } || p \rbrack \mathrel{\bigg|} _ { \epsilon = 0 } = - \mathbb{ E } _ {x \sim q} \lbrack \mathcal{A}_{p} \phi (x) \rbrack \tag {12}
}

が成り立つ。

証明は [Q. Liu et.al., 2017] のAppendixに書かれていますので、そちらを御覧ください。

 Eq. (12)において、右辺はEq.(9) で最大化される量と同じであることがわかります。右辺を最大化する場合、先程と同様に再生核ヒルベルト空間を考えるとKSDになることがわかります。つまり、近似事後分布と真の事後分布のKLダイバージェンスを最も大きく減少させるようにサンプルの位置が更新され、その減少幅はKSDと一致するようになります。*5

SVGDのアルゴリズム

 アルゴリズムとしては非常に簡単です。初期分布からサンプルを取得し、  x _ { i } ^ { l + 1 } \leftarrow x _ { i } ^ { l } + \epsilon \hat { \phi ^ { * } }  (x _ { i } ^ { l }) に従ってサンプルの位置を更新するだけです。ただ、  \hat { \phi ^ {*} }  (x _ {i} ^ {l}) は実際にはイテレーション  l 時点でのサンプルを用いた近似値

{ \displaystyle
\phi ^{*} (x) =  \frac {1} {n} \sum _ {j = 1} ^ {n} \lbrack k( x _ {j} ^ {l} , x) \nabla _ { x _ {i} ^ {l}  } \log p(x _ {j} ^ {l} ) + \nabla _ { x _ {i} ^ {l} } k(x _ {i} ^ {l} , x) \rbrack \tag{13}
}

を使います。

Algorithm 1: Stein Variational Gradient Descent [Q. Liu et.al., 2017]

入力: イテレーション数  L、初期分布  q_{0} (x) からの n 個のサンプル集合  \{ x _ {i} ^ {0} \} _ {i=1} ^ {n}
出力: 目的の分布を近似するサンプル集合  \{ x _ {i} \} _ {i=1} ^ {n}

  1.  {\bf for} \, \, l \leftarrow 1 \, \, to \, \, L \, \, {\bf do}
  2.   \,  \, \, \, x _ {i} ^ {l + 1} \leftarrow x _ {i} ^ {l} + \epsilon \hat { \phi ^ {*} }  (x _ {i} ^ {l}) where  \hat { \phi ^ { * } } (x)= \frac {1} {n} \sum _ {j = 1} ^ {n} \lbrack k( x _ {j} ^ {l} , x ) \nabla _ { x _ {i} ^ {l} } \log p(x _ {j} ^ {l} ) + \nabla _ { x _ {i} ^ {l} } k(x _ {i} ^ {l} , x) \rbrack
  3.  {\bf end \, for}

SVGDの実装と検証

 簡単な例ではありますが、実際にSVGDを用いて初期分布からのサンプルを混合ガウス分布に近づけてみます。初期分布からのサンプルが実際に混合ガウス分布を構成するように動き、真の分布に近づくにつれてサンプルの位置の変化が小さくなっていくのがわかりますね。

f:id:w-hash52:20190908232841g:plain
2次元の等方ガウス分布からのサンプルを混合ガウス分布に近づけています

SVGDおよびそのアニメーションを生成するコードはこちらのリポジトリに置いていますので、興味がある方は御覧ください。

github.com

最後に

 今回はSVGDという、再生核ヒルベルト空間上における汎関数微分に基づく勾配降下を用いた変分推論のアルゴリズムを紹介しました。ベイズ推論では一般的にMarkov Chain Monte Carlo (MCMC) と変分推論が使われるのですが、一般的にMCMCは理論的に真の事後分布が得られる一方で推論が遅く、一方で変分推論は近似精度はMCMCと比べると劣るのですが、高速に近似事後分布が得られるという特徴があります。実務の観点から言うと、近似精度が高くかつ高速に確率分布を推論する手法に興味があるため、MCMCをより高速にしたり変分推論の近似精度を向上させる試みについて、もっと調べてみたいですね。

1: Q. Liu, and D. Wang, "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm", in NIPS, 2017. Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm

2: C. Liu, J. Zhuo, P. Cheng, R. Zhang, and J. Zhu, "Understanding and Accelerating Particle-Based Variational Inference", in ICML, 2019. Understanding and Accelerating Particle-Based Variational Inference

3: J. Gorham, and L. Mackey, "Measuring Sample Quality with Stein's Method", in NIPS, 2015. Measuring Sample Quality with Stein's Method

4: Q. Liu, J. D. Lee, and M. I. Jordan,"A Kernelized Stein Discrepancy for Goodness-of-fit Tests and Model Evaluation", in ICML, 2016.

5: こちらの短いノートのFigure 2に、KSDと他のダイバージェンスとの関連がまとめられています。http://www.cs.utexas.edu/~lqiang/PDF/ksd_short.pdf

© Sansan, Inc.