WGANの論文読んでTensorflowで実装する その2

WGANの論文読んでTensorflowで実装する その1 - 時給600円の続き

前回はEarth Mover DistanceもしくはWasserstein Distanceが他のJSダイバージェンスやTV距離と比べて優れてるというのをまとめた。

このEM距離をGANの目的関数として使いたいが、

{ \displaystyle W(\mathbb{P}_r\ ,\ \mathbb{P}_g) = \inf_{\gamma \in \Pi (\mathbb{P}_r\ ,\ \mathbb{P}_g)} \mathbb{E}_{(x,y) \sim \gamma} [ || x\ - \ y || ] }

このままでは使うことができないと書いてある。そもそも同時分布の集合を求めるのも大変だし、KLダイバージェンスと違って積分が閉じてないとかなんとか。

駄目じゃん。ってなるがその次のセクションに双対性というので以下の式を計算することでEM距離の値を求めることができると書いてある。双対問題はまだよくわかってない(´・ω・)
なんというかAの世界では掛け算で解くけど、Bの世界では足し算で解けるみたいな物事を別の世界で考えて計算を楽にする感じなんだろうか

{ \displaystyle W(\mathbb{P}_r\ ,\ \mathbb{P}_\theta) = \sup_{|| f ||_L \ \leq 1} \mathbb{E}_{x \sim \mathbb{P}_r} [ f(x) ] -  \mathbb{E}_{x \sim \mathbb{P}_\theta}[ f(x) ] \ \ \ (2) }

{ \displaystyle || f ||_L }という謎のものがある。まず関数{ \displaystyle f }リプシッツ関数( 1 - Lipshitz )であることが条件らしい。リプシッツ関数って何だよって思ったけど

f:id:Owatank:20180627153242p:plain

ある関数 { \displaystyle f }が存在して、その曲線が上の赤の三角形のように表せるならリプシッツ関数といえる認識でいいっぽい。間違ってるかもしれない。
要は増加量が線形というか緩やかなものがリプシッツ関数といえるのかな。そうなると{ \displaystyle x^{3} }とかはリプシッツ関数とはいえないはず(/・ω・)/
論文の端っこに、sigmoidtanhなどはリプシッツ関数の例であると説明があった。そういう認識で問題なさそう

問題は{ \displaystyle || f ||_L \leq 1}{ \displaystyle \leq 1}の部分。期待値の差分を取るから値はスカラーなはずで、その差が1以下のものって条件かな。ぬん・・・

GANなので識別器と生成器の2つがある。この式に2つをあてはめて目的関数として実用するには

{ \displaystyle \max_{w \in \mathcal{W} }  \mathbb{E}_{x \sim \mathbb{P}_r} [ f_w(x) ] -  \mathbb{E}_{z \sim p(z)}[ f_w(g_\theta(z) ) ]  }

とすればいいらしい。maxだから最大値を求めるのか。識別器はDiscriminatorだからよく d(x) みたいに表されるけど、なんで{ \displaystyle f_w(x)  }と書かれているのかというのは後にわかるから置いておいて、この式の条件下で最大値を取るとき、その時のパラメータ { \displaystyle \theta } で生成器{ \displaystyle  g_\theta(z) }は本物に近い生成データを得られる、つまり欲しい真の分布{ \displaystyle P_r }に近くなってるはず。前にEM距離は連続であると書かれていたから、微分ができて徐々に{ \displaystyle P_r }に近くなるように学習できるぜってことなんだな。すごい・・・すごくない・・・?

と思ったら次にこんな一文がある

Now comes the question of finding the function f that solves the maximization problem in equation (2).

うん?関数 { \displaystyle f} を見つける?
関数 { \displaystyle f} を表現するために使われる、集合 { \displaystyle \mathcal{W} }の中から取れるパラメータ { \displaystyle w } を先に最適化して式(2)が最大になるような関数 { \displaystyle f} もといパラメータ { \displaystyle w } 先に見つけないと駄目なのかな。
だからあくまでこの関数 { \displaystyle f}識別の役割とはいえないから Discriminator ではなく Critic という名前でこの論文では呼ばれている。

で、パラメータの更新によっては { \displaystyle w } の値が集合 { \displaystyle \mathcal{W} } にはない値を取る場合もある。これでは式(2)の条件を満たせないので、無理やり値を抑える(クリッピングと呼ばれてる)。式(2)の条件というより値がはみ出るとリプシッツ関数として成り立たないからとかだろうか。

値を抑える範囲としては論文では [ -0.01 , 0.01 ]が採用されている。この範囲が大きすぎると最適なパラメータ { \displaystyle w } を見つけるのに時間が掛かって、小さすぎると今度は勾配消失の問題が起きやすいと書かれている。難しい・・・

まとめると、先に Critic { \displaystyle f_w }のパラメータ { \displaystyle w } を式(2)

{ \displaystyle W(\mathbb{P}_r\ ,\ \mathbb{P}_\theta) = \sup_{|| f ||_L \ \leq 1} \mathbb{E}_{x \sim \mathbb{P}_r} [ f(x) ] -  \mathbb{E}_{x \sim \mathbb{P}_\theta}[ f(x) ] \ \ \ (2) }

が最大になるように訓練して、(パラメータの値がある範囲を超えたら抑える)

何回か訓練させた後、つまりEM距離になってくれているCritic { \displaystyle f_w }を使って、生成器 { \displaystyle  g_\theta(z) } のパラメータ { \displaystyle \theta } をより本物が作れるような方向に、つまり分布{ \displaystyle P_r }に近づくように更新していく。

生成器のパラメータの更新については数式ではこう表現されている。

