WGANの論文読んでTensorflowで実装する その2
WGANの論文読んでTensorflowで実装する その1 - 時給600円の続き
前回はEarth Mover Distance
もしくはWasserstein Distance
が他のJSダイバージェンスやTV距離と比べて優れてるというのをまとめた。
このEM距離をGANの目的関数として使いたいが、
このままでは使うことができないと書いてある。そもそも同時分布の集合を求めるのも大変だし、KLダイバージェンスと違って積分が閉じてないとかなんとか。
駄目じゃん。ってなるがその次のセクションに双対性というので以下の式を計算することでEM距離の値を求めることができると書いてある。双対問題はまだよくわかってない(´・ω・)
なんというかAの世界では掛け算で解くけど、Bの世界では足し算で解けるみたいな物事を別の世界で考えて計算を楽にする感じなんだろうか
という謎のものがある。まず関数がリプシッツ関数( 1 - Lipshitz )であることが条件らしい。リプシッツ関数って何だよって思ったけど
ある関数 が存在して、その曲線が上の赤の三角形のように表せるならリプシッツ関数といえる認識でいいっぽい。間違ってるかもしれない。
要は増加量が線形というか緩やかなものがリプシッツ関数といえるのかな。そうなるととかはリプシッツ関数とはいえないはず(/・ω・)/
論文の端っこに、sigmoid
やtanh
などはリプシッツ関数の例であると説明があった。そういう認識で問題なさそう
問題はのの部分。期待値の差分を取るから値はスカラーなはずで、その差が1以下のものって条件かな。ぬん・・・
GANなので識別器と生成器の2つがある。この式に2つをあてはめて目的関数として実用するには
とすればいいらしい。maxだから最大値を求めるのか。識別器はDiscriminator
だからよく d(x) みたいに表されるけど、なんでと書かれているのかというのは後にわかるから置いておいて、この式の条件下で最大値を取るとき、その時のパラメータ で生成器は本物に近い生成データを得られる、つまり欲しい真の分布に近くなってるはず。前にEM距離は連続であると書かれていたから、微分ができて徐々にに近くなるように学習できるぜってことなんだな。すごい・・・すごくない・・・?
と思ったら次にこんな一文がある
Now comes the question of finding the function f that solves the maximization problem in equation (2).
うん?関数 を見つける?
関数 を表現するために使われる、集合 の中から取れるパラメータ を先に最適化して式(2)が最大になるような関数 もといパラメータ 先に見つけないと駄目なのかな。
だからあくまでこの関数 は識別の役割とはいえないから Discriminator ではなく Critic という名前でこの論文では呼ばれている。
で、パラメータの更新によっては の値が集合 にはない値を取る場合もある。これでは式(2)の条件を満たせないので、無理やり値を抑える(クリッピングと呼ばれてる)。式(2)の条件というより値がはみ出るとリプシッツ関数として成り立たないからとかだろうか。
値を抑える範囲としては論文では [ -0.01 , 0.01 ]が採用されている。この範囲が大きすぎると最適なパラメータ を見つけるのに時間が掛かって、小さすぎると今度は勾配消失の問題が起きやすいと書かれている。難しい・・・
まとめると、先に Critic のパラメータ を式(2)
が最大になるように訓練して、(パラメータの値がある範囲を超えたら抑える)
何回か訓練させた後、つまりEM距離になってくれているCritic を使って、生成器 のパラメータ をより本物が作れるような方向に、つまり分布に近づくように更新していく。
生成器のパラメータの更新については数式ではこう表現されている。
GANやDCGANでいえば critic を discriminatorと見立てたとき、 の返す値は 1(本物)と見れるので、上記の式でも が 1(本物)の値を返すように学習するという見方で大丈夫だろうか。
学習の方法の流れが優しく書いてあったので、WGANを実装してみる。UnrolledGANのときと同じくガウス分布のデータを使って実験する。
1回の学習ステップの流れとしてはこうする
先に k回 critic を更新してから、k回更新したcritic、もといEM距離を使って1回だけ生成器を更新する。なんかUnrolledGANと似てるな。
前回コード載せたけどここにも置く
論文で再三リプシッツ関数だからなと述べているのに、出力にtanh
関数やsigmoid
関数を使ってしまって、最初学習がうまくいかなかった。
出力を critic、generatorどちらとも
### Generator #fc = tf.tanh(tf.nn.xw_plus_b(h2, self.gen_w3, self.gen_b3)) fc = tf.nn.xw_plus_b(h2, self.gen_w3, self.gen_b3) ### Critic #fc = tf.nn.sigmoid(tf.nn.xw_plus_b(h2, self.cri_w3, self.cri_b3)) fc = tf.nn.xw_plus_b(h2, self.cri_w3, self.cri_b3)
に直したらうまくいった。アレ・・・?ただの恒等関数ってリプシッツ関数でいいのかな?)`Д゚).・;'∴
UnrolledGANのときは、ガウス分布の入力の組が [ -1 , 1 ] の範囲の値をとるからGeneratorの出力関数をtanh
関数に設定したけど、そうしないでただの恒等関数でもちゃんと [ -1 , 1 ] の値を取ってきてくれるのだろうか。不思議だなあ・・・
あとはWGANでめんどいクリッピングの操作はゴリ押ししか今のとこできなかったので次のようにオペレーションを定義した
clip_value = 0.01
clip_list = [
self.critic.cri_w1.assign(
tf.clip_by_value(self.critic.cri_w1, -clip_value,
clip_value)),
self.critic.cri_w2.assign(
tf.clip_by_value(self.critic.cri_w2, -clip_value,
clip_value)),
self.critic.cri_w3.assign(
tf.clip_by_value(self.critic.cri_w3, -clip_value,
clip_value)),
self.critic.cri_b1.assign(
tf.clip_by_value(self.critic.cri_b1, -clip_value,
clip_value)),
self.critic.cri_b2.assign(
tf.clip_by_value(self.critic.cri_b2, -clip_value,
clip_value)),
self.critic.cri_b3.assign(
tf.clip_by_value(self.critic.cri_b3, -clip_value,
clip_value))
]
self.clip_op = tf.group(*clip_list)
今回はtf.group()
というのを使ってみた。これを使うと訓練時にいちいち各パラメータのクリッピングのオペレーションを呼ばずに、
# Train Critic # Critic step for Critic for k in range(critic_step): # ノイズ事前分布からノイズをミニバッチ分取得 noise_z = np.random.uniform( -1, 1, size=[batch_size, 100]).astype(np.float32) # 訓練データのミニバッチ取得 cri_perm = np.random.permutation(datanum) X_batch = X_train[cri_perm][:batch_size] sess.run( self.opt_cri, feed_dict={ self.input_X: X_batch, self.is_train: False, self.gen_z: noise_z }) # Clip Critic Parameter sess.run(self.clip_op)
一行でクリッピングの操作を済ませることができる。かしこい
論文通りcritic_step
の回数を5回に設定して実験する。最適化の手法としてAdam
のようなmomentum based optimizer
はWGANの訓練を不安定にさせるからRMSProp
のがいいよと書いてあったけど、なんかこのガウス分布のデータにおける実験ではAdam
のがよかったからこっちを採用した。RMSPropはよく知らなかったけど、
We therefore switched to RMSProp [21] which is known to perform well even on very nonstationary problems [13].
と書いてあって、RMSPropおもろいなと思った。いい情報だ(∩ ^ω^ ∩)
実験結果としてはこんな感じ
マジでただの恒等関数の出力で [ -1 , 1 ]の範囲の値を取ってやがる・・・。ほんまか・・・
前のUnrolledGANの時の結果は次の通りだった
うーん・・・。UnrolledGANのが綺麗に生成できている気がする。同じネットワーク構造、ハイパーパラメータじゃないから比較は難しいんだけども
でも大体入力データを真似てくれているから実装としてはあまり間違えはなさそう
ポアンカレ埋め込みの論文と同じで距離についてワクワクさせてくれる素敵な論文だった。Appendixのところとかまだ完璧に理解できないのでまた成長したら読みたいな。
どっちも同じFacebookのリサーチャーの人が関わっているらしくてFacebookすげえ・・・
おまけでこの論文では feedforward neural network のことを
By a feedforward neural network we mean a function composed by affine transformations and pointwise nonlinearities which are smooth Lipschitz functions (such as the sigmoid, tanh, elu, softplus, etc).
と下に小さい箇所で述べていた。なんかめっちゃカッコいい・・・?カッコよくない・・・?