RelativisticGANの論文を読んでPytorchで実装した その2
その1の続き
Standardな方のRSGANを実装してみる。
WGANまでTensorflowで実装してて今更Pytorchに変えたのはGeneratorとCriticのアーキテクチャの部分とか訓練の部分の定義がめんちいから。自分が効率悪い書き方してるだけの向上心がクズなだけです・・・
訓練のデータセットはhiragana73なるものを使ってみた。某開始5分村焼きソシャゲとデータセットの名前が似てたからそれだけ。
文字画像データセット(平仮名73文字版)を試験公開しました | NDLラボ
コードはここ
Pytorch初めて触ったけどかなり良さげだった。
書いてて感動したのはまず最適化の部分
GANではgeneratorとcriticで別々に更新するパラメータを指定しないといけない。
tensorflowのときはパラメータを指定するとき
self.cri_vars = [x for x in tf.trainable_variables() if "cri_" in x.name] self.gen_vars = [x for x in tf.trainable_variables() if "gen_" in x.name]
こんな感じでパラメータのリスト用意してoptimizerに突っ込んだ。
pytorchだとこれで終わる
self.critic = Critic() self.opt_critic = torch.optim.RMSprop(self.critic.parameters(), lr=0.00002) self.generator = Generator() self.opt_generator = torch.optim.RMSprop(self.generator.parameters(), lr=0.00002)
訓練中に更新ステップでself.opt_critic.step()
とか呼び出せばいい。特に一部のパラメータは更新しないとかじゃなきゃ楽すぎる。嬉しい
あとDataLoader
なるものもよかった
今まで向上心がないので画像データを[-1,1]の範囲に正規化するのをミニバッチ(X_train)でデータ取得してから
(X_train, _) = train_generator.next() X_train = (X_train - 127.5) / 127.5
こんな感じで取得してそこから直してモデルに渡してた。
pytorchだとこんなのでいけた
transform = torchvision.transforms.Compose([ torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor(), torchvision.transforms.Lambda(lambda x: (x*255. - 127.5)/127.5) ]) train_dataset = torchvision.datasets.ImageFolder( root='./hiragana73/', transform=transform )
それがtransform・・・僕の求めていた力・・・。RandomCropなりFlipなり色々あるけど特にtorchvision.transforms.Lambda
これすき
kerasにもkeras.preprocessing.image.ImageDataGenerator
みたいのあるけどrescaleしかなくて[-1,1]の範囲に正規化するのめんどかった(探せば楽な方法あるんだろうけど)
地味にWGANのクリッピングの操作がこれだけで終わるのも嬉しすぎた
for p in self.critic.parameters(): p.data = torch.clamp(p.data,-self.clip_value,self.clip_value)
元の定義されてたRSGANの数式
これに従ってRSGANの更新ステップは次のように書いた
generated_X = self.critic(self.generator(noise_z)) real_X = self.critic(X_train.detach()) cri_loss = -torch.mean(torch.log(torch.sigmoid(real_X - generated_X))) total_cri_loss += cri_loss.item() self.opt_critic.zero_grad() cri_loss.backward() self.opt_critic.step() for p in self.critic.parameters(): p.requires_grad = False generated_X = self.critic(self.generator(noise_z)) real_X = self.critic(X_train.detach()) gen_loss = -torch.mean(torch.log(torch.sigmoid(generated_X - real_X))) total_gen_loss += gen_loss.item() self.opt_generator.zero_grad() gen_loss.backward() self.opt_generator.step()
話は戻って論文で学習時間も改善できると書いてあったし折角なのでWGANもpytorchで書き直して1エポックあたりにかかる学習時間を比較してみた
1エポック 48x48 のグレースケール画像を2万枚として環境は貧乏なのでGoogle Colaboratoryを頼る
WGANはcriticの回数を5に固定して1エポック100秒ほどかかったのに対してRSGANは1エポック35秒ほどだった。criticの回数1と思えば早いのはそれもそう
学習時間が早くても精度が悪かったら駄目じゃん。50エポックまでの精度を比較する。
WGANのGeneratorが生成したひらがなと思わしき画像はこんな感じ
各エポックごとのCriticとGeneratorのロス(1ステップ100枚=200ステップの合計)についてはこんなん
これに対してRSGANのGeneratorが生成したひらがなと思わしき画像たち
RSGANのCriticとGeneratorのロス
ヤベ100エポックまで回してるのバレた
う、うーん・・どっちも魔界の王を決める戦いの魔本の文字みたいなのしか生成してない気がする。RSGANのがちょっとだけ良さそうに見えなくもない
早くてこの精度ならめっちゃいいじゃんRelativistic GAN
参考にしたもの
https://github.com/AlexiaJM/RelativisticGAN
実践Pytorch
論文ちょいちょい読んでいるけれど、自分の元の数学の知識が乏しいせいで大事でとても面白い部分をスルーして読んでいるのを毎回痛感する・・・。
高校のときに読んだ算数の小説に「ゼータ関数の自明でない零点の実数部は全て1/2である」といったのが載ってた。当時は全くわかんなくてそれでもすごい惹かれて大学入って数学の授業受ければわかるかなあとか思ってたし、Poincare embeddingという論文を読んで双曲幾何学をすげーと思ってせめて論文の述べてる仕組みわかるようになりたいとか、読んでいくたび論文に勉強足りなさすぎだハゲといった感じでボコボコにされつつ聞いたことのある数学の単語とかが出てきて、どうしてそれが出てきたのか意味が知りたかったり。わかればきっと楽しいはずで、多分・・・
まだまだ勉強足りないのに時間だけが迫ってくる。うーん。全く手をつけていない卒論頑張ろう