{ \displaystyle \nabla_\theta W(\mathbb{P}_r\ ,\ \mathbb{P}_\theta) = -  \mathbb{E}_{z \sim p(z)} [ \nabla_\theta f(g_\theta (z)) ] }

GANやDCGANでいえば critic を discriminatorと見立てたとき、{ \displaystyle f_w(x)} の返す値は 1(本物)と見れるので、上記の式でも { \displaystyle f(g_\theta (z)) }が 1(本物)の値を返すように学習するという見方で大丈夫だろうか。

学習の方法の流れが優しく書いてあったので、WGANを実装してみる。UnrolledGANのときと同じくガウス分布のデータを使って実験する。

1回の学習ステップの流れとしてはこうする

f:id:Owatank:20180627170709j:plain

先に k回 critic を更新してから、k回更新したcritic、もといEM距離を使って1回だけ生成器を更新する。なんかUnrolledGANと似てるな。

前回コード載せたけどここにも置く

github.com

論文で再三リプシッツ関数だからなと述べているのに、出力にtanh関数やsigmoid関数を使ってしまって、最初学習がうまくいかなかった。
出力を critic、generatorどちらとも

### Generator 
#fc = tf.tanh(tf.nn.xw_plus_b(h2, self.gen_w3, self.gen_b3))
fc = tf.nn.xw_plus_b(h2, self.gen_w3, self.gen_b3)

### Critic
#fc = tf.nn.sigmoid(tf.nn.xw_plus_b(h2, self.cri_w3, self.cri_b3))
fc = tf.nn.xw_plus_b(h2, self.cri_w3, self.cri_b3)

に直したらうまくいった。アレ・・・?ただの恒等関数ってリプシッツ関数でいいのかな?)`Д゚).・;'∴

UnrolledGANのときは、ガウス分布の入力の組が [ -1 , 1 ] の範囲の値をとるからGeneratorの出力関数をtanh関数に設定したけど、そうしないでただの恒等関数でもちゃんと [ -1 , 1 ] の値を取ってきてくれるのだろうか。不思議だなあ・・・

あとはWGANでめんどいクリッピングの操作はゴリ押ししか今のとこできなかったので次のようにオペレーションを定義した

clip_value = 0.01

clip_list = [
    self.critic.cri_w1.assign(
        tf.clip_by_value(self.critic.cri_w1, -clip_value,
                         clip_value)),
    self.critic.cri_w2.assign(
        tf.clip_by_value(self.critic.cri_w2, -clip_value,
                         clip_value)),
    self.critic.cri_w3.assign(
        tf.clip_by_value(self.critic.cri_w3, -clip_value,
                         clip_value)),
    self.critic.cri_b1.assign(
        tf.clip_by_value(self.critic.cri_b1, -clip_value,
                         clip_value)),
    self.critic.cri_b2.assign(
        tf.clip_by_value(self.critic.cri_b2, -clip_value,
                         clip_value)),
    self.critic.cri_b3.assign(
        tf.clip_by_value(self.critic.cri_b3, -clip_value,
                         clip_value))
]

self.clip_op = tf.group(*clip_list)

今回はtf.group()というのを使ってみた。これを使うと訓練時にいちいち各パラメータのクリッピングのオペレーションを呼ばずに、

# Train Critic
# Critic step for Critic
for k in range(critic_step):
    # ノイズ事前分布からノイズをミニバッチ分取得
    noise_z = np.random.uniform(
        -1, 1, size=[batch_size, 100]).astype(np.float32)
    # 訓練データのミニバッチ取得
    cri_perm = np.random.permutation(datanum)
    X_batch = X_train[cri_perm][:batch_size]

    sess.run(
        self.opt_cri,
        feed_dict={
            self.input_X: X_batch,
            self.is_train: False,
            self.gen_z: noise_z
        })

    # Clip Critic Parameter
    sess.run(self.clip_op)

一行でクリッピングの操作を済ませることができる。かしこい

論文通りcritic_stepの回数を5回に設定して実験する。最適化の手法としてAdamのようなmomentum based optimizerはWGANの訓練を不安定にさせるからRMSPropのがいいよと書いてあったけど、なんかこのガウス分布のデータにおける実験ではAdamのがよかったからこっちを採用した。RMSPropはよく知らなかったけど、

We therefore switched to RMSProp [21] which is known to perform well even on very nonstationary problems [13].

と書いてあって、RMSPropおもろいなと思った。いい情報だ(∩ ^ω^ ∩)

実験結果としてはこんな感じ

f:id:Owatank:20180627172915p:plain

マジでただの恒等関数の出力で [ -1 , 1 ]の範囲の値を取ってやがる・・・。ほんまか・・・

前のUnrolledGANの時の結果は次の通りだった
f:id:Owatank:20180601115051p:plain

うーん・・・。UnrolledGANのが綺麗に生成できている気がする。同じネットワーク構造、ハイパーパラメータじゃないから比較は難しいんだけども
でも大体入力データを真似てくれているから実装としてはあまり間違えはなさそう

ポアンカレ埋め込みの論文と同じで距離についてワクワクさせてくれる素敵な論文だった。Appendixのところとかまだ完璧に理解できないのでまた成長したら読みたいな。
どっちも同じFacebookのリサーチャーの人が関わっているらしくてFacebookすげえ・・・

おまけでこの論文では feedforward neural network のことを

By a feedforward neural network we mean a function composed by affine transformations and pointwise nonlinearities which are smooth Lipschitz functions (such as the sigmoid, tanh, elu, softplus, etc).

と下に小さい箇所で述べていた。なんかめっちゃカッコいい・・・?カッコよくない・・・?