RelativisticGANの論文を読んでPytorchで実装した その2

その1の続き

Standardな方のRSGANを実装してみる。
WGANまでTensorflowで実装してて今更Pytorchに変えたのはGeneratorとCriticのアーキテクチャの部分とか訓練の部分の定義がめんちいから。自分が効率悪い書き方してるだけの向上心がクズなだけです・・・

訓練のデータセットはhiragana73なるものを使ってみた。某開始5分村焼きソシャゲとデータセットの名前が似てたからそれだけ。
文字画像データセット(平仮名73文字版)を試験公開しました | NDLラボ

コードはここ

github.com

Pytorch初めて触ったけどかなり良さげだった。

書いてて感動したのはまず最適化の部分

GANではgeneratorとcriticで別々に更新するパラメータを指定しないといけない。
tensorflowのときはパラメータを指定するとき

self.cri_vars = [x for x in tf.trainable_variables() if "cri_" in x.name]
self.gen_vars = [x for x in tf.trainable_variables() if "gen_" in x.name]

こんな感じでパラメータのリスト用意してoptimizerに突っ込んだ。

pytorchだとこれで終わる

self.critic = Critic()
self.opt_critic = torch.optim.RMSprop(self.critic.parameters(), lr=0.00002)
        
self.generator = Generator()
self.opt_generator = torch.optim.RMSprop(self.generator.parameters(), lr=0.00002)

訓練中に更新ステップでself.opt_critic.step()とか呼び出せばいい。特に一部のパラメータは更新しないとかじゃなきゃ楽すぎる。嬉しい

あとDataLoaderなるものもよかった
今まで向上心がないので画像データを[-1,1]の範囲に正規化するのをミニバッチ(X_train)でデータ取得してから

(X_train, _) = train_generator.next()
X_train = (X_train - 127.5) / 127.5

こんな感じで取得してそこから直してモデルに渡してた。

pytorchだとこんなのでいけた

transform = torchvision.transforms.Compose([
                torchvision.transforms.Grayscale(),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Lambda(lambda x: (x*255. - 127.5)/127.5)
                 ])

train_dataset = torchvision.datasets.ImageFolder(
    root='./hiragana73/',
    transform=transform
)

それがtransform・・・僕の求めていた力・・・。RandomCropなりFlipなり色々あるけど特にtorchvision.transforms.Lambdaこれすき
kerasにもkeras.preprocessing.image.ImageDataGeneratorみたいのあるけどrescaleしかなくて[-1,1]の範囲に正規化するのめんどかった(探せば楽な方法あるんだろうけど)

地味にWGANのクリッピングの操作がこれだけで終わるのも嬉しすぎた

for p in self.critic.parameters():
    p.data = torch.clamp(p.data,-self.clip_value,self.clip_value)

元の定義されてたRSGANの数式

{ \displaystyle L^{RSGAN}_{D} = - \mathbb{E}_{ (x_{r}\ , \ x_{f}) \sim (\mathbb{P},\mathbb{Q}) } [log(sigmoid(C(x_r) - C_(x_f)))] }
{ \displaystyle L^{RSGAN}_{G} = - \mathbb{E}_{ (x_{r}\ , \ x_{f}) \sim (\mathbb{P},\mathbb{Q}) } [log(sigmoid(C(x_f) - C_(x_r)))] }


これに従ってRSGANの更新ステップは次のように書いた

generated_X = self.critic(self.generator(noise_z))
real_X = self.critic(X_train.detach())

cri_loss = -torch.mean(torch.log(torch.sigmoid(real_X - generated_X)))
total_cri_loss += cri_loss.item()
self.opt_critic.zero_grad()
cri_loss.backward()
self.opt_critic.step()

for p in self.critic.parameters():
  p.requires_grad = False

generated_X = self.critic(self.generator(noise_z))
real_X = self.critic(X_train.detach())

gen_loss = -torch.mean(torch.log(torch.sigmoid(generated_X - real_X)))
total_gen_loss += gen_loss.item()
self.opt_generator.zero_grad()
gen_loss.backward()
self.opt_generator.step()

話は戻って論文で学習時間も改善できると書いてあったし折角なのでWGANもpytorchで書き直して1エポックあたりにかかる学習時間を比較してみた

1エポック 48x48 のグレースケール画像を2万枚として環境は貧乏なのでGoogle Colaboratoryを頼る
WGANはcriticの回数を5に固定して1エポック100秒ほどかかったのに対してRSGANは1エポック35秒ほどだった。criticの回数1と思えば早いのはそれもそう

学習時間が早くても精度が悪かったら駄目じゃん。50エポックまでの精度を比較する。

WGANのGeneratorが生成したひらがなと思わしき画像はこんな感じ
f:id:Owatank:20180923224643p:plain

各エポックごとのCriticとGeneratorのロス(1ステップ100枚=200ステップの合計)についてはこんなん
f:id:Owatank:20180923224717p:plain

これに対してRSGANのGeneratorが生成したひらがなと思わしき画像たち
f:id:Owatank:20180923224831p:plain

RSGANのCriticとGeneratorのロス
f:id:Owatank:20180923225219p:plain ヤベ100エポックまで回してるのバレた

う、うーん・・どっちも魔界の王を決める戦いの魔本の文字みたいなのしか生成してない気がする。RSGANのがちょっとだけ良さそうに見えなくもない
早くてこの精度ならめっちゃいいじゃんRelativistic GAN

参考にしたもの

https://github.com/AlexiaJM/RelativisticGAN
実践Pytorch




論文ちょいちょい読んでいるけれど、自分の元の数学の知識が乏しいせいで大事でとても面白い部分をスルーして読んでいるのを毎回痛感する・・・。
高校のときに読んだ算数の小説に「ゼータ関数の自明でない零点の実数部は全て1/2である」といったのが載ってた。当時は全くわかんなくてそれでもすごい惹かれて大学入って数学の授業受ければわかるかなあとか思ってたし、Poincare embeddingという論文を読んで双曲幾何学をすげーと思ってせめて論文の述べてる仕組みわかるようになりたいとか、読んでいくたび論文に勉強足りなさすぎだハゲといった感じでボコボコにされつつ聞いたことのある数学の単語とかが出てきて、どうしてそれが出てきたのか意味が知りたかったり。わかればきっと楽しいはずで、多分・・・
まだまだ勉強足りないのに時間だけが迫ってくる。うーん。全く手をつけていない卒論頑張ろう